diff --git a/DI-engine b/DI-engine
deleted file mode 160000
index a57bc3024b938c881aaf6511d1fb26296cd98601..0000000000000000000000000000000000000000
--- a/DI-engine
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit a57bc3024b938c881aaf6511d1fb26296cd98601
diff --git a/DI-engine/.flake8 b/DI-engine/.flake8
new file mode 100644
index 0000000000000000000000000000000000000000..9d86ca5e8cd851b29293c3e979cdde17f76fd5f8
--- /dev/null
+++ b/DI-engine/.flake8
@@ -0,0 +1,4 @@
+[flake8]
+ignore=F401,F841,F403,E226,E126,W504,E265,E722,W503,W605,E741,E122,E731
+max-line-length=120
+statistics
diff --git a/DI-engine/.gitignore b/DI-engine/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..8de173d5792bcfcb8e4148dff00340086636e6b9
--- /dev/null
+++ b/DI-engine/.gitignore
@@ -0,0 +1,1431 @@
+# Created by .ignore support plugin (hsz.mobi)
+### ArchLinuxPackages template
+*.tar
+*.tar.*
+*.jar
+*.exe
+*.msi
+*.zip
+*.tgz
+*.log
+*.log.*
+*.sig
+*.mov
+*.pkl
+
+pkg/
+src/
+impala_log/
+
+### CVS template
+/CVS/*
+**/CVS/*
+.cvsignore
+*/.cvsignore
+
+### LibreOffice template
+# LibreOffice locks
+.~lock.*#
+
+### CUDA template
+*.i
+*.ii
+*.gpu
+*.ptx
+*.cubin
+*.fatbin
+
+### Eclipse template
+*.bin
+.metadata
+bin/
+tmp/
+*.tmp
+*.bak
+*.swp
+*~.nib
+local.properties
+.settings/
+.loadpath
+.recommenders
+
+# External tool builders
+.externalToolBuilders/
+
+# Locally stored "Eclipse launch configurations"
+*.launch
+
+# PyDev specific (Python IDE for Eclipse)
+*.pydevproject
+
+# CDT-specific (C/C++ Development Tooling)
+.cproject
+
+# CDT- autotools
+.autotools
+
+# Java annotation processor (APT)
+.factorypath
+
+# PDT-specific (PHP Development Tools)
+.buildpath
+
+# sbteclipse plugin
+.target
+
+# Tern plugin
+.tern-project
+
+# TeXlipse plugin
+.texlipse
+
+# STS (Spring Tool Suite)
+.springBeans
+
+# Code Recommenders
+.recommenders/
+
+# Annotation Processing
+.apt_generated/
+.apt_generated_test/
+
+# Scala IDE specific (Scala & Java development for Eclipse)
+.cache-main
+.scala_dependencies
+.worksheet
+
+# Uncomment this line if you wish to ignore the project description file.
+# Typically, this file would be tracked if it contains build/dependency configurations:
+#.project
+
+### SVN template
+.svn/
+
+### Images template
+# JPEG
+*.jpg
+*.jpeg
+*.jpe
+*.jif
+*.jfif
+*.jfi
+
+# JPEG 2000
+*.jp2
+*.j2k
+*.jpf
+*.jpx
+*.jpm
+*.mj2
+
+# JPEG XR
+*.jxr
+*.hdp
+*.wdp
+
+# Graphics Interchange Format
+*.gif
+*.mp4
+*.mpg
+
+# RAW
+*.raw
+
+# Web P
+*.webp
+
+# Portable Network Graphics
+*.png
+
+# Animated Portable Network Graphics
+*.apng
+
+# Multiple-image Network Graphics
+*.mng
+
+# Tagged Image File Format
+*.tiff
+*.tif
+
+# Scalable Vector Graphics
+*.svg
+*.svgz
+
+# Portable Document Format
+*.pdf
+
+# X BitMap
+*.xbm
+
+# BMP
+*.bmp
+*.dib
+
+# ICO
+*.ico
+
+# 3D Images
+*.3dm
+*.max
+
+### Diff template
+*.patch
+*.diff
+
+### JetBrains template
+# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
+# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
+
+# User-specific stuff
+.idea/**/workspace.xml
+.idea/**/tasks.xml
+.idea/**/usage.statistics.xml
+.idea/**/dictionaries
+.idea/**/shelf
+
+# Generated files
+.idea/**/contentModel.xml
+
+# Sensitive or high-churn files
+.idea/**/dataSources/
+.idea/**/dataSources.ids
+.idea/**/dataSources.local.xml
+.idea/**/sqlDataSources.xml
+.idea/**/dynamic.xml
+.idea/**/uiDesigner.xml
+.idea/**/dbnavigator.xml
+
+# Gradle
+.idea/**/gradle.xml
+.idea/**/libraries
+
+# Gradle and Maven with auto-import
+# When using Gradle or Maven with auto-import, you should exclude module files,
+# since they will be recreated, and may cause churn. Uncomment if using
+# auto-import.
+# .idea/artifacts
+# .idea/compiler.xml
+# .idea/jarRepositories.xml
+# .idea/modules.xml
+# .idea/*.iml
+# .idea/modules
+# *.iml
+# *.ipr
+
+# CMake
+cmake-build-*/
+
+# Mongo Explorer plugin
+.idea/**/mongoSettings.xml
+
+# File-based project format
+*.iws
+
+# IntelliJ
+out/
+
+# mpeltonen/sbt-idea plugin
+.idea_modules/
+
+# JIRA plugin
+atlassian-ide-plugin.xml
+
+# Cursive Clojure plugin
+.idea/replstate.xml
+
+# Crashlytics plugin (for Android Studio and IntelliJ)
+com_crashlytics_export_strings.xml
+crashlytics.properties
+crashlytics-build.properties
+fabric.properties
+
+# Editor-based Rest Client
+.idea/httpRequests
+
+# Android studio 3.1+ serialized cache file
+.idea/caches/build_file_checksums.ser
+
+### CodeIgniter template
+*/config/development
+*/logs/log-*.php
+!*/logs/index.html
+*/cache/*
+!*/cache/index.html
+!*/cache/.htaccess
+
+user_guide_src/build/*
+user_guide_src/cilexer/build/*
+user_guide_src/cilexer/dist/*
+user_guide_src/cilexer/pycilexer.egg-info/*
+
+#codeigniter 3
+application/logs/*
+!application/logs/index.html
+!application/logs/.htaccess
+/vendor/
+
+### Emacs template
+# -*- mode: gitignore; -*-
+*~
+\#*\#
+/.emacs.desktop
+/.emacs.desktop.lock
+*.elc
+auto-save-list
+tramp
+.\#*
+
+# Org-mode
+.org-id-locations
+*_archive
+
+# flymake-mode
+*_flymake.*
+
+# eshell files
+/eshell/history
+/eshell/lastdir
+
+# elpa packages
+/elpa/
+
+# reftex files
+*.rel
+
+# AUCTeX auto folder
+/auto/
+
+# cask packages
+.cask/
+dist/
+
+# Flycheck
+flycheck_*.el
+
+# server auth directory
+/server/
+
+# projectiles files
+.projectile
+
+# directory configuration
+.dir-locals.el
+
+# network security
+/network-security.data
+
+
+### Windows template
+# Windows thumbnail cache files
+Thumbs.db
+Thumbs.db:encryptable
+ehthumbs.db
+ehthumbs_vista.db
+
+# Dump file
+*.stackdump
+
+# Folder config file
+[Dd]esktop.ini
+
+# Recycle Bin used on file shares
+$RECYCLE.BIN/
+
+# Windows Installer files
+*.cab
+*.msix
+*.msm
+*.msp
+
+# Windows shortcuts
+*.lnk
+
+### VisualStudioCode template
+.vscode/*
+!.vscode/settings.json
+!.vscode/tasks.json
+!.vscode/launch.json
+!.vscode/extensions.json
+*.code-workspace
+
+# Local History for Visual Studio Code
+.history/
+
+### CMake template
+CMakeLists.txt.user
+CMakeCache.txt
+CMakeFiles
+CMakeScripts
+Testing
+cmake_install.cmake
+install_manifest.txt
+compile_commands.json
+CTestTestfile.cmake
+_deps
+
+### VisualStudio template
+## Ignore Visual Studio temporary files, build results, and
+## files generated by popular Visual Studio add-ons.
+##
+## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
+
+# User-specific files
+*.rsuser
+*.suo
+*.user
+*.userosscache
+*.sln.docstates
+
+# User-specific files (MonoDevelop/Xamarin Studio)
+*.userprefs
+
+# Mono auto generated files
+mono_crash.*
+
+# Build results
+[Dd]ebug/
+[Dd]ebugPublic/
+[Rr]elease/
+[Rr]eleases/
+x64/
+x86/
+[Ww][Ii][Nn]32/
+[Aa][Rr][Mm]/
+[Aa][Rr][Mm]64/
+bld/
+[Bb]in/
+[Oo]bj/
+[Ll]og/
+[Ll]ogs/
+
+# Visual Studio 2015/2017 cache/options directory
+.vs/
+# Uncomment if you have tasks that create the project's static files in wwwroot
+#wwwroot/
+
+# Visual Studio 2017 auto generated files
+Generated\ Files/
+
+# MSTest test Results
+[Tt]est[Rr]esult*/
+[Bb]uild[Ll]og.*
+
+# NUnit
+*.VisualState.xml
+TestResult.xml
+nunit-*.xml
+
+# Build Results of an ATL Project
+[Dd]ebugPS/
+[Rr]eleasePS/
+dlldata.c
+
+# Benchmark Results
+BenchmarkDotNet.Artifacts/
+
+# .NET Core
+project.lock.json
+project.fragment.lock.json
+artifacts/
+
+# ASP.NET Scaffolding
+ScaffoldingReadMe.txt
+
+# StyleCop
+StyleCopReport.xml
+
+# Files built by Visual Studio
+*_i.c
+*_p.c
+*_h.h
+*.ilk
+*.meta
+*.obj
+*.iobj
+*.pch
+*.pdb
+*.ipdb
+*.pgc
+*.pgd
+*.rsp
+*.sbr
+*.tlb
+*.tli
+*.tlh
+*.tmp_proj
+*_wpftmp.csproj
+*.vspscc
+*.vssscc
+.builds
+*.pidb
+*.svclog
+*.scc
+
+# Chutzpah Test files
+_Chutzpah*
+
+# Visual C++ cache files
+ipch/
+*.aps
+*.ncb
+*.opendb
+*.opensdf
+*.sdf
+*.cachefile
+*.VC.db
+*.VC.VC.opendb
+
+# Visual Studio profiler
+*.psess
+*.vsp
+*.vspx
+*.sap
+
+# Visual Studio Trace Files
+*.e2e
+
+# TFS 2012 Local Workspace
+$tf/
+
+# Guidance Automation Toolkit
+*.gpState
+
+# ReSharper is a .NET coding add-in
+_ReSharper*/
+*.[Rr]e[Ss]harper
+*.DotSettings.user
+
+# TeamCity is a build add-in
+_TeamCity*
+
+# DotCover is a Code Coverage Tool
+*.dotCover
+
+# AxoCover is a Code Coverage Tool
+.axoCover/*
+!.axoCover/settings.json
+
+# Coverlet is a free, cross platform Code Coverage Tool
+coverage*.json
+coverage*.xml
+coverage*.info
+
+# Visual Studio code coverage results
+*.coverage
+*.coveragexml
+
+# NCrunch
+_NCrunch_*
+.*crunch*.local.xml
+nCrunchTemp_*
+
+# MightyMoose
+*.mm.*
+AutoTest.Net/
+
+# Web workbench (sass)
+.sass-cache/
+
+# Installshield output folder
+[Ee]xpress/
+
+# DocProject is a documentation generator add-in
+DocProject/buildhelp/
+DocProject/Help/*.HxT
+DocProject/Help/*.HxC
+DocProject/Help/*.hhc
+DocProject/Help/*.hhk
+DocProject/Help/*.hhp
+DocProject/Help/Html2
+DocProject/Help/html
+
+# Click-Once directory
+publish/
+
+# Publish Web Output
+*.[Pp]ublish.xml
+*.azurePubxml
+# Note: Comment the next line if you want to checkin your web deploy settings,
+# but database connection strings (with potential passwords) will be unencrypted
+*.pubxml
+*.publishproj
+
+# Microsoft Azure Web App publish settings. Comment the next line if you want to
+# checkin your Azure Web App publish settings, but sensitive information contained
+# in these scripts will be unencrypted
+PublishScripts/
+
+# NuGet Packages
+*.nupkg
+# NuGet Symbol Packages
+*.snupkg
+# The packages folder can be ignored because of Package Restore
+**/[Pp]ackages/*
+# except build/, which is used as an MSBuild target.
+!**/[Pp]ackages/build/
+# Uncomment if necessary however generally it will be regenerated when needed
+#!**/[Pp]ackages/repositories.config
+# NuGet v3's project.json files produces more ignorable files
+*.nuget.props
+*.nuget.targets
+
+# Microsoft Azure Build Output
+csx/
+*.build.csdef
+
+# Microsoft Azure Emulator
+ecf/
+rcf/
+
+# Windows Store app package directories and files
+AppPackages/
+BundleArtifacts/
+Package.StoreAssociation.xml
+_pkginfo.txt
+*.appx
+*.appxbundle
+*.appxupload
+
+# Visual Studio cache files
+# files ending in .cache can be ignored
+*.[Cc]ache
+# but keep track of directories ending in .cache
+!?*.[Cc]ache/
+
+# Others
+ClientBin/
+~$*
+*.dbmdl
+*.dbproj.schemaview
+*.jfm
+*.pfx
+*.publishsettings
+orleans.codegen.cs
+
+# Including strong name files can present a security risk
+# (https://github.com/github/gitignore/pull/2483#issue-259490424)
+#*.snk
+
+# Since there are multiple workflows, uncomment next line to ignore bower_components
+# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
+#bower_components/
+
+# RIA/Silverlight projects
+Generated_Code/
+
+# Backup & report files from converting an old project file
+# to a newer Visual Studio version. Backup files are not needed,
+# because we have git ;-)
+_UpgradeReport_Files/
+Backup*/
+UpgradeLog*.XML
+UpgradeLog*.htm
+ServiceFabricBackup/
+*.rptproj.bak
+
+# SQL Server files
+*.mdf
+*.ldf
+*.ndf
+
+# Business Intelligence projects
+*.rdl.data
+*.bim.layout
+*.bim_*.settings
+*.rptproj.rsuser
+*- [Bb]ackup.rdl
+*- [Bb]ackup ([0-9]).rdl
+*- [Bb]ackup ([0-9][0-9]).rdl
+
+# Microsoft Fakes
+FakesAssemblies/
+
+# GhostDoc plugin setting file
+*.GhostDoc.xml
+
+# Node.js Tools for Visual Studio
+.ntvs_analysis.dat
+node_modules/
+
+# Visual Studio 6 build log
+*.plg
+
+# Visual Studio 6 workspace options file
+*.opt
+
+# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
+*.vbw
+
+# Visual Studio LightSwitch build output
+**/*.HTMLClient/GeneratedArtifacts
+**/*.DesktopClient/GeneratedArtifacts
+**/*.DesktopClient/ModelManifest.xml
+**/*.Server/GeneratedArtifacts
+**/*.Server/ModelManifest.xml
+_Pvt_Extensions
+
+# Paket dependency manager
+.paket/paket.exe
+paket-files/
+
+# FAKE - F# Make
+.fake/
+
+# CodeRush personal settings
+.cr/personal
+
+# Python Tools for Visual Studio (PTVS)
+__pycache__/
+*.pyc
+
+# Cake - Uncomment if you are using it
+# tools/**
+# !tools/packages.config
+
+# Tabs Studio
+*.tss
+
+# Telerik's JustMock configuration file
+*.jmconfig
+
+# BizTalk build output
+*.btp.cs
+*.btm.cs
+*.odx.cs
+*.xsd.cs
+
+# OpenCover UI analysis results
+OpenCover/
+
+# Azure Stream Analytics local run output
+ASALocalRun/
+
+# MSBuild Binary and Structured Log
+*.binlog
+
+# NVidia Nsight GPU debugger configuration file
+*.nvuser
+
+# MFractors (Xamarin productivity tool) working folder
+.mfractor/
+
+# Local History for Visual Studio
+.localhistory/
+
+# BeatPulse healthcheck temp database
+healthchecksdb
+
+# Backup folder for Package Reference Convert tool in Visual Studio 2017
+MigrationBackup/
+
+# Ionide (cross platform F# VS Code tools) working folder
+.ionide/
+
+# Fody - auto-generated XML schema
+FodyWeavers.xsd
+
+### Python template
+# Byte-compiled / optimized / DLL files
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+venv/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+### Backup template
+*.gho
+*.ori
+*.orig
+
+### Node template
+# Logs
+logs
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+lerna-debug.log*
+
+# Diagnostic reports (https://nodejs.org/api/report.html)
+report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
+
+# Runtime data
+pids
+*.pid
+*.seed
+*.pid.lock
+
+# Directory for instrumented libs generated by jscoverage/JSCover
+lib-cov
+
+# Coverage directory used by tools like istanbul
+coverage
+*.lcov
+
+# nyc test coverage
+.nyc_output
+
+# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
+.grunt
+
+# Bower dependency directory (https://bower.io/)
+bower_components
+
+# node-waf configuration
+.lock-wscript
+
+# Compiled binary addons (https://nodejs.org/api/addons.html)
+build/Release
+
+# Dependency directories
+jspm_packages/
+
+# Snowpack dependency directory (https://snowpack.dev/)
+web_modules/
+
+# TypeScript cache
+*.tsbuildinfo
+
+# Optional npm cache directory
+.npm
+
+# Optional eslint cache
+.eslintcache
+
+# Microbundle cache
+.rpt2_cache/
+.rts2_cache_cjs/
+.rts2_cache_es/
+.rts2_cache_umd/
+
+# Optional REPL history
+.node_repl_history
+
+# Output of 'npm pack'
+
+# Yarn Integrity file
+.yarn-integrity
+
+# dotenv environment variables file
+.env.test
+
+# parcel-bundler cache (https://parceljs.org/)
+.parcel-cache
+
+# Next.js build output
+.next
+out
+
+# Nuxt.js build / generate output
+.nuxt
+dist
+
+# Gatsby files
+.cache/
+# Comment in the public line in if your project uses Gatsby and not Next.js
+# https://nextjs.org/blog/next-9-1#public-directory-support
+# public
+
+# vuepress build output
+.vuepress/dist
+
+# Serverless directories
+.serverless/
+
+# FuseBox cache
+.fusebox/
+
+# DynamoDB Local files
+.dynamodb/
+
+# TernJS port file
+.tern-port
+
+# Stores VSCode versions used for testing VSCode extensions
+.vscode-test
+
+# yarn v2
+.yarn/cache
+.yarn/unplugged
+.yarn/build-state.yml
+.yarn/install-state.gz
+.pnp.*
+
+### VirtualEnv template
+# Virtualenv
+# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
+[Bb]in
+[Ii]nclude
+[Ll]ib
+[Ll]ib64
+[Ll]ocal
+pyvenv.cfg
+pip-selfcheck.json
+
+### macOS template
+# General
+.DS_Store
+.AppleDouble
+.LSOverride
+
+# Icon must end with two \r
+Icon
+
+# Thumbnails
+._*
+
+# Files that might appear in the root of a volume
+.DocumentRevisions-V100
+.fseventsd
+.Spotlight-V100
+.TemporaryItems
+.Trashes
+.VolumeIcon.icns
+.com.apple.timemachine.donotpresent
+
+# Directories potentially created on remote AFP share
+.AppleDB
+.AppleDesktop
+Network Trash Folder
+Temporary Items
+.apdisk
+
+### Go template
+# Binaries for programs and plugins
+*.exe~
+*.dll
+*.dylib
+
+# Test binary, built with `go test -c`
+*.test
+
+# Output of the go coverage tool, specifically when used with LiteIDE
+*.out
+
+# Dependency directories (remove the comment below to include it)
+# vendor/
+
+### C template
+# Prerequisites
+*.d
+
+# Object files
+*.o
+*.ko
+*.elf
+
+# Linker output
+*.map
+*.exp
+
+# Precompiled Headers
+*.gch
+
+# Libraries
+*.lib
+*.a
+*.la
+*.lo
+
+# Shared objects (inc. Windows DLLs)
+*.so.*
+
+# Executables
+*.app
+*.i*86
+*.x86_64
+*.hex
+
+# Debug files
+*.dSYM/
+*.su
+*.idb
+
+# Kernel Module Compile Results
+*.mod*
+*.cmd
+.tmp_versions/
+modules.order
+Module.symvers
+Mkfile.old
+dkms.conf
+
+### Example user template template
+### Example user template
+
+# IntelliJ project files
+.idea
+*.iml
+gen
+### TextMate template
+*.tmproj
+*.tmproject
+tmtags
+
+### Anjuta template
+# Local configuration folder and symbol database
+/.anjuta/
+/.anjuta_sym_db.db
+
+### XilinxISE template
+# intermediate build files
+*.bgn
+*.bit
+*.bld
+*.cmd_log
+*.drc
+*.ll
+*.lso
+*.msd
+*.msk
+*.ncd
+*.ngc
+*.ngd
+*.ngr
+*.pad
+*.par
+*.pcf
+*.prj
+*.ptwx
+*.rbb
+*.rbd
+*.stx
+*.syr
+*.twr
+*.twx
+*.unroutes
+*.ut
+*.xpi
+*.xst
+*_bitgen.xwbt
+*_envsettings.html
+*_map.map
+*_map.mrp
+*_map.ngm
+*_map.xrpt
+*_ngdbuild.xrpt
+*_pad.csv
+*_pad.txt
+*_par.xrpt
+*_summary.html
+*_summary.xml
+*_usage.xml
+*_xst.xrpt
+
+# iMPACT generated files
+_impactbatch.log
+impact.xsl
+impact_impact.xwbt
+ise_impact.cmd
+webtalk_impact.xml
+
+# Core Generator generated files
+xaw2verilog.log
+
+# project-wide generated files
+*.gise
+par_usage_statistics.html
+usage_statistics_webtalk.html
+webtalk.log
+webtalk_pn.xml
+
+# generated folders
+iseconfig/
+xlnx_auto_0_xdb/
+xst/
+_ngo/
+_xmsgs/
+
+### TortoiseGit template
+# Project-level settings
+/.tgitconfig
+
+### C++ template
+# Prerequisites
+
+# Compiled Object files
+*.slo
+
+# Precompiled Headers
+
+# Compiled Dynamic libraries
+
+# Fortran module files
+*.mod
+*.smod
+
+# Compiled Static libraries
+*.lai
+
+# Executables
+
+### SublimeText template
+# Cache files for Sublime Text
+*.tmlanguage.cache
+*.tmPreferences.cache
+*.stTheme.cache
+
+# Workspace files are user-specific
+*.sublime-workspace
+
+# Project files should be checked into the repository, unless a significant
+# proportion of contributors will probably not be using Sublime Text
+# *.sublime-project
+
+# SFTP configuration file
+sftp-config.json
+sftp-config-alt*.json
+
+# Package control specific files
+Package Control.last-run
+Package Control.ca-list
+Package Control.ca-bundle
+Package Control.system-ca-bundle
+Package Control.cache/
+Package Control.ca-certs/
+Package Control.merged-ca-bundle
+Package Control.user-ca-bundle
+oscrypto-ca-bundle.crt
+bh_unicode_properties.cache
+
+# Sublime-github package stores a github token in this file
+# https://packagecontrol.io/packages/sublime-github
+GitHub.sublime-settings
+
+### Vim template
+# Swap
+[._]*.s[a-v][a-z]
+!*.svg # comment out if you don't need vector files
+[._]*.sw[a-p]
+[._]s[a-rt-v][a-z]
+[._]ss[a-gi-z]
+[._]sw[a-p]
+
+# Session
+Session.vim
+Sessionx.vim
+
+# Temporary
+.netrwhist
+# Auto-generated tag files
+tags
+# Persistent undo
+[._]*.un~
+
+### Autotools template
+# http://www.gnu.org/software/automake
+
+Makefile.in
+/ar-lib
+/mdate-sh
+/py-compile
+/test-driver
+/ylwrap
+.deps/
+.dirstamp
+
+# http://www.gnu.org/software/autoconf
+
+autom4te.cache
+/autoscan.log
+/autoscan-*.log
+/aclocal.m4
+/compile
+/config.guess
+/config.h.in
+/config.log
+/config.status
+/config.sub
+/configure
+/configure.scan
+/depcomp
+/install-sh
+/missing
+/stamp-h1
+
+# https://www.gnu.org/software/libtool/
+
+/ltmain.sh
+
+# http://www.gnu.org/software/texinfo
+
+/texinfo.tex
+
+# http://www.gnu.org/software/m4/
+
+m4/libtool.m4
+m4/ltoptions.m4
+m4/ltsugar.m4
+m4/ltversion.m4
+m4/lt~obsolete.m4
+
+# Generated Makefile
+# (meta build system like autotools,
+# can automatically generate from config.status script
+# (which is called by configure script))
+
+### Lua template
+# Compiled Lua sources
+luac.out
+
+# luarocks build files
+*.src.rock
+*.tar.gz
+
+# Object files
+*.os
+
+# Precompiled Headers
+
+# Libraries
+*.def
+
+# Shared objects (inc. Windows DLLs)
+
+# Executables
+
+
+### Vagrant template
+# General
+.vagrant/
+
+# Log files (if you are creating logs in debug mode, uncomment this)
+# *.log
+
+### Xcode template
+# Xcode
+#
+# gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore
+
+## User settings
+xcuserdata/
+
+## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9)
+*.xcscmblueprint
+*.xccheckout
+
+## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4)
+DerivedData/
+*.moved-aside
+*.pbxuser
+!default.pbxuser
+*.mode1v3
+!default.mode1v3
+*.mode2v3
+!default.mode2v3
+*.perspectivev3
+!default.perspectivev3
+
+## Gcc Patch
+/*.gcno
+
+### Linux template
+
+# temporary files which can be created if a process still has a handle open of a deleted file
+.fuse_hidden*
+
+# KDE directory preferences
+.directory
+
+# Linux trash folder which might appear on any partition or disk
+.Trash-*
+
+# .nfs files are created when an open file is removed but is still being accessed
+.nfs*
+
+### GitBook template
+# Node rules:
+## Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files)
+
+## Dependency directory
+## Commenting this out is preferred by some people, see
+## https://docs.npmjs.com/misc/faq#should-i-check-my-node_modules-folder-into-git
+node_modules
+
+# Book build output
+_book
+
+# eBook build output
+*.epub
+*.mobi
+
+### CodeSniffer template
+# gitignore for the PHP Codesniffer framework
+# website: https://github.com/squizlabs/PHP_CodeSniffer
+#
+# Recommended template: PHP.gitignore
+
+/wpcs/*
+
+### PuTTY template
+# Private key
+*.ppk
+*_pb2.py
+*.pth
+*.pth.tar
+*.pt
+*.npy
+__pycache__
+*.egg-info
+experiment_config.yaml
+api-log/
+log/
+htmlcov
+*.lock
+.coverage*
+/test_*
+.python-version
+/name.txt
+/summary_log
+policy_*
+/data
+.vscode
+formatted_*
+**/exp
+**/benchmark
+**/model_zoo
+*ckpt*
+log*
+*.puml.png
+*.puml.eps
+*.puml.svg
+default*
+events.*
+
+# DI-engine special key
+*default_logger.txt
+*default_tb_logger
+*evaluate.txt
+*total_config.py
+eval_config.py
+collect_demo_data_config.py
+!ding/**/*.py
+events.*
+
+evogym/*
diff --git a/DI-engine/.style.yapf b/DI-engine/.style.yapf
new file mode 100644
index 0000000000000000000000000000000000000000..edd867c28237606d759f83a8242d93ec821557b4
--- /dev/null
+++ b/DI-engine/.style.yapf
@@ -0,0 +1,11 @@
+[style]
+# For explanation and more information: https://github.com/google/yapf
+BASED_ON_STYLE=pep8
+DEDENT_CLOSING_BRACKETS=True
+SPLIT_BEFORE_FIRST_ARGUMENT=True
+ALLOW_SPLIT_BEFORE_DICT_VALUE=False
+JOIN_MULTIPLE_LINES=False
+COLUMN_LIMIT=120
+BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=True
+BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION=2
+SPACES_AROUND_POWER_OPERATOR=True
diff --git a/DI-engine/CHANGELOG b/DI-engine/CHANGELOG
new file mode 100644
index 0000000000000000000000000000000000000000..a42ca1f98713979a08067539173c3403fc794b87
--- /dev/null
+++ b/DI-engine/CHANGELOG
@@ -0,0 +1,489 @@
+2023.11.06(v0.5.0)
+- env: add tabmwp env (#667)
+- env: polish anytrading env issues (#731)
+- algo: add PromptPG algorithm (#667)
+- algo: add Plan Diffuser algorithm (#700)
+- algo: add new pipeline implementation of IMPALA algorithm (#713)
+- algo: add dropout layers to DQN-style algorithms (#712)
+- feature: add new pipeline agent for sac/ddpg/a2c/ppo and Hugging Face support (#637) (#730) (#737)
+- feature: add more unittest cases for model (#728)
+- feature: add collector logging in new pipeline (#735)
+- fix: logger middleware problems (#715)
+- fix: ppo parallel bug (#709)
+- fix: typo in optimizer_helper.py (#726)
+- fix: mlp dropout if condition bug
+- fix: drex collecting data unittest bugs
+- style: polish env manager/wrapper comments and API doc (#742)
+- style: polish model comments and API doc (#722) (#729) (#734) (#736) (#741)
+- style: polish policy comments and API doc (#732)
+- style: polish rl_utils comments and API doc (#724)
+- style: polish torch_utils comments and API doc (#738)
+- style: update README.md and Colab demo (#733)
+- style: update metaworld docker image
+
+2023.08.23(v0.4.9)
+- env: add cliffwalking env (#677)
+- env: add lunarlander ppo config and example
+- algo: add BCQ offline RL algorithm (#640)
+- algo: add Dreamerv3 model-based RL algorithm (#652)
+- algo: add tensor stream merge network tools (#673)
+- algo: add scatter connection model (#680)
+- algo: refactor Decision Transformer in new pipeline and support img input and discrete output (#693)
+- algo: add three variants of Bilinear classes and a FiLM class (#703)
+- feature: polish offpolicy RL multi-gpu DDP training (#679)
+- feature: add middleware for Ape-X distributed pipeline (#696)
+- feature: add example for evaluating trained DQN (#706)
+- fix: to_ndarray fails to assign dtype for scalars (#708)
+- fix: evaluator return episode_info compatibility bug
+- fix: cql example entry wrong config bug
+- fix: enable_save_figure env interface
+- fix: redundant env info bug in evaluator
+- fix: to_item unittest bug
+- style: polish and simplify requirements (#672)
+- style: add Hugging Face Model Zoo badge (#674)
+- style: add openxlab Model Zoo badge (#675)
+- style: fix py37 macos ci bug and update default pytorch from 1.7.1 to 1.12.1 (#678)
+- style: fix mujoco-py compatibility issue for cython<3 (#711)
+- style: fix type spell error (#704)
+- style: fix pypi release actions ubuntu 18.04 bug
+- style: update contact information (e.g. wechat)
+- style: polish algorithm doc tables
+
+2023.05.25(v0.4.8)
+- env: fix gym hybrid reward dtype bug (#664)
+- env: fix atari env id noframeskip bug (#655)
+- env: fix typo in gym any_trading env (#654)
+- env: update td3bc d4rl config (#659)
+- env: polish bipedalwalker config
+- algo: add EDAC offline RL algorithm (#639)
+- algo: add LN and GN norm_type support in ResBlock (#660)
+- algo: add normal value norm baseline for PPOF (#658)
+- algo: polish last layer init/norm in MLP (#650)
+- algo: polish TD3 monitor variable
+- feature: add MAPPO/MASAC task example (#661)
+- feature: add PPO example for complex env observation (#644)
+- feature: add barrier middleware (#570)
+- fix: abnormal collector log and add record_random_collect option (#662)
+- fix: to_item compatibility bug (#646)
+- fix: trainer dtype transform compatibility bug
+- fix: pettingzoo 1.23.0 compatibility bug
+- fix: ensemble head unittest bug
+- style: fix incompatible gym version bug in Dockerfile.env (#653)
+- style: add more algorithm docs
+
+2023.04.11(v0.4.7)
+- env: add dmc2gym env support and baseline (#451)
+- env: update pettingzoo to the latest version (#597)
+- env: polish icm/rnd+onppo config bugs and add app_door_to_key env (#564)
+- env: add lunarlander continuous TD3/SAC config
+- env: polish lunarlander discrete C51 config
+- algo: add Procedure Cloning (PC) imitation learning algorithm (#514)
+- algo: add Munchausen Reinforcement Learning (MDQN) algorithm (#590)
+- algo: add reward/value norm methods: popart & value rescale & symlog (#605)
+- algo: polish reward model config and training pipeline (#624)
+- algo: add PPOF reward space demo support (#608)
+- algo: add PPOF Atari demo support (#589)
+- algo: polish dqn default config and env examples (#611)
+- algo: polish comment and clean code about SAC
+- feature: add language model (e.g. GPT) training utils (#625)
+- feature: remove policy cfg sub fields requirements (#620)
+- feature: add full wandb support (#579)
+- fix: confusing shallow copy operation about next_obs (#641)
+- fix: unsqueeze action_args in PDQN when shape is 1 (#599)
+- fix: evaluator return_info tensor type bug (#592)
+- fix: deque buffer wrapper PER bug (#586)
+- fix: reward model save method compatibility bug
+- fix: logger assertion and unittest bug
+- fix: bfs test py3.9 compatibility bug
+- fix: zergling collector unittest bug
+- style: add DI-engine torch-rpc p2p communication docker (#628)
+- style: add D4RL docker (#591)
+- style: correct typo in task (#617)
+- style: correct typo in time_helper (#602)
+- style: polish readme and add treetensor example
+- style: update contributing doc
+
+2023.02.16(v0.4.6)
+- env: add metadrive env and related ppo config (#574)
+- env: add acrobot env and related dqn config (#577)
+- env: add carracing in box2d (#575)
+- env: add new gym hybrid viz (#563)
+- env: update cartpole IL config (#578)
+- algo: add BDQ algorithm (#558)
+- algo: add procedure cloning model (#573)
+- feature: add simplified PPOF (PPO × Family) interface (#567) (#568) (#581) (#582)
+- fix: to_device and prev_state bug when using ttorch (#571)
+- fix: py38 and numpy unittest bugs (#565)
+- fix: typo in contrastive_loss.py (#572)
+- fix: dizoo envs pkg installation bugs
+- fix: multi_trainer middleware unittest bug
+- style: add evogym docker (#580)
+- style: fix metaworld docker bug
+- style: fix setuptools high version incompatibility bug
+- style: extend treetensor lowest version
+
+2022.12.13(v0.4.5)
+- env: add beergame supply chain optimization env (#512)
+- env: add env gym_pybullet_drones (#526)
+- env: rename eval reward to episode return (#536)
+- algo: add policy gradient algo implementation (#544)
+- algo: add MADDPG algo implementation (#550)
+- algo: add IMPALA continuous algo implementation (#551)
+- algo: add MADQN algo implementation (#540)
+- feature: add new task IMPALA-type distributed training scheme (#321)
+- feature: add load and save method for replaybuffer (#542)
+- feature: add more DingEnvWrapper example (#525)
+- feature: add evaluator more info viz support (#538)
+- feature: add trackback log for subprocess env manager (#534)
+- fix: halfcheetah td3 config file (#537)
+- fix: mujoco action_clip args compatibility bug (#535)
+- fix: atari a2c config entry bug
+- fix: drex unittest compatibility bug
+- style: add Roadmap issue of DI-engine (#548)
+- style: update related project link and new env doc
+
+2022.10.31(v0.4.4)
+- env: add modified gym-hybrid including moving, sliding and hardmove (#505) (#519)
+- env: add evogym support (#495) (#527)
+- env: add save_replay_gif option (#506)
+- env: adapt minigrid_env and related config to latest MiniGrid v2.0.0 (#500)
+- algo: add pcgrad optimizer (#489)
+- algo: add some features in MLP and ResBlock (#511)
+- algo: delete mcts related modules (#518)
+- feature: add wandb middleware and demo (#488) (#523) (#528)
+- feature: add new properties in Context (#499)
+- feature: add single env policy wrapper for policy deployment
+- feature: add custom model demo and doc
+- fix: build logger args and unittests (#522)
+- fix: total_loss calculation in PDQN (#504)
+- fix: save gif function bug
+- fix: level sample unittest bug
+- style: update contact email address (#503)
+- style: polish env log and resblock name
+- style: add details button in readme
+
+2022.09.23(v0.4.3)
+- env: add rule-based gomoku expert (#465)
+- algo: fix a2c policy batch size bug (#481)
+- algo: enable activation option in collaq attention and mixer
+- algo: minor fix about IBC (#477)
+- feature: add IGM support (#486)
+- feature: add tb logger middleware and demo
+- fix: the type conversion in ding_env_wrapper (#483)
+- fix: di-orchestrator version bug in unittest (#479)
+- fix: data collection errors caused by shallow copies (#475)
+- fix: gym==0.26.0 seed args bug
+- style: add readme tutorial link(environment & algorithm) (#490) (#493)
+- style: adjust location of the default_model method in policy (#453)
+
+2022.09.08(v0.4.2)
+- env: add rocket env (#449)
+- env: updated pettingzoo env and improved related performance (#457)
+- env: add mario env demo (#443)
+- env: add MAPPO multi-agent config (#464)
+- env: add mountain car (discrete action) environment (#452)
+- env: fix multi-agent mujoco gym comaptibility bug
+- env: fix gfootball env save_replay variable init bug
+- algo: add IBC (Implicit Behaviour Cloning) algorithm (#401)
+- algo: add BCO (Behaviour Cloning from Observation) algorithm (#270)
+- algo: add continuous PPOPG algorithm (#414)
+- algo: add PER in CollaQ (#472)
+- algo: add activation option in QMIX and CollaQ
+- feature: update ctx to dataclass (#467)
+- fix: base_env FinalMeta bug about gym 0.25.0-0.25.1
+- fix: config inplace modification bug
+- fix: ding cli no argument problem
+- fix: import errors after running setup.py (jinja2, markupsafe)
+- fix: conda py3.6 and cross platform build bug
+- style: add project state and datetime in log dir (#455)
+- style: polish notes for q-learning model (#427)
+- style: revision to mujoco dockerfile and validation (#474)
+- style: add dockerfile for cityflow env
+- style: polish default output log format
+
+2022.08.12(v0.4.1)
+- env: add gym trading env (#424)
+- env: add board games env (tictactoe, gomuku, chess) (#356)
+- env: add sokoban env (#397) (#429)
+- env: add BC and DQN demo for gfootball (#418) (#423)
+- env: add discrete pendulum env (#395)
+- algo: add STEVE model-based algorithm (#363)
+- algo: add PLR algorithm (#408)
+- algo: plugin ST-DIM in PPO (#379)
+- feature: add final result saving in training pipeline
+- fix: random policy randomness bug
+- fix: action_space seed compalbility bug
+- fix: discard message sent by self in redis mq (#354)
+- fix: remove pace controller (#400)
+- fix: import error in serial_pipeline_trex (#410)
+- fix: unittest hang and fail bug (#413)
+- fix: DREX collect data unittest bug
+- fix: remove unused import cv2
+- fix: ding CLI env/policy option bug
+- style: upgrade Python version from 3.6-3.8 to 3.7-3.9
+- style: upgrade gym version from 0.20.0 to 0.25.0
+- style: upgrade torch version from 1.10.0 to 1.12.0
+- style: upgrade mujoco bin from 2.0.0 to 2.1.0
+- style: add buffer api description (#371)
+- style: polish VAE comments (#404)
+- style: unittest for FQF (#412)
+- style: add metaworld dockerfile (#432)
+- style: remove opencv requirement in default setting
+- style: update long description in setup.py
+
+2022.06.21(v0.4.0)
+- env: add MAPPO/MASAC all configs in SMAC (#310) **(SOTA results in SMAC!!!)**
+- env: add dmc2gym env (#344) (#360)
+- env: remove DI-star requirements of dizoo/smac, use official pysc2 (#302)
+- env: add latest GAIL mujoco config (#298)
+- env: polish procgen env (#311)
+- env: add MBPO ant and humanoid config for mbpo (#314)
+- env: fix slime volley env obs space bug when agent_vs_agent
+- env: fix smac env obs space bug
+- env: fix import path error in lunarlander (#362)
+- algo: add Decision Transformer algorithm (#327) (#364)
+- algo: add on-policy PPG algorithm (#312)
+- algo: add DDPPO & add model-based SAC with lambda-return algorithm (#332)
+- algo: add infoNCE loss and ST-DIM algorithm (#326)
+- algo: add FQF distributional RL algorithm (#274)
+- algo: add continuous BC algorithm (#318)
+- algo: add pure policy gradient PPO algorithm (#382)
+- algo: add SQIL + SAC algorithm (#348)
+- algo: polish NGU and related modules (#283) (#343) (#353)
+- algo: add marl distributional td loss (#331)
+- feature: add new worker middleware (#236)
+- feature: refactor model-based RL pipeline (ding/world_model) (#332)
+- feature: refactor logging system in the whole DI-engine (#316)
+- feature: add env supervisor design (#330)
+- feature: support async reset for envpool env manager (#250)
+- feature: add log videos to tensorboard (#320)
+- feature: refactor impala cnn encoder interface (#378)
+- fix: env save replay bug
+- fix: transformer mask inplace operation bug
+- fix: transtion_with_policy_data bug in SAC and PPG
+- style: add dockerfile for ding:hpc image (#337)
+- style: fix mpire 2.3.5 which handles default processes more elegantly (#306)
+- style: use FORMAT_DIR instead of ./ding (#309)
+- style: update quickstart colab link (#347)
+- style: polish comments in ding/model/common (#315)
+- style: update mujoco docker download path (#386)
+- style: fix protobuf new version compatibility bug
+- style: fix torch1.8.0 torch.div compatibility bug
+- style: update doc links in readme
+- style: add outline in readme and update wechat image
+- style: update head image and refactor docker dir
+
+2022.04.23(v0.3.1)
+- env: polish and standardize dizoo config (#252) (#255) (#249) (#246) (#262) (#261) (#266) (#273) (#263) (#280) (#259) (#286) (#277) (#290) (#289) (#299)
+- env: add GRF academic env and config (#281)
+- env: update env inferface of GRF (#258)
+- env: update D4RL offline RL env and config (#285)
+- env: polish PomdpAtariEnv (#254)
+- algo: DREX algorithm (#218)
+- feature: separate mq and parallel modules, add redis (#247)
+- feature: rename env variables; fix attach_to parameter (#244)
+- feature: env implementation check (#275)
+- feature: adjust and set the max column number of tabulate in log (#296)
+- feature: add drop_extra option for sample collect
+- feature: speed up GTrXL forward method + GRU unittest (#253) (#292)
+- fix: add act_scale in DingEnvWrapper; fix envpool env manager (#245)
+- fix: auto_reset=False and env_ref bug in env manager (#248)
+- fix: data type and deepcopy bug in RND (#288)
+- fix: share_memory bug and multi_mujoco env (#279)
+- fix: some bugs in GTrXL (#276)
+- fix: update gym_vector_env_manager and add more unittest (#241)
+- fix: mdpolicy random collect bug (#293)
+- fix: gym.wrapper save video replay bug
+- fix: collect abnormal step format bug and add unittest
+- test: add buffer benchmark & socket test (#284)
+- style: upgrade mpire (#251)
+- style: add GRF(google research football) docker (#256)
+- style: update policy and gail comment
+
+2022.03.24(v0.3.0)
+- env: add bitfilp HER DQN benchmark (#192) (#193) (#197)
+- env: slime volley league training demo (#229)
+- algo: Gated TransformXL (GTrXL) algorithm (#136)
+- algo: TD3 + VAE(HyAR) latent action algorithm (#152)
+- algo: stochastic dueling network (#234)
+- algo: use log prob instead of using prob in ACER (#186)
+- feature: support envpool env manager (#228)
+- feature: add league main and other improvements in new framework (#177) (#214)
+- feature: add pace controller middleware in new framework (#198)
+- feature: add auto recover option in new framework (#242)
+- feature: add k8s parser in new framework (#243)
+- feature: support async event handler and logger (#213)
+- feautre: add grad norm calculator (#205)
+- feautre: add gym vector env manager (#147)
+- feautre: add train_iter and env_step in serial pipeline (#212)
+- feautre: add rich logger handler (#219) (#223) (#232)
+- feature: add naive lr_scheduler demo
+- refactor: new BaseEnv and DingEnvWrapper (#171) (#231) (#240)
+- polish: MAPPO and MASAC smac config (#209) (#239)
+- polish: QMIX smac config (#175)
+- polish: R2D2 atari config (#181)
+- polish: A2C atari config (#189)
+- polish: GAIL box2d and mujoco config (#188)
+- polish: ACER atari config (#180)
+- polish: SQIL atari config (#230)
+- polish: TREX atari/mujoco config
+- polish: IMPALA atari config
+- polish: MBPO/D4PG mujoco config
+- fix: random_collect compatible to episode collector (#190)
+- fix: remove default n_sample/n_episode value in policy config (#185)
+- fix: PDQN model bug on gpu device (#220)
+- fix: TREX algorithm CLI bug (#182)
+- fix: DQfD JE computation bug and move to AdamW optimizer (#191)
+- fix: pytest problem for parallel middleware (#211)
+- fix: mujoco numpy compatibility bug
+- fix: markupsafe 2.1.0 bug
+- fix: framework parallel module network emit bug
+- fix: mpire bug and disable algotest in py3.8
+- fix: lunarlander env import and env_id bug
+- fix: icm unittest repeat name bug
+- fix: buffer thruput close bug
+- test: resnet unittest (#199)
+- test: SAC/SQN unittest (#207)
+- test: CQL/R2D3/GAIL unittest (#201)
+- test: NGU td unittest (#210)
+- test: model wrapper unittest (#215)
+- test: MAQAC model unittest (#226)
+- style: add doc docker (#221)
+
+2022.01.01(v0.2.3)
+- env: add multi-agent mujoco env (#146)
+- env: add delay reward mujoco env (#145)
+- env: fix port conflict in gym_soccer (#139)
+- algo: MASAC algorithm (#112)
+- algo: TREX algorithm (#119) (#144)
+- algo: H-PPO hybrid action space algorithm (#140)
+- algo: residual link in R2D2 (#150)
+- algo: gumbel softmax (#169)
+- algo: move actor_head_type to action_space field
+- feature: new main pipeline and async/parallel framework (#142) (#166) (#168)
+- feature: refactor buffer, separate algorithm and storage (#129)
+- feature: cli in new pipeline(ditask) (#160)
+- feature: add multiprocess tblogger, fix circular reference problem (#156)
+- feature: add multiple seed cli
+- feature: polish eps_greedy_multinomial_sample in model_wrapper (#154)
+- fix: R2D3 abs priority problem (#158) (#161)
+- fix: multi-discrete action space policies random action bug (#167)
+- fix: doc generate bug with enum_tools (#155)
+- style: more comments about R2D2 (#149)
+- style: add doc about how to migrate a new env
+- style: add doc about env tutorial in dizoo
+- style: add conda auto release (#148)
+- style: udpate zh doc link
+- style: update kaggle tutorial link
+
+2021.12.03(v0.2.2)
+- env: apple key to door treasure env (#128)
+- env: add bsuite memory benchmark (#138)
+- env: polish atari impala config
+- algo: Guided Cost IRL algorithm (#57)
+- algo: ICM exploration algorithm (#41)
+- algo: MP-DQN hybrid action space algorithm (#131)
+- algo: add loss statistics and polish r2d3 pong config (#126)
+- feautre: add renew env mechanism in env manager and update timeout mechanism (#127) (#134)
+- fix: async subprocess env manager reset bug (#137)
+- fix: keepdims name bug in model wrapper
+- fix: on-policy ppo value norm bug
+- fix: GAE and RND unittest bug
+- fix: hidden state wrapper h tensor compatiblity
+- fix: naive buffer auto config create bug
+- style: add supporters list
+
+2021.11.22(v0.2.1)
+- env: gym-hybrid env (#86)
+- env: gym-soccer (HFO) env (#94)
+- env: Go-Bigger env baseline (#95)
+- env: add the bipedalwalker config of sac and ppo (#121)
+- algo: DQfD Imitation Learning algorithm (#48) (#98)
+- algo: TD3BC offline RL algorithm (#88)
+- algo: MBPO model-based RL algorithm (#113)
+- algo: PADDPG hybrid action space algorithm (#109)
+- algo: PDQN hybrid action space algorithm (#118)
+- algo: fix R2D2 bugs and produce benchmark, add naive NGU (#40)
+- algo: self-play training demo in slime_volley env (#23)
+- algo: add example of GAIL entry + config for mujoco (#114)
+- feature: enable arbitrary policy num in serial sample collector
+- feautre: add torch DataParallel for single machine multi-GPU
+- feature: add registry force_overwrite argument
+- feature: add naive buffer periodic thruput seconds argument
+- test: add pure docker setting test (#103)
+- test: add unittest for dataset and evaluator (#107)
+- test: add unittest for on-policy algorithm (#92)
+- test: add unittest for ppo and td (MARL case) (#89)
+- test: polish collector benchmark test
+- fix: target model wrapper hard reset bug
+- fix: fix learn state_dict target model bug
+- fix: ppo bugs and update atari ppo offpolicy config (#108)
+- fix: pyyaml version bug (#99)
+- fix: small fix on bsuite environment (#117)
+- fix: discrete cql unittest bug
+- fix: release workflow bug
+- fix: base policy model state_dict overlap bug
+- fix: remove on_policy option in dizoo config and entry
+- fix: remove torch in env
+- style: gym version > 0.20.0
+- style: torch version >= 1.1.0, <= 1.10.0
+- style: ale-py == 0.7.0
+
+2021.9.30(v0.2.0)
+- env: overcooked env (#20)
+- env: procgen env (#26)
+- env: modified predator env (#30)
+- env: d4rl env (#37)
+- env: imagenet dataset (#27)
+- env: bsuite env (#58)
+- env: move atari_py to ale-py
+- algo: SQIL algorithm (#25) (#44)
+- algo: CQL algorithm (discrete/continuous) (#37) (#68)
+- algo: MAPPO algorithm (#62)
+- algo: WQMIX algorithm (#24)
+- algo: D4PG algorithm (#76)
+- algo: update multi discrete policy(dqn, ppo, rainbow) (#51) (#72)
+- feature: image classification training pipeline (#27)
+- feature: add force_reproducibility option in subprocess env manager
+- feature: add/delete/restart replicas via cli for k8s
+- feautre: add league metric (trueskill and elo) (#22)
+- feature: add tb in naive buffer and modify tb in advanced buffer (#39)
+- feature: add k8s launcher and di-orchestrator launcher, add related unittest (#45) (#49)
+- feature: add hyper-parameter scheduler module (#38)
+- feautre: add plot function (#59)
+- fix: acer bug and update atari result (#21)
+- fix: mappo nan bug and dict obs cannot unsqueeze bug (#54)
+- fix: r2d2 hidden state and obs arange bug (#36) (#52)
+- fix: ppo bug when use dual_clip and adv > 0
+- fix: qmix double_q hidden state bug
+- fix: spawn context problem in interaction unittest (#69)
+- fix: formatted config no eval bug (#53)
+- fix: the catch statments that will never succeed and system proxy bug (#71) (#79)
+- fix: lunarlander config
+- fix: c51 head dimension mismatch bug
+- fix: mujoco config typo bug
+- fix: ppg atari config bug
+- fix: max use and priority update special branch bug in advanced_buffer
+- style: add docker deploy in github workflow (#70) (#78) (#80)
+- style: support PyTorch 1.9.0
+- style: add algo/env list in README
+- style: rename advanced_buffer register name to advanced
+
+
+2021.8.3(v0.1.1)
+- env: selfplay/league demo (#12)
+- env: pybullet env (#16)
+- env: minigrid env (#13)
+- env: atari enduro config (#11)
+- algo: on policy PPO (#9)
+- algo: ACER algorithm (#14)
+- feature: polish experiment directory structure (#10)
+- refactor: split doc to new repo (#4)
+- fix: atari env info action space bug
+- fix: env manager retry wrapper raise exception info bug
+- fix: dist entry disable-flask-log typo
+- style: codestyle optimization by lgtm (#7)
+- style: code/comment statistics badge
+- style: github CI workflow
+
+2021.7.8(v0.1.0)
diff --git a/DI-engine/CODE_OF_CONDUCT.md b/DI-engine/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..879bcacfaf91e8f1add41eeafb70ea3b5c193d3e
--- /dev/null
+++ b/DI-engine/CODE_OF_CONDUCT.md
@@ -0,0 +1,128 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our
+community a harassment-free experience for everyone, regardless of age, body
+size, visible or invisible disability, ethnicity, sex characteristics, gender
+identity and expression, level of experience, education, socio-economic status,
+nationality, personal appearance, race, religion, or sexual identity
+and orientation.
+
+We pledge to act and interact in ways that contribute to an open, welcoming,
+diverse, inclusive, and healthy community.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment for our
+community include:
+
+* Demonstrating empathy and kindness toward other people
+* Being respectful of differing opinions, viewpoints, and experiences
+* Giving and gracefully accepting constructive feedback
+* Accepting responsibility and apologizing to those affected by our mistakes,
+ and learning from the experience
+* Focusing on what is best not just for us as individuals, but for the
+ overall community
+
+Examples of unacceptable behavior include:
+
+* The use of sexualized language or imagery, and sexual attention or
+ advances of any kind
+* Trolling, insulting or derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or email
+ address, without their explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Enforcement Responsibilities
+
+Community leaders are responsible for clarifying and enforcing our standards of
+acceptable behavior and will take appropriate and fair corrective action in
+response to any behavior that they deem inappropriate, threatening, offensive,
+or harmful.
+
+Community leaders have the right and responsibility to remove, edit, or reject
+comments, commits, code, wiki edits, issues, and other contributions that are
+not aligned to this Code of Conduct, and will communicate reasons for moderation
+decisions when appropriate.
+
+## Scope
+
+This Code of Conduct applies within all community spaces, and also applies when
+an individual is officially representing the community in public spaces.
+Examples of representing our community include using an official e-mail address,
+posting via an official social media account, or acting as an appointed
+representative at an online or offline event.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported to the community leaders responsible for enforcement at
+opendilab.contact@gmail.com.
+All complaints will be reviewed and investigated promptly and fairly.
+
+All community leaders are obligated to respect the privacy and security of the
+reporter of any incident.
+
+## Enforcement Guidelines
+
+Community leaders will follow these Community Impact Guidelines in determining
+the consequences for any action they deem in violation of this Code of Conduct:
+
+### 1. Correction
+
+**Community Impact**: Use of inappropriate language or other behavior deemed
+unprofessional or unwelcome in the community.
+
+**Consequence**: A private, written warning from community leaders, providing
+clarity around the nature of the violation and an explanation of why the
+behavior was inappropriate. A public apology may be requested.
+
+### 2. Warning
+
+**Community Impact**: A violation through a single incident or series
+of actions.
+
+**Consequence**: A warning with consequences for continued behavior. No
+interaction with the people involved, including unsolicited interaction with
+those enforcing the Code of Conduct, for a specified period of time. This
+includes avoiding interactions in community spaces as well as external channels
+like social media. Violating these terms may lead to a temporary or
+permanent ban.
+
+### 3. Temporary Ban
+
+**Community Impact**: A serious violation of community standards, including
+sustained inappropriate behavior.
+
+**Consequence**: A temporary ban from any sort of interaction or public
+communication with the community for a specified period of time. No public or
+private interaction with the people involved, including unsolicited interaction
+with those enforcing the Code of Conduct, is allowed during this period.
+Violating these terms may lead to a permanent ban.
+
+### 4. Permanent Ban
+
+**Community Impact**: Demonstrating a pattern of violation of community
+standards, including sustained inappropriate behavior, harassment of an
+individual, or aggression toward or disparagement of classes of individuals.
+
+**Consequence**: A permanent ban from any sort of public interaction within
+the community.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage],
+version 2.0, available at
+https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
+
+Community Impact Guidelines were inspired by [Mozilla's code of conduct
+enforcement ladder](https://github.com/mozilla/diversity).
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see the FAQ at
+https://www.contributor-covenant.org/faq. Translations are available at
+https://www.contributor-covenant.org/translations.
diff --git a/DI-engine/CONTRIBUTING.md b/DI-engine/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..ecfc1e0670a60c4c8aa2cbd5929347ea85a0155f
--- /dev/null
+++ b/DI-engine/CONTRIBUTING.md
@@ -0,0 +1,7 @@
+[Git Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/git_guide.html)
+
+[GitHub Cooperation Guide](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html)
+
+ - [Code Style](https://di-engine-docs.readthedocs.io/en/latest/21_code_style/index.html)
+ - [Unit Test](https://di-engine-docs.readthedocs.io/en/latest/22_test/index.html)
+ - [Code Review](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/issue_pr.html#pr-s-code-review)
diff --git a/DI-engine/LICENSE b/DI-engine/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..afdfe50e72e0e63f3b2bd373e6147a170277ffdc
--- /dev/null
+++ b/DI-engine/LICENSE
@@ -0,0 +1,202 @@
+
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright 2017 Google Inc.
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/DI-engine/Makefile b/DI-engine/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..39810b7871edd80d0e73ccdfde6ed0c7f2700455
--- /dev/null
+++ b/DI-engine/Makefile
@@ -0,0 +1,71 @@
+CI ?=
+
+# Directory variables
+DING_DIR ?= ./ding
+DIZOO_DIR ?= ./dizoo
+RANGE_DIR ?=
+TEST_DIR ?= $(if ${RANGE_DIR},${RANGE_DIR},${DING_DIR})
+COV_DIR ?= $(if ${RANGE_DIR},${RANGE_DIR},${DING_DIR})
+FORMAT_DIR ?= $(if ${RANGE_DIR},${RANGE_DIR},${DING_DIR})
+PLATFORM_TEST_DIR ?= $(if ${RANGE_DIR},${RANGE_DIR},${DING_DIR}/entry/tests/test_serial_entry.py ${DING_DIR}/entry/tests/test_serial_entry_onpolicy.py)
+
+# Workers command
+WORKERS ?= 2
+WORKERS_COMMAND := $(if ${WORKERS},-n ${WORKERS} --dist=loadscope,)
+
+# Duration command
+DURATIONS ?= 10
+DURATIONS_COMMAND := $(if ${DURATIONS},--durations=${DURATIONS},)
+
+docs:
+ $(MAKE) -C ${DING_DIR}/docs html
+
+unittest:
+ pytest ${TEST_DIR} \
+ --cov-report=xml \
+ --cov-report term-missing \
+ --cov=${COV_DIR} \
+ ${DURATIONS_COMMAND} \
+ ${WORKERS_COMMAND} \
+ -sv -m unittest \
+
+algotest:
+ pytest ${TEST_DIR} \
+ ${DURATIONS_COMMAND} \
+ -sv -m algotest
+
+cudatest:
+ pytest ${TEST_DIR} \
+ -sv -m cudatest
+
+envpooltest:
+ pytest ${TEST_DIR} \
+ -sv -m envpooltest
+
+dockertest:
+ ${DING_DIR}/scripts/docker-test-entry.sh
+
+platformtest:
+ pytest ${TEST_DIR} \
+ --cov-report term-missing \
+ --cov=${COV_DIR} \
+ ${WORKERS_COMMAND} \
+ -sv -m platformtest
+
+benchmark:
+ pytest ${TEST_DIR} \
+ --durations=0 \
+ -sv -m benchmark
+
+test: unittest # just for compatibility, can be changed later
+
+cpu_test: unittest algotest benchmark
+
+all_test: unittest algotest cudatest benchmark
+
+format:
+ yapf --in-place --recursive -p --verbose --style .style.yapf ${FORMAT_DIR}
+format_test:
+ bash format.sh ${FORMAT_DIR} --test
+flake_check:
+ flake8 ${FORMAT_DIR}
diff --git a/DI-engine/README.md b/DI-engine/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4833c8fec0451421f947c14c424b80b7f5c888f8
--- /dev/null
+++ b/DI-engine/README.md
@@ -0,0 +1,475 @@
+
+
+
+
+---
+
+[![Twitter](https://img.shields.io/twitter/url?style=social&url=https%3A%2F%2Ftwitter.com%2Fopendilab)](https://twitter.com/opendilab)
+[![PyPI](https://img.shields.io/pypi/v/DI-engine)](https://pypi.org/project/DI-engine/)
+![Conda](https://anaconda.org/opendilab/di-engine/badges/version.svg)
+![Conda update](https://anaconda.org/opendilab/di-engine/badges/latest_release_date.svg)
+![PyPI - Python Version](https://img.shields.io/pypi/pyversions/DI-engine)
+![PyTorch Version](https://img.shields.io/badge/dynamic/json?color=blue&label=pytorch&query=%24.pytorchVersion&url=https%3A%2F%2Fgist.githubusercontent.com/PaParaZz1/54c5c44eeb94734e276b2ed5770eba8d/raw/85b94a54933a9369f8843cc2cea3546152a75661/badges.json)
+
+![Loc](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/HansBug/3690cccd811e4c5f771075c2f785c7bb/raw/loc.json)
+![Comments](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/HansBug/3690cccd811e4c5f771075c2f785c7bb/raw/comments.json)
+
+![Style](https://github.com/opendilab/DI-engine/actions/workflows/style.yml/badge.svg)
+[![Read en Docs](https://github.com/opendilab/DI-engine/actions/workflows/doc.yml/badge.svg)](https://di-engine-docs.readthedocs.io/en/latest)
+[![Read zh_CN Docs](https://img.shields.io/readthedocs/di-engine-docs?label=%E4%B8%AD%E6%96%87%E6%96%87%E6%A1%A3)](https://di-engine-docs.readthedocs.io/zh_CN/latest)
+![Unittest](https://github.com/opendilab/DI-engine/actions/workflows/unit_test.yml/badge.svg)
+![Algotest](https://github.com/opendilab/DI-engine/actions/workflows/algo_test.yml/badge.svg)
+![deploy](https://github.com/opendilab/DI-engine/actions/workflows/deploy.yml/badge.svg)
+[![codecov](https://codecov.io/gh/opendilab/DI-engine/branch/main/graph/badge.svg?token=B0Q15JI301)](https://codecov.io/gh/opendilab/DI-engine)
+
+
+
+![GitHub Org's stars](https://img.shields.io/github/stars/opendilab)
+[![GitHub stars](https://img.shields.io/github/stars/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/stargazers)
+[![GitHub forks](https://img.shields.io/github/forks/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/network)
+![GitHub commit activity](https://img.shields.io/github/commit-activity/m/opendilab/DI-engine)
+[![GitHub issues](https://img.shields.io/github/issues/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/issues)
+[![GitHub pulls](https://img.shields.io/github/issues-pr/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/pulls)
+[![Contributors](https://img.shields.io/github/contributors/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/graphs/contributors)
+[![GitHub license](https://img.shields.io/github/license/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/blob/master/LICENSE)
+[![Hugging Face](https://img.shields.io/badge/%F0%9F%A4%97%20Hugging%20Face-Models-yellow)](https://huggingface.co/OpenDILabCommunity)
+[![Open in OpenXLab](https://cdn-static.openxlab.org.cn/header/openxlab_models.svg)](https://openxlab.org.cn/models?search=opendilab)
+
+Updated on 2023.12.05 DI-engine-v0.5.0
+
+
+## Introduction to DI-engine
+[Documentation](https://di-engine-docs.readthedocs.io/en/latest/) | [中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/) | [Tutorials](https://di-engine-docs.readthedocs.io/en/latest/01_quickstart/index.html) | [Feature](#feature) | [Task & Middleware](https://di-engine-docs.readthedocs.io/en/latest/03_system/index.html) | [TreeTensor](#general-data-container-treetensor) | [Roadmap](https://github.com/opendilab/DI-engine/issues/548)
+
+**DI-engine** is a generalized decision intelligence engine for PyTorch and JAX.
+
+It provides **python-first** and **asynchronous-native** task and middleware abstractions, and modularly integrates several of the most important decision-making concepts: Env, Policy and Model. Based on the above mechanisms, DI-engine supports **various [deep reinforcement learning](https://di-engine-docs.readthedocs.io/en/latest/10_concepts/index.html) algorithms** with superior performance, high efficiency, well-organized [documentation](https://di-engine-docs.readthedocs.io/en/latest/) and [unittest](https://github.com/opendilab/DI-engine/actions):
+
+- Most basic DRL algorithms: such as DQN, Rainbow, PPO, TD3, SAC, R2D2, IMPALA
+- Multi-agent RL algorithms: such as QMIX, WQMIX, MAPPO, HAPPO, ACE
+- Imitation learning algorithms (BC/IRL/GAIL): such as GAIL, SQIL, Guided Cost Learning, Implicit BC
+- Offline RL algorithms: BCQ, CQL, TD3BC, Decision Transformer, EDAC, Diffuser, Decision Diffuser, SO2
+- Model-based RL algorithms: SVG, STEVE, MBPO, DDPPO, DreamerV3, MuZero
+- Exploration algorithms: HER, RND, ICM, NGU
+- LLM + RL Algorithms: PPO-max, DPO, MPDPO
+- Other algorithms: such as PER, PLR, PCGrad
+
+**DI-engine** aims to **standardize different Decision Intelligence environments and applications**, supporting both academic research and prototype applications. Various training pipelines and customized decision AI applications are also supported:
+
+
+(Click to Collapse)
+
+- Traditional academic environments
+ - [DI-zoo](https://github.com/opendilab/DI-engine#environment-versatility): various decision intelligence demonstrations and benchmark environments with DI-engine.
+- Tutorial courses
+ - [PPOxFamily](https://github.com/opendilab/PPOxFamily): PPO x Family DRL Tutorial Course
+- Real world decision AI applications
+ - [DI-star](https://github.com/opendilab/DI-star): Decision AI in StarCraftII
+ - [DI-drive](https://github.com/opendilab/DI-drive): Auto-driving platform
+ - [DI-sheep](https://github.com/opendilab/DI-sheep): Decision AI in 3 Tiles Game
+ - [DI-smartcross](https://github.com/opendilab/DI-smartcross): Decision AI in Traffic Light Control
+ - [DI-bioseq](https://github.com/opendilab/DI-bioseq): Decision AI in Biological Sequence Prediction and Searching
+ - [DI-1024](https://github.com/opendilab/DI-1024): Deep Reinforcement Learning + 1024 Game
+- Research paper
+ - [InterFuser](https://github.com/opendilab/InterFuser): [CoRL 2022] Safety-Enhanced Autonomous Driving Using Interpretable Sensor Fusion Transformer
+ - [ACE](https://github.com/opendilab/ACE): [AAAI 2023] ACE: Cooperative Multi-agent Q-learning with Bidirectional Action-Dependency
+ - [GoBigger](https://github.com/opendilab/GoBigger): [ICLR 2023] Multi-Agent Decision Intelligence Environment
+ - [DOS](https://github.com/opendilab/DOS): [CVPR 2023] ReasonNet: End-to-End Driving with Temporal and Global Reasoning
+ - [LightZero](https://github.com/opendilab/LightZero): [NeurIPS 2023 Spotlight] A lightweight and efficient MCTS/AlphaZero/MuZero algorithm toolkit
+ - [SO2](https://github.com/opendilab/SO2): [AAAI 2024] A Perspective of Q-value Estimation on Offline-to-Online Reinforcement Learning
+ - [LMDrive](https://github.com/opendilab/LMDrive): LMDrive: Closed-Loop End-to-End Driving with Large Language Models
+- Docs and Tutorials
+ - [DI-engine-docs](https://github.com/opendilab/DI-engine-docs): Tutorials, best practice and the API reference.
+ - [awesome-model-based-RL](https://github.com/opendilab/awesome-model-based-RL): A curated list of awesome Model-Based RL resources
+ - [awesome-exploration-RL](https://github.com/opendilab/awesome-exploration-rl): A curated list of awesome exploration RL resources
+ - [awesome-decision-transformer](https://github.com/opendilab/awesome-decision-transformer): A curated list of Decision Transformer resources
+ - [awesome-RLHF](https://github.com/opendilab/awesome-RLHF): A curated list of reinforcement learning with human feedback resources
+ - [awesome-multi-modal-reinforcement-learning](https://github.com/opendilab/awesome-multi-modal-reinforcement-learning): A curated list of Multi-Modal Reinforcement Learning resources
+ - [awesome-AI-based-protein-design](https://github.com/opendilab/awesome-AI-based-protein-design): a collection of research papers for AI-based protein design
+ - [awesome-diffusion-model-in-rl](https://github.com/opendilab/awesome-diffusion-model-in-rl): A curated list of Diffusion Model in RL resources
+ - [awesome-end-to-end-autonomous-driving](https://github.com/opendilab/awesome-end-to-end-autonomous-driving): A curated list of awesome End-to-End Autonomous Driving resources
+ - [awesome-driving-behavior-prediction](https://github.com/opendilab/awesome-driving-behavior-prediction): A collection of research papers for Driving Behavior Prediction
+
+
+On the low-level end, DI-engine comes with a set of highly re-usable modules, including [RL optimization functions](https://github.com/opendilab/DI-engine/tree/main/ding/rl_utils), [PyTorch utilities](https://github.com/opendilab/DI-engine/tree/main/ding/torch_utils) and [auxiliary tools](https://github.com/opendilab/DI-engine/tree/main/ding/utils).
+
+BTW, **DI-engine** also has some special **system optimization and design** for efficient and robust large-scale RL training:
+
+
+(Click for Details)
+
+- [treevalue](https://github.com/opendilab/treevalue): Tree-nested data structure
+- [DI-treetensor](https://github.com/opendilab/DI-treetensor): Tree-nested PyTorch tensor Lib
+- [DI-toolkit](https://github.com/opendilab/DI-toolkit): A simple toolkit package for decision intelligence
+- [DI-orchestrator](https://github.com/opendilab/DI-orchestrator): RL Kubernetes Custom Resource and Operator Lib
+- [DI-hpc](https://github.com/opendilab/DI-hpc): RL HPC OP Lib
+- [DI-store](https://github.com/opendilab/DI-store): RL Object Store
+
+
+Have fun with exploration and exploitation.
+
+## Outline
+
+- [Introduction to DI-engine](#introduction-to-di-engine)
+- [Outline](#outline)
+- [Installation](#installation)
+- [Quick Start](#quick-start)
+- [Feature](#feature)
+ - [Algorithm Versatility](#algorithm-versatility)
+ - [Environment Versatility](#environment-versatility)
+ - [General Data Container: TreeTensor](#general-data-container-treetensor)
+- [Feedback and Contribution](#feedback-and-contribution)
+- [Supporters](#supporters)
+ - [↳ Stargazers](#-stargazers)
+ - [↳ Forkers](#-forkers)
+- [Citation](#citation)
+- [License](#license)
+
+## Installation
+
+You can simply install DI-engine from PyPI with the following command:
+```bash
+pip install DI-engine
+```
+
+If you use Anaconda or Miniconda, you can install DI-engine from conda-forge through the following command:
+```bash
+conda install -c opendilab di-engine
+```
+
+For more information about installation, you can refer to [installation](https://di-engine-docs.readthedocs.io/en/latest/01_quickstart/installation.html).
+
+And our dockerhub repo can be found [here](https://hub.docker.com/repository/docker/opendilab/ding),we prepare `base image` and `env image` with common RL environments.
+
+
+(Click for Details)
+
+- base: opendilab/ding:nightly
+- rpc: opendilab/ding:nightly-rpc
+- atari: opendilab/ding:nightly-atari
+- mujoco: opendilab/ding:nightly-mujoco
+- dmc: opendilab/ding:nightly-dmc2gym
+- metaworld: opendilab/ding:nightly-metaworld
+- smac: opendilab/ding:nightly-smac
+- grf: opendilab/ding:nightly-grf
+- cityflow: opendilab/ding:nightly-cityflow
+- evogym: opendilab/ding:nightly-evogym
+- d4rl: opendilab/ding:nightly-d4rl
+
+
+The detailed documentation are hosted on [doc](https://di-engine-docs.readthedocs.io/en/latest/) | [中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/).
+
+## Quick Start
+
+[3 Minutes Kickoff](https://di-engine-docs.readthedocs.io/en/latest/01_quickstart/first_rl_program.html)
+
+[3 Minutes Kickoff (colab)](https://colab.research.google.com/drive/1_7L-QFDfeCvMvLJzRyBRUW5_Q6ESXcZ4)
+
+[DI-engine Huggingface Kickoff (colab)](https://colab.research.google.com/drive/1UH1GQOjcHrmNSaW77hnLGxFJrLSLwCOk)
+
+[How to migrate a new **RL Env**](https://di-engine-docs.readthedocs.io/en/latest/11_dizoo/index.html) | [如何迁移一个新的**强化学习环境**](https://di-engine-docs.readthedocs.io/zh_CN/latest/11_dizoo/index_zh.html)
+
+[How to customize the neural network model](https://di-engine-docs.readthedocs.io/en/latest/04_best_practice/custom_model.html) | [如何定制策略使用的**神经网络模型**](https://di-engine-docs.readthedocs.io/zh_CN/latest/04_best_practice/custom_model_zh.html)
+
+[测试/部署 **强化学习策略** 的样例](https://github.com/opendilab/DI-engine/blob/main/dizoo/classic_control/cartpole/entry/cartpole_c51_deploy.py)
+
+[新老 pipeline 的异同对比](https://di-engine-docs.readthedocs.io/zh_CN/latest/04_best_practice/diff_in_new_pipeline_zh.html)
+
+
+## Feature
+### Algorithm Versatility
+
+
+(Click to Collapse)
+
+![discrete](https://img.shields.io/badge/-discrete-brightgreen) discrete means discrete action space, which is only label in normal DRL algorithms (1-23)
+
+![continuous](https://img.shields.io/badge/-continous-green) means continuous action space, which is only label in normal DRL algorithms (1-23)
+
+![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) means hybrid (discrete + continuous) action space (1-23)
+
+![dist](https://img.shields.io/badge/-distributed-blue) [Distributed Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/distributed_rl.html)|[分布式强化学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/distributed_rl_zh.html)
+
+![MARL](https://img.shields.io/badge/-MARL-yellow) [Multi-Agent Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/multi_agent_cooperation_rl.html)|[多智能体强化学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/multi_agent_cooperation_rl_zh.html)
+
+![exp](https://img.shields.io/badge/-exploration-orange) [Exploration Mechanisms in Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/exploration_rl.html)|[强化学习中的探索机制](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/exploration_rl_zh.html)
+
+![IL](https://img.shields.io/badge/-IL-purple) [Imitation Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/imitation_learning.html)|[模仿学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/imitation_learning_zh.html)
+
+![offline](https://img.shields.io/badge/-offlineRL-darkblue) [Offiline Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/offline_rl.html)|[离线强化学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/offline_rl_zh.html)
+
+
+![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) [Model-Based Reinforcement Learning](https://di-engine-docs.readthedocs.io/en/latest/02_algo/model_based_rl.html)|[基于模型的强化学习](https://di-engine-docs.readthedocs.io/zh_CN/latest/02_algo/model_based_rl_zh.html)
+
+![other](https://img.shields.io/badge/-other-lightgrey) means other sub-direction algorithms, usually as plugin-in in the whole pipeline
+
+P.S: The `.py` file in `Runnable Demo` can be found in `dizoo`
+
+
+
+| No. | Algorithm | Label | Doc and Implementation | Runnable Demo |
+| :--: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: | :----------------------------------------------------------: |
+| 1 | [DQN](https://storage.googleapis.com/deepmind-media/dqn/DQNNaturePaper.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [DQN doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/dqn.html) [DQN中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/dqn_zh.html) [policy/dqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u cartpole_dqn_main.py / ding -m serial -c cartpole_dqn_config.py -s 0 |
+| 2 | [C51](https://arxiv.org/pdf/1707.06887.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [C51 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/c51.html) [policy/c51](https://github.com/opendilab/DI-engine/blob/main/ding/policy/c51.py) | ding -m serial -c cartpole_c51_config.py -s 0 |
+| 3 | [QRDQN](https://arxiv.org/pdf/1710.10044.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [QRDQN doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/qrdqn.html) [policy/qrdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qrdqn.py) | ding -m serial -c cartpole_qrdqn_config.py -s 0 |
+| 4 | [IQN](https://arxiv.org/pdf/1806.06923.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [IQN doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/iqn.html) [policy/iqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/iqn.py) | ding -m serial -c cartpole_iqn_config.py -s 0 |
+| 5 | [FQF](https://arxiv.org/pdf/1911.02140.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [FQF doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/fqf.html) [policy/fqf](https://github.com/opendilab/DI-engine/blob/main/ding/policy/fqf.py) | ding -m serial -c cartpole_fqf_config.py -s 0 |
+| 6 | [Rainbow](https://arxiv.org/pdf/1710.02298.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [Rainbow doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/rainbow.html) [policy/rainbow](https://github.com/opendilab/DI-engine/blob/main/ding/policy/rainbow.py) | ding -m serial -c cartpole_rainbow_config.py -s 0 |
+| 7 | [SQL](https://arxiv.org/pdf/1702.08165.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green) | [SQL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sql.html) [policy/sql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sql.py) | ding -m serial -c cartpole_sql_config.py -s 0 |
+| 8 | [R2D2](https://openreview.net/forum?id=r1lyTjAqYX) | ![dist](https://img.shields.io/badge/-distributed-blue)![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [R2D2 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/r2d2.html) [policy/r2d2](https://github.com/opendilab/DI-engine/blob/main/ding/policy/r2d2.py) | ding -m serial -c cartpole_r2d2_config.py -s 0 |
+| 9 | [PG](https://proceedings.neurips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [PG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/a2c.html) [policy/pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pg.py) | ding -m serial -c cartpole_pg_config.py -s 0 |
+| 10 | [PromptPG](https://arxiv.org/abs/2209.14610) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/prompt_pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/prompt_pg.py) | ding -m serial_onpolicy -c tabmwp_pg_config.py -s 0 |
+| 11 | [A2C](https://arxiv.org/pdf/1602.01783.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [A2C doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/a2c.html) [policy/a2c](https://github.com/opendilab/DI-engine/blob/main/ding/policy/a2c.py) | ding -m serial -c cartpole_a2c_config.py -s 0 |
+| 12 | [PPO](https://arxiv.org/abs/1707.06347)/[MAPPO](https://arxiv.org/pdf/2103.01955.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [PPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppo.html) [policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | python3 -u cartpole_ppo_main.py / ding -m serial_onpolicy -c cartpole_ppo_config.py -s 0 |
+| 13 | [PPG](https://arxiv.org/pdf/2009.04416.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [PPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ppg.html) [policy/ppg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppg.py) | python3 -u cartpole_ppg_main.py |
+| 14 | [ACER](https://arxiv.org/pdf/1611.01224.pdf) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green) | [ACER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/acer.html) [policy/acer](https://github.com/opendilab/DI-engine/blob/main/ding/policy/acer.py) | ding -m serial -c cartpole_acer_config.py -s 0 |
+| 15 | [IMPALA](https://arxiv.org/abs/1802.01561) | ![dist](https://img.shields.io/badge/-distributed-blue)![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [IMPALA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/impala.html) [policy/impala](https://github.com/opendilab/DI-engine/blob/main/ding/policy/impala.py) | ding -m serial -c cartpole_impala_config.py -s 0 |
+| 16 | [DDPG](https://arxiv.org/pdf/1509.02971.pdf)/[PADDPG](https://arxiv.org/pdf/1511.04143.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [DDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html) [policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c pendulum_ddpg_config.py -s 0 |
+| 17 | [TD3](https://arxiv.org/pdf/1802.09477.pdf) | ![continuous](https://img.shields.io/badge/-continous-green)![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [TD3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3.html) [policy/td3](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3.py) | python3 -u pendulum_td3_main.py / ding -m serial -c pendulum_td3_config.py -s 0 |
+| 18 | [D4PG](https://arxiv.org/pdf/1804.08617.pdf) | ![continuous](https://img.shields.io/badge/-continous-green) | [D4PG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/d4pg.html) [policy/d4pg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/d4pg.py) | python3 -u pendulum_d4pg_config.py |
+| 19 | [SAC](https://arxiv.org/abs/1801.01290)/[MASAC] | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![continuous](https://img.shields.io/badge/-continous-green)![MARL](https://img.shields.io/badge/-MARL-yellow) | [SAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sac.html) [policy/sac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/sac.py) | ding -m serial -c pendulum_sac_config.py -s 0 |
+| 20 | [PDQN](https://arxiv.org/pdf/1810.06394.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_pdqn_config.py -s 0 |
+| 21 | [MPDQN](https://arxiv.org/pdf/1905.04388.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/pdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/pdqn.py) | ding -m serial -c gym_hybrid_mpdqn_config.py -s 0 |
+| 22 | [HPPO](https://arxiv.org/pdf/1903.01344.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/ppo](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ppo.py) | ding -m serial_onpolicy -c gym_hybrid_hppo_config.py -s 0 |
+| 23 | [BDQ](https://arxiv.org/pdf/1711.08946.pdf) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | [policy/bdq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqn.py) | python3 -u hopper_bdq_config.py |
+| 24 | [MDQN](https://arxiv.org/abs/2007.14430) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | [policy/mdqn](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mdqn.py) | python3 -u asterix_mdqn_config.py |
+| 25 | [QMIX](https://arxiv.org/pdf/1803.11485.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [QMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/qmix.html) [policy/qmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qmix.py) | ding -m serial -c smac_3s5z_qmix_config.py -s 0 |
+| 26 | [COMA](https://arxiv.org/pdf/1705.08926.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [COMA doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/coma.html) [policy/coma](https://github.com/opendilab/DI-engine/blob/main/ding/policy/coma.py) | ding -m serial -c smac_3s5z_coma_config.py -s 0 |
+| 27 | [QTran](https://arxiv.org/abs/1905.05408) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [policy/qtran](https://github.com/opendilab/DI-engine/blob/main/ding/policy/qtran.py) | ding -m serial -c smac_3s5z_qtran_config.py -s 0 |
+| 28 | [WQMIX](https://arxiv.org/abs/2006.10800) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [WQMIX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/wqmix.html) [policy/wqmix](https://github.com/opendilab/DI-engine/blob/main/ding/policy/wqmix.py) | ding -m serial -c smac_3s5z_wqmix_config.py -s 0 |
+| 29 | [CollaQ](https://arxiv.org/pdf/2010.08531.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [CollaQ doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/collaq.html) [policy/collaq](https://github.com/opendilab/DI-engine/blob/main/ding/policy/collaq.py) | ding -m serial -c smac_3s5z_collaq_config.py -s 0 |
+| 30 | [MADDPG](https://arxiv.org/pdf/1706.02275.pdf) | ![MARL](https://img.shields.io/badge/-MARL-yellow) | [MADDPG doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/ddpg.html) [policy/ddpg](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ddpg.py) | ding -m serial -c ant_maddpg_config.py -s 0 |
+| 31 | [GAIL](https://arxiv.org/pdf/1606.03476.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [GAIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/gail.html) [reward_model/gail](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/gail_irl_model.py) | ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0 |
+| 32 | [SQIL](https://arxiv.org/pdf/1905.11108.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [SQIL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/sqil.html) [entry/sqil](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_sqil.py) | ding -m serial_sqil -c cartpole_sqil_config.py -s 0 |
+| 33 | [DQFD](https://arxiv.org/pdf/1704.03732.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [DQFD doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/dqfd.html) [policy/dqfd](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dqfd.py) | ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0 |
+| 34 | [R2D3](https://arxiv.org/pdf/1909.01387.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [R2D3 doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/r2d3.html) [R2D3中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html) [policy/r2d3](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/r2d3_zh.html) | python3 -u pong_r2d3_r2d2expert_config.py |
+| 35 | [Guided Cost Learning](https://arxiv.org/pdf/1603.00448.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [Guided Cost Learning中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/guided_cost_zh.html) [reward_model/guided_cost](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/guided_cost_reward_model.py) | python3 lunarlander_gcl_config.py |
+| 36 | [TREX](https://arxiv.org/abs/1904.06387) | ![IL](https://img.shields.io/badge/-IL-purple) | [TREX doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/trex.html) [reward_model/trex](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/trex_reward_model.py) | python3 mujoco_trex_main.py |
+| 37 | [Implicit Behavorial Cloning](https://implicitbc.github.io/) (DFO+MCMC) | ![IL](https://img.shields.io/badge/-IL-purple) | [policy/ibc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/ibc.py) [model/template/ebm](https://github.com/opendilab/DI-engine/blob/main/ding/model/template/ebm.py) | python3 d4rl_ibc_main.py -s 0 -c pen_human_ibc_mcmc_config.py |
+| 38 | [BCO](https://arxiv.org/pdf/1805.01954.pdf) | ![IL](https://img.shields.io/badge/-IL-purple) | [entry/bco](https://github.com/opendilab/DI-engine/blob/main/ding/entry/serial_entry_bco.py) | python3 -u cartpole_bco_config.py |
+| 39 | [HER](https://arxiv.org/pdf/1707.01495.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [HER doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/her.html) [reward_model/her](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/her_reward_model.py) | python3 -u bitflip_her_dqn.py |
+| 40 | [RND](https://arxiv.org/abs/1810.12894) | ![exp](https://img.shields.io/badge/-exploration-orange) | [RND doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/rnd.html) [reward_model/rnd](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/rnd_reward_model.py) | python3 -u cartpole_rnd_onppo_config.py |
+| 41 | [ICM](https://arxiv.org/pdf/1705.05363.pdf) | ![exp](https://img.shields.io/badge/-exploration-orange) | [ICM doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/icm.html) [ICM中文文档](https://di-engine-docs.readthedocs.io/zh_CN/latest/12_policies/icm_zh.html) [reward_model/icm](https://github.com/opendilab/DI-engine/blob/main/ding/reward_model/icm_reward_model.py) | python3 -u cartpole_ppo_icm_config.py |
+| 42 | [CQL](https://arxiv.org/pdf/2006.04779.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [CQL doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/cql.html) [policy/cql](https://github.com/opendilab/DI-engine/blob/main/ding/policy/cql.py) | python3 -u d4rl_cql_main.py |
+| 43 | [TD3BC](https://arxiv.org/pdf/2106.06860.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [TD3BC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/td3_bc.html) [policy/td3_bc](https://github.com/opendilab/DI-engine/blob/main/ding/policy/td3_bc.py) | python3 -u d4rl_td3_bc_main.py |
+| 44 | [Decision Transformer](https://arxiv.org/pdf/2106.01345.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [policy/dt](https://github.com/opendilab/DI-engine/blob/main/ding/policy/dt.py) | python3 -u d4rl_dt_mujoco.py |
+| 45 | [EDAC](https://arxiv.org/pdf/2110.01548.pdf) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | [EDAC doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/edac.html) [policy/edac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/edac.py) | python3 -u d4rl_edac_main.py |
+| 46 | MBSAC([SAC](https://arxiv.org/abs/1801.01290)+[MVE](https://arxiv.org/abs/1803.00101)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_mbsac_mbpo_config.py \ python3 -u pendulum_mbsac_ddppo_config.py |
+| 47 | STEVESAC([SAC](https://arxiv.org/abs/1801.01290)+[STEVE](https://arxiv.org/abs/1807.01675)+[SVG](https://arxiv.org/abs/1510.09142)) | ![continuous](https://img.shields.io/badge/-continous-green)![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [policy/mbpolicy/mbsac](https://github.com/opendilab/DI-engine/blob/main/ding/policy/mbpolicy/mbsac.py) | python3 -u pendulum_stevesac_mbpo_config.py |
+| 48 | [MBPO](https://arxiv.org/pdf/1906.08253.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [MBPO doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/mbpo.html) [world_model/mbpo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/mbpo.py) | python3 -u pendulum_sac_mbpo_config.py |
+| 49 | [DDPPO](https://openreview.net/forum?id=rzvOQrnclO0) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/ddppo](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/ddppo.py) | python3 -u pendulum_mbsac_ddppo_config.py |
+| 50 | [DreamerV3](https://arxiv.org/pdf/2301.04104.pdf) | ![mbrl](https://img.shields.io/badge/-ModelBasedRL-lightblue) | [world_model/dreamerv3](https://github.com/opendilab/DI-engine/blob/main/ding/world_model/dreamerv3.py) | python3 -u cartpole_balance_dreamer_config.py |
+| 51 | [PER](https://arxiv.org/pdf/1511.05952.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [worker/replay_buffer](https://github.com/opendilab/DI-engine/blob/main/ding/worker/replay_buffer/advanced_buffer.py) | `rainbow demo` |
+| 52 | [GAE](https://arxiv.org/pdf/1506.02438.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [rl_utils/gae](https://github.com/opendilab/DI-engine/blob/main/ding/rl_utils/gae.py) | `ppo demo` |
+| 53 | [ST-DIM](https://arxiv.org/pdf/1906.08226.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/loss/contrastive_loss](https://github.com/opendilab/DI-engine/blob/main/ding/torch_utils/loss/contrastive_loss.py) | ding -m serial -c cartpole_dqn_stdim_config.py -s 0 |
+| 54 | [PLR](https://arxiv.org/pdf/2010.03934.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [PLR doc](https://di-engine-docs.readthedocs.io/en/latest/12_policies/plr.html) [data/level_replay/level_sampler](https://github.com/opendilab/DI-engine/blob/main/ding/data/level_replay/level_sampler.py) | python3 -u bigfish_plr_config.py -s 0 |
+| 55 | [PCGrad](https://arxiv.org/pdf/2001.06782.pdf) | ![other](https://img.shields.io/badge/-other-lightgrey) | [torch_utils/optimizer_helper/PCGrad](https://github.com/opendilab/DI-engine/blob/main/ding/data/torch_utils/optimizer_helper.py) | python3 -u multi_mnist_pcgrad_main.py -s 0 |
+
+
+
+### Environment Versatility
+
+(Click to Collapse)
+
+| No | Environment | Label | Visualization | Code and Doc Links |
+| :--: | :--------------------------------------: | :---------------------------------: | :--------------------------------:|:---------------------------------------------------------: |
+| 1 | [Atari](https://github.com/openai/gym/tree/master/gym/envs/atari) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/atari/atari.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/atari/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/atari.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/atari_zh.html) |
+| 2 | [box2d/bipedalwalker](https://github.com/openai/gym/tree/master/gym/envs/box2d) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/box2d/bipedalwalker/original.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/bipedalwalker/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/bipedalwalker.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/bipedalwalker_zh.html) |
+| 3 | [box2d/lunarlander](https://github.com/openai/gym/tree/master/gym/envs/box2d) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/box2d/lunarlander/lunarlander.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/lunarlander/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/lunarlander.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/lunarlander_zh.html) |
+| 4 | [classic_control/cartpole](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/cartpole/cartpole.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/cartpole/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/cartpole.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/cartpole_zh.html) |
+| 5 | [classic_control/pendulum](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/classic_control/pendulum/pendulum.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/pendulum/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/pendulum.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/pendulum_zh.html) |
+| 6 | [competitive_rl](https://github.com/cuhkrlcourse/competitive-rl) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/competitive_rl/competitive_rl.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.classic_control) [环境指南](https://di-engine-docs.readthedocs.io/en/latest/13_envs/competitive_rl_zh.html) |
+| 7 | [gfootball](https://github.com/google-research/football) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![sparse](https://img.shields.io/badge/-sparse%20reward-orange)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/gfootball/gfootball.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo.gfootball/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/gfootball.html) [环境指南](https://di-engine-docs.readthedocs.io/en/latest/13_envs/gfootball_zh.html) |
+| 8 | [minigrid](https://github.com/maximecb/gym-minigrid) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/minigrid/minigrid.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/minigrid/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/minigrid.html) [环境指南](https://di-engine-docs.readthedocs.io/en/latest/13_envs/minigrid_zh.html) |
+| 9 | [MuJoCo](https://github.com/openai/gym/tree/master/gym/envs/mujoco) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/mujoco/mujoco.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/majoco/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/mujoco.html) [环境指南](https://di-engine-docs.readthedocs.io/en/latest/13_envs/mujoco_zh.html) |
+| 10 | [PettingZoo](https://github.com/Farama-Foundation/PettingZoo) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![continuous](https://img.shields.io/badge/-continous-green) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/petting_zoo/petting_zoo_mpe_simple_spread.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/petting_zoo/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/pettingzoo.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/pettingzoo_zh.html) |
+| 11 | [overcooked](https://github.com/HumanCompatibleAI/overcooked-demo) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/overcooked/overcooked.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/overcooded/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/overcooked.html) |
+| 12 | [procgen](https://github.com/openai/procgen) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/procgen/coinrun.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/procgen) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/procgen.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/procgen_zh.html) |
+| 13 | [pybullet](https://github.com/benelot/pybullet-gym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/pybullet/pybullet.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/pybullet/envs) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/pybullet_zh.html) |
+| 14 | [smac](https://github.com/oxwhirl/smac) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![marl](https://img.shields.io/badge/-MARL-yellow)![selfplay](https://img.shields.io/badge/-selfplay-blue)![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/smac/smac.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/smac/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/smac.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/smac_zh.html) |
+| 15 | [d4rl](https://github.com/rail-berkeley/d4rl) | ![offline](https://img.shields.io/badge/-offlineRL-darkblue) | ![ori](dizoo/d4rl/d4rl.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/d4rl) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/d4rl_zh.html) |
+| 16 | league_demo | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![original](./dizoo/league_demo/league_demo.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/league_demo/envs) |
+| 17 | pomdp atari | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/pomdp/envs) |
+| 18 | [bsuite](https://github.com/deepmind/bsuite) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/bsuite/bsuite.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/bsuite/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs//bsuite.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/bsuite_zh.html) |
+| 19 | [ImageNet](https://www.image-net.org/) | ![IL](https://img.shields.io/badge/-IL/SL-purple) | ![original](./dizoo/image_classification/imagenet.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/image_classification) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/image_cls_zh.html) |
+| 20 | [slime_volleyball](https://github.com/hardmaru/slimevolleygym) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![ori](dizoo/slime_volley/slime_volley.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/slime_volley) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/slime_volleyball.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/slime_volleyball_zh.html) |
+| 21 | [gym_hybrid](https://github.com/thomashirtz/gym-hybrid) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | ![ori](dizoo/gym_hybrid/moving_v0.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_hybrid) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/gym_hybrid.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/gym_hybrid_zh.html) |
+| 22 | [GoBigger](https://github.com/opendilab/GoBigger) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen)![marl](https://img.shields.io/badge/-MARL-yellow)![selfplay](https://img.shields.io/badge/-selfplay-blue) | ![ori](./dizoo/gobigger_overview.gif) | [dizoo link](https://github.com/opendilab/GoBigger-Challenge-2021/tree/main/di_baseline) [env tutorial](https://gobigger.readthedocs.io/en/latest/index.html) [环境指南](https://gobigger.readthedocs.io/zh_CN/latest/) |
+| 23 | [gym_soccer](https://github.com/openai/gym-soccer) | ![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) | ![ori](dizoo/gym_soccer/half_offensive.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_soccer) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/gym_soccer_zh.html) |
+| 24 |[multiagent_mujoco](https://github.com/schroederdewitt/multiagent_mujoco) | ![continuous](https://img.shields.io/badge/-continous-green) ![marl](https://img.shields.io/badge/-MARL-yellow) | ![original](./dizoo/mujoco/mujoco.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/multiagent_mujoco/envs) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/mujoco_zh.html) |
+| 25 |bitflip | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![sparse](https://img.shields.io/badge/-sparse%20reward-orange) | ![original](./dizoo/bitflip/bitflip.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/bitflip/envs) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/bitflip_zh.html) |
+| 26 |[sokoban](https://github.com/mpSchrader/gym-sokoban) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![Game 2](https://github.com/mpSchrader/gym-sokoban/raw/default/docs/Animations/solved_4.gif?raw=true) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/sokoban/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/sokoban.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/sokoban_zh.html) |
+| 27 |[gym_anytrading](https://github.com/AminHP/gym-anytrading) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/gym_anytrading/envs/position.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_anytrading) [env tutorial](https://github.com/opendilab/DI-engine/blob/main/dizoo/gym_anytrading/envs/README.md) |
+| 28 |[mario](https://github.com/Kautenja/gym-super-mario-bros) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/mario/mario.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/mario) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/gym_super_mario_bros.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/gym_super_mario_bros_zh.html) |
+| 29 |[dmc2gym](https://github.com/denisyarats/dmc2gym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/dmc2gym/dmc2gym_cheetah.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/dmc2gym) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/dmc2gym.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/dmc2gym_zh.html) |
+| 30 |[evogym](https://github.com/EvolutionGym/evogym) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/evogym/evogym.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/evogym/envs) [env tutorial](https://di-engine-docs.readthedocs.io/en/latest/13_envs/evogym.html) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/Evogym_zh.html) |
+| 31 |[gym-pybullet-drones](https://github.com/utiasDSL/gym-pybullet-drones) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/gym_pybullet_drones/gym_pybullet_drones.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/gym_pybullet_drones/envs) 环境指南 |
+| 32 |[beergame](https://github.com/OptMLGroup/DeepBeerInventory-RL) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/beergame/beergame.png) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/beergame/envs) 环境指南 |
+| 33 |[classic_control/acrobot](https://github.com/openai/gym/tree/master/gym/envs/classic_control) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/classic_control/acrobot/acrobot.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/classic_control/acrobot/envs) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/acrobot_zh.html) |
+| 34 |[box2d/car_racing](https://github.com/openai/gym/blob/master/gym/envs/box2d/car_racing.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/box2d/carracing/car_racing.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/box2d/carracing/envs) 环境指南 |
+| 35 |[metadrive](https://github.com/metadriverse/metadrive) | ![continuous](https://img.shields.io/badge/-continous-green) | ![original](./dizoo/metadrive/metadrive_env.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/metadrive/env) [环境指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/13_envs/metadrive_zh.html) |
+| 36 |[cliffwalking](https://github.com/openai/gym/blob/master/gym/envs/toy_text/cliffwalking.py) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/cliffwalking/cliff_walking.gif) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/cliffwalking/envs) env tutorial 环境指南 |
+| 37 | [tabmwp](https://promptpg.github.io/explore.html) | ![discrete](https://img.shields.io/badge/-discrete-brightgreen) | ![original](./dizoo/tabmwp/tabmwp.jpeg) | [dizoo link](https://github.com/opendilab/DI-engine/tree/main/dizoo/tabmwp) env tutorial 环境指南|
+
+![discrete](https://img.shields.io/badge/-discrete-brightgreen) means discrete action space
+
+![continuous](https://img.shields.io/badge/-continous-green) means continuous action space
+
+![hybrid](https://img.shields.io/badge/-hybrid-darkgreen) means hybrid (discrete + continuous) action space
+
+![MARL](https://img.shields.io/badge/-MARL-yellow) means multi-agent RL environment
+
+![sparse](https://img.shields.io/badge/-sparse%20reward-orange) means environment which is related to exploration and sparse reward
+
+![offline](https://img.shields.io/badge/-offlineRL-darkblue) means offline RL environment
+
+![IL](https://img.shields.io/badge/-IL/SL-purple) means Imitation Learning or Supervised Learning Dataset
+
+![selfplay](https://img.shields.io/badge/-selfplay-blue) means environment that allows agent VS agent battle
+
+P.S. some enviroments in Atari, such as **MontezumaRevenge**, are also the sparse reward type.
+
+
+
+### General Data Container: TreeTensor
+
+DI-engine utilizes [TreeTensor](https://github.com/opendilab/DI-treetensor) as the basic data container in various components, which is ease of use and consistent across different code modules such as environment definition, data processing and DRL optimization. Here are some concrete code examples:
+
+- TreeTensor can easily extend all the operations of `torch.Tensor` to nested data:
+
+ (Click for Details)
+
+ ```python
+ import treetensor.torch as ttorch
+
+
+ # create random tensor
+ data = ttorch.randn({'a': (3, 2), 'b': {'c': (3, )}})
+ # clone+detach tensor
+ data_clone = data.clone().detach()
+ # access tree structure like attribute
+ a = data.a
+ c = data.b.c
+ # stack/cat/split
+ stacked_data = ttorch.stack([data, data_clone], 0)
+ cat_data = ttorch.cat([data, data_clone], 0)
+ data, data_clone = ttorch.split(stacked_data, 1)
+ # reshape
+ data = data.unsqueeze(-1)
+ data = data.squeeze(-1)
+ flatten_data = data.view(-1)
+ # indexing
+ data_0 = data[0]
+ data_1to2 = data[1:2]
+ # execute math calculations
+ data = data.sin()
+ data.b.c.cos_().clamp_(-1, 1)
+ data += data ** 2
+ # backward
+ data.requires_grad_(True)
+ loss = data.arctan().mean()
+ loss.backward()
+ # print shape
+ print(data.shape)
+ # result
+ #
+ # ├── 'a' --> torch.Size([1, 3, 2])
+ # └── 'b' -->
+ # └── 'c' --> torch.Size([1, 3])
+ ```
+
+
+
+- TreeTensor can make it simple yet effective to implement classic deep reinforcement learning pipeline
+
+ (Click for Details)
+
+ ```diff
+ import torch
+ import treetensor.torch as ttorch
+
+ B = 4
+
+
+ def get_item():
+ return {
+ 'obs': {
+ 'scalar': torch.randn(12),
+ 'image': torch.randn(3, 32, 32),
+ },
+ 'action': torch.randint(0, 10, size=(1,)),
+ 'reward': torch.rand(1),
+ 'done': False,
+ }
+
+
+ data = [get_item() for _ in range(B)]
+
+
+ # execute `stack` op
+ - def stack(data, dim):
+ - elem = data[0]
+ - if isinstance(elem, torch.Tensor):
+ - return torch.stack(data, dim)
+ - elif isinstance(elem, dict):
+ - return {k: stack([item[k] for item in data], dim) for k in elem.keys()}
+ - elif isinstance(elem, bool):
+ - return torch.BoolTensor(data)
+ - else:
+ - raise TypeError("not support elem type: {}".format(type(elem)))
+ - stacked_data = stack(data, dim=0)
+ + data = [ttorch.tensor(d) for d in data]
+ + stacked_data = ttorch.stack(data, dim=0)
+
+ # validate
+ - assert stacked_data['obs']['image'].shape == (B, 3, 32, 32)
+ - assert stacked_data['action'].shape == (B, 1)
+ - assert stacked_data['reward'].shape == (B, 1)
+ - assert stacked_data['done'].shape == (B,)
+ - assert stacked_data['done'].dtype == torch.bool
+ + assert stacked_data.obs.image.shape == (B, 3, 32, 32)
+ + assert stacked_data.action.shape == (B, 1)
+ + assert stacked_data.reward.shape == (B, 1)
+ + assert stacked_data.done.shape == (B,)
+ + assert stacked_data.done.dtype == torch.bool
+ ```
+
+
+
+## Feedback and Contribution
+
+- [File an issue](https://github.com/opendilab/DI-engine/issues/new/choose) on Github
+- Open or participate in our [forum](https://github.com/opendilab/DI-engine/discussions)
+- Discuss on DI-engine [slack communication channel](https://join.slack.com/t/opendilab/shared_invite/zt-v9tmv4fp-nUBAQEH1_Kuyu_q4plBssQ)
+- Discuss on DI-engine's WeChat group (i.e. add us on WeChat: ding314assist)
+
+
+- Contact our email (opendilab@pjlab.org.cn)
+- Contributes to our future plan [Roadmap](https://github.com/opendilab/DI-engine/issues/548)
+
+We appreciate all the feedbacks and contributions to improve DI-engine, both algorithms and system designs. And `CONTRIBUTING.md` offers some necessary information.
+
+## Supporters
+
+### ↳ Stargazers
+
+[![Stargazers repo roster for @opendilab/DI-engine](https://reporoster.com/stars/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/stargazers)
+
+### ↳ Forkers
+
+[![Forkers repo roster for @opendilab/DI-engine](https://reporoster.com/forks/opendilab/DI-engine)](https://github.com/opendilab/DI-engine/network/members)
+
+
+## Citation
+```latex
+@misc{ding,
+ title={DI-engine: OpenDILab Decision Intelligence Engine},
+ author={OpenDILab Contributors},
+ publisher={GitHub},
+ howpublished={\url{https://github.com/opendilab/DI-engine}},
+ year={2021},
+}
+```
+
+## License
+DI-engine released under the Apache 2.0 license.
diff --git a/DI-engine/cloc.sh b/DI-engine/cloc.sh
new file mode 100755
index 0000000000000000000000000000000000000000..43bf78e2ae6b95b60a39d6b47542e476599bc0f4
--- /dev/null
+++ b/DI-engine/cloc.sh
@@ -0,0 +1,69 @@
+#!/bin/bash
+
+# This scripts counts the lines of code and comments in all source files
+# and prints the results to the command line. It uses the commandline tool
+# "cloc". You can either pass --loc, --comments or --percentage to show the
+# respective values only.
+# Some parts below need to be adapted to your project!
+
+# Get the location of this script.
+SCRIPT_DIR="$( cd "$( dirname "$0" )" && pwd )"
+
+# Run cloc - this counts code lines, blank lines and comment lines
+# for the specified languages. You will need to change this accordingly.
+# For C++, you could use "C++,C/C++ Header" for example.
+# We are only interested in the summary, therefore the tail -1
+SUMMARY="$(cloc "${SCRIPT_DIR}" --include-lang="Python" --md | tail -1)"
+
+# The $SUMMARY is one line of a markdown table and looks like this:
+# SUM:|101|3123|2238|10783
+# We use the following command to split it into an array.
+IFS='|' read -r -a TOKENS <<< "$SUMMARY"
+
+# Store the individual tokens for better readability.
+NUMBER_OF_FILES=${TOKENS[1]}
+COMMENT_LINES=${TOKENS[3]}
+LINES_OF_CODE=${TOKENS[4]}
+
+# To make the estimate of commented lines more accurate, we have to
+# subtract any copyright header which is included in each file.
+# For Fly-Pie, this header has the length of five lines.
+# All dumb comments like those /////////// or those // ------------
+# are also subtracted. As cloc does not count inline comments,
+# the overall estimate should be rather conservative.
+# Change the lines below according to your project.
+DUMB_COMMENTS="$(grep -r -E '//////|// -----' "${SCRIPT_DIR}" | wc -l)"
+COMMENT_LINES=$(($COMMENT_LINES - 5 * $NUMBER_OF_FILES - $DUMB_COMMENTS))
+
+# Print all results if no arguments are given.
+if [[ $# -eq 0 ]] ; then
+ awk -v a=$LINES_OF_CODE \
+ 'BEGIN {printf "Lines of source code: %6.1fk\n", a/1000}'
+ awk -v a=$COMMENT_LINES \
+ 'BEGIN {printf "Lines of comments: %6.1fk\n", a/1000}'
+ awk -v a=$COMMENT_LINES -v b=$LINES_OF_CODE \
+ 'BEGIN {printf "Comment Percentage: %6.1f%\n", 100*a/b}'
+ exit 0
+fi
+
+# Show lines of code if --loc is given.
+if [[ $* == *--loc* ]]
+then
+ awk -v a=$LINES_OF_CODE \
+ 'BEGIN {printf "%.1fk\n", a/1000}'
+fi
+
+# Show lines of comments if --comments is given.
+if [[ $* == *--comments* ]]
+then
+ awk -v a=$COMMENT_LINES \
+ 'BEGIN {printf "%.1fk\n", a/1000}'
+fi
+
+# Show precentage of comments if --percentage is given.
+if [[ $* == *--percentage* ]]
+then
+ awk -v a=$COMMENT_LINES -v b=$LINES_OF_CODE \
+ 'BEGIN {printf "%.1f\n", 100*a/b}'
+fi
+
diff --git a/DI-engine/codecov.yml b/DI-engine/codecov.yml
new file mode 100644
index 0000000000000000000000000000000000000000..0779ada7736e33d08302f815c225877dad53f6ad
--- /dev/null
+++ b/DI-engine/codecov.yml
@@ -0,0 +1,8 @@
+coverage:
+ status:
+ project:
+ default:
+ # basic
+ target: auto
+ threshold: 0.5%
+ if_ci_failed: success #success, failure, error, ignore
diff --git a/DI-engine/conda/conda_build_config.yaml b/DI-engine/conda/conda_build_config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..4c25caf9a7fc0813e167efb86fd61bef9f8ad828
--- /dev/null
+++ b/DI-engine/conda/conda_build_config.yaml
@@ -0,0 +1,2 @@
+python:
+ - 3.7
diff --git a/DI-engine/conda/meta.yaml b/DI-engine/conda/meta.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0dbea5284a344c75397d9d72457584ba9aee5058
--- /dev/null
+++ b/DI-engine/conda/meta.yaml
@@ -0,0 +1,35 @@
+{% set data = load_setup_py_data() %}
+package:
+ name: di-engine
+ version: v0.5.0
+
+source:
+ path: ..
+
+build:
+ number: 0
+ script: python -m pip install . -vv
+ entry_points:
+ - ding = ding.entry.cli:cli
+
+requirements:
+ build:
+ - python
+ - setuptools
+ run:
+ - python
+
+test:
+ imports:
+ - ding
+ - dizoo
+
+about:
+ home: https://github.com/opendilab/DI-engine
+ license: Apache-2.0
+ license_file: LICENSE
+ summary: DI-engine is a generalized Decision Intelligence engine (https://github.com/opendilab/DI-engine).
+ description: Please refer to https://di-engine-docs.readthedocs.io/en/latest/00_intro/index.html#what-is-di-engine
+ dev_url: https://github.com/opendilab/DI-engine
+ doc_url: https://di-engine-docs.readthedocs.io/en/latest/index.html
+ doc_source_url: https://github.com/opendilab/DI-engine-docs
diff --git a/DI-engine/ding/__init__.py b/DI-engine/ding/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..52583dc3c0e2713b5d38a799197276f52b66a927
--- /dev/null
+++ b/DI-engine/ding/__init__.py
@@ -0,0 +1,12 @@
+import os
+
+__TITLE__ = 'DI-engine'
+__VERSION__ = 'v0.5.0'
+__DESCRIPTION__ = 'Decision AI Engine'
+__AUTHOR__ = "OpenDILab Contributors"
+__AUTHOR_EMAIL__ = "opendilab@pjlab.org.cn"
+__version__ = __VERSION__
+
+enable_hpc_rl = os.environ.get('ENABLE_DI_HPC', 'false').lower() == 'true'
+enable_linklink = os.environ.get('ENABLE_LINKLINK', 'false').lower() == 'true'
+enable_numba = True
diff --git a/DI-engine/ding/bonus/__init__.py b/DI-engine/ding/bonus/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3329830b399a2a9671ef0e6649e60744bb4a3bdd
--- /dev/null
+++ b/DI-engine/ding/bonus/__init__.py
@@ -0,0 +1,132 @@
+import ding.config
+from .a2c import A2CAgent
+from .c51 import C51Agent
+from .ddpg import DDPGAgent
+from .dqn import DQNAgent
+from .pg import PGAgent
+from .ppof import PPOF
+from .ppo_offpolicy import PPOOffPolicyAgent
+from .sac import SACAgent
+from .sql import SQLAgent
+from .td3 import TD3Agent
+
+supported_algo = dict(
+ A2C=A2CAgent,
+ C51=C51Agent,
+ DDPG=DDPGAgent,
+ DQN=DQNAgent,
+ PG=PGAgent,
+ PPOF=PPOF,
+ PPOOffPolicy=PPOOffPolicyAgent,
+ SAC=SACAgent,
+ SQL=SQLAgent,
+ TD3=TD3Agent,
+)
+
+supported_algo_list = list(supported_algo.keys())
+
+
+def env_supported(algo: str = None) -> list:
+ """
+ return list of the envs that supported by di-engine.
+ """
+
+ if algo is not None:
+ if algo.upper() == "A2C":
+ return list(ding.config.example.A2C.supported_env.keys())
+ elif algo.upper() == "C51":
+ return list(ding.config.example.C51.supported_env.keys())
+ elif algo.upper() == "DDPG":
+ return list(ding.config.example.DDPG.supported_env.keys())
+ elif algo.upper() == "DQN":
+ return list(ding.config.example.DQN.supported_env.keys())
+ elif algo.upper() == "PG":
+ return list(ding.config.example.PG.supported_env.keys())
+ elif algo.upper() == "PPOF":
+ return list(ding.config.example.PPOF.supported_env.keys())
+ elif algo.upper() == "PPOOFFPOLICY":
+ return list(ding.config.example.PPOOffPolicy.supported_env.keys())
+ elif algo.upper() == "SAC":
+ return list(ding.config.example.SAC.supported_env.keys())
+ elif algo.upper() == "SQL":
+ return list(ding.config.example.SQL.supported_env.keys())
+ elif algo.upper() == "TD3":
+ return list(ding.config.example.TD3.supported_env.keys())
+ else:
+ raise ValueError("The algo {} is not supported by di-engine.".format(algo))
+ else:
+ supported_env = set()
+ supported_env.update(ding.config.example.A2C.supported_env.keys())
+ supported_env.update(ding.config.example.C51.supported_env.keys())
+ supported_env.update(ding.config.example.DDPG.supported_env.keys())
+ supported_env.update(ding.config.example.DQN.supported_env.keys())
+ supported_env.update(ding.config.example.PG.supported_env.keys())
+ supported_env.update(ding.config.example.PPOF.supported_env.keys())
+ supported_env.update(ding.config.example.PPOOffPolicy.supported_env.keys())
+ supported_env.update(ding.config.example.SAC.supported_env.keys())
+ supported_env.update(ding.config.example.SQL.supported_env.keys())
+ supported_env.update(ding.config.example.TD3.supported_env.keys())
+ # return the list of the envs
+ return list(supported_env)
+
+
+supported_env = env_supported()
+
+
+def algo_supported(env_id: str = None) -> list:
+ """
+ return list of the algos that supported by di-engine.
+ """
+ if env_id is not None:
+ algo = []
+ if env_id.upper() in [item.upper() for item in ding.config.example.A2C.supported_env.keys()]:
+ algo.append("A2C")
+ if env_id.upper() in [item.upper() for item in ding.config.example.C51.supported_env.keys()]:
+ algo.append("C51")
+ if env_id.upper() in [item.upper() for item in ding.config.example.DDPG.supported_env.keys()]:
+ algo.append("DDPG")
+ if env_id.upper() in [item.upper() for item in ding.config.example.DQN.supported_env.keys()]:
+ algo.append("DQN")
+ if env_id.upper() in [item.upper() for item in ding.config.example.PG.supported_env.keys()]:
+ algo.append("PG")
+ if env_id.upper() in [item.upper() for item in ding.config.example.PPOF.supported_env.keys()]:
+ algo.append("PPOF")
+ if env_id.upper() in [item.upper() for item in ding.config.example.PPOOffPolicy.supported_env.keys()]:
+ algo.append("PPOOffPolicy")
+ if env_id.upper() in [item.upper() for item in ding.config.example.SAC.supported_env.keys()]:
+ algo.append("SAC")
+ if env_id.upper() in [item.upper() for item in ding.config.example.SQL.supported_env.keys()]:
+ algo.append("SQL")
+ if env_id.upper() in [item.upper() for item in ding.config.example.TD3.supported_env.keys()]:
+ algo.append("TD3")
+
+ if len(algo) == 0:
+ raise ValueError("The env {} is not supported by di-engine.".format(env_id))
+ return algo
+ else:
+ return supported_algo_list
+
+
+def is_supported(env_id: str = None, algo: str = None) -> bool:
+ """
+ Check if the env-algo pair is supported by di-engine.
+ """
+ if env_id is not None and env_id.upper() in [item.upper() for item in supported_env.keys()]:
+ if algo is not None and algo.upper() in supported_algo_list:
+ if env_id.upper() in env_supported(algo):
+ return True
+ else:
+ return False
+ elif algo is None:
+ return True
+ else:
+ return False
+ elif env_id is None:
+ if algo is not None and algo.upper() in supported_algo_list:
+ return True
+ elif algo is None:
+ raise ValueError("Please specify the env or algo.")
+ else:
+ return False
+ else:
+ return False
diff --git a/DI-engine/ding/bonus/a2c.py b/DI-engine/ding/bonus/a2c.py
new file mode 100644
index 0000000000000000000000000000000000000000..d10def313bfb30a97d04ece7323a2cb05769957a
--- /dev/null
+++ b/DI-engine/ding/bonus/a2c.py
@@ -0,0 +1,460 @@
+from typing import Optional, Union, List
+from ditk import logging
+from easydict import EasyDict
+import os
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+from ding.framework import task, OnlineRLContext
+from ding.framework.middleware import CkptSaver, trainer, \
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, \
+ gae_estimator, final_ctx_saver
+from ding.envs import BaseEnv
+from ding.envs import setup_ding_env_manager
+from ding.policy import A2CPolicy
+from ding.utils import set_pkg_seed
+from ding.utils import get_env_fps, render
+from ding.config import save_config_py, compile_config
+from ding.model import VAC
+from ding.model import model_wrap
+from ding.bonus.common import TrainingReturn, EvalReturn
+from ding.config.example.A2C import supported_env_cfg
+from ding.config.example.A2C import supported_env
+
+
+class A2CAgent:
+ """
+ Overview:
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
+ Advantage Actor Critic(A2C).
+ For more information about the system design of RL agent, please refer to \
+ .
+ Interface:
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
+ """
+ supported_env_list = list(supported_env_cfg.keys())
+ """
+ Overview:
+ List of supported envs.
+ Examples:
+ >>> from ding.bonus.a2c import A2CAgent
+ >>> print(A2CAgent.supported_env_list)
+ """
+
+ def __init__(
+ self,
+ env_id: str = None,
+ env: BaseEnv = None,
+ seed: int = 0,
+ exp_name: str = None,
+ model: Optional[torch.nn.Module] = None,
+ cfg: Optional[Union[EasyDict, dict]] = None,
+ policy_state_dict: str = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize agent for A2C algorithm.
+ Arguments:
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
+ Default to 0.
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
+ - model (:obj:`torch.nn.Module`): The model of A2C algorithm, which should be an instance of class \
+ :class:`ding.model.VAC`. \
+ If not specified, a default model will be generated according to the configuration.
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of A2C algorithm, which is a dict. \
+ Default to None. If not specified, the default configuration will be used. \
+ The default configuration can be found in ``ding/config/example/A2C/gym_lunarlander_v2.py``.
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
+ If specified, the policy will be loaded from this file. Default to None.
+
+ .. note::
+ An RL Agent Instance can be initialized in two basic ways. \
+ For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
+ and we want to train an agent with A2C algorithm with default configuration. \
+ Then we can initialize the agent in the following ways:
+ >>> agent = A2CAgent(env_id='LunarLanderContinuous-v2')
+ or, if we want can specify the env_id in the configuration:
+ >>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
+ >>> agent = A2CAgent(cfg=cfg)
+ There are also other arguments to specify the agent when initializing.
+ For example, if we want to specify the environment instance:
+ >>> env = CustomizedEnv('LunarLanderContinuous-v2')
+ >>> agent = A2CAgent(cfg=cfg, env=env)
+ or, if we want to specify the model:
+ >>> model = VAC(**cfg.policy.model)
+ >>> agent = A2CAgent(cfg=cfg, model=model)
+ or, if we want to reload the policy from a saved policy state dict:
+ >>> agent = A2CAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
+ Make sure that the configuration is consistent with the saved policy state dict.
+ """
+
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
+
+ if cfg is not None and not isinstance(cfg, EasyDict):
+ cfg = EasyDict(cfg)
+
+ if env_id is not None:
+ assert env_id in A2CAgent.supported_env_list, "Please use supported envs: {}".format(
+ A2CAgent.supported_env_list
+ )
+ if cfg is None:
+ cfg = supported_env_cfg[env_id]
+ else:
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
+ else:
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
+ assert cfg.env.env_id in A2CAgent.supported_env_list, "Please use supported envs: {}".format(
+ A2CAgent.supported_env_list
+ )
+ default_policy_config = EasyDict({"policy": A2CPolicy.default_config()})
+ default_policy_config.update(cfg)
+ cfg = default_policy_config
+
+ if exp_name is not None:
+ cfg.exp_name = exp_name
+ self.cfg = compile_config(cfg, policy=A2CPolicy)
+ self.exp_name = self.cfg.exp_name
+ if env is None:
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
+ else:
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
+ self.env = env
+
+ logging.getLogger().setLevel(logging.INFO)
+ self.seed = seed
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
+ if not os.path.exists(self.exp_name):
+ os.makedirs(self.exp_name)
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
+ if model is None:
+ model = VAC(**self.cfg.policy.model)
+ self.policy = A2CPolicy(self.cfg.policy, model=model)
+ if policy_state_dict is not None:
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
+
+ def train(
+ self,
+ step: int = int(1e7),
+ collector_env_num: int = 4,
+ evaluator_env_num: int = 4,
+ n_iter_log_show: int = 500,
+ n_iter_save_ckpt: int = 1000,
+ context: Optional[str] = None,
+ debug: bool = False,
+ wandb_sweep: bool = False,
+ ) -> TrainingReturn:
+ """
+ Overview:
+ Train the agent with A2C algorithm for ``step`` iterations with ``collector_env_num`` collector \
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
+ recorded and saved by wandb.
+ Arguments:
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
+ Default to 1000.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
+ which is a hyper-parameter optimization process for seeking the best configurations. \
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
+ Returns:
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ logging.debug(self.policy._model)
+ # define env and policy
+ collector_env = self._setup_env_manager(collector_env_num, context, debug, 'collector')
+ evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug, 'evaluator')
+
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ interaction_evaluator(
+ self.cfg,
+ self.policy.eval_mode,
+ evaluator_env,
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
+ )
+ )
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(
+ StepCollector(
+ self.cfg,
+ self.policy.collect_mode,
+ collector_env,
+ random_collect_size=self.cfg.policy.random_collect_size
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
+ )
+ )
+ task.use(gae_estimator(self.cfg, self.policy.collect_mode))
+ task.use(trainer(self.cfg, self.policy.learn_mode))
+ task.use(
+ wandb_online_logger(
+ metric_list=self.policy._monitor_vars_learn(),
+ model=self.policy._model,
+ anonymous=True,
+ project_name=self.exp_name,
+ wandb_sweep=wandb_sweep,
+ )
+ )
+ task.use(termination_checker(max_env_step=step))
+ task.use(final_ctx_saver(name=self.exp_name))
+ task.run()
+
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
+
+ def deploy(
+ self,
+ enable_save_replay: bool = False,
+ concatenate_all_replay: bool = False,
+ replay_save_path: str = None,
+ seed: Optional[Union[int, List]] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Deploy the agent with A2C algorithm by interacting with the environment, during which the replay video \
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
+ Arguments:
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
+ the replay video of each episode will be saved separately.
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
+ If not specified, the video will be saved in ``exp_name/videos``.
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
+ Default to None. If not specified, ``self.seed`` will be used. \
+ If ``seed`` is an integer, the agent will be deployed once. \
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env = self.env.clone(caller='evaluator')
+
+ if seed is not None and isinstance(seed, int):
+ seeds = [seed]
+ elif seed is not None and isinstance(seed, list):
+ seeds = seed
+ else:
+ seeds = [self.seed]
+
+ returns = []
+ images = []
+ if enable_save_replay:
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
+ env.enable_save_replay(replay_path=replay_save_path)
+ else:
+ logging.warning('No video would be generated during the deploy.')
+ if concatenate_all_replay:
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
+ concatenate_all_replay = False
+
+ def single_env_forward_wrapper(forward_fn, cuda=True):
+
+ if self.cfg.policy.action_space == 'continuous':
+ forward_fn = model_wrap(forward_fn, wrapper_name='deterministic_sample').forward
+ elif self.cfg.policy.action_space == 'discrete':
+ forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
+ else:
+ raise NotImplementedError
+
+ def _forward(obs):
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
+ if cuda and torch.cuda.is_available():
+ obs = obs.cuda()
+ action = forward_fn(obs, mode='compute_actor')["action"]
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
+ action = action.squeeze(0).detach().cpu().numpy()
+ return action
+
+ return _forward
+
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.reset()
+
+ for seed in seeds:
+ env.seed(seed, dynamic_seed=False)
+ return_ = 0.
+ step = 0
+ obs = env.reset()
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ while True:
+ action = forward_fn(obs)
+ obs, rew, done, info = env.step(action)
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ return_ += rew
+ step += 1
+ if done:
+ break
+ logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
+ returns.append(return_)
+
+ env.close()
+
+ if concatenate_all_replay:
+ images = np.concatenate(images, axis=0)
+ import imageio
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
+
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
+
+ def collect_data(
+ self,
+ env_num: int = 8,
+ save_data_path: Optional[str] = None,
+ n_sample: Optional[int] = None,
+ n_episode: Optional[int] = None,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> None:
+ """
+ Overview:
+ Collect data with A2C algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
+ ``exp_name/demo_data``.
+ Arguments:
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
+ If not specified, the data will be saved in ``exp_name/demo_data``.
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
+ If not specified, ``n_episode`` must be specified.
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
+ If not specified, ``n_sample`` must be specified.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ if n_episode is not None:
+ raise NotImplementedError
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
+
+ if save_data_path is None:
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ StepCollector(
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
+ )
+ )
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
+ task.run(max_step=1)
+ logging.info(
+ f'A2C collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
+ )
+
+ def batch_evaluate(
+ self,
+ env_num: int = 4,
+ n_evaluator_episode: int = 4,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Evaluate the agent with A2C algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
+ environments. The evaluation result will be returned.
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
+ will only create one evaluator environment to evaluate the agent and save the replay video.
+ Arguments:
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.launch()
+ env.reset()
+
+ evaluate_cfg = self.cfg
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
+ task.run(max_step=1)
+
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
+
+ @property
+ def best(self) -> 'A2CAgent':
+ """
+ Overview:
+ Load the best model from the checkpoint directory, \
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
+ The return value is the agent with the best model.
+ Returns:
+ - (:obj:`A2CAgent`): The agent with the best model.
+ Examples:
+ >>> agent = A2CAgent(env_id='LunarLanderContinuous-v2')
+ >>> agent.train()
+ >>> agent = agent.best
+
+ .. note::
+ The best model is the model with the highest evaluation return. If this method is called, the current \
+ model will be replaced by the best model.
+ """
+
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
+ # Load best model if it exists
+ if os.path.exists(best_model_file_path):
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ return self
diff --git a/DI-engine/ding/bonus/c51.py b/DI-engine/ding/bonus/c51.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab4f0be85e5ab1031d2b6faa6d11672dd2c9dcbe
--- /dev/null
+++ b/DI-engine/ding/bonus/c51.py
@@ -0,0 +1,459 @@
+from typing import Optional, Union, List
+from ditk import logging
+from easydict import EasyDict
+import os
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+from ding.framework import task, OnlineRLContext
+from ding.framework.middleware import CkptSaver, \
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
+ OffPolicyLearner, final_ctx_saver, eps_greedy_handler, nstep_reward_enhancer
+from ding.envs import BaseEnv
+from ding.envs import setup_ding_env_manager
+from ding.policy import C51Policy
+from ding.utils import set_pkg_seed
+from ding.utils import get_env_fps, render
+from ding.config import save_config_py, compile_config
+from ding.model import C51DQN
+from ding.model import model_wrap
+from ding.data import DequeBuffer
+from ding.bonus.common import TrainingReturn, EvalReturn
+from ding.config.example.C51 import supported_env_cfg
+from ding.config.example.C51 import supported_env
+
+
+class C51Agent:
+ """
+ Overview:
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm C51.
+ For more information about the system design of RL agent, please refer to \
+ .
+ Interface:
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
+ """
+ supported_env_list = list(supported_env_cfg.keys())
+ """
+ Overview:
+ List of supported envs.
+ Examples:
+ >>> from ding.bonus.c51 import C51Agent
+ >>> print(C51Agent.supported_env_list)
+ """
+
+ def __init__(
+ self,
+ env_id: str = None,
+ env: BaseEnv = None,
+ seed: int = 0,
+ exp_name: str = None,
+ model: Optional[torch.nn.Module] = None,
+ cfg: Optional[Union[EasyDict, dict]] = None,
+ policy_state_dict: str = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize agent for C51 algorithm.
+ Arguments:
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
+ Default to 0.
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
+ - model (:obj:`torch.nn.Module`): The model of C51 algorithm, which should be an instance of class \
+ :class:`ding.model.C51DQN`. \
+ If not specified, a default model will be generated according to the configuration.
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of C51 algorithm, which is a dict. \
+ Default to None. If not specified, the default configuration will be used. \
+ The default configuration can be found in ``ding/config/example/C51/gym_lunarlander_v2.py``.
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
+ If specified, the policy will be loaded from this file. Default to None.
+
+ .. note::
+ An RL Agent Instance can be initialized in two basic ways. \
+ For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
+ and we want to train an agent with C51 algorithm with default configuration. \
+ Then we can initialize the agent in the following ways:
+ >>> agent = C51Agent(env_id='LunarLander-v2')
+ or, if we want can specify the env_id in the configuration:
+ >>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
+ >>> agent = C51Agent(cfg=cfg)
+ There are also other arguments to specify the agent when initializing.
+ For example, if we want to specify the environment instance:
+ >>> env = CustomizedEnv('LunarLander-v2')
+ >>> agent = C51Agent(cfg=cfg, env=env)
+ or, if we want to specify the model:
+ >>> model = C51DQN(**cfg.policy.model)
+ >>> agent = C51Agent(cfg=cfg, model=model)
+ or, if we want to reload the policy from a saved policy state dict:
+ >>> agent = C51Agent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
+ Make sure that the configuration is consistent with the saved policy state dict.
+ """
+
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
+
+ if cfg is not None and not isinstance(cfg, EasyDict):
+ cfg = EasyDict(cfg)
+
+ if env_id is not None:
+ assert env_id in C51Agent.supported_env_list, "Please use supported envs: {}".format(
+ C51Agent.supported_env_list
+ )
+ if cfg is None:
+ cfg = supported_env_cfg[env_id]
+ else:
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
+ else:
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
+ assert cfg.env.env_id in C51Agent.supported_env_list, "Please use supported envs: {}".format(
+ C51Agent.supported_env_list
+ )
+ default_policy_config = EasyDict({"policy": C51Policy.default_config()})
+ default_policy_config.update(cfg)
+ cfg = default_policy_config
+
+ if exp_name is not None:
+ cfg.exp_name = exp_name
+ self.cfg = compile_config(cfg, policy=C51Policy)
+ self.exp_name = self.cfg.exp_name
+ if env is None:
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
+ else:
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
+ self.env = env
+
+ logging.getLogger().setLevel(logging.INFO)
+ self.seed = seed
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
+ if not os.path.exists(self.exp_name):
+ os.makedirs(self.exp_name)
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
+ if model is None:
+ model = C51DQN(**self.cfg.policy.model)
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
+ self.policy = C51Policy(self.cfg.policy, model=model)
+ if policy_state_dict is not None:
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
+
+ def train(
+ self,
+ step: int = int(1e7),
+ collector_env_num: int = None,
+ evaluator_env_num: int = None,
+ n_iter_save_ckpt: int = 1000,
+ context: Optional[str] = None,
+ debug: bool = False,
+ wandb_sweep: bool = False,
+ ) -> TrainingReturn:
+ """
+ Overview:
+ Train the agent with C51 algorithm for ``step`` iterations with ``collector_env_num`` collector \
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
+ recorded and saved by wandb.
+ Arguments:
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
+ Default to 1000.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
+ which is a hyper-parameter optimization process for seeking the best configurations. \
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
+ Returns:
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ logging.debug(self.policy._model)
+ # define env and policy
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
+
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ interaction_evaluator(
+ self.cfg,
+ self.policy.eval_mode,
+ evaluator_env,
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
+ )
+ )
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(eps_greedy_handler(self.cfg))
+ task.use(
+ StepCollector(
+ self.cfg,
+ self.policy.collect_mode,
+ collector_env,
+ random_collect_size=self.cfg.policy.random_collect_size
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
+ )
+ )
+ task.use(nstep_reward_enhancer(self.cfg))
+ task.use(data_pusher(self.cfg, self.buffer_))
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
+ task.use(
+ wandb_online_logger(
+ metric_list=self.policy._monitor_vars_learn(),
+ model=self.policy._model,
+ anonymous=True,
+ project_name=self.exp_name,
+ wandb_sweep=wandb_sweep,
+ )
+ )
+ task.use(termination_checker(max_env_step=step))
+ task.use(final_ctx_saver(name=self.exp_name))
+ task.run()
+
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
+
+ def deploy(
+ self,
+ enable_save_replay: bool = False,
+ concatenate_all_replay: bool = False,
+ replay_save_path: str = None,
+ seed: Optional[Union[int, List]] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Deploy the agent with C51 algorithm by interacting with the environment, during which the replay video \
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
+ Arguments:
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
+ the replay video of each episode will be saved separately.
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
+ If not specified, the video will be saved in ``exp_name/videos``.
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
+ Default to None. If not specified, ``self.seed`` will be used. \
+ If ``seed`` is an integer, the agent will be deployed once. \
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env = self.env.clone(caller='evaluator')
+
+ if seed is not None and isinstance(seed, int):
+ seeds = [seed]
+ elif seed is not None and isinstance(seed, list):
+ seeds = seed
+ else:
+ seeds = [self.seed]
+
+ returns = []
+ images = []
+ if enable_save_replay:
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
+ env.enable_save_replay(replay_path=replay_save_path)
+ else:
+ logging.warning('No video would be generated during the deploy.')
+ if concatenate_all_replay:
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
+ concatenate_all_replay = False
+
+ def single_env_forward_wrapper(forward_fn, cuda=True):
+
+ forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
+
+ def _forward(obs):
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
+ if cuda and torch.cuda.is_available():
+ obs = obs.cuda()
+ action = forward_fn(obs)["action"]
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
+ action = action.squeeze(0).detach().cpu().numpy()
+ return action
+
+ return _forward
+
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.reset()
+
+ for seed in seeds:
+ env.seed(seed, dynamic_seed=False)
+ return_ = 0.
+ step = 0
+ obs = env.reset()
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ while True:
+ action = forward_fn(obs)
+ obs, rew, done, info = env.step(action)
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ return_ += rew
+ step += 1
+ if done:
+ break
+ logging.info(f'C51 deploy is finished, final episode return with {step} steps is: {return_}')
+ returns.append(return_)
+
+ env.close()
+
+ if concatenate_all_replay:
+ images = np.concatenate(images, axis=0)
+ import imageio
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
+
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
+
+ def collect_data(
+ self,
+ env_num: int = 8,
+ save_data_path: Optional[str] = None,
+ n_sample: Optional[int] = None,
+ n_episode: Optional[int] = None,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> None:
+ """
+ Overview:
+ Collect data with C51 algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
+ ``exp_name/demo_data``.
+ Arguments:
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
+ If not specified, the data will be saved in ``exp_name/demo_data``.
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
+ If not specified, ``n_episode`` must be specified.
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
+ If not specified, ``n_sample`` must be specified.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ if n_episode is not None:
+ raise NotImplementedError
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
+
+ if save_data_path is None:
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ StepCollector(
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
+ )
+ )
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
+ task.run(max_step=1)
+ logging.info(
+ f'C51 collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
+ )
+
+ def batch_evaluate(
+ self,
+ env_num: int = 4,
+ n_evaluator_episode: int = 4,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Evaluate the agent with C51 algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
+ environments. The evaluation result will be returned.
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
+ will only create one evaluator environment to evaluate the agent and save the replay video.
+ Arguments:
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.launch()
+ env.reset()
+
+ evaluate_cfg = self.cfg
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
+ task.run(max_step=1)
+
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
+
+ @property
+ def best(self) -> 'C51Agent':
+ """
+ Overview:
+ Load the best model from the checkpoint directory, \
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
+ The return value is the agent with the best model.
+ Returns:
+ - (:obj:`C51Agent`): The agent with the best model.
+ Examples:
+ >>> agent = C51Agent(env_id='LunarLander-v2')
+ >>> agent.train()
+ >>> agent = agent.best
+
+ .. note::
+ The best model is the model with the highest evaluation return. If this method is called, the current \
+ model will be replaced by the best model.
+ """
+
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
+ # Load best model if it exists
+ if os.path.exists(best_model_file_path):
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ return self
diff --git a/DI-engine/ding/bonus/common.py b/DI-engine/ding/bonus/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d4ddfc71175cccab096d7ca4fae1086acd668c0
--- /dev/null
+++ b/DI-engine/ding/bonus/common.py
@@ -0,0 +1,22 @@
+from dataclasses import dataclass
+import numpy as np
+
+
+@dataclass
+class TrainingReturn:
+ '''
+ Attributions
+ wandb_url: The weight & biases (wandb) project url of the trainning experiment.
+ '''
+ wandb_url: str
+
+
+@dataclass
+class EvalReturn:
+ '''
+ Attributions
+ eval_value: The mean of evaluation return.
+ eval_value_std: The standard deviation of evaluation return.
+ '''
+ eval_value: np.float32
+ eval_value_std: np.float32
diff --git a/DI-engine/ding/bonus/config.py b/DI-engine/ding/bonus/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..285eff6586e61b49925621eb1cad616dcbd637a0
--- /dev/null
+++ b/DI-engine/ding/bonus/config.py
@@ -0,0 +1,326 @@
+from easydict import EasyDict
+import os
+import gym
+from ding.envs import BaseEnv, DingEnvWrapper
+from ding.envs.env_wrappers import MaxAndSkipWrapper, WarpFrameWrapper, ScaledFloatFrameWrapper, FrameStackWrapper, \
+ EvalEpisodeReturnWrapper, TransposeWrapper, TimeLimitWrapper, FlatObsWrapper, GymToGymnasiumWrapper
+from ding.policy import PPOFPolicy
+
+
+def get_instance_config(env_id: str, algorithm: str) -> EasyDict:
+ if algorithm == 'PPOF':
+ cfg = PPOFPolicy.default_config()
+ if env_id == 'LunarLander-v2':
+ cfg.n_sample = 512
+ cfg.value_norm = 'popart'
+ cfg.entropy_weight = 1e-3
+ elif env_id == 'LunarLanderContinuous-v2':
+ cfg.action_space = 'continuous'
+ cfg.n_sample = 400
+ elif env_id == 'BipedalWalker-v3':
+ cfg.learning_rate = 1e-3
+ cfg.action_space = 'continuous'
+ cfg.n_sample = 1024
+ elif env_id == 'Pendulum-v1':
+ cfg.action_space = 'continuous'
+ cfg.n_sample = 400
+ elif env_id == 'acrobot':
+ cfg.learning_rate = 1e-4
+ cfg.n_sample = 400
+ elif env_id == 'rocket_landing':
+ cfg.n_sample = 2048
+ cfg.adv_norm = False
+ cfg.model = dict(
+ encoder_hidden_size_list=[64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ )
+ elif env_id == 'drone_fly':
+ cfg.action_space = 'continuous'
+ cfg.adv_norm = False
+ cfg.epoch_per_collect = 5
+ cfg.learning_rate = 5e-5
+ cfg.n_sample = 640
+ elif env_id == 'hybrid_moving':
+ cfg.action_space = 'hybrid'
+ cfg.n_sample = 3200
+ cfg.entropy_weight = 0.03
+ cfg.batch_size = 320
+ cfg.adv_norm = False
+ cfg.model = dict(
+ encoder_hidden_size_list=[256, 128, 64, 64],
+ sigma_type='fixed',
+ fixed_sigma_value=0.3,
+ bound_type='tanh',
+ )
+ elif env_id == 'evogym_carrier':
+ cfg.action_space = 'continuous'
+ cfg.n_sample = 2048
+ cfg.batch_size = 256
+ cfg.epoch_per_collect = 10
+ cfg.learning_rate = 3e-3
+ elif env_id == 'mario':
+ cfg.n_sample = 256
+ cfg.batch_size = 64
+ cfg.epoch_per_collect = 2
+ cfg.learning_rate = 1e-3
+ cfg.model = dict(
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ )
+ elif env_id == 'di_sheep':
+ cfg.n_sample = 3200
+ cfg.batch_size = 320
+ cfg.epoch_per_collect = 10
+ cfg.learning_rate = 3e-4
+ cfg.adv_norm = False
+ cfg.entropy_weight = 0.001
+ elif env_id == 'procgen_bigfish':
+ cfg.n_sample = 16384
+ cfg.batch_size = 16384
+ cfg.epoch_per_collect = 10
+ cfg.learning_rate = 5e-4
+ cfg.model = dict(
+ encoder_hidden_size_list=[64, 128, 256],
+ critic_head_hidden_size=256,
+ actor_head_hidden_size=256,
+ )
+ elif env_id in ['KangarooNoFrameskip-v4', 'BowlingNoFrameskip-v4']:
+ cfg.n_sample = 1024
+ cfg.batch_size = 128
+ cfg.epoch_per_collect = 10
+ cfg.learning_rate = 0.0001
+ cfg.model = dict(
+ encoder_hidden_size_list=[32, 64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ critic_head_layer_num=2,
+ )
+ elif env_id == 'PongNoFrameskip-v4':
+ cfg.n_sample = 3200
+ cfg.batch_size = 320
+ cfg.epoch_per_collect = 10
+ cfg.learning_rate = 3e-4
+ cfg.model = dict(
+ encoder_hidden_size_list=[64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ )
+ elif env_id == 'SpaceInvadersNoFrameskip-v4':
+ cfg.n_sample = 320
+ cfg.batch_size = 320
+ cfg.epoch_per_collect = 1
+ cfg.learning_rate = 1e-3
+ cfg.entropy_weight = 0.01
+ cfg.lr_scheduler = (2000, 0.1)
+ cfg.model = dict(
+ encoder_hidden_size_list=[64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ )
+ elif env_id == 'QbertNoFrameskip-v4':
+ cfg.n_sample = 3200
+ cfg.batch_size = 320
+ cfg.epoch_per_collect = 10
+ cfg.learning_rate = 5e-4
+ cfg.lr_scheduler = (1000, 0.1)
+ cfg.model = dict(
+ encoder_hidden_size_list=[64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ )
+ elif env_id == 'minigrid_fourroom':
+ cfg.n_sample = 3200
+ cfg.batch_size = 320
+ cfg.learning_rate = 3e-4
+ cfg.epoch_per_collect = 10
+ cfg.entropy_weight = 0.001
+ elif env_id == 'metadrive':
+ cfg.learning_rate = 3e-4
+ cfg.action_space = 'continuous'
+ cfg.entropy_weight = 0.001
+ cfg.n_sample = 3000
+ cfg.epoch_per_collect = 10
+ cfg.learning_rate = 0.0001
+ cfg.model = dict(
+ encoder_hidden_size_list=[32, 64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ critic_head_layer_num=2,
+ )
+ elif env_id == 'Hopper-v3':
+ cfg.action_space = "continuous"
+ cfg.n_sample = 3200
+ cfg.batch_size = 320
+ cfg.epoch_per_collect = 10
+ cfg.learning_rate = 3e-4
+ elif env_id == 'HalfCheetah-v3':
+ cfg.action_space = "continuous"
+ cfg.n_sample = 3200
+ cfg.batch_size = 320
+ cfg.epoch_per_collect = 10
+ cfg.learning_rate = 3e-4
+ elif env_id == 'Walker2d-v3':
+ cfg.action_space = "continuous"
+ cfg.n_sample = 3200
+ cfg.batch_size = 320
+ cfg.epoch_per_collect = 10
+ cfg.learning_rate = 3e-4
+ else:
+ raise KeyError("not supported env type: {}".format(env_id))
+ else:
+ raise KeyError("not supported algorithm type: {}".format(algorithm))
+
+ return cfg
+
+
+def get_instance_env(env_id: str) -> BaseEnv:
+ if env_id == 'LunarLander-v2':
+ return DingEnvWrapper(gym.make('LunarLander-v2'))
+ elif env_id == 'LunarLanderContinuous-v2':
+ return DingEnvWrapper(gym.make('LunarLanderContinuous-v2', continuous=True))
+ elif env_id == 'BipedalWalker-v3':
+ return DingEnvWrapper(gym.make('BipedalWalker-v3'), cfg={'act_scale': True, 'rew_clip': True})
+ elif env_id == 'Pendulum-v1':
+ return DingEnvWrapper(gym.make('Pendulum-v1'), cfg={'act_scale': True})
+ elif env_id == 'acrobot':
+ return DingEnvWrapper(gym.make('Acrobot-v1'))
+ elif env_id == 'rocket_landing':
+ from dizoo.rocket.envs import RocketEnv
+ cfg = EasyDict({
+ 'task': 'landing',
+ 'max_steps': 800,
+ })
+ return RocketEnv(cfg)
+ elif env_id == 'drone_fly':
+ from dizoo.gym_pybullet_drones.envs import GymPybulletDronesEnv
+ cfg = EasyDict({
+ 'env_id': 'flythrugate-aviary-v0',
+ 'action_type': 'VEL',
+ })
+ return GymPybulletDronesEnv(cfg)
+ elif env_id == 'hybrid_moving':
+ import gym_hybrid
+ return DingEnvWrapper(gym.make('Moving-v0'))
+ elif env_id == 'evogym_carrier':
+ import evogym.envs
+ from evogym import sample_robot, WorldObject
+ path = os.path.join(os.path.dirname(__file__), '../../dizoo/evogym/envs/world_data/carry_bot.json')
+ robot_object = WorldObject.from_json(path)
+ body = robot_object.get_structure()
+ return DingEnvWrapper(
+ gym.make('Carrier-v0', body=body),
+ cfg={
+ 'env_wrapper': [
+ lambda env: TimeLimitWrapper(env, max_limit=300),
+ lambda env: EvalEpisodeReturnWrapper(env),
+ ]
+ }
+ )
+ elif env_id == 'mario':
+ import gym_super_mario_bros
+ from nes_py.wrappers import JoypadSpace
+ return DingEnvWrapper(
+ JoypadSpace(gym_super_mario_bros.make("SuperMarioBros-1-1-v1"), [["right"], ["right", "A"]]),
+ cfg={
+ 'env_wrapper': [
+ lambda env: MaxAndSkipWrapper(env, skip=4),
+ lambda env: WarpFrameWrapper(env, size=84),
+ lambda env: ScaledFloatFrameWrapper(env),
+ lambda env: FrameStackWrapper(env, n_frames=4),
+ lambda env: TimeLimitWrapper(env, max_limit=200),
+ lambda env: EvalEpisodeReturnWrapper(env),
+ ]
+ }
+ )
+ elif env_id == 'di_sheep':
+ from sheep_env import SheepEnv
+ return DingEnvWrapper(SheepEnv(level=9))
+ elif env_id == 'procgen_bigfish':
+ return DingEnvWrapper(
+ gym.make('procgen:procgen-bigfish-v0', start_level=0, num_levels=1),
+ cfg={
+ 'env_wrapper': [
+ lambda env: TransposeWrapper(env),
+ lambda env: ScaledFloatFrameWrapper(env),
+ lambda env: EvalEpisodeReturnWrapper(env),
+ ]
+ },
+ seed_api=False,
+ )
+ elif env_id == 'Hopper-v3':
+ cfg = EasyDict(
+ env_id='Hopper-v3',
+ env_wrapper='mujoco_default',
+ act_scale=True,
+ rew_clip=True,
+ )
+ return DingEnvWrapper(gym.make('Hopper-v3'), cfg=cfg)
+ elif env_id == 'HalfCheetah-v3':
+ cfg = EasyDict(
+ env_id='HalfCheetah-v3',
+ env_wrapper='mujoco_default',
+ act_scale=True,
+ rew_clip=True,
+ )
+ return DingEnvWrapper(gym.make('HalfCheetah-v3'), cfg=cfg)
+ elif env_id == 'Walker2d-v3':
+ cfg = EasyDict(
+ env_id='Walker2d-v3',
+ env_wrapper='mujoco_default',
+ act_scale=True,
+ rew_clip=True,
+ )
+ return DingEnvWrapper(gym.make('Walker2d-v3'), cfg=cfg)
+
+ elif env_id in [
+ 'BowlingNoFrameskip-v4',
+ 'BreakoutNoFrameskip-v4',
+ 'GopherNoFrameskip-v4'
+ 'KangarooNoFrameskip-v4',
+ 'PongNoFrameskip-v4',
+ 'QbertNoFrameskip-v4',
+ 'SpaceInvadersNoFrameskip-v4',
+ ]:
+
+ cfg = EasyDict({
+ 'env_id': env_id,
+ 'env_wrapper': 'atari_default',
+ })
+ ding_env_atari = DingEnvWrapper(gym.make(env_id), cfg=cfg)
+ return ding_env_atari
+ elif env_id == 'minigrid_fourroom':
+ import gymnasium
+ return DingEnvWrapper(
+ gymnasium.make('MiniGrid-FourRooms-v0'),
+ cfg={
+ 'env_wrapper': [
+ lambda env: GymToGymnasiumWrapper(env),
+ lambda env: FlatObsWrapper(env),
+ lambda env: TimeLimitWrapper(env, max_limit=300),
+ lambda env: EvalEpisodeReturnWrapper(env),
+ ]
+ }
+ )
+ elif env_id == 'metadrive':
+ from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
+ from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
+ cfg = dict(
+ map='XSOS',
+ horizon=4000,
+ out_of_road_penalty=40.0,
+ crash_vehicle_penalty=40.0,
+ out_of_route_done=True,
+ )
+ cfg = EasyDict(cfg)
+ return DriveEnvWrapper(MetaDrivePPOOriginEnv(cfg))
+ else:
+ raise KeyError("not supported env type: {}".format(env_id))
+
+
+def get_hybrid_shape(action_space) -> EasyDict:
+ return EasyDict({
+ 'action_type_shape': action_space[0].n,
+ 'action_args_shape': action_space[1].shape,
+ })
diff --git a/DI-engine/ding/bonus/ddpg.py b/DI-engine/ding/bonus/ddpg.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dade9e38b5bd74c1612f551ae85ad6fa7d2979e
--- /dev/null
+++ b/DI-engine/ding/bonus/ddpg.py
@@ -0,0 +1,456 @@
+from typing import Optional, Union, List
+from ditk import logging
+from easydict import EasyDict
+import os
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+from ding.framework import task, OnlineRLContext
+from ding.framework.middleware import CkptSaver, \
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
+ OffPolicyLearner, final_ctx_saver
+from ding.envs import BaseEnv
+from ding.envs import setup_ding_env_manager
+from ding.policy import DDPGPolicy
+from ding.utils import set_pkg_seed
+from ding.utils import get_env_fps, render
+from ding.config import save_config_py, compile_config
+from ding.model import ContinuousQAC
+from ding.data import DequeBuffer
+from ding.bonus.common import TrainingReturn, EvalReturn
+from ding.config.example.DDPG import supported_env_cfg
+from ding.config.example.DDPG import supported_env
+
+
+class DDPGAgent:
+ """
+ Overview:
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
+ Deep Deterministic Policy Gradient(DDPG).
+ For more information about the system design of RL agent, please refer to \
+ .
+ Interface:
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
+ """
+ supported_env_list = list(supported_env_cfg.keys())
+ """
+ Overview:
+ List of supported envs.
+ Examples:
+ >>> from ding.bonus.ddpg import DDPGAgent
+ >>> print(DDPGAgent.supported_env_list)
+ """
+
+ def __init__(
+ self,
+ env_id: str = None,
+ env: BaseEnv = None,
+ seed: int = 0,
+ exp_name: str = None,
+ model: Optional[torch.nn.Module] = None,
+ cfg: Optional[Union[EasyDict, dict]] = None,
+ policy_state_dict: str = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize agent for DDPG algorithm.
+ Arguments:
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
+ Default to 0.
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
+ - model (:obj:`torch.nn.Module`): The model of DDPG algorithm, which should be an instance of class \
+ :class:`ding.model.ContinuousQAC`. \
+ If not specified, a default model will be generated according to the configuration.
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of DDPG algorithm, which is a dict. \
+ Default to None. If not specified, the default configuration will be used. \
+ The default configuration can be found in ``ding/config/example/DDPG/gym_lunarlander_v2.py``.
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
+ If specified, the policy will be loaded from this file. Default to None.
+
+ .. note::
+ An RL Agent Instance can be initialized in two basic ways. \
+ For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
+ and we want to train an agent with DDPG algorithm with default configuration. \
+ Then we can initialize the agent in the following ways:
+ >>> agent = DDPGAgent(env_id='LunarLanderContinuous-v2')
+ or, if we want can specify the env_id in the configuration:
+ >>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
+ >>> agent = DDPGAgent(cfg=cfg)
+ There are also other arguments to specify the agent when initializing.
+ For example, if we want to specify the environment instance:
+ >>> env = CustomizedEnv('LunarLanderContinuous-v2')
+ >>> agent = DDPGAgent(cfg=cfg, env=env)
+ or, if we want to specify the model:
+ >>> model = ContinuousQAC(**cfg.policy.model)
+ >>> agent = DDPGAgent(cfg=cfg, model=model)
+ or, if we want to reload the policy from a saved policy state dict:
+ >>> agent = DDPGAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
+ Make sure that the configuration is consistent with the saved policy state dict.
+ """
+
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
+
+ if cfg is not None and not isinstance(cfg, EasyDict):
+ cfg = EasyDict(cfg)
+
+ if env_id is not None:
+ assert env_id in DDPGAgent.supported_env_list, "Please use supported envs: {}".format(
+ DDPGAgent.supported_env_list
+ )
+ if cfg is None:
+ cfg = supported_env_cfg[env_id]
+ else:
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
+ else:
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
+ assert cfg.env.env_id in DDPGAgent.supported_env_list, "Please use supported envs: {}".format(
+ DDPGAgent.supported_env_list
+ )
+ default_policy_config = EasyDict({"policy": DDPGPolicy.default_config()})
+ default_policy_config.update(cfg)
+ cfg = default_policy_config
+
+ if exp_name is not None:
+ cfg.exp_name = exp_name
+ self.cfg = compile_config(cfg, policy=DDPGPolicy)
+ self.exp_name = self.cfg.exp_name
+ if env is None:
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
+ else:
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
+ self.env = env
+
+ logging.getLogger().setLevel(logging.INFO)
+ self.seed = seed
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
+ if not os.path.exists(self.exp_name):
+ os.makedirs(self.exp_name)
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
+ if model is None:
+ model = ContinuousQAC(**self.cfg.policy.model)
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
+ self.policy = DDPGPolicy(self.cfg.policy, model=model)
+ if policy_state_dict is not None:
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
+
+ def train(
+ self,
+ step: int = int(1e7),
+ collector_env_num: int = None,
+ evaluator_env_num: int = None,
+ n_iter_log_show: int = 500,
+ n_iter_save_ckpt: int = 1000,
+ context: Optional[str] = None,
+ debug: bool = False,
+ wandb_sweep: bool = False,
+ ) -> TrainingReturn:
+ """
+ Overview:
+ Train the agent with DDPG algorithm for ``step`` iterations with ``collector_env_num`` collector \
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
+ recorded and saved by wandb.
+ Arguments:
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
+ Default to 1000.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
+ which is a hyper-parameter optimization process for seeking the best configurations. \
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
+ Returns:
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ logging.debug(self.policy._model)
+ # define env and policy
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
+
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ interaction_evaluator(
+ self.cfg,
+ self.policy.eval_mode,
+ evaluator_env,
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
+ )
+ )
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(
+ StepCollector(
+ self.cfg,
+ self.policy.collect_mode,
+ collector_env,
+ random_collect_size=self.cfg.policy.random_collect_size
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
+ )
+ )
+ task.use(data_pusher(self.cfg, self.buffer_))
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
+ task.use(
+ wandb_online_logger(
+ metric_list=self.policy._monitor_vars_learn(),
+ model=self.policy._model,
+ anonymous=True,
+ project_name=self.exp_name,
+ wandb_sweep=wandb_sweep,
+ )
+ )
+ task.use(termination_checker(max_env_step=step))
+ task.use(final_ctx_saver(name=self.exp_name))
+ task.run()
+
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
+
+ def deploy(
+ self,
+ enable_save_replay: bool = False,
+ concatenate_all_replay: bool = False,
+ replay_save_path: str = None,
+ seed: Optional[Union[int, List]] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Deploy the agent with DDPG algorithm by interacting with the environment, during which the replay video \
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
+ Arguments:
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
+ the replay video of each episode will be saved separately.
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
+ If not specified, the video will be saved in ``exp_name/videos``.
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
+ Default to None. If not specified, ``self.seed`` will be used. \
+ If ``seed`` is an integer, the agent will be deployed once. \
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env = self.env.clone(caller='evaluator')
+
+ if seed is not None and isinstance(seed, int):
+ seeds = [seed]
+ elif seed is not None and isinstance(seed, list):
+ seeds = seed
+ else:
+ seeds = [self.seed]
+
+ returns = []
+ images = []
+ if enable_save_replay:
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
+ env.enable_save_replay(replay_path=replay_save_path)
+ else:
+ logging.warning('No video would be generated during the deploy.')
+ if concatenate_all_replay:
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
+ concatenate_all_replay = False
+
+ def single_env_forward_wrapper(forward_fn, cuda=True):
+
+ def _forward(obs):
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
+ if cuda and torch.cuda.is_available():
+ obs = obs.cuda()
+ action = forward_fn(obs, mode='compute_actor')["action"]
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
+ action = action.squeeze(0).detach().cpu().numpy()
+ return action
+
+ return _forward
+
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.reset()
+
+ for seed in seeds:
+ env.seed(seed, dynamic_seed=False)
+ return_ = 0.
+ step = 0
+ obs = env.reset()
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ while True:
+ action = forward_fn(obs)
+ obs, rew, done, info = env.step(action)
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ return_ += rew
+ step += 1
+ if done:
+ break
+ logging.info(f'DDPG deploy is finished, final episode return with {step} steps is: {return_}')
+ returns.append(return_)
+
+ env.close()
+
+ if concatenate_all_replay:
+ images = np.concatenate(images, axis=0)
+ import imageio
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
+
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
+
+ def collect_data(
+ self,
+ env_num: int = 8,
+ save_data_path: Optional[str] = None,
+ n_sample: Optional[int] = None,
+ n_episode: Optional[int] = None,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> None:
+ """
+ Overview:
+ Collect data with DDPG algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
+ ``exp_name/demo_data``.
+ Arguments:
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
+ If not specified, the data will be saved in ``exp_name/demo_data``.
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
+ If not specified, ``n_episode`` must be specified.
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
+ If not specified, ``n_sample`` must be specified.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ if n_episode is not None:
+ raise NotImplementedError
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
+
+ if save_data_path is None:
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ StepCollector(
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
+ )
+ )
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
+ task.run(max_step=1)
+ logging.info(
+ f'DDPG collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
+ )
+
+ def batch_evaluate(
+ self,
+ env_num: int = 4,
+ n_evaluator_episode: int = 4,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Evaluate the agent with DDPG algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
+ environments. The evaluation result will be returned.
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
+ will only create one evaluator environment to evaluate the agent and save the replay video.
+ Arguments:
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.launch()
+ env.reset()
+
+ evaluate_cfg = self.cfg
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
+ task.run(max_step=1)
+
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
+
+ @property
+ def best(self) -> 'DDPGAgent':
+ """
+ Overview:
+ Load the best model from the checkpoint directory, \
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
+ The return value is the agent with the best model.
+ Returns:
+ - (:obj:`DDPGAgent`): The agent with the best model.
+ Examples:
+ >>> agent = DDPGAgent(env_id='LunarLanderContinuous-v2')
+ >>> agent.train()
+ >>> agent = agent.best
+
+ .. note::
+ The best model is the model with the highest evaluation return. If this method is called, the current \
+ model will be replaced by the best model.
+ """
+
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
+ # Load best model if it exists
+ if os.path.exists(best_model_file_path):
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ return self
diff --git a/DI-engine/ding/bonus/dqn.py b/DI-engine/ding/bonus/dqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..4894e2aa6f4aee69d48ad09b93db0e0f310cd1e4
--- /dev/null
+++ b/DI-engine/ding/bonus/dqn.py
@@ -0,0 +1,460 @@
+from typing import Optional, Union, List
+from ditk import logging
+from easydict import EasyDict
+import os
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+from ding.framework import task, OnlineRLContext
+from ding.framework.middleware import CkptSaver, \
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
+ OffPolicyLearner, final_ctx_saver, nstep_reward_enhancer, eps_greedy_handler
+from ding.envs import BaseEnv
+from ding.envs import setup_ding_env_manager
+from ding.policy import DQNPolicy
+from ding.utils import set_pkg_seed
+from ding.utils import get_env_fps, render
+from ding.config import save_config_py, compile_config
+from ding.model import DQN
+from ding.model import model_wrap
+from ding.data import DequeBuffer
+from ding.bonus.common import TrainingReturn, EvalReturn
+from ding.config.example.DQN import supported_env_cfg
+from ding.config.example.DQN import supported_env
+
+
+class DQNAgent:
+ """
+ Overview:
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm Deep Q-Learning(DQN).
+ For more information about the system design of RL agent, please refer to \
+ .
+ Interface:
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
+ """
+ supported_env_list = list(supported_env_cfg.keys())
+ """
+ Overview:
+ List of supported envs.
+ Examples:
+ >>> from ding.bonus.dqn import DQNAgent
+ >>> print(DQNAgent.supported_env_list)
+ """
+
+ def __init__(
+ self,
+ env_id: str = None,
+ env: BaseEnv = None,
+ seed: int = 0,
+ exp_name: str = None,
+ model: Optional[torch.nn.Module] = None,
+ cfg: Optional[Union[EasyDict, dict]] = None,
+ policy_state_dict: str = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize agent for DQN algorithm.
+ Arguments:
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
+ Default to 0.
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
+ - model (:obj:`torch.nn.Module`): The model of DQN algorithm, which should be an instance of class \
+ :class:`ding.model.DQN`. \
+ If not specified, a default model will be generated according to the configuration.
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of DQN algorithm, which is a dict. \
+ Default to None. If not specified, the default configuration will be used. \
+ The default configuration can be found in ``ding/config/example/DQN/gym_lunarlander_v2.py``.
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
+ If specified, the policy will be loaded from this file. Default to None.
+
+ .. note::
+ An RL Agent Instance can be initialized in two basic ways. \
+ For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
+ and we want to train an agent with DQN algorithm with default configuration. \
+ Then we can initialize the agent in the following ways:
+ >>> agent = DQNAgent(env_id='LunarLander-v2')
+ or, if we want can specify the env_id in the configuration:
+ >>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
+ >>> agent = DQNAgent(cfg=cfg)
+ There are also other arguments to specify the agent when initializing.
+ For example, if we want to specify the environment instance:
+ >>> env = CustomizedEnv('LunarLander-v2')
+ >>> agent = DQNAgent(cfg=cfg, env=env)
+ or, if we want to specify the model:
+ >>> model = DQN(**cfg.policy.model)
+ >>> agent = DQNAgent(cfg=cfg, model=model)
+ or, if we want to reload the policy from a saved policy state dict:
+ >>> agent = DQNAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
+ Make sure that the configuration is consistent with the saved policy state dict.
+ """
+
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
+
+ if cfg is not None and not isinstance(cfg, EasyDict):
+ cfg = EasyDict(cfg)
+
+ if env_id is not None:
+ assert env_id in DQNAgent.supported_env_list, "Please use supported envs: {}".format(
+ DQNAgent.supported_env_list
+ )
+ if cfg is None:
+ cfg = supported_env_cfg[env_id]
+ else:
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
+ else:
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
+ assert cfg.env.env_id in DQNAgent.supported_env_list, "Please use supported envs: {}".format(
+ DQNAgent.supported_env_list
+ )
+ default_policy_config = EasyDict({"policy": DQNPolicy.default_config()})
+ default_policy_config.update(cfg)
+ cfg = default_policy_config
+
+ if exp_name is not None:
+ cfg.exp_name = exp_name
+ self.cfg = compile_config(cfg, policy=DQNPolicy)
+ self.exp_name = self.cfg.exp_name
+ if env is None:
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
+ else:
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
+ self.env = env
+
+ logging.getLogger().setLevel(logging.INFO)
+ self.seed = seed
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
+ if not os.path.exists(self.exp_name):
+ os.makedirs(self.exp_name)
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
+ if model is None:
+ model = DQN(**self.cfg.policy.model)
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
+ self.policy = DQNPolicy(self.cfg.policy, model=model)
+ if policy_state_dict is not None:
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
+
+ def train(
+ self,
+ step: int = int(1e7),
+ collector_env_num: int = None,
+ evaluator_env_num: int = None,
+ n_iter_save_ckpt: int = 1000,
+ context: Optional[str] = None,
+ debug: bool = False,
+ wandb_sweep: bool = False,
+ ) -> TrainingReturn:
+ """
+ Overview:
+ Train the agent with DQN algorithm for ``step`` iterations with ``collector_env_num`` collector \
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
+ recorded and saved by wandb.
+ Arguments:
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
+ Default to 1000.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
+ which is a hyper-parameter optimization process for seeking the best configurations. \
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
+ Returns:
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ logging.debug(self.policy._model)
+ # define env and policy
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
+
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ interaction_evaluator(
+ self.cfg,
+ self.policy.eval_mode,
+ evaluator_env,
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
+ )
+ )
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(eps_greedy_handler(self.cfg))
+ task.use(
+ StepCollector(
+ self.cfg,
+ self.policy.collect_mode,
+ collector_env,
+ random_collect_size=self.cfg.policy.random_collect_size
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
+ )
+ )
+ if "nstep" in self.cfg.policy and self.cfg.policy.nstep > 1:
+ task.use(nstep_reward_enhancer(self.cfg))
+ task.use(data_pusher(self.cfg, self.buffer_))
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
+ task.use(
+ wandb_online_logger(
+ metric_list=self.policy._monitor_vars_learn(),
+ model=self.policy._model,
+ anonymous=True,
+ project_name=self.exp_name,
+ wandb_sweep=wandb_sweep,
+ )
+ )
+ task.use(termination_checker(max_env_step=step))
+ task.use(final_ctx_saver(name=self.exp_name))
+ task.run()
+
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
+
+ def deploy(
+ self,
+ enable_save_replay: bool = False,
+ concatenate_all_replay: bool = False,
+ replay_save_path: str = None,
+ seed: Optional[Union[int, List]] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Deploy the agent with DQN algorithm by interacting with the environment, during which the replay video \
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
+ Arguments:
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
+ the replay video of each episode will be saved separately.
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
+ If not specified, the video will be saved in ``exp_name/videos``.
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
+ Default to None. If not specified, ``self.seed`` will be used. \
+ If ``seed`` is an integer, the agent will be deployed once. \
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env = self.env.clone(caller='evaluator')
+
+ if seed is not None and isinstance(seed, int):
+ seeds = [seed]
+ elif seed is not None and isinstance(seed, list):
+ seeds = seed
+ else:
+ seeds = [self.seed]
+
+ returns = []
+ images = []
+ if enable_save_replay:
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
+ env.enable_save_replay(replay_path=replay_save_path)
+ else:
+ logging.warning('No video would be generated during the deploy.')
+ if concatenate_all_replay:
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
+ concatenate_all_replay = False
+
+ def single_env_forward_wrapper(forward_fn, cuda=True):
+
+ forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
+
+ def _forward(obs):
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
+ if cuda and torch.cuda.is_available():
+ obs = obs.cuda()
+ action = forward_fn(obs)["action"]
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
+ action = action.squeeze(0).detach().cpu().numpy()
+ return action
+
+ return _forward
+
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.reset()
+
+ for seed in seeds:
+ env.seed(seed, dynamic_seed=False)
+ return_ = 0.
+ step = 0
+ obs = env.reset()
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ while True:
+ action = forward_fn(obs)
+ obs, rew, done, info = env.step(action)
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ return_ += rew
+ step += 1
+ if done:
+ break
+ logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
+ returns.append(return_)
+
+ env.close()
+
+ if concatenate_all_replay:
+ images = np.concatenate(images, axis=0)
+ import imageio
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
+
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
+
+ def collect_data(
+ self,
+ env_num: int = 8,
+ save_data_path: Optional[str] = None,
+ n_sample: Optional[int] = None,
+ n_episode: Optional[int] = None,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> None:
+ """
+ Overview:
+ Collect data with DQN algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
+ ``exp_name/demo_data``.
+ Arguments:
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
+ If not specified, the data will be saved in ``exp_name/demo_data``.
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
+ If not specified, ``n_episode`` must be specified.
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
+ If not specified, ``n_sample`` must be specified.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ if n_episode is not None:
+ raise NotImplementedError
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
+
+ if save_data_path is None:
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ StepCollector(
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
+ )
+ )
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
+ task.run(max_step=1)
+ logging.info(
+ f'DQN collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
+ )
+
+ def batch_evaluate(
+ self,
+ env_num: int = 4,
+ n_evaluator_episode: int = 4,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Evaluate the agent with DQN algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
+ environments. The evaluation result will be returned.
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
+ will only create one evaluator environment to evaluate the agent and save the replay video.
+ Arguments:
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.launch()
+ env.reset()
+
+ evaluate_cfg = self.cfg
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
+ task.run(max_step=1)
+
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
+
+ @property
+ def best(self) -> 'DQNAgent':
+ """
+ Overview:
+ Load the best model from the checkpoint directory, \
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
+ The return value is the agent with the best model.
+ Returns:
+ - (:obj:`DQNAgent`): The agent with the best model.
+ Examples:
+ >>> agent = DQNAgent(env_id='LunarLander-v2')
+ >>> agent.train()
+ >>> agent = agent.best
+
+ .. note::
+ The best model is the model with the highest evaluation return. If this method is called, the current \
+ model will be replaced by the best model.
+ """
+
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
+ # Load best model if it exists
+ if os.path.exists(best_model_file_path):
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ return self
diff --git a/DI-engine/ding/bonus/model.py b/DI-engine/ding/bonus/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..d33fa4c779143714da762ee67f2030ee0fdd0af9
--- /dev/null
+++ b/DI-engine/ding/bonus/model.py
@@ -0,0 +1,245 @@
+from typing import Union, Optional
+from easydict import EasyDict
+import torch
+import torch.nn as nn
+import treetensor.torch as ttorch
+from copy import deepcopy
+from ding.utils import SequenceType, squeeze
+from ding.model.common import ReparameterizationHead, RegressionHead, MultiHead, \
+ FCEncoder, ConvEncoder, IMPALAConvEncoder, PopArtVHead
+from ding.torch_utils import MLP, fc_block
+
+
+class DiscretePolicyHead(nn.Module):
+
+ def __init__(
+ self,
+ hidden_size: int,
+ output_size: int,
+ layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ ) -> None:
+ super(DiscretePolicyHead, self).__init__()
+ self.main = nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ layer_num,
+ layer_fn=nn.Linear,
+ activation=activation,
+ norm_type=norm_type
+ ), fc_block(hidden_size, output_size)
+ )
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ return self.main(x)
+
+
+class PPOFModel(nn.Module):
+ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType, EasyDict],
+ action_space: str = 'discrete',
+ share_encoder: bool = True,
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 1,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ sigma_type: Optional[str] = 'independent',
+ fixed_sigma_value: Optional[int] = 0.3,
+ bound_type: Optional[str] = None,
+ encoder: Optional[torch.nn.Module] = None,
+ popart_head=False,
+ ) -> None:
+ super(PPOFModel, self).__init__()
+ obs_shape = squeeze(obs_shape)
+ action_shape = squeeze(action_shape)
+ self.obs_shape, self.action_shape = obs_shape, action_shape
+ self.share_encoder = share_encoder
+
+ # Encoder Type
+ def new_encoder(outsize):
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ return FCEncoder(
+ obs_shape=obs_shape,
+ hidden_size_list=encoder_hidden_size_list,
+ activation=activation,
+ norm_type=norm_type
+ )
+ elif len(obs_shape) == 3:
+ return ConvEncoder(
+ obs_shape=obs_shape,
+ hidden_size_list=encoder_hidden_size_list,
+ activation=activation,
+ norm_type=norm_type
+ )
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own encoder".
+ format(obs_shape)
+ )
+
+ if self.share_encoder:
+ assert actor_head_hidden_size == critic_head_hidden_size, \
+ "actor and critic network head should have same size."
+ if encoder:
+ if isinstance(encoder, torch.nn.Module):
+ self.encoder = encoder
+ else:
+ raise ValueError("illegal encoder instance.")
+ else:
+ self.encoder = new_encoder(actor_head_hidden_size)
+ else:
+ if encoder:
+ if isinstance(encoder, torch.nn.Module):
+ self.actor_encoder = encoder
+ self.critic_encoder = deepcopy(encoder)
+ else:
+ raise ValueError("illegal encoder instance.")
+ else:
+ self.actor_encoder = new_encoder(actor_head_hidden_size)
+ self.critic_encoder = new_encoder(critic_head_hidden_size)
+
+ # Head Type
+ if not popart_head:
+ self.critic_head = RegressionHead(
+ critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
+ )
+ else:
+ self.critic_head = PopArtVHead(
+ critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
+ )
+
+ self.action_space = action_space
+ assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space
+ if self.action_space == 'continuous':
+ self.multi_head = False
+ self.actor_head = ReparameterizationHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ sigma_type=sigma_type,
+ activation=activation,
+ norm_type=norm_type,
+ bound_type=bound_type
+ )
+ elif self.action_space == 'discrete':
+ actor_head_cls = DiscretePolicyHead
+ multi_head = not isinstance(action_shape, int)
+ self.multi_head = multi_head
+ if multi_head:
+ self.actor_head = MultiHead(
+ actor_head_cls,
+ actor_head_hidden_size,
+ action_shape,
+ layer_num=actor_head_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+ else:
+ self.actor_head = actor_head_cls(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+ elif self.action_space == 'hybrid': # HPPO
+ # hybrid action space: action_type(discrete) + action_args(continuous),
+ # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])}
+ action_shape.action_args_shape = squeeze(action_shape.action_args_shape)
+ action_shape.action_type_shape = squeeze(action_shape.action_type_shape)
+ actor_action_args = ReparameterizationHead(
+ actor_head_hidden_size,
+ action_shape.action_args_shape,
+ actor_head_layer_num,
+ sigma_type=sigma_type,
+ fixed_sigma_value=fixed_sigma_value,
+ activation=activation,
+ norm_type=norm_type,
+ bound_type=bound_type,
+ )
+ actor_action_type = DiscretePolicyHead(
+ actor_head_hidden_size,
+ action_shape.action_type_shape,
+ actor_head_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ )
+ self.actor_head = nn.ModuleList([actor_action_type, actor_action_args])
+
+ # must use list, not nn.ModuleList
+ if self.share_encoder:
+ self.actor = [self.encoder, self.actor_head]
+ self.critic = [self.encoder, self.critic_head]
+ else:
+ self.actor = [self.actor_encoder, self.actor_head]
+ self.critic = [self.critic_encoder, self.critic_head]
+ # Convenient for calling some apis (e.g. self.critic.parameters()),
+ # but may cause misunderstanding when `print(self)`
+ self.actor = nn.ModuleList(self.actor)
+ self.critic = nn.ModuleList(self.critic)
+
+ def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor:
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor:
+ if self.share_encoder:
+ x = self.encoder(x)
+ else:
+ x = self.actor_encoder(x)
+
+ if self.action_space == 'discrete':
+ return self.actor_head(x)
+ elif self.action_space == 'continuous':
+ x = self.actor_head(x) # mu, sigma
+ return ttorch.as_tensor(x)
+ elif self.action_space == 'hybrid':
+ action_type = self.actor_head[0](x)
+ action_args = self.actor_head[1](x)
+ return ttorch.as_tensor({'action_type': action_type, 'action_args': action_args})
+
+ def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
+ if self.share_encoder:
+ x = self.encoder(x)
+ else:
+ x = self.critic_encoder(x)
+ x = self.critic_head(x)
+ return x
+
+ def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor:
+ if self.share_encoder:
+ actor_embedding = critic_embedding = self.encoder(x)
+ else:
+ actor_embedding = self.actor_encoder(x)
+ critic_embedding = self.critic_encoder(x)
+
+ value = self.critic_head(critic_embedding)
+
+ if self.action_space == 'discrete':
+ logit = self.actor_head(actor_embedding)
+ return ttorch.as_tensor({'logit': logit, 'value': value['pred']})
+ elif self.action_space == 'continuous':
+ x = self.actor_head(actor_embedding)
+ return ttorch.as_tensor({'logit': x, 'value': value['pred']})
+ elif self.action_space == 'hybrid':
+ action_type = self.actor_head[0](actor_embedding)
+ action_args = self.actor_head[1](actor_embedding)
+ return ttorch.as_tensor(
+ {
+ 'logit': {
+ 'action_type': action_type,
+ 'action_args': action_args
+ },
+ 'value': value['pred']
+ }
+ )
diff --git a/DI-engine/ding/bonus/pg.py b/DI-engine/ding/bonus/pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..59c031d65de20ef7caf7eb3bfd6306b9dcf584e3
--- /dev/null
+++ b/DI-engine/ding/bonus/pg.py
@@ -0,0 +1,453 @@
+from typing import Optional, Union, List
+from ditk import logging
+from easydict import EasyDict
+import os
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+from ding.framework import task, OnlineRLContext
+from ding.framework.middleware import CkptSaver, trainer, \
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, \
+ montecarlo_return_estimator, final_ctx_saver, EpisodeCollector
+from ding.envs import BaseEnv
+from ding.envs import setup_ding_env_manager
+from ding.policy import PGPolicy
+from ding.utils import set_pkg_seed
+from ding.utils import get_env_fps, render
+from ding.config import save_config_py, compile_config
+from ding.model import PG
+from ding.bonus.common import TrainingReturn, EvalReturn
+from ding.config.example.PG import supported_env_cfg
+from ding.config.example.PG import supported_env
+
+
+class PGAgent:
+ """
+ Overview:
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm Policy Gradient(PG).
+ For more information about the system design of RL agent, please refer to \
+ .
+ Interface:
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
+ """
+ supported_env_list = list(supported_env_cfg.keys())
+ """
+ Overview:
+ List of supported envs.
+ Examples:
+ >>> from ding.bonus.pg import PGAgent
+ >>> print(PGAgent.supported_env_list)
+ """
+
+ def __init__(
+ self,
+ env_id: str = None,
+ env: BaseEnv = None,
+ seed: int = 0,
+ exp_name: str = None,
+ model: Optional[torch.nn.Module] = None,
+ cfg: Optional[Union[EasyDict, dict]] = None,
+ policy_state_dict: str = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize agent for PG algorithm.
+ Arguments:
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
+ Default to 0.
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
+ - model (:obj:`torch.nn.Module`): The model of PG algorithm, which should be an instance of class \
+ :class:`ding.model.PG`. \
+ If not specified, a default model will be generated according to the configuration.
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of PG algorithm, which is a dict. \
+ Default to None. If not specified, the default configuration will be used. \
+ The default configuration can be found in ``ding/config/example/PG/gym_lunarlander_v2.py``.
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
+ If specified, the policy will be loaded from this file. Default to None.
+
+ .. note::
+ An RL Agent Instance can be initialized in two basic ways. \
+ For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
+ and we want to train an agent with PG algorithm with default configuration. \
+ Then we can initialize the agent in the following ways:
+ >>> agent = PGAgent(env_id='LunarLanderContinuous-v2')
+ or, if we want can specify the env_id in the configuration:
+ >>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
+ >>> agent = PGAgent(cfg=cfg)
+ There are also other arguments to specify the agent when initializing.
+ For example, if we want to specify the environment instance:
+ >>> env = CustomizedEnv('LunarLanderContinuous-v2')
+ >>> agent = PGAgent(cfg=cfg, env=env)
+ or, if we want to specify the model:
+ >>> model = PG(**cfg.policy.model)
+ >>> agent = PGAgent(cfg=cfg, model=model)
+ or, if we want to reload the policy from a saved policy state dict:
+ >>> agent = PGAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
+ Make sure that the configuration is consistent with the saved policy state dict.
+ """
+
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
+
+ if cfg is not None and not isinstance(cfg, EasyDict):
+ cfg = EasyDict(cfg)
+
+ if env_id is not None:
+ assert env_id in PGAgent.supported_env_list, "Please use supported envs: {}".format(
+ PGAgent.supported_env_list
+ )
+ if cfg is None:
+ cfg = supported_env_cfg[env_id]
+ else:
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
+ else:
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
+ assert cfg.env.env_id in PGAgent.supported_env_list, "Please use supported envs: {}".format(
+ PGAgent.supported_env_list
+ )
+ default_policy_config = EasyDict({"policy": PGPolicy.default_config()})
+ default_policy_config.update(cfg)
+ cfg = default_policy_config
+
+ if exp_name is not None:
+ cfg.exp_name = exp_name
+ self.cfg = compile_config(cfg, policy=PGPolicy)
+ self.exp_name = self.cfg.exp_name
+ if env is None:
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
+ else:
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
+ self.env = env
+
+ logging.getLogger().setLevel(logging.INFO)
+ self.seed = seed
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
+ if not os.path.exists(self.exp_name):
+ os.makedirs(self.exp_name)
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
+ if model is None:
+ model = PG(**self.cfg.policy.model)
+ self.policy = PGPolicy(self.cfg.policy, model=model)
+ if policy_state_dict is not None:
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
+
+ def train(
+ self,
+ step: int = int(1e7),
+ collector_env_num: int = None,
+ evaluator_env_num: int = None,
+ n_iter_save_ckpt: int = 1000,
+ context: Optional[str] = None,
+ debug: bool = False,
+ wandb_sweep: bool = False,
+ ) -> TrainingReturn:
+ """
+ Overview:
+ Train the agent with PG algorithm for ``step`` iterations with ``collector_env_num`` collector \
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
+ recorded and saved by wandb.
+ Arguments:
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
+ Default to 1000.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
+ which is a hyper-parameter optimization process for seeking the best configurations. \
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
+ Returns:
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ logging.debug(self.policy._model)
+ # define env and policy
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
+
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ interaction_evaluator(
+ self.cfg,
+ self.policy.eval_mode,
+ evaluator_env,
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
+ )
+ )
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(EpisodeCollector(self.cfg, self.policy.collect_mode, collector_env))
+ task.use(montecarlo_return_estimator(self.policy))
+ task.use(trainer(self.cfg, self.policy.learn_mode))
+ task.use(
+ wandb_online_logger(
+ metric_list=self.policy._monitor_vars_learn(),
+ model=self.policy._model,
+ anonymous=True,
+ project_name=self.exp_name,
+ wandb_sweep=wandb_sweep,
+ )
+ )
+ task.use(termination_checker(max_env_step=step))
+ task.use(final_ctx_saver(name=self.exp_name))
+ task.run()
+
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
+
+ def deploy(
+ self,
+ enable_save_replay: bool = False,
+ concatenate_all_replay: bool = False,
+ replay_save_path: str = None,
+ seed: Optional[Union[int, List]] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Deploy the agent with PG algorithm by interacting with the environment, during which the replay video \
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
+ Arguments:
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
+ the replay video of each episode will be saved separately.
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
+ If not specified, the video will be saved in ``exp_name/videos``.
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
+ Default to None. If not specified, ``self.seed`` will be used. \
+ If ``seed`` is an integer, the agent will be deployed once. \
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env = self.env.clone(caller='evaluator')
+
+ if seed is not None and isinstance(seed, int):
+ seeds = [seed]
+ elif seed is not None and isinstance(seed, list):
+ seeds = seed
+ else:
+ seeds = [self.seed]
+
+ returns = []
+ images = []
+ if enable_save_replay:
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
+ env.enable_save_replay(replay_path=replay_save_path)
+ else:
+ logging.warning('No video would be generated during the deploy.')
+ if concatenate_all_replay:
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
+ concatenate_all_replay = False
+
+ def single_env_forward_wrapper(forward_fn, cuda=True):
+
+ def _forward(obs):
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
+ if cuda and torch.cuda.is_available():
+ obs = obs.cuda()
+ output = forward_fn(obs)
+ if self.policy._cfg.deterministic_eval:
+ if self.policy._cfg.action_space == 'discrete':
+ output['action'] = output['logit'].argmax(dim=-1)
+ elif self.policy._cfg.action_space == 'continuous':
+ output['action'] = output['logit']['mu']
+ else:
+ raise KeyError("invalid action_space: {}".format(self.policy._cfg.action_space))
+ else:
+ output['action'] = output['dist'].sample()
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
+ action = output['action'].squeeze(0).detach().cpu().numpy()
+ return action
+
+ return _forward
+
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.reset()
+
+ for seed in seeds:
+ env.seed(seed, dynamic_seed=False)
+ return_ = 0.
+ step = 0
+ obs = env.reset()
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ while True:
+ action = forward_fn(obs)
+ obs, rew, done, info = env.step(action)
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ return_ += rew
+ step += 1
+ if done:
+ break
+ logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
+ returns.append(return_)
+
+ env.close()
+
+ if concatenate_all_replay:
+ images = np.concatenate(images, axis=0)
+ import imageio
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
+
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
+
+ def collect_data(
+ self,
+ env_num: int = 8,
+ save_data_path: Optional[str] = None,
+ n_sample: Optional[int] = None,
+ n_episode: Optional[int] = None,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> None:
+ """
+ Overview:
+ Collect data with PG algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
+ ``exp_name/demo_data``.
+ Arguments:
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
+ If not specified, the data will be saved in ``exp_name/demo_data``.
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
+ If not specified, ``n_episode`` must be specified.
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
+ If not specified, ``n_sample`` must be specified.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ if n_episode is not None:
+ raise NotImplementedError
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
+
+ if save_data_path is None:
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ StepCollector(
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
+ )
+ )
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
+ task.run(max_step=1)
+ logging.info(
+ f'PG collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
+ )
+
+ def batch_evaluate(
+ self,
+ env_num: int = 4,
+ n_evaluator_episode: int = 4,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Evaluate the agent with PG algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
+ environments. The evaluation result will be returned.
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
+ will only create one evaluator environment to evaluate the agent and save the replay video.
+ Arguments:
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.launch()
+ env.reset()
+
+ evaluate_cfg = self.cfg
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
+ task.run(max_step=1)
+
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
+
+ @property
+ def best(self) -> 'PGAgent':
+ """
+ Overview:
+ Load the best model from the checkpoint directory, \
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
+ The return value is the agent with the best model.
+ Returns:
+ - (:obj:`PGAgent`): The agent with the best model.
+ Examples:
+ >>> agent = PGAgent(env_id='LunarLanderContinuous-v2')
+ >>> agent.train()
+ >>> agent = agent.best
+
+ .. note::
+ The best model is the model with the highest evaluation return. If this method is called, the current \
+ model will be replaced by the best model.
+ """
+
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
+ # Load best model if it exists
+ if os.path.exists(best_model_file_path):
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ return self
diff --git a/DI-engine/ding/bonus/ppo_offpolicy.py b/DI-engine/ding/bonus/ppo_offpolicy.py
new file mode 100644
index 0000000000000000000000000000000000000000..546aecbd6d4c2028ed760b0787d9ae69fa54fc49
--- /dev/null
+++ b/DI-engine/ding/bonus/ppo_offpolicy.py
@@ -0,0 +1,471 @@
+from typing import Optional, Union, List
+from ditk import logging
+from easydict import EasyDict
+import os
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+from ding.framework import task, OnlineRLContext
+from ding.framework.middleware import CkptSaver, final_ctx_saver, OffPolicyLearner, StepCollector, \
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, gae_estimator
+from ding.envs import BaseEnv
+from ding.envs import setup_ding_env_manager
+from ding.policy import PPOOffPolicy
+from ding.utils import set_pkg_seed
+from ding.utils import get_env_fps, render
+from ding.config import save_config_py, compile_config
+from ding.model import VAC
+from ding.model import model_wrap
+from ding.data import DequeBuffer
+from ding.bonus.common import TrainingReturn, EvalReturn
+from ding.config.example.PPOOffPolicy import supported_env_cfg
+from ding.config.example.PPOOffPolicy import supported_env
+
+
+class PPOOffPolicyAgent:
+ """
+ Overview:
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
+ Proximal Policy Optimization(PPO) in an off-policy style.
+ For more information about the system design of RL agent, please refer to \
+ .
+ Interface:
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
+ """
+ supported_env_list = list(supported_env_cfg.keys())
+ """
+ Overview:
+ List of supported envs.
+ Examples:
+ >>> from ding.bonus.ppo_offpolicy import PPOOffPolicyAgent
+ >>> print(PPOOffPolicyAgent.supported_env_list)
+ """
+
+ def __init__(
+ self,
+ env_id: str = None,
+ env: BaseEnv = None,
+ seed: int = 0,
+ exp_name: str = None,
+ model: Optional[torch.nn.Module] = None,
+ cfg: Optional[Union[EasyDict, dict]] = None,
+ policy_state_dict: str = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize agent for PPO (offpolicy) algorithm.
+ Arguments:
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
+ Default to 0.
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
+ - model (:obj:`torch.nn.Module`): The model of PPO (offpolicy) algorithm, \
+ which should be an instance of class :class:`ding.model.VAC`. \
+ If not specified, a default model will be generated according to the configuration.
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of PPO (offpolicy) algorithm, which is a dict. \
+ Default to None. If not specified, the default configuration will be used. \
+ The default configuration can be found in ``ding/config/example/PPO (offpolicy)/gym_lunarlander_v2.py``.
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
+ If specified, the policy will be loaded from this file. Default to None.
+
+ .. note::
+ An RL Agent Instance can be initialized in two basic ways. \
+ For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
+ and we want to train an agent with PPO (offpolicy) algorithm with default configuration. \
+ Then we can initialize the agent in the following ways:
+ >>> agent = PPOOffPolicyAgent(env_id='LunarLander-v2')
+ or, if we want can specify the env_id in the configuration:
+ >>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
+ >>> agent = PPOOffPolicyAgent(cfg=cfg)
+ There are also other arguments to specify the agent when initializing.
+ For example, if we want to specify the environment instance:
+ >>> env = CustomizedEnv('LunarLander-v2')
+ >>> agent = PPOOffPolicyAgent(cfg=cfg, env=env)
+ or, if we want to specify the model:
+ >>> model = VAC(**cfg.policy.model)
+ >>> agent = PPOOffPolicyAgent(cfg=cfg, model=model)
+ or, if we want to reload the policy from a saved policy state dict:
+ >>> agent = PPOOffPolicyAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
+ Make sure that the configuration is consistent with the saved policy state dict.
+ """
+
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
+
+ if cfg is not None and not isinstance(cfg, EasyDict):
+ cfg = EasyDict(cfg)
+
+ if env_id is not None:
+ assert env_id in PPOOffPolicyAgent.supported_env_list, "Please use supported envs: {}".format(
+ PPOOffPolicyAgent.supported_env_list
+ )
+ if cfg is None:
+ cfg = supported_env_cfg[env_id]
+ else:
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
+ else:
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
+ assert cfg.env.env_id in PPOOffPolicyAgent.supported_env_list, "Please use supported envs: {}".format(
+ PPOOffPolicyAgent.supported_env_list
+ )
+ default_policy_config = EasyDict({"policy": PPOOffPolicy.default_config()})
+ default_policy_config.update(cfg)
+ cfg = default_policy_config
+
+ if exp_name is not None:
+ cfg.exp_name = exp_name
+ self.cfg = compile_config(cfg, policy=PPOOffPolicy)
+ self.exp_name = self.cfg.exp_name
+ if env is None:
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
+ else:
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
+ self.env = env
+
+ logging.getLogger().setLevel(logging.INFO)
+ self.seed = seed
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
+ if not os.path.exists(self.exp_name):
+ os.makedirs(self.exp_name)
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
+ if model is None:
+ model = VAC(**self.cfg.policy.model)
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
+ self.policy = PPOOffPolicy(self.cfg.policy, model=model)
+ if policy_state_dict is not None:
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
+
+ def train(
+ self,
+ step: int = int(1e7),
+ collector_env_num: int = None,
+ evaluator_env_num: int = None,
+ n_iter_save_ckpt: int = 1000,
+ context: Optional[str] = None,
+ debug: bool = False,
+ wandb_sweep: bool = False,
+ ) -> TrainingReturn:
+ """
+ Overview:
+ Train the agent with PPO (offpolicy) algorithm for ``step`` iterations with ``collector_env_num`` \
+ collector environments and ``evaluator_env_num`` evaluator environments. \
+ Information during training will be recorded and saved by wandb.
+ Arguments:
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
+ Default to 1000.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
+ which is a hyper-parameter optimization process for seeking the best configurations. \
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
+ Returns:
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ logging.debug(self.policy._model)
+ # define env and policy
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
+
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ interaction_evaluator(
+ self.cfg,
+ self.policy.eval_mode,
+ evaluator_env,
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
+ )
+ )
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(
+ StepCollector(
+ self.cfg,
+ self.policy.collect_mode,
+ collector_env,
+ random_collect_size=self.cfg.policy.random_collect_size
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
+ )
+ )
+ task.use(gae_estimator(self.cfg, self.policy.collect_mode, self.buffer_))
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
+ task.use(
+ wandb_online_logger(
+ cfg=self.cfg.wandb_logger,
+ exp_config=self.cfg,
+ metric_list=self.policy._monitor_vars_learn(),
+ model=self.policy._model,
+ anonymous=True,
+ project_name=self.exp_name,
+ wandb_sweep=wandb_sweep,
+ )
+ )
+ task.use(termination_checker(max_env_step=step))
+ task.use(final_ctx_saver(name=self.exp_name))
+ task.run()
+
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
+
+ def deploy(
+ self,
+ enable_save_replay: bool = False,
+ concatenate_all_replay: bool = False,
+ replay_save_path: str = None,
+ seed: Optional[Union[int, List]] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Deploy the agent with PPO (offpolicy) algorithm by interacting with the environment, \
+ during which the replay video can be saved if ``enable_save_replay`` is True. \
+ The evaluation result will be returned.
+ Arguments:
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
+ the replay video of each episode will be saved separately.
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
+ If not specified, the video will be saved in ``exp_name/videos``.
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
+ Default to None. If not specified, ``self.seed`` will be used. \
+ If ``seed`` is an integer, the agent will be deployed once. \
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env = self.env.clone(caller='evaluator')
+
+ if seed is not None and isinstance(seed, int):
+ seeds = [seed]
+ elif seed is not None and isinstance(seed, list):
+ seeds = seed
+ else:
+ seeds = [self.seed]
+
+ returns = []
+ images = []
+ if enable_save_replay:
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
+ env.enable_save_replay(replay_path=replay_save_path)
+ else:
+ logging.warning('No video would be generated during the deploy.')
+ if concatenate_all_replay:
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
+ concatenate_all_replay = False
+
+ def single_env_forward_wrapper(forward_fn, cuda=True):
+
+ if self.cfg.policy.action_space == 'discrete':
+ forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
+ elif self.cfg.policy.action_space == 'continuous':
+ forward_fn = model_wrap(forward_fn, wrapper_name='deterministic_sample').forward
+ elif self.cfg.policy.action_space == 'hybrid':
+ forward_fn = model_wrap(forward_fn, wrapper_name='hybrid_deterministic_argmax_sample').forward
+ elif self.cfg.policy.action_space == 'general':
+ forward_fn = model_wrap(forward_fn, wrapper_name='base').forward
+ else:
+ raise NotImplementedError
+
+ def _forward(obs):
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
+ if cuda and torch.cuda.is_available():
+ obs = obs.cuda()
+ action = forward_fn(obs, mode='compute_actor')["action"]
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
+ action = action.squeeze(0).detach().cpu().numpy()
+ return action
+
+ return _forward
+
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.reset()
+
+ for seed in seeds:
+ env.seed(seed, dynamic_seed=False)
+ return_ = 0.
+ step = 0
+ obs = env.reset()
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ while True:
+ action = forward_fn(obs)
+ obs, rew, done, info = env.step(action)
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ return_ += rew
+ step += 1
+ if done:
+ break
+ logging.info(f'PPO (offpolicy) deploy is finished, final episode return with {step} steps is: {return_}')
+ returns.append(return_)
+
+ env.close()
+
+ if concatenate_all_replay:
+ images = np.concatenate(images, axis=0)
+ import imageio
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
+
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
+
+ def collect_data(
+ self,
+ env_num: int = 8,
+ save_data_path: Optional[str] = None,
+ n_sample: Optional[int] = None,
+ n_episode: Optional[int] = None,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> None:
+ """
+ Overview:
+ Collect data with PPO (offpolicy) algorithm for ``n_episode`` episodes \
+ with ``env_num`` collector environments. \
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
+ ``exp_name/demo_data``.
+ Arguments:
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
+ If not specified, the data will be saved in ``exp_name/demo_data``.
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
+ If not specified, ``n_episode`` must be specified.
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
+ If not specified, ``n_sample`` must be specified.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ if n_episode is not None:
+ raise NotImplementedError
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
+
+ if save_data_path is None:
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ StepCollector(
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
+ )
+ )
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
+ task.run(max_step=1)
+ logging.info(
+ f'PPOOffPolicy collecting is finished, more than {n_sample} \
+ samples are collected and saved in `{save_data_path}`'
+ )
+
+ def batch_evaluate(
+ self,
+ env_num: int = 4,
+ n_evaluator_episode: int = 4,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Evaluate the agent with PPO (offpolicy) algorithm for ``n_evaluator_episode`` episodes \
+ with ``env_num`` evaluator environments. The evaluation result will be returned.
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
+ will only create one evaluator environment to evaluate the agent and save the replay video.
+ Arguments:
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.launch()
+ env.reset()
+
+ evaluate_cfg = self.cfg
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
+ task.run(max_step=1)
+
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
+
+ @property
+ def best(self) -> 'PPOOffPolicyAgent':
+ """
+ Overview:
+ Load the best model from the checkpoint directory, \
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
+ The return value is the agent with the best model.
+ Returns:
+ - (:obj:`PPOOffPolicyAgent`): The agent with the best model.
+ Examples:
+ >>> agent = PPOOffPolicyAgent(env_id='LunarLander-v2')
+ >>> agent.train()
+ >>> agent.best
+
+ .. note::
+ The best model is the model with the highest evaluation return. If this method is called, the current \
+ model will be replaced by the best model.
+ """
+
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
+ # Load best model if it exists
+ if os.path.exists(best_model_file_path):
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ return self
diff --git a/DI-engine/ding/bonus/ppof.py b/DI-engine/ding/bonus/ppof.py
new file mode 100644
index 0000000000000000000000000000000000000000..88d0b43e1efc0fffddb9092f6cfdc23ef3e10e4d
--- /dev/null
+++ b/DI-engine/ding/bonus/ppof.py
@@ -0,0 +1,509 @@
+from typing import Optional, Union, List
+from ditk import logging
+from easydict import EasyDict
+from functools import partial
+import os
+import gym
+import gymnasium
+import numpy as np
+import torch
+from ding.framework import task, OnlineRLContext
+from ding.framework.middleware import interaction_evaluator_ttorch, PPOFStepCollector, multistep_trainer, CkptSaver, \
+ wandb_online_logger, offline_data_saver, termination_checker, ppof_adv_estimator
+from ding.envs import BaseEnv, BaseEnvManagerV2, SubprocessEnvManagerV2
+from ding.policy import PPOFPolicy, single_env_forward_wrapper_ttorch
+from ding.utils import set_pkg_seed
+from ding.utils import get_env_fps, render
+from ding.config import save_config_py
+from .model import PPOFModel
+from .config import get_instance_config, get_instance_env, get_hybrid_shape
+from ding.bonus.common import TrainingReturn, EvalReturn
+
+
+class PPOF:
+ """
+ Overview:
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
+ Proximal Policy Optimization(PPO).
+ For more information about the system design of RL agent, please refer to \
+ .
+ Interface:
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
+ """
+
+ supported_env_list = [
+ # common
+ 'LunarLander-v2',
+ 'LunarLanderContinuous-v2',
+ 'BipedalWalker-v3',
+ 'Pendulum-v1',
+ 'acrobot',
+ # ch2: action
+ 'rocket_landing',
+ 'drone_fly',
+ 'hybrid_moving',
+ # ch3: obs
+ 'evogym_carrier',
+ 'mario',
+ 'di_sheep',
+ 'procgen_bigfish',
+ # ch4: reward
+ 'minigrid_fourroom',
+ 'metadrive',
+ # atari
+ 'BowlingNoFrameskip-v4',
+ 'BreakoutNoFrameskip-v4',
+ 'GopherNoFrameskip-v4'
+ 'KangarooNoFrameskip-v4',
+ 'PongNoFrameskip-v4',
+ 'QbertNoFrameskip-v4',
+ 'SpaceInvadersNoFrameskip-v4',
+ # mujoco
+ 'Hopper-v3',
+ 'HalfCheetah-v3',
+ 'Walker2d-v3',
+ ]
+ """
+ Overview:
+ List of supported envs.
+ Examples:
+ >>> from ding.bonus.ppof import PPOF
+ >>> print(PPOF.supported_env_list)
+ """
+
+ def __init__(
+ self,
+ env_id: str = None,
+ env: BaseEnv = None,
+ seed: int = 0,
+ exp_name: str = None,
+ model: Optional[torch.nn.Module] = None,
+ cfg: Optional[Union[EasyDict, dict]] = None,
+ policy_state_dict: str = None
+ ) -> None:
+ """
+ Overview:
+ Initialize agent for PPO algorithm.
+ Arguments:
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
+ If ``env_id`` is not specified, ``env_id`` in ``cfg`` must be specified. \
+ If ``env_id`` is specified, ``env_id`` in ``cfg`` will be ignored. \
+ ``env_id`` should be one of the supported envs, which can be found in ``PPOF.supported_env_list``.
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
+ If ``env`` is not specified, ``env_id`` or ``cfg.env_id`` must be specified. \
+ ``env_id`` or ``cfg.env_id`` will be used to create environment instance. \
+ If ``env`` is specified, ``env_id`` and ``cfg.env_id`` will be ignored.
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
+ Default to 0.
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
+ - model (:obj:`torch.nn.Module`): The model of PPO algorithm, which should be an instance of class \
+ ``ding.model.PPOFModel``. \
+ If not specified, a default model will be generated according to the configuration.
+ - cfg (:obj:`Union[EasyDict, dict]`): The configuration of PPO algorithm, which is a dict. \
+ Default to None. If not specified, the default configuration will be used.
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
+ If specified, the policy will be loaded from this file. Default to None.
+
+ .. note::
+ An RL Agent Instance can be initialized in two basic ways. \
+ For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
+ and we want to train an agent with PPO algorithm with default configuration. \
+ Then we can initialize the agent in the following ways:
+ >>> agent = PPOF(env_id='LunarLander-v2')
+ or, if we want can specify the env_id in the configuration:
+ >>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
+ >>> agent = PPOF(cfg=cfg)
+ There are also other arguments to specify the agent when initializing.
+ For example, if we want to specify the environment instance:
+ >>> env = CustomizedEnv('LunarLander-v2')
+ >>> agent = PPOF(cfg=cfg, env=env)
+ or, if we want to specify the model:
+ >>> model = VAC(**cfg.policy.model)
+ >>> agent = PPOF(cfg=cfg, model=model)
+ or, if we want to reload the policy from a saved policy state dict:
+ >>> agent = PPOF(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
+ Make sure that the configuration is consistent with the saved policy state dict.
+ """
+
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
+
+ if cfg is not None and not isinstance(cfg, EasyDict):
+ cfg = EasyDict(cfg)
+
+ if env_id is not None:
+ assert env_id in PPOF.supported_env_list, "Please use supported envs: {}".format(PPOF.supported_env_list)
+ if cfg is None:
+ cfg = get_instance_config(env_id, algorithm="PPOF")
+
+ if not hasattr(cfg, "env_id"):
+ cfg.env_id = env_id
+ assert cfg.env_id == env_id, "env_id in cfg should be the same as env_id in args."
+ else:
+ assert hasattr(cfg, "env_id"), "Please specify env_id in cfg."
+ assert cfg.env_id in PPOF.supported_env_list, "Please use supported envs: {}".format(
+ PPOF.supported_env_list
+ )
+
+ if exp_name is not None:
+ cfg.exp_name = exp_name
+ elif not hasattr(cfg, "exp_name"):
+ cfg.exp_name = "{}-{}".format(cfg.env_id, "PPO")
+ self.cfg = cfg
+ self.exp_name = self.cfg.exp_name
+
+ if env is None:
+ self.env = get_instance_env(self.cfg.env_id)
+ else:
+ self.env = env
+
+ logging.getLogger().setLevel(logging.INFO)
+ self.seed = seed
+ set_pkg_seed(self.seed, use_cuda=self.cfg.cuda)
+
+ if not os.path.exists(self.exp_name):
+ os.makedirs(self.exp_name)
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
+
+ action_space = self.env.action_space
+ if isinstance(action_space, (gym.spaces.Discrete, gymnasium.spaces.Discrete)):
+ action_shape = int(action_space.n)
+ elif isinstance(action_space, (gym.spaces.Tuple, gymnasium.spaces.Tuple)):
+ action_shape = get_hybrid_shape(action_space)
+ else:
+ action_shape = action_space.shape
+
+ # Three types of value normalization is supported currently
+ assert self.cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline']
+ if model is None:
+ if self.cfg.value_norm != 'popart':
+ model = PPOFModel(
+ self.env.observation_space.shape,
+ action_shape,
+ action_space=self.cfg.action_space,
+ **self.cfg.model
+ )
+ else:
+ model = PPOFModel(
+ self.env.observation_space.shape,
+ action_shape,
+ action_space=self.cfg.action_space,
+ popart_head=True,
+ **self.cfg.model
+ )
+ self.policy = PPOFPolicy(self.cfg, model=model)
+ if policy_state_dict is not None:
+ self.policy.load_state_dict(policy_state_dict)
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
+
+ def train(
+ self,
+ step: int = int(1e7),
+ collector_env_num: int = 4,
+ evaluator_env_num: int = 4,
+ n_iter_log_show: int = 500,
+ n_iter_save_ckpt: int = 1000,
+ context: Optional[str] = None,
+ reward_model: Optional[str] = None,
+ debug: bool = False,
+ wandb_sweep: bool = False,
+ ) -> TrainingReturn:
+ """
+ Overview:
+ Train the agent with PPO algorithm for ``step`` iterations with ``collector_env_num`` collector \
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
+ recorded and saved by wandb.
+ Arguments:
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
+ - collector_env_num (:obj:`int`): The number of collector environments. Default to 4.
+ - evaluator_env_num (:obj:`int`): The number of evaluator environments. Default to 4.
+ - n_iter_log_show (:obj:`int`): The frequency of logging every training iteration. Default to 500.
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
+ Default to 1000.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - reward_model (:obj:`str`): The reward model name. Default to None. This argument is not supported yet.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
+ which is a hyper-parameter optimization process for seeking the best configurations. \
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
+ Returns:
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ logging.debug(self.policy._model)
+ # define env and policy
+ collector_env = self._setup_env_manager(collector_env_num, context, debug, 'collector')
+ evaluator_env = self._setup_env_manager(evaluator_env_num, context, debug, 'evaluator')
+
+ if reward_model is not None:
+ # self.reward_model = create_reward_model(reward_model, self.cfg.reward_model)
+ pass
+
+ with task.start(ctx=OnlineRLContext()):
+ task.use(interaction_evaluator_ttorch(self.seed, self.policy, evaluator_env))
+ task.use(CkptSaver(self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(PPOFStepCollector(self.seed, self.policy, collector_env, self.cfg.n_sample))
+ task.use(ppof_adv_estimator(self.policy))
+ task.use(multistep_trainer(self.policy, log_freq=n_iter_log_show))
+ task.use(
+ wandb_online_logger(
+ metric_list=self.policy.monitor_vars(),
+ model=self.policy._model,
+ anonymous=True,
+ project_name=self.exp_name,
+ wandb_sweep=wandb_sweep,
+ )
+ )
+ task.use(termination_checker(max_env_step=step))
+ task.run()
+
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
+
+ def deploy(
+ self,
+ enable_save_replay: bool = False,
+ concatenate_all_replay: bool = False,
+ replay_save_path: str = None,
+ seed: Optional[Union[int, List]] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Deploy the agent with PPO algorithm by interacting with the environment, during which the replay video \
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
+ Arguments:
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
+ the replay video of each episode will be saved separately.
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
+ If not specified, the video will be saved in ``exp_name/videos``.
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
+ Default to None. If not specified, ``self.seed`` will be used. \
+ If ``seed`` is an integer, the agent will be deployed once. \
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env = self.env.clone(caller='evaluator')
+
+ if seed is not None and isinstance(seed, int):
+ seeds = [seed]
+ elif seed is not None and isinstance(seed, list):
+ seeds = seed
+ else:
+ seeds = [self.seed]
+
+ returns = []
+ images = []
+ if enable_save_replay:
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
+ env.enable_save_replay(replay_path=replay_save_path)
+ else:
+ logging.warning('No video would be generated during the deploy.')
+ if concatenate_all_replay:
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
+ concatenate_all_replay = False
+
+ forward_fn = single_env_forward_wrapper_ttorch(self.policy.eval, self.cfg.cuda)
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.reset()
+
+ for seed in seeds:
+ env.seed(seed, dynamic_seed=False)
+ return_ = 0.
+ step = 0
+ obs = env.reset()
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ while True:
+ action = forward_fn(obs)
+ obs, rew, done, info = env.step(action)
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ return_ += rew
+ step += 1
+ if done:
+ break
+ logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
+ returns.append(return_)
+
+ env.close()
+
+ if concatenate_all_replay:
+ images = np.concatenate(images, axis=0)
+ import imageio
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
+
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
+
+ def collect_data(
+ self,
+ env_num: int = 8,
+ save_data_path: Optional[str] = None,
+ n_sample: Optional[int] = None,
+ n_episode: Optional[int] = None,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> None:
+ """
+ Overview:
+ Collect data with PPO algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
+ ``exp_name/demo_data``.
+ Arguments:
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
+ If not specified, the data will be saved in ``exp_name/demo_data``.
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
+ If not specified, ``n_episode`` must be specified.
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
+ If not specified, ``n_sample`` must be specified.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ if n_episode is not None:
+ raise NotImplementedError
+ # define env and policy
+ env = self._setup_env_manager(env_num, context, debug, 'collector')
+ if save_data_path is None:
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(PPOFStepCollector(self.seed, self.policy, env, n_sample))
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
+ task.run(max_step=1)
+ logging.info(
+ f'PPOF collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
+ )
+
+ def batch_evaluate(
+ self,
+ env_num: int = 4,
+ n_evaluator_episode: int = 4,
+ context: Optional[str] = None,
+ debug: bool = False,
+ ) -> EvalReturn:
+ """
+ Overview:
+ Evaluate the agent with PPO algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
+ environments. The evaluation result will be returned.
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
+ will only create one evaluator environment to evaluate the agent and save the replay video.
+ Arguments:
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env = self._setup_env_manager(env_num, context, debug, 'evaluator')
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.launch()
+ env.reset()
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(interaction_evaluator_ttorch(
+ self.seed,
+ self.policy,
+ env,
+ n_evaluator_episode,
+ ))
+ task.run(max_step=1)
+
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
+
+ def _setup_env_manager(
+ self,
+ env_num: int,
+ context: Optional[str] = None,
+ debug: bool = False,
+ caller: str = 'collector'
+ ) -> BaseEnvManagerV2:
+ """
+ Overview:
+ Setup the environment manager. The environment manager is used to manage multiple environments.
+ Arguments:
+ - env_num (:obj:`int`): The number of environments.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ - caller (:obj:`str`): The caller of the environment manager. Default to 'collector'.
+ Returns:
+ - (:obj:`BaseEnvManagerV2`): The environment manager.
+ """
+ assert caller in ['evaluator', 'collector']
+ if debug:
+ env_cls = BaseEnvManagerV2
+ manager_cfg = env_cls.default_config()
+ else:
+ env_cls = SubprocessEnvManagerV2
+ manager_cfg = env_cls.default_config()
+ if context is not None:
+ manager_cfg.context = context
+ return env_cls([partial(self.env.clone, caller) for _ in range(env_num)], manager_cfg)
+
+ @property
+ def best(self) -> 'PPOF':
+ """
+ Overview:
+ Load the best model from the checkpoint directory, \
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
+ The return value is the agent with the best model.
+ Returns:
+ - (:obj:`PPOF`): The agent with the best model.
+ Examples:
+ >>> agent = PPOF(env_id='LunarLander-v2')
+ >>> agent.train()
+ >>> agent = agent.best()
+
+ .. note::
+ The best model is the model with the highest evaluation return. If this method is called, the current \
+ model will be replaced by the best model.
+ """
+
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
+ # Load best model if it exists
+ if os.path.exists(best_model_file_path):
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ return self
diff --git a/DI-engine/ding/bonus/sac.py b/DI-engine/ding/bonus/sac.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb6046476cf026f9094265bd63b1f1100f01af8f
--- /dev/null
+++ b/DI-engine/ding/bonus/sac.py
@@ -0,0 +1,457 @@
+from typing import Optional, Union, List
+from ditk import logging
+from easydict import EasyDict
+import os
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+from ding.framework import task, OnlineRLContext
+from ding.framework.middleware import CkptSaver, \
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
+ OffPolicyLearner, final_ctx_saver
+from ding.envs import BaseEnv
+from ding.envs import setup_ding_env_manager
+from ding.policy import SACPolicy
+from ding.utils import set_pkg_seed
+from ding.utils import get_env_fps, render
+from ding.config import save_config_py, compile_config
+from ding.model import ContinuousQAC
+from ding.model import model_wrap
+from ding.data import DequeBuffer
+from ding.bonus.common import TrainingReturn, EvalReturn
+from ding.config.example.SAC import supported_env_cfg
+from ding.config.example.SAC import supported_env
+
+
+class SACAgent:
+ """
+ Overview:
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
+ Soft Actor-Critic(SAC).
+ For more information about the system design of RL agent, please refer to \
+ .
+ Interface:
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
+ """
+ supported_env_list = list(supported_env_cfg.keys())
+ """
+ Overview:
+ List of supported envs.
+ Examples:
+ >>> from ding.bonus.sac import SACAgent
+ >>> print(SACAgent.supported_env_list)
+ """
+
+ def __init__(
+ self,
+ env_id: str = None,
+ env: BaseEnv = None,
+ seed: int = 0,
+ exp_name: str = None,
+ model: Optional[torch.nn.Module] = None,
+ cfg: Optional[Union[EasyDict, dict]] = None,
+ policy_state_dict: str = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize agent for SAC algorithm.
+ Arguments:
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
+ Default to 0.
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
+ - model (:obj:`torch.nn.Module`): The model of SAC algorithm, which should be an instance of class \
+ :class:`ding.model.ContinuousQAC`. \
+ If not specified, a default model will be generated according to the configuration.
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of SAC algorithm, which is a dict. \
+ Default to None. If not specified, the default configuration will be used. \
+ The default configuration can be found in ``ding/config/example/SAC/gym_lunarlander_v2.py``.
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
+ If specified, the policy will be loaded from this file. Default to None.
+
+ .. note::
+ An RL Agent Instance can be initialized in two basic ways. \
+ For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
+ and we want to train an agent with SAC algorithm with default configuration. \
+ Then we can initialize the agent in the following ways:
+ >>> agent = SACAgent(env_id='LunarLanderContinuous-v2')
+ or, if we want can specify the env_id in the configuration:
+ >>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
+ >>> agent = SACAgent(cfg=cfg)
+ There are also other arguments to specify the agent when initializing.
+ For example, if we want to specify the environment instance:
+ >>> env = CustomizedEnv('LunarLanderContinuous-v2')
+ >>> agent = SACAgent(cfg=cfg, env=env)
+ or, if we want to specify the model:
+ >>> model = ContinuousQAC(**cfg.policy.model)
+ >>> agent = SACAgent(cfg=cfg, model=model)
+ or, if we want to reload the policy from a saved policy state dict:
+ >>> agent = SACAgent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
+ Make sure that the configuration is consistent with the saved policy state dict.
+ """
+
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
+
+ if cfg is not None and not isinstance(cfg, EasyDict):
+ cfg = EasyDict(cfg)
+
+ if env_id is not None:
+ assert env_id in SACAgent.supported_env_list, "Please use supported envs: {}".format(
+ SACAgent.supported_env_list
+ )
+ if cfg is None:
+ cfg = supported_env_cfg[env_id]
+ else:
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
+ else:
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
+ assert cfg.env.env_id in SACAgent.supported_env_list, "Please use supported envs: {}".format(
+ SACAgent.supported_env_list
+ )
+ default_policy_config = EasyDict({"policy": SACPolicy.default_config()})
+ default_policy_config.update(cfg)
+ cfg = default_policy_config
+
+ if exp_name is not None:
+ cfg.exp_name = exp_name
+ self.cfg = compile_config(cfg, policy=SACPolicy)
+ self.exp_name = self.cfg.exp_name
+ if env is None:
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
+ else:
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
+ self.env = env
+
+ logging.getLogger().setLevel(logging.INFO)
+ self.seed = seed
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
+ if not os.path.exists(self.exp_name):
+ os.makedirs(self.exp_name)
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
+ if model is None:
+ model = ContinuousQAC(**self.cfg.policy.model)
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
+ self.policy = SACPolicy(self.cfg.policy, model=model)
+ if policy_state_dict is not None:
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
+
+ def train(
+ self,
+ step: int = int(1e7),
+ collector_env_num: int = None,
+ evaluator_env_num: int = None,
+ n_iter_save_ckpt: int = 1000,
+ context: Optional[str] = None,
+ debug: bool = False,
+ wandb_sweep: bool = False,
+ ) -> TrainingReturn:
+ """
+ Overview:
+ Train the agent with SAC algorithm for ``step`` iterations with ``collector_env_num`` collector \
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
+ recorded and saved by wandb.
+ Arguments:
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
+ Default to 1000.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
+ which is a hyper-parameter optimization process for seeking the best configurations. \
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
+ Returns:
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ logging.debug(self.policy._model)
+ # define env and policy
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
+
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ interaction_evaluator(
+ self.cfg,
+ self.policy.eval_mode,
+ evaluator_env,
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
+ )
+ )
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(
+ StepCollector(
+ self.cfg,
+ self.policy.collect_mode,
+ collector_env,
+ random_collect_size=self.cfg.policy.random_collect_size
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
+ )
+ )
+ task.use(data_pusher(self.cfg, self.buffer_))
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
+ task.use(
+ wandb_online_logger(
+ metric_list=self.policy._monitor_vars_learn(),
+ model=self.policy._model,
+ anonymous=True,
+ project_name=self.exp_name,
+ wandb_sweep=wandb_sweep,
+ )
+ )
+ task.use(termination_checker(max_env_step=step))
+ task.use(final_ctx_saver(name=self.exp_name))
+ task.run()
+
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
+
+ def deploy(
+ self,
+ enable_save_replay: bool = False,
+ concatenate_all_replay: bool = False,
+ replay_save_path: str = None,
+ seed: Optional[Union[int, List]] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Deploy the agent with SAC algorithm by interacting with the environment, during which the replay video \
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
+ Arguments:
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
+ the replay video of each episode will be saved separately.
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
+ If not specified, the video will be saved in ``exp_name/videos``.
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
+ Default to None. If not specified, ``self.seed`` will be used. \
+ If ``seed`` is an integer, the agent will be deployed once. \
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env = self.env.clone(caller='evaluator')
+
+ if seed is not None and isinstance(seed, int):
+ seeds = [seed]
+ elif seed is not None and isinstance(seed, list):
+ seeds = seed
+ else:
+ seeds = [self.seed]
+
+ returns = []
+ images = []
+ if enable_save_replay:
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
+ env.enable_save_replay(replay_path=replay_save_path)
+ else:
+ logging.warning('No video would be generated during the deploy.')
+ if concatenate_all_replay:
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
+ concatenate_all_replay = False
+
+ def single_env_forward_wrapper(forward_fn, cuda=True):
+
+ forward_fn = model_wrap(forward_fn, wrapper_name='base').forward
+
+ def _forward(obs):
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
+ if cuda and torch.cuda.is_available():
+ obs = obs.cuda()
+ (mu, sigma) = forward_fn(obs, mode='compute_actor')['logit']
+ action = torch.tanh(mu).detach().cpu().numpy()[0] # deterministic_eval
+ return action
+
+ return _forward
+
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.reset()
+
+ for seed in seeds:
+ env.seed(seed, dynamic_seed=False)
+ return_ = 0.
+ step = 0
+ obs = env.reset()
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ while True:
+ action = forward_fn(obs)
+ obs, rew, done, info = env.step(action)
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ return_ += rew
+ step += 1
+ if done:
+ break
+ logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
+ returns.append(return_)
+
+ env.close()
+
+ if concatenate_all_replay:
+ images = np.concatenate(images, axis=0)
+ import imageio
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
+
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
+
+ def collect_data(
+ self,
+ env_num: int = 8,
+ save_data_path: Optional[str] = None,
+ n_sample: Optional[int] = None,
+ n_episode: Optional[int] = None,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> None:
+ """
+ Overview:
+ Collect data with SAC algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
+ ``exp_name/demo_data``.
+ Arguments:
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
+ If not specified, the data will be saved in ``exp_name/demo_data``.
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
+ If not specified, ``n_episode`` must be specified.
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
+ If not specified, ``n_sample`` must be specified.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ if n_episode is not None:
+ raise NotImplementedError
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
+
+ if save_data_path is None:
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ StepCollector(
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
+ )
+ )
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
+ task.run(max_step=1)
+ logging.info(
+ f'SAC collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
+ )
+
+ def batch_evaluate(
+ self,
+ env_num: int = 4,
+ n_evaluator_episode: int = 4,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Evaluate the agent with SAC algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
+ environments. The evaluation result will be returned.
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
+ will only create one evaluator environment to evaluate the agent and save the replay video.
+ Arguments:
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.launch()
+ env.reset()
+
+ evaluate_cfg = self.cfg
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
+ task.run(max_step=1)
+
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
+
+ @property
+ def best(self) -> 'SACAgent':
+ """
+ Overview:
+ Load the best model from the checkpoint directory, \
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
+ The return value is the agent with the best model.
+ Returns:
+ - (:obj:`SACAgent`): The agent with the best model.
+ Examples:
+ >>> agent = SACAgent(env_id='LunarLanderContinuous-v2')
+ >>> agent.train()
+ >>> agent = agent.best
+
+ .. note::
+ The best model is the model with the highest evaluation return. If this method is called, the current \
+ model will be replaced by the best model.
+ """
+
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
+ # Load best model if it exists
+ if os.path.exists(best_model_file_path):
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ return self
diff --git a/DI-engine/ding/bonus/sql.py b/DI-engine/ding/bonus/sql.py
new file mode 100644
index 0000000000000000000000000000000000000000..63d26acce2a05635827b433bbc35d3356d1d8587
--- /dev/null
+++ b/DI-engine/ding/bonus/sql.py
@@ -0,0 +1,461 @@
+from typing import Optional, Union, List
+from ditk import logging
+from easydict import EasyDict
+import os
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+from ding.framework import task, OnlineRLContext
+from ding.framework.middleware import CkptSaver, \
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
+ OffPolicyLearner, final_ctx_saver, nstep_reward_enhancer, eps_greedy_handler
+from ding.envs import BaseEnv
+from ding.envs import setup_ding_env_manager
+from ding.policy import SQLPolicy
+from ding.utils import set_pkg_seed
+from ding.utils import get_env_fps, render
+from ding.config import save_config_py, compile_config
+from ding.model import DQN
+from ding.model import model_wrap
+from ding.data import DequeBuffer
+from ding.bonus.common import TrainingReturn, EvalReturn
+from ding.config.example.SQL import supported_env_cfg
+from ding.config.example.SQL import supported_env
+
+
+class SQLAgent:
+ """
+ Overview:
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
+ Soft Q-Learning(SQL).
+ For more information about the system design of RL agent, please refer to \
+ .
+ Interface:
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
+ """
+ supported_env_list = list(supported_env_cfg.keys())
+ """
+ Overview:
+ List of supported envs.
+ Examples:
+ >>> from ding.bonus.sql import SQLAgent
+ >>> print(SQLAgent.supported_env_list)
+ """
+
+ def __init__(
+ self,
+ env_id: str = None,
+ env: BaseEnv = None,
+ seed: int = 0,
+ exp_name: str = None,
+ model: Optional[torch.nn.Module] = None,
+ cfg: Optional[Union[EasyDict, dict]] = None,
+ policy_state_dict: str = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize agent for SQL algorithm.
+ Arguments:
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
+ Default to 0.
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
+ - model (:obj:`torch.nn.Module`): The model of SQL algorithm, which should be an instance of class \
+ :class:`ding.model.DQN`. \
+ If not specified, a default model will be generated according to the configuration.
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of SQL algorithm, which is a dict. \
+ Default to None. If not specified, the default configuration will be used. \
+ The default configuration can be found in ``ding/config/example/SQL/gym_lunarlander_v2.py``.
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
+ If specified, the policy will be loaded from this file. Default to None.
+
+ .. note::
+ An RL Agent Instance can be initialized in two basic ways. \
+ For example, we have an environment with id ``LunarLander-v2`` registered in gym, \
+ and we want to train an agent with SQL algorithm with default configuration. \
+ Then we can initialize the agent in the following ways:
+ >>> agent = SQLAgent(env_id='LunarLander-v2')
+ or, if we want can specify the env_id in the configuration:
+ >>> cfg = {'env': {'env_id': 'LunarLander-v2'}, 'policy': ...... }
+ >>> agent = SQLAgent(cfg=cfg)
+ There are also other arguments to specify the agent when initializing.
+ For example, if we want to specify the environment instance:
+ >>> env = CustomizedEnv('LunarLander-v2')
+ >>> agent = SQLAgent(cfg=cfg, env=env)
+ or, if we want to specify the model:
+ >>> model = DQN(**cfg.policy.model)
+ >>> agent = SQLAgent(cfg=cfg, model=model)
+ or, if we want to reload the policy from a saved policy state dict:
+ >>> agent = SQLAgent(cfg=cfg, policy_state_dict='LunarLander-v2.pth.tar')
+ Make sure that the configuration is consistent with the saved policy state dict.
+ """
+
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
+
+ if cfg is not None and not isinstance(cfg, EasyDict):
+ cfg = EasyDict(cfg)
+
+ if env_id is not None:
+ assert env_id in SQLAgent.supported_env_list, "Please use supported envs: {}".format(
+ SQLAgent.supported_env_list
+ )
+ if cfg is None:
+ cfg = supported_env_cfg[env_id]
+ else:
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
+ else:
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
+ assert cfg.env.env_id in SQLAgent.supported_env_list, "Please use supported envs: {}".format(
+ SQLAgent.supported_env_list
+ )
+ default_policy_config = EasyDict({"policy": SQLPolicy.default_config()})
+ default_policy_config.update(cfg)
+ cfg = default_policy_config
+
+ if exp_name is not None:
+ cfg.exp_name = exp_name
+ self.cfg = compile_config(cfg, policy=SQLPolicy)
+ self.exp_name = self.cfg.exp_name
+ if env is None:
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
+ else:
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
+ self.env = env
+
+ logging.getLogger().setLevel(logging.INFO)
+ self.seed = seed
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
+ if not os.path.exists(self.exp_name):
+ os.makedirs(self.exp_name)
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
+ if model is None:
+ model = DQN(**self.cfg.policy.model)
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
+ self.policy = SQLPolicy(self.cfg.policy, model=model)
+ if policy_state_dict is not None:
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
+
+ def train(
+ self,
+ step: int = int(1e7),
+ collector_env_num: int = None,
+ evaluator_env_num: int = None,
+ n_iter_save_ckpt: int = 1000,
+ context: Optional[str] = None,
+ debug: bool = False,
+ wandb_sweep: bool = False,
+ ) -> TrainingReturn:
+ """
+ Overview:
+ Train the agent with SQL algorithm for ``step`` iterations with ``collector_env_num`` collector \
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
+ recorded and saved by wandb.
+ Arguments:
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
+ Default to 1000.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
+ which is a hyper-parameter optimization process for seeking the best configurations. \
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
+ Returns:
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ logging.debug(self.policy._model)
+ # define env and policy
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
+
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ interaction_evaluator(
+ self.cfg,
+ self.policy.eval_mode,
+ evaluator_env,
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
+ )
+ )
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(eps_greedy_handler(self.cfg))
+ task.use(
+ StepCollector(
+ self.cfg,
+ self.policy.collect_mode,
+ collector_env,
+ random_collect_size=self.cfg.policy.random_collect_size
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
+ )
+ )
+ if "nstep" in self.cfg.policy and self.cfg.policy.nstep > 1:
+ task.use(nstep_reward_enhancer(self.cfg))
+ task.use(data_pusher(self.cfg, self.buffer_))
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
+ task.use(
+ wandb_online_logger(
+ metric_list=self.policy._monitor_vars_learn(),
+ model=self.policy._model,
+ anonymous=True,
+ project_name=self.exp_name,
+ wandb_sweep=wandb_sweep,
+ )
+ )
+ task.use(termination_checker(max_env_step=step))
+ task.use(final_ctx_saver(name=self.exp_name))
+ task.run()
+
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
+
+ def deploy(
+ self,
+ enable_save_replay: bool = False,
+ concatenate_all_replay: bool = False,
+ replay_save_path: str = None,
+ seed: Optional[Union[int, List]] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Deploy the agent with SQL algorithm by interacting with the environment, during which the replay video \
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
+ Arguments:
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
+ the replay video of each episode will be saved separately.
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
+ If not specified, the video will be saved in ``exp_name/videos``.
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
+ Default to None. If not specified, ``self.seed`` will be used. \
+ If ``seed`` is an integer, the agent will be deployed once. \
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env = self.env.clone(caller='evaluator')
+
+ if seed is not None and isinstance(seed, int):
+ seeds = [seed]
+ elif seed is not None and isinstance(seed, list):
+ seeds = seed
+ else:
+ seeds = [self.seed]
+
+ returns = []
+ images = []
+ if enable_save_replay:
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
+ env.enable_save_replay(replay_path=replay_save_path)
+ else:
+ logging.warning('No video would be generated during the deploy.')
+ if concatenate_all_replay:
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
+ concatenate_all_replay = False
+
+ def single_env_forward_wrapper(forward_fn, cuda=True):
+
+ forward_fn = model_wrap(forward_fn, wrapper_name='argmax_sample').forward
+
+ def _forward(obs):
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
+ if cuda and torch.cuda.is_available():
+ obs = obs.cuda()
+ action = forward_fn(obs)["action"]
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
+ action = action.squeeze(0).detach().cpu().numpy()
+ return action
+
+ return _forward
+
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.reset()
+
+ for seed in seeds:
+ env.seed(seed, dynamic_seed=False)
+ return_ = 0.
+ step = 0
+ obs = env.reset()
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ while True:
+ action = forward_fn(obs)
+ obs, rew, done, info = env.step(action)
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ return_ += rew
+ step += 1
+ if done:
+ break
+ logging.info(f'SQL deploy is finished, final episode return with {step} steps is: {return_}')
+ returns.append(return_)
+
+ env.close()
+
+ if concatenate_all_replay:
+ images = np.concatenate(images, axis=0)
+ import imageio
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
+
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
+
+ def collect_data(
+ self,
+ env_num: int = 8,
+ save_data_path: Optional[str] = None,
+ n_sample: Optional[int] = None,
+ n_episode: Optional[int] = None,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> None:
+ """
+ Overview:
+ Collect data with SQL algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
+ ``exp_name/demo_data``.
+ Arguments:
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
+ If not specified, the data will be saved in ``exp_name/demo_data``.
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
+ If not specified, ``n_episode`` must be specified.
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
+ If not specified, ``n_sample`` must be specified.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ if n_episode is not None:
+ raise NotImplementedError
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
+
+ if save_data_path is None:
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ StepCollector(
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
+ )
+ )
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
+ task.run(max_step=1)
+ logging.info(
+ f'SQL collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
+ )
+
+ def batch_evaluate(
+ self,
+ env_num: int = 4,
+ n_evaluator_episode: int = 4,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Evaluate the agent with SQL algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
+ environments. The evaluation result will be returned.
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
+ will only create one evaluator environment to evaluate the agent and save the replay video.
+ Arguments:
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.launch()
+ env.reset()
+
+ evaluate_cfg = self.cfg
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
+ task.run(max_step=1)
+
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
+
+ @property
+ def best(self) -> 'SQLAgent':
+ """
+ Overview:
+ Load the best model from the checkpoint directory, \
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
+ The return value is the agent with the best model.
+ Returns:
+ - (:obj:`SQLAgent`): The agent with the best model.
+ Examples:
+ >>> agent = SQLAgent(env_id='LunarLander-v2')
+ >>> agent.train()
+ >>> agent = agent.best
+
+ .. note::
+ The best model is the model with the highest evaluation return. If this method is called, the current \
+ model will be replaced by the best model.
+ """
+
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
+ # Load best model if it exists
+ if os.path.exists(best_model_file_path):
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ return self
diff --git a/DI-engine/ding/bonus/td3.py b/DI-engine/ding/bonus/td3.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2889a370d6f546ec3d65428513cb838423b8750
--- /dev/null
+++ b/DI-engine/ding/bonus/td3.py
@@ -0,0 +1,455 @@
+from typing import Optional, Union, List
+from ditk import logging
+from easydict import EasyDict
+import os
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+from ding.framework import task, OnlineRLContext
+from ding.framework.middleware import CkptSaver, \
+ wandb_online_logger, offline_data_saver, termination_checker, interaction_evaluator, StepCollector, data_pusher, \
+ OffPolicyLearner, final_ctx_saver
+from ding.envs import BaseEnv
+from ding.envs import setup_ding_env_manager
+from ding.policy import TD3Policy
+from ding.utils import set_pkg_seed
+from ding.utils import get_env_fps, render
+from ding.config import save_config_py, compile_config
+from ding.model import ContinuousQAC
+from ding.data import DequeBuffer
+from ding.bonus.common import TrainingReturn, EvalReturn
+from ding.config.example.TD3 import supported_env_cfg
+from ding.config.example.TD3 import supported_env
+
+
+class TD3Agent:
+ """
+ Overview:
+ Class of agent for training, evaluation and deployment of Reinforcement learning algorithm \
+ Twin Delayed Deep Deterministic Policy Gradient(TD3).
+ For more information about the system design of RL agent, please refer to \
+ .
+ Interface:
+ ``__init__``, ``train``, ``deploy``, ``collect_data``, ``batch_evaluate``, ``best``
+ """
+ supported_env_list = list(supported_env_cfg.keys())
+ """
+ Overview:
+ List of supported envs.
+ Examples:
+ >>> from ding.bonus.td3 import TD3Agent
+ >>> print(TD3Agent.supported_env_list)
+ """
+
+ def __init__(
+ self,
+ env_id: str = None,
+ env: BaseEnv = None,
+ seed: int = 0,
+ exp_name: str = None,
+ model: Optional[torch.nn.Module] = None,
+ cfg: Optional[Union[EasyDict, dict]] = None,
+ policy_state_dict: str = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize agent for TD3 algorithm.
+ Arguments:
+ - env_id (:obj:`str`): The environment id, which is a registered environment name in gym or gymnasium. \
+ If ``env_id`` is not specified, ``env_id`` in ``cfg.env`` must be specified. \
+ If ``env_id`` is specified, ``env_id`` in ``cfg.env`` will be ignored. \
+ ``env_id`` should be one of the supported envs, which can be found in ``supported_env_list``.
+ - env (:obj:`BaseEnv`): The environment instance for training and evaluation. \
+ If ``env`` is not specified, `env_id`` or ``cfg.env.env_id`` must be specified. \
+ ``env_id`` or ``cfg.env.env_id`` will be used to create environment instance. \
+ If ``env`` is specified, ``env_id`` and ``cfg.env.env_id`` will be ignored.
+ - seed (:obj:`int`): The random seed, which is set before running the program. \
+ Default to 0.
+ - exp_name (:obj:`str`): The name of this experiment, which will be used to create the folder to save \
+ log data. Default to None. If not specified, the folder name will be ``env_id``-``algorithm``.
+ - model (:obj:`torch.nn.Module`): The model of TD3 algorithm, which should be an instance of class \
+ :class:`ding.model.ContinuousQAC`. \
+ If not specified, a default model will be generated according to the configuration.
+ - cfg (:obj:Union[EasyDict, dict]): The configuration of TD3 algorithm, which is a dict. \
+ Default to None. If not specified, the default configuration will be used. \
+ The default configuration can be found in ``ding/config/example/TD3/gym_lunarlander_v2.py``.
+ - policy_state_dict (:obj:`str`): The path of policy state dict saved by PyTorch a in local file. \
+ If specified, the policy will be loaded from this file. Default to None.
+
+ .. note::
+ An RL Agent Instance can be initialized in two basic ways. \
+ For example, we have an environment with id ``LunarLanderContinuous-v2`` registered in gym, \
+ and we want to train an agent with TD3 algorithm with default configuration. \
+ Then we can initialize the agent in the following ways:
+ >>> agent = TD3Agent(env_id='LunarLanderContinuous-v2')
+ or, if we want can specify the env_id in the configuration:
+ >>> cfg = {'env': {'env_id': 'LunarLanderContinuous-v2'}, 'policy': ...... }
+ >>> agent = TD3Agent(cfg=cfg)
+ There are also other arguments to specify the agent when initializing.
+ For example, if we want to specify the environment instance:
+ >>> env = CustomizedEnv('LunarLanderContinuous-v2')
+ >>> agent = TD3Agent(cfg=cfg, env=env)
+ or, if we want to specify the model:
+ >>> model = ContinuousQAC(**cfg.policy.model)
+ >>> agent = TD3Agent(cfg=cfg, model=model)
+ or, if we want to reload the policy from a saved policy state dict:
+ >>> agent = TD3Agent(cfg=cfg, policy_state_dict='LunarLanderContinuous-v2.pth.tar')
+ Make sure that the configuration is consistent with the saved policy state dict.
+ """
+
+ assert env_id is not None or cfg is not None, "Please specify env_id or cfg."
+
+ if cfg is not None and not isinstance(cfg, EasyDict):
+ cfg = EasyDict(cfg)
+
+ if env_id is not None:
+ assert env_id in TD3Agent.supported_env_list, "Please use supported envs: {}".format(
+ TD3Agent.supported_env_list
+ )
+ if cfg is None:
+ cfg = supported_env_cfg[env_id]
+ else:
+ assert cfg.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
+ else:
+ assert hasattr(cfg.env, "env_id"), "Please specify env_id in cfg."
+ assert cfg.env.env_id in TD3Agent.supported_env_list, "Please use supported envs: {}".format(
+ TD3Agent.supported_env_list
+ )
+ default_policy_config = EasyDict({"policy": TD3Policy.default_config()})
+ default_policy_config.update(cfg)
+ cfg = default_policy_config
+
+ if exp_name is not None:
+ cfg.exp_name = exp_name
+ self.cfg = compile_config(cfg, policy=TD3Policy)
+ self.exp_name = self.cfg.exp_name
+ if env is None:
+ self.env = supported_env[cfg.env.env_id](cfg=cfg.env)
+ else:
+ assert isinstance(env, BaseEnv), "Please use BaseEnv as env data type."
+ self.env = env
+
+ logging.getLogger().setLevel(logging.INFO)
+ self.seed = seed
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
+ if not os.path.exists(self.exp_name):
+ os.makedirs(self.exp_name)
+ save_config_py(self.cfg, os.path.join(self.exp_name, 'policy_config.py'))
+ if model is None:
+ model = ContinuousQAC(**self.cfg.policy.model)
+ self.buffer_ = DequeBuffer(size=self.cfg.policy.other.replay_buffer.replay_buffer_size)
+ self.policy = TD3Policy(self.cfg.policy, model=model)
+ if policy_state_dict is not None:
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
+
+ def train(
+ self,
+ step: int = int(1e7),
+ collector_env_num: int = None,
+ evaluator_env_num: int = None,
+ n_iter_save_ckpt: int = 1000,
+ context: Optional[str] = None,
+ debug: bool = False,
+ wandb_sweep: bool = False,
+ ) -> TrainingReturn:
+ """
+ Overview:
+ Train the agent with TD3 algorithm for ``step`` iterations with ``collector_env_num`` collector \
+ environments and ``evaluator_env_num`` evaluator environments. Information during training will be \
+ recorded and saved by wandb.
+ Arguments:
+ - step (:obj:`int`): The total training environment steps of all collector environments. Default to 1e7.
+ - collector_env_num (:obj:`int`): The collector environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - evaluator_env_num (:obj:`int`): The evaluator environment number. Default to None. \
+ If not specified, it will be set according to the configuration.
+ - n_iter_save_ckpt (:obj:`int`): The frequency of saving checkpoint every training iteration. \
+ Default to 1000.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep, \
+ which is a hyper-parameter optimization process for seeking the best configurations. \
+ Default to False. If True, the wandb sweep id will be used as the experiment name.
+ Returns:
+ - (:obj:`TrainingReturn`): The training result, of which the attributions are:
+ - wandb_url (:obj:`str`): The weight & biases (wandb) project url of the trainning experiment.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ logging.debug(self.policy._model)
+ # define env and policy
+ collector_env_num = collector_env_num if collector_env_num else self.cfg.env.collector_env_num
+ evaluator_env_num = evaluator_env_num if evaluator_env_num else self.cfg.env.evaluator_env_num
+ collector_env = setup_ding_env_manager(self.env, collector_env_num, context, debug, 'collector')
+ evaluator_env = setup_ding_env_manager(self.env, evaluator_env_num, context, debug, 'evaluator')
+
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ interaction_evaluator(
+ self.cfg,
+ self.policy.eval_mode,
+ evaluator_env,
+ render=self.cfg.policy.eval.render if hasattr(self.cfg.policy.eval, "render") else False
+ )
+ )
+ task.use(CkptSaver(policy=self.policy, save_dir=self.checkpoint_save_dir, train_freq=n_iter_save_ckpt))
+ task.use(
+ StepCollector(
+ self.cfg,
+ self.policy.collect_mode,
+ collector_env,
+ random_collect_size=self.cfg.policy.random_collect_size
+ if hasattr(self.cfg.policy, 'random_collect_size') else 0,
+ )
+ )
+ task.use(data_pusher(self.cfg, self.buffer_))
+ task.use(OffPolicyLearner(self.cfg, self.policy.learn_mode, self.buffer_))
+ task.use(
+ wandb_online_logger(
+ metric_list=self.policy._monitor_vars_learn(),
+ model=self.policy._model,
+ anonymous=True,
+ project_name=self.exp_name,
+ wandb_sweep=wandb_sweep,
+ )
+ )
+ task.use(termination_checker(max_env_step=step))
+ task.use(final_ctx_saver(name=self.exp_name))
+ task.run()
+
+ return TrainingReturn(wandb_url=task.ctx.wandb_url)
+
+ def deploy(
+ self,
+ enable_save_replay: bool = False,
+ concatenate_all_replay: bool = False,
+ replay_save_path: str = None,
+ seed: Optional[Union[int, List]] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Deploy the agent with TD3 algorithm by interacting with the environment, during which the replay video \
+ can be saved if ``enable_save_replay`` is True. The evaluation result will be returned.
+ Arguments:
+ - enable_save_replay (:obj:`bool`): Whether to save the replay video. Default to False.
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one video. \
+ Default to False. If ``enable_save_replay`` is False, this argument will be ignored. \
+ If ``enable_save_replay`` is True and ``concatenate_all_replay`` is False, \
+ the replay video of each episode will be saved separately.
+ - replay_save_path (:obj:`str`): The path to save the replay video. Default to None. \
+ If not specified, the video will be saved in ``exp_name/videos``.
+ - seed (:obj:`Union[int, List]`): The random seed, which is set before running the program. \
+ Default to None. If not specified, ``self.seed`` will be used. \
+ If ``seed`` is an integer, the agent will be deployed once. \
+ If ``seed`` is a list of integers, the agent will be deployed once for each seed in the list.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env = self.env.clone(caller='evaluator')
+
+ if seed is not None and isinstance(seed, int):
+ seeds = [seed]
+ elif seed is not None and isinstance(seed, list):
+ seeds = seed
+ else:
+ seeds = [self.seed]
+
+ returns = []
+ images = []
+ if enable_save_replay:
+ replay_save_path = os.path.join(self.exp_name, 'videos') if replay_save_path is None else replay_save_path
+ env.enable_save_replay(replay_path=replay_save_path)
+ else:
+ logging.warning('No video would be generated during the deploy.')
+ if concatenate_all_replay:
+ logging.warning('concatenate_all_replay is set to False because enable_save_replay is False.')
+ concatenate_all_replay = False
+
+ def single_env_forward_wrapper(forward_fn, cuda=True):
+
+ def _forward(obs):
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
+ if cuda and torch.cuda.is_available():
+ obs = obs.cuda()
+ action = forward_fn(obs, mode='compute_actor')["action"]
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
+ action = action.squeeze(0).detach().cpu().numpy()
+ return action
+
+ return _forward
+
+ forward_fn = single_env_forward_wrapper(self.policy._model, self.cfg.policy.cuda)
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.reset()
+
+ for seed in seeds:
+ env.seed(seed, dynamic_seed=False)
+ return_ = 0.
+ step = 0
+ obs = env.reset()
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ while True:
+ action = forward_fn(obs)
+ obs, rew, done, info = env.step(action)
+ images.append(render(env)[None]) if concatenate_all_replay else None
+ return_ += rew
+ step += 1
+ if done:
+ break
+ logging.info(f'DQN deploy is finished, final episode return with {step} steps is: {return_}')
+ returns.append(return_)
+
+ env.close()
+
+ if concatenate_all_replay:
+ images = np.concatenate(images, axis=0)
+ import imageio
+ imageio.mimwrite(os.path.join(replay_save_path, 'deploy.mp4'), images, fps=get_env_fps(env))
+
+ return EvalReturn(eval_value=np.mean(returns), eval_value_std=np.std(returns))
+
+ def collect_data(
+ self,
+ env_num: int = 8,
+ save_data_path: Optional[str] = None,
+ n_sample: Optional[int] = None,
+ n_episode: Optional[int] = None,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> None:
+ """
+ Overview:
+ Collect data with TD3 algorithm for ``n_episode`` episodes with ``env_num`` collector environments. \
+ The collected data will be saved in ``save_data_path`` if specified, otherwise it will be saved in \
+ ``exp_name/demo_data``.
+ Arguments:
+ - env_num (:obj:`int`): The number of collector environments. Default to 8.
+ - save_data_path (:obj:`str`): The path to save the collected data. Default to None. \
+ If not specified, the data will be saved in ``exp_name/demo_data``.
+ - n_sample (:obj:`int`): The number of samples to collect. Default to None. \
+ If not specified, ``n_episode`` must be specified.
+ - n_episode (:obj:`int`): The number of episodes to collect. Default to None. \
+ If not specified, ``n_sample`` must be specified.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ if n_episode is not None:
+ raise NotImplementedError
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.collector_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'collector')
+
+ if save_data_path is None:
+ save_data_path = os.path.join(self.exp_name, 'demo_data')
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(
+ StepCollector(
+ self.cfg, self.policy.collect_mode, env, random_collect_size=self.cfg.policy.random_collect_size
+ )
+ )
+ task.use(offline_data_saver(save_data_path, data_type='hdf5'))
+ task.run(max_step=1)
+ logging.info(
+ f'TD3 collecting is finished, more than {n_sample} samples are collected and saved in `{save_data_path}`'
+ )
+
+ def batch_evaluate(
+ self,
+ env_num: int = 4,
+ n_evaluator_episode: int = 4,
+ context: Optional[str] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Evaluate the agent with TD3 algorithm for ``n_evaluator_episode`` episodes with ``env_num`` evaluator \
+ environments. The evaluation result will be returned.
+ The difference between methods ``batch_evaluate`` and ``deploy`` is that ``batch_evaluate`` will create \
+ multiple evaluator environments to evaluate the agent to get an average performance, while ``deploy`` \
+ will only create one evaluator environment to evaluate the agent and save the replay video.
+ Arguments:
+ - env_num (:obj:`int`): The number of evaluator environments. Default to 4.
+ - n_evaluator_episode (:obj:`int`): The number of episodes to evaluate. Default to 4.
+ - context (:obj:`str`): The multi-process context of the environment manager. Default to None. \
+ It can be specified as ``spawn``, ``fork`` or ``forkserver``.
+ - debug (:obj:`bool`): Whether to use debug mode in the environment manager. Default to False. \
+ If set True, base environment manager will be used for easy debugging. Otherwise, \
+ subprocess environment manager will be used.
+ Returns:
+ - (:obj:`EvalReturn`): The evaluation result, of which the attributions are:
+ - eval_value (:obj:`np.float32`): The mean of evaluation return.
+ - eval_value_std (:obj:`np.float32`): The standard deviation of evaluation return.
+ """
+
+ if debug:
+ logging.getLogger().setLevel(logging.DEBUG)
+ # define env and policy
+ env_num = env_num if env_num else self.cfg.env.evaluator_env_num
+ env = setup_ding_env_manager(self.env, env_num, context, debug, 'evaluator')
+
+ # reset first to make sure the env is in the initial state
+ # env will be reset again in the main loop
+ env.launch()
+ env.reset()
+
+ evaluate_cfg = self.cfg
+ evaluate_cfg.env.n_evaluator_episode = n_evaluator_episode
+
+ # main execution task
+ with task.start(ctx=OnlineRLContext()):
+ task.use(interaction_evaluator(self.cfg, self.policy.eval_mode, env))
+ task.run(max_step=1)
+
+ return EvalReturn(eval_value=task.ctx.eval_value, eval_value_std=task.ctx.eval_value_std)
+
+ @property
+ def best(self) -> 'TD3Agent':
+ """
+ Overview:
+ Load the best model from the checkpoint directory, \
+ which is by default in folder ``exp_name/ckpt/eval.pth.tar``. \
+ The return value is the agent with the best model.
+ Returns:
+ - (:obj:`TD3Agent`): The agent with the best model.
+ Examples:
+ >>> agent = TD3Agent(env_id='LunarLanderContinuous-v2')
+ >>> agent.train()
+ >>> agent.best
+
+ .. note::
+ The best model is the model with the highest evaluation return. If this method is called, the current \
+ model will be replaced by the best model.
+ """
+
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "eval.pth.tar")
+ # Load best model if it exists
+ if os.path.exists(best_model_file_path):
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ return self
diff --git a/DI-engine/ding/compatibility.py b/DI-engine/ding/compatibility.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd6b1fd0dae9c49b26962f3c1ec1190542027e4e
--- /dev/null
+++ b/DI-engine/ding/compatibility.py
@@ -0,0 +1,9 @@
+import torch
+
+
+def torch_ge_131():
+ return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 131
+
+
+def torch_ge_180():
+ return int("".join(list(filter(str.isdigit, torch.__version__)))) >= 180
diff --git a/DI-engine/ding/config/__init__.py b/DI-engine/ding/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..162fc86c86b05815e44928fef4b5ab41f2e331d8
--- /dev/null
+++ b/DI-engine/ding/config/__init__.py
@@ -0,0 +1,4 @@
+from .config import Config, read_config, save_config, compile_config, compile_config_parallel, read_config_directly, \
+ read_config_with_system, save_config_py
+from .utils import parallel_transform, parallel_transform_slurm
+from .example import A2C, C51, DDPG, DQN, PG, PPOF, PPOOffPolicy, SAC, SQL, TD3
diff --git a/DI-engine/ding/config/config.py b/DI-engine/ding/config/config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b92921a6ce337396eef3c82e13cfb02f7c86d58
--- /dev/null
+++ b/DI-engine/ding/config/config.py
@@ -0,0 +1,579 @@
+import os
+import os.path as osp
+import yaml
+import json
+import shutil
+import sys
+import time
+import tempfile
+import subprocess
+import datetime
+from importlib import import_module
+from typing import Optional, Tuple
+from easydict import EasyDict
+from copy import deepcopy
+
+from ding.utils import deep_merge_dicts, get_rank
+from ding.envs import get_env_cls, get_env_manager_cls, BaseEnvManager
+from ding.policy import get_policy_cls
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, Coordinator, \
+ AdvancedReplayBuffer, get_parallel_commander_cls, get_parallel_collector_cls, get_buffer_cls, \
+ get_serial_collector_cls, MetricSerialEvaluator, BattleInteractionSerialEvaluator
+from ding.reward_model import get_reward_model_cls
+from ding.world_model import get_world_model_cls
+from .utils import parallel_transform, parallel_transform_slurm, parallel_transform_k8s, save_config_formatted
+
+
+class Config(object):
+ r"""
+ Overview:
+ Base class for config.
+ Interface:
+ __init__, file_to_dict
+ Property:
+ cfg_dict
+ """
+
+ def __init__(
+ self,
+ cfg_dict: Optional[dict] = None,
+ cfg_text: Optional[str] = None,
+ filename: Optional[str] = None
+ ) -> None:
+ """
+ Overview:
+ Init method. Create config including dict type config and text type config.
+ Arguments:
+ - cfg_dict (:obj:`Optional[dict]`): dict type config
+ - cfg_text (:obj:`Optional[str]`): text type config
+ - filename (:obj:`Optional[str]`): config file name
+ """
+ if cfg_dict is None:
+ cfg_dict = {}
+ if not isinstance(cfg_dict, dict):
+ raise TypeError("invalid type for cfg_dict: {}".format(type(cfg_dict)))
+ self._cfg_dict = cfg_dict
+ if cfg_text:
+ text = cfg_text
+ elif filename:
+ with open(filename, 'r') as f:
+ text = f.read()
+ else:
+ text = '.'
+ self._text = text
+ self._filename = filename
+
+ @staticmethod
+ def file_to_dict(filename: str) -> 'Config': # noqa
+ """
+ Overview:
+ Read config file and create config.
+ Arguments:
+ - filename (:obj:`Optional[str]`): config file name.
+ Returns:
+ - cfg_dict (:obj:`Config`): config class
+ """
+ cfg_dict, cfg_text = Config._file_to_dict(filename)
+ return Config(cfg_dict, cfg_text, filename=filename)
+
+ @staticmethod
+ def _file_to_dict(filename: str) -> Tuple[dict, str]:
+ """
+ Overview:
+ Read config file and convert the config file to dict type config and text type config.
+ Arguments:
+ - filename (:obj:`Optional[str]`): config file name.
+ Returns:
+ - cfg_dict (:obj:`Optional[dict]`): dict type config
+ - cfg_text (:obj:`Optional[str]`): text type config
+ """
+ filename = osp.abspath(osp.expanduser(filename))
+ # TODO check exist
+ # TODO check suffix
+ ext_name = osp.splitext(filename)[-1]
+ with tempfile.TemporaryDirectory() as temp_config_dir:
+ temp_config_file = tempfile.NamedTemporaryFile(dir=temp_config_dir, suffix=ext_name)
+ temp_config_name = osp.basename(temp_config_file.name)
+ temp_config_file.close()
+ shutil.copyfile(filename, temp_config_file.name)
+
+ temp_module_name = osp.splitext(temp_config_name)[0]
+ sys.path.insert(0, temp_config_dir)
+ # TODO validate py syntax
+ module = import_module(temp_module_name)
+ cfg_dict = {k: v for k, v in module.__dict__.items() if not k.startswith('_')}
+ del sys.modules[temp_module_name]
+ sys.path.pop(0)
+
+ cfg_text = filename + '\n'
+ with open(filename, 'r') as f:
+ cfg_text += f.read()
+
+ return cfg_dict, cfg_text
+
+ @property
+ def cfg_dict(self) -> dict:
+ return self._cfg_dict
+
+
+def read_config_yaml(path: str) -> EasyDict:
+ """
+ Overview:
+ read configuration from path
+ Arguments:
+ - path (:obj:`str`): Path of source yaml
+ Returns:
+ - (:obj:`EasyDict`): Config data from this file with dict type
+ """
+ with open(path, "r") as f:
+ config_ = yaml.safe_load(f)
+
+ return EasyDict(config_)
+
+
+def save_config_yaml(config_: dict, path: str) -> None:
+ """
+ Overview:
+ save configuration to path
+ Arguments:
+ - config (:obj:`dict`): Config dict
+ - path (:obj:`str`): Path of target yaml
+ """
+ config_string = json.dumps(config_)
+ with open(path, "w") as f:
+ yaml.safe_dump(json.loads(config_string), f)
+
+
+def save_config_py(config_: dict, path: str) -> None:
+ """
+ Overview:
+ save configuration to python file
+ Arguments:
+ - config (:obj:`dict`): Config dict
+ - path (:obj:`str`): Path of target yaml
+ """
+ # config_string = json.dumps(config_, indent=4)
+ config_string = str(config_)
+ from yapf.yapflib.yapf_api import FormatCode
+ config_string, _ = FormatCode(config_string)
+ config_string = config_string.replace('inf,', 'float("inf"),')
+ with open(path, "w") as f:
+ f.write('exp_config = ' + config_string)
+
+
+def read_config_directly(path: str) -> dict:
+ """
+ Overview:
+ Read configuration from a file path(now only support python file) and directly return results.
+ Arguments:
+ - path (:obj:`str`): Path of configuration file
+ Returns:
+ - cfg (:obj:`Tuple[dict, dict]`): Configuration dict.
+ """
+ suffix = path.split('.')[-1]
+ if suffix == 'py':
+ return Config.file_to_dict(path).cfg_dict
+ else:
+ raise KeyError("invalid config file suffix: {}".format(suffix))
+
+
+def read_config(path: str) -> Tuple[dict, dict]:
+ """
+ Overview:
+ Read configuration from a file path(now only suport python file). And select some proper parts.
+ Arguments:
+ - path (:obj:`str`): Path of configuration file
+ Returns:
+ - cfg (:obj:`Tuple[dict, dict]`): A collection(tuple) of configuration dict, divided into `main_config` and \
+ `create_cfg` two parts.
+ """
+ suffix = path.split('.')[-1]
+ if suffix == 'py':
+ cfg = Config.file_to_dict(path).cfg_dict
+ assert "main_config" in cfg, "Please make sure a 'main_config' variable is declared in config python file!"
+ assert "create_config" in cfg, "Please make sure a 'create_config' variable is declared in config python file!"
+ return cfg['main_config'], cfg['create_config']
+ else:
+ raise KeyError("invalid config file suffix: {}".format(suffix))
+
+
+def read_config_with_system(path: str) -> Tuple[dict, dict, dict]:
+ """
+ Overview:
+ Read configuration from a file path(now only suport python file). And select some proper parts
+ Arguments:
+ - path (:obj:`str`): Path of configuration file
+ Returns:
+ - cfg (:obj:`Tuple[dict, dict]`): A collection(tuple) of configuration dict, divided into `main_config`, \
+ `create_cfg` and `system_config` three parts.
+ """
+ suffix = path.split('.')[-1]
+ if suffix == 'py':
+ cfg = Config.file_to_dict(path).cfg_dict
+ assert "main_config" in cfg, "Please make sure a 'main_config' variable is declared in config python file!"
+ assert "create_config" in cfg, "Please make sure a 'create_config' variable is declared in config python file!"
+ assert "system_config" in cfg, "Please make sure a 'system_config' variable is declared in config python file!"
+ return cfg['main_config'], cfg['create_config'], cfg['system_config']
+ else:
+ raise KeyError("invalid config file suffix: {}".format(suffix))
+
+
+def save_config(config_: dict, path: str, type_: str = 'py', save_formatted: bool = False) -> None:
+ """
+ Overview:
+ save configuration to python file or yaml file
+ Arguments:
+ - config (:obj:`dict`): Config dict
+ - path (:obj:`str`): Path of target yaml or target python file
+ - type (:obj:`str`): If type is ``yaml`` , save configuration to yaml file. If type is ``py`` , save\
+ configuration to python file.
+ - save_formatted (:obj:`bool`): If save_formatted is true, save formatted config to path.\
+ Formatted config can be read by serial_pipeline directly.
+ """
+ assert type_ in ['yaml', 'py'], type_
+ if type_ == 'yaml':
+ save_config_yaml(config_, path)
+ elif type_ == 'py':
+ save_config_py(config_, path)
+ if save_formatted:
+ formated_path = osp.join(osp.dirname(path), 'formatted_' + osp.basename(path))
+ save_config_formatted(config_, formated_path)
+
+
+def compile_buffer_config(policy_cfg: EasyDict, user_cfg: EasyDict, buffer_cls: 'IBuffer') -> EasyDict: # noqa
+
+ def _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, buffer_cls):
+
+ if buffer_cls is None:
+ assert 'type' in policy_buffer_cfg, "please indicate buffer type in create_cfg"
+ buffer_cls = get_buffer_cls(policy_buffer_cfg)
+ buffer_cfg = deep_merge_dicts(buffer_cls.default_config(), policy_buffer_cfg)
+ buffer_cfg = deep_merge_dicts(buffer_cfg, user_buffer_cfg)
+ return buffer_cfg
+
+ policy_multi_buffer = policy_cfg.other.replay_buffer.get('multi_buffer', False)
+ user_multi_buffer = user_cfg.policy.get('other', {}).get('replay_buffer', {}).get('multi_buffer', False)
+ assert not user_multi_buffer or user_multi_buffer == policy_multi_buffer, "For multi_buffer, \
+ user_cfg({}) and policy_cfg({}) must be in accordance".format(user_multi_buffer, policy_multi_buffer)
+ multi_buffer = policy_multi_buffer
+ if not multi_buffer:
+ policy_buffer_cfg = policy_cfg.other.replay_buffer
+ user_buffer_cfg = user_cfg.policy.get('other', {}).get('replay_buffer', {})
+ return _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, buffer_cls)
+ else:
+ return_cfg = EasyDict()
+ for buffer_name in policy_cfg.other.replay_buffer: # Only traverse keys in policy_cfg
+ if buffer_name == 'multi_buffer':
+ continue
+ policy_buffer_cfg = policy_cfg.other.replay_buffer[buffer_name]
+ user_buffer_cfg = user_cfg.policy.get('other', {}).get('replay_buffer', {}).get('buffer_name', {})
+ if buffer_cls is None:
+ return_cfg[buffer_name] = _compile_buffer_config(policy_buffer_cfg, user_buffer_cfg, None)
+ else:
+ return_cfg[buffer_name] = _compile_buffer_config(
+ policy_buffer_cfg, user_buffer_cfg, buffer_cls[buffer_name]
+ )
+ return_cfg[buffer_name].name = buffer_name
+ return return_cfg
+
+
+def compile_collector_config(
+ policy_cfg: EasyDict,
+ user_cfg: EasyDict,
+ collector_cls: 'ISerialCollector' # noqa
+) -> EasyDict:
+ policy_collector_cfg = policy_cfg.collect.collector
+ user_collector_cfg = user_cfg.policy.get('collect', {}).get('collector', {})
+ # step1: get collector class
+ # two cases: create cfg merged in policy_cfg, collector class, and class has higher priority
+ if collector_cls is None:
+ assert 'type' in policy_collector_cfg, "please indicate collector type in create_cfg"
+ # use type to get collector_cls
+ collector_cls = get_serial_collector_cls(policy_collector_cfg)
+ # step2: policy collector cfg merge to collector cfg
+ collector_cfg = deep_merge_dicts(collector_cls.default_config(), policy_collector_cfg)
+ # step3: user collector cfg merge to the step2 config
+ collector_cfg = deep_merge_dicts(collector_cfg, user_collector_cfg)
+
+ return collector_cfg
+
+
+policy_config_template = dict(
+ model=dict(),
+ learn=dict(learner=dict()),
+ collect=dict(collector=dict()),
+ eval=dict(evaluator=dict()),
+ other=dict(replay_buffer=dict()),
+)
+policy_config_template = EasyDict(policy_config_template)
+env_config_template = dict(manager=dict(), stop_value=int(1e10), n_evaluator_episode=4)
+env_config_template = EasyDict(env_config_template)
+
+
+def save_project_state(exp_name: str) -> None:
+
+ def _fn(cmd: str):
+ return subprocess.run(cmd, shell=True, stdout=subprocess.PIPE).stdout.strip().decode("utf-8")
+
+ if subprocess.run("git status", shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE).returncode == 0:
+ short_sha = _fn("git describe --always")
+ log = _fn("git log --stat -n 5")
+ diff = _fn("git diff")
+ with open(os.path.join(exp_name, "git_log.txt"), "w", encoding='utf-8') as f:
+ f.write(short_sha + '\n\n' + log)
+ with open(os.path.join(exp_name, "git_diff.txt"), "w", encoding='utf-8') as f:
+ f.write(diff)
+
+
+def compile_config(
+ cfg: EasyDict,
+ env_manager: type = None,
+ policy: type = None,
+ learner: type = BaseLearner,
+ collector: type = None,
+ evaluator: type = InteractionSerialEvaluator,
+ buffer: type = None,
+ env: type = None,
+ reward_model: type = None,
+ world_model: type = None,
+ seed: int = 0,
+ auto: bool = False,
+ create_cfg: dict = None,
+ save_cfg: bool = True,
+ save_path: str = 'total_config.py',
+ renew_dir: bool = True,
+) -> EasyDict:
+ """
+ Overview:
+ Combine the input config information with other input information.
+ Compile config to make it easy to be called by other programs
+ Arguments:
+ - cfg (:obj:`EasyDict`): Input config dict which is to be used in the following pipeline
+ - env_manager (:obj:`type`): Env_manager class which is to be used in the following pipeline
+ - policy (:obj:`type`): Policy class which is to be used in the following pipeline
+ - learner (:obj:`type`): Input learner class, defaults to BaseLearner
+ - collector (:obj:`type`): Input collector class, defaults to BaseSerialCollector
+ - evaluator (:obj:`type`): Input evaluator class, defaults to InteractionSerialEvaluator
+ - buffer (:obj:`type`): Input buffer class, defaults to IBuffer
+ - env (:obj:`type`): Environment class which is to be used in the following pipeline
+ - reward_model (:obj:`type`): Reward model class which aims to offer various and valuable reward
+ - seed (:obj:`int`): Random number seed
+ - auto (:obj:`bool`): Compile create_config dict or not
+ - create_cfg (:obj:`dict`): Input create config dict
+ - save_cfg (:obj:`bool`): Save config or not
+ - save_path (:obj:`str`): Path of saving file
+ - renew_dir (:obj:`bool`): Whether to new a directory for saving config.
+ Returns:
+ - cfg (:obj:`EasyDict`): Config after compiling
+ """
+ cfg, create_cfg = deepcopy(cfg), deepcopy(create_cfg)
+ if auto:
+ assert create_cfg is not None
+ # for compatibility
+ if 'collector' not in create_cfg:
+ create_cfg.collector = EasyDict(dict(type='sample'))
+ if 'replay_buffer' not in create_cfg:
+ create_cfg.replay_buffer = EasyDict(dict(type='advanced'))
+ buffer = AdvancedReplayBuffer
+ if env is None:
+ if 'env' in create_cfg:
+ env = get_env_cls(create_cfg.env)
+ else:
+ env = None
+ create_cfg.env = {'type': 'ding_env_wrapper_generated'}
+ if env_manager is None:
+ env_manager = get_env_manager_cls(create_cfg.env_manager)
+ if policy is None:
+ policy = get_policy_cls(create_cfg.policy)
+ if 'default_config' in dir(env):
+ env_config = env.default_config()
+ else:
+ env_config = EasyDict() # env does not have default_config
+ env_config = deep_merge_dicts(env_config_template, env_config)
+ env_config.update(create_cfg.env)
+ env_config.manager = deep_merge_dicts(env_manager.default_config(), env_config.manager)
+ env_config.manager.update(create_cfg.env_manager)
+ policy_config = policy.default_config()
+ policy_config = deep_merge_dicts(policy_config_template, policy_config)
+ policy_config.update(create_cfg.policy)
+ policy_config.collect.collector.update(create_cfg.collector)
+ if 'evaluator' in create_cfg:
+ policy_config.eval.evaluator.update(create_cfg.evaluator)
+ policy_config.other.replay_buffer.update(create_cfg.replay_buffer)
+
+ policy_config.other.commander = BaseSerialCommander.default_config()
+ if 'reward_model' in create_cfg:
+ reward_model = get_reward_model_cls(create_cfg.reward_model)
+ reward_model_config = reward_model.default_config()
+ else:
+ reward_model_config = EasyDict()
+ if 'world_model' in create_cfg:
+ world_model = get_world_model_cls(create_cfg.world_model)
+ world_model_config = world_model.default_config()
+ world_model_config.update(create_cfg.world_model)
+ else:
+ world_model_config = EasyDict()
+ else:
+ if 'default_config' in dir(env):
+ env_config = env.default_config()
+ else:
+ env_config = EasyDict() # env does not have default_config
+ env_config = deep_merge_dicts(env_config_template, env_config)
+ if env_manager is None:
+ env_manager = BaseEnvManager # for compatibility
+ env_config.manager = deep_merge_dicts(env_manager.default_config(), env_config.manager)
+ policy_config = policy.default_config()
+ policy_config = deep_merge_dicts(policy_config_template, policy_config)
+ if reward_model is None:
+ reward_model_config = EasyDict()
+ else:
+ reward_model_config = reward_model.default_config()
+ if world_model is None:
+ world_model_config = EasyDict()
+ else:
+ world_model_config = world_model.default_config()
+ world_model_config.update(create_cfg.world_model)
+ policy_config.learn.learner = deep_merge_dicts(
+ learner.default_config(),
+ policy_config.learn.learner,
+ )
+ if create_cfg is not None or collector is not None:
+ policy_config.collect.collector = compile_collector_config(policy_config, cfg, collector)
+ if evaluator:
+ policy_config.eval.evaluator = deep_merge_dicts(
+ evaluator.default_config(),
+ policy_config.eval.evaluator,
+ )
+ if create_cfg is not None or buffer is not None:
+ policy_config.other.replay_buffer = compile_buffer_config(policy_config, cfg, buffer)
+ default_config = EasyDict({'env': env_config, 'policy': policy_config})
+ if len(reward_model_config) > 0:
+ default_config['reward_model'] = reward_model_config
+ if len(world_model_config) > 0:
+ default_config['world_model'] = world_model_config
+ cfg = deep_merge_dicts(default_config, cfg)
+ if 'unroll_len' in cfg.policy:
+ cfg.policy.collect.unroll_len = cfg.policy.unroll_len
+ cfg.seed = seed
+ # check important key in config
+ if evaluator in [InteractionSerialEvaluator, BattleInteractionSerialEvaluator]: # env interaction evaluation
+ cfg.policy.eval.evaluator.stop_value = cfg.env.stop_value
+ cfg.policy.eval.evaluator.n_episode = cfg.env.n_evaluator_episode
+ if 'exp_name' not in cfg:
+ cfg.exp_name = 'default_experiment'
+ if save_cfg and get_rank() == 0:
+ if os.path.exists(cfg.exp_name) and renew_dir:
+ cfg.exp_name += datetime.datetime.now().strftime("_%y%m%d_%H%M%S")
+ try:
+ os.makedirs(cfg.exp_name)
+ except FileExistsError:
+ pass
+ save_project_state(cfg.exp_name)
+ save_path = os.path.join(cfg.exp_name, save_path)
+ save_config(cfg, save_path, save_formatted=True)
+ return cfg
+
+
+def compile_config_parallel(
+ cfg: EasyDict,
+ create_cfg: EasyDict,
+ system_cfg: EasyDict,
+ seed: int = 0,
+ save_cfg: bool = True,
+ save_path: str = 'total_config.py',
+ platform: str = 'local',
+ coordinator_host: Optional[str] = None,
+ learner_host: Optional[str] = None,
+ collector_host: Optional[str] = None,
+ coordinator_port: Optional[int] = None,
+ learner_port: Optional[int] = None,
+ collector_port: Optional[int] = None,
+) -> EasyDict:
+ """
+ Overview:
+ Combine the input parallel mode configuration information with other input information. Compile config\
+ to make it easy to be called by other programs
+ Arguments:
+ - cfg (:obj:`EasyDict`): Input main config dict
+ - create_cfg (:obj:`dict`): Input create config dict, including type parameters, such as environment type
+ - system_cfg (:obj:`dict`): Input system config dict, including system parameters, such as file path,\
+ communication mode, use multiple GPUs or not
+ - seed (:obj:`int`): Random number seed
+ - save_cfg (:obj:`bool`): Save config or not
+ - save_path (:obj:`str`): Path of saving file
+ - platform (:obj:`str`): Where to run the program, 'local' or 'slurm'
+ - coordinator_host (:obj:`Optional[str]`): Input coordinator's host when platform is slurm
+ - learner_host (:obj:`Optional[str]`): Input learner's host when platform is slurm
+ - collector_host (:obj:`Optional[str]`): Input collector's host when platform is slurm
+ Returns:
+ - cfg (:obj:`EasyDict`): Config after compiling
+ """
+ # for compatibility
+ if 'replay_buffer' not in create_cfg:
+ create_cfg.replay_buffer = EasyDict(dict(type='advanced'))
+ # env
+ env = get_env_cls(create_cfg.env)
+ if 'default_config' in dir(env):
+ env_config = env.default_config()
+ else:
+ env_config = EasyDict() # env does not have default_config
+ env_config = deep_merge_dicts(env_config_template, env_config)
+ env_config.update(create_cfg.env)
+
+ env_manager = get_env_manager_cls(create_cfg.env_manager)
+ env_config.manager = env_manager.default_config()
+ env_config.manager.update(create_cfg.env_manager)
+
+ # policy
+ policy = get_policy_cls(create_cfg.policy)
+ policy_config = policy.default_config()
+ policy_config = deep_merge_dicts(policy_config_template, policy_config)
+ cfg.policy.update(create_cfg.policy)
+
+ collector = get_parallel_collector_cls(create_cfg.collector)
+ policy_config.collect.collector = collector.default_config()
+ policy_config.collect.collector.update(create_cfg.collector)
+ policy_config.learn.learner = BaseLearner.default_config()
+ policy_config.learn.learner.update(create_cfg.learner)
+ commander = get_parallel_commander_cls(create_cfg.commander)
+ policy_config.other.commander = commander.default_config()
+ policy_config.other.commander.update(create_cfg.commander)
+ policy_config.other.replay_buffer.update(create_cfg.replay_buffer)
+ policy_config.other.replay_buffer = compile_buffer_config(policy_config, cfg, None)
+
+ default_config = EasyDict({'env': env_config, 'policy': policy_config})
+ cfg = deep_merge_dicts(default_config, cfg)
+
+ cfg.policy.other.commander.path_policy = system_cfg.path_policy # league may use 'path_policy'
+
+ # system
+ for k in ['comm_learner', 'comm_collector']:
+ system_cfg[k] = create_cfg[k]
+ if platform == 'local':
+ cfg = parallel_transform(EasyDict({'main': cfg, 'system': system_cfg}))
+ elif platform == 'slurm':
+ cfg = parallel_transform_slurm(
+ EasyDict({
+ 'main': cfg,
+ 'system': system_cfg
+ }), coordinator_host, learner_host, collector_host
+ )
+ elif platform == 'k8s':
+ cfg = parallel_transform_k8s(
+ EasyDict({
+ 'main': cfg,
+ 'system': system_cfg
+ }),
+ coordinator_port=coordinator_port,
+ learner_port=learner_port,
+ collector_port=collector_port
+ )
+ else:
+ raise KeyError("not support platform type: {}".format(platform))
+ cfg.system.coordinator = deep_merge_dicts(Coordinator.default_config(), cfg.system.coordinator)
+ # seed
+ cfg.seed = seed
+
+ if save_cfg:
+ save_config(cfg, save_path)
+ return cfg
diff --git a/DI-engine/ding/config/example/A2C/__init__.py b/DI-engine/ding/config/example/A2C/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..37beafc2486aa3d560ac233d654210f19070692a
--- /dev/null
+++ b/DI-engine/ding/config/example/A2C/__init__.py
@@ -0,0 +1,17 @@
+from easydict import EasyDict
+from . import gym_bipedalwalker_v3
+from . import gym_lunarlander_v2
+
+supported_env_cfg = {
+ gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.cfg,
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
+}
+
+supported_env_cfg = EasyDict(supported_env_cfg)
+
+supported_env = {
+ gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.env,
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
+}
+
+supported_env = EasyDict(supported_env)
diff --git a/DI-engine/ding/config/example/A2C/gym_bipedalwalker_v3.py b/DI-engine/ding/config/example/A2C/gym_bipedalwalker_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0f9637cec049bc13063e3a8e89b9596a5938088
--- /dev/null
+++ b/DI-engine/ding/config/example/A2C/gym_bipedalwalker_v3.py
@@ -0,0 +1,43 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='Bipedalwalker-v3-A2C',
+ seed=0,
+ env=dict(
+ env_id='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ act_scale=True,
+ rew_clip=True,
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='continuous',
+ model=dict(
+ action_space='continuous',
+ obs_shape=24,
+ action_shape=4,
+ ),
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.0003,
+ value_weight=0.7,
+ entropy_weight=0.0005,
+ discount_factor=0.99,
+ adv_norm=True,
+ ),
+ collect=dict(
+ n_sample=64,
+ discount_factor=0.99,
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/A2C/gym_lunarlander_v2.py b/DI-engine/ding/config/example/A2C/gym_lunarlander_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..6092bb412628708a1499fbe9918f89af7095ed63
--- /dev/null
+++ b/DI-engine/ding/config/example/A2C/gym_lunarlander_v2.py
@@ -0,0 +1,38 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='LunarLander-v2-A2C',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=260,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ ),
+ learn=dict(
+ batch_size=64,
+ learning_rate=3e-4,
+ entropy_weight=0.001,
+ adv_norm=True,
+ ),
+ collect=dict(
+ n_sample=64,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/C51/__init__.py b/DI-engine/ding/config/example/C51/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2704b04c5346fb524c8faa01dd302432633eb1df
--- /dev/null
+++ b/DI-engine/ding/config/example/C51/__init__.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict
+from . import gym_lunarlander_v2
+from . import gym_pongnoframeskip_v4
+from . import gym_qbertnoframeskip_v4
+from . import gym_spaceInvadersnoframeskip_v4
+
+supported_env_cfg = {
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
+ gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg,
+ gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg,
+ gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg,
+}
+
+supported_env_cfg = EasyDict(supported_env_cfg)
+
+supported_env = {
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
+ gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env,
+ gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env,
+ gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env,
+}
+
+supported_env = EasyDict(supported_env)
diff --git a/DI-engine/ding/config/example/C51/gym_lunarlander_v2.py b/DI-engine/ding/config/example/C51/gym_lunarlander_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f929a964fa01cec0cb2302cdd678e7488ef836f
--- /dev/null
+++ b/DI-engine/ding/config/example/C51/gym_lunarlander_v2.py
@@ -0,0 +1,52 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='lunarlander_c51',
+ seed=0,
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=260,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[512, 64],
+ v_min=-30,
+ v_max=30,
+ n_atom=51,
+ ),
+ discount_factor=0.99,
+ nstep=3,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=64,
+ learning_rate=0.001,
+ target_update_freq=100,
+ ),
+ collect=dict(
+ n_sample=64,
+ unroll_len=1,
+ ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=50000,
+ ), replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/C51/gym_pongnoframeskip_v4.py b/DI-engine/ding/config/example/C51/gym_pongnoframeskip_v4.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7b62b5fa85f580615f0ee8a8185f8b1c89d436c
--- /dev/null
+++ b/DI-engine/ding/config/example/C51/gym_pongnoframeskip_v4.py
@@ -0,0 +1,54 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='PongNoFrameskip-v4-C51',
+ seed=0,
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30,
+ env_id='PongNoFrameskip-v4',
+ frame_stack=4,
+ env_wrapper='atari_default',
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ v_min=-10,
+ v_max=10,
+ n_atom=51,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/C51/gym_qbertnoframeskip_v4.py b/DI-engine/ding/config/example/C51/gym_qbertnoframeskip_v4.py
new file mode 100644
index 0000000000000000000000000000000000000000..21c3bcf7ea1003425015efa36ede592f53b053b4
--- /dev/null
+++ b/DI-engine/ding/config/example/C51/gym_qbertnoframeskip_v4.py
@@ -0,0 +1,54 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='QbertNoFrameskip-v4-C51',
+ seed=0,
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30000,
+ env_id='QbertNoFrameskip-v4',
+ frame_stack=4,
+ env_wrapper='atari_default',
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ v_min=-10,
+ v_max=10,
+ n_atom=51,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/C51/gym_spaceInvadersnoframeskip_v4.py b/DI-engine/ding/config/example/C51/gym_spaceInvadersnoframeskip_v4.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2fcf431c3a7bc82d81634b4179537523c7785af
--- /dev/null
+++ b/DI-engine/ding/config/example/C51/gym_spaceInvadersnoframeskip_v4.py
@@ -0,0 +1,54 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='SpaceInvadersNoFrameskip-v4-C51',
+ seed=0,
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ frame_stack=4,
+ env_wrapper='atari_default',
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ v_min=-10,
+ v_max=10,
+ n_atom=51,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/DDPG/__init__.py b/DI-engine/ding/config/example/DDPG/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e01f29d7452ca48b57800abd0d8893703d6b4e8
--- /dev/null
+++ b/DI-engine/ding/config/example/DDPG/__init__.py
@@ -0,0 +1,29 @@
+from easydict import EasyDict
+from . import gym_bipedalwalker_v3
+from . import gym_halfcheetah_v3
+from . import gym_hopper_v3
+from . import gym_lunarlandercontinuous_v2
+from . import gym_pendulum_v1
+from . import gym_walker2d_v3
+
+supported_env_cfg = {
+ gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.cfg,
+ gym_halfcheetah_v3.cfg.env.env_id: gym_halfcheetah_v3.cfg,
+ gym_hopper_v3.cfg.env.env_id: gym_hopper_v3.cfg,
+ gym_lunarlandercontinuous_v2.cfg.env.env_id: gym_lunarlandercontinuous_v2.cfg,
+ gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.cfg,
+ gym_walker2d_v3.cfg.env.env_id: gym_walker2d_v3.cfg,
+}
+
+supported_env_cfg = EasyDict(supported_env_cfg)
+
+supported_env = {
+ gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.env,
+ gym_halfcheetah_v3.cfg.env.env_id: gym_halfcheetah_v3.env,
+ gym_hopper_v3.cfg.env.env_id: gym_hopper_v3.env,
+ gym_lunarlandercontinuous_v2.cfg.env.env_id: gym_lunarlandercontinuous_v2.env,
+ gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.env,
+ gym_walker2d_v3.cfg.env.env_id: gym_walker2d_v3.env,
+}
+
+supported_env = EasyDict(supported_env)
diff --git a/DI-engine/ding/config/example/DDPG/gym_bipedalwalker_v3.py b/DI-engine/ding/config/example/DDPG/gym_bipedalwalker_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd26d46f683ef51ffcad617409880cda3f6e1eee
--- /dev/null
+++ b/DI-engine/ding/config/example/DDPG/gym_bipedalwalker_v3.py
@@ -0,0 +1,45 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='Bipedalwalker-v3-DDPG',
+ seed=0,
+ env=dict(
+ env_id='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ act_scale=True,
+ rew_clip=True,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=24,
+ action_shape=4,
+ twin_critic=False,
+ action_space='regression',
+ actor_head_hidden_size=400,
+ critic_head_hidden_size=400,
+ ),
+ learn=dict(
+ update_per_collect=64,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ target_theta=0.005,
+ discount_factor=0.99,
+ learner=dict(hook=dict(log_show_after_iter=1000, ))
+ ),
+ collect=dict(n_sample=64, ),
+ other=dict(replay_buffer=dict(replay_buffer_size=300000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/DDPG/gym_halfcheetah_v3.py b/DI-engine/ding/config/example/DDPG/gym_halfcheetah_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..2bbf075a03f89aaa65af1860c5cfd67a4bdef087
--- /dev/null
+++ b/DI-engine/ding/config/example/DDPG/gym_halfcheetah_v3.py
@@ -0,0 +1,53 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='HalfCheetah-v3-DDPG',
+ seed=0,
+ env=dict(
+ env_id='HalfCheetah-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=11000,
+ env_wrapper='mujoco_default',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=False,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/DDPG/gym_hopper_v3.py b/DI-engine/ding/config/example/DDPG/gym_hopper_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd8d2538070671b75b7f04a16e5cd0c5121fd131
--- /dev/null
+++ b/DI-engine/ding/config/example/DDPG/gym_hopper_v3.py
@@ -0,0 +1,53 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='Hopper-v3-DDPG',
+ seed=0,
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ env_wrapper='mujoco_default',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ twin_critic=False,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/DDPG/gym_lunarlandercontinuous_v2.py b/DI-engine/ding/config/example/DDPG/gym_lunarlandercontinuous_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..89b92e8de1ebddfea93033e77992c02110d7c071
--- /dev/null
+++ b/DI-engine/ding/config/example/DDPG/gym_lunarlandercontinuous_v2.py
@@ -0,0 +1,60 @@
+from easydict import EasyDict
+from functools import partial
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='LunarLanderContinuous-V2-DDPG',
+ seed=0,
+ env=dict(
+ env_id='LunarLanderContinuous-v2',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=260,
+ act_scale=True,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=0,
+ model=dict(
+ obs_shape=8,
+ action_shape=2,
+ twin_critic=True,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=2,
+ batch_size=128,
+ learning_rate_actor=0.001,
+ learning_rate_critic=0.001,
+ ignore_done=False, # TODO(pu)
+ # (int) When critic network updates once, how many times will actor network update.
+ # Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default 1 for DDPG, 2 for TD3.
+ actor_update_freq=1,
+ # (bool) Whether to add noise on target network's action.
+ # Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default True for TD3, False for DDPG.
+ noise=False,
+ noise_sigma=0.1,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=48,
+ noise_sigma=0.1,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ other=dict(replay_buffer=dict(replay_buffer_size=20000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = partial(ding.envs.gym_env.env, continuous=True)
diff --git a/DI-engine/ding/config/example/DDPG/gym_pendulum_v1.py b/DI-engine/ding/config/example/DDPG/gym_pendulum_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..63e85869e76561d35627646e9b09173389268c62
--- /dev/null
+++ b/DI-engine/ding/config/example/DDPG/gym_pendulum_v1.py
@@ -0,0 +1,52 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='Pendulum-v1-DDPG',
+ seed=0,
+ env=dict(
+ env_id='Pendulum-v1',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ act_scale=True,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ random_collect_size=800,
+ model=dict(
+ obs_shape=3,
+ action_shape=1,
+ twin_critic=False,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=2,
+ batch_size=128,
+ learning_rate_actor=0.001,
+ learning_rate_critic=0.001,
+ ignore_done=True,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=48,
+ noise_sigma=0.1,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(replay_buffer=dict(
+ replay_buffer_size=20000,
+ max_use=16,
+ ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/DDPG/gym_walker2d_v3.py b/DI-engine/ding/config/example/DDPG/gym_walker2d_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..84e6407de4f8c062556a59609aba587e1f4f4f11
--- /dev/null
+++ b/DI-engine/ding/config/example/DDPG/gym_walker2d_v3.py
@@ -0,0 +1,53 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='Walker2d-v3-DDPG',
+ seed=0,
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ env_wrapper='mujoco_default',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=False,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/DQN/__init__.py b/DI-engine/ding/config/example/DQN/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2704b04c5346fb524c8faa01dd302432633eb1df
--- /dev/null
+++ b/DI-engine/ding/config/example/DQN/__init__.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict
+from . import gym_lunarlander_v2
+from . import gym_pongnoframeskip_v4
+from . import gym_qbertnoframeskip_v4
+from . import gym_spaceInvadersnoframeskip_v4
+
+supported_env_cfg = {
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
+ gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg,
+ gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg,
+ gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg,
+}
+
+supported_env_cfg = EasyDict(supported_env_cfg)
+
+supported_env = {
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
+ gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env,
+ gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env,
+ gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env,
+}
+
+supported_env = EasyDict(supported_env)
diff --git a/DI-engine/ding/config/example/DQN/gym_lunarlander_v2.py b/DI-engine/ding/config/example/DQN/gym_lunarlander_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b79a4eeaafe961677a0ea178e99990748dc3ff5
--- /dev/null
+++ b/DI-engine/ding/config/example/DQN/gym_lunarlander_v2.py
@@ -0,0 +1,53 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='LunarLander-v2-DQN',
+ seed=0,
+ env=dict(
+ env_id='LunarLander-v2',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=260,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ discount_factor=0.99,
+ nstep=3,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=64,
+ learning_rate=0.001,
+ # Frequency of target network update.
+ target_update_freq=100,
+ ),
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[512, 64],
+ # Whether to use dueling head.
+ dueling=True,
+ ),
+ collect=dict(
+ n_sample=64,
+ unroll_len=1,
+ ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=50000,
+ ), replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/DQN/gym_pongnoframeskip_v4.py b/DI-engine/ding/config/example/DQN/gym_pongnoframeskip_v4.py
new file mode 100644
index 0000000000000000000000000000000000000000..0266783ac57f050def11f9813cccf43a4cd1b22a
--- /dev/null
+++ b/DI-engine/ding/config/example/DQN/gym_pongnoframeskip_v4.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='PongNoFrameskip-v4-DQN',
+ seed=0,
+ env=dict(
+ env_id='PongNoFrameskip-v4',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30,
+ fram_stack=4,
+ env_wrapper='atari_default',
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ discount_factor=0.99,
+ nstep=3,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ # Frequency of target network update.
+ target_update_freq=500,
+ ),
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ collect=dict(n_sample=96, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ), replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/DQN/gym_qbertnoframeskip_v4.py b/DI-engine/ding/config/example/DQN/gym_qbertnoframeskip_v4.py
new file mode 100644
index 0000000000000000000000000000000000000000..e782a12e9caa1424bfba2a29eab36c99fd91ae17
--- /dev/null
+++ b/DI-engine/ding/config/example/DQN/gym_qbertnoframeskip_v4.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='QbertNoFrameskip-v4-DQN',
+ seed=0,
+ env=dict(
+ env_id='QbertNoFrameskip-v4',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ fram_stack=4,
+ stop_value=30000,
+ env_wrapper='atari_default',
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ discount_factor=0.99,
+ nstep=3,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ # Frequency of target network update.
+ target_update_freq=500,
+ ),
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ collect=dict(n_sample=100, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ), replay_buffer=dict(replay_buffer_size=400000, )
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/DQN/gym_spaceInvadersnoframeskip_v4.py b/DI-engine/ding/config/example/DQN/gym_spaceInvadersnoframeskip_v4.py
new file mode 100644
index 0000000000000000000000000000000000000000..f69a61e6dd73da2e32e76ac03d65ab0287941df5
--- /dev/null
+++ b/DI-engine/ding/config/example/DQN/gym_spaceInvadersnoframeskip_v4.py
@@ -0,0 +1,51 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='SpaceInvadersNoFrameskip-v4-DQN',
+ seed=0,
+ env=dict(
+ env_id='SpaceInvadersNoFrameskip-v4',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ fram_stack=4,
+ stop_value=2000,
+ env_wrapper='atari_default',
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ discount_factor=0.99,
+ nstep=3,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ # Frequency of target network update.
+ target_update_freq=500,
+ hook=dict(save_ckpt_after_iter=1000000, )
+ ),
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ collect=dict(n_sample=100, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ), replay_buffer=dict(replay_buffer_size=400000, )
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/PG/__init__.py b/DI-engine/ding/config/example/PG/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..518449884f53d38468f092ff949eaae4c4b9cc6c
--- /dev/null
+++ b/DI-engine/ding/config/example/PG/__init__.py
@@ -0,0 +1,14 @@
+from easydict import EasyDict
+from . import gym_pendulum_v1
+
+supported_env_cfg = {
+ gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.cfg,
+}
+
+supported_env_cfg = EasyDict(supported_env_cfg)
+
+supported_env = {
+ gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.env,
+}
+
+supported_env = EasyDict(supported_env)
diff --git a/DI-engine/ding/config/example/PG/gym_pendulum_v1.py b/DI-engine/ding/config/example/PG/gym_pendulum_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..59b3e31eb865702d3968e2097e2482df3502512a
--- /dev/null
+++ b/DI-engine/ding/config/example/PG/gym_pendulum_v1.py
@@ -0,0 +1,42 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='Pendulum-v1-PG',
+ seed=0,
+ env=dict(
+ env_id='Pendulum-v1',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=-200,
+ act_scale=True,
+ ),
+ policy=dict(
+ cuda=False,
+ action_space='continuous',
+ model=dict(
+ action_space='continuous',
+ obs_shape=3,
+ action_shape=1,
+ ),
+ learn=dict(
+ batch_size=4000,
+ learning_rate=0.001,
+ entropy_weight=0.001,
+ ),
+ collect=dict(
+ n_episode=20,
+ unroll_len=1,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1, ))
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/PPOF/__init__.py b/DI-engine/ding/config/example/PPOF/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2adaaf4df23fa892b117adc8694b6a1ba88dd1bd
--- /dev/null
+++ b/DI-engine/ding/config/example/PPOF/__init__.py
@@ -0,0 +1,17 @@
+from easydict import EasyDict
+from . import gym_lunarlander_v2
+from . import gym_lunarlandercontinuous_v2
+
+supported_env_cfg = {
+ gym_lunarlander_v2.cfg.env_id: gym_lunarlander_v2.cfg,
+ gym_lunarlandercontinuous_v2.cfg.env_id: gym_lunarlandercontinuous_v2.cfg,
+}
+
+supported_env_cfg = EasyDict(supported_env_cfg)
+
+supported_env = {
+ gym_lunarlander_v2.cfg.env_id: gym_lunarlander_v2.env,
+ gym_lunarlandercontinuous_v2.cfg.env_id: gym_lunarlandercontinuous_v2.env,
+}
+
+supported_env = EasyDict(supported_env)
diff --git a/DI-engine/ding/config/example/PPOF/gym_lunarlander_v2.py b/DI-engine/ding/config/example/PPOF/gym_lunarlander_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a05266277106ddce245aecc164b8d0930a9ed75
--- /dev/null
+++ b/DI-engine/ding/config/example/PPOF/gym_lunarlander_v2.py
@@ -0,0 +1,13 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='LunarLander-v2-PPO',
+ env_id='LunarLander-v2',
+ n_sample=400,
+ value_norm='popart',
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/PPOF/gym_lunarlandercontinuous_v2.py b/DI-engine/ding/config/example/PPOF/gym_lunarlandercontinuous_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..c12c88fd42c017a81592111c1523a4a102fe933b
--- /dev/null
+++ b/DI-engine/ding/config/example/PPOF/gym_lunarlandercontinuous_v2.py
@@ -0,0 +1,15 @@
+from easydict import EasyDict
+from functools import partial
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='LunarLanderContinuous-V2-PPO',
+ env_id='LunarLanderContinuous-v2',
+ action_space='continuous',
+ n_sample=400,
+ act_scale=True,
+)
+
+cfg = EasyDict(cfg)
+
+env = partial(ding.envs.gym_env.env, continuous=True)
diff --git a/DI-engine/ding/config/example/PPOOffPolicy/__init__.py b/DI-engine/ding/config/example/PPOOffPolicy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2704b04c5346fb524c8faa01dd302432633eb1df
--- /dev/null
+++ b/DI-engine/ding/config/example/PPOOffPolicy/__init__.py
@@ -0,0 +1,23 @@
+from easydict import EasyDict
+from . import gym_lunarlander_v2
+from . import gym_pongnoframeskip_v4
+from . import gym_qbertnoframeskip_v4
+from . import gym_spaceInvadersnoframeskip_v4
+
+supported_env_cfg = {
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
+ gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.cfg,
+ gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.cfg,
+ gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.cfg,
+}
+
+supported_env_cfg = EasyDict(supported_env_cfg)
+
+supported_env = {
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
+ gym_pongnoframeskip_v4.cfg.env.env_id: gym_pongnoframeskip_v4.env,
+ gym_qbertnoframeskip_v4.cfg.env.env_id: gym_qbertnoframeskip_v4.env,
+ gym_spaceInvadersnoframeskip_v4.cfg.env.env_id: gym_spaceInvadersnoframeskip_v4.env,
+}
+
+supported_env = EasyDict(supported_env)
diff --git a/DI-engine/ding/config/example/PPOOffPolicy/gym_lunarlander_v2.py b/DI-engine/ding/config/example/PPOOffPolicy/gym_lunarlander_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..1db3551b5011d986bf67da7a0a7ea9b502efa85e
--- /dev/null
+++ b/DI-engine/ding/config/example/PPOOffPolicy/gym_lunarlander_v2.py
@@ -0,0 +1,44 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='LunarLander-v2-PPOOffPolicy',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=260,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ ),
+ learn=dict(
+ update_per_collect=4,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ nstep=1,
+ nstep_return=False,
+ adv_norm=True,
+ ),
+ collect=dict(
+ n_sample=128,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/PPOOffPolicy/gym_pongnoframeskip_v4.py b/DI-engine/ding/config/example/PPOOffPolicy/gym_pongnoframeskip_v4.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f376e26516ef194f4851ed9a90a52e4aa60ae36
--- /dev/null
+++ b/DI-engine/ding/config/example/PPOOffPolicy/gym_pongnoframeskip_v4.py
@@ -0,0 +1,54 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='PongNoFrameskip-v4-PPOOffPolicy',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30,
+ env_id='PongNoFrameskip-v4',
+ frame_stack=4,
+ env_wrapper='atari_default',
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ update_per_collect=10,
+ batch_size=320,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ # value_norm=True,
+ ignore_done=False,
+ grad_clip_type='clip_norm',
+ grad_clip_value=0.5,
+ ),
+ collect=dict(
+ n_sample=3200,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/PPOOffPolicy/gym_qbertnoframeskip_v4.py b/DI-engine/ding/config/example/PPOOffPolicy/gym_qbertnoframeskip_v4.py
new file mode 100644
index 0000000000000000000000000000000000000000..7272ffebd686b3849a81055be98ed420e797cfb2
--- /dev/null
+++ b/DI-engine/ding/config/example/PPOOffPolicy/gym_qbertnoframeskip_v4.py
@@ -0,0 +1,48 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='QbertNoFrameskip-v4-PPOOffPolicy',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='QbertNoFrameskip-v4',
+ frame_stack=4,
+ env_wrapper='atari_default',
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[32, 64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ critic_head_layer_num=2,
+ ),
+ learn=dict(
+ update_per_collect=18,
+ batch_size=128,
+ learning_rate=0.0001,
+ value_weight=1.0,
+ entropy_weight=0.005,
+ clip_ratio=0.1,
+ adv_norm=False,
+ ),
+ collect=dict(
+ n_sample=1024,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/PPOOffPolicy/gym_spaceInvadersnoframeskip_v4.py b/DI-engine/ding/config/example/PPOOffPolicy/gym_spaceInvadersnoframeskip_v4.py
new file mode 100644
index 0000000000000000000000000000000000000000..18558553ace70f464355f99078543add450dc2f9
--- /dev/null
+++ b/DI-engine/ding/config/example/PPOOffPolicy/gym_spaceInvadersnoframeskip_v4.py
@@ -0,0 +1,48 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='SpaceInvadersNoFrameskip-v4-PPOOffPolicy',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ frame_stack=4,
+ env_wrapper='atari_default',
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[32, 64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ critic_head_layer_num=2,
+ ),
+ learn=dict(
+ update_per_collect=24,
+ batch_size=128,
+ learning_rate=0.0001,
+ value_weight=1.0,
+ entropy_weight=0.03,
+ clip_ratio=0.1,
+ adv_norm=False,
+ ),
+ collect=dict(
+ n_sample=1024,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/SAC/__init__.py b/DI-engine/ding/config/example/SAC/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e01f29d7452ca48b57800abd0d8893703d6b4e8
--- /dev/null
+++ b/DI-engine/ding/config/example/SAC/__init__.py
@@ -0,0 +1,29 @@
+from easydict import EasyDict
+from . import gym_bipedalwalker_v3
+from . import gym_halfcheetah_v3
+from . import gym_hopper_v3
+from . import gym_lunarlandercontinuous_v2
+from . import gym_pendulum_v1
+from . import gym_walker2d_v3
+
+supported_env_cfg = {
+ gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.cfg,
+ gym_halfcheetah_v3.cfg.env.env_id: gym_halfcheetah_v3.cfg,
+ gym_hopper_v3.cfg.env.env_id: gym_hopper_v3.cfg,
+ gym_lunarlandercontinuous_v2.cfg.env.env_id: gym_lunarlandercontinuous_v2.cfg,
+ gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.cfg,
+ gym_walker2d_v3.cfg.env.env_id: gym_walker2d_v3.cfg,
+}
+
+supported_env_cfg = EasyDict(supported_env_cfg)
+
+supported_env = {
+ gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.env,
+ gym_halfcheetah_v3.cfg.env.env_id: gym_halfcheetah_v3.env,
+ gym_hopper_v3.cfg.env.env_id: gym_hopper_v3.env,
+ gym_lunarlandercontinuous_v2.cfg.env.env_id: gym_lunarlandercontinuous_v2.env,
+ gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.env,
+ gym_walker2d_v3.cfg.env.env_id: gym_walker2d_v3.env,
+}
+
+supported_env = EasyDict(supported_env)
diff --git a/DI-engine/ding/config/example/SAC/gym_bipedalwalker_v3.py b/DI-engine/ding/config/example/SAC/gym_bipedalwalker_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..d97f131fd832541c26bcf202b6ff7c022ed36a0b
--- /dev/null
+++ b/DI-engine/ding/config/example/SAC/gym_bipedalwalker_v3.py
@@ -0,0 +1,47 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='BipedalWalker-v3-SAC',
+ seed=0,
+ env=dict(
+ env_id='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ act_scale=True,
+ rew_clip=True,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=24,
+ action_shape=4,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ update_per_collect=64,
+ batch_size=256,
+ learning_rate_q=0.0003,
+ learning_rate_policy=0.0003,
+ learning_rate_alpha=0.0003,
+ target_theta=0.005,
+ discount_factor=0.99,
+ auto_alpha=True,
+ learner=dict(hook=dict(log_show_after_iter=1000, ))
+ ),
+ collect=dict(n_sample=64, ),
+ other=dict(replay_buffer=dict(replay_buffer_size=300000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/SAC/gym_halfcheetah_v3.py b/DI-engine/ding/config/example/SAC/gym_halfcheetah_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2f0f8cc21708eed5cd74f12aabbc95b42b84445
--- /dev/null
+++ b/DI-engine/ding/config/example/SAC/gym_halfcheetah_v3.py
@@ -0,0 +1,54 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='HalfCheetah-v3-SAC',
+ seed=0,
+ env=dict(
+ env_id='HalfCheetah-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=12000,
+ env_wrapper='mujoco_default',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/SAC/gym_hopper_v3.py b/DI-engine/ding/config/example/SAC/gym_hopper_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a609e8c180b341b1a1f5deed8c1fb66dabdef5d
--- /dev/null
+++ b/DI-engine/ding/config/example/SAC/gym_hopper_v3.py
@@ -0,0 +1,41 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='Hopper-v3-SAC',
+ seed=0,
+ env=dict(
+ env_id='Hopper-v3',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ env_wrapper='mujoco_default',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/SAC/gym_lunarlandercontinuous_v2.py b/DI-engine/ding/config/example/SAC/gym_lunarlandercontinuous_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..37201da601dfc63dd7363bca129b30aff5b01041
--- /dev/null
+++ b/DI-engine/ding/config/example/SAC/gym_lunarlandercontinuous_v2.py
@@ -0,0 +1,44 @@
+from easydict import EasyDict
+from functools import partial
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='LunarLanderContinuous-v2-SAC',
+ seed=0,
+ env=dict(
+ env_id='LunarLanderContinuous-v2',
+ collector_env_num=4,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=260,
+ act_scale=True,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=8,
+ action_shape=2,
+ action_space='reparameterization',
+ twin_critic=True,
+ ),
+ learn=dict(
+ update_per_collect=256,
+ batch_size=128,
+ learning_rate_q=1e-3,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ auto_alpha=True,
+ ),
+ collect=dict(n_sample=256, ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ other=dict(replay_buffer=dict(replay_buffer_size=int(1e5), ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = partial(ding.envs.gym_env.env, continuous=True)
diff --git a/DI-engine/ding/config/example/SAC/gym_pendulum_v1.py b/DI-engine/ding/config/example/SAC/gym_pendulum_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..45f2e9ddf33f9ea0b2306dc4d0c6c8cc99f1a79f
--- /dev/null
+++ b/DI-engine/ding/config/example/SAC/gym_pendulum_v1.py
@@ -0,0 +1,49 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='Pendulum-v1-SAC',
+ seed=0,
+ env=dict(
+ env_id='Pendulum-v1',
+ collector_env_num=10,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=-250,
+ act_scale=True,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ random_collect_size=1000,
+ model=dict(
+ obs_shape=3,
+ action_shape=1,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=128,
+ learning_rate_q=0.001,
+ learning_rate_policy=0.001,
+ learning_rate_alpha=0.0003,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ auto_alpha=True,
+ ),
+ collect=dict(n_sample=10, ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/SAC/gym_walker2d_v3.py b/DI-engine/ding/config/example/SAC/gym_walker2d_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..f40d68d1170b55bb41ed32c19c8ebe796cb4a196
--- /dev/null
+++ b/DI-engine/ding/config/example/SAC/gym_walker2d_v3.py
@@ -0,0 +1,54 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='Walker2d-v3-SAC',
+ seed=0,
+ env=dict(
+ env_id='Walker2d-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ env_wrapper='mujoco_default',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/SQL/__init__.py b/DI-engine/ding/config/example/SQL/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9637366fb420b01ebfbb06713a1f655ec9c8b98e
--- /dev/null
+++ b/DI-engine/ding/config/example/SQL/__init__.py
@@ -0,0 +1,14 @@
+from easydict import EasyDict
+from . import gym_lunarlander_v2
+
+supported_env_cfg = {
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.cfg,
+}
+
+supported_env_cfg = EasyDict(supported_env_cfg)
+
+supported_env = {
+ gym_lunarlander_v2.cfg.env.env_id: gym_lunarlander_v2.env,
+}
+
+supported_env = EasyDict(supported_env)
diff --git a/DI-engine/ding/config/example/SQL/gym_lunarlander_v2.py b/DI-engine/ding/config/example/SQL/gym_lunarlander_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..648564dadbc6344eaca2eca8af39be37cab7637f
--- /dev/null
+++ b/DI-engine/ding/config/example/SQL/gym_lunarlander_v2.py
@@ -0,0 +1,43 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='LunarLander-v2-SQL',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=260,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(batch_size=64, learning_rate=0.001, alpha=0.08),
+ collect=dict(n_sample=64),
+ eval=dict(evaluator=dict(eval_freq=50, )), # note: this is the times after which you learns to evaluate
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/TD3/__init__.py b/DI-engine/ding/config/example/TD3/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e01f29d7452ca48b57800abd0d8893703d6b4e8
--- /dev/null
+++ b/DI-engine/ding/config/example/TD3/__init__.py
@@ -0,0 +1,29 @@
+from easydict import EasyDict
+from . import gym_bipedalwalker_v3
+from . import gym_halfcheetah_v3
+from . import gym_hopper_v3
+from . import gym_lunarlandercontinuous_v2
+from . import gym_pendulum_v1
+from . import gym_walker2d_v3
+
+supported_env_cfg = {
+ gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.cfg,
+ gym_halfcheetah_v3.cfg.env.env_id: gym_halfcheetah_v3.cfg,
+ gym_hopper_v3.cfg.env.env_id: gym_hopper_v3.cfg,
+ gym_lunarlandercontinuous_v2.cfg.env.env_id: gym_lunarlandercontinuous_v2.cfg,
+ gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.cfg,
+ gym_walker2d_v3.cfg.env.env_id: gym_walker2d_v3.cfg,
+}
+
+supported_env_cfg = EasyDict(supported_env_cfg)
+
+supported_env = {
+ gym_bipedalwalker_v3.cfg.env.env_id: gym_bipedalwalker_v3.env,
+ gym_halfcheetah_v3.cfg.env.env_id: gym_halfcheetah_v3.env,
+ gym_hopper_v3.cfg.env.env_id: gym_hopper_v3.env,
+ gym_lunarlandercontinuous_v2.cfg.env.env_id: gym_lunarlandercontinuous_v2.env,
+ gym_pendulum_v1.cfg.env.env_id: gym_pendulum_v1.env,
+ gym_walker2d_v3.cfg.env.env_id: gym_walker2d_v3.env,
+}
+
+supported_env = EasyDict(supported_env)
diff --git a/DI-engine/ding/config/example/TD3/gym_bipedalwalker_v3.py b/DI-engine/ding/config/example/TD3/gym_bipedalwalker_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e4cc24321ec6a3a8697312c31f206da499ba34a
--- /dev/null
+++ b/DI-engine/ding/config/example/TD3/gym_bipedalwalker_v3.py
@@ -0,0 +1,52 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='Bipedalwalker-v3-TD3',
+ seed=0,
+ env=dict(
+ env_id='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ act_scale=True,
+ rew_clip=True,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=24,
+ action_shape=4,
+ twin_critic=True,
+ action_space='regression',
+ actor_head_hidden_size=400,
+ critic_head_hidden_size=400,
+ ),
+ learn=dict(
+ update_per_collect=64,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ learner=dict(hook=dict(log_show_after_iter=1000, ))
+ ),
+ collect=dict(n_sample=64, ),
+ other=dict(replay_buffer=dict(replay_buffer_size=300000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/TD3/gym_halfcheetah_v3.py b/DI-engine/ding/config/example/TD3/gym_halfcheetah_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..ddd2f1a68ec02f2829c52f6cc0b48dc8ab0087ea
--- /dev/null
+++ b/DI-engine/ding/config/example/TD3/gym_halfcheetah_v3.py
@@ -0,0 +1,56 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='HalfCheetah-v3-TD3',
+ seed=0,
+ env=dict(
+ env_id='HalfCheetah-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=11000,
+ env_wrapper='mujoco_default',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/TD3/gym_hopper_v3.py b/DI-engine/ding/config/example/TD3/gym_hopper_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..e213323fd41600b7c34ff9443d40b5dc6a565e3c
--- /dev/null
+++ b/DI-engine/ding/config/example/TD3/gym_hopper_v3.py
@@ -0,0 +1,35 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='Hopper-v3-TD3',
+ seed=0,
+ env=dict(
+ env_id='Hopper-v3',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ env_wrapper='mujoco_default',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ collect=dict(n_sample=1, ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/TD3/gym_lunarlandercontinuous_v2.py b/DI-engine/ding/config/example/TD3/gym_lunarlandercontinuous_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..9798705ff27e887a20b1d639c156aa086095f589
--- /dev/null
+++ b/DI-engine/ding/config/example/TD3/gym_lunarlandercontinuous_v2.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+from functools import partial
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='LunarLanderContinuous-V2-TD3',
+ seed=0,
+ env=dict(
+ env_id='LunarLanderContinuous-v2',
+ collector_env_num=4,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=240,
+ act_scale=True,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=8,
+ action_shape=2,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=256,
+ batch_size=256,
+ learning_rate_actor=3e-4,
+ learning_rate_critic=1e-3,
+ noise=True,
+ noise_sigma=0.1,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=256,
+ noise_sigma=0.1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = partial(ding.envs.gym_env.env, continuous=True)
diff --git a/DI-engine/ding/config/example/TD3/gym_pendulum_v1.py b/DI-engine/ding/config/example/TD3/gym_pendulum_v1.py
new file mode 100644
index 0000000000000000000000000000000000000000..57ebeeae6c9f53487d4ddd5a0b29a89dbc24636e
--- /dev/null
+++ b/DI-engine/ding/config/example/TD3/gym_pendulum_v1.py
@@ -0,0 +1,54 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='Pendulum-v1-TD3',
+ seed=0,
+ env=dict(
+ env_id='Pendulum-v1',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ act_scale=True,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ random_collect_size=800,
+ model=dict(
+ obs_shape=3,
+ action_shape=1,
+ twin_critic=True,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=2,
+ batch_size=128,
+ learning_rate_actor=0.001,
+ learning_rate_critic=0.001,
+ ignore_done=True,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.1,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=48,
+ noise_sigma=0.1,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ other=dict(replay_buffer=dict(replay_buffer_size=20000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/TD3/gym_walker2d_v3.py b/DI-engine/ding/config/example/TD3/gym_walker2d_v3.py
new file mode 100644
index 0000000000000000000000000000000000000000..92b88e1e0816783844954111a68b5681fca19cbc
--- /dev/null
+++ b/DI-engine/ding/config/example/TD3/gym_walker2d_v3.py
@@ -0,0 +1,58 @@
+from easydict import EasyDict
+import ding.envs.gym_env
+
+cfg = dict(
+ exp_name='Walker2d-v3-TD3',
+ seed=0,
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ env_wrapper='mujoco_default',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+ wandb_logger=dict(
+ gradient_logger=True, video_logger=True, plot_logger=True, action_logger=True, return_logger=False
+ ),
+)
+
+cfg = EasyDict(cfg)
+
+env = ding.envs.gym_env.env
diff --git a/DI-engine/ding/config/example/__init__.py b/DI-engine/ding/config/example/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b2c8e750c14b9fc8bda90c056ddea7c7db78476
--- /dev/null
+++ b/DI-engine/ding/config/example/__init__.py
@@ -0,0 +1,10 @@
+from . import A2C
+from . import C51
+from . import DDPG
+from . import DQN
+from . import PG
+from . import PPOF
+from . import PPOOffPolicy
+from . import SAC
+from . import SQL
+from . import TD3
diff --git a/DI-engine/ding/config/tests/test_config_formatted.py b/DI-engine/ding/config/tests/test_config_formatted.py
new file mode 100644
index 0000000000000000000000000000000000000000..f906b2fc04e966b675f94aad00a3cd71589dcc53
--- /dev/null
+++ b/DI-engine/ding/config/tests/test_config_formatted.py
@@ -0,0 +1,34 @@
+import pytest
+import os
+import importlib
+from typing import Union, Optional, List, Any, Callable, Tuple
+from ding.config import read_config, compile_config
+import dizoo.classic_control.cartpole.config.cartpole_ppo_config as cppo
+import dizoo.classic_control.cartpole.config.cartpole_dqn_config as cdqn
+import dizoo.classic_control.cartpole.config.cartpole_a2c_config as ca2c
+import dizoo.classic_control.cartpole.config.cartpole_c51_config as cc51
+
+args = [
+ ['dizoo.classic_control.cartpole.config.cartpole_ppo_config', 'ppo'],
+ ['dizoo.classic_control.cartpole.config.cartpole_a2c_config', 'a2c'],
+ # TODO adapt to new buffer
+ # ['dizoo.classic_control.cartpole.config.cartpole_dqn_config', 'dqn',
+ ['dizoo.classic_control.cartpole.config.cartpole_c51_config', 'c51'],
+]
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('config_path, name', args)
+def test_config_formatted(config_path, name):
+ module_config = importlib.import_module(config_path)
+ main_config, create_config = module_config.main_config, module_config.create_config
+ main_config.exp_name = 'test_config_formatted_' + main_config.exp_name
+ cfg = compile_config(
+ main_config, seed=0, auto=True, create_cfg=create_config, save_cfg=True, save_path='{}_config.py'.format(name)
+ )
+
+ module = importlib.import_module('test_config_formatted_cartpole_{}_seed0.formatted_{}_config'.format(name, name))
+ main_config, create_config = module.main_config, module.create_config
+ cfg_test = compile_config(main_config, seed=0, auto=True, create_cfg=create_config, save_cfg=False)
+ assert cfg == cfg_test, 'cfg_formatted_failed'
+ os.popen('rm -rf test_config_formatted_cartpole_{}_seed0'.format(name))
diff --git a/DI-engine/ding/config/utils.py b/DI-engine/ding/config/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a9a2d666433b4e25bf50ce6fe2ee917f33344bc
--- /dev/null
+++ b/DI-engine/ding/config/utils.py
@@ -0,0 +1,536 @@
+from typing import Optional, List
+import copy
+from easydict import EasyDict
+
+from ding.utils import find_free_port, find_free_port_slurm, node_to_partition, node_to_host, pretty_print, \
+ DEFAULT_K8S_COLLECTOR_PORT, DEFAULT_K8S_LEARNER_PORT, DEFAULT_K8S_COORDINATOR_PORT
+from dizoo.classic_control.cartpole.config.parallel import cartpole_dqn_config
+
+default_host = '0.0.0.0'
+default_port = 22270
+
+
+def set_host_port(cfg: EasyDict, coordinator_host: str, learner_host: str, collector_host: str) -> EasyDict:
+ cfg.coordinator.host = coordinator_host
+ if cfg.coordinator.port == 'auto':
+ cfg.coordinator.port = find_free_port(coordinator_host)
+ learner_count = 0
+ collector_count = 0
+ for k in cfg.keys():
+ if k == 'learner_aggregator':
+ raise NotImplementedError
+ if k.startswith('learner'):
+ if cfg[k].host == 'auto':
+ if isinstance(learner_host, list):
+ cfg[k].host = learner_host[learner_count]
+ learner_count += 1
+ elif isinstance(learner_host, str):
+ cfg[k].host = learner_host
+ else:
+ raise TypeError("not support learner_host type: {}".format(learner_host))
+ if cfg[k].port == 'auto':
+ cfg[k].port = find_free_port(cfg[k].host)
+ cfg[k].aggregator = False
+ if k.startswith('collector'):
+ if cfg[k].host == 'auto':
+ if isinstance(collector_host, list):
+ cfg[k].host = collector_host[collector_count]
+ collector_count += 1
+ elif isinstance(collector_host, str):
+ cfg[k].host = collector_host
+ else:
+ raise TypeError("not support collector_host type: {}".format(collector_host))
+ if cfg[k].port == 'auto':
+ cfg[k].port = find_free_port(cfg[k].host)
+ return cfg
+
+
+def set_host_port_slurm(cfg: EasyDict, coordinator_host: str, learner_node: list, collector_node: list) -> EasyDict:
+ cfg.coordinator.host = coordinator_host
+ if cfg.coordinator.port == 'auto':
+ cfg.coordinator.port = find_free_port(coordinator_host)
+ if isinstance(learner_node, str):
+ learner_node = [learner_node]
+ if isinstance(collector_node, str):
+ collector_node = [collector_node]
+ learner_count, collector_count = 0, 0
+ learner_multi = {}
+ for k in cfg.keys():
+ if learner_node is not None and k.startswith('learner'):
+ node = learner_node[learner_count % len(learner_node)]
+ cfg[k].node = node
+ cfg[k].partition = node_to_partition(node)
+ gpu_num = cfg[k].gpu_num
+ if cfg[k].host == 'auto':
+ cfg[k].host = node_to_host(node)
+ if cfg[k].port == 'auto':
+ if gpu_num == 1:
+ cfg[k].port = find_free_port_slurm(node)
+ learner_multi[k] = False
+ else:
+ cfg[k].port = [find_free_port_slurm(node) for _ in range(gpu_num)]
+ learner_multi[k] = True
+ learner_count += 1
+ if collector_node is not None and k.startswith('collector'):
+ node = collector_node[collector_count % len(collector_node)]
+ cfg[k].node = node
+ cfg[k].partition = node_to_partition(node)
+ if cfg[k].host == 'auto':
+ cfg[k].host = node_to_host(node)
+ if cfg[k].port == 'auto':
+ cfg[k].port = find_free_port_slurm(node)
+ collector_count += 1
+ for k, flag in learner_multi.items():
+ if flag:
+ host = cfg[k].host
+ learner_interaction_cfg = {str(i): [str(i), host, p] for i, p in enumerate(cfg[k].port)}
+ aggregator_cfg = dict(
+ master=dict(
+ host=host,
+ port=find_free_port_slurm(cfg[k].node),
+ ),
+ slave=dict(
+ host=host,
+ port=find_free_port_slurm(cfg[k].node),
+ ),
+ learner=learner_interaction_cfg,
+ node=cfg[k].node,
+ partition=cfg[k].partition,
+ )
+ cfg[k].aggregator = True
+ cfg['learner_aggregator' + k[7:]] = aggregator_cfg
+ else:
+ cfg[k].aggregator = False
+ return cfg
+
+
+def set_host_port_k8s(cfg: EasyDict, coordinator_port: int, learner_port: int, collector_port: int) -> EasyDict:
+ cfg.coordinator.host = default_host
+ cfg.coordinator.port = coordinator_port if coordinator_port is not None else DEFAULT_K8S_COORDINATOR_PORT
+ base_learner_cfg = None
+ base_collector_cfg = None
+ if learner_port is None:
+ learner_port = DEFAULT_K8S_LEARNER_PORT
+ if collector_port is None:
+ collector_port = DEFAULT_K8S_COLLECTOR_PORT
+ for k in cfg.keys():
+ if k.startswith('learner'):
+ # create the base learner config
+ if base_learner_cfg is None:
+ base_learner_cfg = copy.deepcopy(cfg[k])
+ base_learner_cfg.host = default_host
+ base_learner_cfg.port = learner_port
+ cfg[k].port = learner_port
+ elif k.startswith('collector'):
+ # create the base collector config
+ if base_collector_cfg is None:
+ base_collector_cfg = copy.deepcopy(cfg[k])
+ base_collector_cfg.host = default_host
+ base_collector_cfg.port = collector_port
+ cfg[k].port = collector_port
+ cfg['learner'] = base_learner_cfg
+ cfg['collector'] = base_collector_cfg
+ return cfg
+
+
+def set_learner_interaction_for_coordinator(cfg: EasyDict) -> EasyDict:
+ cfg.coordinator.learner = {}
+ for k in cfg.keys():
+ if k.startswith('learner') and not k.startswith('learner_aggregator'):
+ if cfg[k].aggregator:
+ dst_k = 'learner_aggregator' + k[7:]
+ cfg.coordinator.learner[k] = [k, cfg[dst_k].slave.host, cfg[dst_k].slave.port]
+ else:
+ dst_k = k
+ cfg.coordinator.learner[k] = [k, cfg[dst_k].host, cfg[dst_k].port]
+ return cfg
+
+
+def set_collector_interaction_for_coordinator(cfg: EasyDict) -> EasyDict:
+ cfg.coordinator.collector = {}
+ for k in cfg.keys():
+ if k.startswith('collector'):
+ cfg.coordinator.collector[k] = [k, cfg[k].host, cfg[k].port]
+ return cfg
+
+
+def set_system_cfg(cfg: EasyDict) -> EasyDict:
+ learner_num = cfg.main.policy.learn.learner.learner_num
+ collector_num = cfg.main.policy.collect.collector.collector_num
+ path_data = cfg.system.path_data
+ path_policy = cfg.system.path_policy
+ coordinator_cfg = cfg.system.coordinator
+ communication_mode = cfg.system.communication_mode
+ assert communication_mode in ['auto'], communication_mode
+ learner_gpu_num = cfg.system.learner_gpu_num
+ learner_multi_gpu = learner_gpu_num > 1
+ new_cfg = dict(coordinator=dict(
+ host='auto',
+ port='auto',
+ ))
+ new_cfg['coordinator'].update(coordinator_cfg)
+ for i in range(learner_num):
+ new_cfg[f'learner{i}'] = dict(
+ type=cfg.system.comm_learner.type,
+ import_names=cfg.system.comm_learner.import_names,
+ host='auto',
+ port='auto',
+ path_data=path_data,
+ path_policy=path_policy,
+ multi_gpu=learner_multi_gpu,
+ gpu_num=learner_gpu_num,
+ )
+ for i in range(collector_num):
+ new_cfg[f'collector{i}'] = dict(
+ type=cfg.system.comm_collector.type,
+ import_names=cfg.system.comm_collector.import_names,
+ host='auto',
+ port='auto',
+ path_data=path_data,
+ path_policy=path_policy,
+ )
+ return EasyDict(new_cfg)
+
+
+def parallel_transform(
+ cfg: dict,
+ coordinator_host: Optional[str] = None,
+ learner_host: Optional[List[str]] = None,
+ collector_host: Optional[List[str]] = None
+) -> None:
+ coordinator_host = default_host if coordinator_host is None else coordinator_host
+ collector_host = default_host if collector_host is None else collector_host
+ learner_host = default_host if learner_host is None else learner_host
+ cfg = EasyDict(cfg)
+ cfg.system = set_system_cfg(cfg)
+ cfg.system = set_host_port(cfg.system, coordinator_host, learner_host, collector_host)
+ cfg.system = set_learner_interaction_for_coordinator(cfg.system)
+ cfg.system = set_collector_interaction_for_coordinator(cfg.system)
+ return cfg
+
+
+def parallel_transform_slurm(
+ cfg: dict,
+ coordinator_host: Optional[str] = None,
+ learner_node: Optional[List[str]] = None,
+ collector_node: Optional[List[str]] = None
+) -> None:
+ cfg = EasyDict(cfg)
+ cfg.system = set_system_cfg(cfg)
+ cfg.system = set_host_port_slurm(cfg.system, coordinator_host, learner_node, collector_node)
+ cfg.system = set_learner_interaction_for_coordinator(cfg.system)
+ cfg.system = set_collector_interaction_for_coordinator(cfg.system)
+ pretty_print(cfg)
+ return cfg
+
+
+def parallel_transform_k8s(
+ cfg: dict,
+ coordinator_port: Optional[int] = None,
+ learner_port: Optional[int] = None,
+ collector_port: Optional[int] = None
+) -> None:
+ cfg = EasyDict(cfg)
+ cfg.system = set_system_cfg(cfg)
+ cfg.system = set_host_port_k8s(cfg.system, coordinator_port, learner_port, collector_port)
+ # learner/collector is created by opereator, so the following field is placeholder
+ cfg.system.coordinator.collector = {}
+ cfg.system.coordinator.learner = {}
+ pretty_print(cfg)
+ return cfg
+
+
+def save_config_formatted(config_: dict, path: str = 'formatted_total_config.py') -> None:
+ """
+ Overview:
+ save formatted configuration to python file that can be read by serial_pipeline directly.
+ Arguments:
+ - config (:obj:`dict`): Config dict
+ - path (:obj:`str`): Path of python file
+ """
+ with open(path, "w") as f:
+ f.write('from easydict import EasyDict\n\n')
+ f.write('main_config = dict(\n')
+ f.write(" exp_name='{}',\n".format(config_.exp_name))
+ for k, v in config_.items():
+ if (k == 'env'):
+ f.write(' env=dict(\n')
+ for k2, v2 in v.items():
+ if (k2 != 'type' and k2 != 'import_names' and k2 != 'manager'):
+ if (isinstance(v2, str)):
+ f.write(" {}='{}',\n".format(k2, v2))
+ else:
+ f.write(" {}={},\n".format(k2, v2))
+ if (k2 == 'manager'):
+ f.write(" manager=dict(\n")
+ for k3, v3 in v2.items():
+ if (v3 != 'cfg_type' and v3 != 'type'):
+ if (isinstance(v3, str)):
+ f.write(" {}='{}',\n".format(k3, v3))
+ elif v3 == float('inf'):
+ f.write(" {}=float('{}'),\n".format(k3, v3))
+ else:
+ f.write(" {}={},\n".format(k3, v3))
+ f.write(" ),\n")
+ f.write(" ),\n")
+ if (k == 'policy'):
+ f.write(' policy=dict(\n')
+ for k2, v2 in v.items():
+ if (k2 != 'type' and k2 != 'learn' and k2 != 'collect' and k2 != 'eval' and k2 != 'other'
+ and k2 != 'model'):
+ if (isinstance(v2, str)):
+ f.write(" {}='{}',\n".format(k2, v2))
+ else:
+ f.write(" {}={},\n".format(k2, v2))
+ elif (k2 == 'learn'):
+ f.write(" learn=dict(\n")
+ for k3, v3 in v2.items():
+ if (k3 != 'learner'):
+ if (isinstance(v3, str)):
+ f.write(" {}='{}',\n".format(k3, v3))
+ else:
+ f.write(" {}={},\n".format(k3, v3))
+ if (k3 == 'learner'):
+ f.write(" learner=dict(\n")
+ for k4, v4 in v3.items():
+ if (k4 != 'dataloader' and k4 != 'hook'):
+ if (isinstance(v4, str)):
+ f.write(" {}='{}',\n".format(k4, v4))
+ else:
+ f.write(" {}={},\n".format(k4, v4))
+ else:
+ if (k4 == 'dataloader'):
+ f.write(" dataloader=dict(\n")
+ for k5, v5 in v4.items():
+ if (isinstance(v5, str)):
+ f.write(" {}='{}',\n".format(k5, v5))
+ else:
+ f.write(" {}={},\n".format(k5, v5))
+ f.write(" ),\n")
+ if (k4 == 'hook'):
+ f.write(" hook=dict(\n")
+ for k5, v5 in v4.items():
+ if (isinstance(v5, str)):
+ f.write(" {}='{}',\n".format(k5, v5))
+ else:
+ f.write(" {}={},\n".format(k5, v5))
+ f.write(" ),\n")
+ f.write(" ),\n")
+ f.write(" ),\n")
+ elif (k2 == 'collect'):
+ f.write(" collect=dict(\n")
+ for k3, v3 in v2.items():
+ if (k3 != 'collector'):
+ if (isinstance(v3, str)):
+ f.write(" {}='{}',\n".format(k3, v3))
+ else:
+ f.write(" {}={},\n".format(k3, v3))
+ if (k3 == 'collector'):
+ f.write(" collector=dict(\n")
+ for k4, v4 in v3.items():
+ if (isinstance(v4, str)):
+ f.write(" {}='{}',\n".format(k4, v4))
+ else:
+ f.write(" {}={},\n".format(k4, v4))
+ f.write(" ),\n")
+ f.write(" ),\n")
+ elif (k2 == 'eval'):
+ f.write(" eval=dict(\n")
+ for k3, v3 in v2.items():
+ if (k3 != 'evaluator'):
+ if (isinstance(v3, str)):
+ f.write(" {}='{}',\n".format(k3, v3))
+ else:
+ f.write(" {}={},\n".format(k3, v3))
+ if (k3 == 'evaluator'):
+ f.write(" evaluator=dict(\n")
+ for k4, v4 in v3.items():
+ if (isinstance(v4, str)):
+ f.write(" {}='{}',\n".format(k4, v4))
+ else:
+ f.write(" {}={},\n".format(k4, v4))
+ f.write(" ),\n")
+ f.write(" ),\n")
+ elif (k2 == 'model'):
+ f.write(" model=dict(\n")
+ for k3, v3 in v2.items():
+ if (isinstance(v3, str)):
+ f.write(" {}='{}',\n".format(k3, v3))
+ else:
+ f.write(" {}={},\n".format(k3, v3))
+ f.write(" ),\n")
+ elif (k2 == 'other'):
+ f.write(" other=dict(\n")
+ for k3, v3 in v2.items():
+ if (k3 == 'replay_buffer'):
+ f.write(" replay_buffer=dict(\n")
+ for k4, v4 in v3.items():
+ if (k4 != 'monitor' and k4 != 'thruput_controller'):
+ if (isinstance(v4, dict)):
+ f.write(" {}=dict(\n".format(k4))
+ for k5, v5 in v4.items():
+ if (isinstance(v5, str)):
+ f.write(" {}='{}',\n".format(k5, v5))
+ elif v5 == float('inf'):
+ f.write(" {}=float('{}'),\n".format(k5, v5))
+ elif (isinstance(v5, dict)):
+ f.write(" {}=dict(\n".format(k5))
+ for k6, v6 in v5.items():
+ if (isinstance(v6, str)):
+ f.write(" {}='{}',\n".format(k6, v6))
+ elif v6 == float('inf'):
+ f.write(
+ " {}=float('{}'),\n".format(
+ k6, v6
+ )
+ )
+ elif (isinstance(v6, dict)):
+ f.write(" {}=dict(\n".format(k6))
+ for k7, v7 in v6.items():
+ if (isinstance(v7, str)):
+ f.write(
+ " {}='{}',\n".format(
+ k7, v7
+ )
+ )
+ elif v7 == float('inf'):
+ f.write(
+ " {}=float('{}'),\n".
+ format(k7, v7)
+ )
+ else:
+ f.write(
+ " {}={},\n".format(
+ k7, v7
+ )
+ )
+ f.write(" ),\n")
+ else:
+ f.write(" {}={},\n".format(k6, v6))
+ f.write(" ),\n")
+ else:
+ f.write(" {}={},\n".format(k5, v5))
+ f.write(" ),\n")
+ else:
+ if (isinstance(v4, str)):
+ f.write(" {}='{}',\n".format(k4, v4))
+ elif v4 == float('inf'):
+ f.write(" {}=float('{}'),\n".format(k4, v4))
+
+ else:
+ f.write(" {}={},\n".format(k4, v4))
+ else:
+ if (k4 == 'monitor'):
+ f.write(" monitor=dict(\n")
+ for k5, v5 in v4.items():
+ if (k5 == 'log_path'):
+ if (isinstance(v5, str)):
+ f.write(" {}='{}',\n".format(k5, v5))
+ else:
+ f.write(" {}={},\n".format(k5, v5))
+ else:
+ f.write(" {}=dict(\n".format(k5))
+ for k6, v6 in v5.items():
+ if (isinstance(v6, str)):
+ f.write(" {}='{}',\n".format(k6, v6))
+ else:
+ f.write(" {}={},\n".format(k6, v6))
+ f.write(" ),\n")
+ f.write(" ),\n")
+ if (k4 == 'thruput_controller'):
+ f.write(" thruput_controller=dict(\n")
+ for k5, v5 in v4.items():
+ if (isinstance(v5, dict)):
+ f.write(" {}=dict(\n".format(k5))
+ for k6, v6 in v5.items():
+ if (isinstance(v6, str)):
+ f.write(" {}='{}',\n".format(k6, v6))
+ elif v6 == float('inf'):
+ f.write(
+ " {}=float('{}'),\n".format(
+ k6, v6
+ )
+ )
+ else:
+ f.write(" {}={},\n".format(k6, v6))
+ f.write(" ),\n")
+ else:
+ if (isinstance(v5, str)):
+ f.write(" {}='{}',\n".format(k5, v5))
+ else:
+ f.write(" {}={},\n".format(k5, v5))
+ f.write(" ),\n")
+ f.write(" ),\n")
+ f.write(" ),\n")
+ f.write(" ),\n)\n")
+ f.write('main_config = EasyDict(main_config)\n')
+ f.write('main_config = main_config\n')
+ f.write('create_config = dict(\n')
+ for k, v in config_.items():
+ if (k == 'env'):
+ f.write(' env=dict(\n')
+ for k2, v2 in v.items():
+ if (k2 == 'type' or k2 == 'import_names'):
+ if isinstance(v2, str):
+ f.write(" {}='{}',\n".format(k2, v2))
+ else:
+ f.write(" {}={},\n".format(k2, v2))
+ f.write(" ),\n")
+ for k2, v2 in v.items():
+ if (k2 == 'manager'):
+ f.write(' env_manager=dict(\n')
+ for k3, v3 in v2.items():
+ if (k3 == 'cfg_type' or k3 == 'type'):
+ if (isinstance(v3, str)):
+ f.write(" {}='{}',\n".format(k3, v3))
+ else:
+ f.write(" {}={},\n".format(k3, v3))
+ f.write(" ),\n")
+ policy_type = config_.policy.type
+ if '_command' in policy_type:
+ f.write(" policy=dict(type='{}'),\n".format(policy_type[0:len(policy_type) - 8]))
+ else:
+ f.write(" policy=dict(type='{}'),\n".format(policy_type))
+ f.write(")\n")
+ f.write('create_config = EasyDict(create_config)\n')
+ f.write('create_config = create_config\n')
+
+
+parallel_test_main_config = cartpole_dqn_config
+parallel_test_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn_command'),
+ comm_learner=dict(
+ type='flask_fs',
+ import_names=['ding.worker.learner.comm.flask_fs_learner'],
+ ),
+ comm_collector=dict(
+ type='flask_fs',
+ import_names=['ding.worker.collector.comm.flask_fs_collector'],
+ ),
+ learner=dict(
+ type='base',
+ import_names=['ding.worker.learner.base_learner'],
+ ),
+ collector=dict(
+ type='zergling',
+ import_names=['ding.worker.collector.zergling_parallel_collector'],
+ ),
+ commander=dict(
+ type='naive',
+ import_names=['ding.worker.coordinator.base_parallel_commander'],
+ ),
+)
+parallel_test_create_config = EasyDict(parallel_test_create_config)
+parallel_test_system_config = dict(
+ coordinator=dict(),
+ path_data='.',
+ path_policy='.',
+ communication_mode='auto',
+ learner_gpu_num=1,
+)
+parallel_test_system_config = EasyDict(parallel_test_system_config)
diff --git a/DI-engine/ding/data/__init__.py b/DI-engine/ding/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b72987cac99d2dd339cdf40f834af1ee3588894e
--- /dev/null
+++ b/DI-engine/ding/data/__init__.py
@@ -0,0 +1,7 @@
+from torch.utils.data import Dataset, DataLoader
+from ding.utils.data import create_dataset, offline_data_save_type # for compatibility
+from .buffer import *
+from .storage import *
+from .storage_loader import StorageLoader, FileStorageLoader
+from .shm_buffer import ShmBufferContainer, ShmBuffer
+from .model_loader import ModelLoader, FileModelLoader
diff --git a/DI-engine/ding/data/buffer/__init__.py b/DI-engine/ding/data/buffer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d3cf7c0b13a5fa792832cb0151f9beb80a2cf4fc
--- /dev/null
+++ b/DI-engine/ding/data/buffer/__init__.py
@@ -0,0 +1,3 @@
+from .buffer import Buffer, apply_middleware, BufferedData
+from .deque_buffer import DequeBuffer
+from .deque_buffer_wrapper import DequeBufferWrapper
diff --git a/DI-engine/ding/data/buffer/buffer.py b/DI-engine/ding/data/buffer/buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..53b3f39bd14bc1e2dc65cb65a09db6f5f0313f28
--- /dev/null
+++ b/DI-engine/ding/data/buffer/buffer.py
@@ -0,0 +1,211 @@
+from abc import abstractmethod, ABC
+from typing import Any, List, Optional, Union, Callable
+import copy
+from dataclasses import dataclass
+from functools import wraps
+from ding.utils import fastcopy
+
+
+def apply_middleware(func_name: str):
+
+ def wrap_func(base_func: Callable):
+
+ @wraps(base_func)
+ def handler(buffer, *args, **kwargs):
+ """
+ Overview:
+ The real processing starts here, we apply the middleware one by one,
+ each middleware will receive next `chained` function, which is an executor of next
+ middleware. You can change the input arguments to the next `chained` middleware, and you
+ also can get the return value from the next middleware, so you have the
+ maximum freedom to choose at what stage to implement your method.
+ """
+
+ def wrap_handler(middleware, *args, **kwargs):
+ if len(middleware) == 0:
+ return base_func(buffer, *args, **kwargs)
+
+ def chain(*args, **kwargs):
+ return wrap_handler(middleware[1:], *args, **kwargs)
+
+ func = middleware[0]
+ return func(func_name, chain, *args, **kwargs)
+
+ return wrap_handler(buffer._middleware, *args, **kwargs)
+
+ return handler
+
+ return wrap_func
+
+
+@dataclass
+class BufferedData:
+ data: Any
+ index: str
+ meta: dict
+
+
+# Register new dispatcher on fastcopy to avoid circular references
+def _copy_buffereddata(d: BufferedData) -> BufferedData:
+ return BufferedData(data=fastcopy.copy(d.data), index=d.index, meta=fastcopy.copy(d.meta))
+
+
+fastcopy.dispatch[BufferedData] = _copy_buffereddata
+
+
+class Buffer(ABC):
+ """
+ Buffer is an abstraction of device storage, third-party services or data structures,
+ For example, memory queue, sum-tree, redis, or di-store.
+ """
+
+ def __init__(self, size: int) -> None:
+ self._middleware = []
+ self.size = size
+
+ @abstractmethod
+ def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
+ """
+ Overview:
+ Push data and it's meta information in buffer.
+ Arguments:
+ - data (:obj:`Any`): The data which will be pushed into buffer.
+ - meta (:obj:`dict`): Meta information, e.g. priority, count, staleness.
+ Returns:
+ - buffered_data (:obj:`BufferedData`): The pushed data.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def sample(
+ self,
+ size: Optional[int] = None,
+ indices: Optional[List[str]] = None,
+ replace: bool = False,
+ sample_range: Optional[slice] = None,
+ ignore_insufficient: bool = False,
+ groupby: Optional[str] = None,
+ unroll_len: Optional[int] = None
+ ) -> Union[List[BufferedData], List[List[BufferedData]]]:
+ """
+ Overview:
+ Sample data with length ``size``.
+ Arguments:
+ - size (:obj:`Optional[int]`): The number of the data that will be sampled.
+ - indices (:obj:`Optional[List[str]]`): Sample with multiple indices.
+ - replace (:obj:`bool`): If use replace is true, you may receive duplicated data from the buffer.
+ - sample_range (:obj:`slice`): Sample range slice.
+ - ignore_insufficient (:obj:`bool`): If ignore_insufficient is true, sampling more than buffer size
+ with no repetition will not cause an exception.
+ - groupby (:obj:`Optional[str]`): Groupby key in meta, i.e. groupby="episode"
+ - unroll_len (:obj:`Optional[int]`): Number of consecutive frames within a group.
+ Returns:
+ - sample_data (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`):
+ A list of data with length ``size``, may be nested if groupby is set.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool:
+ """
+ Overview:
+ Update data and meta by index
+ Arguments:
+ - index (:obj:`str`): Index of data.
+ - data (:obj:`any`): Pure data.
+ - meta (:obj:`dict`): Meta information.
+ Returns:
+ - success (:obj:`bool`): Success or not, if data with the index not exist in buffer, return false.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def delete(self, index: str):
+ """
+ Overview:
+ Delete one data sample by index
+ Arguments:
+ - index (:obj:`str`): Index
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def save_data(self, file_name: str):
+ """
+ Overview:
+ Save buffer data into a file.
+ Arguments:
+ - file_name (:obj:`str`): file name of buffer data
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def load_data(self, file_name: str):
+ """
+ Overview:
+ Load buffer data from a file.
+ Arguments:
+ - file_name (:obj:`str`): file name of buffer data
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def count(self) -> int:
+ raise NotImplementedError
+
+ @abstractmethod
+ def clear(self) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def get(self, idx: int) -> BufferedData:
+ """
+ Overview:
+ Get item by subscript index
+ Arguments:
+ - idx (:obj:`int`): Subscript index
+ Returns:
+ - buffered_data (:obj:`BufferedData`): Item from buffer
+ """
+ raise NotImplementedError
+
+ def use(self, func: Callable) -> "Buffer":
+ """
+ Overview:
+ Use algorithm middleware to modify the behavior of the buffer.
+ Every middleware should be a callable function, it will receive three argument parts, including:
+ 1. The buffer instance, you can use this instance to visit every thing of the buffer, including the storage.
+ 2. The functions called by the user, there are three methods named `push` , `sample` and `clear` , \
+ so you can use these function name to decide which action to choose.
+ 3. The remaining arguments passed by the user to the original function, will be passed in `*args` .
+
+ Each middleware handler should return two parts of the value, including:
+ 1. The first value is `done` (True or False), if done==True, the middleware chain will stop immediately, \
+ no more middleware will be executed during this execution
+ 2. The remaining values, will be passed to the next middleware or the default function in the buffer.
+ Arguments:
+ - func (:obj:`Callable`): The middleware handler
+ Returns:
+ - buffer (:obj:`Buffer`): The instance self
+ """
+ self._middleware.append(func)
+ return self
+
+ def view(self) -> "Buffer":
+ r"""
+ Overview:
+ A view is a new instance of buffer, with a deepcopy of every property except the storage.
+ The storage is shared among all the buffer instances.
+ Returns:
+ - buffer (:obj:`Buffer`): The instance self
+ """
+ return copy.copy(self)
+
+ def __copy__(self) -> "Buffer":
+ raise NotImplementedError
+
+ def __len__(self) -> int:
+ return self.count()
+
+ def __getitem__(self, idx: int) -> BufferedData:
+ return self.get(idx)
diff --git a/DI-engine/ding/data/buffer/deque_buffer.py b/DI-engine/ding/data/buffer/deque_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..26c7cebc8ed206c64eb296c3c693148a9b93abb0
--- /dev/null
+++ b/DI-engine/ding/data/buffer/deque_buffer.py
@@ -0,0 +1,386 @@
+import os
+import itertools
+import random
+import uuid
+from ditk import logging
+import hickle
+from typing import Any, Iterable, List, Optional, Tuple, Union
+from collections import Counter
+from collections import defaultdict, deque, OrderedDict
+from ding.data.buffer import Buffer, apply_middleware, BufferedData
+from ding.utils import fastcopy
+from ding.torch_utils import get_null_data
+
+
+class BufferIndex():
+ """
+ Overview:
+ Save index string and offset in key value pair.
+ """
+
+ def __init__(self, maxlen: int, *args, **kwargs):
+ self.maxlen = maxlen
+ self.__map = OrderedDict(*args, **kwargs)
+ self._last_key = next(reversed(self.__map)) if len(self) > 0 else None
+ self._cumlen = len(self.__map)
+
+ def get(self, key: str) -> int:
+ value = self.__map[key]
+ value = value % self._cumlen + min(0, (self.maxlen - self._cumlen))
+ return value
+
+ def __len__(self) -> int:
+ return len(self.__map)
+
+ def has(self, key: str) -> bool:
+ return key in self.__map
+
+ def append(self, key: str):
+ self.__map[key] = self.__map[self._last_key] + 1 if self._last_key else 0
+ self._last_key = key
+ self._cumlen += 1
+ if len(self) > self.maxlen:
+ self.__map.popitem(last=False)
+
+ def clear(self):
+ self.__map = OrderedDict()
+ self._last_key = None
+ self._cumlen = 0
+
+
+class DequeBuffer(Buffer):
+ """
+ Overview:
+ A buffer implementation based on the deque structure.
+ """
+
+ def __init__(self, size: int, sliced: bool = False) -> None:
+ """
+ Overview:
+ The initialization method of DequeBuffer.
+ Arguments:
+ - size (:obj:`int`): The maximum number of objects that the buffer can hold.
+ - sliced (:obj:`bool`): The flag whether slice data by unroll_len when sample by group
+ """
+ super().__init__(size=size)
+ self.storage = deque(maxlen=size)
+ self.indices = BufferIndex(maxlen=size)
+ self.sliced = sliced
+ # Meta index is a dict which uses deque as values
+ self.meta_index = {}
+
+ @apply_middleware("push")
+ def push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
+ """
+ Overview:
+ The method that input the objects and the related meta information into the buffer.
+ Arguments:
+ - data (:obj:`Any`): The input object which can be in any format.
+ - meta (:obj:`Optional[dict]`): A dict that helps describe data, such as\
+ category, label, priority, etc. Default to ``None``.
+ """
+ return self._push(data, meta)
+
+ @apply_middleware("sample")
+ def sample(
+ self,
+ size: Optional[int] = None,
+ indices: Optional[List[str]] = None,
+ replace: bool = False,
+ sample_range: Optional[slice] = None,
+ ignore_insufficient: bool = False,
+ groupby: Optional[str] = None,
+ unroll_len: Optional[int] = None
+ ) -> Union[List[BufferedData], List[List[BufferedData]]]:
+ """
+ Overview:
+ The method that randomly sample data from the buffer or retrieve certain data by indices.
+ Arguments:
+ - size (:obj:`Optional[int]`): The number of objects to be obtained from the buffer.
+ If ``indices`` is not specified, the ``size`` is required to randomly sample the\
+ corresponding number of objects from the buffer.
+ - indices (:obj:`Optional[List[str]]`): Only used when you want to retrieve data by indices.
+ Default to ``None``.
+ - replace (:obj:`bool`): As the sampling process is carried out one by one, this parameter\
+ determines whether the previous samples will be put back into the buffer for subsequent\
+ sampling. Default to ``False``, it means that duplicate samples will not appear in one\
+ ``sample`` call.
+ - sample_range (:obj:`Optional[slice]`): The indices range to sample data. Default to ``None``,\
+ it means no restrictions on the range of indices for the sampling process.
+ - ignore_insufficient (:obj:`bool`): whether throw `` ValueError`` if the sampled size is smaller\
+ than the required size. Default to ``False``.
+ - groupby (:obj:`Optional[str]`): If this parameter is activated, the method will return a\
+ target size of object groups.
+ - unroll_len (:obj:`Optional[int]`): The unroll length of a trajectory, used only when the\
+ ``groupby`` is activated.
+ Returns:
+ - sampled_data (Union[List[BufferedData], List[List[BufferedData]]]): The sampling result.
+ """
+ storage = self.storage
+ if sample_range:
+ storage = list(itertools.islice(self.storage, sample_range.start, sample_range.stop, sample_range.step))
+
+ # Size and indices
+ assert size or indices, "One of size and indices must not be empty."
+ if (size and indices) and (size != len(indices)):
+ raise AssertionError("Size and indices length must be equal.")
+ if not size:
+ size = len(indices)
+ # Indices and groupby
+ assert not (indices and groupby), "Cannot use groupby and indicex at the same time."
+ # Groupby and unroll_len
+ assert not unroll_len or (
+ unroll_len and groupby
+ ), "Parameter unroll_len needs to be used in conjunction with groupby."
+
+ value_error = None
+ sampled_data = []
+ if indices:
+ indices_set = set(indices)
+ hashed_data = filter(lambda item: item.index in indices_set, storage)
+ hashed_data = map(lambda item: (item.index, item), hashed_data)
+ hashed_data = dict(hashed_data)
+ # Re-sample and return in indices order
+ sampled_data = [hashed_data[index] for index in indices]
+ elif groupby:
+ sampled_data = self._sample_by_group(
+ size=size, groupby=groupby, replace=replace, unroll_len=unroll_len, storage=storage, sliced=self.sliced
+ )
+ else:
+ if replace:
+ sampled_data = random.choices(storage, k=size)
+ else:
+ try:
+ sampled_data = random.sample(storage, k=size)
+ except ValueError as e:
+ value_error = e
+
+ if value_error or len(sampled_data) != size:
+ if ignore_insufficient:
+ logging.warning(
+ "Sample operation is ignored due to data insufficient, current buffer is {} while sample is {}".
+ format(self.count(), size)
+ )
+ else:
+ raise ValueError("There are less than {} records/groups in buffer({})".format(size, self.count()))
+
+ sampled_data = self._independence(sampled_data)
+
+ return sampled_data
+
+ @apply_middleware("update")
+ def update(self, index: str, data: Optional[Any] = None, meta: Optional[dict] = None) -> bool:
+ """
+ Overview:
+ the method that update data and the related meta information with a certain index.
+ Arguments:
+ - data (:obj:`Any`): The data which is supposed to replace the old one. If you set it\
+ to ``None``, nothing will happen to the old record.
+ - meta (:obj:`Optional[dict]`): The new dict which is supposed to merge with the old one.
+ """
+ if not self.indices.has(index):
+ return False
+ i = self.indices.get(index)
+ item = self.storage[i]
+ if data is not None:
+ item.data = data
+ if meta is not None:
+ item.meta = meta
+ for key in self.meta_index:
+ self.meta_index[key][i] = meta[key] if key in meta else None
+ return True
+
+ @apply_middleware("delete")
+ def delete(self, indices: Union[str, Iterable[str]]) -> None:
+ """
+ Overview:
+ The method that delete the data and related meta information by specific indices.
+ Arguments:
+ - indices (Union[str, Iterable[str]]): Where the data to be cleared in the buffer.
+ """
+ if isinstance(indices, str):
+ indices = [indices]
+ del_idx = []
+ for index in indices:
+ if self.indices.has(index):
+ del_idx.append(self.indices.get(index))
+ if len(del_idx) == 0:
+ return
+ del_idx = sorted(del_idx, reverse=True)
+ for idx in del_idx:
+ del self.storage[idx]
+ remain_indices = [item.index for item in self.storage]
+ key_value_pairs = zip(remain_indices, range(len(indices)))
+ self.indices = BufferIndex(self.storage.maxlen, key_value_pairs)
+
+ def save_data(self, file_name: str):
+ if not os.path.exists(os.path.dirname(file_name)):
+ # If the folder for the specified file does not exist, it will be created.
+ if os.path.dirname(file_name) != "":
+ os.makedirs(os.path.dirname(file_name))
+ hickle.dump(
+ py_obj=(
+ self.storage,
+ self.indices,
+ self.meta_index,
+ ), file_obj=file_name
+ )
+
+ def load_data(self, file_name: str):
+ self.storage, self.indices, self.meta_index = hickle.load(file_name)
+
+ def count(self) -> int:
+ """
+ Overview:
+ The method that returns the current length of the buffer.
+ """
+ return len(self.storage)
+
+ def get(self, idx: int) -> BufferedData:
+ """
+ Overview:
+ The method that returns the BufferedData object given a specific index.
+ """
+ return self.storage[idx]
+
+ @apply_middleware("clear")
+ def clear(self) -> None:
+ """
+ Overview:
+ The method that clear all data, indices, and the meta information in the buffer.
+ """
+ self.storage.clear()
+ self.indices.clear()
+ self.meta_index = {}
+
+ def _push(self, data: Any, meta: Optional[dict] = None) -> BufferedData:
+ index = uuid.uuid1().hex
+ if meta is None:
+ meta = {}
+ buffered = BufferedData(data=data, index=index, meta=meta)
+ self.storage.append(buffered)
+ self.indices.append(index)
+ # Add meta index
+ for key in self.meta_index:
+ self.meta_index[key].append(meta[key] if key in meta else None)
+
+ return buffered
+
+ def _independence(
+ self, buffered_samples: Union[List[BufferedData], List[List[BufferedData]]]
+ ) -> Union[List[BufferedData], List[List[BufferedData]]]:
+ """
+ Overview:
+ Make sure that each record is different from each other, but remember that this function
+ is different from clone_object. You may change the data in the buffer by modifying a record.
+ Arguments:
+ - buffered_samples (:obj:`Union[List[BufferedData], List[List[BufferedData]]]`) Sampled data,
+ can be nested if groupby has been set.
+ """
+ if len(buffered_samples) == 0:
+ return buffered_samples
+ occurred = defaultdict(int)
+
+ for i, buffered in enumerate(buffered_samples):
+ if isinstance(buffered, list):
+ sampled_list = buffered
+ # Loop over nested samples
+ for j, buffered in enumerate(sampled_list):
+ occurred[buffered.index] += 1
+ if occurred[buffered.index] > 1:
+ sampled_list[j] = fastcopy.copy(buffered)
+ elif isinstance(buffered, BufferedData):
+ occurred[buffered.index] += 1
+ if occurred[buffered.index] > 1:
+ buffered_samples[i] = fastcopy.copy(buffered)
+ else:
+ raise Exception("Get unexpected buffered type {}".format(type(buffered)))
+ return buffered_samples
+
+ def _sample_by_group(
+ self,
+ size: int,
+ groupby: str,
+ replace: bool = False,
+ unroll_len: Optional[int] = None,
+ storage: deque = None,
+ sliced: bool = False
+ ) -> List[List[BufferedData]]:
+ """
+ Overview:
+ Sampling by `group` instead of records, the result will be a collection
+ of lists with a length of `size`, but the length of each list may be different from other lists.
+ """
+ if storage is None:
+ storage = self.storage
+ if groupby not in self.meta_index:
+ self._create_index(groupby)
+
+ def filter_by_unroll_len():
+ "Filter groups by unroll len, ensure count of items in each group is greater than unroll_len."
+ group_count = Counter(self.meta_index[groupby])
+ group_names = []
+ for key, count in group_count.items():
+ if count >= unroll_len:
+ group_names.append(key)
+ return group_names
+
+ if unroll_len and unroll_len > 1:
+ group_names = filter_by_unroll_len()
+ if len(group_names) == 0:
+ return []
+ else:
+ group_names = list(set(self.meta_index[groupby]))
+
+ sampled_groups = []
+ if replace:
+ sampled_groups = random.choices(group_names, k=size)
+ else:
+ try:
+ sampled_groups = random.sample(group_names, k=size)
+ except ValueError:
+ raise ValueError("There are less than {} groups in buffer({} groups)".format(size, len(group_names)))
+
+ # Build dict like {"group name": [records]}
+ sampled_data = defaultdict(list)
+ for buffered in storage:
+ meta_value = buffered.meta[groupby] if groupby in buffered.meta else None
+ if meta_value in sampled_groups:
+ sampled_data[buffered.meta[groupby]].append(buffered)
+
+ final_sampled_data = []
+ for group in sampled_groups:
+ seq_data = sampled_data[group]
+ # Filter records by unroll_len
+ if unroll_len:
+ # slice b unroll_len. If don’t do this, more likely obtain duplicate data, \
+ # and the training will easily crash.
+ if sliced:
+ start_indice = random.choice(range(max(1, len(seq_data))))
+ start_indice = start_indice // unroll_len
+ if start_indice == (len(seq_data) - 1) // unroll_len:
+ seq_data = seq_data[-unroll_len:]
+ else:
+ seq_data = seq_data[start_indice * unroll_len:start_indice * unroll_len + unroll_len]
+ else:
+ start_indice = random.choice(range(max(1, len(seq_data) - unroll_len)))
+ seq_data = seq_data[start_indice:start_indice + unroll_len]
+
+ final_sampled_data.append(seq_data)
+
+ return final_sampled_data
+
+ def _create_index(self, meta_key: str):
+ self.meta_index[meta_key] = deque(maxlen=self.storage.maxlen)
+ for data in self.storage:
+ self.meta_index[meta_key].append(data.meta[meta_key] if meta_key in data.meta else None)
+
+ def __iter__(self) -> deque:
+ return iter(self.storage)
+
+ def __copy__(self) -> "DequeBuffer":
+ buffer = type(self)(size=self.storage.maxlen)
+ buffer.storage = self.storage
+ buffer.meta_index = self.meta_index
+ buffer.indices = self.indices
+ return buffer
diff --git a/DI-engine/ding/data/buffer/deque_buffer_wrapper.py b/DI-engine/ding/data/buffer/deque_buffer_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2e4945a9f2660e2bd1685d3470e917e4f052796
--- /dev/null
+++ b/DI-engine/ding/data/buffer/deque_buffer_wrapper.py
@@ -0,0 +1,121 @@
+import os
+from typing import Optional
+import copy
+from easydict import EasyDict
+import numpy as np
+import hickle
+
+from ding.data.buffer import DequeBuffer
+from ding.data.buffer.middleware import use_time_check, PriorityExperienceReplay
+from ding.utils import BUFFER_REGISTRY
+
+
+@BUFFER_REGISTRY.register('deque')
+class DequeBufferWrapper(object):
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ config = dict(
+ replay_buffer_size=10000,
+ max_use=float("inf"),
+ train_iter_per_log=100,
+ priority=False,
+ priority_IS_weight=False,
+ priority_power_factor=0.6,
+ IS_weight_power_factor=0.4,
+ IS_weight_anneal_train_iter=int(1e5),
+ priority_max_limit=1000,
+ )
+
+ def __init__(
+ self,
+ cfg: EasyDict,
+ tb_logger: Optional[object] = None,
+ exp_name: str = 'default_experiement',
+ instance_name: str = 'buffer'
+ ) -> None:
+ self.cfg = cfg
+ self.priority_max_limit = cfg.priority_max_limit
+ self.name = '{}_iter'.format(instance_name)
+ self.tb_logger = tb_logger
+ self.buffer = DequeBuffer(size=cfg.replay_buffer_size)
+ self.last_log_train_iter = -1
+
+ # use_count middleware
+ if self.cfg.max_use != float("inf"):
+ self.buffer.use(use_time_check(self.buffer, max_use=self.cfg.max_use))
+ # priority middleware
+ if self.cfg.priority:
+ self.buffer.use(
+ PriorityExperienceReplay(
+ self.buffer,
+ IS_weight=self.cfg.priority_IS_weight,
+ priority_power_factor=self.cfg.priority_power_factor,
+ IS_weight_power_factor=self.cfg.IS_weight_power_factor,
+ IS_weight_anneal_train_iter=self.cfg.IS_weight_anneal_train_iter
+ )
+ )
+ self.last_sample_index = None
+ self.last_sample_meta = None
+
+ def sample(self, size: int, train_iter: int = 0):
+ output = self.buffer.sample(size=size, ignore_insufficient=True)
+ if len(output) > 0:
+ if self.last_log_train_iter == -1 or train_iter - self.last_log_train_iter >= self.cfg.train_iter_per_log:
+ meta = [o.meta for o in output]
+ if self.cfg.max_use != float("inf"):
+ use_count_avg = np.mean([m['use_count'] for m in meta])
+ self.tb_logger.add_scalar('{}/use_count_avg'.format(self.name), use_count_avg, train_iter)
+ if self.cfg.priority:
+ self.last_sample_index = [o.index for o in output]
+ self.last_sample_meta = meta
+ priority_list = [m['priority'] for m in meta]
+ priority_avg = np.mean(priority_list)
+ priority_max = np.max(priority_list)
+ self.tb_logger.add_scalar('{}/priority_avg'.format(self.name), priority_avg, train_iter)
+ self.tb_logger.add_scalar('{}/priority_max'.format(self.name), priority_max, train_iter)
+ self.tb_logger.add_scalar('{}/buffer_data_count'.format(self.name), self.buffer.count(), train_iter)
+ self.last_log_train_iter = train_iter
+
+ data = [o.data for o in output]
+ if self.cfg.priority_IS_weight:
+ IS = [o.meta['priority_IS'] for o in output]
+ for i in range(len(data)):
+ data[i]['IS'] = IS[i]
+ return data
+ else:
+ return None
+
+ def push(self, data, cur_collector_envstep: int = -1) -> None:
+ for d in data:
+ meta = {}
+ if self.cfg.priority and 'priority' in d:
+ init_priority = d.pop('priority')
+ meta['priority'] = init_priority
+ self.buffer.push(d, meta=meta)
+
+ def update(self, meta: dict) -> None:
+ if not self.cfg.priority:
+ return
+ if self.last_sample_index is None:
+ return
+ new_meta = self.last_sample_meta
+ for m, p in zip(new_meta, meta['priority']):
+ m['priority'] = min(self.priority_max_limit, p)
+ for idx, m in zip(self.last_sample_index, new_meta):
+ self.buffer.update(idx, data=None, meta=m)
+ self.last_sample_index = None
+ self.last_sample_meta = None
+
+ def count(self) -> int:
+ return self.buffer.count()
+
+ def save_data(self, file_name):
+ self.buffer.save_data(file_name)
+
+ def load_data(self, file_name: str):
+ self.buffer.load_data(file_name)
diff --git a/DI-engine/ding/data/buffer/middleware/__init__.py b/DI-engine/ding/data/buffer/middleware/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c00edfb3e0fd9b8459752dfc36a79f776e70a9ce
--- /dev/null
+++ b/DI-engine/ding/data/buffer/middleware/__init__.py
@@ -0,0 +1,7 @@
+from .clone_object import clone_object
+from .use_time_check import use_time_check
+from .staleness_check import staleness_check
+from .priority import PriorityExperienceReplay
+from .padding import padding
+from .group_sample import group_sample
+from .sample_range_view import sample_range_view
diff --git a/DI-engine/ding/data/buffer/middleware/clone_object.py b/DI-engine/ding/data/buffer/middleware/clone_object.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f1e5be06f8ee6d242bfe32aa2f11a4513c585d0
--- /dev/null
+++ b/DI-engine/ding/data/buffer/middleware/clone_object.py
@@ -0,0 +1,29 @@
+from typing import Callable, Any, List, Union
+from ding.data.buffer import BufferedData
+from ding.utils import fastcopy
+
+
+def clone_object():
+ """
+ Overview:
+ This middleware freezes the objects saved in memory buffer and return copies during sampling,
+ try this middleware when you need to keep the object unchanged in buffer, and modify\
+ the object after sampling it (usually in multiple threads)
+ """
+
+ def push(chain: Callable, data: Any, *args, **kwargs) -> BufferedData:
+ data = fastcopy.copy(data)
+ return chain(data, *args, **kwargs)
+
+ def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
+ data = chain(*args, **kwargs)
+ return fastcopy.copy(data)
+
+ def _clone_object(action: str, chain: Callable, *args, **kwargs):
+ if action == "push":
+ return push(chain, *args, **kwargs)
+ elif action == "sample":
+ return sample(chain, *args, **kwargs)
+ return chain(*args, **kwargs)
+
+ return _clone_object
diff --git a/DI-engine/ding/data/buffer/middleware/group_sample.py b/DI-engine/ding/data/buffer/middleware/group_sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..10edc2b2f608eb8b64af558f0f328d7363132bfe
--- /dev/null
+++ b/DI-engine/ding/data/buffer/middleware/group_sample.py
@@ -0,0 +1,37 @@
+import random
+from typing import Callable, List
+from ding.data.buffer.buffer import BufferedData
+
+
+def group_sample(size_in_group: int, ordered_in_group: bool = True, max_use_in_group: bool = True) -> Callable:
+ """
+ Overview:
+ The middleware is designed to process the data in each group after sampling from the buffer.
+ Arguments:
+ - size_in_group (:obj:`int`): Sample size in each group.
+ - ordered_in_group (:obj:`bool`): Whether to keep the original order of records, default is true.
+ - max_use_in_group (:obj:`bool`): Whether to use as much data in each group as possible, default is true.
+ """
+
+ def sample(chain: Callable, *args, **kwargs) -> List[List[BufferedData]]:
+ if not kwargs.get("groupby"):
+ raise Exception("Group sample must be used when the `groupby` parameter is specified.")
+ sampled_data = chain(*args, **kwargs)
+ for i, grouped_data in enumerate(sampled_data):
+ if ordered_in_group:
+ if max_use_in_group:
+ end = max(0, len(grouped_data) - size_in_group) + 1
+ else:
+ end = len(grouped_data)
+ start_idx = random.choice(range(end))
+ sampled_data[i] = grouped_data[start_idx:start_idx + size_in_group]
+ else:
+ sampled_data[i] = random.sample(grouped_data, k=size_in_group)
+ return sampled_data
+
+ def _group_sample(action: str, chain: Callable, *args, **kwargs):
+ if action == "sample":
+ return sample(chain, *args, **kwargs)
+ return chain(*args, **kwargs)
+
+ return _group_sample
diff --git a/DI-engine/ding/data/buffer/middleware/padding.py b/DI-engine/ding/data/buffer/middleware/padding.py
new file mode 100644
index 0000000000000000000000000000000000000000..6895fb4530bf91559efe95e545d97f80d7de967a
--- /dev/null
+++ b/DI-engine/ding/data/buffer/middleware/padding.py
@@ -0,0 +1,40 @@
+import random
+from typing import Callable, Union, List
+
+from ding.data.buffer import BufferedData
+from ding.utils import fastcopy
+
+
+def padding(policy="random"):
+ """
+ Overview:
+ Fill the nested buffer list to the same size as the largest list.
+ The default policy `random` will randomly select data from each group
+ and fill it into the current group list.
+ Arguments:
+ - policy (:obj:`str`): Padding policy, supports `random`, `none`.
+ """
+
+ def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
+ sampled_data = chain(*args, **kwargs)
+ if len(sampled_data) == 0 or isinstance(sampled_data[0], BufferedData):
+ return sampled_data
+ max_len = len(max(sampled_data, key=len))
+ for i, grouped_data in enumerate(sampled_data):
+ group_len = len(grouped_data)
+ if group_len == max_len:
+ continue
+ for _ in range(max_len - group_len):
+ if policy == "random":
+ sampled_data[i].append(fastcopy.copy(random.choice(grouped_data)))
+ elif policy == "none":
+ sampled_data[i].append(BufferedData(data=None, index=None, meta=None))
+
+ return sampled_data
+
+ def _padding(action: str, chain: Callable, *args, **kwargs):
+ if action == "sample":
+ return sample(chain, *args, **kwargs)
+ return chain(*args, **kwargs)
+
+ return _padding
diff --git a/DI-engine/ding/data/buffer/middleware/priority.py b/DI-engine/ding/data/buffer/middleware/priority.py
new file mode 100644
index 0000000000000000000000000000000000000000..017b302a5fc15e15fd3235505b01de6ff3589803
--- /dev/null
+++ b/DI-engine/ding/data/buffer/middleware/priority.py
@@ -0,0 +1,154 @@
+from typing import Callable, Any, List, Dict, Optional, Union, TYPE_CHECKING
+import copy
+import numpy as np
+import torch
+from ding.utils import SumSegmentTree, MinSegmentTree
+from ding.data.buffer.buffer import BufferedData
+if TYPE_CHECKING:
+ from ding.data.buffer.buffer import Buffer
+
+
+class PriorityExperienceReplay:
+ """
+ Overview:
+ The middleware that implements priority experience replay (PER).
+ """
+
+ def __init__(
+ self,
+ buffer: 'Buffer',
+ IS_weight: bool = True,
+ priority_power_factor: float = 0.6,
+ IS_weight_power_factor: float = 0.4,
+ IS_weight_anneal_train_iter: int = int(1e5),
+ ) -> None:
+ """
+ Arguments:
+ - buffer (:obj:`Buffer`): The buffer to use PER.
+ - IS_weight (:obj:`bool`): Whether use importance sampling or not.
+ - priority_power_factor (:obj:`float`): The factor that adjust the sensitivity between\
+ the sampling probability and the priority level.
+ - IS_weight_power_factor (:obj:`float`): The factor that adjust the sensitivity between\
+ the sample rarity and sampling probability in importance sampling.
+ - IS_weight_anneal_train_iter (:obj:`float`): The factor that controls the increasing of\
+ ``IS_weight_power_factor`` during training.
+ """
+
+ self.buffer = buffer
+ self.buffer_idx = {}
+ self.buffer_size = buffer.size
+ self.IS_weight = IS_weight
+ self.priority_power_factor = priority_power_factor
+ self.IS_weight_power_factor = IS_weight_power_factor
+ self.IS_weight_anneal_train_iter = IS_weight_anneal_train_iter
+
+ # Max priority till now, it's used to initizalize data's priority if "priority" is not passed in with the data.
+ self.max_priority = 1.0
+ # Capacity needs to be the power of 2.
+ capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size))))
+ self.sum_tree = SumSegmentTree(capacity)
+ if self.IS_weight:
+ self.min_tree = MinSegmentTree(capacity)
+ self.delta_anneal = (1 - self.IS_weight_power_factor) / self.IS_weight_anneal_train_iter
+ self.pivot = 0
+
+ def push(self, chain: Callable, data: Any, meta: Optional[dict] = None, *args, **kwargs) -> BufferedData:
+ if meta is None:
+ if 'priority' in data:
+ meta = {'priority': data.pop('priority')}
+ else:
+ meta = {'priority': self.max_priority}
+ else:
+ if 'priority' not in meta:
+ meta['priority'] = self.max_priority
+ meta['priority_idx'] = self.pivot
+ self._update_tree(meta['priority'], self.pivot)
+ buffered = chain(data, meta=meta, *args, **kwargs)
+ index = buffered.index
+ self.buffer_idx[self.pivot] = index
+ self.pivot = (self.pivot + 1) % self.buffer_size
+ return buffered
+
+ def sample(self, chain: Callable, size: int, *args,
+ **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
+ # Divide [0, 1) into size intervals on average
+ intervals = np.array([i * 1.0 / size for i in range(size)])
+ # Uniformly sample within each interval
+ mass = intervals + np.random.uniform(size=(size, )) * 1. / size
+ # Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree)
+ mass *= self.sum_tree.reduce()
+ indices = [self.sum_tree.find_prefixsum_idx(m) for m in mass]
+ indices = [self.buffer_idx[i] for i in indices]
+ # Sample with indices
+ data = chain(indices=indices, *args, **kwargs)
+ if self.IS_weight:
+ # Calculate max weight for normalizing IS
+ sum_tree_root = self.sum_tree.reduce()
+ p_min = self.min_tree.reduce() / sum_tree_root
+ buffer_count = self.buffer.count()
+ max_weight = (buffer_count * p_min) ** (-self.IS_weight_power_factor)
+ for i in range(len(data)):
+ meta = data[i].meta
+ priority_idx = meta['priority_idx']
+ p_sample = self.sum_tree[priority_idx] / sum_tree_root
+ weight = (buffer_count * p_sample) ** (-self.IS_weight_power_factor)
+ meta['priority_IS'] = weight / max_weight
+ data[i].data['priority_IS'] = torch.as_tensor([meta['priority_IS']]).float() # for compability
+ self.IS_weight_power_factor = min(1.0, self.IS_weight_power_factor + self.delta_anneal)
+ return data
+
+ def update(self, chain: Callable, index: str, data: Any, meta: Any, *args, **kwargs) -> None:
+ update_flag = chain(index, data, meta, *args, **kwargs)
+ if update_flag: # when update succeed
+ assert meta is not None, "Please indicate dict-type meta in priority update"
+ new_priority, idx = meta['priority'], meta['priority_idx']
+ assert new_priority >= 0, "new_priority should greater than 0, but found {}".format(new_priority)
+ new_priority += 1e-5 # Add epsilon to avoid priority == 0
+ self._update_tree(new_priority, idx)
+ self.max_priority = max(self.max_priority, new_priority)
+
+ def delete(self, chain: Callable, index: str, *args, **kwargs) -> None:
+ for item in self.buffer.storage:
+ meta = item.meta
+ priority_idx = meta['priority_idx']
+ self.sum_tree[priority_idx] = self.sum_tree.neutral_element
+ self.min_tree[priority_idx] = self.min_tree.neutral_element
+ self.buffer_idx.pop(priority_idx)
+ return chain(index, *args, **kwargs)
+
+ def clear(self, chain: Callable) -> None:
+ self.max_priority = 1.0
+ capacity = int(np.power(2, np.ceil(np.log2(self.buffer_size))))
+ self.sum_tree = SumSegmentTree(capacity)
+ if self.IS_weight:
+ self.min_tree = MinSegmentTree(capacity)
+ self.buffer_idx = {}
+ self.pivot = 0
+ chain()
+
+ def _update_tree(self, priority: float, idx: int) -> None:
+ weight = priority ** self.priority_power_factor
+ self.sum_tree[idx] = weight
+ if self.IS_weight:
+ self.min_tree[idx] = weight
+
+ def state_dict(self) -> Dict:
+ return {
+ 'max_priority': self.max_priority,
+ 'IS_weight_power_factor': self.IS_weight_power_factor,
+ 'sumtree': self.sumtree,
+ 'mintree': self.mintree,
+ 'buffer_idx': self.buffer_idx,
+ }
+
+ def load_state_dict(self, _state_dict: Dict, deepcopy: bool = False) -> None:
+ for k, v in _state_dict.items():
+ if deepcopy:
+ setattr(self, '{}'.format(k), copy.deepcopy(v))
+ else:
+ setattr(self, '{}'.format(k), v)
+
+ def __call__(self, action: str, chain: Callable, *args, **kwargs) -> Any:
+ if action in ["push", "sample", "update", "delete", "clear"]:
+ return getattr(self, action)(chain, *args, **kwargs)
+ return chain(*args, **kwargs)
diff --git a/DI-engine/ding/data/buffer/middleware/sample_range_view.py b/DI-engine/ding/data/buffer/middleware/sample_range_view.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0465f90c642cfa8e89479efde90d092b3021d0d
--- /dev/null
+++ b/DI-engine/ding/data/buffer/middleware/sample_range_view.py
@@ -0,0 +1,27 @@
+from typing import Callable, Any, List, Optional, Union, TYPE_CHECKING
+from ding.data.buffer import BufferedData
+if TYPE_CHECKING:
+ from ding.data.buffer.buffer import Buffer
+
+
+def sample_range_view(buffer_: 'Buffer', start: Optional[int] = None, end: Optional[int] = None) -> Callable:
+ """
+ Overview:
+ The middleware that places restrictions on the range of indices during sampling.
+ Arguments:
+ - start (:obj:`int`): The starting index.
+ - end (:obj:`int`): One above the ending index.
+ """
+ assert start is not None or end is not None
+ if start and start < 0:
+ start = buffer_.size + start
+ if end and end < 0:
+ end = buffer_.size + end
+ sample_range = slice(start, end)
+
+ def _sample_range_view(action: str, chain: Callable, *args, **kwargs) -> Any:
+ if action == "sample":
+ return chain(*args, sample_range=sample_range)
+ return chain(*args, **kwargs)
+
+ return _sample_range_view
diff --git a/DI-engine/ding/data/buffer/middleware/staleness_check.py b/DI-engine/ding/data/buffer/middleware/staleness_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fb92ad06980f98fe3a10867c5ebf5254ffec107
--- /dev/null
+++ b/DI-engine/ding/data/buffer/middleware/staleness_check.py
@@ -0,0 +1,41 @@
+from typing import Callable, Any, List, TYPE_CHECKING
+if TYPE_CHECKING:
+ from ding.data.buffer.buffer import Buffer
+
+
+def staleness_check(buffer_: 'Buffer', max_staleness: int = float("inf")) -> Callable:
+ """
+ Overview:
+ This middleware aims to check staleness before each sample operation,
+ staleness = train_iter_sample_data - train_iter_data_collected, means how old/off-policy the data is,
+ If data's staleness is greater(>) than max_staleness, this data will be removed from buffer as soon as possible.
+ Arguments:
+ - max_staleness (:obj:`int`): The maximum legal span between the time of collecting and time of sampling.
+ """
+
+ def push(next: Callable, data: Any, *args, **kwargs) -> Any:
+ assert 'meta' in kwargs and 'train_iter_data_collected' in kwargs[
+ 'meta'], "staleness_check middleware must push data with meta={'train_iter_data_collected': }"
+ return next(data, *args, **kwargs)
+
+ def sample(next: Callable, train_iter_sample_data: int, *args, **kwargs) -> List[Any]:
+ delete_index = []
+ for i, item in enumerate(buffer_.storage):
+ index, meta = item.index, item.meta
+ staleness = train_iter_sample_data - meta['train_iter_data_collected']
+ meta['staleness'] = staleness
+ if staleness > max_staleness:
+ delete_index.append(index)
+ for index in delete_index:
+ buffer_.delete(index)
+ data = next(*args, **kwargs)
+ return data
+
+ def _staleness_check(action: str, next: Callable, *args, **kwargs) -> Any:
+ if action == "push":
+ return push(next, *args, **kwargs)
+ elif action == "sample":
+ return sample(next, *args, **kwargs)
+ return next(*args, **kwargs)
+
+ return _staleness_check
diff --git a/DI-engine/ding/data/buffer/middleware/use_time_check.py b/DI-engine/ding/data/buffer/middleware/use_time_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..522d63965db8d0a67c0d026d5fb37edd06bafe48
--- /dev/null
+++ b/DI-engine/ding/data/buffer/middleware/use_time_check.py
@@ -0,0 +1,52 @@
+from typing import Callable, Any, List, Optional, Union, TYPE_CHECKING
+from collections import defaultdict
+from ding.data.buffer import BufferedData
+if TYPE_CHECKING:
+ from ding.data.buffer.buffer import Buffer
+
+
+def use_time_check(buffer_: 'Buffer', max_use: int = float("inf")) -> Callable:
+ """
+ Overview:
+ This middleware aims to check the usage times of data in buffer. If the usage times of a data is
+ greater than or equal to max_use, this data will be removed from buffer as soon as possible.
+ Arguments:
+ - max_use (:obj:`int`): The max reused (resampled) count for any individual object.
+ """
+
+ use_count = defaultdict(int)
+
+ def _need_delete(item: BufferedData) -> bool:
+ nonlocal use_count
+ idx = item.index
+ use_count[idx] += 1
+ item.meta['use_count'] = use_count[idx]
+ if use_count[idx] >= max_use:
+ return True
+ else:
+ return False
+
+ def _check_use_count(sampled_data: List[BufferedData]):
+ delete_indices = [item.index for item in filter(_need_delete, sampled_data)]
+ buffer_.delete(delete_indices)
+ for index in delete_indices:
+ del use_count[index]
+
+ def sample(chain: Callable, *args, **kwargs) -> Union[List[BufferedData], List[List[BufferedData]]]:
+ sampled_data = chain(*args, **kwargs)
+ if len(sampled_data) == 0:
+ return sampled_data
+
+ if isinstance(sampled_data[0], BufferedData):
+ _check_use_count(sampled_data)
+ else:
+ for grouped_data in sampled_data:
+ _check_use_count(grouped_data)
+ return sampled_data
+
+ def _use_time_check(action: str, chain: Callable, *args, **kwargs) -> Any:
+ if action == "sample":
+ return sample(chain, *args, **kwargs)
+ return chain(*args, **kwargs)
+
+ return _use_time_check
diff --git a/DI-engine/ding/data/buffer/tests/test_buffer.py b/DI-engine/ding/data/buffer/tests/test_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..647816d36de33f2a65b9823cbd7b47315bce4682
--- /dev/null
+++ b/DI-engine/ding/data/buffer/tests/test_buffer.py
@@ -0,0 +1,352 @@
+import os
+import pytest
+import time
+import random
+import functools
+import tempfile
+from typing import Callable
+from ding.data.buffer import DequeBuffer
+from ding.data.buffer.buffer import BufferedData
+from torch.utils.data import DataLoader
+
+
+class RateLimit:
+ r"""
+ Add rate limit threshold to push function
+ """
+
+ def __init__(self, max_rate: int = float("inf"), window_seconds: int = 30) -> None:
+ self.max_rate = max_rate
+ self.window_seconds = window_seconds
+ self.buffered = []
+
+ def __call__(self, action: str, chain: Callable, *args, **kwargs):
+ if action == "push":
+ return self.push(chain, *args, **kwargs)
+ return chain(*args, **kwargs)
+
+ def push(self, chain, data, *args, **kwargs) -> None:
+ current = time.time()
+ # Cut off stale records
+ self.buffered = [t for t in self.buffered if t > current - self.window_seconds]
+ if len(self.buffered) < self.max_rate:
+ self.buffered.append(current)
+ return chain(data, *args, **kwargs)
+ else:
+ return None
+
+
+def add_10() -> Callable:
+ """
+ Transform data on sampling
+ """
+
+ def sample(chain: Callable, size: int, replace: bool = False, *args, **kwargs):
+ sampled_data = chain(size, replace, *args, **kwargs)
+ return [BufferedData(data=item.data + 10, index=item.index, meta=item.meta) for item in sampled_data]
+
+ def _subview(action: str, chain: Callable, *args, **kwargs):
+ if action == "sample":
+ return sample(chain, *args, **kwargs)
+ return chain(*args, **kwargs)
+
+ return _subview
+
+
+@pytest.mark.unittest
+def test_naive_push_sample():
+ # Push and sample
+ buffer = DequeBuffer(size=10)
+ for i in range(20):
+ buffer.push(i)
+ assert buffer.count() == 10
+ assert 0 not in [item.data for item in buffer.sample(10)]
+
+ # Clear
+ buffer.clear()
+ assert buffer.count() == 0
+
+ # Test replace sample
+ for i in range(5):
+ buffer.push(i)
+ assert buffer.count() == 5
+ assert len(buffer.sample(10, replace=True)) == 10
+
+ # Test slicing
+ buffer.clear()
+ for i in range(10):
+ buffer.push(i)
+ assert len(buffer.sample(5, sample_range=slice(5, 10))) == 5
+ assert 0 not in [item.data for item in buffer.sample(5, sample_range=slice(5, 10))]
+
+
+@pytest.mark.unittest
+def test_rate_limit_push_sample():
+ buffer = DequeBuffer(size=10).use(RateLimit(max_rate=5))
+ for i in range(10):
+ buffer.push(i)
+ assert buffer.count() == 5
+ assert 5 not in buffer.sample(5)
+
+
+@pytest.mark.unittest
+def test_load_and_save():
+ buffer = DequeBuffer(size=10).use(RateLimit(max_rate=5))
+ buffer.meta_index = {"label": []}
+ for i in range(10):
+ buffer.push(i, meta={"label": i})
+ assert buffer.count() == 5
+ assert 5 not in buffer.sample(5)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ test_file = os.path.join(tmpdirname, "data.hkl")
+ buffer.save_data(test_file)
+ buffer_new = DequeBuffer(size=10).use(RateLimit(max_rate=5))
+ buffer_new.load_data(test_file)
+ assert buffer_new.count() == 5
+ assert 5 not in buffer_new.sample(5)
+ assert len(buffer.meta_index["label"]) == 5
+ assert all([index < 5 for index in buffer.meta_index["label"]])
+
+
+@pytest.mark.unittest
+def test_buffer_view():
+ buf1 = DequeBuffer(size=10)
+ for i in range(1):
+ buf1.push(i)
+ assert buf1.count() == 1
+
+ buf2 = buf1.view().use(RateLimit(max_rate=5)).use(add_10())
+
+ for i in range(10):
+ buf2.push(i)
+ # With 1 record written by buf1 and 5 records written by buf2
+ assert len(buf1._middleware) == 0
+ assert buf1.count() == 6
+ # All data in buffer should bigger than 10 because of `add_10`
+ assert all(d.data >= 10 for d in buf2.sample(5))
+ # But data in storage is still less than 10
+ assert all(d.data < 10 for d in buf1.sample(5))
+
+
+@pytest.mark.unittest
+def test_sample_with_index():
+ buf = DequeBuffer(size=10)
+ for i in range(10):
+ buf.push({"data": i}, {"meta": i})
+ # Random sample and get indices
+ indices = [item.index for item in buf.sample(10)]
+ assert len(indices) == 10
+ random.shuffle(indices)
+ indices = indices[:5]
+
+ # Resample by indices
+ new_indices = [item.index for item in buf.sample(indices=indices)]
+ assert len(new_indices) == len(indices)
+ for index in new_indices:
+ assert index in indices
+
+
+@pytest.mark.unittest
+def test_update():
+ buf = DequeBuffer(size=10)
+ for i in range(1):
+ buf.push({"data": i}, {"meta": i})
+
+ # Update one data
+ [item] = buf.sample(1)
+ item.data["new_prop"] = "any"
+ meta = None
+ success = buf.update(item.index, item.data, item.meta)
+ assert success
+ # Resample
+ [item] = buf.sample(1)
+ assert "new_prop" in item.data
+ assert meta is None
+ # Update object that not exists in buffer
+ success = buf.update("invalidindex", {}, None)
+ assert not success
+
+ # When exceed buffer size
+ for i in range(20):
+ buf.push({"data": i})
+ assert len(buf.indices) == 10
+ assert len(buf.storage) == 10
+ for i in range(10):
+ index = buf.storage[i].index
+ assert buf.indices.get(index) == i
+
+
+@pytest.mark.unittest
+def test_delete():
+ maxlen = 100
+ cumlen = 40
+ dellen = 20
+ buf = DequeBuffer(size=maxlen)
+ for i in range(cumlen):
+ buf.push(i)
+ # Delete data
+ del_indices = [item.index for item in buf.sample(dellen)]
+ buf.delete(del_indices)
+ # Reappend
+ for i in range(10):
+ buf.push(i)
+ remlen = min(cumlen, maxlen) - dellen + 10
+ assert len(buf.indices) == remlen
+ assert len(buf.storage) == remlen
+ for i in range(remlen):
+ index = buf.storage[i].index
+ assert buf.indices.get(index) == i
+
+
+@pytest.mark.unittest
+def test_ignore_insufficient():
+ buffer = DequeBuffer(size=10)
+ for i in range(2):
+ buffer.push(i)
+
+ with pytest.raises(ValueError):
+ buffer.sample(3, ignore_insufficient=False)
+ data = buffer.sample(3, ignore_insufficient=True)
+ assert len(data) == 0
+
+
+@pytest.mark.unittest
+def test_independence():
+ # By replace
+ buffer = DequeBuffer(size=1)
+ data = {"key": "origin"}
+ buffer.push(data)
+ sampled_data = buffer.sample(2, replace=True)
+ assert len(sampled_data) == 2
+ sampled_data[0].data["key"] = "new"
+ assert sampled_data[1].data["key"] == "origin"
+
+ # By indices
+ buffer = DequeBuffer(size=1)
+ data = {"key": "origin"}
+ buffered = buffer.push(data)
+ indices = [buffered.index, buffered.index]
+ sampled_data = buffer.sample(indices=indices)
+ assert len(sampled_data) == 2
+ sampled_data[0].data["key"] = "new"
+ assert sampled_data[1].data["key"] == "origin"
+
+
+@pytest.mark.unittest
+def test_groupby():
+ buffer = DequeBuffer(size=3)
+ buffer.push("a", {"group": 1})
+ buffer.push("b", {"group": 2})
+ buffer.push("c", {"group": 2})
+
+ sampled_data = buffer.sample(2, groupby="group")
+ assert len(sampled_data) == 2
+ group1 = sampled_data[0] if len(sampled_data[0]) == 1 else sampled_data[1]
+ group2 = sampled_data[0] if len(sampled_data[0]) == 2 else sampled_data[1]
+ # Group1 should contain a
+ assert "a" == group1[0].data
+ # Group2 should contain b and c
+ data = [buffered.data for buffered in group2] # ["b", "c"]
+ assert "b" in data
+ assert "c" in data
+
+ # Push new data and swap out a, the result will all in group 2
+ buffer.push("d", {"group": 2})
+ sampled_data = buffer.sample(1, groupby="group")
+ assert len(sampled_data) == 1
+ assert len(sampled_data[0]) == 3
+ data = [buffered.data for buffered in sampled_data[0]]
+ assert "d" in data
+
+ # Update meta, set first data's group to 1
+ first: BufferedData = buffer.storage[0]
+ buffer.update(first.index, first.data, {"group": 1})
+ sampled_data = buffer.sample(2, groupby="group")
+ assert len(sampled_data) == 2
+
+ # Delete last record, each group will only have one record
+ last: BufferedData = buffer.storage[-1]
+ buffer.delete(last.index)
+ sampled_data = buffer.sample(2, groupby="group")
+ assert len(sampled_data) == 2
+
+
+@pytest.mark.unittest
+def test_dataset():
+ buffer = DequeBuffer(size=10)
+ for i in range(10):
+ buffer.push(i)
+ dataloader = DataLoader(buffer, batch_size=6, shuffle=True, collate_fn=lambda batch: batch)
+ for batch in dataloader:
+ assert len(batch) in [4, 6]
+
+
+@pytest.mark.unittest
+def test_unroll_len_in_group():
+ buffer = DequeBuffer(size=100)
+ for i in range(10):
+ for env_id in list("ABC"):
+ buffer.push(i, {"env": env_id})
+
+ sampled_data = buffer.sample(3, groupby="env", unroll_len=4)
+ assert len(sampled_data) == 3
+ for grouped_data in sampled_data:
+ assert len(grouped_data) == 4
+ # Ensure each group has the same env
+ env_ids = set(map(lambda sample: sample.meta["env"], grouped_data))
+ assert len(env_ids) == 1
+ # Ensure samples in each group is continuous
+ result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data)
+ assert isinstance(result, BufferedData), "Not continuous"
+
+
+@pytest.mark.unittest
+def test_insufficient_unroll_len_in_group():
+ buffer = DequeBuffer(size=100)
+
+ num = 3 # Items in group A,B,C is 3,4,5
+ for env_id in list("ABC"):
+ for i in range(num):
+ buffer.push(i, {"env": env_id})
+ num += 1
+
+ with pytest.raises(ValueError) as exc_info:
+ buffer.sample(3, groupby="env", unroll_len=4)
+ e = exc_info._excinfo[1]
+ assert "There are less than" in str(e)
+
+ # Sample with replace
+ sampled_data = buffer.sample(3, groupby="env", unroll_len=4, replace=True)
+ assert len(sampled_data) == 3
+ for grouped_data in sampled_data:
+ assert len(grouped_data) == 4
+ # Ensure each group has the same env
+ env_ids = set(map(lambda sample: sample.meta["env"], grouped_data))
+ assert len(env_ids) == 1
+ # Ensure samples in each group is continuous
+ result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data)
+ assert isinstance(result, BufferedData), "Not continuous"
+
+
+@pytest.mark.unittest
+def test_slice_unroll_len_in_group():
+ buffer = DequeBuffer(size=100, sliced=True)
+ data_len = 10
+ unroll_len = 4
+ start_index = list(range(0, data_len, unroll_len)) + [data_len - unroll_len]
+ for i in range(data_len):
+ for env_id in list("ABC"):
+ buffer.push(i, {"env": env_id})
+
+ sampled_data = buffer.sample(3, groupby="env", unroll_len=unroll_len)
+ assert len(sampled_data) == 3
+ for grouped_data in sampled_data:
+ assert len(grouped_data) == 4
+ # Ensure each group has the same env
+ env_ids = set(map(lambda sample: sample.meta["env"], grouped_data))
+ assert len(env_ids) == 1
+ # Ensure samples in each group is continuous
+ result = functools.reduce(lambda a, b: a and a.data + 1 == b.data and b, grouped_data)
+ assert isinstance(result, BufferedData), "Not continuous"
+ # Ensure data after sliced start from correct index
+ assert grouped_data[0].data in start_index
diff --git a/DI-engine/ding/data/buffer/tests/test_buffer_benchmark.py b/DI-engine/ding/data/buffer/tests/test_buffer_benchmark.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3a3594356477b9db786068df13f2cc2c10f69b6
--- /dev/null
+++ b/DI-engine/ding/data/buffer/tests/test_buffer_benchmark.py
@@ -0,0 +1,95 @@
+import sys
+import timeit
+import torch
+import random
+import pytest
+import numpy as np
+
+from ding.data.buffer import DequeBuffer
+from ding.data.buffer.middleware import clone_object, PriorityExperienceReplay
+
+# test different buffer size, eg: 1000, 10000, 100000;
+size_list = [1000, 10000]
+# test different tensor dim, eg: 32*32, 128*128, 512*512;
+data_dim_list = [32, 128]
+# repeat times.
+repeats = 100
+
+
+class BufferBenchmark:
+
+ def __init__(self, buffer_size, data_dim, buffer_type='base') -> None:
+ self._buffer = DequeBuffer(size=buffer_size)
+ self._meta = dict()
+ if buffer_type == "clone":
+ self._buffer.use(clone_object())
+ if buffer_type == "priority":
+ self._buffer.use(PriorityExperienceReplay(self._buffer))
+ self._meta["priority"] = 2.0
+ self._data = {"obs": torch.rand(data_dim, data_dim)}
+
+ def data_storage(self) -> float:
+ return sys.getsizeof(self._data["obs"].storage()) / 1024
+
+ def count(self) -> int:
+ return self._buffer.count()
+
+ def push_op(self) -> None:
+ self._buffer.push(self._data, meta=self._meta)
+
+ def push_with_group_info(self, num_keys=256) -> None:
+ meta = self._meta.copy()
+ rand = random.random()
+ value = int(rand * num_keys)
+ meta['group'] = value
+ self._buffer.push(self._data, meta=meta)
+
+ def sample_op(self) -> None:
+ self._buffer.sample(128, replace=False)
+
+ def replace_sample_op(self) -> None:
+ self._buffer.sample(128, replace=True)
+
+ def groupby_sample_op(self) -> None:
+ self._buffer.sample(128, groupby="group")
+
+
+def get_mean_std(res):
+ # return the total time per 1000 ops
+ return np.mean(res) * 1000.0 / repeats, np.std(res) * 1000.0 / repeats
+
+
+@pytest.mark.benchmark
+@pytest.mark.parametrize('buffer_type', ['base', 'clone', 'priority'])
+def test_benchmark(buffer_type):
+ for size in size_list:
+ for dim in data_dim_list:
+ assert size >= 128, "size is too small, please set an int no less than 128!"
+
+ buffer_test = BufferBenchmark(size, dim, buffer_type)
+
+ print("exp-buffer_{}_{}-data_{:.2f}_KB".format(buffer_type, size, buffer_test.data_storage()))
+
+ # test pushing
+ mean, std = get_mean_std(timeit.repeat(buffer_test.push_op, number=repeats))
+ print("Empty Push Test: mean {:.4f} s, std {:.4f} s".format(mean, std))
+
+ # fill the buffer before sampling tests
+ for _ in range(size):
+ buffer_test.push_with_group_info()
+ assert buffer_test.count() == size, "buffer is not full when testing sampling!"
+
+ # test sampling without replace
+ mean, std = get_mean_std(timeit.repeat(buffer_test.sample_op, number=repeats))
+ print("No-Replace Sample Test: mean {:.4f} s, std {:.4f} s".format(mean, std))
+
+ # test sampling with replace
+ mean, std = get_mean_std(timeit.repeat(buffer_test.replace_sample_op, number=repeats))
+ print("Replace Sample Test: mean {:.4f} s, std {:.4f} s".format(mean, std))
+
+ # test groupby sampling
+ if buffer_type != 'priority':
+ mean, std = get_mean_std(timeit.repeat(buffer_test.groupby_sample_op, number=repeats))
+ print("Groupby Sample Test: mean {:.4f} s, std {:.4f} s".format(mean, std))
+
+ print("=" * 100)
diff --git a/DI-engine/ding/data/buffer/tests/test_middleware.py b/DI-engine/ding/data/buffer/tests/test_middleware.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc19866ee32dccbf40fb22011e0e3b1419f4cfde
--- /dev/null
+++ b/DI-engine/ding/data/buffer/tests/test_middleware.py
@@ -0,0 +1,210 @@
+import pytest
+import torch
+from ding.data.buffer import DequeBuffer
+from ding.data.buffer.middleware import clone_object, use_time_check, staleness_check, sample_range_view
+from ding.data.buffer.middleware import PriorityExperienceReplay, group_sample
+from ding.data.buffer.middleware.padding import padding
+
+
+@pytest.mark.unittest
+def test_clone_object():
+ buffer = DequeBuffer(size=10).use(clone_object())
+
+ # Store a dict, a list, a tensor
+ arr = [{"key": "v1"}, ["a"], torch.Tensor([1, 2, 3])]
+ for o in arr:
+ buffer.push(o)
+
+ # Modify it
+ for item in buffer.sample(len(arr)):
+ item = item.data
+ if isinstance(item, dict):
+ item["key"] = "v2"
+ elif isinstance(item, list):
+ item.append("b")
+ elif isinstance(item, torch.Tensor):
+ item[0] = 3
+ else:
+ raise Exception("Unexpected type")
+
+ # Resample it, and check their values
+ for item in buffer.sample(len(arr)):
+ item = item.data
+ if isinstance(item, dict):
+ assert item["key"] == "v1"
+ elif isinstance(item, list):
+ assert len(item) == 1
+ elif isinstance(item, torch.Tensor):
+ assert item[0] == 1
+ else:
+ raise Exception("Unexpected type")
+
+
+def get_data():
+ return {'obs': torch.randn(4), 'reward': torch.randn(1), 'info': 'xxx'}
+
+
+@pytest.mark.unittest
+def test_use_time_check():
+ N = 6
+ buffer = DequeBuffer(size=10)
+ buffer.use(use_time_check(buffer, max_use=2))
+
+ for _ in range(N):
+ buffer.push(get_data())
+
+ for _ in range(2):
+ data = buffer.sample(size=N, replace=False)
+ assert len(data) == N
+ with pytest.raises(ValueError):
+ buffer.sample(size=1, replace=False)
+
+
+@pytest.mark.unittest
+def test_staleness_check():
+ N = 6
+ buffer = DequeBuffer(size=10)
+ buffer.use(staleness_check(buffer, max_staleness=10))
+
+ with pytest.raises(AssertionError):
+ buffer.push(get_data())
+ for _ in range(N):
+ buffer.push(get_data(), meta={'train_iter_data_collected': 0})
+ data = buffer.sample(size=N, replace=False, train_iter_sample_data=9)
+ assert len(data) == N
+ data = buffer.sample(size=N, replace=False, train_iter_sample_data=10) # edge case
+ assert len(data) == N
+ for _ in range(2):
+ buffer.push(get_data(), meta={'train_iter_data_collected': 5})
+ assert buffer.count() == 8
+ with pytest.raises(ValueError):
+ data = buffer.sample(size=N, replace=False, train_iter_sample_data=11)
+ assert buffer.count() == 2
+
+
+@pytest.mark.unittest
+def test_priority():
+ N = 5
+ buffer = DequeBuffer(size=10)
+ buffer.use(PriorityExperienceReplay(buffer, IS_weight=True))
+ for _ in range(N):
+ buffer.push(get_data(), meta={'priority': 2.0})
+ assert buffer.count() == N
+ for _ in range(N):
+ buffer.push(get_data(), meta={'priority': 2.0})
+ assert buffer.count() == N + N
+ data = buffer.sample(size=N + N, replace=False)
+ assert len(data) == N + N
+ for item in data:
+ meta = item.meta
+ assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS']))
+ meta['priority'] = 3.0
+ for item in data:
+ data, index, meta = item.data, item.index, item.meta
+ buffer.update(index, data, meta)
+ data = buffer.sample(size=1)
+ assert data[0].meta['priority'] == 3.0
+ buffer.delete(data[0].index)
+ assert buffer.count() == N + N - 1
+ buffer.clear()
+ assert buffer.count() == 0
+
+
+@pytest.mark.unittest
+def test_priority_from_collector():
+ N = 5
+ buffer = DequeBuffer(size=10)
+ buffer.use(PriorityExperienceReplay(buffer, IS_weight=True))
+ for _ in range(N):
+ tmp_data = get_data()
+ tmp_data['priority'] = 2.0
+ buffer.push(get_data())
+ assert buffer.count() == N
+ for _ in range(N):
+ tmp_data = get_data()
+ tmp_data['priority'] = 2.0
+ buffer.push(get_data())
+ assert buffer.count() == N + N
+ data = buffer.sample(size=N + N, replace=False)
+ assert len(data) == N + N
+ for item in data:
+ meta = item.meta
+ assert set(meta.keys()).issuperset(set(['priority', 'priority_idx', 'priority_IS']))
+ meta['priority'] = 3.0
+ for item in data:
+ data, index, meta = item.data, item.index, item.meta
+ buffer.update(index, data, meta)
+ data = buffer.sample(size=1)
+ assert data[0].meta['priority'] == 3.0
+ buffer.delete(data[0].index)
+ assert buffer.count() == N + N - 1
+ buffer.clear()
+ assert buffer.count() == 0
+
+
+@pytest.mark.unittest
+def test_padding():
+ buffer = DequeBuffer(size=10)
+ buffer.use(padding())
+ for i in range(10):
+ buffer.push(i, {"group": i & 5}) # [3,3,2,2]
+ sampled_data = buffer.sample(4, groupby="group")
+ assert len(sampled_data) == 4
+ for grouped_data in sampled_data:
+ assert len(grouped_data) == 3
+
+
+@pytest.mark.unittest
+def test_group_sample():
+ buffer = DequeBuffer(size=10)
+ buffer.use(padding(policy="none")).use(group_sample(size_in_group=5, ordered_in_group=True, max_use_in_group=True))
+ for i in range(4):
+ buffer.push(i, {"episode": 0})
+ for i in range(6):
+ buffer.push(i, {"episode": 1})
+ sampled_data = buffer.sample(2, groupby="episode")
+ assert len(sampled_data) == 2
+
+ def check_group0(grouped_data):
+ # In group0 should find only last record with data as None
+ n_none = 0
+ for item in grouped_data:
+ if item.data is None:
+ n_none += 1
+ assert n_none == 1
+
+ def check_group1(grouped_data):
+ # In group1 every record should have data and meta
+ for item in grouped_data:
+ assert item.data is not None
+
+ for grouped_data in sampled_data:
+ assert len(grouped_data) == 5
+ meta = grouped_data[0].meta
+ if meta and "episode" in meta and meta["episode"] == 1:
+ check_group1(grouped_data)
+ else:
+ check_group0(grouped_data)
+
+
+@pytest.mark.unittest
+def test_sample_range_view():
+ buffer_ = DequeBuffer(size=10)
+ for i in range(5):
+ buffer_.push({'data': 'x'})
+ for i in range(5, 5 + 3):
+ buffer_.push({'data': 'y'})
+ for i in range(8, 8 + 2):
+ buffer_.push({'data': 'z'})
+
+ buffer1 = buffer_.view()
+ buffer1.use(sample_range_view(buffer1, start=-5, end=-2))
+ for _ in range(10):
+ sampled_data = buffer1.sample(1)
+ assert sampled_data[0].data['data'] == 'y'
+
+ buffer2 = buffer_.view()
+ buffer2.use(sample_range_view(buffer1, start=-2))
+ for _ in range(10):
+ sampled_data = buffer2.sample(1)
+ assert sampled_data[0].data['data'] == 'z'
diff --git a/DI-engine/ding/data/level_replay/__init__.py b/DI-engine/ding/data/level_replay/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/ding/data/level_replay/level_sampler.py b/DI-engine/ding/data/level_replay/level_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac51fc4d4d00460aae287ecbfab6fc60cfc0b03b
--- /dev/null
+++ b/DI-engine/ding/data/level_replay/level_sampler.py
@@ -0,0 +1,321 @@
+from typing import Optional, Union, Any, List
+from easydict import EasyDict
+from ding.utils import deep_merge_dicts, SequenceType
+from collections import namedtuple
+import numpy as np
+import torch
+
+
+class LevelSampler():
+ """
+ Overview:
+ Policy class of Prioritized Level Replay algorithm.
+ https://arxiv.org/pdf/2010.03934.pdf
+
+ PLR is a method for improving generalization and sample-efficiency of \
+ deep RL agents on procedurally-generated environments by adaptively updating \
+ a sampling distribution over the training levels based on a score of the learning \
+ potential of replaying each level.
+ """
+ config = dict(
+ strategy='policy_entropy',
+ replay_schedule='fixed',
+ score_transform='rank',
+ temperature=1.0,
+ eps=0.05,
+ rho=0.2,
+ nu=0.5,
+ alpha=1.0,
+ staleness_coef=0,
+ staleness_transform='power',
+ staleness_temperature=1.0,
+ )
+
+ def __init__(
+ self,
+ seeds: Optional[List[int]],
+ obs_space: Union[int, SequenceType],
+ action_space: int,
+ num_actors: int,
+ cfg: EasyDict,
+ ):
+ self.cfg = EasyDict(deep_merge_dicts(self.config, cfg))
+ self.cfg.update(cfg)
+ self.obs_space = obs_space
+ self.action_space = action_space
+ self.strategy = self.cfg.strategy
+ self.replay_schedule = self.cfg.replay_schedule
+ self.score_transform = self.cfg.score_transform
+ self.temperature = self.cfg.temperature
+ # Eps means the level replay epsilon for eps-greedy sampling
+ self.eps = self.cfg.eps
+ # Rho means the minimum size of replay set relative to total number of levels before sampling replays
+ self.rho = self.cfg.rho
+ # Nu means the probability of sampling a new level instead of a replay level
+ self.nu = self.cfg.nu
+ # Alpha means the level score EWA smoothing factor
+ self.alpha = self.cfg.alpha
+ self.staleness_coef = self.cfg.staleness_coef
+ self.staleness_transform = self.cfg.staleness_transform
+ self.staleness_temperature = self.cfg.staleness_temperature
+
+ # Track seeds and scores as in np arrays backed by shared memory
+ self.seeds = np.array(seeds, dtype=np.int64)
+ self.seed2index = {seed: i for i, seed in enumerate(seeds)}
+
+ self.unseen_seed_weights = np.ones(len(seeds))
+ self.seed_scores = np.zeros(len(seeds))
+ self.partial_seed_scores = np.zeros((num_actors, len(seeds)), dtype=np.float32)
+ self.partial_seed_steps = np.zeros((num_actors, len(seeds)), dtype=np.int64)
+ self.seed_staleness = np.zeros(len(seeds))
+
+ self.next_seed_index = 0 # Only used for sequential strategy
+
+ def update_with_rollouts(self, train_data: dict, num_actors: int):
+ total_steps = train_data['reward'].shape[0]
+ if self.strategy == 'random':
+ return
+
+ if self.strategy == 'policy_entropy':
+ score_function = self._entropy
+ elif self.strategy == 'least_confidence':
+ score_function = self._least_confidence
+ elif self.strategy == 'min_margin':
+ score_function = self._min_margin
+ elif self.strategy == 'gae':
+ score_function = self._gae
+ elif self.strategy == 'value_l1':
+ score_function = self._value_l1
+ elif self.strategy == 'one_step_td_error':
+ score_function = self._one_step_td_error
+ else:
+ raise ValueError('Not supported strategy: {}'.format(self.strategy))
+
+ self._update_with_rollouts(train_data, num_actors, total_steps, score_function)
+
+ for actor_index in range(self.partial_seed_scores.shape[0]):
+ for seed_idx in range(self.partial_seed_scores.shape[1]):
+ if self.partial_seed_scores[actor_index][seed_idx] != 0:
+ self.update_seed_score(actor_index, seed_idx, 0, 0)
+ self.partial_seed_scores.fill(0)
+ self.partial_seed_steps.fill(0)
+
+ def update_seed_score(self, actor_index: int, seed_idx: int, score: float, num_steps: int):
+ score = self._partial_update_seed_score(actor_index, seed_idx, score, num_steps, done=True)
+
+ self.unseen_seed_weights[seed_idx] = 0. # No longer unseen
+
+ old_score = self.seed_scores[seed_idx]
+ self.seed_scores[seed_idx] = (1 - self.alpha) * old_score + self.alpha * score
+
+ def _partial_update_seed_score(
+ self, actor_index: int, seed_idx: int, score: float, num_steps: int, done: bool = False
+ ):
+ partial_score = self.partial_seed_scores[actor_index][seed_idx]
+ partial_num_steps = self.partial_seed_steps[actor_index][seed_idx]
+
+ running_num_steps = partial_num_steps + num_steps
+ merged_score = partial_score + (score - partial_score) * num_steps / float(running_num_steps)
+
+ if done:
+ self.partial_seed_scores[actor_index][seed_idx] = 0. # zero partial score, partial num_steps
+ self.partial_seed_steps[actor_index][seed_idx] = 0
+ else:
+ self.partial_seed_scores[actor_index][seed_idx] = merged_score
+ self.partial_seed_steps[actor_index][seed_idx] = running_num_steps
+
+ return merged_score
+
+ def _entropy(self, **kwargs):
+ episode_logits = kwargs['episode_logits']
+ num_actions = self.action_space
+ max_entropy = -(1. / num_actions) * np.log(1. / num_actions) * num_actions
+
+ return (-torch.exp(episode_logits) * episode_logits).sum(-1).mean().item() / max_entropy
+
+ def _least_confidence(self, **kwargs):
+ episode_logits = kwargs['episode_logits']
+ return (1 - torch.exp(episode_logits.max(-1, keepdim=True)[0])).mean().item()
+
+ def _min_margin(self, **kwargs):
+ episode_logits = kwargs['episode_logits']
+ top2_confidence = torch.exp(episode_logits.topk(2, dim=-1)[0])
+ return 1 - (top2_confidence[:, 0] - top2_confidence[:, 1]).mean().item()
+
+ def _gae(self, **kwargs):
+
+ advantages = kwargs['adv']
+
+ return advantages.mean().item()
+
+ def _value_l1(self, **kwargs):
+ advantages = kwargs['adv']
+ # If the absolute value of ADV is large, it means that the level can significantly change
+ # the policy and can be used to learn more
+
+ return advantages.abs().mean().item()
+
+ def _one_step_td_error(self, **kwargs):
+ rewards = kwargs['rewards']
+ value = kwargs['value']
+
+ max_t = len(rewards)
+ td_errors = (rewards[:-1] + value[:max_t - 1] - value[1:max_t]).abs()
+
+ return td_errors.abs().mean().item()
+
+ def _update_with_rollouts(self, train_data: dict, num_actors: int, all_total_steps: int, score_function):
+ level_seeds = train_data['seed'].reshape(num_actors, int(all_total_steps / num_actors)).transpose(0, 1)
+ policy_logits = train_data['logit'].reshape(num_actors, int(all_total_steps / num_actors), -1).transpose(0, 1)
+ done = train_data['done'].reshape(num_actors, int(all_total_steps / num_actors)).transpose(0, 1)
+ total_steps, num_actors = policy_logits.shape[:2]
+ num_decisions = len(policy_logits)
+
+ for actor_index in range(num_actors):
+ done_steps = done[:, actor_index].nonzero()[:total_steps, 0]
+ start_t = 0
+
+ for t in done_steps:
+ if not start_t < total_steps:
+ break
+
+ if t == 0: # if t is 0, then this done step caused a full update of previous seed last cycle
+ continue
+
+ seed_t = level_seeds[start_t, actor_index].item()
+ seed_t = int(seed_t)
+ seed_idx_t = self.seed2index[seed_t]
+
+ score_function_kwargs = {}
+ episode_logits = policy_logits[start_t:t, actor_index]
+ score_function_kwargs['episode_logits'] = torch.log_softmax(episode_logits, -1)
+
+ if self.strategy in ['gae', 'value_l1', 'one_step_td_error']:
+ rewards = train_data['reward'].reshape(num_actors,
+ int(all_total_steps / num_actors)).transpose(0, 1)
+ adv = train_data['adv'].reshape(num_actors, int(all_total_steps / num_actors)).transpose(0, 1)
+ value = train_data['value'].reshape(num_actors, int(all_total_steps / num_actors)).transpose(0, 1)
+ score_function_kwargs['adv'] = adv[start_t:t, actor_index]
+ score_function_kwargs['rewards'] = rewards[start_t:t, actor_index]
+ score_function_kwargs['value'] = value[start_t:t, actor_index]
+
+ score = score_function(**score_function_kwargs)
+ num_steps = len(episode_logits)
+ self.update_seed_score(actor_index, seed_idx_t, score, num_steps)
+
+ start_t = t.item()
+
+ if start_t < total_steps:
+ seed_t = level_seeds[start_t, actor_index].item()
+ seed_idx_t = self.seed2index[seed_t]
+
+ score_function_kwargs = {}
+ episode_logits = policy_logits[start_t:, actor_index]
+ score_function_kwargs['episode_logits'] = torch.log_softmax(episode_logits, -1)
+
+ if self.strategy in ['gae', 'value_l1', 'one_step_td_error']:
+ rewards = train_data['reward'].reshape(num_actors,
+ int(all_total_steps / num_actors)).transpose(0, 1)
+ adv = train_data['adv'].reshape(num_actors, int(all_total_steps / num_actors)).transpose(0, 1)
+ value = train_data['value'].reshape(num_actors, int(all_total_steps / num_actors)).transpose(0, 1)
+ score_function_kwargs['adv'] = adv[start_t:, actor_index]
+ score_function_kwargs['rewards'] = rewards[start_t:, actor_index]
+ score_function_kwargs['value'] = value[start_t:, actor_index]
+
+ score = score_function(**score_function_kwargs)
+ num_steps = len(episode_logits)
+ self._partial_update_seed_score(actor_index, seed_idx_t, score, num_steps)
+
+ def _update_staleness(self, selected_idx: int):
+ if self.staleness_coef > 0:
+ self.seed_staleness += 1
+ self.seed_staleness[selected_idx] = 0
+
+ def _sample_replay_level(self):
+ sample_weights = self._sample_weights()
+
+ if np.isclose(np.sum(sample_weights), 0):
+ sample_weights = np.ones_like(sample_weights, dtype=np.float32) / len(sample_weights)
+
+ seed_idx = np.random.choice(range(len(self.seeds)), 1, p=sample_weights)[0]
+ seed = self.seeds[seed_idx]
+
+ self._update_staleness(seed_idx)
+
+ return int(seed)
+
+ def _sample_unseen_level(self):
+ sample_weights = self.unseen_seed_weights / self.unseen_seed_weights.sum()
+ seed_idx = np.random.choice(range(len(self.seeds)), 1, p=sample_weights)[0]
+ seed = self.seeds[seed_idx]
+
+ self._update_staleness(seed_idx)
+
+ return int(seed)
+
+ def sample(self, strategy: Optional[str] = None):
+ if not strategy:
+ strategy = self.strategy
+
+ if strategy == 'random':
+ seed_idx = np.random.choice(range(len(self.seeds)))
+ seed = self.seeds[seed_idx]
+ return int(seed)
+
+ elif strategy == 'sequential':
+ seed_idx = self.next_seed_index
+ self.next_seed_index = (self.next_seed_index + 1) % len(self.seeds)
+ seed = self.seeds[seed_idx]
+ return int(seed)
+
+ num_unseen = (self.unseen_seed_weights > 0).sum()
+ proportion_seen = (len(self.seeds) - num_unseen) / len(self.seeds)
+
+ if self.replay_schedule == 'fixed':
+ if proportion_seen >= self.rho:
+ # Sample replay level with fixed prob = 1 - nu OR if all levels seen
+ if np.random.rand() > self.nu or not proportion_seen < 1.0:
+ return self._sample_replay_level()
+
+ # Otherwise, sample a new level
+ return self._sample_unseen_level()
+
+ else: # Default to proportionate schedule
+ if proportion_seen >= self.rho and np.random.rand() < proportion_seen:
+ return self._sample_replay_level()
+ else:
+ return self._sample_unseen_level()
+
+ def _sample_weights(self):
+ weights = self._score_transform(self.score_transform, self.temperature, self.seed_scores)
+ weights = weights * (1 - self.unseen_seed_weights) # zero out unseen levels
+
+ z = np.sum(weights)
+ if z > 0:
+ weights /= z
+
+ staleness_weights = 0
+ if self.staleness_coef > 0:
+ staleness_weights = self._score_transform(
+ self.staleness_transform, self.staleness_temperature, self.seed_staleness
+ )
+ staleness_weights = staleness_weights * (1 - self.unseen_seed_weights)
+ z = np.sum(staleness_weights)
+ if z > 0:
+ staleness_weights /= z
+
+ weights = (1 - self.staleness_coef) * weights + self.staleness_coef * staleness_weights
+
+ return weights
+
+ def _score_transform(self, transform: Optional[str], temperature: float, scores: Optional[List[float]]):
+ if transform == 'rank':
+ temp = np.flip(scores.argsort())
+ ranks = np.empty_like(temp)
+ ranks[temp] = np.arange(len(temp)) + 1
+ weights = 1 / ranks ** (1. / temperature)
+ elif transform == 'power':
+ eps = 0 if self.staleness_coef > 0 else 1e-3
+ weights = (np.array(scores) + eps) ** (1. / temperature)
+
+ return weights
diff --git a/DI-engine/ding/data/level_replay/tests/test_level_sampler.py b/DI-engine/ding/data/level_replay/tests/test_level_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea0c2ddfa04814744db9ad5f2e129ab8af687a9b
--- /dev/null
+++ b/DI-engine/ding/data/level_replay/tests/test_level_sampler.py
@@ -0,0 +1,38 @@
+import pytest
+import numpy as np
+import random
+import torch
+from ding.data.level_replay.level_sampler import LevelSampler
+
+
+@pytest.mark.unittest
+def test_level_sampler():
+ num_seeds = 500
+ obs_shape = [3, 64, 64]
+ action_shape = 15
+ collector_env_num = 16
+ level_replay_dict = dict(
+ strategy='min_margin',
+ score_transform='rank',
+ temperature=0.1,
+ )
+ N = 10
+ collector_sample_length = 160
+
+ train_seeds = [i for i in range(num_seeds)]
+ level_sampler = LevelSampler(train_seeds, obs_shape, action_shape, collector_env_num, level_replay_dict)
+
+ value = torch.randn(collector_sample_length)
+ reward = torch.randn(collector_sample_length)
+ adv = torch.randn(collector_sample_length)
+ done = torch.randn(collector_sample_length)
+ logit = torch.randn(collector_sample_length, N)
+ seeds = [random.randint(0, num_seeds) for i in range(collector_env_num)]
+ all_seeds = torch.Tensor(
+ [seeds[i] for i in range(collector_env_num) for j in range(int(collector_sample_length / collector_env_num))]
+ )
+
+ train_data = {'value': value, 'reward': reward, 'adv': adv, 'done': done, 'logit': logit, 'seed': all_seeds}
+ level_sampler.update_with_rollouts(train_data, collector_env_num)
+ sample_seed = level_sampler.sample()
+ assert isinstance(sample_seed, int)
diff --git a/DI-engine/ding/data/model_loader.py b/DI-engine/ding/data/model_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd3182897bb3094ac98bcccce2f2349c0e84b3de
--- /dev/null
+++ b/DI-engine/ding/data/model_loader.py
@@ -0,0 +1,155 @@
+from abc import ABC, abstractmethod
+import logging
+from os import path
+import os
+from threading import Thread
+from time import sleep, time
+from typing import Callable, Optional
+import uuid
+import torch.multiprocessing as mp
+
+import torch
+from ding.data.storage.file import FileModelStorage
+from ding.data.storage.storage import Storage
+from ding.framework import Supervisor
+from ding.framework.supervisor import ChildType, SendPayload
+
+
+class ModelWorker():
+
+ def __init__(self, model: torch.nn.Module) -> None:
+ self._model = model
+
+ def save(self, storage: Storage) -> Storage:
+ storage.save(self._model.state_dict())
+ return storage
+
+
+class ModelLoader(Supervisor, ABC):
+
+ def __init__(self, model: torch.nn.Module) -> None:
+ """
+ Overview:
+ Save and send models asynchronously and load them synchronously.
+ Arguments:
+ - model (:obj:`torch.nn.Module`): Torch module.
+ """
+ if next(model.parameters()).is_cuda:
+ super().__init__(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn"))
+ else:
+ super().__init__(type_=ChildType.PROCESS)
+ self._model = model
+ self._send_callback_loop = None
+ self._send_callbacks = {}
+ self._model_worker = ModelWorker(self._model)
+
+ def start(self):
+ if not self._running:
+ self._model.share_memory()
+ self.register(self._model_worker)
+ self.start_link()
+ self._send_callback_loop = Thread(target=self._loop_send_callback, daemon=True)
+ self._send_callback_loop.start()
+
+ def shutdown(self, timeout: Optional[float] = None) -> None:
+ super().shutdown(timeout)
+ self._send_callback_loop = None
+ self._send_callbacks = {}
+
+ def _loop_send_callback(self):
+ while True:
+ payload = self.recv(ignore_err=True)
+ if payload.err:
+ logging.warning("Got error when loading data: {}".format(payload.err))
+ if payload.req_id in self._send_callbacks:
+ del self._send_callbacks[payload.req_id]
+ else:
+ if payload.req_id in self._send_callbacks:
+ callback = self._send_callbacks.pop(payload.req_id)
+ callback(payload.data)
+
+ def load(self, storage: Storage) -> object:
+ """
+ Overview:
+ Load model synchronously.
+ Arguments:
+ - storage (:obj:`Stroage`): The model should be wrapped in a storage object, e.g. FileModelStorage.
+ Returns:
+ - object (:obj:): The loaded model.
+ """
+ return storage.load()
+
+ @abstractmethod
+ def save(self, callback: Callable) -> Storage:
+ """
+ Overview:
+ Save model asynchronously.
+ Arguments:
+ - callback (:obj:`Callable`): The callback function after saving model.
+ Returns:
+ - storage (:obj:`Storage`): The storage object is created synchronously, so it can be returned.
+ """
+ raise NotImplementedError
+
+
+class FileModelLoader(ModelLoader):
+
+ def __init__(self, model: torch.nn.Module, dirname: str, ttl: int = 20) -> None:
+ """
+ Overview:
+ Model loader using files as storage media.
+ Arguments:
+ - model (:obj:`torch.nn.Module`): Torch module.
+ - dirname (:obj:`str`): The directory for saving files.
+ - ttl (:obj:`int`): Files will be automatically cleaned after ttl. Note that \
+ files that do not time out when the process is stopped are not cleaned up \
+ (to avoid errors when other processes read the file), so you may need to \
+ clean up the remaining files manually
+ """
+ super().__init__(model)
+ self._dirname = dirname
+ self._ttl = ttl
+ self._files = []
+ self._cleanup_thread = None
+
+ def _start_cleanup(self):
+ """
+ Overview:
+ Start a cleanup thread to clean up files that are taking up too much time on the disk.
+ """
+ if self._cleanup_thread is None:
+ self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True)
+ self._cleanup_thread.start()
+
+ def shutdown(self, timeout: Optional[float] = None) -> None:
+ super().shutdown(timeout)
+ self._cleanup_thread = None
+
+ def _loop_cleanup(self):
+ while True:
+ if len(self._files) == 0 or time() - self._files[0][0] < self._ttl:
+ sleep(1)
+ continue
+ _, file_path = self._files.pop(0)
+ if path.exists(file_path):
+ os.remove(file_path)
+
+ def save(self, callback: Callable) -> FileModelStorage:
+ if not self._running:
+ logging.warning("Please start model loader before saving model.")
+ return
+ if not path.exists(self._dirname):
+ os.mkdir(self._dirname)
+ file_path = "model_{}.pth.tar".format(uuid.uuid1())
+ file_path = path.join(self._dirname, file_path)
+ model_storage = FileModelStorage(file_path)
+ payload = SendPayload(proc_id=0, method="save", args=[model_storage])
+ self.send(payload)
+
+ def clean_callback(storage: Storage):
+ self._files.append([time(), file_path])
+ callback(storage)
+
+ self._send_callbacks[payload.req_id] = clean_callback
+ self._start_cleanup()
+ return model_storage
diff --git a/DI-engine/ding/data/shm_buffer.py b/DI-engine/ding/data/shm_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b76f5d56e9d8120e533191b566f66160f8c0a9d5
--- /dev/null
+++ b/DI-engine/ding/data/shm_buffer.py
@@ -0,0 +1,133 @@
+from typing import Any, Optional, Union, Tuple, Dict
+from multiprocessing import Array
+import ctypes
+import numpy as np
+import torch
+
+_NTYPE_TO_CTYPE = {
+ np.bool_: ctypes.c_bool,
+ np.uint8: ctypes.c_uint8,
+ np.uint16: ctypes.c_uint16,
+ np.uint32: ctypes.c_uint32,
+ np.uint64: ctypes.c_uint64,
+ np.int8: ctypes.c_int8,
+ np.int16: ctypes.c_int16,
+ np.int32: ctypes.c_int32,
+ np.int64: ctypes.c_int64,
+ np.float32: ctypes.c_float,
+ np.float64: ctypes.c_double,
+}
+
+
+class ShmBuffer():
+ """
+ Overview:
+ Shared memory buffer to store numpy array.
+ """
+
+ def __init__(
+ self,
+ dtype: Union[type, np.dtype],
+ shape: Tuple[int],
+ copy_on_get: bool = True,
+ ctype: Optional[type] = None
+ ) -> None:
+ """
+ Overview:
+ Initialize the buffer.
+ Arguments:
+ - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
+ - shape (:obj:`Tuple[int]`): The shape of the data to limit the size of the buffer.
+ - copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
+ - ctype (:obj:`Optional[type]`): Origin class type, e.g. np.ndarray, torch.Tensor.
+ """
+ if isinstance(dtype, np.dtype): # it is type of gym.spaces.dtype
+ dtype = dtype.type
+ self.buffer = Array(_NTYPE_TO_CTYPE[dtype], int(np.prod(shape)))
+ self.dtype = dtype
+ self.shape = shape
+ self.copy_on_get = copy_on_get
+ self.ctype = ctype
+
+ def fill(self, src_arr: np.ndarray) -> None:
+ """
+ Overview:
+ Fill the shared memory buffer with a numpy array. (Replace the original one.)
+ Arguments:
+ - src_arr (:obj:`np.ndarray`): array to fill the buffer.
+ """
+ assert isinstance(src_arr, np.ndarray), type(src_arr)
+ # for np.array with shape (4, 84, 84) and float32 dtype, reshape is 15~20x faster than flatten
+ # for np.array with shape (4, 84, 84) and uint8 dtype, reshape is 5~7x faster than flatten
+ # so we reshape dst_arr rather than flatten src_arr
+ dst_arr = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
+ np.copyto(dst_arr, src_arr)
+
+ def get(self) -> np.ndarray:
+ """
+ Overview:
+ Get the array stored in the buffer.
+ Return:
+ - data (:obj:`np.ndarray`): A copy of the data stored in the buffer.
+ """
+ data = np.frombuffer(self.buffer.get_obj(), dtype=self.dtype).reshape(self.shape)
+ if self.copy_on_get:
+ data = data.copy() # must use np.copy, torch.from_numpy and torch.as_tensor still use the same memory
+ if self.ctype is torch.Tensor:
+ data = torch.from_numpy(data)
+ return data
+
+
+class ShmBufferContainer(object):
+ """
+ Overview:
+ Support multiple shared memory buffers. Each key-value is name-buffer.
+ """
+
+ def __init__(
+ self,
+ dtype: Union[Dict[Any, type], type, np.dtype],
+ shape: Union[Dict[Any, tuple], tuple],
+ copy_on_get: bool = True
+ ) -> None:
+ """
+ Overview:
+ Initialize the buffer container.
+ Arguments:
+ - dtype (:obj:`Union[type, np.dtype]`): The dtype of the data to limit the size of the buffer.
+ - shape (:obj:`Union[Dict[Any, tuple], tuple]`): If `Dict[Any, tuple]`, use a dict to manage \
+ multiple buffers; If `tuple`, use single buffer.
+ - copy_on_get (:obj:`bool`): Whether to copy data when calling get method.
+ """
+ if isinstance(shape, dict):
+ self._data = {k: ShmBufferContainer(dtype[k], v, copy_on_get) for k, v in shape.items()}
+ elif isinstance(shape, (tuple, list)):
+ self._data = ShmBuffer(dtype, shape, copy_on_get)
+ else:
+ raise RuntimeError("not support shape: {}".format(shape))
+ self._shape = shape
+
+ def fill(self, src_arr: Union[Dict[Any, np.ndarray], np.ndarray]) -> None:
+ """
+ Overview:
+ Fill the one or many shared memory buffer.
+ Arguments:
+ - src_arr (:obj:`Union[Dict[Any, np.ndarray], np.ndarray]`): array to fill the buffer.
+ """
+ if isinstance(self._shape, dict):
+ for k in self._shape.keys():
+ self._data[k].fill(src_arr[k])
+ elif isinstance(self._shape, (tuple, list)):
+ self._data.fill(src_arr)
+
+ def get(self) -> Union[Dict[Any, np.ndarray], np.ndarray]:
+ """
+ Overview:
+ Get the one or many arrays stored in the buffer.
+ Return:
+ - data (:obj:`np.ndarray`): The array(s) stored in the buffer.
+ """
+ if isinstance(self._shape, dict):
+ return {k: self._data[k].get() for k in self._shape.keys()}
+ elif isinstance(self._shape, (tuple, list)):
+ return self._data.get()
diff --git a/DI-engine/ding/data/storage/__init__.py b/DI-engine/ding/data/storage/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..962fbbbf18a6e526fb3c8b04ab14848a45bd6c9c
--- /dev/null
+++ b/DI-engine/ding/data/storage/__init__.py
@@ -0,0 +1,2 @@
+from .storage import Storage
+from .file import FileStorage, FileModelStorage
diff --git a/DI-engine/ding/data/storage/file.py b/DI-engine/ding/data/storage/file.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6a89910b8d4e921212daa1451f2d4d05e162da7
--- /dev/null
+++ b/DI-engine/ding/data/storage/file.py
@@ -0,0 +1,25 @@
+from typing import Any
+from ding.data.storage import Storage
+import pickle
+
+from ding.utils.file_helper import read_file, save_file
+
+
+class FileStorage(Storage):
+
+ def save(self, data: Any) -> None:
+ with open(self.path, "wb") as f:
+ pickle.dump(data, f, protocol=pickle.HIGHEST_PROTOCOL)
+
+ def load(self) -> Any:
+ with open(self.path, "rb") as f:
+ return pickle.load(f)
+
+
+class FileModelStorage(Storage):
+
+ def save(self, state_dict: object) -> None:
+ save_file(self.path, state_dict)
+
+ def load(self) -> object:
+ return read_file(self.path)
diff --git a/DI-engine/ding/data/storage/storage.py b/DI-engine/ding/data/storage/storage.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6a0dae679e07bad76ce2d3bafee8ada10735e68
--- /dev/null
+++ b/DI-engine/ding/data/storage/storage.py
@@ -0,0 +1,16 @@
+from abc import ABC, abstractmethod
+from typing import Any
+
+
+class Storage(ABC):
+
+ def __init__(self, path: str) -> None:
+ self.path = path
+
+ @abstractmethod
+ def save(self, data: Any) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def load(self) -> Any:
+ raise NotImplementedError
diff --git a/DI-engine/ding/data/storage/tests/test_storage.py b/DI-engine/ding/data/storage/tests/test_storage.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f6f1d2c47012b0c42cae89cc4fcec3878b36101
--- /dev/null
+++ b/DI-engine/ding/data/storage/tests/test_storage.py
@@ -0,0 +1,18 @@
+import tempfile
+import pytest
+import os
+from os import path
+from ding.data.storage import FileStorage
+
+
+@pytest.mark.unittest
+def test_file_storage():
+ path_ = path.join(tempfile.gettempdir(), "test_storage.txt")
+ try:
+ storage = FileStorage(path=path_)
+ storage.save("test")
+ content = storage.load()
+ assert content == "test"
+ finally:
+ if path.exists(path_):
+ os.remove(path_)
diff --git a/DI-engine/ding/data/storage_loader.py b/DI-engine/ding/data/storage_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..daf18e2d8277ae538d22c02e9a102a341ccbcde5
--- /dev/null
+++ b/DI-engine/ding/data/storage_loader.py
@@ -0,0 +1,305 @@
+from dataclasses import dataclass
+import os
+import torch
+import numpy as np
+import uuid
+import treetensor.torch as ttorch
+from abc import ABC, abstractmethod
+from ditk import logging
+from time import sleep, time
+from threading import Lock, Thread
+from typing import Any, Callable, Dict, List, Optional, Union
+from ding.data import FileStorage, Storage
+from os import path
+from ding.data.shm_buffer import ShmBuffer
+from ding.framework.supervisor import RecvPayload, Supervisor, ChildType, SendPayload
+
+
+@dataclass
+class ShmObject:
+ id_: ShmBuffer
+ buf: Any
+
+
+class StorageWorker:
+
+ def load(self, storage: Storage) -> Any:
+ return storage.load()
+
+
+class StorageLoader(Supervisor, ABC):
+
+ def __init__(self, worker_num: int = 3) -> None:
+ """
+ Overview:
+ Save and send data synchronously and load them asynchronously.
+ Arguments:
+ - worker_num (:obj:`int`): Subprocess worker number.
+ """
+ super().__init__(type_=ChildType.PROCESS)
+ self._load_lock = Lock() # Load (first meet) should be called one by one.
+ self._callback_map: Dict[str, Callable] = {}
+ self._shm_obj_map: Dict[int, ShmObject] = {}
+ self._worker_num = worker_num
+ self._req_count = 0
+
+ def shutdown(self, timeout: Optional[float] = None) -> None:
+ super().shutdown(timeout)
+ self._recv_loop = None
+ self._callback_map = {}
+ self._shm_obj_map = {}
+ self._req_count = 0
+
+ def start_link(self) -> None:
+ if not self._running:
+ super().start_link()
+ self._recv_loop = Thread(target=self._loop_recv, daemon=True)
+ self._recv_loop.start()
+
+ @property
+ def _next_proc_id(self):
+ return self._req_count % self._worker_num
+
+ @abstractmethod
+ def save(self, obj: Union[Dict, List]) -> Storage:
+ """
+ Overview:
+ Save data with a storage object synchronously.
+ Arguments:
+ - obj (:obj:`Union[Dict, List]`): The data (traj or episodes), can be numpy, tensor or treetensor.
+ Returns:
+ - storage (:obj:`Storage`): The storage object.
+ """
+ raise NotImplementedError
+
+ def load(self, storage: Storage, callback: Callable):
+ """
+ Overview:
+ Load data from a storage object asynchronously. \
+ This function will analysis the data structure when first meet a new data, \
+ then alloc a shared memory buffer for each subprocess, these shared memory buffer \
+ will be responsible for asynchronously loading data into memory.
+ Arguments:
+ - storage (:obj:`Storage`): The storage object.
+ - callback (:obj:`Callable`): Callback function after data loaded.
+ """
+ with self._load_lock:
+ if not self._running:
+ self._first_meet(storage, callback)
+ return
+
+ payload = SendPayload(proc_id=self._next_proc_id, method="load", args=[storage])
+ self._callback_map[payload.req_id] = callback
+ self.send(payload)
+ self._req_count += 1
+
+ def _first_meet(self, storage: Storage, callback: Callable):
+ """
+ Overview:
+ When first meet an object type, we'll load this object directly and analysis the structure,
+ to allocate the shared memory object and create subprocess workers.
+ Arguments:
+ - storage (:obj:`Storage`): The storage object.
+ - callback (:obj:`Callable`): Callback function after data loaded.
+ """
+ obj = storage.load()
+ # Create three workers for each usage type.
+ for i in range(self._worker_num):
+ shm_obj = self._create_shm_buffer(obj)
+ self._shm_obj_map[i] = shm_obj
+ self.register(StorageWorker, shm_buffer=shm_obj, shm_callback=self._shm_callback)
+ self.start_link()
+ callback(obj)
+
+ def _loop_recv(self):
+ while True:
+ payload = self.recv(ignore_err=True)
+ if payload.err:
+ logging.warning("Got error when loading data: {}".format(payload.err))
+ if payload.req_id in self._callback_map:
+ del self._callback_map[payload.req_id]
+ else:
+ self._shm_putback(payload, self._shm_obj_map[payload.proc_id])
+ if payload.req_id in self._callback_map:
+ callback = self._callback_map.pop(payload.req_id)
+ callback(payload.data)
+
+ def _create_shm_buffer(self, obj: Union[Dict, List]) -> Optional[ShmObject]:
+ """
+ Overview:
+ Create shared object (buf and callback) by walk through the data structure.
+ Arguments:
+ - obj (:obj:`Union[Dict, List]`): The data (traj or episodes), can be numpy, tensor or treetensor.
+ Returns:
+ - shm_buf (:obj:`Optional[ShmObject]`): The shared memory buffer.
+ """
+ max_level = 2
+
+ def to_shm(obj: Dict, level: int):
+ if level > max_level:
+ return
+ shm_buf = None
+ if isinstance(obj, Dict) or isinstance(obj, ttorch.Tensor):
+ shm_buf = {}
+ for key, val in obj.items():
+ # Only numpy array can fill into shm buffer
+ if isinstance(val, np.ndarray):
+ shm_buf[key] = ShmBuffer(val.dtype, val.shape, copy_on_get=False)
+ elif isinstance(val, torch.Tensor):
+ shm_buf[key] = ShmBuffer(
+ val.numpy().dtype, val.numpy().shape, copy_on_get=False, ctype=torch.Tensor
+ )
+ # Recursive parsing structure
+ elif isinstance(val, Dict) or isinstance(val, ttorch.Tensor) or isinstance(val, List):
+ buf = to_shm(val, level=level + 1)
+ if buf:
+ shm_buf[key] = buf
+ elif isinstance(obj, List):
+ # Double the size of buffer
+ shm_buf = [to_shm(o, level=level) for o in obj] * 2
+ if all(s is None for s in shm_buf):
+ shm_buf = []
+ return shm_buf
+
+ shm_buf = to_shm(obj, level=0)
+ if shm_buf is not None:
+ random_id = self._random_id()
+ shm_buf = ShmObject(id_=ShmBuffer(random_id.dtype, random_id.shape, copy_on_get=False), buf=shm_buf)
+ return shm_buf
+
+ def _random_id(self) -> np.ndarray:
+ return np.random.randint(1, 9e6, size=(1))
+
+ def _shm_callback(self, payload: RecvPayload, shm_obj: ShmObject):
+ """
+ Overview:
+ Called in subprocess, put payload.data into buf.
+ Arguments:
+ - payload (:obj:`RecvPayload`): The recv payload with meta info of the data.
+ - shm_obj (:obj:`ShmObject`): The shm buffer.
+ """
+ assert isinstance(payload.data, type(
+ shm_obj.buf
+ )), "Data type ({}) and buf type ({}) are not match!".format(type(payload.data), type(shm_obj.buf))
+
+ # Sleep while shm object is not ready.
+ while shm_obj.id_.get()[0] != 0:
+ sleep(0.001)
+
+ max_level = 2
+
+ def shm_callback(data: Union[Dict, List, ttorch.Tensor], buf: Union[Dict, List], level: int):
+ if level > max_level:
+ return
+
+ if isinstance(buf, List):
+ assert isinstance(data, List), "Data ({}) and buf ({}) type not match".format(type(data), type(buf))
+ elif isinstance(buf, Dict):
+ assert isinstance(data, ttorch.Tensor) or isinstance(
+ data, Dict
+ ), "Data ({}) and buf ({}) type not match".format(type(data), type(buf))
+
+ if isinstance(data, Dict) or isinstance(data, ttorch.Tensor):
+ for key, val in data.items():
+ if isinstance(val, torch.Tensor):
+ val = val.numpy()
+ buf_val = buf.get(key)
+ if buf_val is None:
+ continue
+ if isinstance(buf_val, ShmBuffer) and isinstance(val, np.ndarray):
+ buf_val.fill(val)
+ data[key] = None
+ else:
+ shm_callback(val, buf_val, level=level + 1)
+ elif isinstance(data, List):
+ for i, data_ in enumerate(data):
+ shm_callback(data_, buf[i], level=level)
+
+ shm_callback(payload.data, buf=shm_obj.buf, level=0)
+ id_ = self._random_id()
+ shm_obj.id_.fill(id_)
+ payload.extra = id_
+
+ def _shm_putback(self, payload: RecvPayload, shm_obj: ShmObject):
+ """
+ Overview:
+ Called in main process, put buf back into payload.data.
+ Arguments:
+ - payload (:obj:`RecvPayload`): The recv payload with meta info of the data.
+ - shm_obj (:obj:`ShmObject`): The shm buffer.
+ """
+ assert isinstance(payload.data, type(
+ shm_obj.buf
+ )), "Data type ({}) and buf type ({}) are not match!".format(type(payload.data), type(shm_obj.buf))
+
+ assert shm_obj.id_.get()[0] == payload.extra[0], "Shm object and payload do not match ({} - {}).".format(
+ shm_obj.id_.get()[0], payload.extra[0]
+ )
+
+ def shm_putback(data: Union[Dict, List], buf: Union[Dict, List]):
+ if isinstance(data, Dict) or isinstance(data, ttorch.Tensor):
+ for key, val in data.items():
+ buf_val = buf.get(key)
+ if buf_val is None:
+ continue
+ if val is None and isinstance(buf_val, ShmBuffer):
+ data[key] = buf[key].get()
+ else:
+ shm_putback(val, buf_val)
+ elif isinstance(data, List):
+ for i, data_ in enumerate(data):
+ shm_putback(data_, buf[i])
+
+ shm_putback(payload.data, buf=shm_obj.buf)
+ shm_obj.id_.fill(np.array([0]))
+
+
+class FileStorageLoader(StorageLoader):
+
+ def __init__(self, dirname: str, ttl: int = 20, worker_num: int = 3) -> None:
+ """
+ Overview:
+ Dump and load object with file storage.
+ Arguments:
+ - dirname (:obj:`str`): The directory to save files.
+ - ttl (:obj:`str`): Maximum time to keep a file, after which it will be deleted.
+ - worker_num (:obj:`int`): Number of subprocess worker loaders.
+ """
+ super().__init__(worker_num)
+ self._dirname = dirname
+ self._files = []
+ self._cleanup_thread = None
+ self._ttl = ttl # # Delete files created 10 minutes ago.
+
+ def save(self, obj: Union[Dict, List]) -> FileStorage:
+ if not path.exists(self._dirname):
+ os.mkdir(self._dirname)
+ filename = "{}.pkl".format(uuid.uuid1())
+ full_path = path.join(self._dirname, filename)
+ f = FileStorage(full_path)
+ f.save(obj)
+ self._files.append([time(), f.path])
+ self._start_cleanup()
+ return f
+
+ def _start_cleanup(self):
+ """
+ Overview:
+ Start a cleanup thread to clean up files that are taking up too much time on the disk.
+ """
+ if self._cleanup_thread is None:
+ self._cleanup_thread = Thread(target=self._loop_cleanup, daemon=True)
+ self._cleanup_thread.start()
+
+ def shutdown(self, timeout: Optional[float] = None) -> None:
+ super().shutdown(timeout)
+ self._cleanup_thread = None
+
+ def _loop_cleanup(self):
+ while True:
+ if len(self._files) == 0 or time() - self._files[0][0] < self._ttl:
+ sleep(1)
+ continue
+ _, file_path = self._files.pop(0)
+ if path.exists(file_path):
+ os.remove(file_path)
diff --git a/DI-engine/ding/data/tests/test_model_loader.py b/DI-engine/ding/data/tests/test_model_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..caf8c071869d3f75965d5f1479130fc18dba66ce
--- /dev/null
+++ b/DI-engine/ding/data/tests/test_model_loader.py
@@ -0,0 +1,74 @@
+import shutil
+import tempfile
+from time import sleep, time
+import pytest
+from ding.data.model_loader import FileModelLoader
+from ding.data.storage.file import FileModelStorage
+from ding.model import DQN
+from ding.config import compile_config
+from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config
+from os import path
+import torch
+
+
+@pytest.mark.tmp # gitlab ci and local test pass, github always fail
+def test_model_loader():
+ tempdir = path.join(tempfile.gettempdir(), "test_model_loader")
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ model = DQN(**cfg.policy.model)
+ loader = FileModelLoader(model=model, dirname=tempdir, ttl=1)
+ try:
+ loader.start()
+ model_storage = None
+
+ def save_model(storage):
+ nonlocal model_storage
+ model_storage = storage
+
+ start = time()
+ loader.save(save_model)
+ save_time = time() - start
+ print("Save time: {:.4f}s".format(save_time))
+ assert save_time < 0.1
+ sleep(0.5)
+ assert isinstance(model_storage, FileModelStorage)
+ assert len(loader._files) > 0
+
+ state_dict = loader.load(model_storage)
+ model.load_state_dict(state_dict)
+
+ sleep(2)
+ assert not path.exists(model_storage.path)
+ assert len(loader._files) == 0
+ finally:
+ if path.exists(tempdir):
+ shutil.rmtree(tempdir)
+
+
+@pytest.mark.benchmark
+def test_model_loader_benchmark():
+ model = torch.nn.Sequential(torch.nn.Linear(1024, 1024), torch.nn.Linear(1024, 100)) # 40MB
+ tempdir = path.join(tempfile.gettempdir(), "test_model_loader")
+ loader = FileModelLoader(model=model, dirname=tempdir)
+
+ try:
+ loader.start()
+ count = 0
+
+ def send_callback(_):
+ nonlocal count
+ count += 1
+
+ start = time()
+ for _ in range(5):
+ loader.save(send_callback)
+ sleep(0.2)
+
+ while count < 5:
+ sleep(0.001)
+
+ assert time() - start < 1.2
+ finally:
+ if path.exists(tempdir):
+ shutil.rmtree(tempdir)
+ loader.shutdown()
diff --git a/DI-engine/ding/data/tests/test_shm_buffer.py b/DI-engine/ding/data/tests/test_shm_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..04334b47999f69a189578fbf1af2efa4ccfcfb42
--- /dev/null
+++ b/DI-engine/ding/data/tests/test_shm_buffer.py
@@ -0,0 +1,20 @@
+import pytest
+import numpy as np
+import timeit
+from ding.data.shm_buffer import ShmBuffer
+import multiprocessing as mp
+
+
+def subprocess(shm_buf):
+ data = np.random.rand(1024, 1024).astype(np.float32)
+ res = timeit.repeat(lambda: shm_buf.fill(data), repeat=5, number=1000)
+ print("Mean: {:.4f}s, STD: {:.4f}s, Mean each call: {:.4f}ms".format(np.mean(res), np.std(res), np.mean(res)))
+
+
+@pytest.mark.benchmark
+def test_shm_buffer():
+ data = np.random.rand(1024, 1024).astype(np.float32)
+ shm_buf = ShmBuffer(data.dtype, data.shape, copy_on_get=False)
+ proc = mp.Process(target=subprocess, args=[shm_buf])
+ proc.start()
+ proc.join()
diff --git a/DI-engine/ding/data/tests/test_storage_loader.py b/DI-engine/ding/data/tests/test_storage_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ab07acd739a2c2edef1699374a0a4e7c2063c5a
--- /dev/null
+++ b/DI-engine/ding/data/tests/test_storage_loader.py
@@ -0,0 +1,176 @@
+import os
+import timeit
+import pytest
+import tempfile
+import shutil
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+from ding.data.shm_buffer import ShmBuffer
+from ding.data.storage_loader import FileStorageLoader
+from time import sleep, time
+from os import path
+from ding.framework.supervisor import RecvPayload
+
+
+@pytest.mark.tmp # gitlab ci and local test pass, github always fail
+def test_file_storage_loader():
+ tempdir = path.join(tempfile.gettempdir(), "test_storage_loader")
+ loader = FileStorageLoader(dirname=tempdir)
+ try:
+ total_num = 200
+ storages = []
+ for i in range(10):
+ # 21MB
+ data = [
+ {
+ "s": "abc",
+ "obs": np.random.rand(4, 84, 84).astype(np.float32),
+ # "next_obs": np.random.rand(4, 84, 84).astype(np.float32),
+ # "obs": torch.rand(4, 84, 84, dtype=torch.float32),
+ "next_obs": torch.rand(4, 84, 84, dtype=torch.float32)
+ } for _ in range(96)
+ ]
+ storage = loader.save(data)
+ storages.append(storage)
+
+ start = time()
+ for i in range(total_num):
+ storage = storages[i % 10]
+ data = storage.load()
+ origin_time_cost = time() - start
+ print("Load time cost: {:.4f}s".format(origin_time_cost))
+
+ call_times = 0
+
+ def callback(data):
+ assert data[0]['obs'] is not None
+ nonlocal call_times
+ call_times += 1
+
+ # First initialize shared memory is very slow, discard this time cost.
+ start = time()
+ loader._first_meet(storage=storages[0], callback=callback)
+ print("Initialize shared memory time: {:.4f}s".format(time() - start))
+
+ start = time()
+ for i in range(1, total_num):
+ storage = storages[i % 10]
+ loader.load(storage, callback)
+
+ while True:
+ if call_times == total_num:
+ break
+ sleep(0.01)
+ new_time_cost = time() - start
+ print("Loader time cost: {:.4f}s".format(new_time_cost))
+
+ assert new_time_cost < origin_time_cost
+ finally:
+ if path.exists(tempdir):
+ shutil.rmtree(tempdir)
+ loader.shutdown()
+
+
+@pytest.mark.unittest
+def test_file_storage_loader_cleanup():
+ tempdir = path.join(tempfile.gettempdir(), "test_storage_loader")
+ loader = FileStorageLoader(dirname=tempdir, ttl=1)
+ try:
+ storages = []
+ for _ in range(4):
+ data = np.random.rand(4, 84, 84).astype(np.float32)
+ storage = loader.save(data)
+ storages.append(storage)
+ sleep(0.5)
+ assert len(os.listdir(tempdir)) < 4
+ finally:
+ if path.exists(tempdir):
+ shutil.rmtree(tempdir)
+ loader.shutdown()
+
+
+@pytest.mark.unittest
+def test_shared_object():
+ loader = FileStorageLoader(dirname="")
+
+ # ========== Test array ==========
+ obj = [{"obs": np.random.rand(100, 100)} for _ in range(10)]
+ shm_obj = loader._create_shm_buffer(obj)
+ assert len(shm_obj.buf) == len(obj) * 2
+ assert isinstance(shm_obj.buf[0]["obs"], ShmBuffer)
+
+ # Callback
+ payload = RecvPayload(proc_id=0, data=obj)
+ loader._shm_callback(payload=payload, shm_obj=shm_obj)
+ assert len(payload.data) == 10
+ assert [d["obs"] is None for d in payload.data]
+
+ # ========== Putback ==========
+ loader._shm_putback(payload=payload, shm_obj=shm_obj)
+ obj = payload.data
+ assert len(obj) == 10
+ for o in obj:
+ assert isinstance(o["obs"], np.ndarray)
+ assert o["obs"].shape == (100, 100)
+
+ # ========== Test dict ==========
+ obj = {"obs": torch.rand(100, 100, dtype=torch.float32)}
+ shm_obj = loader._create_shm_buffer(obj)
+ assert isinstance(shm_obj.buf["obs"], ShmBuffer)
+
+ payload = RecvPayload(proc_id=0, data=obj)
+ loader._shm_callback(payload=payload, shm_obj=shm_obj)
+ assert payload.data["obs"] is None
+
+ loader._shm_putback(payload=payload, shm_obj=shm_obj)
+ assert isinstance(payload.data["obs"], torch.Tensor)
+ assert payload.data["obs"].shape == (100, 100)
+
+ # ========== Test treetensor ==========
+ obj = {"trajectories": [ttorch.as_tensor({"obs": torch.rand(10, 10, dtype=torch.float32)}) for _ in range(10)]}
+ shm_obj = loader._create_shm_buffer(obj)
+
+ payload = RecvPayload(proc_id=0, data=obj)
+ loader._shm_callback(payload=payload, shm_obj=shm_obj)
+ assert len(payload.data["trajectories"]) == 10
+ for traj in payload.data["trajectories"]:
+ assert traj["obs"] is None
+
+ loader._shm_putback(payload=payload, shm_obj=shm_obj)
+ for traj in payload.data["trajectories"]:
+ assert isinstance(traj["obs"], torch.Tensor)
+ assert traj["obs"].shape == (10, 10)
+
+
+@pytest.mark.benchmark
+def test_shared_object_benchmark():
+ loader = FileStorageLoader(dirname="")
+ # ========== Test treetensor ==========
+ obj = {
+ "env_step": 0,
+ "trajectories": [
+ ttorch.as_tensor(
+ {
+ "done": False,
+ "reward": torch.tensor([1, 0], dtype=torch.int32),
+ "obs": torch.rand(4, 84, 84, dtype=torch.float32),
+ "next_obs": torch.rand(4, 84, 84, dtype=torch.float32),
+ "action": torch.tensor([1], dtype=torch.int32),
+ "collect_train_iter": torch.tensor([1], dtype=torch.int32),
+ "env_data_id": torch.tensor([1], dtype=torch.int32),
+ }
+ ) for _ in range(10)
+ ]
+ }
+ buf = loader._create_shm_buffer(obj)
+ payload = RecvPayload(proc_id=0, data=obj)
+ loader._shm_callback(payload=payload, shm_obj=buf)
+
+ def stmt():
+ payload.extra = buf.id_.get()
+ loader._shm_putback(payload=payload, shm_obj=buf)
+
+ res = timeit.repeat(stmt, repeat=5, number=1000)
+ print("Mean: {:.4f}s, STD: {:.4f}s, Mean each call: {:.4f}ms".format(np.mean(res), np.std(res), np.mean(res)))
+ assert np.mean(res) < 1
diff --git a/DI-engine/ding/design/dataloader-sequence.puml b/DI-engine/ding/design/dataloader-sequence.puml
new file mode 100644
index 0000000000000000000000000000000000000000..c07b4179901e812b8880ee682d208f7959c39262
--- /dev/null
+++ b/DI-engine/ding/design/dataloader-sequence.puml
@@ -0,0 +1,55 @@
+@startuml async_dataloader
+header Async Dataloader
+title Async Dataloader
+
+participant main_process
+participant async_process
+participant get_data_thread
+participant job_queue
+participant worker_process_0
+participant ...
+participant worker_process_n
+participant async_train_queue
+participant cuda_thread
+participant cuda_queue
+autonumber
+
+main_process -> async_process: Start async_process
+main_process -> get_data_thread: Start get_data_thread
+alt num_workers > 1
+ main_process -> job_queue: Init job_queue
+ main_process -> worker_process_0: Start worker_process_0
+ main_process -> ...: Start ...
+ main_process -> worker_process_n: Start worker_process_n
+end
+main_process -> async_train_queue: Init async_train_queue
+alt use_cuda
+ main_process -> cuda_thread: Start cuda_thread
+ main_process -> cuda_queue: Init cuda_queue
+end
+
+async_process -> get_data_thread: Send request "get_data"
+get_data_thread -> get_data_thread: Get data from "data_source"
+get_data_thread -> async_process: Send data (in CPU)
+
+alt num_workers <= 1
+ async_process -> async_process: Process data
+ async_process -> async_train_queue: Put data in queue
+else
+ async_process -> async_process: Chunk pre-process task into pieces
+ async_process -> job_queue: Put sub-tasks in queue
+ worker_process_0 -> job_queue: Get a sub-task from queue
+ worker_process_n -> job_queue: Get a sub-task from queue
+ worker_process_0 -> worker_process_0: Process data
+ worker_process_n -> worker_process_n: Process data
+ worker_process_0 -> async_train_queue: Put data in queue
+ worker_process_n -> async_train_queue: Put data in queue
+end
+
+alt use_cuda
+ cuda_thread -> async_train_queue: Get data (in CPU)
+ cuda_thread -> cuda_thread: Move data from CPU to GPU
+ cuda_thread -> cuda_queue: Put data(in GPU) in queue
+end
+
+@enduml
diff --git a/DI-engine/ding/design/parallel_main-sequence.puml b/DI-engine/ding/design/parallel_main-sequence.puml
new file mode 100644
index 0000000000000000000000000000000000000000..8934b80a06e84068b547d659f77c21ad068f8fbb
--- /dev/null
+++ b/DI-engine/ding/design/parallel_main-sequence.puml
@@ -0,0 +1,97 @@
+@startuml
+skinparam NoteBackgroundColor PapayaWhip
+
+autonumber
+
+participant Coordinator
+participant Learner
+participant Collector
+participant Middleware
+participant Operator
+
+group start
+Coordinator->Coordinator: start communication module
+Coordinator->Coordinator: start commander
+Coordinator->Coordinator: start replay buffer
+Coordinator->Operator: connect operator
+Operator->Coordinator: send collector/learner info
+Coordinator->Learner: create connection
+Coordinator->Collector: create connection
+end
+
+loop
+autonumber
+group learn(async)
+Coordinator->Learner: request learner start task
+note right
+policy config
+learner config
+end note
+Learner->Coordinator: return learner start info
+group learner loop
+Coordinator->Learner: request data demand task
+Learner->Coordinator: return data demand
+Coordinator->Learner: request learn task and send data(metadata)
+note right
+data path
+data priority
+end note
+Middleware->Learner: load data(stepdata)
+Learner->Learner: learner a iteration
+Learner->Middleware: send policy info
+note left
+model state_dict
+model hyper-parameter
+end note
+Learner->Coordinator: return learn info
+note right
+policy meta
+train stat
+data priority
+end note
+end
+Coordinator->Learner: request learner close task
+Learner->Coordinator: return learner close info
+note right
+save final policy
+end note
+end
+
+autonumber
+group data collection/evaluation(async)
+Coordinator->Collector: request collector start task
+note right
+policy meta
+env config
+collector config
+end note
+Collector->Coordinator: return collector start info
+Middleware->Collector: load policy info for init
+group collector loop
+Coordinator->Collector: request get data task
+Collector->Collector: policy interact with env
+Collector->Middleware: send data(stepdata)
+Collector->Coordinator: return data(metadata)
+note right
+data path
+data length(rollout length)
+end note
+Middleware->Collector: load policy info for update
+end group
+Coordinator->Collector: request collector close task
+Collector->Coordinator: return collector close info
+note right
+episode result(cumulative reward)
+collector performance
+end note
+end group
+end
+
+autonumber
+group close
+Coordinator->Learner: destroy connection
+Coordinator->Collector: destroy connection
+Coordinator->Operator: disconnect operator
+Coordinator->Coordinator: close
+end group
+@enduml
diff --git a/DI-engine/ding/design/serial_collector-activity.puml b/DI-engine/ding/design/serial_collector-activity.puml
new file mode 100644
index 0000000000000000000000000000000000000000..d53b2193e2cc956cf612be91f69abe6b2bf9f6b8
--- /dev/null
+++ b/DI-engine/ding/design/serial_collector-activity.puml
@@ -0,0 +1,43 @@
+@startuml serial_collector
+header Serial Pipeline
+title Serial Collector
+
+|#99CCCC|serial_controller|
+|#99CCFF|env_manager|
+|#CCCCFF|policy|
+|#FFCCCC|collector|
+
+|#99CCCC|serial_controller|
+start
+:init collector, set its \nenv_manager and \ncollect_mode policy;
+|#99CCFF|env_manager|
+repeat
+ |#99CCFF|env_manager|
+ :return current obs;
+ |#CCCCFF|policy|
+ :[network] forward with obs;
+ |#99CCFF|env_manager|
+ :env step with action;
+ |#CCCCFF|policy|
+ :process transition;
+ |#FFCCCC|collector|
+ :save transition in cache;
+ if (for every env: \n env_i is done? OR cache is full?) then (yes)
+ if (is sample_collector?) then (yes)
+ note right: Only sample_collector will do so, \n episode_collector will not.
+ |#CCCCFF|policy|
+ :[adder] get train_sample from cache;
+ endif
+ |#FFCCCC|collector|
+ :save sample/episode for return;
+ if (env_i is done?) then (yes)
+ |#99CCFF|env_manager|
+ :env_i reset;
+ endif
+ endif
+|#FFCCCC|collector|
+repeat while (collected sample/episode is not enough?)
+:return sample/episode;
+stop
+
+@enduml
diff --git a/DI-engine/ding/design/serial_evaluator-activity.puml b/DI-engine/ding/design/serial_evaluator-activity.puml
new file mode 100644
index 0000000000000000000000000000000000000000..aa84c4eef5517c7b70a97f96a00e14ced90c5391
--- /dev/null
+++ b/DI-engine/ding/design/serial_evaluator-activity.puml
@@ -0,0 +1,31 @@
+@startuml serial_evaluator
+header Serial Pipeline
+title Serial Evaluator
+
+|#99CCCC|serial_controller|
+|#99CCFF|env_manager|
+|#CCCCFF|policy|
+|#FFCCCC|evaluator|
+
+|#99CCCC|serial_controller|
+start
+:init evaluator, set its \nenv_manager and \neval_mode policy;
+|#99CCFF|env_manager|
+repeat
+ :return current obs;
+ |#CCCCFF|policy|
+ :[network] forward with obs;
+ |#99CCFF|env_manager|
+ :env step with action;
+ |#FFCCCC|evaluator|
+ if (for every env: env i is done?) then (yes)
+ |#99CCFF|env_manager|
+ :env i reset;
+ |#FFCCCC|evaluator|
+ :log eval_episode_info;
+ endif
+repeat while (evaluate episodes are not enough?)
+|#FFCCCC|evaluator|
+:return eval_episode_return;
+stop
+@enduml
diff --git a/DI-engine/ding/design/serial_learner-activity.puml b/DI-engine/ding/design/serial_learner-activity.puml
new file mode 100644
index 0000000000000000000000000000000000000000..839ff94b80c8cfbb88fa119a3ce1b1e78a6fac19
--- /dev/null
+++ b/DI-engine/ding/design/serial_learner-activity.puml
@@ -0,0 +1,22 @@
+@startuml serial_learner
+header Serial Pipeline
+title Serial Learner
+
+|#99CCCC|serial_controller|
+|#CCCCFF|policy|
+|#99CCFF|learner|
+
+|#99CCCC|serial_controller|
+start
+:init learner, \nset its learn_mode policy;
+|#99CCFF|learner|
+:get data from buffer;
+|#CCCCFF|policy|
+:data forward;
+:loss backward;
+:optimizer step, gradient update;
+|#99CCFF|learner|
+:update train info(loss, value) and log;
+:update learn info(iteration, priority);
+stop
+@enduml
diff --git a/DI-engine/ding/design/serial_main.puml b/DI-engine/ding/design/serial_main.puml
new file mode 100644
index 0000000000000000000000000000000000000000..f710639dd839955e5d218c8a048101ea22bf66c8
--- /dev/null
+++ b/DI-engine/ding/design/serial_main.puml
@@ -0,0 +1,56 @@
+@startuml serial_main
+header Serial Pipeline
+title Serial Main
+
+participant controller
+participant env_manager
+participant policy
+participant learner
+participant replay_buffer
+participant collector
+participant evaluator
+participant commander
+autonumber
+
+controller -> env_manager: init collector and evaluator env_manager; set seed
+controller -> policy: init policy
+controller -> learner: init learner; set learn_mode policy
+controller -> collector: init collector; set collect_mode policy; set env_manager
+controller -> evaluator: init evaluator; set eval_mode policy; set env_manager
+controller -> commander: init commander; set command_mode policy
+controller -> replay_buffer: init replay_buffer
+alt random collect before training starts
+ collector -> collector: reset policy to random one; generate random data
+ collector -> replay_buffer: push_data
+ collector -> collector: reset policy back to the original one
+end
+learner -> learner: call before_run hook
+loop
+ commander -> commander: step
+ alt this iteration needs evaluation
+ evaluator -> evaluator: eval_performance
+ alt reach eval stop_value
+ learner -> learner: save checkpoint and exit
+ else episode_return is new highest
+ learner -> learner: save checkpoint
+ end
+ end
+ collector -> collector: generate data (steps or episodes)
+ collector -> replay_buffer: push_data
+ loop learner_train_iteration times
+ replay_buffer -> learner: sample_data
+ learner -> learner: train
+ alt replay replay_buffer use prioritization
+ learner -> replay_buffer: update with priority_info
+ end
+ end
+ alt on_policy training
+ replay_buffer -> replay_buffer: clear
+ end
+end
+learner -> learner: call after_run hook
+controller -> replay_buffer: close replay_buffer
+controller -> learner: close learner
+controller -> collector: close collector
+controller -> evaluator: close evaluator
+@enduml
diff --git a/DI-engine/ding/entry/__init__.py b/DI-engine/ding/entry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..11cccf0e131d10947d64748992e5e6a8c01e2dc7
--- /dev/null
+++ b/DI-engine/ding/entry/__init__.py
@@ -0,0 +1,28 @@
+from .cli import cli
+from .cli_ditask import cli_ditask
+from .serial_entry import serial_pipeline
+from .serial_entry_td3_vae import serial_pipeline_td3_vae
+from .serial_entry_onpolicy import serial_pipeline_onpolicy
+from .serial_entry_onpolicy_ppg import serial_pipeline_onpolicy_ppg
+from .serial_entry_offline import serial_pipeline_offline
+from .serial_entry_ngu import serial_pipeline_ngu
+from .serial_entry_reward_model_offpolicy import serial_pipeline_reward_model_offpolicy
+from .serial_entry_reward_model_onpolicy import serial_pipeline_reward_model_onpolicy
+from .serial_entry_bc import serial_pipeline_bc
+from .serial_entry_dqfd import serial_pipeline_dqfd
+from .serial_entry_r2d3 import serial_pipeline_r2d3
+from .serial_entry_sqil import serial_pipeline_sqil
+from .parallel_entry import parallel_pipeline
+from .application_entry import eval, collect_demo_data, collect_episodic_demo_data, \
+ episode_to_transitions, episode_to_transitions_filter
+from .application_entry_trex_collect_data import trex_collecting_data, collect_episodic_demo_data_for_trex
+from .serial_entry_guided_cost import serial_pipeline_guided_cost
+from .serial_entry_gail import serial_pipeline_gail
+from .utils import random_collect
+from .serial_entry_preference_based_irl \
+ import serial_pipeline_preference_based_irl
+from .serial_entry_preference_based_irl_onpolicy \
+ import serial_pipeline_preference_based_irl_onpolicy
+from .serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream, serial_pipeline_dreamer
+from .serial_entry_bco import serial_pipeline_bco
+from .serial_entry_pc import serial_pipeline_pc
diff --git a/DI-engine/ding/entry/application_entry.py b/DI-engine/ding/entry/application_entry.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb8fe882dfbaa03e13ecc908d10ff529b778f1b4
--- /dev/null
+++ b/DI-engine/ding/entry/application_entry.py
@@ -0,0 +1,281 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import pickle
+import numpy as np
+import torch
+from functools import partial
+from copy import deepcopy
+
+from ding.config import compile_config, read_config
+from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, EpisodeSerialCollector
+from ding.envs import create_env_manager, get_vec_env_setting
+from ding.policy import create_policy
+from ding.torch_utils import to_device, to_ndarray
+from ding.utils import set_pkg_seed
+from ding.utils.data import offline_data_save_type
+from ding.rl_utils import get_nstep_return_data
+from ding.utils.data import default_collate
+
+
+def eval(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ state_dict: Optional[dict] = None,
+ load_path: Optional[str] = None,
+ replay_path: Optional[str] = None,
+) -> float:
+ """
+ Overview:
+ Pure policy evaluation entry. Evaluate mean episode return and save replay videos.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model.
+ - load_path (:obj:`Optional[str]`): Path to load ckpt.
+ - replay_path (:obj:`Optional[str]`): Path to save replay.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(
+ cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, save_path='eval_config.py'
+ )
+
+ # Create components: env, policy, evaluator
+ if env_setting is None:
+ env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env, collect=False)
+ else:
+ env_fn, _, evaluator_env_cfg = env_setting
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ evaluator_env.seed(seed, dynamic_seed=False)
+ if replay_path is None: # argument > config
+ replay_path = cfg.env.get('replay_path', None)
+ if replay_path:
+ evaluator_env.enable_save_replay(replay_path)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['eval'])
+ if state_dict is None:
+ if load_path is None:
+ load_path = cfg.policy.learn.learner.load_path
+ state_dict = torch.load(load_path, map_location='cpu')
+ policy.eval_mode.load_state_dict(state_dict)
+ evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode)
+
+ # Evaluate
+ _, episode_info = evaluator.eval()
+ episode_return = np.mean(episode_info['eval_episode_return'])
+ print('Eval is over! The performance of your RL policy is {}'.format(episode_return))
+ return episode_return
+
+
+def collect_demo_data(
+ input_cfg: Union[str, dict],
+ seed: int,
+ collect_count: int,
+ expert_data_path: Optional[str] = None,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ state_dict: Optional[dict] = None,
+ state_dict_path: Optional[str] = None,
+) -> None:
+ r"""
+ Overview:
+ Collect demonstration data by the trained policy.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - collect_count (:obj:`int`): The count of collected data.
+ - expert_data_path (:obj:`str`): File path of the expert demo data will be written to.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model.
+ - state_dict_path (:obj:`Optional[str]`): The path of the state_dict of policy or model.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(
+ cfg,
+ seed=seed,
+ env=env_fn,
+ auto=True,
+ create_cfg=create_cfg,
+ save_cfg=True,
+ save_path='collect_demo_data_config.py'
+ )
+ if expert_data_path is None:
+ expert_data_path = cfg.policy.collect.save_path
+
+ # Create components: env, policy, collector
+ if env_setting is None:
+ env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env, eval_=False)
+ else:
+ env_fn, collector_env_cfg, _ = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ collector_env.seed(seed)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval'])
+ # for policies like DQN (in collect_mode has eps-greedy)
+ # collect_demo_policy = policy.collect_function(
+ # policy._forward_eval,
+ # policy._process_transition,
+ # policy._get_train_sample,
+ # policy._reset_eval,
+ # policy._get_attribute,
+ # policy._set_attribute,
+ # policy._state_dict_collect,
+ # policy._load_state_dict_collect,
+ # )
+ collect_demo_policy = policy.collect_mode
+ if state_dict is None:
+ assert state_dict_path is not None
+ state_dict = torch.load(state_dict_path, map_location='cpu')
+ policy.collect_mode.load_state_dict(state_dict)
+ collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy)
+
+ if hasattr(cfg.policy.other, 'eps'):
+ policy_kwargs = {'eps': 0.}
+ else:
+ policy_kwargs = None
+
+ # Let's collect some expert demonstrations
+ exp_data = collector.collect(n_sample=collect_count, policy_kwargs=policy_kwargs)
+ if cfg.policy.cuda:
+ exp_data = to_device(exp_data, 'cpu')
+ # Save data transitions.
+ offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive'))
+ print('Collect demo data successfully')
+
+
+def collect_episodic_demo_data(
+ input_cfg: Union[str, dict],
+ seed: int,
+ collect_count: int,
+ expert_data_path: str,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ state_dict: Optional[dict] = None,
+ state_dict_path: Optional[str] = None,
+) -> None:
+ r"""
+ Overview:
+ Collect episodic demonstration data by the trained policy.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - collect_count (:obj:`int`): The count of collected data.
+ - expert_data_path (:obj:`str`): File path of the expert demo data will be written to.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model.
+ - state_dict_path (:obj:'str') the abs path of the state dict
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(
+ cfg,
+ collector=EpisodeSerialCollector,
+ seed=seed,
+ env=env_fn,
+ auto=True,
+ create_cfg=create_cfg,
+ save_cfg=True,
+ save_path='collect_demo_data_config.py'
+ )
+
+ # Create components: env, policy, collector
+ if env_setting is None:
+ env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env, eval_=False)
+ else:
+ env_fn, collector_env_cfg, _ = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ collector_env.seed(seed)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval'])
+ collect_demo_policy = policy.collect_mode
+ if state_dict is None:
+ assert state_dict_path is not None
+ state_dict = torch.load(state_dict_path, map_location='cpu')
+ policy.collect_mode.load_state_dict(state_dict)
+ collector = EpisodeSerialCollector(cfg.policy.collect.collector, collector_env, collect_demo_policy)
+
+ if hasattr(cfg.policy.other, 'eps'):
+ policy_kwargs = {'eps': 0.}
+ else:
+ policy_kwargs = None
+
+ # Let's collect some expert demonstrations
+ exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs)
+ if cfg.policy.cuda:
+ exp_data = to_device(exp_data, 'cpu')
+ # Save data transitions.
+ offline_data_save_type(exp_data, expert_data_path, data_type=cfg.policy.collect.get('data_type', 'naive'))
+ print('Collect episodic demo data successfully')
+
+
+def episode_to_transitions(data_path: str, expert_data_path: str, nstep: int) -> None:
+ r"""
+ Overview:
+ Transfer episodic data into nstep transitions.
+ Arguments:
+ - data_path (:obj:str): data path that stores the pkl file
+ - expert_data_path (:obj:`str`): File path of the expert demo data will be written to.
+ - nstep (:obj:`int`): {s_{t}, a_{t}, s_{t+n}}.
+
+ """
+ with open(data_path, 'rb') as f:
+ _dict = pickle.load(f) # class is list; length is cfg.reward_model.collect_count
+ post_process_data = []
+ for i in range(len(_dict)):
+ data = get_nstep_return_data(_dict[i], nstep)
+ post_process_data.extend(data)
+ offline_data_save_type(
+ post_process_data,
+ expert_data_path,
+ )
+
+
+def episode_to_transitions_filter(data_path: str, expert_data_path: str, nstep: int, min_episode_return: int) -> None:
+ r"""
+ Overview:
+ Transfer episodic data into n-step transitions and only take the episode data whose return is larger than
+ min_episode_return.
+ Arguments:
+ - data_path (:obj:str): data path that stores the pkl file
+ - expert_data_path (:obj:`str`): File path of the expert demo data will be written to.
+ - nstep (:obj:`int`): {s_{t}, a_{t}, s_{t+n}}.
+
+ """
+ with open(data_path, 'rb') as f:
+ _dict = pickle.load(f) # class is list; length is cfg.reward_model.collect_count
+ post_process_data = []
+ for i in range(len(_dict)):
+ episode_returns = torch.stack([_dict[i][j]['reward'] for j in range(_dict[i].__len__())], axis=0)
+ if episode_returns.sum() < min_episode_return:
+ continue
+ data = get_nstep_return_data(_dict[i], nstep)
+ post_process_data.extend(data)
+ offline_data_save_type(
+ post_process_data,
+ expert_data_path,
+ )
diff --git a/DI-engine/ding/entry/application_entry_trex_collect_data.py b/DI-engine/ding/entry/application_entry_trex_collect_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcc5d227b4d29ea76866758d626fd35450248c4b
--- /dev/null
+++ b/DI-engine/ding/entry/application_entry_trex_collect_data.py
@@ -0,0 +1,160 @@
+import argparse
+import torch
+import os
+from typing import Union, Optional, List, Any
+from functools import partial
+from copy import deepcopy
+
+from ding.config import compile_config, read_config
+from ding.worker import EpisodeSerialCollector
+from ding.envs import create_env_manager, get_vec_env_setting
+from ding.policy import create_policy
+from ding.torch_utils import to_device
+from ding.utils import set_pkg_seed
+from ding.utils.data import offline_data_save_type
+from ding.utils.data import default_collate
+
+
+def collect_episodic_demo_data_for_trex(
+ input_cfg: Union[str, dict],
+ seed: int,
+ collect_count: int,
+ rank: int,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ state_dict: Optional[dict] = None,
+ state_dict_path: Optional[str] = None,
+):
+ """
+ Overview:
+ Collect episodic demonstration data by the trained policy for trex specifically.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - collect_count (:obj:`int`): The count of collected data.
+ - rank (:obj:`int`): The episode ranking.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model.
+ - state_dict_path (:obj:'str') The abs path of the state dict.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type += '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg.env.collector_env_num = 1
+ cfg = compile_config(
+ cfg,
+ collector=EpisodeSerialCollector,
+ seed=seed,
+ env=env_fn,
+ auto=True,
+ create_cfg=create_cfg,
+ save_cfg=True,
+ save_path='collect_demo_data_config.py'
+ )
+
+ # Create components: env, policy, collector
+ if env_setting is None:
+ env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, _ = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ collector_env.seed(seed)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['collect', 'eval'])
+ collect_demo_policy = policy.collect_mode
+ if state_dict is None:
+ assert state_dict_path is not None
+ state_dict = torch.load(state_dict_path, map_location='cpu')
+ policy.collect_mode.load_state_dict(state_dict)
+ collector = EpisodeSerialCollector(
+ cfg.policy.collect.collector, collector_env, collect_demo_policy, exp_name=cfg.exp_name
+ )
+
+ policy_kwargs = None if not hasattr(cfg.policy.other, 'eps') \
+ else {'eps': cfg.policy.other.eps.get('collect', 0.2)}
+
+ # Let's collect some sub-optimal demostrations
+ exp_data = collector.collect(n_episode=collect_count, policy_kwargs=policy_kwargs)
+ if cfg.policy.cuda:
+ exp_data = to_device(exp_data, 'cpu')
+ # Save data transitions.
+ print('Collect {}th episodic demo data successfully'.format(rank))
+ return exp_data
+
+
+def trex_get_args():
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='abs path for a config')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_known_args()[0]
+ return args
+
+
+def trex_collecting_data(args=None):
+ if args is None:
+ args = trex_get_args() # TODO(nyz) use sub-command in cli
+ if isinstance(args.cfg, str):
+ cfg, create_cfg = read_config(args.cfg)
+ else:
+ cfg, create_cfg = deepcopy(args.cfg)
+ data_path = cfg.exp_name
+ expert_model_path = cfg.reward_model.expert_model_path # directory path
+ checkpoint_min = cfg.reward_model.checkpoint_min
+ checkpoint_max = cfg.reward_model.checkpoint_max
+ checkpoint_step = cfg.reward_model.checkpoint_step
+ checkpoints = []
+ for i in range(checkpoint_min, checkpoint_max + checkpoint_step, checkpoint_step):
+ checkpoints.append(str(i))
+ data_for_save = {}
+ learning_returns = []
+ learning_rewards = []
+ episodes_data = []
+ for checkpoint in checkpoints:
+ num_per_ckpt = 1
+ model_path = expert_model_path + \
+ '/ckpt/iteration_' + checkpoint + '.pth.tar'
+ seed = args.seed + (int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step)
+ exp_data = collect_episodic_demo_data_for_trex(
+ deepcopy(args.cfg),
+ seed,
+ state_dict_path=model_path,
+ collect_count=num_per_ckpt,
+ rank=(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step) + 1
+ )
+ data_for_save[(int(checkpoint) - int(checkpoint_min)) // int(checkpoint_step)] = exp_data
+ obs = [list(default_collate(exp_data[i])['obs'].numpy()) for i in range(len(exp_data))]
+ rewards = [default_collate(exp_data[i])['reward'].tolist() for i in range(len(exp_data))]
+ sum_rewards = [torch.sum(default_collate(exp_data[i])['reward']).item() for i in range(len(exp_data))]
+
+ learning_rewards.append(rewards)
+ learning_returns.append(sum_rewards)
+ episodes_data.append(obs)
+ offline_data_save_type(
+ data_for_save, data_path + '/suboptimal_data.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
+ )
+ # if not compiled_cfg.reward_model.auto: more feature
+ offline_data_save_type(
+ episodes_data, data_path + '/episodes_data.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
+ )
+ offline_data_save_type(
+ learning_returns, data_path + '/learning_returns.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
+ )
+ offline_data_save_type(
+ learning_rewards, data_path + '/learning_rewards.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
+ )
+ offline_data_save_type(
+ checkpoints, data_path + '/checkpoints.pkl', data_type=cfg.policy.collect.get('data_type', 'naive')
+ )
+ return checkpoints, episodes_data, learning_returns, learning_rewards
+
+
+if __name__ == '__main__':
+ trex_collecting_data()
diff --git a/DI-engine/ding/entry/cli.py b/DI-engine/ding/entry/cli.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a335c2d76c07558ea7c8d44014887f17f807bff
--- /dev/null
+++ b/DI-engine/ding/entry/cli.py
@@ -0,0 +1,290 @@
+from typing import List, Union
+import os
+import copy
+import click
+from click.core import Context, Option
+import numpy as np
+
+from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__
+from ding.config import read_config
+from .predefined_config import get_predefined_config
+
+
+def print_version(ctx: Context, param: Option, value: bool) -> None:
+ if not value or ctx.resilient_parsing:
+ return
+ click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__))
+ click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__))
+ ctx.exit()
+
+
+def print_registry(ctx: Context, param: Option, value: str):
+ if value is None:
+ return
+ from ding.utils import registries # noqa
+ if value not in registries:
+ click.echo('[ERROR]: not support registry name: {}'.format(value))
+ else:
+ registered_info = registries[value].query_details()
+ click.echo('Available {}: [{}]'.format(value, '|'.join(registered_info.keys())))
+ for alias, info in registered_info.items():
+ click.echo('\t{}: registered at {}#{}'.format(alias, info[0], info[1]))
+ ctx.exit()
+
+
+CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
+
+
+@click.command(context_settings=CONTEXT_SETTINGS)
+@click.option(
+ '-v',
+ '--version',
+ is_flag=True,
+ callback=print_version,
+ expose_value=False,
+ is_eager=True,
+ help="Show package's version information."
+)
+@click.option(
+ '-q',
+ '--query-registry',
+ type=str,
+ callback=print_registry,
+ expose_value=False,
+ is_eager=True,
+ help='query registered module or function, show name and path'
+)
+@click.option(
+ '-m',
+ '--mode',
+ type=click.Choice(
+ [
+ 'serial',
+ 'serial_onpolicy',
+ 'serial_sqil',
+ 'serial_dqfd',
+ 'serial_trex',
+ 'serial_trex_onpolicy',
+ 'parallel',
+ 'dist',
+ 'eval',
+ 'serial_reward_model',
+ 'serial_gail',
+ 'serial_offline',
+ 'serial_ngu',
+ ]
+ ),
+ help='serial-train or parallel-train or dist-train or eval'
+)
+@click.option('-c', '--config', type=str, help='Path to DRL experiment config')
+@click.option(
+ '-s',
+ '--seed',
+ type=int,
+ default=[0],
+ multiple=True,
+ help='random generator seed(for all the possible package: random, numpy, torch and user env)'
+)
+@click.option('-e', '--env', type=str, help='RL env name')
+@click.option('-p', '--policy', type=str, help='DRL policy name')
+@click.option('--exp-name', type=str, help='experiment directory name')
+@click.option('--train-iter', type=str, default='1e8', help='Maximum policy update iterations in training')
+@click.option('--env-step', type=str, default='1e8', help='Maximum collected environment steps for training')
+@click.option('--load-path', type=str, default=None, help='Path to load ckpt')
+@click.option('--replay-path', type=str, default=None, help='Path to save replay')
+# the following arguments are only applied to dist mode
+@click.option('--enable-total-log', type=bool, help='whether enable the total DI-engine system log', default=False)
+@click.option('--disable-flask-log', type=bool, help='whether disable flask log', default=True)
+@click.option(
+ '-P', '--platform', type=click.Choice(['local', 'slurm', 'k8s']), help='local or slurm or k8s', default='local'
+)
+@click.option(
+ '-M',
+ '--module',
+ type=click.Choice(['config', 'collector', 'learner', 'coordinator', 'learner_aggregator', 'spawn_learner']),
+ help='dist module type'
+)
+@click.option('--module-name', type=str, help='dist module name')
+@click.option('-cdh', '--coordinator-host', type=str, help='coordinator host', default='0.0.0.0')
+@click.option('-cdp', '--coordinator-port', type=int, help='coordinator port')
+@click.option('-lh', '--learner-host', type=str, help='learner host', default='0.0.0.0')
+@click.option('-lp', '--learner-port', type=int, help='learner port')
+@click.option('-clh', '--collector-host', type=str, help='collector host', default='0.0.0.0')
+@click.option('-clp', '--collector-port', type=int, help='collector port')
+@click.option('-agh', '--aggregator-host', type=str, help='aggregator slave host', default='0.0.0.0')
+@click.option('-agp', '--aggregator-port', type=int, help='aggregator slave port')
+@click.option('--add', type=click.Choice(['collector', 'learner']), help='add replicas type')
+@click.option('--delete', type=click.Choice(['collector', 'learner']), help='delete replicas type')
+@click.option('--restart', type=click.Choice(['collector', 'learner']), help='restart replicas type')
+@click.option('--kubeconfig', type=str, default=None, help='the path of Kubernetes configuration file')
+@click.option('-cdn', '--coordinator-name', type=str, default=None, help='coordinator name')
+@click.option('-ns', '--namespace', type=str, default=None, help='job namespace')
+@click.option('-rs', '--replicas', type=int, default=1, help='number of replicas to add/delete/restart')
+@click.option('-rpn', '--restart-pod-name', type=str, default=None, help='restart pod name')
+@click.option('--cpus', type=int, default=0, help='The requested CPU, read the value from DIJob yaml by default')
+@click.option('--gpus', type=int, default=0, help='The requested GPU, read the value from DIJob yaml by default')
+@click.option(
+ '--memory', type=str, default=None, help='The requested Memory, read the value from DIJob yaml by default'
+)
+@click.option(
+ '--profile',
+ type=str,
+ default=None,
+ help='profile Time cost by cProfile, and save the files into the specified folder path'
+)
+def cli(
+ # serial/eval
+ mode: str,
+ config: str,
+ seed: Union[int, List],
+ exp_name: str,
+ env: str,
+ policy: str,
+ train_iter: str, # transform into int
+ env_step: str, # transform into int
+ load_path: str,
+ replay_path: str,
+ # parallel/dist
+ platform: str,
+ coordinator_host: str,
+ coordinator_port: int,
+ learner_host: str,
+ learner_port: int,
+ collector_host: str,
+ collector_port: int,
+ aggregator_host: str,
+ aggregator_port: int,
+ enable_total_log: bool,
+ disable_flask_log: bool,
+ module: str,
+ module_name: str,
+ # add/delete/restart
+ add: str,
+ delete: str,
+ restart: str,
+ kubeconfig: str,
+ coordinator_name: str,
+ namespace: str,
+ replicas: int,
+ cpus: int,
+ gpus: int,
+ memory: str,
+ restart_pod_name: str,
+ profile: str,
+):
+ if profile is not None:
+ from ..utils.profiler_helper import Profiler
+ profiler = Profiler()
+ profiler.profile(profile)
+
+ train_iter = int(float(train_iter))
+ env_step = int(float(env_step))
+
+ def run_single_pipeline(seed, config):
+ if config is None:
+ config = get_predefined_config(env, policy)
+ else:
+ config = read_config(config)
+ if exp_name is not None:
+ config[0].exp_name = exp_name
+
+ if mode == 'serial':
+ from .serial_entry import serial_pipeline
+ serial_pipeline(config, seed, max_train_iter=train_iter, max_env_step=env_step)
+ elif mode == 'serial_onpolicy':
+ from .serial_entry_onpolicy import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step)
+ elif mode == 'serial_sqil':
+ from .serial_entry_sqil import serial_pipeline_sqil
+ expert_config = input("Enter the name of the config you used to generate your expert model: ")
+ serial_pipeline_sqil(config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step)
+ elif mode == 'serial_reward_model':
+ from .serial_entry_reward_model_offpolicy import serial_pipeline_reward_model_offpolicy
+ serial_pipeline_reward_model_offpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step)
+ elif mode == 'serial_gail':
+ from .serial_entry_gail import serial_pipeline_gail
+ expert_config = input("Enter the name of the config you used to generate your expert model: ")
+ serial_pipeline_gail(
+ config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step, collect_data=True
+ )
+ elif mode == 'serial_dqfd':
+ from .serial_entry_dqfd import serial_pipeline_dqfd
+ expert_config = input("Enter the name of the config you used to generate your expert model: ")
+ assert (expert_config == config[:config.find('_dqfd')] + '_dqfd_config.py'), "DQFD only supports "\
+ + "the models used in q learning now; However, one should still type the DQFD config in this "\
+ + "place, i.e., {}{}".format(config[:config.find('_dqfd')], '_dqfd_config.py')
+ serial_pipeline_dqfd(config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step)
+ elif mode == 'serial_trex':
+ from .serial_entry_trex import serial_pipeline_trex
+ serial_pipeline_trex(config, seed, max_train_iter=train_iter, max_env_step=env_step)
+ elif mode == 'serial_trex_onpolicy':
+ from .serial_entry_trex_onpolicy import serial_pipeline_trex_onpolicy
+ serial_pipeline_trex_onpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step)
+ elif mode == 'serial_offline':
+ from .serial_entry_offline import serial_pipeline_offline
+ serial_pipeline_offline(config, seed, max_train_iter=train_iter)
+ elif mode == 'serial_ngu':
+ from .serial_entry_ngu import serial_pipeline_ngu
+ serial_pipeline_ngu(config, seed, max_train_iter=train_iter)
+ elif mode == 'parallel':
+ from .parallel_entry import parallel_pipeline
+ parallel_pipeline(config, seed, enable_total_log, disable_flask_log)
+ elif mode == 'dist':
+ from .dist_entry import dist_launch_coordinator, dist_launch_collector, dist_launch_learner, \
+ dist_prepare_config, dist_launch_learner_aggregator, dist_launch_spawn_learner, \
+ dist_add_replicas, dist_delete_replicas, dist_restart_replicas
+ if module == 'config':
+ dist_prepare_config(
+ config, seed, platform, coordinator_host, learner_host, collector_host, coordinator_port,
+ learner_port, collector_port
+ )
+ elif module == 'coordinator':
+ dist_launch_coordinator(config, seed, coordinator_port, disable_flask_log)
+ elif module == 'learner_aggregator':
+ dist_launch_learner_aggregator(
+ config, seed, aggregator_host, aggregator_port, module_name, disable_flask_log
+ )
+
+ elif module == 'collector':
+ dist_launch_collector(config, seed, collector_port, module_name, disable_flask_log)
+ elif module == 'learner':
+ dist_launch_learner(config, seed, learner_port, module_name, disable_flask_log)
+ elif module == 'spawn_learner':
+ dist_launch_spawn_learner(config, seed, learner_port, module_name, disable_flask_log)
+ elif add in ['collector', 'learner']:
+ dist_add_replicas(add, kubeconfig, replicas, coordinator_name, namespace, cpus, gpus, memory)
+ elif delete in ['collector', 'learner']:
+ dist_delete_replicas(delete, kubeconfig, replicas, coordinator_name, namespace)
+ elif restart in ['collector', 'learner']:
+ dist_restart_replicas(restart, kubeconfig, coordinator_name, namespace, restart_pod_name)
+ else:
+ raise Exception
+ elif mode == 'eval':
+ from .application_entry import eval
+ eval(config, seed, load_path=load_path, replay_path=replay_path)
+
+ if mode is None:
+ raise RuntimeError("Please indicate at least one argument.")
+
+ if isinstance(seed, (list, tuple)):
+ assert len(seed) > 0, "Please input at least 1 seed"
+ if len(seed) == 1: # necessary
+ run_single_pipeline(seed[0], config)
+ else:
+ if exp_name is None:
+ multi_exp_root = os.path.basename(config).split('.')[0] + '_result'
+ else:
+ multi_exp_root = exp_name
+ if not os.path.exists(multi_exp_root):
+ os.makedirs(multi_exp_root)
+ abs_config_path = os.path.abspath(config)
+ origin_root = os.getcwd()
+ for s in seed:
+ seed_exp_root = os.path.join(multi_exp_root, 'seed{}'.format(s))
+ if not os.path.exists(seed_exp_root):
+ os.makedirs(seed_exp_root)
+ os.chdir(seed_exp_root)
+ run_single_pipeline(s, abs_config_path)
+ os.chdir(origin_root)
+ else:
+ raise TypeError("invalid seed type: {}".format(type(seed)))
diff --git a/DI-engine/ding/entry/cli_ditask.py b/DI-engine/ding/entry/cli_ditask.py
new file mode 100644
index 0000000000000000000000000000000000000000..443fe1a6b6f67bb4fec777f8123ec368d320942b
--- /dev/null
+++ b/DI-engine/ding/entry/cli_ditask.py
@@ -0,0 +1,161 @@
+import click
+import os
+import sys
+import importlib
+import importlib.util
+import json
+from click.core import Context, Option
+
+from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__
+from ding.framework import Parallel
+from ding.entry.cli_parsers import PLATFORM_PARSERS
+
+
+def print_version(ctx: Context, param: Option, value: bool) -> None:
+ if not value or ctx.resilient_parsing:
+ return
+ click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__))
+ click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__))
+ ctx.exit()
+
+
+CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
+
+
+@click.command(context_settings=CONTEXT_SETTINGS)
+@click.option(
+ '-v',
+ '--version',
+ is_flag=True,
+ callback=print_version,
+ expose_value=False,
+ is_eager=True,
+ help="Show package's version information."
+)
+@click.option('-p', '--package', type=str, help="Your code package path, could be a directory or a zip file.")
+@click.option('--parallel-workers', type=int, default=1, help="Parallel worker number, default: 1")
+@click.option(
+ '--protocol',
+ type=click.Choice(["tcp", "ipc"]),
+ default="tcp",
+ help="Network protocol in parallel mode, default: tcp"
+)
+@click.option(
+ "--ports",
+ type=str,
+ help="The port addresses that the tasks listen to, e.g. 50515,50516, default: k8s, local: 50515, slurm: 15151"
+)
+@click.option("--attach-to", type=str, help="The addresses to connect to.")
+@click.option("--address", type=str, help="The address to listen to (without port).")
+@click.option("--labels", type=str, help="Labels.")
+@click.option("--node-ids", type=str, help="Candidate node ids.")
+@click.option(
+ "--topology",
+ type=click.Choice(["alone", "mesh", "star"]),
+ default="alone",
+ help="Network topology, default: alone."
+)
+@click.option("--platform-spec", type=str, help="Platform specific configure.")
+@click.option("--platform", type=str, help="Platform type: slurm, k8s.")
+@click.option("--mq-type", type=str, default="nng", help="Class type of message queue, i.e. nng, redis.")
+@click.option("--redis-host", type=str, help="Redis host.")
+@click.option("--redis-port", type=int, help="Redis port.")
+@click.option("-m", "--main", type=str, help="Main function of entry module.")
+@click.option("--startup-interval", type=int, default=1, help="Start up interval between each task.")
+@click.option("--local_rank", type=int, default=0, help="Compatibility with PyTorch DDP")
+def cli_ditask(*args, **kwargs):
+ return _cli_ditask(*args, **kwargs)
+
+
+def _parse_platform_args(platform: str, platform_spec: str, all_args: dict):
+ if platform_spec:
+ try:
+ if os.path.splitext(platform_spec) == "json":
+ with open(platform_spec) as f:
+ platform_spec = json.load(f)
+ else:
+ platform_spec = json.loads(platform_spec)
+ except:
+ click.echo("platform_spec is not a valid json!")
+ exit(1)
+ if platform not in PLATFORM_PARSERS:
+ click.echo("platform type is invalid! type: {}".format(platform))
+ exit(1)
+ all_args.pop("platform")
+ all_args.pop("platform_spec")
+ try:
+ parsed_args = PLATFORM_PARSERS[platform](platform_spec, **all_args)
+ except Exception as e:
+ click.echo("error when parse platform spec configure: {}".format(e))
+ raise e
+
+ return parsed_args
+
+
+def _cli_ditask(
+ package: str,
+ main: str,
+ parallel_workers: int,
+ protocol: str,
+ ports: str,
+ attach_to: str,
+ address: str,
+ labels: str,
+ node_ids: str,
+ topology: str,
+ mq_type: str,
+ redis_host: str,
+ redis_port: int,
+ startup_interval: int,
+ local_rank: int = 0,
+ platform: str = None,
+ platform_spec: str = None,
+):
+ # Parse entry point
+ all_args = locals()
+ if platform:
+ parsed_args = _parse_platform_args(platform, platform_spec, all_args)
+ return _cli_ditask(**parsed_args)
+
+ if not package:
+ package = os.getcwd()
+ sys.path.append(package)
+ if main is None:
+ mod_name = os.path.basename(package)
+ mod_name, _ = os.path.splitext(mod_name)
+ func_name = "main"
+ else:
+ mod_name, func_name = main.rsplit(".", 1)
+ root_mod_name = mod_name.split(".", 1)[0]
+ sys.path.append(os.path.join(package, root_mod_name))
+ mod = importlib.import_module(mod_name)
+ main_func = getattr(mod, func_name)
+ # Parse arguments
+ ports = ports or 50515
+ if not isinstance(ports, int):
+ ports = ports.split(",")
+ ports = list(map(lambda i: int(i), ports))
+ ports = ports[0] if len(ports) == 1 else ports
+ if attach_to:
+ attach_to = attach_to.split(",")
+ attach_to = list(map(lambda s: s.strip(), attach_to))
+ if labels:
+ labels = labels.split(",")
+ labels = set(map(lambda s: s.strip(), labels))
+ if node_ids and not isinstance(node_ids, int):
+ node_ids = node_ids.split(",")
+ node_ids = list(map(lambda i: int(i), node_ids))
+ Parallel.runner(
+ n_parallel_workers=parallel_workers,
+ ports=ports,
+ protocol=protocol,
+ topology=topology,
+ attach_to=attach_to,
+ address=address,
+ labels=labels,
+ node_ids=node_ids,
+ mq_type=mq_type,
+ redis_host=redis_host,
+ redis_port=redis_port,
+ startup_interval=startup_interval
+ )(main_func)
diff --git a/DI-engine/ding/entry/cli_parsers/__init__.py b/DI-engine/ding/entry/cli_parsers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..daa2410e0b4ae08c84f90c8b28b9b871f0532bef
--- /dev/null
+++ b/DI-engine/ding/entry/cli_parsers/__init__.py
@@ -0,0 +1,3 @@
+from .slurm_parser import slurm_parser
+from .k8s_parser import k8s_parser
+PLATFORM_PARSERS = {"slurm": slurm_parser, "k8s": k8s_parser}
diff --git a/DI-engine/ding/entry/cli_parsers/k8s_parser.py b/DI-engine/ding/entry/cli_parsers/k8s_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f2b0aebe7b3d647a176abb9f7adc8f8d6fa86b1
--- /dev/null
+++ b/DI-engine/ding/entry/cli_parsers/k8s_parser.py
@@ -0,0 +1,151 @@
+import os
+import numpy as np
+from time import sleep
+from typing import Dict, List, Optional
+
+
+class K8SParser():
+
+ def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None:
+ """
+ Overview:
+ Should only set global cluster properties
+ """
+ self.kwargs = kwargs
+ self.nodelist = self._parse_node_list()
+ self.ntasks = len(self.nodelist)
+ self.platform_spec = platform_spec
+ self.parallel_workers = kwargs.get("parallel_workers") or 1
+ self.topology = kwargs.get("topology") or "alone"
+ self.ports = int(kwargs.get("ports") or 50515)
+ self.tasks = {}
+
+ def parse(self) -> dict:
+ if self.kwargs.get("mq_type", "nng") != "nng":
+ return self.kwargs
+ procid = int(os.environ["DI_RANK"])
+ nodename = self.nodelist[procid]
+ task = self._get_task(procid)
+ # Validation
+ assert task["address"] == nodename
+ return {**self.kwargs, **task}
+
+ def _parse_node_list(self) -> List[str]:
+ return os.environ["DI_NODES"].split(",")
+
+ def _get_task(self, procid: int) -> dict:
+ """
+ Overview:
+ Complete node properties, use environment vars in list instead of on current node.
+ For example, if you want to set nodename in this function, please derive it from DI_NODES.
+ Arguments:
+ - procid (:obj:`int`): Proc order, starting from 0, must be set automatically by dijob.
+ Note that it is different from node_id.
+ """
+ if procid in self.tasks:
+ return self.tasks.get(procid)
+
+ if self.platform_spec:
+ task = self.platform_spec["tasks"][procid]
+ else:
+ task = {}
+ if "ports" not in task:
+ task["ports"] = self.kwargs.get("ports") or self._get_ports()
+ if "address" not in task:
+ task["address"] = self.kwargs.get("address") or self._get_address(procid)
+ if "node_ids" not in task:
+ task["node_ids"] = self.kwargs.get("node_ids") or self._get_node_id(procid)
+
+ task["attach_to"] = self.kwargs.get("attach_to") or self._get_attach_to(procid, task.get("attach_to"))
+ task["topology"] = self.topology
+ task["parallel_workers"] = self.parallel_workers
+
+ self.tasks[procid] = task
+ return task
+
+ def _get_attach_to(self, procid: int, attach_to: Optional[str] = None) -> str:
+ """
+ Overview:
+ Parse from pattern of attach_to. If attach_to is specified in the platform_spec,
+ it is formatted as a real address based on the specified address.
+ If not, the real addresses will be generated based on the globally specified typology.
+ Arguments:
+ - procid (:obj:`int`): Proc order.
+ - attach_to (:obj:`str`): The attach_to field in platform_spec for the task with current procid.
+ Returns
+ - attach_to (:obj:`str`): The real addresses for attach_to.
+ """
+ if attach_to:
+ attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")]
+ elif procid == 0:
+ attach_to = []
+ else:
+ if self.topology == "mesh":
+ prev_tasks = [self._get_task(i) for i in range(procid)]
+ attach_to = [self._get_attach_to_from_task(task) for task in prev_tasks]
+ attach_to = list(np.concatenate(attach_to))
+ elif self.topology == "star":
+ head_task = self._get_task(0)
+ attach_to = self._get_attach_to_from_task(head_task)
+ else:
+ attach_to = []
+
+ return ",".join(attach_to)
+
+ def _get_attach_to_part(self, attach_part: str) -> str:
+ """
+ Overview:
+ Parse each part of attach_to.
+ Arguments:
+ - attach_part (:obj:`str`): The attach_to field with specific pattern, e.g. $node:0
+ Returns
+ - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000
+ """
+ if not attach_part.startswith("$node."):
+ return attach_part
+ attach_node_id = int(attach_part[6:])
+ attach_task = self._get_task(self._get_procid_from_nodeid(attach_node_id))
+ return self._get_tcp_link(attach_task["address"], attach_task["ports"])
+
+ def _get_attach_to_from_task(self, task: dict) -> List[str]:
+ """
+ Overview:
+ Get attach_to list from task, note that parallel_workers will affact the connected processes.
+ Arguments:
+ - task (:obj:`dict`): The task object.
+ Returns
+ - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000
+ """
+ port = task.get("ports")
+ address = task.get("address")
+ ports = [int(port) + i for i in range(self.parallel_workers)]
+ attach_to = [self._get_tcp_link(address, port) for port in ports]
+ return attach_to
+
+ def _get_procid_from_nodeid(self, nodeid: int) -> int:
+ procid = None
+ for i in range(self.ntasks):
+ task = self._get_task(i)
+ if task["node_ids"] == nodeid:
+ procid = i
+ break
+ if procid is None:
+ raise Exception("Can not find procid from nodeid: {}".format(nodeid))
+ return procid
+
+ def _get_ports(self) -> str:
+ return self.ports
+
+ def _get_address(self, procid: int) -> str:
+ address = self.nodelist[procid]
+ return address
+
+ def _get_tcp_link(self, address: str, port: int) -> str:
+ return "tcp://{}:{}".format(address, port)
+
+ def _get_node_id(self, procid: int) -> int:
+ return procid * self.parallel_workers
+
+
+def k8s_parser(platform_spec: Optional[str] = None, **kwargs) -> dict:
+ return K8SParser(platform_spec, **kwargs).parse()
diff --git a/DI-engine/ding/entry/cli_parsers/slurm_parser.py b/DI-engine/ding/entry/cli_parsers/slurm_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..c46716438b063a05d448ccd734b38f8e42715a4c
--- /dev/null
+++ b/DI-engine/ding/entry/cli_parsers/slurm_parser.py
@@ -0,0 +1,150 @@
+import os
+import re
+from time import sleep
+import numpy as np
+from typing import Any, Dict, List, Optional
+
+
+class SlurmParser():
+
+ def __init__(self, platform_spec: Optional[Dict] = None, **kwargs) -> None:
+ """
+ Overview:
+ Should only set global cluster properties
+ """
+ self.kwargs = kwargs
+ self.ntasks = int(os.environ["SLURM_NTASKS"])
+ self.platform_spec = platform_spec
+ self.tasks = {}
+ self.ntasks_per_node = int(os.environ["SLURM_NTASKS_PER_NODE"])
+ self.nodelist = self._parse_node_list()
+ self.ports = int(kwargs.get("ports") or 15151)
+ self.parallel_workers = kwargs.get("parallel_workers") or 1
+ self.topology = kwargs.get("topology") or "alone"
+
+ def parse(self) -> dict:
+ procid = int(os.environ["SLURM_PROCID"])
+ task = self._get_task(procid)
+ # Validation
+ assert task["address"] == os.environ["SLURMD_NODENAME"]
+ return {**self.kwargs, **task}
+
+ def _get_task(self, procid: int) -> Dict[str, Any]:
+ if procid in self.tasks:
+ return self.tasks.get(procid)
+ if self.platform_spec:
+ task = self.platform_spec["tasks"][procid]
+ else:
+ task = {}
+ if "ports" not in task:
+ task["ports"] = self._get_ports(procid)
+ if "address" not in task:
+ task["address"] = self._get_address(procid)
+ if "node_ids" not in task:
+ task["node_ids"] = self._get_node_id(procid)
+
+ task["attach_to"] = self._get_attach_to(procid, task.get("attach_to"))
+ task["topology"] = self.topology
+ task["parallel_workers"] = self.parallel_workers
+
+ self.tasks[procid] = task
+ return task
+
+ def _parse_node_list(self) -> List[str]:
+ nodelist = os.environ["SLURM_NODELIST"]
+ result = re.match(r"(.*)?\[(.*)\]$", nodelist)
+ if result:
+ prefix, tails = result.groups()
+ nodelist = []
+ for tail in tails.split(","):
+ if "-" in tail:
+ start, stop = tail.split("-")
+ for number in range(int(start), int(stop) + 1):
+ nodelist.append(prefix + str(number))
+ else:
+ nodelist.append(prefix + tail)
+ elif isinstance(nodelist, str):
+ nodelist = [nodelist]
+ if self.ntasks_per_node > 1:
+ expand_nodelist = [] # Expand node for each task
+ for node in nodelist:
+ for _ in range(self.ntasks_per_node):
+ expand_nodelist.append(node)
+ nodelist = expand_nodelist
+ return nodelist
+
+ def _get_attach_to(self, procid: int, attach_to: Optional[str] = None) -> str:
+ if attach_to:
+ attach_to = [self._get_attach_to_part(part) for part in attach_to.split(",")]
+ elif procid == 0:
+ attach_to = []
+ else:
+ if self.topology == "mesh":
+ prev_tasks = [self._get_task(i) for i in range(procid)]
+ attach_to = [self._get_attach_to_from_task(task) for task in prev_tasks]
+ attach_to = list(np.concatenate(attach_to))
+ elif self.topology == "star":
+ head_task = self._get_task(0)
+ attach_to = self._get_attach_to_from_task(head_task)
+ else:
+ attach_to = []
+
+ return ",".join(attach_to)
+
+ def _get_attach_to_part(self, attach_part: str) -> str:
+ """
+ Overview:
+ Parse each part of attach_to.
+ Arguments:
+ - attach_part (:obj:`str`): The attach_to field with specific pattern, e.g. $node:0
+ Returns
+ - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000
+ """
+ if not attach_part.startswith("$node."):
+ return attach_part
+ attach_node_id = int(attach_part[6:])
+ attach_task = self._get_task(self._get_procid_from_nodeid(attach_node_id))
+ return self._get_tcp_link(attach_task["address"], attach_task["ports"])
+
+ def _get_attach_to_from_task(self, task: dict) -> List[str]:
+ """
+ Overview:
+ Get attach_to list from task, note that parallel_workers will affact the connected processes.
+ Arguments:
+ - task (:obj:`dict`): The task object.
+ Returns
+ - attach_to (:obj:`str`): The real address, e.g. tcp://SH-0:50000
+ """
+ port = task.get("ports")
+ address = task.get("address")
+ ports = [int(port) + i for i in range(self.parallel_workers)]
+ attach_to = [self._get_tcp_link(address, port) for port in ports]
+ return attach_to
+
+ def _get_procid_from_nodeid(self, nodeid: int) -> int:
+ procid = None
+ for i in range(self.ntasks):
+ task = self._get_task(i)
+ if task["node_ids"] == nodeid:
+ procid = i
+ break
+ if procid is None:
+ raise Exception("Can not find procid from nodeid: {}".format(nodeid))
+ return procid
+
+ def _get_ports(self, procid) -> int:
+ return self.ports + (procid % self.ntasks_per_node) * self.parallel_workers
+
+ def _get_address(self, procid: int) -> str:
+ address = self.nodelist[procid]
+ return address
+
+ def _get_node_id(self, procid: int) -> int:
+ return procid * self.parallel_workers
+
+ def _get_tcp_link(self, address: str, port: int) -> str:
+ return "tcp://{}:{}".format(address, port)
+
+
+def slurm_parser(platform_spec: str, **kwargs) -> dict:
+ return SlurmParser(platform_spec, **kwargs).parse()
diff --git a/DI-engine/ding/entry/cli_parsers/tests/test_k8s_parser.py b/DI-engine/ding/entry/cli_parsers/tests/test_k8s_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..d8a358b8661c27266698e106fc5b7dab1b5c3eee
--- /dev/null
+++ b/DI-engine/ding/entry/cli_parsers/tests/test_k8s_parser.py
@@ -0,0 +1,81 @@
+import pytest
+import os
+from ding.entry.cli_parsers.k8s_parser import k8s_parser
+
+
+@pytest.fixture
+def set_k8s_env():
+ os.environ["DI_NODES"] = 'SH-0,SH-1,SH-2,SH-3,SH-4,SH-5' # All the nodes
+ os.environ["DI_RANK"] = '3' # Proc order, start from 0, can not be modified by config
+
+ yield
+
+ del os.environ["DI_NODES"]
+ del os.environ["DI_RANK"]
+
+
+@pytest.mark.unittest
+@pytest.mark.usefixtures('set_k8s_env')
+def test_k8s_parser():
+ # With platform_spec
+ platform_spec = {
+ "tasks": [
+ {
+ "labels": "league,collect",
+ "node_ids": 10
+ }, {
+ "labels": "league,collect",
+ "node_ids": 11
+ }, {
+ "labels": "evaluate",
+ "node_ids": 20,
+ "attach_to": "$node.10,$node.11"
+ }, {
+ "labels": "learn",
+ "node_ids": 31,
+ "ports": 50000,
+ "attach_to": "$node.10,$node.11,$node.20"
+ }, {
+ "labels": "learn",
+ "node_ids": 32,
+ "attach_to": "$node.10,$node.11,$node.20"
+ }, {
+ "labels": "learn",
+ "node_ids": 33,
+ "attach_to": "$node.10,$node.11,$node.20"
+ }
+ ]
+ }
+ all_args = k8s_parser(platform_spec, mq_type="nng")
+ assert all_args["labels"] == "learn"
+ assert all_args["address"] == "SH-3"
+ assert all_args["ports"] == 50000
+ assert all_args["node_ids"] == 31
+ assert all_args["parallel_workers"] == 1
+ assert all_args[
+ "attach_to"
+ ] == "tcp://SH-0:50515," +\
+ "tcp://SH-1:50515," +\
+ "tcp://SH-2:50515"
+
+ # Without platform_spec, parse by global config
+ all_args = k8s_parser(None, topology="mesh", mq_type="nng")
+ assert all_args["address"] == "SH-3"
+ assert all_args["node_ids"] == 3
+ assert all_args["parallel_workers"] == 1
+ assert all_args[
+ "attach_to"
+ ] == "tcp://SH-0:50515," +\
+ "tcp://SH-1:50515," +\
+ "tcp://SH-2:50515"
+
+ # With multiple parallel workers
+ all_args = k8s_parser(None, topology="mesh", parallel_workers=2)
+ assert all_args["address"] == "SH-3"
+ assert all_args["node_ids"] == 6
+ assert all_args["parallel_workers"] == 2
+ assert all_args[
+ "attach_to"
+ ] == "tcp://SH-0:50515,tcp://SH-0:50516," +\
+ "tcp://SH-1:50515,tcp://SH-1:50516," +\
+ "tcp://SH-2:50515,tcp://SH-2:50516"
diff --git a/DI-engine/ding/entry/cli_parsers/tests/test_slurm_parser.py b/DI-engine/ding/entry/cli_parsers/tests/test_slurm_parser.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b817ba48a5aa5f5b1a4610d71c4031295452f1d
--- /dev/null
+++ b/DI-engine/ding/entry/cli_parsers/tests/test_slurm_parser.py
@@ -0,0 +1,84 @@
+import pytest
+import os
+from ding.entry.cli_parsers import PLATFORM_PARSERS
+from ding.entry.cli_parsers.slurm_parser import SlurmParser
+slurm_parser = PLATFORM_PARSERS["slurm"]
+
+
+@pytest.fixture
+def set_slurm_env():
+ os.environ["SLURM_NTASKS"] = '6' # Parameter n,Process count / Task count
+ os.environ["SLURM_NTASKS_PER_NODE"] = '3' # Parameter ntasks-per-node,process count of each node
+ os.environ["SLURM_NODELIST"] = 'SH-IDC1-10-5-38-[190,215]' # All the nodes
+ os.environ["SLURM_PROCID"] = '3' # Proc order,start from 0,the read proc order may be different from nominal order
+ os.environ["SLURMD_NODENAME"] = 'SH-IDC1-10-5-38-215' # Name of current node
+
+ yield
+
+ del os.environ["SLURM_NTASKS"]
+ del os.environ["SLURM_NTASKS_PER_NODE"]
+ del os.environ["SLURM_NODELIST"]
+ del os.environ["SLURM_PROCID"]
+ del os.environ["SLURMD_NODENAME"]
+
+
+@pytest.mark.unittest
+@pytest.mark.usefixtures('set_slurm_env')
+def test_slurm_parser():
+ platform_spec = {
+ "tasks": [
+ {
+ "labels": "league,collect",
+ "node_ids": 10
+ }, {
+ "labels": "league,collect",
+ "node_ids": 11
+ }, {
+ "labels": "evaluate",
+ "node_ids": 20,
+ "attach_to": "$node.10,$node.11"
+ }, {
+ "labels": "learn",
+ "node_ids": 31,
+ "attach_to": "$node.10,$node.11,$node.20"
+ }, {
+ "labels": "learn",
+ "node_ids": 32,
+ "attach_to": "$node.10,$node.11,$node.20"
+ }, {
+ "labels": "learn",
+ "node_ids": 33,
+ "attach_to": "$node.10,$node.11,$node.20"
+ }
+ ]
+ }
+ all_args = slurm_parser(platform_spec)
+ assert all_args["labels"] == "learn"
+ assert all_args["address"] == "SH-IDC1-10-5-38-215"
+ assert all_args["ports"] == 15151 # Start from 15151
+ assert all_args["node_ids"] == 31
+ assert all_args[
+ "attach_to"
+ ] == "tcp://SH-IDC1-10-5-38-190:15151," +\
+ "tcp://SH-IDC1-10-5-38-190:15152," +\
+ "tcp://SH-IDC1-10-5-38-190:15153"
+
+ # Test without platform_spec
+ all_args = slurm_parser(None, topology="mesh", mq_type="nng")
+ assert all_args["address"] == "SH-IDC1-10-5-38-215"
+ assert all_args["node_ids"] == 3
+ assert all_args["parallel_workers"] == 1
+ assert all_args[
+ "attach_to"
+ ] == "tcp://SH-IDC1-10-5-38-190:15151," +\
+ "tcp://SH-IDC1-10-5-38-190:15152," +\
+ "tcp://SH-IDC1-10-5-38-190:15153"
+
+ # Test _parse_node_list
+ sp = SlurmParser(platform_spec)
+ os.environ["SLURM_NODELIST"] = 'SH-IDC1-10-5-[38-40]'
+ nodelist = sp._parse_node_list() # Nodes * parallel_workers
+ assert nodelist == [
+ 'SH-IDC1-10-5-38', 'SH-IDC1-10-5-38', 'SH-IDC1-10-5-38', 'SH-IDC1-10-5-39', 'SH-IDC1-10-5-39',
+ 'SH-IDC1-10-5-39', 'SH-IDC1-10-5-40', 'SH-IDC1-10-5-40', 'SH-IDC1-10-5-40'
+ ]
diff --git a/DI-engine/ding/entry/dist_entry.py b/DI-engine/ding/entry/dist_entry.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d9d37fe14814eef2dfcf562401b261bdeaae53b
--- /dev/null
+++ b/DI-engine/ding/entry/dist_entry.py
@@ -0,0 +1,333 @@
+import os
+import sys
+import subprocess
+import signal
+import pickle
+from ditk import logging
+import time
+from threading import Thread
+from easydict import EasyDict
+import numpy as np
+from ding.worker import Coordinator, create_comm_collector, create_comm_learner, LearnerAggregator
+from ding.config import read_config_with_system, compile_config_parallel
+from ding.utils import set_pkg_seed, DEFAULT_K8S_AGGREGATOR_SLAVE_PORT, pod_exec_command
+
+
+def dist_prepare_config(
+ filename: str,
+ seed: int,
+ platform: str,
+ coordinator_host: str,
+ learner_host: str,
+ collector_host: str,
+ coordinator_port: int,
+ learner_port: int,
+ collector_port,
+) -> str:
+ set_pkg_seed(seed)
+ main_cfg, create_cfg, system_cfg = read_config_with_system(filename)
+ config = compile_config_parallel(
+ main_cfg,
+ create_cfg=create_cfg,
+ system_cfg=system_cfg,
+ seed=seed,
+ platform=platform,
+ coordinator_host=coordinator_host,
+ learner_host=learner_host,
+ collector_host=collector_host,
+ coordinator_port=coordinator_port,
+ learner_port=learner_port,
+ collector_port=collector_port,
+ )
+ # Pickle dump config to disk for later use.
+ real_filename = filename + '.pkl'
+ with open(real_filename, 'wb') as f:
+ pickle.dump(config, f)
+ return real_filename
+
+
+def dist_launch_coordinator(
+ filename: str,
+ seed: int,
+ coordinator_port: int,
+ disable_flask_log: bool,
+ enable_total_log: bool = False
+) -> None:
+ set_pkg_seed(seed)
+ # Disable some part of DI-engine log
+ if not enable_total_log:
+ coordinator_log = logging.getLogger('coordinator_logger')
+ coordinator_log.disabled = True
+ if disable_flask_log:
+ log = logging.getLogger('werkzeug')
+ log.disabled = True
+ with open(filename, 'rb') as f:
+ config = pickle.load(f)
+ # CLI > ENV VARIABLE > CONFIG
+ if coordinator_port is not None:
+ config.system.coordinator.port = coordinator_port
+ elif os.environ.get('COORDINATOR_PORT', None):
+ port = os.environ['COORDINATOR_PORT']
+ if port.isdigit():
+ config.system.coordinator.port = int(port)
+ else: # use config pre-defined value
+ assert 'port' in config.system.coordinator and np.isscalar(config.system.coordinator.port)
+ coordinator = Coordinator(config)
+ coordinator.start()
+
+ # Monitor thread: Coordinator will remain running until its ``system_shutdown_flag`` is set to False.
+ def shutdown_monitor():
+ while True:
+ time.sleep(3)
+ if coordinator.system_shutdown_flag:
+ coordinator.close()
+ break
+
+ shutdown_monitor_thread = Thread(target=shutdown_monitor, args=(), daemon=True, name='shutdown_monitor')
+ shutdown_monitor_thread.start()
+ shutdown_monitor_thread.join()
+ print("[DI-engine dist pipeline]Your RL agent is converged, you can refer to 'log' and 'tensorboard' for details")
+
+
+def dist_launch_learner(
+ filename: str, seed: int, learner_port: int, name: str = None, disable_flask_log: bool = True
+) -> None:
+ set_pkg_seed(seed)
+ if disable_flask_log:
+ log = logging.getLogger('werkzeug')
+ log.disabled = True
+ if name is None:
+ name = 'learner'
+ with open(filename, 'rb') as f:
+ config = pickle.load(f).system[name]
+ # CLI > ENV VARIABLE > CONFIG
+ if learner_port is not None:
+ config.port = learner_port
+ elif os.environ.get('LEARNER_PORT', None):
+ port = os.environ['LEARNER_PORT']
+ if port.isdigit():
+ config.port = int(port)
+ else: # use config pre-defined value
+ assert 'port' in config and np.isscalar(config.port)
+ learner = create_comm_learner(config)
+ learner.start()
+
+
+def dist_launch_collector(
+ filename: str, seed: int, collector_port: int, name: str = None, disable_flask_log: bool = True
+) -> None:
+ set_pkg_seed(seed)
+ if disable_flask_log:
+ log = logging.getLogger('werkzeug')
+ log.disabled = True
+ if name is None:
+ name = 'collector'
+ with open(filename, 'rb') as f:
+ config = pickle.load(f).system[name]
+ # CLI > ENV VARIABLE > CONFIG
+ if collector_port is not None:
+ config.port = collector_port
+ elif os.environ.get('COLLECTOR_PORT', None):
+ port = os.environ['COLLECTOR_PORT']
+ if port.isdigit():
+ config.port = int(port)
+ else: # use config pre-defined value
+ assert 'port' in config and np.isscalar(config.port)
+ collector = create_comm_collector(config)
+ collector.start()
+
+
+def dist_launch_learner_aggregator(
+ filename: str,
+ seed: int,
+ aggregator_host: str,
+ aggregator_port: int,
+ name: str = None,
+ disable_flask_log: bool = True
+) -> None:
+ set_pkg_seed(seed)
+ if disable_flask_log:
+ log = logging.getLogger('werkzeug')
+ log.disabled = True
+ if filename is not None:
+ if name is None:
+ name = 'learner_aggregator'
+ with open(filename, 'rb') as f:
+ config = pickle.load(f).system[name]
+ else:
+ # start without config (create a fake one)
+ host, port = aggregator_host, DEFAULT_K8S_AGGREGATOR_SLAVE_PORT
+ if aggregator_port is not None:
+ port = aggregator_port
+ elif os.environ.get('AGGREGATOR_PORT', None):
+ _port = os.environ['AGGREGATOR_PORT']
+ if _port.isdigit():
+ port = int(_port)
+ config = dict(
+ master=dict(host=host, port=port + 1),
+ slave=dict(host=host, port=port + 0),
+ learner={},
+ )
+ config = EasyDict(config)
+ learner_aggregator = LearnerAggregator(config)
+ learner_aggregator.start()
+
+
+def dist_launch_spawn_learner(
+ filename: str, seed: int, learner_port: int, name: str = None, disable_flask_log: bool = True
+) -> None:
+ current_env = os.environ.copy()
+ local_world_size = int(os.environ.get('LOCAL_WORLD_SIZE', 1))
+ processes = []
+
+ for local_rank in range(0, local_world_size):
+ dist_rank = int(os.environ.get('START_RANK', 0)) + local_rank
+ current_env["RANK"] = str(dist_rank)
+ current_env["LOCAL_RANK"] = str(local_rank)
+
+ executable = subprocess.getoutput('which ding')
+ assert len(executable) > 0, "cannot find executable \"ding\""
+
+ cmd = [executable, '-m', 'dist', '--module', 'learner']
+ if filename is not None:
+ cmd += ['-c', f'{filename}']
+ if seed is not None:
+ cmd += ['-s', f'{seed}']
+ if learner_port is not None:
+ cmd += ['-lp', f'{learner_port}']
+ if name is not None:
+ cmd += ['--module-name', f'{name}']
+ if disable_flask_log is not None:
+ cmd += ['--disable-flask-log', f'{int(disable_flask_log)}']
+
+ sig_names = {2: "SIGINT", 15: "SIGTERM"}
+ last_return_code = None
+
+ def sigkill_handler(signum, frame):
+ for process in processes:
+ print(f"Killing subprocess {process.pid}")
+ try:
+ process.kill()
+ except Exception:
+ pass
+ if last_return_code is not None:
+ raise subprocess.CalledProcessError(returncode=last_return_code, cmd=cmd)
+ if signum in sig_names:
+ print(f"Main process received {sig_names[signum]}, exiting")
+ sys.exit(1)
+
+ # pass SIGINT/SIGTERM to children if the parent is being terminated
+ signal.signal(signal.SIGINT, sigkill_handler)
+ signal.signal(signal.SIGTERM, sigkill_handler)
+
+ process = subprocess.Popen(cmd, env=current_env, stdout=None, stderr=None)
+ processes.append(process)
+
+ try:
+ alive_processes = set(processes)
+ while len(alive_processes):
+ finished_processes = []
+ for process in alive_processes:
+ if process.poll() is None:
+ # the process is still running
+ continue
+ else:
+ if process.returncode != 0:
+ last_return_code = process.returncode # for sigkill_handler
+ sigkill_handler(signal.SIGTERM, None) # not coming back
+ else:
+ # exited cleanly
+ finished_processes.append(process)
+ alive_processes = set(alive_processes) - set(finished_processes)
+
+ time.sleep(1)
+ finally:
+ # close open file descriptors
+ pass
+
+
+def dist_add_replicas(
+ replicas_type: str,
+ kubeconfig: str,
+ replicas: int,
+ coordinator_name: str,
+ namespace: str,
+ cpus: int,
+ gpus: int,
+ memory: str,
+) -> None:
+ assert coordinator_name and namespace, "Please provide --coordinator-name or --namespace"
+
+ import json
+ data = {
+ "namespace": namespace,
+ "coordinator": coordinator_name,
+ }
+ res = {"replicas": replicas}
+ if cpus > 0:
+ res['cpus'] = cpus
+ if gpus > 0:
+ res['gpus'] = gpus
+ if memory:
+ res['memory'] = memory
+ if replicas_type == 'collector':
+ data['collectors'] = res
+ elif replicas_type == 'learner':
+ data['learners'] = res
+ cmd = 'curl -X POST $KUBERNETES_SERVER_URL/v1alpha1/replicas ' \
+ '-H "content-type: application/json" ' \
+ f'-d \'{json.dumps(data)}\''
+ ret, msg = pod_exec_command(kubeconfig, coordinator_name, namespace, cmd)
+ if ret == 0:
+ print(f'{replicas_type} add successfully')
+ else:
+ print(f'Failed to add {replicas_type}, return code: {ret}, message: {msg}')
+
+
+def dist_delete_replicas(
+ replicas_type: str, kubeconfig: str, replicas: int, coordinator_name: str, namespace: str
+) -> None:
+ assert coordinator_name and namespace, "Please provide --coordinator-name or --namespace"
+
+ import json
+ data = {
+ "namespace": namespace,
+ "coordinator": coordinator_name,
+ }
+ if replicas_type == 'collector':
+ data['collectors'] = {"replicas": replicas}
+ elif replicas_type == 'learner':
+ data['learners'] = {"replicas": replicas}
+ cmd = 'curl -X DELETE $KUBERNETES_SERVER_URL/v1alpha1/replicas ' \
+ '-H "content-type: application/json" ' \
+ f'-d \'{json.dumps(data)}\''
+ ret, msg = pod_exec_command(kubeconfig, coordinator_name, namespace, cmd)
+ if ret == 0:
+ print(f'{replicas_type} delete successfully')
+ else:
+ print(f'Failed to delete {replicas_type}, return code: {ret}, message: {msg}')
+
+
+def dist_restart_replicas(
+ replicas_type: str, kubeconfig: str, coordinator_name: str, namespace: str, restart_pod_name: str
+) -> None:
+ assert coordinator_name and namespace, "Please provide --coordinator-name or --namespace"
+
+ import json
+ data = {
+ "namespace": namespace,
+ "coordinator": coordinator_name,
+ }
+ assert restart_pod_name, "Please provide restart pod name with --restart-pod-name"
+ if replicas_type == 'collector':
+ data['collectors'] = [restart_pod_name]
+ elif replicas_type == 'learner':
+ data['learners'] = [restart_pod_name]
+ cmd = 'curl -X POST $KUBERNETES_SERVER_URL/v1alpha1/replicas/failed ' \
+ '-H "content-type: application/json" ' \
+ f'-d \'{json.dumps(data)}\''
+ ret, msg = pod_exec_command(kubeconfig, coordinator_name, namespace, cmd)
+ if ret == 0:
+ print(f'{replicas_type} restart successfully')
+ else:
+ print(f'Failed to restart {replicas_type}, return code: {ret}, message: {msg}')
diff --git a/DI-engine/ding/entry/parallel_entry.py b/DI-engine/ding/entry/parallel_entry.py
new file mode 100644
index 0000000000000000000000000000000000000000..634a8f0ac9723549c15aef030b16ab8eaa008004
--- /dev/null
+++ b/DI-engine/ding/entry/parallel_entry.py
@@ -0,0 +1,151 @@
+from typing import Optional, Union, Tuple
+import time
+import pickle
+from ditk import logging
+from multiprocessing import Process, Event
+import threading
+from easydict import EasyDict
+
+from ding.worker import create_comm_learner, create_comm_collector, Coordinator
+from ding.config import read_config_with_system, compile_config_parallel
+from ding.utils import set_pkg_seed
+
+
+def parallel_pipeline(
+ input_cfg: Union[str, Tuple[dict, dict, dict]],
+ seed: int,
+ enable_total_log: Optional[bool] = False,
+ disable_flask_log: Optional[bool] = True,
+) -> None:
+ r"""
+ Overview:
+ Parallel pipeline entry.
+ Arguments:
+ - config (:obj:`Union[str, dict]`): Config file path.
+ - seed (:obj:`int`): Random seed.
+ - enable_total_log (:obj:`Optional[bool]`): whether enable total DI-engine system log
+ - disable_flask_log (:obj:`Optional[bool]`): whether disable flask log
+ """
+ # Disable some part of DI-engine log
+ if not enable_total_log:
+ coordinator_log = logging.getLogger('coordinator_logger')
+ coordinator_log.disabled = True
+ # Disable flask logger.
+ if disable_flask_log:
+ log = logging.getLogger('werkzeug')
+ log.disabled = True
+ # Parallel job launch.
+ if isinstance(input_cfg, str):
+ main_cfg, create_cfg, system_cfg = read_config_with_system(input_cfg)
+ elif isinstance(input_cfg, tuple) or isinstance(input_cfg, list):
+ main_cfg, create_cfg, system_cfg = input_cfg
+ else:
+ raise TypeError("invalid config type: {}".format(input_cfg))
+ config = compile_config_parallel(main_cfg, create_cfg=create_cfg, system_cfg=system_cfg, seed=seed)
+ learner_handle = []
+ collector_handle = []
+ for k, v in config.system.items():
+ if 'learner' in k:
+ learner_handle.append(launch_learner(config.seed, v))
+ elif 'collector' in k:
+ collector_handle.append(launch_collector(config.seed, v))
+ launch_coordinator(config.seed, config, learner_handle=learner_handle, collector_handle=collector_handle)
+
+
+# Following functions are used to launch different components(learner, learner aggregator, collector, coordinator).
+# Argument ``config`` is the dict type config. If it is None, then ``filename`` and ``name`` must be passed,
+# for they can be used to read corresponding config from file.
+def run_learner(config, seed, start_learner_event, close_learner_event):
+ set_pkg_seed(seed)
+ log = logging.getLogger('werkzeug')
+ log.disabled = True
+ learner = create_comm_learner(config)
+ learner.start()
+ start_learner_event.set()
+ close_learner_event.wait()
+ learner.close()
+
+
+def launch_learner(
+ seed: int, config: Optional[dict] = None, filename: Optional[str] = None, name: Optional[str] = None
+) -> list:
+ if config is None:
+ with open(filename, 'rb') as f:
+ config = pickle.load(f)[name]
+ start_learner_event = Event()
+ close_learner_event = Event()
+
+ learner_thread = Process(
+ target=run_learner, args=(config, seed, start_learner_event, close_learner_event), name='learner_entry_process'
+ )
+ learner_thread.start()
+ return learner_thread, start_learner_event, close_learner_event
+
+
+def run_collector(config, seed, start_collector_event, close_collector_event):
+ set_pkg_seed(seed)
+ log = logging.getLogger('werkzeug')
+ log.disabled = True
+ collector = create_comm_collector(config)
+ collector.start()
+ start_collector_event.set()
+ close_collector_event.wait()
+ collector.close()
+
+
+def launch_collector(
+ seed: int, config: Optional[dict] = None, filename: Optional[str] = None, name: Optional[str] = None
+) -> list:
+ if config is None:
+ with open(filename, 'rb') as f:
+ config = pickle.load(f)[name]
+ start_collector_event = Event()
+ close_collector_event = Event()
+
+ collector_thread = Process(
+ target=run_collector,
+ args=(config, seed, start_collector_event, close_collector_event),
+ name='collector_entry_process'
+ )
+ collector_thread.start()
+ return collector_thread, start_collector_event, close_collector_event
+
+
+def launch_coordinator(
+ seed: int,
+ config: Optional[EasyDict] = None,
+ filename: Optional[str] = None,
+ learner_handle: Optional[list] = None,
+ collector_handle: Optional[list] = None
+) -> None:
+ set_pkg_seed(seed)
+ if config is None:
+ with open(filename, 'rb') as f:
+ config = pickle.load(f)
+ coordinator = Coordinator(config)
+ for _, start_event, _ in learner_handle:
+ start_event.wait()
+ for _, start_event, _ in collector_handle:
+ start_event.wait()
+ coordinator.start()
+ system_shutdown_event = threading.Event()
+
+ # Monitor thread: Coordinator will remain running until its ``system_shutdown_flag`` is set to False.
+ def shutdown_monitor():
+ while True:
+ time.sleep(3)
+ if coordinator.system_shutdown_flag:
+ coordinator.close()
+ for _, _, close_event in learner_handle:
+ close_event.set()
+ for _, _, close_event in collector_handle:
+ close_event.set()
+ system_shutdown_event.set()
+ break
+
+ shutdown_monitor_thread = threading.Thread(target=shutdown_monitor, args=(), daemon=True, name='shutdown_monitor')
+ shutdown_monitor_thread.start()
+ system_shutdown_event.wait()
+ print(
+ "[DI-engine parallel pipeline]Your RL agent is converged, you can refer to 'log' and 'tensorboard' for details"
+ )
diff --git a/DI-engine/ding/entry/predefined_config.py b/DI-engine/ding/entry/predefined_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..213ca4fcd4100a45ec0ebf9d2f1578fdd762854b
--- /dev/null
+++ b/DI-engine/ding/entry/predefined_config.py
@@ -0,0 +1,35 @@
+from typing import Tuple
+from easydict import EasyDict
+import sys
+import importlib
+
+env_dict = {
+ 'cartpole': 'dizoo.classic_control.cartpole.config',
+ 'pendulum': 'dizoo.classic_control.pendulum.config',
+}
+policy_dict = {
+ 'dqn': 'ding.policy.dqn',
+ 'rainbow': 'ding.policy.rainbow',
+ 'c51': 'ding.policy.c51',
+ 'qrdqn': 'ding.policy.qrdqn',
+ 'iqn': 'ding.policy.iqn',
+ 'a2c': 'ding.policy.a2c',
+ 'impala': 'ding.policy.impala',
+ 'ppo': 'ding.policy.ppo',
+ 'sqn': 'ding.policy.sqn',
+ 'r2d2': 'ding.policy.r2d2',
+ 'ddpg': 'ding.policy.ddpg',
+ 'td3': 'ding.policy.td3',
+ 'sac': 'ding.policy.sac',
+}
+
+
+def get_predefined_config(env: str, policy: str) -> Tuple[EasyDict, EasyDict]:
+ config_name = '{}_{}_config'.format(env, policy)
+ create_config_name = '{}_{}_create_config'.format(env, policy)
+ try:
+ m = importlib.import_module(env_dict[env] + '.' + config_name)
+ return [getattr(m, config_name), getattr(m, create_config_name)]
+ except ImportError:
+ print("Please get started by other types, there is no related pre-defined config({})".format(config_name))
+ sys.exit(1)
diff --git a/DI-engine/ding/entry/serial_entry.py b/DI-engine/ding/entry/serial_entry.py
new file mode 100644
index 0000000000000000000000000000000000000000..929c83a219cd7a65d1b67bbea53d267463054994
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry.py
@@ -0,0 +1,137 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import torch
+from ditk import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector, create_serial_evaluator
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed, get_rank
+from .utils import random_collect
+
+
+def serial_pipeline(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+ dynamic_seed: Optional[bool] = True,
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry for off-policy RL.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ - dynamic_seed(:obj:`Optional[bool]`): set dynamic seed for collector.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ collector_env.seed(cfg.seed, dynamic_seed=dynamic_seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = create_serial_evaluator(
+ cfg.policy.eval.evaluator,
+ env=evaluator_env,
+ policy=policy.eval_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
+ )
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ # Accumulate plenty of data at the beginning of training.
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, eval_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data by default config n_sample/n_episode
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Learn policy from collected data
+ for i in range(cfg.policy.learn.update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ learner.train(train_data, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ replay_buffer.update(learner.priority_info)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ if get_rank() == 0:
+ import time
+ import pickle
+ import numpy as np
+ with open(os.path.join(cfg.exp_name, 'result.pkl'), 'wb') as f:
+ eval_value_raw = eval_info['eval_episode_return']
+ final_data = {
+ 'stop': stop,
+ 'env_step': collector.envstep,
+ 'train_iter': learner.train_iter,
+ 'eval_value': np.mean(eval_value_raw),
+ 'eval_value_raw': eval_value_raw,
+ 'finish_time': time.ctime(),
+ }
+ pickle.dump(final_data, f)
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_bc.py b/DI-engine/ding/entry/serial_entry_bc.py
new file mode 100644
index 0000000000000000000000000000000000000000..152f9f8470cae9f618fec2d9d2738006fe745398
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_bc.py
@@ -0,0 +1,102 @@
+from typing import Union, Optional, Tuple
+import os
+import torch
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+from torch.utils.data import DataLoader
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from ding.utils.data import NaiveRLDataset
+
+
+def serial_pipeline_bc(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int,
+ data_path: str,
+ model: Optional[torch.nn.Module] = None,
+ max_iter=int(1e6),
+) -> Union['Policy', bool]: # noqa
+ r"""
+ Overview:
+ Serial pipeline entry of imitation learning.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - data_path (:obj:`str`): Path of training data.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ - convergence (:obj:`bool`): whether il training is converged
+ """
+ cont = input_cfg[0].policy.continuous
+
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg)
+
+ # Env, Policy
+ env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ # Random seed
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval'])
+
+ # Main components
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ dataset = NaiveRLDataset(data_path)
+ dataloader = DataLoader(dataset[:-len(dataset) // 10], cfg.policy.learn.batch_size, collate_fn=lambda x: x)
+ eval_loader = DataLoader(
+ dataset[-len(dataset) // 10:],
+ cfg.policy.learn.batch_size,
+ )
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ # ==========
+ # Main loop
+ # ==========
+ learner.call_hook('before_run')
+ stop = False
+ iter_cnt = 0
+ for epoch in range(cfg.policy.learn.train_epoch):
+ # Evaluate policy performance
+ loss_list = []
+ for _, bat in enumerate(eval_loader):
+ res = policy._forward_eval(bat['obs'])
+ if cont:
+ loss_list.append(torch.nn.L1Loss()(res['action'], bat['action'].squeeze(-1)).item())
+ else:
+ res = torch.argmax(res['logit'], dim=1)
+ loss_list.append(torch.sum(res == bat['action'].squeeze(-1)).item() / bat['action'].shape[0])
+ if cont:
+ label = 'validation_loss'
+ else:
+ label = 'validation_acc'
+ tb_logger.add_scalar(label, sum(loss_list) / len(loss_list), iter_cnt)
+ for i, train_data in enumerate(dataloader):
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter)
+ if stop:
+ break
+ learner.train(train_data)
+ iter_cnt += 1
+ if iter_cnt >= max_iter:
+ stop = True
+ break
+ if stop:
+ break
+
+ learner.call_hook('after_run')
+ print('final reward is: {}'.format(reward))
+ return policy, stop
diff --git a/DI-engine/ding/entry/serial_entry_bco.py b/DI-engine/ding/entry/serial_entry_bco.py
new file mode 100644
index 0000000000000000000000000000000000000000..756cfa9f6131343813f1d44dfd45350c706c7e17
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_bco.py
@@ -0,0 +1,199 @@
+import os
+import pickle
+import torch
+from functools import partial
+from tensorboardX import SummaryWriter
+from torch.utils.data import DataLoader
+from typing import Union, Optional, List, Any, Tuple, Dict
+
+from ding.worker import BaseLearner, BaseSerialCommander, InteractionSerialEvaluator, create_serial_collector
+from ding.config import read_config, compile_config
+from ding.utils import set_pkg_seed
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.policy.common_utils import default_preprocess_learn
+from ding.policy import create_policy
+from ding.utils.data.dataset import BCODataset
+from ding.world_model.idm import InverseDynamicsModel
+
+
+def load_expertdata(data: Dict[str, torch.Tensor]) -> BCODataset:
+ """
+ loading from demonstration data, which only have obs and next_obs
+ action need to be inferred from Inverse Dynamics Model
+ """
+ post_data = list()
+ for episode in range(len(data)):
+ for transition in data[episode]:
+ transition['episode_id'] = episode
+ post_data.append(transition)
+ post_data = default_preprocess_learn(post_data)
+ return BCODataset(
+ {
+ 'obs': torch.cat((post_data['obs'], post_data['next_obs']), 1),
+ 'episode_id': post_data['episode_id'],
+ 'action': post_data['action']
+ }
+ )
+
+
+def load_agentdata(data) -> BCODataset:
+ """
+ loading from policy data, which only have obs and next_obs as features and action as label
+ """
+ post_data = list()
+ for episode in range(len(data)):
+ for transition in data[episode]:
+ transition['episode_id'] = episode
+ post_data.append(transition)
+ post_data = default_preprocess_learn(post_data)
+ return BCODataset(
+ {
+ 'obs': torch.cat((post_data['obs'], post_data['next_obs']), 1),
+ 'action': post_data['action'],
+ 'episode_id': post_data['episode_id']
+ }
+ )
+
+
+def serial_pipeline_bco(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ expert_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ expert_model: Optional[torch.nn.Module] = None,
+ # model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> None:
+
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ expert_cfg, expert_create_cfg = read_config(expert_cfg)
+ else:
+ cfg, create_cfg = input_cfg
+ expert_cfg, expert_create_cfg = expert_cfg
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ expert_create_cfg.policy.type = expert_create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ expert_cfg = compile_config(
+ expert_cfg, seed=seed, env=env_fn, auto=True, create_cfg=expert_create_cfg, save_cfg=True
+ )
+ # Random seed
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+
+ # Generate Expert Data
+ if cfg.policy.collect.model_path is None:
+ with open(cfg.policy.collect.data_path, 'rb') as f:
+ data = pickle.load(f)
+ expert_learn_dataset = load_expertdata(data)
+ else:
+ expert_policy = create_policy(expert_cfg.policy, model=expert_model, enable_field=['collect'])
+ expert_collector_env = create_env_manager(
+ expert_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]
+ )
+ expert_collector_env.seed(expert_cfg.seed)
+ expert_policy.collect_mode.load_state_dict(torch.load(cfg.policy.collect.model_path, map_location='cpu'))
+
+ expert_collector = create_serial_collector(
+ cfg.policy.collect.collector, # for episode collector
+ env=expert_collector_env,
+ policy=expert_policy.collect_mode,
+ exp_name=expert_cfg.exp_name
+ )
+ # if expert policy is sac, eps kwargs is unexpected
+ if cfg.policy.continuous:
+ expert_data = expert_collector.collect(n_episode=100)
+ else:
+ policy_kwargs = {'eps': 0}
+ expert_data = expert_collector.collect(n_episode=100, policy_kwargs=policy_kwargs)
+ expert_learn_dataset = load_expertdata(expert_data)
+ expert_collector.reset_policy(expert_policy.collect_mode)
+
+ # Main components
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, None, policy=policy.command_mode
+ )
+ learned_model = InverseDynamicsModel(
+ cfg.policy.model.obs_shape, cfg.policy.model.action_shape, cfg.bco.model.idm_encoder_hidden_size_list,
+ cfg.bco.model.action_space
+ )
+ # ==========
+ # Main loop
+ # ==========
+ learner.call_hook('before_run')
+ collect_episode = int(cfg.policy.collect.n_episode * cfg.bco.alpha)
+ init_episode = True
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+
+ if init_episode:
+ new_data = collector.collect(
+ n_episode=cfg.policy.collect.n_episode, train_iter=learner.train_iter, policy_kwargs=collect_kwargs
+ )
+ init_episode = False
+ else:
+ new_data = collector.collect(
+ n_episode=collect_episode, train_iter=learner.train_iter, policy_kwargs=collect_kwargs
+ )
+ learn_dataset = load_agentdata(new_data)
+ learn_dataloader = DataLoader(learn_dataset, cfg.bco.learn.idm_batch_size)
+ for i, train_data in enumerate(learn_dataloader):
+ idm_loss = learned_model.train(
+ train_data,
+ cfg.bco.learn.idm_train_epoch,
+ cfg.bco.learn.idm_learning_rate,
+ cfg.bco.learn.idm_weight_decay,
+ )
+ # tb_logger.add_scalar("learner_iter/idm_loss", idm_loss, learner.train_iter)
+ # tb_logger.add_scalar("learner_step/idm_loss", idm_loss, collector.envstep)
+ # Generate state transitions from demonstrated state trajectories by IDM
+ expert_action_data = learned_model.predict_action(expert_learn_dataset.obs)['action']
+ post_expert_dataset = BCODataset(
+ {
+ # next_obs are deleted
+ 'obs': expert_learn_dataset.obs[:, 0:int(expert_learn_dataset.obs.shape[1] // 2)],
+ 'action': expert_action_data,
+ 'expert_action': expert_learn_dataset.action
+ }
+ ) # post_expert_dataset: Only obs and action are reserved for BC. next_obs are deleted
+ expert_learn_dataloader = DataLoader(post_expert_dataset, cfg.policy.learn.batch_size)
+ # Improve policy using BC
+ for epoch in range(cfg.policy.learn.train_epoch):
+ for i, train_data in enumerate(expert_learn_dataloader):
+ learner.train(train_data, collector.envstep)
+ if cfg.policy.learn.lr_decay:
+ learner.policy.get_attribute('lr_scheduler').step()
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
diff --git a/DI-engine/ding/entry/serial_entry_dqfd.py b/DI-engine/ding/entry/serial_entry_dqfd.py
new file mode 100644
index 0000000000000000000000000000000000000000..99e121ed5bf9507dbfd31f046aca437c3f8aabb2
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_dqfd.py
@@ -0,0 +1,217 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import torch
+import numpy as np
+from ditk import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from .utils import random_collect, mark_not_expert
+
+
+def serial_pipeline_dqfd(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ expert_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ expert_model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline dqfd entry: we create this serial pipeline in order to\
+ implement dqfd in DI-engine. For now, we support the following envs\
+ Cartpole, Lunarlander, Pong, Spaceinvader. The demonstration\
+ data come from the expert model. We use a well-trained model to \
+ generate demonstration data online
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - expert_model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.\
+ The default model is DQN(**cfg.policy.model)
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ expert_cfg, expert_create_cfg = read_config(expert_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ expert_cfg, expert_create_cfg = expert_cfg
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ expert_create_cfg.policy.type = expert_create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ expert_cfg = compile_config(
+ expert_cfg, seed=seed, env=env_fn, auto=True, create_cfg=expert_create_cfg, save_cfg=True
+ )
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ expert_collector_env = create_env_manager(
+ expert_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]
+ )
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ expert_collector_env.seed(cfg.seed)
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ expert_policy = create_policy(expert_cfg.policy, model=expert_model, enable_field=['collect', 'command'])
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+ expert_policy.collect_mode.load_state_dict(torch.load(cfg.policy.collect.model_path, map_location='cpu'))
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ expert_collector = create_serial_collector(
+ expert_cfg.policy.collect.collector,
+ env=expert_collector_env,
+ policy=expert_policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=expert_cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
+ )
+ expert_commander = BaseSerialCommander(
+ expert_cfg.policy.other.commander, learner, expert_collector, evaluator, replay_buffer,
+ expert_policy.command_mode
+ ) # we create this to avoid the issue of eps, this is an issue due to the sample collector part.
+ expert_collect_kwargs = expert_commander.step()
+ if 'eps' in expert_collect_kwargs:
+ expert_collect_kwargs['eps'] = -1
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+ if cfg.policy.learn.expert_replay_buffer_size != 0: # for ablation study
+ dummy_variable = deepcopy(cfg.policy.other.replay_buffer)
+ dummy_variable['replay_buffer_size'] = cfg.policy.learn.expert_replay_buffer_size
+ expert_buffer = create_buffer(dummy_variable, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ expert_data = expert_collector.collect(
+ n_sample=cfg.policy.learn.expert_replay_buffer_size, policy_kwargs=expert_collect_kwargs
+ )
+ for i in range(len(expert_data)):
+ expert_data[i]['is_expert'] = 1 # set is_expert flag(expert 1, agent 0)
+ expert_buffer.push(expert_data, cur_collector_envstep=0)
+ for _ in range(cfg.policy.learn.per_train_iter_k): # pretrain
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Learn policy from collected data
+ # Expert_learner will train ``update_per_collect == 1`` times in one iteration.
+ train_data = expert_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ learner.train(train_data, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ expert_buffer.update(learner.priority_info)
+ learner.priority_info = {}
+ # Accumulate plenty of data at the beginning of training.
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ random_collect(
+ cfg.policy, policy, collector, collector_env, commander, replay_buffer, postprocess_data_fn=mark_not_expert
+ )
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data by default config n_sample/n_episode
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ for i in range(len(new_data)):
+ new_data[i]['is_expert'] = 0 # set is_expert flag(expert 1, agent 0)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Learn policy from collected data
+ for i in range(cfg.policy.learn.update_per_collect):
+ if cfg.policy.learn.expert_replay_buffer_size != 0:
+ # Learner will train ``update_per_collect`` times in one iteration.
+ # The hyperparameter pho, the demo ratio, control the propotion of data coming\
+ # from expert demonstrations versus from the agent's own experience.
+ stats = np.random.choice(
+ (learner.policy.get_attribute('batch_size')), size=(learner.policy.get_attribute('batch_size'))
+ ) < (
+ learner.policy.get_attribute('batch_size')
+ ) * cfg.policy.collect.pho # torch.rand((learner.policy.get_attribute('batch_size')))\
+ # < cfg.policy.collect.pho
+ expert_batch_size = stats[stats].shape[0]
+ demo_batch_size = (learner.policy.get_attribute('batch_size')) - expert_batch_size
+ train_data = replay_buffer.sample(demo_batch_size, learner.train_iter)
+ train_data_demonstration = expert_buffer.sample(expert_batch_size, learner.train_iter)
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ train_data = train_data + train_data_demonstration
+ learner.train(train_data, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ # When collector, set replay_buffer_idx and replay_unique_id for each data item, priority = 1.\
+ # When learner, assign priority for each data item according their loss
+ learner.priority_info_agent = deepcopy(learner.priority_info)
+ learner.priority_info_expert = deepcopy(learner.priority_info)
+ learner.priority_info_agent['priority'] = learner.priority_info['priority'][0:demo_batch_size]
+ learner.priority_info_agent['replay_buffer_idx'] = learner.priority_info['replay_buffer_idx'][
+ 0:demo_batch_size]
+ learner.priority_info_agent['replay_unique_id'] = learner.priority_info['replay_unique_id'][
+ 0:demo_batch_size]
+ learner.priority_info_expert['priority'] = learner.priority_info['priority'][demo_batch_size:]
+ learner.priority_info_expert['replay_buffer_idx'] = learner.priority_info['replay_buffer_idx'][
+ demo_batch_size:]
+ learner.priority_info_expert['replay_unique_id'] = learner.priority_info['replay_unique_id'][
+ demo_batch_size:]
+ # Expert data and demo data update their priority separately.
+ replay_buffer.update(learner.priority_info_agent)
+ expert_buffer.update(learner.priority_info_expert)
+ else:
+ # Learner will train ``update_per_collect`` times in one iteration.
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ learner.train(train_data, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ replay_buffer.update(learner.priority_info)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_gail.py b/DI-engine/ding/entry/serial_entry_gail.py
new file mode 100644
index 0000000000000000000000000000000000000000..4060291fac298cacad2f3e77d9aefe66f2ab7e2a
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_gail.py
@@ -0,0 +1,170 @@
+from typing import Optional, Tuple
+import os
+import torch
+from ditk import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+import numpy as np
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.reward_model import create_reward_model
+from ding.utils import set_pkg_seed
+from ding.entry import collect_demo_data
+from ding.utils import save_file
+from .utils import random_collect
+
+
+def save_reward_model(path, reward_model, weights_name='best'):
+ path = os.path.join(path, 'reward_model', 'ckpt')
+ if not os.path.exists(path):
+ try:
+ os.makedirs(path)
+ except FileExistsError:
+ pass
+ path = os.path.join(path, 'ckpt_{}.pth.tar'.format(weights_name))
+ state_dict = reward_model.state_dict()
+ save_file(path, state_dict)
+ print('Saved reward model ckpt in {}'.format(path))
+
+
+def serial_pipeline_gail(
+ input_cfg: Tuple[dict, dict],
+ expert_cfg: Tuple[dict, dict],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+ collect_data: bool = True,
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry for GAIL reward model.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - expert_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Expert config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ - collect_data (:obj:`bool`): Collect expert data.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ if isinstance(expert_cfg, str):
+ expert_cfg, expert_create_cfg = read_config(expert_cfg)
+ else:
+ expert_cfg, expert_create_cfg = expert_cfg
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg, save_cfg=True)
+ if 'data_path' not in cfg.reward_model:
+ cfg.reward_model.data_path = cfg.exp_name
+ # Load expert data
+ if collect_data:
+ if expert_cfg.policy.get('other', None) is not None and expert_cfg.policy.other.get('eps', None) is not None:
+ expert_cfg.policy.other.eps.collect = -1
+ if expert_cfg.policy.get('load_path', None) is None:
+ expert_cfg.policy.load_path = cfg.reward_model.expert_model_path
+ collect_demo_data(
+ (expert_cfg, expert_create_cfg),
+ seed,
+ state_dict_path=expert_cfg.policy.load_path,
+ expert_data_path=cfg.reward_model.data_path + '/expert_data.pkl',
+ collect_count=cfg.reward_model.collect_count
+ )
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
+ )
+ reward_model = create_reward_model(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger)
+
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ # Accumulate plenty of data at the beginning of training.
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
+ best_reward = -np.inf
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ reward_mean = np.array([r['eval_episode_return'] for r in reward]).mean()
+ if reward_mean >= best_reward:
+ save_reward_model(cfg.exp_name, reward_model, 'best')
+ best_reward = reward_mean
+ if stop:
+ break
+ new_data_count, target_new_data_count = 0, cfg.reward_model.get('target_new_data_count', 1)
+ while new_data_count < target_new_data_count:
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ new_data_count += len(new_data)
+ # collect data for reward_model training
+ reward_model.collect_data(new_data)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # update reward_model
+ reward_model.train()
+ reward_model.clear_data()
+ # Learn policy from collected data
+ for i in range(cfg.policy.learn.update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ # update train_data reward using the augmented reward
+ train_data_augmented = reward_model.estimate(train_data)
+ learner.train(train_data_augmented, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ replay_buffer.update(learner.priority_info)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ save_reward_model(cfg.exp_name, reward_model, 'last')
+ # evaluate
+ # evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_guided_cost.py b/DI-engine/ding/entry/serial_entry_guided_cost.py
new file mode 100644
index 0000000000000000000000000000000000000000..a66f4535a2eeac8f47fd8eb3acc1b49a4c98146b
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_guided_cost.py
@@ -0,0 +1,162 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import copy
+import torch
+from ditk import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.reward_model import create_reward_model
+from ding.utils import set_pkg_seed, save_file
+from .utils import random_collect
+
+
+def serial_pipeline_guided_cost(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ expert_model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline guided cost: we create this serial pipeline in order to\
+ implement guided cost learning in DI-engine. For now, we support the following envs\
+ Cartpole, Lunarlander, Hopper, Halfcheetah, Walker2d. The demonstration\
+ data come from the expert model. We use a well-trained model to \
+ generate demonstration data online
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - expert_model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.\
+ The default model is DQN(**cfg.policy.model)
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ expert_collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ expert_collector_env.seed(cfg.seed)
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ expert_policy = create_policy(cfg.policy, model=expert_model, enable_field=['learn', 'collect'])
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+ expert_policy.collect_mode.load_state_dict(torch.load(cfg.policy.collect.model_path, map_location='cpu'))
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ expert_collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=expert_collector_env,
+ policy=expert_policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ expert_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
+ )
+
+ reward_model = create_reward_model(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger)
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ # Accumulate plenty of data at the beginning of training.
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
+ dirname = cfg.exp_name + '/reward_model'
+ if not os.path.exists(dirname):
+ try:
+ os.makedirs(dirname)
+ except FileExistsError:
+ pass
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data by default config n_sample/n_episode
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ # NOTE: deepcopy data is very important,
+ # otherwise the data in the replay buffer will be incorrectly modified.
+ # NOTE: this line cannot move to line130, because in line134 the data may be modified in-place.
+ train_data = copy.deepcopy(new_data)
+ expert_data = expert_collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ expert_buffer.push(expert_data, cur_collector_envstep=expert_collector.envstep)
+ # Learn policy from collected data
+ for i in range(cfg.reward_model.update_per_collect):
+ expert_demo = expert_buffer.sample(cfg.reward_model.batch_size, learner.train_iter)
+ samp = replay_buffer.sample(cfg.reward_model.batch_size, learner.train_iter)
+ reward_model.train(expert_demo, samp, learner.train_iter, collector.envstep)
+ for i in range(cfg.policy.learn.update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ _ = reward_model.estimate(train_data)
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ learner.train(train_data, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ replay_buffer.update(learner.priority_info)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+ # save reward model
+ if learner.train_iter % cfg.reward_model.store_model_every_n_train == 0:
+ #if learner.train_iter%5000 == 0:
+ path = os.path.join(dirname, 'iteration_{}.pth.tar'.format(learner.train_iter))
+ state_dict = reward_model.state_dict_reward_model()
+ save_file(path, state_dict)
+ path = os.path.join(dirname, 'final_model.pth.tar')
+ state_dict = reward_model.state_dict_reward_model()
+ save_file(path, state_dict)
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_mbrl.py b/DI-engine/ding/entry/serial_entry_mbrl.py
new file mode 100644
index 0000000000000000000000000000000000000000..03d240c6ea8ee7bf06210f6b4840ea2bcd1ccbba
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_mbrl.py
@@ -0,0 +1,325 @@
+from typing import Union, Optional, List, Any, Tuple
+import torch
+import os
+from functools import partial
+
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ get_buffer_cls, create_serial_collector
+from ding.world_model import WorldModel
+from ding.worker import IBuffer
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.config import read_config, compile_config
+from ding.utils import set_pkg_seed, deep_merge_dicts
+from ding.policy import create_policy
+from ding.world_model import create_world_model
+from ding.entry.utils import random_collect
+
+
+def mbrl_entry_setup(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+) -> Tuple:
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ # create logger
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+
+ # create world model
+ world_model = create_world_model(cfg.world_model, env_fn(cfg.env), tb_logger)
+
+ # create policy
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+
+ # create worker
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ env_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, env_buffer, policy.command_mode
+ )
+
+ return (
+ cfg,
+ policy,
+ world_model,
+ env_buffer,
+ learner,
+ collector,
+ collector_env,
+ evaluator,
+ commander,
+ tb_logger,
+ )
+
+
+def create_img_buffer(
+ cfg: dict, input_cfg: Union[str, Tuple[dict, dict]], world_model: WorldModel, tb_logger: 'SummaryWriter'
+) -> IBuffer: # noqa
+ if isinstance(input_cfg, str):
+ _, create_cfg = read_config(input_cfg)
+ else:
+ _, create_cfg = input_cfg
+ img_buffer_cfg = cfg.world_model.other.imagination_buffer
+ img_buffer_cfg.update(create_cfg.imagination_buffer)
+ buffer_cls = get_buffer_cls(img_buffer_cfg)
+ cfg.world_model.other.imagination_buffer.update(deep_merge_dicts(buffer_cls.default_config(), img_buffer_cfg))
+ if img_buffer_cfg.type == 'elastic':
+ img_buffer_cfg.set_buffer_size = world_model.buffer_size_scheduler
+ img_buffer = create_buffer(cfg.world_model.other.imagination_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ return img_buffer
+
+
+def serial_pipeline_dyna(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry for dyna-style model-based RL.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
+ mbrl_entry_setup(input_cfg, seed, env_setting, model)
+
+ img_buffer = create_img_buffer(cfg, input_cfg, world_model, tb_logger)
+
+ learner.call_hook('before_run')
+
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
+
+ while True:
+ collect_kwargs = commander.step()
+ # eval the policy
+ if evaluator.should_eval(collector.envstep):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+
+ # fill environment buffer
+ data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ env_buffer.push(data, cur_collector_envstep=collector.envstep)
+
+ # eval&train world model and fill imagination buffer
+ if world_model.should_eval(collector.envstep):
+ world_model.eval(env_buffer, collector.envstep, learner.train_iter)
+ if world_model.should_train(collector.envstep):
+ world_model.train(env_buffer, collector.envstep, learner.train_iter)
+ world_model.fill_img_buffer(
+ policy.collect_mode, env_buffer, img_buffer, collector.envstep, learner.train_iter
+ )
+
+ for i in range(cfg.policy.learn.update_per_collect):
+ batch_size = learner.policy.get_attribute('batch_size')
+ train_data = world_model.sample(env_buffer, img_buffer, batch_size, learner.train_iter)
+ learner.train(train_data, collector.envstep)
+
+ if cfg.policy.on_policy:
+ # On-policy algorithm must clear the replay buffer.
+ env_buffer.clear()
+ img_buffer.clear()
+
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ learner.call_hook('after_run')
+
+ return policy
+
+
+def serial_pipeline_dream(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry for dreamer-style model-based RL.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
+ mbrl_entry_setup(input_cfg, seed, env_setting, model)
+
+ learner.call_hook('before_run')
+
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
+
+ while True:
+ collect_kwargs = commander.step()
+ # eval the policy
+ if evaluator.should_eval(collector.envstep):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+
+ # fill environment buffer
+ data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ env_buffer.push(data, cur_collector_envstep=collector.envstep)
+
+ # eval&train world model and fill imagination buffer
+ if world_model.should_eval(collector.envstep):
+ world_model.eval(env_buffer, collector.envstep, learner.train_iter)
+ if world_model.should_train(collector.envstep):
+ world_model.train(env_buffer, collector.envstep, learner.train_iter)
+
+ update_per_collect = cfg.policy.learn.update_per_collect // world_model.rollout_length_scheduler(
+ collector.envstep
+ )
+ update_per_collect = max(1, update_per_collect)
+ for i in range(update_per_collect):
+ batch_size = learner.policy.get_attribute('batch_size')
+ train_data = env_buffer.sample(batch_size, learner.train_iter)
+ # dreamer-style: use pure on-policy imagined rollout to train policy,
+ # which depends on the current envstep to decide the rollout length
+ learner.train(
+ train_data, collector.envstep, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep)
+ )
+
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ learner.call_hook('after_run')
+
+ return policy
+
+
+def serial_pipeline_dreamer(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry for dreamerv3.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ cfg, policy, world_model, env_buffer, learner, collector, collector_env, evaluator, commander, tb_logger = \
+ mbrl_entry_setup(input_cfg, seed, env_setting, model)
+
+ learner.call_hook('before_run')
+
+ # prefill environment buffer
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ cfg.policy.random_collect_size = cfg.policy.random_collect_size // cfg.policy.collect.unroll_len
+ random_collect(cfg.policy, policy, collector, collector_env, commander, env_buffer)
+
+ while True:
+ collect_kwargs = commander.step()
+ # eval the policy
+ if evaluator.should_eval(collector.envstep):
+ stop, reward = evaluator.eval(
+ learner.save_checkpoint,
+ learner.train_iter,
+ collector.envstep,
+ policy_kwargs=dict(world_model=world_model)
+ )
+ if stop:
+ break
+
+ # train world model and fill imagination buffer
+ steps = (
+ cfg.world_model.pretrain
+ if world_model.should_pretrain() else int(world_model.should_train(collector.envstep))
+ )
+ for _ in range(steps):
+ batch_size = learner.policy.get_attribute('batch_size')
+ batch_length = cfg.policy.learn.batch_length
+ post, context = world_model.train(
+ env_buffer, collector.envstep, learner.train_iter, batch_size, batch_length
+ )
+
+ start = post
+
+ learner.train(
+ start, collector.envstep, policy_kwargs=dict(world_model=world_model, envstep=collector.envstep)
+ )
+
+ # fill environment buffer
+ data = collector.collect(
+ train_iter=learner.train_iter,
+ policy_kwargs=dict(world_model=world_model, envstep=collector.envstep, **collect_kwargs)
+ )
+ env_buffer.push(data, cur_collector_envstep=collector.envstep)
+
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ learner.call_hook('after_run')
+
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_ngu.py b/DI-engine/ding/entry/serial_entry_ngu.py
new file mode 100644
index 0000000000000000000000000000000000000000..176f5558cda1995a7fae0462e22a68df9f7b20d0
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_ngu.py
@@ -0,0 +1,171 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import torch
+from ditk import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.reward_model import create_reward_model
+from ding.utils import set_pkg_seed
+from .utils import random_collect
+
+
+def serial_pipeline_ngu(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry for NGU. The corresponding paper is
+ `never give up: learning directed exploration strategies`.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ # if you want to save replay, please uncomment this line
+ # evaluator_env.enable_save_replay(cfg.env.replay_path)
+
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
+ )
+ rnd_reward_model = create_reward_model(cfg.rnd_reward_model, policy.collect_mode.get_attribute('device'), tb_logger)
+ episodic_reward_model = create_reward_model(
+ cfg.episodic_reward_model, policy.collect_mode.get_attribute('device'), tb_logger
+ )
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ # Accumulate plenty of data at the beginning of training.
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
+
+ estimate_cnt = 0
+ iter_ = 0
+ while True:
+ """some hyper-parameters used in NGU"""
+ # index_to_eps = {i: 0.4 ** (1 + 8 * i / (self._env_num - 1)) for i in range(self._env_num)}
+ # index_to_beta = {
+ # i: 0.3 * torch.sigmoid(torch.tensor(10 * (2 * i - (collector_env_num - 2)) / (collector_env_num - 2)))
+ # for i in range(collector_env_num)
+ # }
+ # index_to_gamma = {
+ # i: 1 - torch.exp(
+ # (
+ # (collector_env_num - 1 - i) * torch.log(torch.tensor(1 - 0.997)) +
+ # i * torch.log(torch.tensor(1 - 0.99))
+ # ) / (collector_env_num - 1)
+ # )
+ # for i in range(collector_env_num)
+ # }
+ iter_ += 1
+
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data by default config n_sample/n_episode
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=None)
+
+ # collect data for reward_model training
+ rnd_reward_model.collect_data(new_data)
+ episodic_reward_model.collect_data(new_data)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+
+ # update reward_model
+ rnd_reward_model.train()
+ if (iter_ + 1) % cfg.rnd_reward_model.clear_buffer_per_iters == 0:
+ rnd_reward_model.clear_data()
+ episodic_reward_model.train()
+ if (iter_ + 1) % cfg.episodic_reward_model.clear_buffer_per_iters == 0:
+ episodic_reward_model.clear_data()
+
+ # Learn policy from collected data
+ for i in range(cfg.policy.learn.update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ # calculate the inter-episodic and episodic intrinsic reward
+ rnd_reward = rnd_reward_model.estimate(train_data)
+ episodic_reward = episodic_reward_model.estimate(train_data)
+
+ # update train_data reward using the augmented reward
+ train_data_augmented, estimate_cnt = episodic_reward_model.fusion_reward(
+ train_data,
+ rnd_reward,
+ episodic_reward,
+ nstep=cfg.policy.nstep,
+ collector_env_num=cfg.policy.collect.env_num,
+ tb_logger=tb_logger,
+ estimate_cnt=estimate_cnt
+ )
+ learner.train(train_data_augmented, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ replay_buffer.update(learner.priority_info)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_offline.py b/DI-engine/ding/entry/serial_entry_offline.py
new file mode 100755
index 0000000000000000000000000000000000000000..b92b5c7ddafb56e973ebd7259a52875916cf77f8
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_offline.py
@@ -0,0 +1,117 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import torch
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+from torch.utils.data import DataLoader
+from torch.utils.data.distributed import DistributedSampler
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed, get_world_size, get_rank
+from ding.utils.data import create_dataset
+
+
+def serial_pipeline_offline(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg)
+
+ # Dataset
+ dataset = create_dataset(cfg)
+ sampler, shuffle = None, True
+ if get_world_size() > 1:
+ sampler, shuffle = DistributedSampler(dataset), False
+ dataloader = DataLoader(
+ dataset,
+ # Dividing by get_world_size() here simply to make multigpu
+ # settings mathmatically equivalent to the singlegpu setting.
+ # If the training efficiency is the bottleneck, feel free to
+ # use the original batch size per gpu and increase learning rate
+ # correspondingly.
+ cfg.policy.learn.batch_size // get_world_size(),
+ # cfg.policy.learn.batch_size
+ shuffle=shuffle,
+ sampler=sampler,
+ collate_fn=lambda x: x,
+ pin_memory=cfg.policy.cuda,
+ )
+ # Env, Policy
+ try:
+ if cfg.env.norm_obs.use_norm and cfg.env.norm_obs.offline_stats.use_offline_stats:
+ cfg.env.norm_obs.offline_stats.update({'mean': dataset.mean, 'std': dataset.std})
+ except (KeyError, AttributeError):
+ pass
+ env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env, collect=False)
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ # Random seed
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval'])
+ if cfg.policy.collect.data_type == 'diffuser_traj':
+ policy.init_data_normalizer(dataset.normalizer)
+
+ if hasattr(policy, 'set_statistic'):
+ # useful for setting action bounds for ibc
+ policy.set_statistic(dataset.statistics)
+
+ # Otherwise, directory may conflicts in the multigpu settings.
+ if get_rank() == 0:
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ else:
+ tb_logger = None
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+ stop = False
+
+ for epoch in range(cfg.policy.learn.train_epoch):
+ if get_world_size() > 1:
+ dataloader.sampler.set_epoch(epoch)
+ for train_data in dataloader:
+ learner.train(train_data)
+
+ # Evaluate policy at most once per epoch.
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter)
+
+ if stop or learner.train_iter >= max_train_iter:
+ stop = True
+ break
+
+ learner.call_hook('after_run')
+ print('final reward is: {}'.format(reward))
+ return policy, stop
diff --git a/DI-engine/ding/entry/serial_entry_onpolicy.py b/DI-engine/ding/entry/serial_entry_onpolicy.py
new file mode 100644
index 0000000000000000000000000000000000000000..22e9cf74f987809ccbbe07ea4f7b716f6e5dd72e
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_onpolicy.py
@@ -0,0 +1,115 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import torch
+from ditk import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy, PolicyFactory
+from ding.reward_model import create_reward_model
+from ding.utils import set_pkg_seed
+
+
+def serial_pipeline_onpolicy(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry on-policy RL.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, None, policy.command_mode
+ )
+
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, eval_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data by default config n_sample/n_episode
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+
+ # Learn policy from collected data
+ learner.train(new_data, collector.envstep)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ import time
+ import pickle
+ import numpy as np
+ with open(os.path.join(cfg.exp_name, 'result.pkl'), 'wb') as f:
+ eval_value_raw = eval_info['eval_episode_return']
+ final_data = {
+ 'stop': stop,
+ 'env_step': collector.envstep,
+ 'train_iter': learner.train_iter,
+ 'eval_value': np.mean(eval_value_raw),
+ 'eval_value_raw': eval_value_raw,
+ 'finish_time': time.ctime(),
+ }
+ pickle.dump(final_data, f)
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_onpolicy_ppg.py b/DI-engine/ding/entry/serial_entry_onpolicy_ppg.py
new file mode 100644
index 0000000000000000000000000000000000000000..02c6dee30784dff53708c616f8fa34d7a25ea7e3
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_onpolicy_ppg.py
@@ -0,0 +1,101 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import torch
+import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy, PolicyFactory
+from ding.reward_model import create_reward_model
+from ding.utils import set_pkg_seed
+
+
+def serial_pipeline_onpolicy_ppg(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry on-policy RL.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ collector_env.seed(cfg.seed, dynamic_seed=False)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, None, policy.command_mode
+ )
+
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data by default config n_sample/n_episode
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+
+ # Learn policy from collected data
+ learner.train(new_data, collector.envstep)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_pc.py b/DI-engine/ding/entry/serial_entry_pc.py
new file mode 100644
index 0000000000000000000000000000000000000000..386d6f0ec970dc59ccf8189a85b0709011206287
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_pc.py
@@ -0,0 +1,108 @@
+from typing import Union, Optional, Tuple
+import os
+from functools import partial
+from copy import deepcopy
+
+import torch
+from tensorboardX import SummaryWriter
+from torch.utils.data import DataLoader
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from ding.utils.data.dataset import load_bfs_datasets
+
+
+def serial_pipeline_pc(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ max_iter=int(1e6),
+) -> Union['Policy', bool]: # noqa
+ r"""
+ Overview:
+ Serial pipeline entry of procedure cloning using BFS as expert policy.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_iter (:obj:`Optional[int]`): Max iteration for executing PC training.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ - convergence (:obj:`bool`): whether the training is converged
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg)
+
+ # Env, Policy
+ env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ # Random seed
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval'])
+
+ # Main components
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ train_data, test_data = load_bfs_datasets(train_seeds=cfg.train_seeds)
+ dataloader = DataLoader(train_data, batch_size=cfg.policy.learn.batch_size, shuffle=True)
+ test_dataloader = DataLoader(test_data, batch_size=cfg.policy.learn.batch_size, shuffle=True)
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+
+ # ==========
+ # Main loop
+ # ==========
+ learner.call_hook('before_run')
+ stop = False
+ iter_cnt = 0
+ for epoch in range(cfg.policy.learn.train_epoch):
+ # train
+ criterion = torch.nn.CrossEntropyLoss()
+ for i, train_data in enumerate(dataloader):
+ learner.train(train_data)
+ iter_cnt += 1
+ if iter_cnt >= max_iter:
+ stop = True
+ break
+ if epoch % 69 == 0:
+ policy._optimizer.param_groups[0]['lr'] /= 10
+ if stop:
+ break
+ losses = []
+ acces = []
+ # Evaluation
+ for _, test_data in enumerate(test_dataloader):
+ observations, bfs_input_maps, bfs_output_maps = test_data['obs'], test_data['bfs_in'].long(), \
+ test_data['bfs_out'].long()
+ states = observations
+ bfs_input_onehot = torch.nn.functional.one_hot(bfs_input_maps, 5).float()
+
+ bfs_states = torch.cat([
+ states,
+ bfs_input_onehot,
+ ], dim=-1).cuda()
+ logits = policy._model(bfs_states)['logit']
+ logits = logits.flatten(0, -2)
+ labels = bfs_output_maps.flatten(0, -1).cuda()
+
+ loss = criterion(logits, labels).item()
+ preds = torch.argmax(logits, dim=-1)
+ acc = torch.sum((preds == labels)) / preds.shape[0]
+
+ losses.append(loss)
+ acces.append(acc)
+ print('Test Finished! Loss: {} acc: {}'.format(sum(losses) / len(losses), sum(acces) / len(acces)))
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter)
+ learner.call_hook('after_run')
+ print('final reward is: {}'.format(reward))
+ return policy, stop
diff --git a/DI-engine/ding/entry/serial_entry_plr.py b/DI-engine/ding/entry/serial_entry_plr.py
new file mode 100644
index 0000000000000000000000000000000000000000..388ef0e4f608ed63f1aa48abff16581212c1bece
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_plr.py
@@ -0,0 +1,125 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import torch
+import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy, PolicyFactory
+from ding.reward_model import create_reward_model
+from ding.utils import set_pkg_seed
+from ding.data.level_replay.level_sampler import LevelSampler
+from ding.policy.common_utils import default_preprocess_learn
+
+
+def generate_seeds(num_seeds=500, base_seed=0):
+ return [base_seed + i for i in range(num_seeds)]
+
+
+def serial_pipeline_plr(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry for Priority Level Replay.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ collector_env_num = cfg.env.collector_env_num
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ collector_env.seed(cfg.seed, dynamic_seed=False)
+ evaluator_env.seed(cfg.seed, dynamic_seed=True)
+ train_seeds = generate_seeds()
+ level_sampler = LevelSampler(
+ train_seeds, cfg.policy.model.obs_shape, cfg.policy.model.action_shape, collector_env_num, cfg.level_replay
+ )
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, None, policy.command_mode
+ )
+
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ seeds = [int(level_sampler.sample('sequential')) for _ in range(collector_env_num)]
+ # default_preprocess_learn function can only deal with the Tensor data
+ level_seeds = torch.Tensor(seeds)
+
+ collector_env.seed(seeds)
+ collector_env.reset()
+
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data by default config n_sample/n_episode
+ new_data = collector.collect(
+ train_iter=learner.train_iter, level_seeds=level_seeds, policy_kwargs=collect_kwargs
+ )
+ # Learn policy from collected data
+ learner.train(new_data, collector.envstep)
+ stacked_data = default_preprocess_learn(new_data, ignore_done=cfg.policy.learn.ignore_done, use_nstep=False)
+ level_sampler.update_with_rollouts(stacked_data, collector_env_num)
+ seeds = [int(level_sampler.sample()) for _ in range(collector_env_num)]
+ level_seeds = torch.Tensor(seeds)
+ collector_env.seed(seeds)
+ collector_env.reset()
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_preference_based_irl.py b/DI-engine/ding/entry/serial_entry_preference_based_irl.py
new file mode 100644
index 0000000000000000000000000000000000000000..682e662baa49aabaf2c2fe89a517453578924361
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_preference_based_irl.py
@@ -0,0 +1,133 @@
+import copy
+from typing import Union, Optional, List, Any, Tuple
+import os
+import torch
+from ditk import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy, PolicyFactory
+from ding.reward_model import create_reward_model
+from ding.utils import set_pkg_seed
+
+
+def serial_pipeline_preference_based_irl(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ serial_pipeline_preference_based_irl.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_iterations (:obj:`Optional[torch.nn.Module]`): Learner's max iteration. Pipeline will stop \
+ when reaching this iteration.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ create_cfg.reward_model = dict(type=cfg.reward_model.type)
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=False)
+ cfg_bak = copy.deepcopy(cfg)
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
+ )
+
+ reward_model = create_reward_model(cfg_bak, policy.collect_mode.get_attribute('device'), tb_logger)
+ reward_model.train()
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ # Accumulate plenty of data at the beginning of training.
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ if cfg.policy.get('transition_with_policy_data', False):
+ collector.reset_policy(policy.collect_mode)
+ else:
+ action_space = collector_env.env_info().act_space
+ random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
+ collector.reset_policy(random_policy)
+ collect_kwargs = commander.step()
+ new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
+ replay_buffer.push(new_data, cur_collector_envstep=0)
+ collector.reset_policy(policy.collect_mode)
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data by default config n_sample/n_episode
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Learn policy from collected data
+ for i in range(cfg.policy.learn.update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ # update train_data reward using the augmented reward
+ train_data_augmented = reward_model.estimate(train_data)
+ learner.train(train_data_augmented, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ replay_buffer.update(learner.priority_info)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_preference_based_irl_onpolicy.py b/DI-engine/ding/entry/serial_entry_preference_based_irl_onpolicy.py
new file mode 100644
index 0000000000000000000000000000000000000000..3941f3337e394b1cc0f696e8d4c5847b9adc744b
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_preference_based_irl_onpolicy.py
@@ -0,0 +1,104 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import torch
+from ditk import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy, PolicyFactory
+from ding.reward_model import create_reward_model
+from ding.utils import set_pkg_seed
+
+
+def serial_pipeline_preference_based_irl_onpolicy(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry for preference based irl of on-policy algorithm(such as PPO).
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ create_cfg.reward_model = dict(type=cfg.reward_model.type)
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True, renew_dir=False)
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, None, policy.command_mode
+ )
+ reward_model = create_reward_model(cfg, policy.collect_mode.get_attribute('device'), tb_logger)
+ reward_model.train()
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data by default config n_sample/n_episode
+ new_data = collector.collect(train_iter=learner.train_iter)
+ train_data = new_data
+ # update train_data reward using the augmented reward
+ train_data_augmented = reward_model.estimate(train_data)
+ learner.train(train_data_augmented, collector.envstep)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_r2d3.py b/DI-engine/ding/entry/serial_entry_r2d3.py
new file mode 100644
index 0000000000000000000000000000000000000000..89f2fbb649d616051353e4f8fc5fa3c1c8f3a092
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_r2d3.py
@@ -0,0 +1,223 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import torch
+import numpy as np
+from ditk import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from .utils import random_collect, mark_not_expert
+
+
+def serial_pipeline_r2d3(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ expert_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ expert_model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline r2d3 entry: we create this serial pipeline in order to\
+ implement r2d3 in DI-engine. For now, we support the following envs\
+ Lunarlander, Pong, Qbert. The demonstration\
+ data come from the expert model. We use a well-trained model to \
+ generate demonstration data online
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - expert_model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.\
+ The default model is DQN(**cfg.policy.model)
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ expert_cfg, expert_create_cfg = read_config(expert_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ expert_cfg, expert_create_cfg = expert_cfg
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ expert_create_cfg.policy.type = expert_create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ expert_cfg = compile_config(
+ expert_cfg, seed=seed, env=env_fn, auto=True, create_cfg=expert_create_cfg, save_cfg=True
+ )
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ expert_collector_env = create_env_manager(
+ expert_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]
+ )
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ expert_collector_env.seed(cfg.seed)
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ expert_policy = create_policy(expert_cfg.policy, model=expert_model, enable_field=['collect', 'command'])
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+ expert_policy.collect_mode.load_state_dict(torch.load(expert_cfg.policy.collect.model_path, map_location='cpu'))
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ expert_collector = create_serial_collector(
+ expert_cfg.policy.collect.collector,
+ env=expert_collector_env,
+ policy=expert_policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=expert_cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
+ )
+ expert_commander = BaseSerialCommander(
+ expert_cfg.policy.other.commander, learner, expert_collector, evaluator, replay_buffer,
+ expert_policy.command_mode
+ ) # we create this to avoid the issue of eps, this is an issue due to the sample collector part.
+ expert_collect_kwargs = expert_commander.step()
+ if 'eps' in expert_collect_kwargs:
+ expert_collect_kwargs['eps'] = -1
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+ if expert_cfg.policy.learn.expert_replay_buffer_size != 0: # for ablation study
+ expert_buffer = create_buffer(expert_cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ expert_data = expert_collector.collect(
+ n_sample=expert_cfg.policy.learn.expert_replay_buffer_size,
+ train_iter=learner.train_iter,
+ policy_kwargs=expert_collect_kwargs
+ )
+
+ for i in range(len(expert_data)):
+ # set is_expert flag(expert 1, agent 0)
+ # expert_data[i]['is_expert'] = 1 # for transition-based alg.
+ expert_data[i]['is_expert'] = [1] * expert_cfg.policy.collect.unroll_len # for rnn/sequence-based alg.
+ expert_buffer.push(expert_data, cur_collector_envstep=0)
+ for _ in range(cfg.policy.learn.per_train_iter_k): # pretrain
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Learn policy from collected data
+ # Expert_learner will train ``update_per_collect == 1`` times in one iteration.
+ train_data = expert_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ learner.train(train_data, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ expert_buffer.update(learner.priority_info)
+ learner.priority_info = {}
+ # Accumulate plenty of data at the beginning of training.
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ random_collect(
+ cfg.policy, policy, collector, collector_env, commander, replay_buffer, postprocess_data_fn=mark_not_expert
+ )
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data by default config n_sample/n_episode
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+
+ for i in range(len(new_data)):
+ # set is_expert flag(expert 1, agent 0)
+ new_data[i]['is_expert'] = [0] * expert_cfg.policy.collect.unroll_len
+
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Learn policy from collected data
+ for i in range(cfg.policy.learn.update_per_collect):
+ if expert_cfg.policy.learn.expert_replay_buffer_size != 0:
+ # Learner will train ``update_per_collect`` times in one iteration.
+
+ # The hyperparameter pho, the demo ratio, control the propotion of data coming\
+ # from expert demonstrations versus from the agent's own experience.
+ expert_batch_size = int(
+ np.float32(np.random.rand(learner.policy.get_attribute('batch_size')) < cfg.policy.collect.pho
+ ).sum()
+ )
+ agent_batch_size = (learner.policy.get_attribute('batch_size')) - expert_batch_size
+ train_data_agent = replay_buffer.sample(agent_batch_size, learner.train_iter)
+ train_data_expert = expert_buffer.sample(expert_batch_size, learner.train_iter)
+ if train_data_agent is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ train_data = train_data_agent + train_data_expert
+ learner.train(train_data, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ # When collector, set replay_buffer_idx and replay_unique_id for each data item, priority = 1.\
+ # When learner, assign priority for each data item according their loss
+ learner.priority_info_agent = deepcopy(learner.priority_info)
+ learner.priority_info_expert = deepcopy(learner.priority_info)
+ learner.priority_info_agent['priority'] = learner.priority_info['priority'][0:agent_batch_size]
+ learner.priority_info_agent['replay_buffer_idx'] = learner.priority_info['replay_buffer_idx'][
+ 0:agent_batch_size]
+ learner.priority_info_agent['replay_unique_id'] = learner.priority_info['replay_unique_id'][
+ 0:agent_batch_size]
+
+ learner.priority_info_expert['priority'] = learner.priority_info['priority'][agent_batch_size:]
+ learner.priority_info_expert['replay_buffer_idx'] = learner.priority_info['replay_buffer_idx'][
+ agent_batch_size:]
+ learner.priority_info_expert['replay_unique_id'] = learner.priority_info['replay_unique_id'][
+ agent_batch_size:]
+
+ # Expert data and demo data update their priority separately.
+ replay_buffer.update(learner.priority_info_agent)
+ expert_buffer.update(learner.priority_info_expert)
+ else:
+ # Learner will train ``update_per_collect`` times in one iteration.
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ learner.train(train_data, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ replay_buffer.update(learner.priority_info)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_reward_model_offpolicy.py b/DI-engine/ding/entry/serial_entry_reward_model_offpolicy.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1b4c004b3427f1b18d5eefae5f413172b9d45c1
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_reward_model_offpolicy.py
@@ -0,0 +1,139 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import numpy as np
+import torch
+from ditk import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.reward_model import create_reward_model
+from ding.utils import set_pkg_seed
+from .utils import random_collect
+
+
+def serial_pipeline_reward_model_offpolicy(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry for off-policy RL with reward model.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
+ )
+ reward_model = create_reward_model(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger)
+
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ # Accumulate plenty of data at the beginning of training.
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
+ count = 0
+ best_return = -np.inf
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, eval_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ eval_return_mean = np.mean(eval_info['eval_episode_return'])
+ if eval_return_mean >= best_return:
+ reward_model.save(path=cfg.exp_name, name='best')
+ best_return = eval_return_mean
+ if stop:
+ break
+ new_data_count, target_new_data_count = 0, cfg.reward_model.get('target_new_data_count', 1)
+ while new_data_count < target_new_data_count:
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ new_data_count += len(new_data)
+ # collect data for reward_model training
+ reward_model.collect_data(new_data)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # update reward_model
+ reward_model.train()
+ # clear buffer per fix iters to make sure replay buffer's data count isn't too few.
+ if count % cfg.reward_model.clear_buffer_per_iters == 0:
+ reward_model.clear_data()
+ # Learn policy from collected data
+ for i in range(cfg.policy.learn.update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ # update train_data reward using the augmented reward
+ train_data_augmented = reward_model.estimate(train_data)
+ learner.train(train_data_augmented, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ replay_buffer.update(learner.priority_info)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+ count += 1
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ reward_model.save(path=cfg.exp_name, name='last')
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_reward_model_onpolicy.py b/DI-engine/ding/entry/serial_entry_reward_model_onpolicy.py
new file mode 100644
index 0000000000000000000000000000000000000000..b01864f98f894f2c5e4c194cb2859ce1a16ff94f
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_reward_model_onpolicy.py
@@ -0,0 +1,137 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import numpy as np
+import torch
+from ditk import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.reward_model import create_reward_model
+from ding.utils import set_pkg_seed
+from .utils import random_collect
+
+
+def serial_pipeline_reward_model_onpolicy(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry for on-policy RL with reward model.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
+ )
+ reward_model = create_reward_model(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger)
+
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ # Accumulate plenty of data at the beginning of training.
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
+ count = 0
+ best_return = -np.inf
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, eval_info = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ eval_return_mean = np.mean(eval_info['eval_episode_return'])
+ if eval_return_mean >= best_return:
+ reward_model.save(path=cfg.exp_name, name='best')
+ best_return = eval_return_mean
+ if stop:
+ break
+ new_data_count, target_new_data_count = 0, cfg.reward_model.get('target_new_data_count', 1)
+ while new_data_count < target_new_data_count:
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ new_data_count += len(new_data)
+ # collect data for reward_model training
+ reward_model.collect_data(new_data)
+ # update reward_model
+ reward_model.train()
+ if count % cfg.reward_model.clear_buffer_per_iters == 0:
+ reward_model.clear_data()
+ # Learn policy from collected data
+ for i in range(cfg.policy.learn.update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ train_data = new_data
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ # update train_data reward using the augmented reward
+ train_data_augmented = reward_model.estimate(train_data)
+ learner.train(train_data_augmented, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ replay_buffer.update(learner.priority_info)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+ count += 1
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ reward_model.save(path=cfg.exp_name, name='last')
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_sqil.py b/DI-engine/ding/entry/serial_entry_sqil.py
new file mode 100644
index 0000000000000000000000000000000000000000..6af9fad31caa18599f964b9b5f7234d942e2ab14
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_sqil.py
@@ -0,0 +1,169 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import torch
+from ditk import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from .utils import random_collect
+
+
+def serial_pipeline_sqil(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ expert_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ expert_model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline sqil entry: we create this serial pipeline in order to\
+ implement SQIL in DI-engine. For now, we support the following envs\
+ Cartpole, Lunarlander, Pong, Spaceinvader, Qbert. The demonstration\
+ data come from the expert model. We use a well-trained model to \
+ generate demonstration data online
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - expert_model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.\
+ The default model is DQN(**cfg.policy.model)
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ expert_cfg, expert_create_cfg = read_config(expert_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ expert_cfg, expert_create_cfg = expert_cfg
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ expert_create_cfg.policy.type = expert_create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ expert_cfg = compile_config(
+ expert_cfg, seed=seed, env=env_fn, auto=True, create_cfg=expert_create_cfg, save_cfg=True
+ )
+ # expert config must have the same `n_sample`. The line below ensure we do not need to modify the expert configs
+ expert_cfg.policy.collect.n_sample = cfg.policy.collect.n_sample
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ expert_collector_env = create_env_manager(
+ expert_cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg]
+ )
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ expert_collector_env.seed(cfg.seed)
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ expert_policy = create_policy(expert_cfg.policy, model=expert_model, enable_field=['collect', 'command'])
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+ expert_policy.collect_mode.load_state_dict(torch.load(cfg.policy.collect.model_path, map_location='cpu'))
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ expert_collector = create_serial_collector(
+ expert_cfg.policy.collect.collector,
+ env=expert_collector_env,
+ policy=expert_policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=expert_cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ expert_buffer = create_buffer(expert_cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
+ )
+ expert_commander = BaseSerialCommander(
+ expert_cfg.policy.other.commander, learner, expert_collector, evaluator, replay_buffer,
+ expert_policy.command_mode
+ ) # we create this to avoid the issue of eps, this is an issue due to the sample collector part.
+ expert_collect_kwargs = expert_commander.step()
+ if 'eps' in expert_collect_kwargs:
+ expert_collect_kwargs['eps'] = -1
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ # Accumulate plenty of data at the beginning of training.
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ random_collect(cfg.policy, policy, collector, collector_env, commander, replay_buffer)
+ if cfg.policy.get('expert_random_collect_size', 0) > 0:
+ random_collect(
+ expert_cfg.policy, expert_policy, expert_collector, expert_collector_env, expert_commander, expert_buffer
+ )
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data by default config n_sample/n_episode
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ expert_data = expert_collector.collect(
+ train_iter=learner.train_iter, policy_kwargs=expert_collect_kwargs
+ ) # policy_kwargs={'eps': -1}
+ for i in range(len(new_data)):
+ device_1 = new_data[i]['obs'].device
+ device_2 = expert_data[i]['obs'].device
+ new_data[i]['reward'] = torch.zeros(cfg.policy.nstep).to(device_1)
+ expert_data[i]['reward'] = torch.ones(cfg.policy.nstep).to(device_2)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ expert_buffer.push(expert_data, cur_collector_envstep=collector.envstep)
+ # Learn policy from collected data
+ for i in range(cfg.policy.learn.update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ train_data = replay_buffer.sample((learner.policy.get_attribute('batch_size')) // 2, learner.train_iter)
+ train_data_demonstration = expert_buffer.sample(
+ (learner.policy.get_attribute('batch_size')) // 2, learner.train_iter
+ )
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ train_data = train_data + train_data_demonstration
+ learner.train(train_data, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ replay_buffer.update(learner.priority_info)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/DI-engine/ding/entry/serial_entry_td3_vae.py b/DI-engine/ding/entry/serial_entry_td3_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..505b15d3e7c579e0bfec43d0bc12940f75767d21
--- /dev/null
+++ b/DI-engine/ding/entry/serial_entry_td3_vae.py
@@ -0,0 +1,207 @@
+from typing import Union, Optional, List, Any, Tuple
+import os
+import torch
+from ditk import logging
+import copy
+from functools import partial
+from tensorboardX import SummaryWriter
+from copy import deepcopy
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from .utils import random_collect, mark_not_expert, mark_warm_up
+
+
+def serial_pipeline_td3_vae(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ Serial pipeline entry for VAE latent action.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ if env_setting is None:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ else:
+ env_fn, collector_env_cfg, evaluator_env_cfg = env_setting
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ replay_buffer_recent = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
+ )
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ # Accumulate plenty of data at the beginning of training.
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ # backup
+ # if cfg.policy.get('transition_with_policy_data', False):
+ # collector.reset_policy(policy.collect_mode)
+ # else:
+ # action_space = collector_env.action_space
+ # random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
+ # collector.reset_policy(random_policy)
+ # collect_kwargs = commander.step()
+ # new_data = collector.collect(n_sample=cfg.policy.random_collect_size, policy_kwargs=collect_kwargs)
+ # for item in new_data:
+ # item['warm_up'] = True
+ # replay_buffer.push(new_data, cur_collector_envstep=0)
+ # collector.reset_policy(policy.collect_mode)
+ # postprocess_data_fn = lambda x: mark_warm_up(mark_not_expert(x))
+ random_collect(
+ cfg.policy,
+ policy,
+ collector,
+ collector_env,
+ commander,
+ replay_buffer,
+ postprocess_data_fn=lambda x: mark_warm_up(mark_not_expert(x)) # postprocess_data_fn
+ )
+ # warm_up
+ # Learn policy from collected data
+ for i in range(cfg.policy.learn.warm_up_update):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ learner.train(train_data, collector.envstep)
+
+ if learner.policy.get_attribute('priority'):
+ replay_buffer.update(learner.priority_info)
+ replay_buffer.clear() # NOTE
+
+ # NOTE: for the case collector_env_num>1, because after the random collect phase, self._traj_buffer[env_id] may
+ # be not empty. Only if the condition "timestep.done or len(self._traj_buffer[env_id]) == self._traj_len" is
+ # satisfied, the self._traj_buffer will be clear. For our alg., the data in self._traj_buffer[env_id],
+ # latent_action=False, cannot be used in rl_vae phase.
+ collector.reset(policy.collect_mode)
+
+ count = 0
+ while True:
+ collect_kwargs = commander.step()
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data by default config n_sample/n_episode
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ for item in new_data:
+ item['warm_up'] = False
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ replay_buffer_recent.push(copy.deepcopy(new_data), cur_collector_envstep=collector.envstep)
+
+ # rl phase
+ if count % cfg.policy.learn.rl_vae_update_circle in range(0, cfg.policy.learn.rl_vae_update_circle):
+ # Learn policy from collected data
+ for i in range(cfg.policy.learn.update_per_collect_rl):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is not None:
+ for item in train_data:
+ item['rl_phase'] = True
+ item['vae_phase'] = False
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ learner.train(train_data, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ replay_buffer.update(learner.priority_info)
+
+ # vae phase
+ if count % cfg.policy.learn.rl_vae_update_circle in range(cfg.policy.learn.rl_vae_update_circle - 1,
+ cfg.policy.learn.rl_vae_update_circle):
+ for i in range(cfg.policy.learn.update_per_collect_vae):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ # TODO(pu): different sample style
+ train_data_history = replay_buffer.sample(
+ int(learner.policy.get_attribute('batch_size') / 2), learner.train_iter
+ )
+ train_data_recent = replay_buffer_recent.sample(
+ int(learner.policy.get_attribute('batch_size') / 2), learner.train_iter
+ )
+ train_data = train_data_history + train_data_recent
+
+ if train_data is not None:
+ for item in train_data:
+ item['rl_phase'] = False
+ item['vae_phase'] = True
+ if train_data is None:
+ # It is possible that replay buffer's data count is too few to train ``update_per_collect`` times
+ logging.warning(
+ "Replay buffer's data can only train for {} steps. ".format(i) +
+ "You can modify data collect config, e.g. increasing n_sample, n_episode."
+ )
+ break
+ learner.train(train_data, collector.envstep)
+ if learner.policy.get_attribute('priority'):
+ replay_buffer.update(learner.priority_info)
+ replay_buffer_recent.clear() # NOTE
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+ count += 1
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/DI-engine/ding/entry/tests/config/agconfig.yaml b/DI-engine/ding/entry/tests/config/agconfig.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..ffb5327574fb0176680409d382419b924c76f536
--- /dev/null
+++ b/DI-engine/ding/entry/tests/config/agconfig.yaml
@@ -0,0 +1,27 @@
+apiVersion: diengine.opendilab.org/v1alpha1
+kind: AggregatorConfig
+metadata:
+ name: aggregator-config
+ namespace: di-system
+spec:
+ aggregator:
+ template:
+ spec:
+ containers:
+ - name: di-container
+ image: diorchestrator/ding:v0.1.1
+ imagePullPolicy: IfNotPresent
+ env:
+ - name: PYTHONUNBUFFERED
+ value: "1"
+ command: ["/bin/bash", "-c",]
+ args:
+ - |
+ # if code has been changed in the mount path, we have to reinstall cli
+ # pip install --no-cache-dir -e .;
+ # pip install --no-cache-dir -e .[common_env]
+
+ ding -m dist --module learner_aggregator
+ ports:
+ - name: di-port
+ containerPort: 22270
diff --git a/DI-engine/ding/entry/tests/config/dijob-cartpole.yaml b/DI-engine/ding/entry/tests/config/dijob-cartpole.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..f98b56517b75fbd9c8dfc4706e61f2a89175d139
--- /dev/null
+++ b/DI-engine/ding/entry/tests/config/dijob-cartpole.yaml
@@ -0,0 +1,200 @@
+apiVersion: diengine.opendilab.org/v1alpha1
+kind: DIJob
+metadata:
+ name: cartpole-dqn
+ labels:
+ run-dijob-type: test
+spec:
+ group: xxx
+ priorityClassName: ""
+ cleanPodPolicy: "Running"
+ volumes:
+ - name: cache-volume
+ emptyDir:
+ medium: Memory
+ sizeLimit: 128Mi
+ - name: work-dir
+ hostPath:
+ path: /data/di-engine
+ coordinator:
+ template:
+ spec:
+ containers:
+ - name: di-container
+ image: diorchestrator/ding:v0.1.1
+ imagePullPolicy: IfNotPresent
+ env:
+ - name: PYTHONUNBUFFERED
+ value: "1"
+ command: ["/bin/bash", "-c",]
+ args:
+ - |
+ cat < cartpole_dqn_config_k8s.py
+ from easydict import EasyDict
+
+ cartpole_dqn_config = dict(
+ exp_name='cartpole_dqn',
+ env=dict(
+ collector_env_num=8,
+ collector_episode_num=2,
+ evaluator_env_num=5,
+ evaluator_episode_num=1,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=3,
+ discount_factor=0.97,
+ learn=dict(
+ batch_size=32,
+ learning_rate=0.001,
+ learner=dict(
+ learner_num=1,
+ send_policy_freq=1,
+ ),
+ ),
+ collect=dict(
+ n_sample=16,
+ collector=dict(
+ collector_num=2,
+ update_policy_second=3,
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=50, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=100000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=100000,
+ enable_track_used_data=False,
+ ),
+ commander=dict(
+ collector_task_space=2,
+ learner_task_space=1,
+ eval_interval=5,
+ ),
+ ),
+ ),
+ )
+ cartpole_dqn_config = EasyDict(cartpole_dqn_config)
+ main_config = cartpole_dqn_config
+
+ cartpole_dqn_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn_command'),
+ learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
+ collector=dict(
+ type='zergling',
+ import_names=['ding.worker.collector.zergling_parallel_collector'],
+ ),
+ commander=dict(
+ type='solo',
+ import_names=['ding.worker.coordinator.solo_parallel_commander'],
+ ),
+ comm_learner=dict(
+ type='flask_fs',
+ import_names=['ding.worker.learner.comm.flask_fs_learner'],
+ ),
+ comm_collector=dict(
+ type='flask_fs',
+ import_names=['ding.worker.collector.comm.flask_fs_collector'],
+ ),
+ )
+ cartpole_dqn_create_config = EasyDict(cartpole_dqn_create_config)
+ create_config = cartpole_dqn_create_config
+
+ cartpole_dqn_system_config = dict(
+ coordinator=dict(
+ operator_server=dict(
+ system_addr='di-server.di-system:8080',
+ api_version='/v1alpha1',
+ init_replicas_request=dict(
+ collectors={
+ "replicas": 2,
+ },
+ learners={
+ "gpus": "0",
+ "replicas": 1,
+ },
+ ),
+ collector_target_num=2,
+ learner_target_num=1,
+ ),
+ ),
+ path_data='./{}/data'.format(main_config.exp_name),
+ path_policy='./{}/policy'.format(main_config.exp_name),
+ communication_mode='auto',
+ learner_gpu_num=1,
+ )
+ cartpole_dqn_system_config = EasyDict(cartpole_dqn_system_config)
+ system_config = cartpole_dqn_system_config
+
+ if __name__ == '__main__':
+ from ding.entry.parallel_entry import parallel_pipeline
+ parallel_pipeline([main_config, create_config, system_config], seed=9)
+ EOF
+
+ ding -m dist --module config -P k8s -c ./cartpole_dqn_config_k8s.py -s 0;
+ ding -m dist --module coordinator -c /ding/cartpole_dqn_config_k8s.py.pkl -s 0 -cdp $COORDINATOR_PORT
+ ports:
+ - name: di-port
+ containerPort: 22270
+ volumeMounts:
+ - name: work-dir
+ mountPath: /ding
+ collector:
+ template:
+ spec:
+ containers:
+ - name: di-container
+ image: diorchestrator/ding:v0.1.1
+ imagePullPolicy: IfNotPresent
+ env:
+ - name: PYTHONUNBUFFERED
+ value: "1"
+ command: ["/bin/bash", "-c",]
+ args:
+ - |
+ ding -m dist --module collector -c /ding/cartpole_dqn_config_k8s.py.pkl -s 0 -clp $COLLECTOR_PORT
+ ports:
+ - name: di-port
+ containerPort: 22270
+ volumeMounts:
+ - name: work-dir
+ mountPath: /ding
+ learner:
+ template:
+ spec:
+ containers:
+ - name: di-container
+ image: diorchestrator/ding:v0.1.1
+ imagePullPolicy: IfNotPresent
+ env:
+ - name: PYTHONUNBUFFERED
+ value: "1"
+ command: ["/bin/bash", "-c",]
+ args:
+ - |
+ ding -m dist --module spawn_learner -c /ding/cartpole_dqn_config_k8s.py.pkl -s 0 -lp $LEARNER_PORT
+ ports:
+ - name: di-port
+ containerPort: 22270
+ volumeMounts:
+ - name: cache-volume
+ mountPath: /dev/shm
+ - name: work-dir
+ mountPath: /ding
\ No newline at end of file
diff --git a/DI-engine/ding/entry/tests/config/k8s-config.yaml b/DI-engine/ding/entry/tests/config/k8s-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c8aceb99ff54a67ff3ecac2e725b4d407dea9319
--- /dev/null
+++ b/DI-engine/ding/entry/tests/config/k8s-config.yaml
@@ -0,0 +1,6 @@
+type: k3s # k3s or local
+name: di-cluster
+servers: 1 # # of k8s masters
+agents: 0 # # of k8s nodes
+preload_images:
+- diorchestrator/ding:v0.1.1 # di-engine image for training should be preloaded
diff --git a/DI-engine/ding/entry/tests/test_application_entry.py b/DI-engine/ding/entry/tests/test_application_entry.py
new file mode 100644
index 0000000000000000000000000000000000000000..9276d9e6e5c7ab0274d453ffddaf9cdcbae5bcd1
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_application_entry.py
@@ -0,0 +1,94 @@
+from copy import deepcopy
+import pytest
+import os
+import pickle
+
+from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \
+ cartpole_ppo_offpolicy_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\
+ cartpole_trex_offppo_create_config
+from dizoo.classic_control.cartpole.envs import CartPoleEnv
+from ding.entry import serial_pipeline, eval, collect_demo_data
+from ding.config import compile_config
+from ding.entry.application_entry import collect_episodic_demo_data, episode_to_transitions
+
+
+@pytest.fixture(scope='module')
+def setup_state_dict():
+ config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
+ try:
+ policy = serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, 'Serial pipeline failure'
+ state_dict = {
+ 'eval': policy.eval_mode.state_dict(),
+ 'collect': policy.collect_mode.state_dict(),
+ }
+ return state_dict
+
+
+@pytest.mark.unittest
+class TestApplication:
+
+ def test_eval(self, setup_state_dict):
+ cfg_for_stop_value = compile_config(
+ cartpole_ppo_offpolicy_config, auto=True, create_cfg=cartpole_ppo_offpolicy_create_config
+ )
+ stop_value = cfg_for_stop_value.env.stop_value
+ config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
+ episode_return = eval(config, seed=0, state_dict=setup_state_dict['eval'])
+ assert episode_return >= stop_value
+ config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
+ episode_return = eval(
+ config,
+ seed=0,
+ env_setting=[CartPoleEnv, None, [{} for _ in range(5)]],
+ state_dict=setup_state_dict['eval']
+ )
+ assert episode_return >= stop_value
+
+ def test_collect_demo_data(self, setup_state_dict):
+ config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
+ collect_count = 16
+ expert_data_path = './expert.data'
+ collect_demo_data(
+ config,
+ seed=0,
+ state_dict=setup_state_dict['collect'],
+ collect_count=collect_count,
+ expert_data_path=expert_data_path
+ )
+ with open(expert_data_path, 'rb') as f:
+ exp_data = pickle.load(f)
+ assert isinstance(exp_data, list)
+ assert isinstance(exp_data[0], dict)
+
+ def test_collect_episodic_demo_data(self, setup_state_dict):
+ config = deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config)
+ config[0].exp_name = 'cartpole_trex_offppo_episodic'
+ collect_count = 16
+ if not os.path.exists('./test_episode'):
+ os.mkdir('./test_episode')
+ expert_data_path = './test_episode/expert.data'
+ collect_episodic_demo_data(
+ config,
+ seed=0,
+ state_dict=setup_state_dict['collect'],
+ expert_data_path=expert_data_path,
+ collect_count=collect_count,
+ )
+ with open(expert_data_path, 'rb') as f:
+ exp_data = pickle.load(f)
+ assert isinstance(exp_data, list)
+ assert isinstance(exp_data[0][0], dict)
+
+ def test_episode_to_transitions(self, setup_state_dict):
+ self.test_collect_episodic_demo_data(setup_state_dict)
+ expert_data_path = './test_episode/expert.data'
+ episode_to_transitions(data_path=expert_data_path, expert_data_path=expert_data_path, nstep=3)
+ with open(expert_data_path, 'rb') as f:
+ exp_data = pickle.load(f)
+ assert isinstance(exp_data, list)
+ assert isinstance(exp_data[0], dict)
+ os.popen('rm -rf ./test_episode/expert.data ckpt* log')
+ os.popen('rm -rf ./test_episode')
diff --git a/DI-engine/ding/entry/tests/test_application_entry_trex_collect_data.py b/DI-engine/ding/entry/tests/test_application_entry_trex_collect_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5cb3c16b96b5ce6b9f863317879327a65f2f497
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_application_entry_trex_collect_data.py
@@ -0,0 +1,65 @@
+from easydict import EasyDict
+import pytest
+from copy import deepcopy
+import os
+from itertools import product
+
+import torch
+
+from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\
+ cartpole_trex_offppo_create_config
+from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config,\
+ cartpole_ppo_offpolicy_create_config
+from ding.entry.application_entry_trex_collect_data import collect_episodic_demo_data_for_trex, trex_collecting_data
+from ding.entry import serial_pipeline
+
+
+@pytest.mark.unittest
+def test_collect_episodic_demo_data_for_trex():
+ exp_name = "test_collect_episodic_demo_data_for_trex_expert"
+ expert_policy_state_dict_path = os.path.join(exp_name, 'expert_policy.pth.tar')
+ config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
+ config[0].exp_name = exp_name
+ expert_policy = serial_pipeline(config, seed=0)
+ torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
+
+ exp_name = "test_collect_episodic_demo_data_for_trex_collect"
+ config = [deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config)]
+ config[0].exp_name = exp_name
+ exp_data = collect_episodic_demo_data_for_trex(
+ config,
+ seed=0,
+ state_dict_path=expert_policy_state_dict_path,
+ collect_count=1,
+ rank=1,
+ )
+ assert isinstance(exp_data, list)
+ assert isinstance(exp_data[0][0], dict)
+ os.popen('rm -rf test_collect_episodic_demo_data_for_trex*')
+
+
+@pytest.mark.unittest
+def test_trex_collecting_data():
+ expert_policy_dir = 'test_trex_collecting_data_expert'
+ config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
+ config[0].exp_name = expert_policy_dir
+ config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
+ serial_pipeline(config, seed=0)
+
+ args = EasyDict(
+ {
+ 'cfg': [deepcopy(cartpole_trex_offppo_config),
+ deepcopy(cartpole_trex_offppo_create_config)],
+ 'seed': 0,
+ 'device': 'cpu'
+ }
+ )
+ exp_name = 'test_trex_collecting_data_collect'
+ args.cfg[0].exp_name = exp_name
+ args.cfg[0].reward_model.reward_model_path = os.path.join(exp_name, "reward_model.pth.tar")
+ args.cfg[0].reward_model.expert_model_path = expert_policy_dir
+ args.cfg[0].reward_model.checkpoint_max = 100
+ args.cfg[0].reward_model.checkpoint_step = 100
+ args.cfg[0].reward_model.num_snippets = 100
+ trex_collecting_data(args=args)
+ os.popen('rm -rf test_trex_collecting_data*')
diff --git a/DI-engine/ding/entry/tests/test_cli_ditask.py b/DI-engine/ding/entry/tests/test_cli_ditask.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bb64e5e6e98ec2503f5a0733a159304045ac279
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_cli_ditask.py
@@ -0,0 +1,37 @@
+from time import sleep
+import pytest
+import pathlib
+import os
+from ding.entry.cli_ditask import _cli_ditask
+
+
+def cli_ditask_main():
+ sleep(0.1)
+
+
+@pytest.mark.unittest
+def test_cli_ditask():
+ kwargs = {
+ "package": os.path.dirname(pathlib.Path(__file__)),
+ "main": "test_cli_ditask.cli_ditask_main",
+ "parallel_workers": 1,
+ "topology": "mesh",
+ "platform": "k8s",
+ "protocol": "tcp",
+ "ports": 50501,
+ "attach_to": "",
+ "address": "127.0.0.1",
+ "labels": "",
+ "node_ids": 0,
+ "mq_type": "nng",
+ "redis_host": "",
+ "redis_port": "",
+ "startup_interval": 1
+ }
+ os.environ["DI_NODES"] = '127.0.0.1'
+ os.environ["DI_RANK"] = '0'
+ try:
+ _cli_ditask(**kwargs)
+ finally:
+ del os.environ["DI_NODES"]
+ del os.environ["DI_RANK"]
diff --git a/DI-engine/ding/entry/tests/test_parallel_entry.py b/DI-engine/ding/entry/tests/test_parallel_entry.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8bd26010f45603af572085b34e009ab43424fe0
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_parallel_entry.py
@@ -0,0 +1,16 @@
+import pytest
+from copy import deepcopy
+from ding.entry import parallel_pipeline
+from dizoo.classic_control.cartpole.config.parallel.cartpole_dqn_config import main_config, create_config,\
+ system_config
+
+
+# @pytest.mark.unittest
+@pytest.mark.execution_timeout(120.0, method='thread')
+def test_dqn():
+ config = tuple([deepcopy(main_config), deepcopy(create_config), deepcopy(system_config)])
+ config[0].env.stop_value = 9
+ try:
+ parallel_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
diff --git a/DI-engine/ding/entry/tests/test_random_collect.py b/DI-engine/ding/entry/tests/test_random_collect.py
new file mode 100644
index 0000000000000000000000000000000000000000..d64d2c6e573eb32da8e0b00e4e880362ff8b1ba2
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_random_collect.py
@@ -0,0 +1,108 @@
+from easydict import EasyDict
+import pytest
+from copy import deepcopy
+from typing import List
+import os
+from functools import partial
+from tensorboardX import SummaryWriter
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseSerialCommander, create_buffer, create_serial_collector
+from ding.config import compile_config
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from ding.entry.utils import random_collect, mark_not_expert, mark_warm_up
+from dizoo.classic_control.cartpole.config.cartpole_c51_config import cartpole_c51_config, cartpole_c51_create_config
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('collector_type', ['sample', 'episode'])
+@pytest.mark.parametrize('transition_with_policy_data', [True, False])
+@pytest.mark.parametrize('data_postprocess', [True, False])
+def test_random_collect(collector_type, transition_with_policy_data, data_postprocess):
+
+ def mark_not_expert_episode(ori_data: List[List[dict]]) -> List[List[dict]]:
+ for i in range(len(ori_data)):
+ for j in range(len(ori_data[i])):
+ # Set is_expert flag (expert 1, agent 0)
+ ori_data[i][j]['is_expert'] = 0
+ return ori_data
+
+ def mark_warm_up_episode(ori_data: List[List[dict]]) -> List[List[dict]]:
+ for i in range(len(ori_data)):
+ for j in range(len(ori_data[i])):
+ ori_data[i][j]['warm_up'] = True
+ return ori_data
+
+ RANDOM_COLLECT_SIZE = 8
+ cfg, create_cfg = deepcopy(cartpole_c51_config), deepcopy(cartpole_c51_create_config)
+ cfg.exp_name = "test_cartpole_c51_seed0"
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ cfg.policy.random_collect_size = RANDOM_COLLECT_SIZE
+ cfg.policy.transition_with_policy_data = transition_with_policy_data
+ if collector_type == 'episode':
+ cfg.policy.collect.n_sample = None
+ cfg.policy.collect.n_episode = 1
+ cfg.policy.collect.n_episode = 1
+ cfg.policy.collect.n_episode = 1
+ create_cfg.replay_buffer = EasyDict(type=collector_type)
+ create_cfg.collector = EasyDict(type=collector_type)
+ cfg = compile_config(cfg, seed=0, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, _ = get_vec_env_setting(cfg.env)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ collector_env.seed(cfg.seed)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=None, enable_field=['learn', 'collect', 'eval', 'command'])
+
+ # Create worker components: collector, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = EasyDict(learn_info=dict(learner_step=10, priority_info='no_info', learner_done=False)) # Fake Learner
+ collector = create_serial_collector(
+ cfg.policy.collect.collector,
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator = None # Fake Evaluator
+ replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+ commander = BaseSerialCommander(
+ cfg.policy.other.commander, learner, collector, evaluator, replay_buffer, policy.command_mode
+ )
+
+ if data_postprocess:
+ if collector_type == 'sample':
+ postprocess_data_fn = lambda x: mark_warm_up(mark_not_expert(x))
+ else:
+ postprocess_data_fn = lambda x: mark_warm_up_episode(mark_not_expert_episode(x))
+ else:
+ postprocess_data_fn = None
+
+ # Accumulate plenty of data at the beginning of training.
+ if cfg.policy.get('random_collect_size', 0) > 0:
+ random_collect(
+ cfg.policy,
+ policy,
+ collector,
+ collector_env,
+ commander,
+ replay_buffer,
+ postprocess_data_fn=postprocess_data_fn
+ )
+ assert replay_buffer.count() == RANDOM_COLLECT_SIZE
+ if data_postprocess:
+ if collector_type == 'sample':
+ for d in replay_buffer._data[:RANDOM_COLLECT_SIZE]:
+ assert d['is_expert'] == 0
+ assert d['warm_up'] is True
+ else:
+ for e in replay_buffer._data[:RANDOM_COLLECT_SIZE]:
+ for d in e:
+ assert d['is_expert'] == 0
+ assert d['warm_up'] is True
+
+
+if __name__ == '__main__':
+ test_random_collect()
diff --git a/DI-engine/ding/entry/tests/test_serial_entry.py b/DI-engine/ding/entry/tests/test_serial_entry.py
new file mode 100644
index 0000000000000000000000000000000000000000..d36f6bc7176bdf913b0df618f902abb181fb715c
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_serial_entry.py
@@ -0,0 +1,716 @@
+import pytest
+from itertools import product
+import time
+import os
+from copy import deepcopy
+
+from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_offline
+from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
+from dizoo.classic_control.cartpole.config.cartpole_dqn_stdim_config import cartpole_dqn_stdim_config, \
+ cartpole_dqn_stdim_create_config
+from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config
+from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, \
+ cartpole_ppo_offpolicy_create_config
+from dizoo.classic_control.cartpole.config.cartpole_impala_config import cartpole_impala_config, cartpole_impala_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_rainbow_config import cartpole_rainbow_config, cartpole_rainbow_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_iqn_config import cartpole_iqn_config, cartpole_iqn_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_fqf_config import cartpole_fqf_config, cartpole_fqf_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_c51_config import cartpole_c51_config, cartpole_c51_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import cartpole_qrdqn_config, cartpole_qrdqn_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_sqn_config import cartpole_sqn_config, cartpole_sqn_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_ppg_config import cartpole_ppg_config, cartpole_ppg_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_acer_config import cartpole_acer_config, cartpole_acer_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_sac_config import cartpole_sac_config, cartpole_sac_create_config # noqa
+from dizoo.classic_control.cartpole.entry.cartpole_ppg_main import main as ppg_main
+from dizoo.classic_control.cartpole.entry.cartpole_ppo_main import main as ppo_main
+from dizoo.classic_control.cartpole.config.cartpole_r2d2_config import cartpole_r2d2_config, cartpole_r2d2_create_config # noqa
+from dizoo.classic_control.pendulum.config import pendulum_ddpg_config, pendulum_ddpg_create_config
+from dizoo.classic_control.pendulum.config import pendulum_td3_config, pendulum_td3_create_config
+from dizoo.classic_control.pendulum.config import pendulum_sac_config, pendulum_sac_create_config
+from dizoo.classic_control.pendulum.config import pendulum_d4pg_config, pendulum_d4pg_create_config
+from dizoo.bitflip.config import bitflip_her_dqn_config, bitflip_her_dqn_create_config
+from dizoo.bitflip.entry.bitflip_dqn_main import main as bitflip_dqn_main
+from dizoo.petting_zoo.config import ptz_simple_spread_atoc_config, ptz_simple_spread_atoc_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_collaq_config, ptz_simple_spread_collaq_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_coma_config, ptz_simple_spread_coma_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_qmix_config, ptz_simple_spread_qmix_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_qtran_config, ptz_simple_spread_qtran_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_vdn_config, ptz_simple_spread_vdn_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_wqmix_config, ptz_simple_spread_wqmix_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_madqn_config, ptz_simple_spread_madqn_create_config # noqa
+from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config
+from dizoo.league_demo.selfplay_demo_ppo_main import main as selfplay_main
+from dizoo.league_demo.league_demo_ppo_main import main as league_main
+from dizoo.classic_control.pendulum.config.pendulum_sac_data_generation_config import pendulum_sac_data_genearation_config, pendulum_sac_data_genearation_create_config # noqa
+from dizoo.classic_control.pendulum.config.pendulum_cql_config import pendulum_cql_config, pendulum_cql_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_qrdqn_generation_data_config import cartpole_qrdqn_generation_data_config, cartpole_qrdqn_generation_data_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_cql_config import cartpole_discrete_cql_config, cartpole_discrete_cql_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_dt_config import cartpole_discrete_dt_config, cartpole_discrete_dt_create_config # noqa
+from dizoo.classic_control.pendulum.config.pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config # noqa
+from dizoo.classic_control.pendulum.config.pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config # noqa
+from dizoo.classic_control.pendulum.config.pendulum_ibc_config import pendulum_ibc_config, pendulum_ibc_create_config
+from dizoo.gym_hybrid.config.gym_hybrid_ddpg_config import gym_hybrid_ddpg_config, gym_hybrid_ddpg_create_config
+from dizoo.gym_hybrid.config.gym_hybrid_pdqn_config import gym_hybrid_pdqn_config, gym_hybrid_pdqn_create_config
+from dizoo.gym_hybrid.config.gym_hybrid_mpdqn_config import gym_hybrid_mpdqn_config, gym_hybrid_mpdqn_create_config
+from dizoo.classic_control.pendulum.config.pendulum_bdq_config import pendulum_bdq_config, pendulum_bdq_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_mdqn_config import cartpole_mdqn_config, cartpole_mdqn_create_config
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_dqn():
+ config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].exp_name = 'cartpole_dqn_unittest'
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf cartpole_dqn_unittest')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_mdqn():
+ config = [deepcopy(cartpole_mdqn_config), deepcopy(cartpole_mdqn_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].exp_name = 'cartpole_mdqn_unittest'
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1, dynamic_seed=False)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf cartpole_mdqn_unittest')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_bdq():
+ config = [deepcopy(pendulum_bdq_config), deepcopy(pendulum_bdq_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].exp_name = 'pendulum_bdq_unittest'
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf pendulum_bdq_unittest')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_ddpg():
+ config = [deepcopy(pendulum_ddpg_config), deepcopy(pendulum_ddpg_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+# @pytest.mark.platformtest
+# @pytest.mark.unittest
+def test_hybrid_ddpg():
+ config = [deepcopy(gym_hybrid_ddpg_config), deepcopy(gym_hybrid_ddpg_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+# @pytest.mark.platformtest
+# @pytest.mark.unittest
+def test_hybrid_pdqn():
+ config = [deepcopy(gym_hybrid_pdqn_config), deepcopy(gym_hybrid_pdqn_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+# @pytest.mark.platformtest
+# @pytest.mark.unittest
+def test_hybrid_mpdqn():
+ config = [deepcopy(gym_hybrid_mpdqn_config), deepcopy(gym_hybrid_mpdqn_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_dqn_stdim():
+ config = [deepcopy(cartpole_dqn_stdim_config), deepcopy(cartpole_dqn_stdim_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].exp_name = 'cartpole_dqn_stdim_unittest'
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf cartpole_dqn_stdim_unittest')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_td3():
+ config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_rainbow():
+ config = [deepcopy(cartpole_rainbow_config), deepcopy(cartpole_rainbow_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_iqn():
+ config = [deepcopy(cartpole_iqn_config), deepcopy(cartpole_iqn_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_fqf():
+ config = [deepcopy(cartpole_fqf_config), deepcopy(cartpole_fqf_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_c51():
+ config = [deepcopy(cartpole_c51_config), deepcopy(cartpole_c51_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_qrdqn():
+ config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_ppo():
+ config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].exp_name = 'ppo_offpolicy_unittest'
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_ppo_nstep_return():
+ config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].policy.nstep_return = True
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_sac():
+ config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].policy.learn.auto_alpha = False
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_sac_auto_alpha():
+ config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].policy.learn.auto_alpha = True
+ config[0].policy.learn.log_space = False
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_sac_log_space():
+ config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].policy.learn.auto_alpha = True
+ config[0].policy.learn.log_space = True
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_discrete_sac():
+ auto_alpha, log_space = True, False
+ config = [deepcopy(cartpole_sac_config), deepcopy(cartpole_sac_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].policy.learn.auto_alpha = auto_alpha
+ config[0].policy.learn.log_space = log_space
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_discrete_sac_twin_critic():
+ config = [deepcopy(cartpole_sac_config), deepcopy(cartpole_sac_create_config)]
+ config[0].cuda = True
+ config[0].policy.learn.update_per_collect = 1
+ config[0].policy.learn.auto_alpha = True
+ config[0].policy.learn.log_space = True
+ config[0].policy.model.twin_critic = False
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_r2d2():
+ config = [deepcopy(cartpole_r2d2_config), deepcopy(cartpole_r2d2_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=5)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_impala():
+ config = [deepcopy(cartpole_impala_config), deepcopy(cartpole_impala_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_her_dqn():
+ bitflip_her_dqn_config.policy.cuda = False
+ try:
+ bitflip_dqn_main(bitflip_her_dqn_config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_collaq():
+ config = [deepcopy(ptz_simple_spread_collaq_config), deepcopy(ptz_simple_spread_collaq_create_config)]
+ config[0].policy.cuda = False
+ config[0].policy.learn.update_per_collect = 1
+ config[0].policy.collect.n_sample = 100
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf log ckpt*')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_coma():
+ config = [deepcopy(ptz_simple_spread_coma_config), deepcopy(ptz_simple_spread_coma_create_config)]
+ config[0].policy.cuda = False
+ config[0].policy.learn.update_per_collect = 1
+ config[0].policy.collect.n_sample = 100
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf log ckpt*')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_qmix():
+ config = [deepcopy(ptz_simple_spread_qmix_config), deepcopy(ptz_simple_spread_qmix_create_config)]
+ config[0].policy.cuda = False
+ config[0].policy.learn.update_per_collect = 1
+ config[0].policy.collect.n_sample = 100
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf log ckpt*')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_wqmix():
+ config = [deepcopy(ptz_simple_spread_wqmix_config), deepcopy(ptz_simple_spread_wqmix_create_config)]
+ config[0].policy.cuda = False
+ config[0].policy.learn.update_per_collect = 1
+ config[0].policy.collect.n_sample = 100
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf log ckpt*')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_madqn():
+ config = [deepcopy(ptz_simple_spread_madqn_config), deepcopy(ptz_simple_spread_madqn_create_config)]
+ config[0].policy.cuda = False
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf log ckpt*')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_qtran():
+ config = [deepcopy(ptz_simple_spread_qtran_config), deepcopy(ptz_simple_spread_qtran_create_config)]
+ config[0].policy.cuda = False
+ config[0].policy.learn.update_per_collect = 1
+ config[0].policy.collect.n_sample = 100
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf log ckpt*')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_atoc():
+ config = [deepcopy(ptz_simple_spread_atoc_config), deepcopy(ptz_simple_spread_atoc_create_config)]
+ config[0].policy.cuda = False
+ config[0].policy.collect.n_sample = 100
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf log ckpt*')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_ppg():
+ cartpole_ppg_config.policy.use_cuda = False
+ try:
+ ppg_main(cartpole_ppg_config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_sqn():
+ config = [deepcopy(cartpole_sqn_config), deepcopy(cartpole_sqn_create_config)]
+ config[0].policy.learn.update_per_collect = 8
+ config[0].policy.learn.batch_size = 8
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=2)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf log ckpt*')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_selfplay():
+ try:
+ selfplay_main(deepcopy(league_demo_ppo_config), seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_league():
+ try:
+ league_main(deepcopy(league_demo_ppo_config), seed=0, max_train_iter=1)
+ except Exception as e:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_acer():
+ config = [deepcopy(cartpole_acer_config), deepcopy(cartpole_acer_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_cql():
+ # train expert
+ config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].exp_name = 'sac_unittest'
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+ # collect expert data
+ import torch
+ config = [deepcopy(pendulum_sac_data_genearation_config), deepcopy(pendulum_sac_data_genearation_create_config)]
+ collect_count = 1000
+ expert_data_path = config[0].policy.collect.save_path
+ state_dict = torch.load('./sac_unittest/ckpt/iteration_0.pth.tar', map_location='cpu')
+ try:
+ collect_demo_data(
+ config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
+ )
+ except Exception:
+ assert False, "pipeline fail"
+
+ # test cql
+ config = [deepcopy(pendulum_cql_config), deepcopy(pendulum_cql_create_config)]
+ config[0].policy.learn.train_epoch = 1
+ config[0].policy.eval.evaluator.eval_freq = 1
+ try:
+ serial_pipeline_offline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_ibc():
+ # train expert
+ config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].exp_name = 'sac_unittest'
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+ # collect expert data
+ import torch
+ config = [deepcopy(pendulum_sac_data_genearation_config), deepcopy(pendulum_sac_data_genearation_create_config)]
+ collect_count = 1000
+ expert_data_path = config[0].policy.collect.save_path
+ state_dict = torch.load('./sac_unittest/ckpt/iteration_0.pth.tar', map_location='cpu')
+ try:
+ collect_demo_data(
+ config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
+ )
+ except Exception:
+ assert False, "pipeline fail"
+
+ # test cql
+ config = [deepcopy(pendulum_ibc_config), deepcopy(pendulum_ibc_create_config)]
+ config[0].policy.learn.train_epoch = 1
+ config[0].policy.eval.evaluator.eval_freq = 1
+ config[0].policy.model.stochastic_optim.iters = 2
+ try:
+ serial_pipeline_offline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_d4pg():
+ config = [deepcopy(pendulum_d4pg_config), deepcopy(pendulum_d4pg_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception as e:
+ assert False, "pipeline fail"
+ print(repr(e))
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_discrete_cql():
+ # train expert
+ config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].exp_name = 'cql_cartpole'
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ # collect expert data
+ import torch
+ config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)]
+ state_dict = torch.load('./cql_cartpole/ckpt/iteration_0.pth.tar', map_location='cpu')
+ try:
+ collect_demo_data(config, seed=0, collect_count=1000, state_dict=state_dict)
+ except Exception as e:
+ assert False, "pipeline fail"
+ print(repr(e))
+
+ # train cql
+ config = [deepcopy(cartpole_discrete_cql_config), deepcopy(cartpole_discrete_cql_create_config)]
+ config[0].policy.learn.train_epoch = 1
+ config[0].policy.eval.evaluator.eval_freq = 1
+ try:
+ serial_pipeline_offline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf cartpole cartpole_cql')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_discrete_dt():
+ # train expert
+ config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].exp_name = 'dt_cartpole'
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ # collect expert data
+ import torch
+ config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)]
+ state_dict = torch.load('./dt_cartpole/ckpt/iteration_0.pth.tar', map_location='cpu')
+ try:
+ collect_demo_data(config, seed=0, collect_count=1000, state_dict=state_dict)
+ except Exception as e:
+ assert False, "pipeline fail"
+ print(repr(e))
+
+ # train dt
+ config = [deepcopy(cartpole_discrete_dt_config), deepcopy(cartpole_discrete_dt_create_config)]
+ config[0].policy.eval.evaluator.eval_freq = 5
+ try:
+ from ding.framework import task, ding_init
+ from ding.framework.context import OfflineRLContext
+ from ding.envs import SubprocessEnvManagerV2, BaseEnvManagerV2
+ from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper
+ from dizoo.classic_control.cartpole.envs import CartPoleEnv
+ from ding.utils import set_pkg_seed
+ from ding.data import create_dataset
+ from ding.config import compile_config
+ from ding.model import DecisionTransformer
+ from ding.policy import DTPolicy
+ from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \
+ OfflineMemoryDataFetcher, offline_logger, termination_checker
+ ding_init(config[0])
+ config = compile_config(config[0], create_cfg=config[1], auto=True)
+ with task.start(async_mode=False, ctx=OfflineRLContext()):
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: AllinObsWrapper(CartPoleEnv(config.env)) for _ in range(config.env.evaluator_env_num)],
+ cfg=config.env.manager
+ )
+
+ set_pkg_seed(config.seed, use_cuda=config.policy.cuda)
+
+ dataset = create_dataset(config)
+
+ model = DecisionTransformer(**config.policy.model)
+ policy = DTPolicy(config.policy, model=model)
+
+ task.use(termination_checker(max_train_iter=1))
+ task.use(interaction_evaluator(config, policy.eval_mode, evaluator_env))
+ task.use(OfflineMemoryDataFetcher(config, dataset))
+ task.use(trainer(config, policy.learn_mode))
+ task.use(CkptSaver(policy, config.exp_name, train_freq=100))
+ task.use(offline_logger())
+ task.run()
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf cartpole cartpole_dt')
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_td3_bc():
+ # train expert
+ config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)]
+ config[0].exp_name = 'td3'
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+ # collect expert data
+ import torch
+ config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)]
+ state_dict = torch.load('./td3/ckpt/iteration_0.pth.tar', map_location='cpu')
+ try:
+ collect_demo_data(config, seed=0, collect_count=1000, state_dict=state_dict)
+ except Exception:
+ assert False, "pipeline fail"
+
+ # train td3 bc
+ config = [deepcopy(pendulum_td3_bc_config), deepcopy(pendulum_td3_bc_create_config)]
+ config[0].exp_name = 'td3_bc'
+ config[0].policy.learn.train_epoch = 1
+ config[0].policy.eval.evaluator.eval_freq = 1
+ try:
+ serial_pipeline_offline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf td3 td3_bc')
diff --git a/DI-engine/ding/entry/tests/test_serial_entry_algo.py b/DI-engine/ding/entry/tests/test_serial_entry_algo.py
new file mode 100644
index 0000000000000000000000000000000000000000..640b1e800ce46da08a8f23e3c8bd16744d3316da
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_serial_entry_algo.py
@@ -0,0 +1,524 @@
+import pytest
+import time
+import os
+import torch
+import subprocess
+from copy import deepcopy
+
+from ding.entry import serial_pipeline, serial_pipeline_offline, collect_demo_data, serial_pipeline_onpolicy
+from ding.entry.serial_entry_sqil import serial_pipeline_sqil
+from dizoo.classic_control.cartpole.config.cartpole_sql_config import cartpole_sql_config, cartpole_sql_create_config
+from dizoo.classic_control.cartpole.config.cartpole_sqil_config import cartpole_sqil_config, cartpole_sqil_create_config
+from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
+from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config
+from dizoo.classic_control.cartpole.config.cartpole_pg_config import cartpole_pg_config, cartpole_pg_create_config
+from dizoo.classic_control.cartpole.config.cartpole_a2c_config import cartpole_a2c_config, cartpole_a2c_create_config
+from dizoo.classic_control.cartpole.config.cartpole_impala_config import cartpole_impala_config, cartpole_impala_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_rainbow_config import cartpole_rainbow_config, cartpole_rainbow_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_iqn_config import cartpole_iqn_config, cartpole_iqn_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_c51_config import cartpole_c51_config, cartpole_c51_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import cartpole_qrdqn_config, cartpole_qrdqn_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_sqn_config import cartpole_sqn_config, cartpole_sqn_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_ppg_config import cartpole_ppg_config, cartpole_ppg_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_acer_config import cartpole_acer_config, cartpole_acer_create_config # noqa
+from dizoo.classic_control.cartpole.entry.cartpole_ppg_main import main as ppg_main
+from dizoo.classic_control.cartpole.entry.cartpole_ppo_main import main as ppo_main
+from dizoo.classic_control.cartpole.config.cartpole_r2d2_config import cartpole_r2d2_config, cartpole_r2d2_create_config # noqa
+from dizoo.classic_control.pendulum.config import pendulum_ddpg_config, pendulum_ddpg_create_config
+from dizoo.classic_control.pendulum.config import pendulum_td3_config, pendulum_td3_create_config
+from dizoo.classic_control.pendulum.config import pendulum_sac_config, pendulum_sac_create_config
+from dizoo.bitflip.config import bitflip_her_dqn_config, bitflip_her_dqn_create_config
+from dizoo.bitflip.entry.bitflip_dqn_main import main as bitflip_dqn_main
+from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config
+from dizoo.league_demo.selfplay_demo_ppo_main import main as selfplay_main
+from dizoo.league_demo.league_demo_ppo_main import main as league_main
+from dizoo.classic_control.pendulum.config.pendulum_sac_data_generation_config import pendulum_sac_data_genearation_config, pendulum_sac_data_genearation_create_config # noqa
+from dizoo.classic_control.pendulum.config.pendulum_cql_config import pendulum_cql_config, pendulum_cql_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_qrdqn_generation_data_config import cartpole_qrdqn_generation_data_config, cartpole_qrdqn_generation_data_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_cql_config import cartpole_discrete_cql_config, cartpole_discrete_cql_create_config # noqa
+from dizoo.classic_control.pendulum.config.pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config # noqa
+from dizoo.classic_control.pendulum.config.pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_atoc_config, ptz_simple_spread_atoc_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_atoc_config, ptz_simple_spread_collaq_config, ptz_simple_spread_collaq_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_coma_config, ptz_simple_spread_coma_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_qmix_config, ptz_simple_spread_qmix_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_qtran_config, ptz_simple_spread_qtran_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_vdn_config, ptz_simple_spread_vdn_create_config # noqa
+from dizoo.petting_zoo.config import ptz_simple_spread_wqmix_config, ptz_simple_spread_wqmix_create_config # noqa
+from dizoo.classic_control.cartpole.config import cartpole_mdqn_config, cartpole_mdqn_create_config
+
+with open("./algo_record.log", "w+") as f:
+ f.write("ALGO TEST STARTS\n")
+
+
+@pytest.mark.algotest
+def test_dqn():
+ config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("1. dqn\n")
+
+
+@pytest.mark.algotest
+def test_ddpg():
+ config = [deepcopy(pendulum_ddpg_config), deepcopy(pendulum_ddpg_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("2. ddpg\n")
+
+
+@pytest.mark.algotest
+def test_td3():
+ config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("3. td3\n")
+
+
+@pytest.mark.algotest
+def test_a2c():
+ config = [deepcopy(cartpole_a2c_config), deepcopy(cartpole_a2c_create_config)]
+ try:
+ serial_pipeline_onpolicy(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("4. a2c\n")
+
+
+@pytest.mark.algotest
+def test_rainbow():
+ config = [deepcopy(cartpole_rainbow_config), deepcopy(cartpole_rainbow_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("5. rainbow\n")
+
+
+@pytest.mark.algotest
+def test_ppo():
+ config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
+ try:
+ ppo_main(config[0], seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("6. ppo\n")
+
+
+# @pytest.mark.algotest
+def test_collaq():
+ config = [deepcopy(ptz_simple_spread_collaq_config), deepcopy(ptz_simple_spread_collaq_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("7. collaq\n")
+
+
+# @pytest.mark.algotest
+def test_coma():
+ config = [deepcopy(ptz_simple_spread_coma_config), deepcopy(ptz_simple_spread_coma_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("8. coma\n")
+
+
+@pytest.mark.algotest
+def test_sac():
+ config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("9. sac\n")
+
+
+@pytest.mark.algotest
+def test_c51():
+ config = [deepcopy(cartpole_c51_config), deepcopy(cartpole_c51_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("10. c51\n")
+
+
+@pytest.mark.algotest
+def test_r2d2():
+ config = [deepcopy(cartpole_r2d2_config), deepcopy(cartpole_r2d2_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("11. r2d2\n")
+
+
+@pytest.mark.algotest
+def test_pg():
+ config = [deepcopy(cartpole_pg_config), deepcopy(cartpole_pg_create_config)]
+ try:
+ serial_pipeline_onpolicy(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("12. pg\n")
+
+
+# @pytest.mark.algotest
+def test_atoc():
+ config = [deepcopy(ptz_simple_spread_atoc_config), deepcopy(ptz_simple_spread_atoc_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("13. atoc\n")
+
+
+# @pytest.mark.algotest
+def test_vdn():
+ config = [deepcopy(ptz_simple_spread_vdn_config), deepcopy(ptz_simple_spread_vdn_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("14. vdn\n")
+
+
+# @pytest.mark.algotest
+def test_qmix():
+ config = [deepcopy(ptz_simple_spread_qmix_config), deepcopy(ptz_simple_spread_qmix_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("15. qmix\n")
+
+
+@pytest.mark.algotest
+def test_impala():
+ config = [deepcopy(cartpole_impala_config), deepcopy(cartpole_impala_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("16. impala\n")
+
+
+@pytest.mark.algotest
+def test_iqn():
+ config = [deepcopy(cartpole_iqn_config), deepcopy(cartpole_iqn_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("17. iqn\n")
+
+
+@pytest.mark.algotest
+def test_her_dqn():
+ try:
+ bitflip_her_dqn_config.exp_name = 'bitflip5_dqn'
+ bitflip_her_dqn_config.env.n_bits = 5
+ bitflip_her_dqn_config.policy.model.obs_shape = 10
+ bitflip_her_dqn_config.policy.model.action_shape = 5
+ bitflip_dqn_main(bitflip_her_dqn_config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("18. her dqn\n")
+
+
+@pytest.mark.algotest
+def test_ppg():
+ try:
+ ppg_main(cartpole_ppg_config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("19. ppg\n")
+
+
+@pytest.mark.algotest
+def test_sqn():
+ config = [deepcopy(cartpole_sqn_config), deepcopy(cartpole_sqn_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("20. sqn\n")
+
+
+@pytest.mark.algotest
+def test_qrdqn():
+ config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("21. qrdqn\n")
+
+
+@pytest.mark.algotest
+def test_acer():
+ config = [deepcopy(cartpole_acer_config), deepcopy(cartpole_acer_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("22. acer\n")
+
+
+@pytest.mark.algotest
+def test_selfplay():
+ try:
+ selfplay_main(deepcopy(league_demo_ppo_config), seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("23. selfplay\n")
+
+
+@pytest.mark.algotest
+def test_league():
+ try:
+ league_main(deepcopy(league_demo_ppo_config), seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("24. league\n")
+
+
+@pytest.mark.algotest
+def test_sqil():
+ expert_policy_state_dict_path = './expert_policy.pth'
+ config = [deepcopy(cartpole_sql_config), deepcopy(cartpole_sql_create_config)]
+ expert_policy = serial_pipeline(config, seed=0)
+ torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
+
+ config = [deepcopy(cartpole_sqil_config), deepcopy(cartpole_sqil_create_config)]
+ config[0].policy.collect.model_path = expert_policy_state_dict_path
+ try:
+ serial_pipeline_sqil(config, [cartpole_sql_config, cartpole_sql_create_config], seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("25. sqil\n")
+
+
+@pytest.mark.algotest
+def test_cql():
+ # train expert
+ config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)]
+ config[0].exp_name = 'sac'
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+
+ # collect expert data
+ import torch
+ config = [deepcopy(pendulum_sac_data_genearation_config), deepcopy(pendulum_sac_data_genearation_create_config)]
+ collect_count = config[0].policy.collect.n_sample
+ expert_data_path = config[0].policy.collect.save_path
+ state_dict = torch.load('./sac/ckpt/ckpt_best.pth.tar', map_location='cpu')
+ try:
+ collect_demo_data(
+ config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
+ )
+ except Exception:
+ assert False, "pipeline fail"
+
+ # train cql
+ config = [deepcopy(pendulum_cql_config), deepcopy(pendulum_cql_create_config)]
+ try:
+ serial_pipeline_offline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("26. cql\n")
+
+
+@pytest.mark.algotest
+def test_discrete_cql():
+ # train expert
+ config = [deepcopy(cartpole_qrdqn_config), deepcopy(cartpole_qrdqn_create_config)]
+ config[0].exp_name = 'cartpole'
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+
+ # collect expert data
+ import torch
+ config = [deepcopy(cartpole_qrdqn_generation_data_config), deepcopy(cartpole_qrdqn_generation_data_create_config)]
+ collect_count = config[0].policy.collect.collect_count
+ state_dict = torch.load('cartpole/ckpt/ckpt_best.pth.tar', map_location='cpu')
+ try:
+ collect_demo_data(config, seed=0, collect_count=collect_count, state_dict=state_dict)
+ except Exception:
+ assert False, "pipeline fail"
+
+ # train cql
+ config = [deepcopy(cartpole_discrete_cql_config), deepcopy(cartpole_discrete_cql_create_config)]
+ try:
+ serial_pipeline_offline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("27. discrete cql\n")
+
+
+# @pytest.mark.algotest
+def test_wqmix():
+ config = [deepcopy(ptz_simple_spread_wqmix_config), deepcopy(ptz_simple_spread_wqmix_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("28. wqmix\n")
+
+
+@pytest.mark.algotest
+def test_mdqn():
+ config = [deepcopy(cartpole_mdqn_config), deepcopy(cartpole_mdqn_create_config)]
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("29. mdqn\n")
+
+
+# @pytest.mark.algotest
+def test_td3_bc():
+ # train expert
+ config = [deepcopy(pendulum_td3_config), deepcopy(pendulum_td3_create_config)]
+ config[0].exp_name = 'td3'
+ try:
+ serial_pipeline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+
+ # collect expert data
+ import torch
+ config = [deepcopy(pendulum_td3_generation_config), deepcopy(pendulum_td3_generation_create_config)]
+ collect_count = config[0].policy.other.replay_buffer.replay_buffer_size
+ expert_data_path = config[0].policy.collect.save_path
+ state_dict = torch.load(config[0].policy.learn.learner.load_path, map_location='cpu')
+ try:
+ collect_demo_data(
+ config, seed=0, collect_count=collect_count, expert_data_path=expert_data_path, state_dict=state_dict
+ )
+ except Exception:
+ assert False, "pipeline fail"
+
+ # train td3 bc
+ config = [deepcopy(pendulum_td3_bc_config), deepcopy(pendulum_td3_bc_create_config)]
+ try:
+ serial_pipeline_offline(config, seed=0)
+ except Exception:
+ assert False, "pipeline fail"
+ with open("./algo_record.log", "a+") as f:
+ f.write("29. td3_bc\n")
+
+
+# @pytest.mark.algotest
+def test_running_on_orchestrator():
+ from kubernetes import config, client, dynamic
+ from ding.utils import K8sLauncher, OrchestratorLauncher
+ cluster_name = 'test-k8s-launcher'
+ config_path = os.path.join(os.path.dirname(__file__), 'config', 'k8s-config.yaml')
+ # create cluster
+ launcher = K8sLauncher(config_path)
+ launcher.name = cluster_name
+ launcher.create_cluster()
+
+ # create orchestrator
+ olauncher = OrchestratorLauncher('v0.2.0-rc.0', cluster=launcher)
+ olauncher.create_orchestrator()
+
+ # create dijob
+ namespace = 'default'
+ name = 'cartpole-dqn'
+ timeout = 20 * 60
+ file_path = os.path.dirname(__file__)
+ agconfig_path = os.path.join(file_path, 'config', 'agconfig.yaml')
+ dijob_path = os.path.join(file_path, 'config', 'dijob-cartpole.yaml')
+ create_object_from_config(agconfig_path, 'di-system')
+ create_object_from_config(dijob_path, namespace)
+
+ # watch for dijob to converge
+ config.load_kube_config()
+ dyclient = dynamic.DynamicClient(client.ApiClient(configuration=config.load_kube_config()))
+ dijobapi = dyclient.resources.get(api_version='diengine.opendilab.org/v1alpha1', kind='DIJob')
+
+ wait_for_dijob_condition(dijobapi, name, namespace, 'Succeeded', timeout)
+
+ v1 = client.CoreV1Api()
+ logs = v1.read_namespaced_pod_log(f'{name}-coordinator', namespace, tail_lines=20)
+ print(f'\ncoordinator logs:\n {logs} \n')
+
+ # delete dijob
+ dijobapi.delete(name=name, namespace=namespace, body={})
+ # delete orchestrator
+ olauncher.delete_orchestrator()
+ # delete k8s cluster
+ launcher.delete_cluster()
+
+
+def create_object_from_config(config_path: str, namespace: str = 'default'):
+ args = ['kubectl', 'apply', '-n', namespace, '-f', config_path]
+ proc = subprocess.Popen(args, stderr=subprocess.PIPE)
+ _, err = proc.communicate()
+ err_str = err.decode('utf-8').strip()
+ if err_str != '' and 'WARN' not in err_str and 'already exists' not in err_str:
+ raise RuntimeError(f'Failed to create object: {err_str}')
+
+
+def delete_object_from_config(config_path: str, namespace: str = 'default'):
+ args = ['kubectl', 'delete', '-n', namespace, '-f', config_path]
+ proc = subprocess.Popen(args, stderr=subprocess.PIPE)
+ _, err = proc.communicate()
+ err_str = err.decode('utf-8').strip()
+ if err_str != '' and 'WARN' not in err_str and 'NotFound' not in err_str:
+ raise RuntimeError(f'Failed to delete object: {err_str}')
+
+
+def wait_for_dijob_condition(dijobapi, name: str, namespace: str, phase: str, timeout: int = 60, interval: int = 1):
+ start = time.time()
+ dijob = dijobapi.get(name=name, namespace=namespace)
+ while (dijob.status is None or dijob.status.phase != phase) and time.time() - start < timeout:
+ time.sleep(interval)
+ dijob = dijobapi.get(name=name, namespace=namespace)
+
+ if dijob.status.phase == phase:
+ return
+ raise TimeoutError(f'Timeout waiting for DIJob: {name} to be {phase}')
diff --git a/DI-engine/ding/entry/tests/test_serial_entry_bc.py b/DI-engine/ding/entry/tests/test_serial_entry_bc.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2c0923ad2e5b3531f25e0d90c3ece1f3530e6cb
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_serial_entry_bc.py
@@ -0,0 +1,186 @@
+from copy import deepcopy
+import pytest
+import torch.nn.functional as F
+from typing import Tuple, List, Dict, Any
+import torch
+from collections import namedtuple
+import os
+
+from ding.torch_utils import to_device
+from ding.rl_utils import get_train_sample, get_nstep_return_data
+from ding.entry import serial_pipeline_bc, collect_demo_data, serial_pipeline
+from ding.policy import PPOOffPolicy, BehaviourCloningPolicy
+from ding.policy.common_utils import default_preprocess_learn
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from dizoo.classic_control.cartpole.config import cartpole_dqn_config, cartpole_dqn_create_config, \
+ cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config
+from dizoo.classic_control.pendulum.config import pendulum_sac_config, pendulum_sac_create_config
+
+
+@POLICY_REGISTRY.register('ppo_bc')
+class PPOILPolicy(PPOOffPolicy):
+
+ def _forward_learn(self, data: dict) -> dict:
+ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.get('ignore_done', False), use_nstep=False)
+ self._learn_model.train()
+ output = self._learn_model.forward(data['obs'], mode='compute_actor_critic')
+ value_loss = F.mse_loss(output['value'], data['value'])
+ policy_loss = F.smooth_l1_loss(output['logit'], data['logit'])
+ total_loss = value_loss + policy_loss
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ self._optimizer.step()
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': policy_loss.item(),
+ 'value_loss': value_loss.item(),
+ }
+
+ def _forward_eval(self, data):
+ if isinstance(data, dict):
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ o = default_decollate(self._eval_model.forward(data, mode='compute_actor'))
+ return {i: d for i, d in zip(data_id, o)}
+ return self._model(data, mode='compute_actor')
+
+ def _monitor_vars_learn(self) -> list:
+ return super()._monitor_vars_learn() + ['policy_loss', 'value_loss']
+
+
+@pytest.mark.unittest
+def test_serial_pipeline_bc_ppo():
+ # train expert policy
+ train_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
+ train_config[0].exp_name = 'test_serial_pipeline_bc_ppo'
+ expert_policy = serial_pipeline(train_config, seed=0)
+
+ # collect expert demo data
+ collect_count = 10000
+ expert_data_path = 'expert_data_ppo_bc.pkl'
+ state_dict = expert_policy.collect_mode.state_dict()
+ collect_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
+ collect_config[0].exp_name = 'test_serial_pipeline_bc_ppo_collect'
+ collect_demo_data(
+ collect_config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
+ )
+
+ # il training 1
+ il_config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
+ il_config[0].policy.eval.evaluator.multi_gpu = False
+ il_config[0].policy.learn.train_epoch = 20
+ il_config[1].policy.type = 'ppo_bc'
+ il_config[0].policy.continuous = False
+ il_config[0].exp_name = 'test_serial_pipeline_bc_ppo_il'
+ _, converge_stop_flag = serial_pipeline_bc(il_config, seed=314, data_path=expert_data_path)
+ assert converge_stop_flag
+
+ os.popen('rm -rf ' + expert_data_path)
+
+
+@POLICY_REGISTRY.register('dqn_bc')
+class DQNILPolicy(BehaviourCloningPolicy):
+
+ def _forward_learn(self, data: dict) -> dict:
+ return super()._forward_learn(data)
+
+ def _forward_collect(self, data: dict, eps: float):
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, eps=eps)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> Dict[str, Any]:
+ ret = super()._process_transition(obs, model_output, timestep)
+ ret['next_obs'] = timestep.obs
+ return ret
+
+ def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ super()._get_train_sample(data)
+ data = get_nstep_return_data(data, 1, gamma=0.99)
+ return get_train_sample(data, unroll_len=1)
+
+ def _forward_eval(self, data: dict) -> dict:
+ if isinstance(data, dict):
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ o = default_decollate(self._eval_model.forward(data))
+ return {i: d for i, d in zip(data_id, o)}
+ return self._model(data)
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'dqn', ['ding.model.template.q_learning']
+
+
+@pytest.mark.unittest
+def test_serial_pipeline_bc_dqn():
+ # train expert policy
+ train_config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
+ expert_policy = serial_pipeline(train_config, seed=0)
+
+ # collect expert demo data
+ collect_count = 10000
+ expert_data_path = 'expert_data_dqn.pkl'
+ state_dict = expert_policy.collect_mode.state_dict()
+ collect_config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
+ collect_config[1].policy.type = 'dqn_bc'
+ collect_config[0].policy.continuous = False
+ collect_config[0].policy.other.eps = 0
+ collect_demo_data(
+ collect_config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
+ )
+
+ # il training 2
+ il_config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
+ il_config[0].policy.learn.train_epoch = 15
+ il_config[1].policy.type = 'dqn_bc'
+ il_config[0].policy.continuous = False
+ il_config[0].env.stop_value = 50
+ il_config[0].policy.eval.evaluator.multi_gpu = False
+ _, converge_stop_flag = serial_pipeline_bc(il_config, seed=314, data_path=expert_data_path)
+ assert converge_stop_flag
+ os.popen('rm -rf ' + expert_data_path)
+
+
+@pytest.mark.unittest
+def test_serial_pipeline_bc_sac():
+ # train expert policy
+ train_config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)]
+ expert_policy = serial_pipeline(train_config, seed=0, max_train_iter=10)
+
+ # collect expert demo data
+ collect_count = 10000
+ expert_data_path = 'expert_data_sac.pkl'
+ state_dict = expert_policy.collect_mode.state_dict()
+ collect_config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)]
+ collect_demo_data(
+ collect_config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
+ )
+
+ # il training 2
+ il_config = [deepcopy(pendulum_sac_config), deepcopy(pendulum_sac_create_config)]
+ il_config[0].policy.learn.train_epoch = 15
+ il_config[1].policy.type = 'bc'
+ il_config[0].policy.continuous = True
+ il_config[0].env.stop_value = 50
+ il_config[0].policy.model = dict(
+ obs_shape=3,
+ action_shape=1,
+ action_space='regression',
+ actor_head_hidden_size=128,
+ )
+ il_config[0].policy.loss_type = 'l1_loss'
+ il_config[0].policy.learn.learning_rate = 1e-5
+ il_config[0].policy.eval.evaluator.multi_gpu = False
+ il_config[1].policy.type = 'bc'
+ _, converge_stop_flag = serial_pipeline_bc(il_config, seed=314, data_path=expert_data_path, max_iter=10)
+ os.popen('rm -rf ' + expert_data_path)
diff --git a/DI-engine/ding/entry/tests/test_serial_entry_bco.py b/DI-engine/ding/entry/tests/test_serial_entry_bco.py
new file mode 100644
index 0000000000000000000000000000000000000000..45faefe0a8de7bcf522cbf6b06bbed5ff0bbaae5
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_serial_entry_bco.py
@@ -0,0 +1,25 @@
+import pytest
+import torch
+from copy import deepcopy
+from ding.entry import serial_pipeline
+from ding.entry.serial_entry_bco import serial_pipeline_bco
+from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
+from dizoo.classic_control.cartpole.config.cartpole_bco_config import cartpole_bco_config, cartpole_bco_create_config
+
+
+@pytest.mark.unittest
+def test_bco():
+ expert_policy_state_dict_path = './expert_policy.pth'
+ expert_config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
+ expert_policy = serial_pipeline(expert_config, seed=0)
+ torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
+
+ config = [deepcopy(cartpole_bco_config), deepcopy(cartpole_bco_create_config)]
+ config[0].policy.collect.model_path = expert_policy_state_dict_path
+ try:
+ serial_pipeline_bco(
+ config, [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)], seed=0, max_train_iter=3
+ )
+ except Exception as e:
+ print(e)
+ assert False, "pipeline fail"
diff --git a/DI-engine/ding/entry/tests/test_serial_entry_dqfd.py b/DI-engine/ding/entry/tests/test_serial_entry_dqfd.py
new file mode 100644
index 0000000000000000000000000000000000000000..2849a3edf1d6f31790d52fc0f3ff45c95c65c911
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_serial_entry_dqfd.py
@@ -0,0 +1,23 @@
+import pytest
+import torch
+from copy import deepcopy
+from ding.entry import serial_pipeline
+from ding.entry.serial_entry_dqfd import serial_pipeline_dqfd
+from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
+from dizoo.classic_control.cartpole.config.cartpole_dqfd_config import cartpole_dqfd_config, cartpole_dqfd_create_config
+
+
+@pytest.mark.unittest
+def test_dqfd():
+ expert_policy_state_dict_path = './expert_policy.pth'
+ config = [deepcopy(cartpole_dqn_config), deepcopy(cartpole_dqn_create_config)]
+ expert_policy = serial_pipeline(config, seed=0)
+ torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
+
+ config = [deepcopy(cartpole_dqfd_config), deepcopy(cartpole_dqfd_create_config)]
+ config[0].policy.collect.model_path = expert_policy_state_dict_path
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline_dqfd(config, [cartpole_dqfd_config, cartpole_dqfd_create_config], seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
diff --git a/DI-engine/ding/entry/tests/test_serial_entry_for_anytrading.py b/DI-engine/ding/entry/tests/test_serial_entry_for_anytrading.py
new file mode 100644
index 0000000000000000000000000000000000000000..d7dd71c1bd05fbcb0489d79fcdae74eff3333922
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_serial_entry_for_anytrading.py
@@ -0,0 +1,41 @@
+import os
+import pytest
+from copy import deepcopy
+import numpy as np
+import pandas as pd
+from ding.entry.serial_entry import serial_pipeline
+from dizoo.gym_anytrading.config import stocks_dqn_config, stocks_dqn_create_config
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_stocks_dqn():
+ config = [deepcopy(stocks_dqn_config), deepcopy(stocks_dqn_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].exp_name = 'stocks_dqn_unittest'
+ config[0].env.stocks_data_filename = 'STOCKS_FAKE'
+
+ # ======== generate fake data =========
+ Date = pd.bdate_range(start='2010-02-20', end='2022-02-20')
+ data = {'Date': [], 'Open': [], 'High': [], 'Low': [], 'Close': [], 'Adj Close': [], 'Volume': []}
+ for i in range(len(Date)):
+ data['Date'].append(Date[i])
+ data['Low'].append(np.random.uniform(200, 500))
+ data['High'].append(np.random.uniform(data['Low'][-1], data['Low'][-1] + 10))
+ data['Open'].append(np.random.uniform(data['Low'][-1], data['High'][-1]))
+ data['Close'].append(np.random.uniform(data['Low'][-1], data['High'][-1]))
+ data['Adj Close'].append(data['Close'][-1])
+ data['Volume'].append(np.random.randint(1000000, 7000000))
+ # =====================================
+
+ fake_data = pd.DataFrame(data)
+ data_path = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
+ data_path += '/dizoo/gym_anytrading/envs/data/STOCKS_FAKE.csv'
+ fake_data.to_csv(data_path, sep=',', index=None)
+ try:
+ serial_pipeline(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.remove(data_path)
+ os.popen('rm -rf {}'.format(os.path.abspath('./stocks_dqn_unittest')))
diff --git a/DI-engine/ding/entry/tests/test_serial_entry_guided_cost.py b/DI-engine/ding/entry/tests/test_serial_entry_guided_cost.py
new file mode 100644
index 0000000000000000000000000000000000000000..33742d4fb80fe99d5642e75d381581838cbbe57c
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_serial_entry_guided_cost.py
@@ -0,0 +1,23 @@
+import pytest
+import torch
+from copy import deepcopy
+from ding.entry import serial_pipeline_onpolicy, serial_pipeline_guided_cost
+from dizoo.classic_control.cartpole.config import cartpole_ppo_config, cartpole_ppo_create_config
+from dizoo.classic_control.cartpole.config import cartpole_gcl_ppo_onpolicy_config, \
+ cartpole_gcl_ppo_onpolicy_create_config
+
+
+@pytest.mark.unittest
+def test_guided_cost():
+ expert_policy_state_dict_path = './expert_policy.pth'
+ config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
+ expert_policy = serial_pipeline_onpolicy(config, seed=0)
+ torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
+
+ config = [deepcopy(cartpole_gcl_ppo_onpolicy_config), deepcopy(cartpole_gcl_ppo_onpolicy_create_config)]
+ config[0].policy.collect.model_path = expert_policy_state_dict_path
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline_guided_cost(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
diff --git a/DI-engine/ding/entry/tests/test_serial_entry_mbrl.py b/DI-engine/ding/entry/tests/test_serial_entry_mbrl.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8f84c43f39b63ead075495d4c74b21ec3912a9d
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_serial_entry_mbrl.py
@@ -0,0 +1,41 @@
+import pytest
+from copy import deepcopy
+from ding.entry.serial_entry_mbrl import serial_pipeline_dyna, serial_pipeline_dream
+
+from dizoo.classic_control.pendulum.config.mbrl.pendulum_sac_mbpo_config \
+ import main_config as pendulum_sac_mbpo_main_config,\
+ create_config as pendulum_sac_mbpo_create_config
+
+from dizoo.classic_control.pendulum.config.mbrl.pendulum_mbsac_mbpo_config \
+ import main_config as pendulum_mbsac_mbpo_main_config,\
+ create_config as pendulum_mbsac_mbpo_create_config
+
+from dizoo.classic_control.pendulum.config.mbrl.pendulum_stevesac_mbpo_config \
+ import main_config as pendulum_stevesac_mbpo_main_config,\
+ create_config as pendulum_stevesac_mbpo_create_config
+
+
+@pytest.mark.unittest
+def test_dyna():
+ config = [deepcopy(pendulum_sac_mbpo_main_config), deepcopy(pendulum_sac_mbpo_create_config)]
+ config[0].world_model.model.max_epochs_since_update = 0
+ try:
+ serial_pipeline_dyna(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.unittest
+def test_dream():
+ configs = [
+ [deepcopy(pendulum_mbsac_mbpo_main_config),
+ deepcopy(pendulum_mbsac_mbpo_create_config)],
+ [deepcopy(pendulum_stevesac_mbpo_main_config),
+ deepcopy(pendulum_stevesac_mbpo_create_config)]
+ ]
+ try:
+ for config in configs:
+ config[0].world_model.model.max_epochs_since_update = 0
+ serial_pipeline_dream(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
diff --git a/DI-engine/ding/entry/tests/test_serial_entry_onpolicy.py b/DI-engine/ding/entry/tests/test_serial_entry_onpolicy.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b43f4068d78fcf19573c1abec14041f406e8d4a
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_serial_entry_onpolicy.py
@@ -0,0 +1,92 @@
+import pytest
+import time
+import os
+from copy import deepcopy
+
+from ding.entry import serial_pipeline_onpolicy
+from dizoo.classic_control.cartpole.config.cartpole_pg_config import cartpole_pg_config, cartpole_pg_create_config
+from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config
+from dizoo.classic_control.cartpole.config.cartpole_ppopg_config import cartpole_ppopg_config, cartpole_ppopg_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_a2c_config import cartpole_a2c_config, cartpole_a2c_create_config
+from dizoo.petting_zoo.config import ptz_simple_spread_mappo_config, ptz_simple_spread_mappo_create_config
+from dizoo.classic_control.pendulum.config.pendulum_ppo_config import pendulum_ppo_config, pendulum_ppo_create_config
+from dizoo.classic_control.cartpole.config.cartpole_ppo_stdim_config import cartpole_ppo_stdim_config, cartpole_ppo_stdim_create_config # noqa
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_pg():
+ config = [deepcopy(cartpole_pg_config), deepcopy(cartpole_pg_create_config)]
+ try:
+ serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_a2c():
+ config = [deepcopy(cartpole_a2c_config), deepcopy(cartpole_a2c_create_config)]
+ try:
+ serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_onpolicy_ppo():
+ config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
+ config[0].policy.learn.epoch_per_collect = 2
+ config[0].policy.eval.evaluator.eval_freq = 1
+ try:
+ serial_pipeline_onpolicy(config, seed=0, max_train_iter=2)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_onpolicy_ppopg():
+ config = [deepcopy(cartpole_ppopg_config), deepcopy(cartpole_ppopg_create_config)]
+ config[0].policy.learn.epoch_per_collect = 1
+ config[0].policy.eval.evaluator.eval_freq = 1
+ try:
+ serial_pipeline_onpolicy(config, seed=0, max_train_iter=2)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_mappo():
+ config = [deepcopy(ptz_simple_spread_mappo_config), deepcopy(ptz_simple_spread_mappo_create_config)]
+ config[0].policy.learn.epoch_per_collect = 1
+ config[1].env_manager.type = 'base'
+ try:
+ serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_onpolicy_ppo_continuous():
+ config = [deepcopy(pendulum_ppo_config), deepcopy(pendulum_ppo_create_config)]
+ config[0].policy.learn.epoch_per_collect = 1
+ try:
+ serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.platformtest
+@pytest.mark.unittest
+def test_onppo_stdim():
+ config = [deepcopy(cartpole_ppo_stdim_config), deepcopy(cartpole_ppo_stdim_create_config)]
+ config[0].policy.learn.update_per_collect = 1
+ config[0].exp_name = 'cartpole_ppo_stdim_unittest'
+ try:
+ serial_pipeline_onpolicy(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
diff --git a/DI-engine/ding/entry/tests/test_serial_entry_preference_based_irl.py b/DI-engine/ding/entry/tests/test_serial_entry_preference_based_irl.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e9198f92965b882e2927790eaf7507db2b1853e
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_serial_entry_preference_based_irl.py
@@ -0,0 +1,62 @@
+import pytest
+from copy import deepcopy
+import os
+from easydict import EasyDict
+
+import torch
+
+from ding.entry import serial_pipeline
+from ding.entry import serial_pipeline_preference_based_irl
+from dizoo.classic_control.cartpole.config.cartpole_trex_offppo_config import cartpole_trex_offppo_config,\
+ cartpole_trex_offppo_create_config
+from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config,\
+ cartpole_ppo_offpolicy_create_config
+from ding.entry.application_entry_trex_collect_data import trex_collecting_data
+from ding.reward_model.trex_reward_model import TrexConvEncoder
+from ding.torch_utils import is_differentiable
+
+
+@pytest.mark.unittest
+def test_serial_pipeline_trex():
+ exp_name = 'test_serial_pipeline_trex_expert'
+ config = [deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)]
+ config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
+ config[0].exp_name = exp_name
+ expert_policy = serial_pipeline(config, seed=0)
+
+ exp_name = 'test_serial_pipeline_trex_collect'
+ config = [deepcopy(cartpole_trex_offppo_config), deepcopy(cartpole_trex_offppo_create_config)]
+ config[0].exp_name = exp_name
+ config[0].reward_model.expert_model_path = 'test_serial_pipeline_trex_expert'
+ config[0].reward_model.checkpoint_max = 100
+ config[0].reward_model.checkpoint_step = 100
+ config[0].reward_model.num_snippets = 100
+ args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'})
+ trex_collecting_data(args=args)
+ try:
+ serial_pipeline_preference_based_irl(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf test_serial_pipeline_trex*')
+
+
+B = 4
+C, H, W = 3, 128, 128
+
+
+@pytest.mark.unittest
+class TestEncoder:
+
+ def output_check(self, model, outputs):
+ loss = outputs.sum()
+ is_differentiable(loss, model)
+
+ def test_conv_encoder(self):
+ inputs = torch.randn(B, C, H, W)
+ model = TrexConvEncoder((C, H, W))
+ print(model)
+ outputs = model(inputs)
+ self.output_check(model, outputs)
+ print(outputs.shape)
+ assert outputs.shape == (B, 1)
diff --git a/DI-engine/ding/entry/tests/test_serial_entry_preference_based_irl_onpolicy.py b/DI-engine/ding/entry/tests/test_serial_entry_preference_based_irl_onpolicy.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffc20b9899a1e3df513ab23f3dc3555d6a5403be
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_serial_entry_preference_based_irl_onpolicy.py
@@ -0,0 +1,38 @@
+import pytest
+from copy import deepcopy
+import os
+from easydict import EasyDict
+
+import torch
+
+from ding.entry import serial_pipeline_onpolicy
+from ding.entry import serial_pipeline_preference_based_irl_onpolicy
+from dizoo.classic_control.cartpole.config import cartpole_ppo_config, cartpole_ppo_create_config
+from dizoo.classic_control.cartpole.config import cartpole_trex_ppo_onpolicy_config, \
+ cartpole_trex_ppo_onpolicy_create_config
+from ding.entry.application_entry_trex_collect_data import trex_collecting_data
+
+
+@pytest.mark.unittest
+def test_serial_pipeline_trex_onpolicy():
+ exp_name = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_expert'
+ config = [deepcopy(cartpole_ppo_config), deepcopy(cartpole_ppo_create_config)]
+ config[0].policy.learn.learner.hook.save_ckpt_after_iter = 100
+ config[0].exp_name = exp_name
+ expert_policy = serial_pipeline_onpolicy(config, seed=0)
+
+ exp_name = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_collect'
+ config = [deepcopy(cartpole_trex_ppo_onpolicy_config), deepcopy(cartpole_trex_ppo_onpolicy_create_config)]
+ config[0].exp_name = exp_name
+ config[0].reward_model.expert_model_path = 'trex_onpolicy_test_serial_pipeline_trex_onpolicy_expert'
+ config[0].reward_model.checkpoint_max = 100
+ config[0].reward_model.checkpoint_step = 100
+ config[0].reward_model.num_snippets = 100
+ args = EasyDict({'cfg': deepcopy(config), 'seed': 0, 'device': 'cpu'})
+ trex_collecting_data(args=args)
+ try:
+ serial_pipeline_preference_based_irl_onpolicy(config, seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
+ finally:
+ os.popen('rm -rf test_serial_pipeline_trex_onpolicy*')
diff --git a/DI-engine/ding/entry/tests/test_serial_entry_reward_model.py b/DI-engine/ding/entry/tests/test_serial_entry_reward_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..404cb6d78cba158b4ab65deb3edcbe48c84685cb
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_serial_entry_reward_model.py
@@ -0,0 +1,89 @@
+import pytest
+import os
+from ditk import logging
+from easydict import EasyDict
+from copy import deepcopy
+
+from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
+from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_rnd_onppo_config import cartpole_ppo_rnd_config, cartpole_ppo_rnd_create_config # noqa
+from dizoo.classic_control.cartpole.config.cartpole_ppo_icm_config import cartpole_ppo_icm_config, cartpole_ppo_icm_create_config # noqa
+from ding.entry import serial_pipeline, collect_demo_data, serial_pipeline_reward_model_offpolicy, \
+ serial_pipeline_reward_model_onpolicy
+
+cfg = [
+ {
+ 'type': 'pdeil',
+ "alpha": 0.5,
+ "discrete_action": False
+ },
+ {
+ 'type': 'gail',
+ 'input_size': 5,
+ 'hidden_size': 64,
+ 'batch_size': 64,
+ },
+ {
+ 'type': 'pwil',
+ 's_size': 4,
+ 'a_size': 2,
+ 'sample_size': 500,
+ },
+ {
+ 'type': 'red',
+ 'sample_size': 5000,
+ 'input_size': 5,
+ 'hidden_size': 64,
+ 'update_per_collect': 200,
+ 'batch_size': 128,
+ },
+]
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('reward_model_config', cfg)
+def test_irl(reward_model_config):
+ reward_model_config = EasyDict(reward_model_config)
+ config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
+ expert_policy = serial_pipeline(config, seed=0, max_train_iter=2)
+ # collect expert demo data
+ collect_count = 10000
+ expert_data_path = 'expert_data.pkl'
+ state_dict = expert_policy.collect_mode.state_dict()
+ config = deepcopy(cartpole_ppo_offpolicy_config), deepcopy(cartpole_ppo_offpolicy_create_config)
+ collect_demo_data(
+ config, seed=0, state_dict=state_dict, expert_data_path=expert_data_path, collect_count=collect_count
+ )
+ # irl + rl training
+ cp_cartpole_dqn_config = deepcopy(cartpole_dqn_config)
+ cp_cartpole_dqn_create_config = deepcopy(cartpole_dqn_create_config)
+ cp_cartpole_dqn_create_config.reward_model = dict(type=reward_model_config.type)
+ if reward_model_config.type == 'gail':
+ reward_model_config['data_path'] = '.'
+ else:
+ reward_model_config['expert_data_path'] = expert_data_path
+ cp_cartpole_dqn_config.reward_model = reward_model_config
+ cp_cartpole_dqn_config.policy.collect.n_sample = 128
+ serial_pipeline_reward_model_offpolicy(
+ (cp_cartpole_dqn_config, cp_cartpole_dqn_create_config), seed=0, max_train_iter=2
+ )
+
+ os.popen("rm -rf ckpt_* log expert_data.pkl")
+
+
+@pytest.mark.unittest
+def test_rnd():
+ config = [deepcopy(cartpole_ppo_rnd_config), deepcopy(cartpole_ppo_rnd_create_config)]
+ try:
+ serial_pipeline_reward_model_onpolicy(config, seed=0, max_train_iter=2)
+ except Exception:
+ assert False, "pipeline fail"
+
+
+@pytest.mark.unittest
+def test_icm():
+ config = [deepcopy(cartpole_ppo_icm_config), deepcopy(cartpole_ppo_icm_create_config)]
+ try:
+ serial_pipeline_reward_model_offpolicy(config, seed=0, max_train_iter=2)
+ except Exception:
+ assert False, "pipeline fail"
diff --git a/DI-engine/ding/entry/tests/test_serial_entry_sqil.py b/DI-engine/ding/entry/tests/test_serial_entry_sqil.py
new file mode 100644
index 0000000000000000000000000000000000000000..34e7c75e2c17f39b23dbcdaf64e9db8810b85d53
--- /dev/null
+++ b/DI-engine/ding/entry/tests/test_serial_entry_sqil.py
@@ -0,0 +1,23 @@
+import pytest
+import torch
+from copy import deepcopy
+from ding.entry import serial_pipeline
+from ding.entry.serial_entry_sqil import serial_pipeline_sqil
+from dizoo.classic_control.cartpole.config.cartpole_sql_config import cartpole_sql_config, cartpole_sql_create_config
+from dizoo.classic_control.cartpole.config.cartpole_sqil_config import cartpole_sqil_config, cartpole_sqil_create_config
+
+
+@pytest.mark.unittest
+def test_sqil():
+ expert_policy_state_dict_path = './expert_policy.pth'
+ config = [deepcopy(cartpole_sql_config), deepcopy(cartpole_sql_create_config)]
+ expert_policy = serial_pipeline(config, seed=0)
+ torch.save(expert_policy.collect_mode.state_dict(), expert_policy_state_dict_path)
+
+ config = [deepcopy(cartpole_sqil_config), deepcopy(cartpole_sqil_create_config)]
+ config[0].policy.collect.model_path = expert_policy_state_dict_path
+ config[0].policy.learn.update_per_collect = 1
+ try:
+ serial_pipeline_sqil(config, [cartpole_sql_config, cartpole_sql_create_config], seed=0, max_train_iter=1)
+ except Exception:
+ assert False, "pipeline fail"
diff --git a/DI-engine/ding/entry/utils.py b/DI-engine/ding/entry/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbfbaa83bdd8a466c0b0cbc86f06c80153d320c9
--- /dev/null
+++ b/DI-engine/ding/entry/utils.py
@@ -0,0 +1,71 @@
+from typing import Optional, Callable, List, Any
+
+from ding.policy import PolicyFactory
+from ding.worker import IMetric, MetricSerialEvaluator
+
+
+class AccMetric(IMetric):
+
+ def eval(self, inputs: Any, label: Any) -> dict:
+ return {'Acc': (inputs['logit'].sum(dim=1) == label).sum().item() / label.shape[0]}
+
+ def reduce_mean(self, inputs: List[Any]) -> Any:
+ s = 0
+ for item in inputs:
+ s += item['Acc']
+ return {'Acc': s / len(inputs)}
+
+ def gt(self, metric1: Any, metric2: Any) -> bool:
+ if metric2 is None:
+ return True
+ if isinstance(metric2, dict):
+ m2 = metric2['Acc']
+ else:
+ m2 = metric2
+ return metric1['Acc'] > m2
+
+
+def mark_not_expert(ori_data: List[dict]) -> List[dict]:
+ for i in range(len(ori_data)):
+ # Set is_expert flag (expert 1, agent 0)
+ ori_data[i]['is_expert'] = 0
+ return ori_data
+
+
+def mark_warm_up(ori_data: List[dict]) -> List[dict]:
+ # for td3_vae
+ for i in range(len(ori_data)):
+ ori_data[i]['warm_up'] = True
+ return ori_data
+
+
+def random_collect(
+ policy_cfg: 'EasyDict', # noqa
+ policy: 'Policy', # noqa
+ collector: 'ISerialCollector', # noqa
+ collector_env: 'BaseEnvManager', # noqa
+ commander: 'BaseSerialCommander', # noqa
+ replay_buffer: 'IBuffer', # noqa
+ postprocess_data_fn: Optional[Callable] = None
+) -> None: # noqa
+ assert policy_cfg.random_collect_size > 0
+ if policy_cfg.get('transition_with_policy_data', False):
+ collector.reset_policy(policy.collect_mode)
+ else:
+ action_space = collector_env.action_space
+ random_policy = PolicyFactory.get_random_policy(policy.collect_mode, action_space=action_space)
+ collector.reset_policy(random_policy)
+ collect_kwargs = commander.step()
+ if policy_cfg.collect.collector.type == 'episode':
+ new_data = collector.collect(n_episode=policy_cfg.random_collect_size, policy_kwargs=collect_kwargs)
+ else:
+ new_data = collector.collect(
+ n_sample=policy_cfg.random_collect_size,
+ random_collect=True,
+ record_random_collect=False,
+ policy_kwargs=collect_kwargs
+ ) # 'record_random_collect=False' means random collect without output log
+ if postprocess_data_fn is not None:
+ new_data = postprocess_data_fn(new_data)
+ replay_buffer.push(new_data, cur_collector_envstep=0)
+ collector.reset_policy(policy.collect_mode)
diff --git a/DI-engine/ding/envs/__init__.py b/DI-engine/ding/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc1a3fed7b08cf6a048d7a52bbaeccc62e019447
--- /dev/null
+++ b/DI-engine/ding/envs/__init__.py
@@ -0,0 +1,5 @@
+from .env import *
+from .env_wrappers import *
+from .env_manager import *
+from .env_manager.ding_env_manager import setup_ding_env_manager
+from . import gym_env
diff --git a/DI-engine/ding/envs/common/__init__.py b/DI-engine/ding/envs/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4b3a2e011dde25731bb160dc4f3424e2cc85ef2
--- /dev/null
+++ b/DI-engine/ding/envs/common/__init__.py
@@ -0,0 +1,5 @@
+from .common_function import num_first_one_hot, sqrt_one_hot, div_one_hot, div_func, clip_one_hot, \
+ reorder_one_hot, reorder_one_hot_array, reorder_boolean_vector, affine_transform, \
+ batch_binary_encode, get_postion_vector, save_frames_as_gif
+from .env_element import EnvElement, EnvElementInfo
+from .env_element_runner import EnvElementRunner
diff --git a/DI-engine/ding/envs/common/common_function.py b/DI-engine/ding/envs/common/common_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..71db31728024930fe745a36e78da022989601f65
--- /dev/null
+++ b/DI-engine/ding/envs/common/common_function.py
@@ -0,0 +1,291 @@
+import math
+from functools import partial, lru_cache
+from typing import Optional, Dict, Any
+
+import numpy as np
+import torch
+
+from ding.compatibility import torch_ge_180
+from ding.torch_utils import one_hot
+
+num_first_one_hot = partial(one_hot, num_first=True)
+
+
+def sqrt_one_hot(v: torch.Tensor, max_val: int) -> torch.Tensor:
+ """
+ Overview:
+ Sqrt the input value ``v`` and transform it into one-hot.
+ Arguments:
+ - v (:obj:`torch.Tensor`): the value to be processed with `sqrt` and `one-hot`
+ - max_val (:obj:`int`): the input ``v``'s estimated max value, used to calculate one-hot bit number. \
+ ``v`` would be clamped by (0, max_val).
+ Returns:
+ - ret (:obj:`torch.Tensor`): the value processed after `sqrt` and `one-hot`
+ """
+ num = int(math.sqrt(max_val)) + 1
+ v = v.float()
+ v = torch.floor(torch.sqrt(torch.clamp(v, 0, max_val))).long()
+ return one_hot(v, num)
+
+
+def div_one_hot(v: torch.Tensor, max_val: int, ratio: int) -> torch.Tensor:
+ """
+ Overview:
+ Divide the input value ``v`` by ``ratio`` and transform it into one-hot.
+ Arguments:
+ - v (:obj:`torch.Tensor`): the value to be processed with `divide` and `one-hot`
+ - max_val (:obj:`int`): the input ``v``'s estimated max value, used to calculate one-hot bit number. \
+ ``v`` would be clamped by (0, ``max_val``).
+ - ratio (:obj:`int`): input ``v`` would be divided by ``ratio``
+ Returns:
+ - ret (:obj:`torch.Tensor`): the value processed after `divide` and `one-hot`
+ """
+ num = int(max_val / ratio) + 1
+ v = v.float()
+ v = torch.floor(torch.clamp(v, 0, max_val) / ratio).long()
+ return one_hot(v, num)
+
+
+def div_func(inputs: torch.Tensor, other: float, unsqueeze_dim: int = 1):
+ """
+ Overview:
+ Divide ``inputs`` by ``other`` and unsqueeze if needed.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): the value to be unsqueezed and divided
+ - other (:obj:`float`): input would be divided by ``other``
+ - unsqueeze_dim (:obj:`int`): the dim to implement unsqueeze
+ Returns:
+ - ret (:obj:`torch.Tensor`): the value processed after `unsqueeze` and `divide`
+ """
+ inputs = inputs.float()
+ if unsqueeze_dim is not None:
+ inputs = inputs.unsqueeze(unsqueeze_dim)
+ return torch.div(inputs, other)
+
+
+def clip_one_hot(v: torch.Tensor, num: int) -> torch.Tensor:
+ """
+ Overview:
+ Clamp the input ``v`` in (0, num-1) and make one-hot mapping.
+ Arguments:
+ - v (:obj:`torch.Tensor`): the value to be processed with `clamp` and `one-hot`
+ - num (:obj:`int`): number of one-hot bits
+ Returns:
+ - ret (:obj:`torch.Tensor`): the value processed after `clamp` and `one-hot`
+ """
+ v = v.clamp(0, num - 1)
+ return one_hot(v, num)
+
+
+def reorder_one_hot(
+ v: torch.LongTensor,
+ dictionary: Dict[int, int],
+ num: int,
+ transform: Optional[np.ndarray] = None
+) -> torch.Tensor:
+ """
+ Overview:
+ Reorder each value in input ``v`` according to reorder dict ``dictionary``, then make one-hot mapping
+ Arguments:
+ - v (:obj:`torch.LongTensor`): the original value to be processed with `reorder` and `one-hot`
+ - dictionary (:obj:`Dict[int, int]`): a reorder lookup dict, \
+ map original value to new reordered index starting from 0
+ - num (:obj:`int`): number of one-hot bits
+ - transform (:obj:`int`): an array to firstly transform the original action to general action
+ Returns:
+ - ret (:obj:`torch.Tensor`): one-hot data indicating reordered index
+ """
+ assert (len(v.shape) == 1)
+ assert (isinstance(v, torch.Tensor))
+ new_v = torch.zeros_like(v)
+ for idx in range(v.shape[0]):
+ if transform is None:
+ val = v[idx].item()
+ else:
+ val = transform[v[idx].item()]
+ new_v[idx] = dictionary[val]
+ return one_hot(new_v, num)
+
+
+def reorder_one_hot_array(
+ v: torch.LongTensor, array: np.ndarray, num: int, transform: Optional[np.ndarray] = None
+) -> torch.Tensor:
+ """
+ Overview:
+ Reorder each value in input ``v`` according to reorder dict ``dictionary``, then make one-hot mapping.
+ The difference between this function and ``reorder_one_hot`` is
+ whether the type of reorder lookup data structure is `np.ndarray` or `dict`.
+ Arguments:
+ - v (:obj:`torch.LongTensor`): the value to be processed with `reorder` and `one-hot`
+ - array (:obj:`np.ndarray`): a reorder lookup array, map original value to new reordered index starting from 0
+ - num (:obj:`int`): number of one-hot bits
+ - transform (:obj:`np.ndarray`): an array to firstly transform the original action to general action
+ Returns:
+ - ret (:obj:`torch.Tensor`): one-hot data indicating reordered index
+ """
+ v = v.numpy()
+ if transform is None:
+ val = array[v]
+ else:
+ val = array[transform[v]]
+ return one_hot(torch.LongTensor(val), num)
+
+
+def reorder_boolean_vector(
+ v: torch.LongTensor,
+ dictionary: Dict[int, int],
+ num: int,
+ transform: Optional[np.ndarray] = None
+) -> torch.Tensor:
+ """
+ Overview:
+ Reorder each value in input ``v`` to new index according to reorder dict ``dictionary``,
+ then set corresponding position in return tensor to 1.
+ Arguments:
+ - v (:obj:`torch.LongTensor`): the value to be processed with `reorder`
+ - dictionary (:obj:`Dict[int, int]`): a reorder lookup dict, \
+ map original value to new reordered index starting from 0
+ - num (:obj:`int`): total number of items, should equals to max index + 1
+ - transform (:obj:`np.ndarray`): an array to firstly transform the original action to general action
+ Returns:
+ - ret (:obj:`torch.Tensor`): boolean data containing only 0 and 1, \
+ indicating whether corresponding original value exists in input ``v``
+ """
+ ret = torch.zeros(num)
+ for item in v:
+ try:
+ if transform is None:
+ val = item.item()
+ else:
+ val = transform[item.item()]
+ idx = dictionary[val]
+ except KeyError as e:
+ # print(dictionary)
+ raise KeyError('{}_{}_'.format(num, e))
+ ret[idx] = 1
+ return ret
+
+
+@lru_cache(maxsize=32)
+def get_to_and(num_bits: int) -> np.ndarray:
+ """
+ Overview:
+ Get an np.ndarray with ``num_bits`` elements, each equals to :math:`2^n` (n decreases from num_bits-1 to 0).
+ Used by ``batch_binary_encode`` to make bit-wise `and`.
+ Arguments:
+ - num_bits (:obj:`int`): length of the generating array
+ Returns:
+ - to_and (:obj:`np.ndarray`): an array with ``num_bits`` elements, \
+ each equals to :math:`2^n` (n decreases from num_bits-1 to 0)
+ """
+ return 2 ** np.arange(num_bits - 1, -1, -1).reshape([1, num_bits])
+
+
+def batch_binary_encode(x: torch.Tensor, bit_num: int) -> torch.Tensor:
+ """
+ Overview:
+ Big endian binary encode ``x`` to float tensor
+ Arguments:
+ - x (:obj:`torch.Tensor`): the value to be unsqueezed and divided
+ - bit_num (:obj:`int`): number of bits, should satisfy :math:`2^{bit num} > max(x)`
+ Example:
+ >>> batch_binary_encode(torch.tensor([131,71]), 10)
+ tensor([[0., 0., 1., 0., 0., 0., 0., 0., 1., 1.],
+ [0., 0., 0., 1., 0., 0., 0., 1., 1., 1.]])
+ Returns:
+ - ret (:obj:`torch.Tensor`): the binary encoded tensor, containing only `0` and `1`
+ """
+ x = x.numpy()
+ xshape = list(x.shape)
+ x = x.reshape([-1, 1])
+ to_and = get_to_and(bit_num)
+ return torch.FloatTensor((x & to_and).astype(bool).astype(float).reshape(xshape + [bit_num]))
+
+
+def compute_denominator(x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Compute the denominator used in ``get_postion_vector``. \
+ Divide 1 at the last step, so you can use it as an multiplier.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input tensor, which is generated from torch.arange(0, d_model).
+ Returns:
+ - ret (:obj:`torch.Tensor`): Denominator result tensor.
+ """
+ if torch_ge_180():
+ x = torch.div(x, 2, rounding_mode='trunc') * 2
+ else:
+ x = torch.div(x, 2) * 2
+ x = torch.div(x, 64.)
+ x = torch.pow(10000., x)
+ x = torch.div(1., x)
+ return x
+
+
+def get_postion_vector(x: list) -> torch.Tensor:
+ """
+ Overview:
+ Get position embedding used in `Transformer`, even and odd :math:`\alpha` are stored in ``POSITION_ARRAY``
+ Arguments:
+ - x (:obj:`list`): original position index, whose length should be 32
+ Returns:
+ - v (:obj:`torch.Tensor`): position embedding tensor in 64 dims
+ """
+ # TODO use lru_cache to optimize it
+ POSITION_ARRAY = compute_denominator(torch.arange(0, 64, dtype=torch.float)) # d_model = 64
+ v = torch.zeros(64, dtype=torch.float)
+ x = torch.FloatTensor(x)
+ v[0::2] = torch.sin(x * POSITION_ARRAY[0::2]) # even
+ v[1::2] = torch.cos(x * POSITION_ARRAY[1::2]) # odd
+ return v
+
+
+def affine_transform(
+ data: Any,
+ action_clip: Optional[bool] = True,
+ alpha: Optional[float] = None,
+ beta: Optional[float] = None,
+ min_val: Optional[float] = None,
+ max_val: Optional[float] = None
+) -> Any:
+ """
+ Overview:
+ do affine transform for data in range [-1, 1], :math:`\alpha \times data + \beta`
+ Arguments:
+ - data (:obj:`Any`): the input data
+ - action_clip (:obj:`bool`): whether to do action clip operation ([-1, 1])
+ - alpha (:obj:`float`): affine transform weight
+ - beta (:obj:`float`): affine transform bias
+ - min_val (:obj:`float`): min value, if `min_val` and `max_val` are indicated, scale input data\
+ to [min_val, max_val]
+ - max_val (:obj:`float`): max value
+ Returns:
+ - transformed_data (:obj:`Any`): affine transformed data
+ """
+ if action_clip:
+ data = np.clip(data, -1, 1)
+ if min_val is not None:
+ assert max_val is not None
+ alpha = (max_val - min_val) / 2
+ beta = (max_val + min_val) / 2
+ assert alpha is not None
+ beta = beta if beta is not None else 0.
+ return data * alpha + beta
+
+
+def save_frames_as_gif(frames: list, path: str) -> None:
+ """
+ Overview:
+ save frames as gif to a specified path.
+ Arguments:
+ - frames (:obj:`List`): list of frames
+ - path (:obj:`str`): the path to save gif
+ """
+ try:
+ import imageio
+ except ImportError:
+ from ditk import logging
+ import sys
+ logging.warning("Please install imageio first.")
+ sys.exit(1)
+ imageio.mimsave(path, frames, fps=20)
diff --git a/DI-engine/ding/envs/common/env_element.py b/DI-engine/ding/envs/common/env_element.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b835e2d8a2d0e31827286ff9fd2b732f04bcdab
--- /dev/null
+++ b/DI-engine/ding/envs/common/env_element.py
@@ -0,0 +1,60 @@
+from abc import ABC, abstractmethod
+from collections import namedtuple
+from typing import Any
+
+EnvElementInfo = namedtuple('EnvElementInfo', ['shape', 'value'])
+
+
+class IEnvElement(ABC):
+
+ @abstractmethod
+ def __repr__(self) -> str:
+ raise NotImplementedError
+
+ @property
+ @abstractmethod
+ def info(self) -> Any:
+ raise NotImplementedError
+
+
+class EnvElement(IEnvElement):
+ _instance = None
+ _name = 'EnvElement'
+
+ def __init__(self, *args, **kwargs) -> None:
+ # placeholder
+ # self._shape = None
+ # self._value = None
+ # self._to_agent_processor = None
+ # self._from_agent_processor = None
+ self._init(*args, **kwargs)
+ self._check()
+
+ @abstractmethod
+ def _init(*args, **kwargs) -> None:
+ raise NotImplementedError
+
+ def __repr__(self) -> str:
+ return '{}: {}'.format(self._name, self._details())
+
+ @abstractmethod
+ def _details(self) -> str:
+ raise NotImplementedError
+
+ def _check(self) -> None:
+ flag = [
+ hasattr(self, '_shape'),
+ hasattr(self, '_value'),
+ # hasattr(self, '_to_agent_processor'),
+ # hasattr(self, '_from_agent_processor'),
+ ]
+ assert all(flag), 'this class {} is not a legal subclass of EnvElement({})'.format(self.__class__, flag)
+
+ @property
+ def info(self) -> 'EnvElementInfo':
+ return EnvElementInfo(
+ shape=self._shape,
+ value=self._value,
+ # to_agent_processor=self._to_agent_processor,
+ # from_agent_processor=self._from_agent_processor
+ )
diff --git a/DI-engine/ding/envs/common/env_element_runner.py b/DI-engine/ding/envs/common/env_element_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4b49f591c2d7af353ca533302ee2e011e35b537
--- /dev/null
+++ b/DI-engine/ding/envs/common/env_element_runner.py
@@ -0,0 +1,39 @@
+from abc import abstractmethod
+from typing import Any
+
+from .env_element import EnvElement, IEnvElement, EnvElementInfo
+from ..env.base_env import BaseEnv
+
+
+class IEnvElementRunner(IEnvElement):
+
+ @abstractmethod
+ def get(self, engine: BaseEnv) -> Any:
+ raise NotImplementedError
+
+ @abstractmethod
+ def reset(self, *args, **kwargs) -> None:
+ raise NotImplementedError
+
+
+class EnvElementRunner(IEnvElementRunner):
+
+ def __init__(self, *args, **kwargs) -> None:
+ self._init(*args, **kwargs)
+ self._check()
+
+ @abstractmethod
+ def _init(self, *args, **kwargs) -> None:
+ # set self._core and other state variable
+ raise NotImplementedError
+
+ def _check(self) -> None:
+ flag = [hasattr(self, '_core'), isinstance(self._core, EnvElement)]
+ assert all(flag), flag
+
+ def __repr__(self) -> str:
+ return repr(self._core)
+
+ @property
+ def info(self) -> 'EnvElementInfo':
+ return self._core.info
diff --git a/DI-engine/ding/envs/common/tests/test_common_function.py b/DI-engine/ding/envs/common/tests/test_common_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..11712331bbf6a3c4e7e7e6d99f6e829810094a14
--- /dev/null
+++ b/DI-engine/ding/envs/common/tests/test_common_function.py
@@ -0,0 +1,129 @@
+import os
+import random
+import shutil
+
+import numpy as np
+import pytest
+import torch
+from ding.envs.common.common_function import sqrt_one_hot, div_one_hot, div_func, clip_one_hot, \
+ reorder_one_hot, reorder_one_hot_array, reorder_boolean_vector, \
+ batch_binary_encode, get_postion_vector, \
+ affine_transform, save_frames_as_gif
+
+VALUES = [2, 3, 5, 7, 11]
+
+
+@pytest.fixture(scope="function")
+def setup_reorder_array():
+ ret = np.full((12), -1)
+ for i, v in enumerate(VALUES):
+ ret[v] = i
+ return ret
+
+
+@pytest.fixture(scope="function")
+def setup_reorder_dict():
+ return {v: i for i, v in enumerate(VALUES)}
+
+
+def generate_data():
+ ret = {
+ 'obs': np.random.randn(4),
+ }
+ p_weight = np.random.uniform()
+ if p_weight < 1. / 3:
+ pass # no key 'priority'
+ elif p_weight < 2. / 3:
+ ret['priority'] = None
+ else:
+ ret['priority'] = np.random.uniform()
+
+ return ret
+
+
+@pytest.mark.unittest
+class TestEnvCommonFunc:
+
+ def test_one_hot(self):
+ a = torch.Tensor([[3, 4, 5], [1, 2, 6]])
+
+ a_sqrt = sqrt_one_hot(a, 6)
+ assert a_sqrt.max().item() == 1
+ assert [j.sum().item() for i in a_sqrt for j in i] == [1 for _ in range(6)]
+ sqrt_dim = 3
+ assert a_sqrt.shape == (2, 3, sqrt_dim)
+
+ a_div = div_one_hot(a, 6, 2)
+ assert a_div.max().item() == 1
+ assert [j.sum().item() for i in a_div for j in i] == [1 for _ in range(6)]
+ div_dim = 4
+ assert a_div.shape == (2, 3, div_dim)
+
+ a_di = div_func(a, 2)
+ assert a_di.shape == (2, 1, 3)
+ assert torch.eq(a_di.squeeze() * 2, a).all()
+
+ a_clip = clip_one_hot(a.long(), 4)
+ assert a_clip.max().item() == 1
+ assert [j.sum().item() for i in a_clip for j in i] == [1 for _ in range(6)]
+ clip_dim = 4
+ assert a_clip.shape == (2, 3, clip_dim)
+
+ def test_reorder(self, setup_reorder_array, setup_reorder_dict):
+ a = torch.LongTensor([2, 7]) # VALUES = [2, 3, 5, 7, 11]
+
+ a_array = reorder_one_hot_array(a, setup_reorder_array, 5)
+ a_dict = reorder_one_hot(a, setup_reorder_dict, 5)
+ assert torch.eq(a_array, a_dict).all()
+ assert a_array.max().item() == 1
+ assert [j.sum().item() for j in a_array] == [1 for _ in range(2)]
+ reorder_dim = 5
+ assert a_array.shape == (2, reorder_dim)
+
+ a_bool = reorder_boolean_vector(a, setup_reorder_dict, 5)
+ assert a_array.max().item() == 1
+ assert torch.eq(a_bool, sum([_ for _ in a_array])).all()
+
+ def test_binary(self):
+ a = torch.LongTensor([445, 1023])
+ a_binary = batch_binary_encode(a, 10)
+ ans = []
+ for number in a:
+ one = [int(_) for _ in list(bin(number))[2:]]
+ for _ in range(10 - len(one)):
+ one.insert(0, 0)
+ ans.append(one)
+ ans = torch.Tensor(ans)
+ assert torch.eq(a_binary, ans).all()
+
+ def test_position(self):
+ a = [random.randint(0, 5000) for _ in range(32)]
+ a_position = get_postion_vector(a)
+ assert a_position.shape == (64, )
+
+ def test_affine_transform(self):
+ a = torch.rand(4, 3)
+ a = (a - a.min()) / (a.max() - a.min())
+ a = a * 2 - 1
+ ans = affine_transform(a, min_val=-2, max_val=2)
+ assert ans.shape == (4, 3)
+ assert ans.min() == -2 and ans.max() == 2
+ a = np.random.rand(3, 5)
+ a = (a - a.min()) / (a.max() - a.min())
+ a = a * 2 - 1
+ ans = affine_transform(a, alpha=4, beta=1)
+ assert ans.shape == (3, 5)
+ assert ans.min() == -3 and ans.max() == 5
+
+
+@pytest.mark.other
+def test_save_frames_as_gif():
+ frames = [np.random.randint(0, 255, [84, 84, 3]) for _ in range(100)]
+ replay_path_gif = './replay_path_gif'
+ env_id = 'test'
+ save_replay_count = 1
+ if not os.path.exists(replay_path_gif):
+ os.makedirs(replay_path_gif)
+ path = os.path.join(replay_path_gif, '{}_episode_{}.gif'.format(env_id, save_replay_count))
+ save_frames_as_gif(frames, path)
+ shutil.rmtree(replay_path_gif)
diff --git a/DI-engine/ding/envs/env/__init__.py b/DI-engine/ding/envs/env/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec0a93602cc8c076c9e546da6c02efe5c3a25a19
--- /dev/null
+++ b/DI-engine/ding/envs/env/__init__.py
@@ -0,0 +1,5 @@
+from .base_env import BaseEnv, get_vec_env_setting, BaseEnvTimestep, get_env_cls, create_model_env
+from .ding_env_wrapper import DingEnvWrapper
+from .default_wrapper import get_default_wrappers
+from .env_implementation_check import check_space_dtype, check_array_space, check_reset, check_step, \
+ check_different_memory, check_obs_deepcopy, check_all, demonstrate_correct_procedure
diff --git a/DI-engine/ding/envs/env/base_env.py b/DI-engine/ding/envs/env/base_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b575a85e52573748779b57fc4e25b99e4ce62f3
--- /dev/null
+++ b/DI-engine/ding/envs/env/base_env.py
@@ -0,0 +1,185 @@
+from abc import ABC, abstractmethod
+from typing import Any, List, Tuple
+import gym
+import copy
+from easydict import EasyDict
+from collections import namedtuple
+from ding.utils import import_module, ENV_REGISTRY
+
+BaseEnvTimestep = namedtuple('BaseEnvTimestep', ['obs', 'reward', 'done', 'info'])
+
+
+# for solving multiple inheritance metaclass conflict between gym and ABC
+class FinalMeta(type(ABC), type(gym.Env)):
+ pass
+
+
+class BaseEnv(gym.Env, ABC, metaclass=FinalMeta):
+ """
+ Overview:
+ Basic environment class, extended from ``gym.Env``
+ Interface:
+ ``__init__``, ``reset``, ``close``, ``step``, ``random_action``, ``create_collector_env_cfg``, \
+ ``create_evaluator_env_cfg``, ``enable_save_replay``
+ """
+
+ @abstractmethod
+ def __init__(self, cfg: dict) -> None:
+ """
+ Overview:
+ Lazy init, only related arguments will be initialized in ``__init__`` method, and the concrete \
+ env will be initialized the first time ``reset`` method is called.
+ Arguments:
+ - cfg (:obj:`dict`): Environment configuration in dict type.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def reset(self) -> Any:
+ """
+ Overview:
+ Reset the env to an initial state and returns an initial observation.
+ Returns:
+ - obs (:obj:`Any`): Initial observation after reset.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def close(self) -> None:
+ """
+ Overview:
+ Close env and all the related resources, it should be called after the usage of env instance.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def step(self, action: Any) -> 'BaseEnv.timestep':
+ """
+ Overview:
+ Run one timestep of the environment's dynamics/simulation.
+ Arguments:
+ - action (:obj:`Any`): The ``action`` input to step with.
+ Returns:
+ - timestep (:obj:`BaseEnv.timestep`): The result timestep of env executing one step.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def seed(self, seed: int) -> None:
+ """
+ Overview:
+ Set the seed for this env's random number generator(s).
+ Arguments:
+ - seed (:obj:`Any`): Random seed.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def __repr__(self) -> str:
+ """
+ Overview:
+ Return the information string of this env instance.
+ Returns:
+ - info (:obj:`str`): Information of this env instance, like type and arguments.
+ """
+ raise NotImplementedError
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ """
+ Overview:
+ Return a list of all of the environment from input config, used in env manager \
+ (a series of vectorized env), and this method is mainly responsible for envs collecting data.
+ Arguments:
+ - cfg (:obj:`dict`): Original input env config, which needs to be transformed into the type of creating \
+ env instance actually and generated the corresponding number of configurations.
+ Returns:
+ - env_cfg_list (:obj:`List[dict]`): List of ``cfg`` including all the config collector envs.
+
+ .. note::
+ Elements(env config) in collector_env_cfg/evaluator_env_cfg can be different, such as server ip and port.
+ """
+ collector_env_num = cfg.pop('collector_env_num')
+ return [cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ """
+ Overview:
+ Return a list of all of the environment from input config, used in env manager \
+ (a series of vectorized env), and this method is mainly responsible for envs evaluating performance.
+ Arguments:
+ - cfg (:obj:`dict`): Original input env config, which needs to be transformed into the type of creating \
+ env instance actually and generated the corresponding number of configurations.
+ Returns:
+ - env_cfg_list (:obj:`List[dict]`): List of ``cfg`` including all the config evaluator envs.
+ """
+ evaluator_env_num = cfg.pop('evaluator_env_num')
+ return [cfg for _ in range(evaluator_env_num)]
+
+ # optional method
+ def enable_save_replay(self, replay_path: str) -> None:
+ """
+ Overview:
+ Save replay file in the given path, and this method need to be self-implemented by each env class.
+ Arguments:
+ - replay_path (:obj:`str`): The path to save replay file.
+ """
+ raise NotImplementedError
+
+ # optional method
+ def random_action(self) -> Any:
+ """
+ Overview:
+ Return random action generated from the original action space, usually it is convenient for test.
+ Returns:
+ - random_action (:obj:`Any`): Action generated randomly.
+ """
+ pass
+
+
+def get_vec_env_setting(cfg: dict, collect: bool = True, eval_: bool = True) -> Tuple[type, List[dict], List[dict]]:
+ """
+ Overview:
+ Get vectorized env setting (env_fn, collector_env_cfg, evaluator_env_cfg).
+ Arguments:
+ - cfg (:obj:`dict`): Original input env config in user config, such as ``cfg.env``.
+ Returns:
+ - env_fn (:obj:`type`): Callable object, call it with proper arguments and then get a new env instance.
+ - collector_env_cfg (:obj:`List[dict]`): A list contains the config of collecting data envs.
+ - evaluator_env_cfg (:obj:`List[dict]`): A list contains the config of evaluation envs.
+
+ .. note::
+ Elements (env config) in collector_env_cfg/evaluator_env_cfg can be different, such as server ip and port.
+
+ """
+ import_module(cfg.get('import_names', []))
+ env_fn = ENV_REGISTRY.get(cfg.type)
+ collector_env_cfg = env_fn.create_collector_env_cfg(cfg) if collect else None
+ evaluator_env_cfg = env_fn.create_evaluator_env_cfg(cfg) if eval_ else None
+ return env_fn, collector_env_cfg, evaluator_env_cfg
+
+
+def get_env_cls(cfg: EasyDict) -> type:
+ """
+ Overview:
+ Get the env class by correspondng module of ``cfg`` and return the callable class.
+ Arguments:
+ - cfg (:obj:`dict`): Original input env config in user config, such as ``cfg.env``.
+ Returns:
+ - env_cls_type (:obj:`type`): Env module as the corresponding callable class type.
+ """
+ import_module(cfg.get('import_names', []))
+ return ENV_REGISTRY.get(cfg.type)
+
+
+def create_model_env(cfg: EasyDict) -> Any:
+ """
+ Overview:
+ Create model env, which is used in model-based RL.
+ """
+ cfg = copy.deepcopy(cfg)
+ model_env_fn = get_env_cls(cfg)
+ cfg.pop('import_names')
+ cfg.pop('type')
+ return model_env_fn(**cfg)
diff --git a/DI-engine/ding/envs/env/default_wrapper.py b/DI-engine/ding/envs/env/default_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0e1401b4621d942cae762cc612461d35f8d1f15
--- /dev/null
+++ b/DI-engine/ding/envs/env/default_wrapper.py
@@ -0,0 +1,51 @@
+from easydict import EasyDict
+from typing import Optional, List
+import copy
+
+eval_episode_return_wrapper = EasyDict(type='eval_episode_return')
+
+
+def get_default_wrappers(env_wrapper_name: str, env_id: Optional[str] = None, caller: str = 'collector') -> List[dict]:
+ """
+ Overview:
+ Get default wrappers for different environments used in ``DingEnvWrapper``.
+ Arguments:
+ - env_wrapper_name (:obj:`str`): The name of the environment wrapper.
+ - env_id (:obj:`Optional[str]`): The id of the specific environment, such as ``PongNoFrameskip-v4``.
+ - caller (:obj:`str`): The caller of the environment, including ``collector`` or ``evaluator``. Different \
+ caller may need different wrappers.
+ Returns:
+ - wrapper_list (:obj:`List[dict]`): The list of wrappers, each element is a config of the concrete wrapper.
+ Raises:
+ - NotImplementedError: ``env_wrapper_name`` is not in ``['mujoco_default', 'atari_default', \
+ 'gym_hybrid_default', 'default']``
+ """
+ assert caller == 'collector' or 'evaluator', caller
+ if env_wrapper_name == 'mujoco_default':
+ return [
+ copy.deepcopy(eval_episode_return_wrapper),
+ ]
+ elif env_wrapper_name == 'atari_default':
+ wrapper_list = []
+ wrapper_list.append(EasyDict(type='noop_reset', kwargs=dict(noop_max=30)))
+ wrapper_list.append(EasyDict(type='max_and_skip', kwargs=dict(skip=4)))
+ wrapper_list.append(EasyDict(type='episodic_life'))
+ if env_id is not None:
+ if 'Pong' in env_id or 'Qbert' in env_id or 'SpaceInvader' in env_id or 'Montezuma' in env_id:
+ wrapper_list.append(EasyDict(type='fire_reset'))
+ wrapper_list.append(EasyDict(type='warp_frame'))
+ wrapper_list.append(EasyDict(type='scaled_float_frame'))
+ if caller == 'collector':
+ wrapper_list.append(EasyDict(type='clip_reward'))
+ wrapper_list.append(EasyDict(type='frame_stack', kwargs=dict(n_frames=4)))
+ wrapper_list.append(copy.deepcopy(eval_episode_return_wrapper))
+ return wrapper_list
+ elif env_wrapper_name == 'gym_hybrid_default':
+ return [
+ EasyDict(type='gym_hybrid_dict_action'),
+ copy.deepcopy(eval_episode_return_wrapper),
+ ]
+ elif env_wrapper_name == 'default':
+ return [copy.deepcopy(eval_episode_return_wrapper)]
+ else:
+ raise NotImplementedError("not supported env_wrapper_name: {}".format(env_wrapper_name))
diff --git a/DI-engine/ding/envs/env/ding_env_wrapper.py b/DI-engine/ding/envs/env/ding_env_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc67e826bd1ba1bf67ec88e1256dfb489b046b47
--- /dev/null
+++ b/DI-engine/ding/envs/env/ding_env_wrapper.py
@@ -0,0 +1,365 @@
+from typing import List, Optional, Union, Dict
+from easydict import EasyDict
+import gym
+import gymnasium
+import copy
+import numpy as np
+import treetensor.numpy as tnp
+
+from ding.envs.common.common_function import affine_transform
+from ding.envs.env_wrappers import create_env_wrapper
+from ding.torch_utils import to_ndarray
+from ding.utils import CloudPickleWrapper
+from .base_env import BaseEnv, BaseEnvTimestep
+from .default_wrapper import get_default_wrappers
+
+
+class DingEnvWrapper(BaseEnv):
+ """
+ Overview:
+ This is a wrapper for the BaseEnv class, used to provide a consistent environment interface.
+ Interfaces:
+ __init__, reset, step, close, seed, random_action, _wrap_env, __repr__, create_collector_env_cfg,
+ create_evaluator_env_cfg, enable_save_replay, observation_space, action_space, reward_space, clone
+ """
+
+ def __init__(self, env: gym.Env = None, cfg: dict = None, seed_api: bool = True, caller: str = 'collector') -> None:
+ """
+ Overview:
+ Initialize the DingEnvWrapper. Either an environment instance or a config to create the environment \
+ instance should be passed in. For the former, i.e., an environment instance: The `env` parameter must not \
+ be `None`, but should be the instance. It does not support subprocess environment manager. Thus, it is \
+ usually used in simple environments. For the latter, i.e., a config to create an environment instance: \
+ The `cfg` parameter must contain `env_id`.
+ Arguments:
+ - env (:obj:`gym.Env`): An environment instance to be wrapped.
+ - cfg (:obj:`dict`): The configuration dictionary to create an environment instance.
+ - seed_api (:obj:`bool`): Whether to use seed API. Defaults to True.
+ - caller (:obj:`str`): A string representing the caller of this method, including ``collector`` or \
+ ``evaluator``. Different caller may need different wrappers. Default is 'collector'.
+ """
+ self._env = None
+ self._raw_env = env
+ self._cfg = cfg
+ self._seed_api = seed_api # some env may disable `env.seed` api
+ self._caller = caller
+ if self._cfg is None:
+ self._cfg = {}
+ self._cfg = EasyDict(self._cfg)
+ if 'act_scale' not in self._cfg:
+ self._cfg.act_scale = False
+ if 'rew_clip' not in self._cfg:
+ self._cfg.rew_clip = False
+ if 'env_wrapper' not in self._cfg:
+ self._cfg.env_wrapper = 'default'
+ if 'env_id' not in self._cfg:
+ self._cfg.env_id = None
+ if env is not None:
+ self._env = env
+ self._wrap_env(caller)
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._action_space.seed(0) # default seed
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ self._init_flag = True
+ else:
+ assert 'env_id' in self._cfg
+ self._init_flag = False
+ self._observation_space = None
+ self._action_space = None
+ self._reward_space = None
+ # Only if user specifies the replay_path, will the video be saved. So its inital value is None.
+ self._replay_path = None
+
+ # override
+ def reset(self) -> np.ndarray:
+ """
+ Overview:
+ Resets the state of the environment. If the environment is not initialized, it will be created first.
+ Returns:
+ - obs (:obj:`Dict`): The new observation after reset.
+ """
+ if not self._init_flag:
+ self._env = gym.make(self._cfg.env_id)
+ self._wrap_env(self._caller)
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ self._init_flag = True
+ if self._replay_path is not None:
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+ self._replay_path = None
+ if isinstance(self._env, gym.Env):
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ if self._seed_api:
+ self._env.seed(self._seed + np_seed)
+ self._action_space.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ if self._seed_api:
+ self._env.seed(self._seed)
+ self._action_space.seed(self._seed)
+ obs = self._env.reset()
+ elif isinstance(self._env, gymnasium.Env):
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._action_space.seed(self._seed + np_seed)
+ obs = self._env.reset(seed=self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._action_space.seed(self._seed)
+ obs = self._env.reset(seed=self._seed)
+ else:
+ obs = self._env.reset()
+ else:
+ raise RuntimeError("not support env type: {}".format(type(self._env)))
+ if self.observation_space.dtype == np.float32:
+ obs = to_ndarray(obs, dtype=np.float32)
+ else:
+ obs = to_ndarray(obs)
+ return obs
+
+ # override
+ def close(self) -> None:
+ """
+ Overview:
+ Clean up the environment by closing and deleting it.
+ This method should be called when the environment is no longer needed.
+ Failing to call this method can lead to memory leaks.
+ """
+ try:
+ self._env.close()
+ del self._env
+ except: # noqa
+ pass
+
+ # override
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ """
+ Overview:
+ Set the seed for the environment.
+ Arguments:
+ - seed (:obj:`int`): The seed to set.
+ - dynamic_seed (:obj:`bool`): Whether to use dynamic seed, default is True.
+ """
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ # override
+ def step(self, action: Union[np.int64, np.ndarray]) -> BaseEnvTimestep:
+ """
+ Overview:
+ Execute the given action in the environment, and return the timestep (observation, reward, done, info).
+ Arguments:
+ - action (:obj:`Union[np.int64, np.ndarray]`): The action to execute in the environment.
+ Returns:
+ - timestep (:obj:`BaseEnvTimestep`): The timestep after the action execution.
+ """
+ action = self._judge_action_type(action)
+ if self._cfg.act_scale:
+ action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high)
+ obs, rew, done, info = self._env.step(action)
+ if self._cfg.rew_clip:
+ rew = max(-10, rew)
+ rew = np.float32(rew)
+ if self.observation_space.dtype == np.float32:
+ obs = to_ndarray(obs, dtype=np.float32)
+ else:
+ obs = to_ndarray(obs)
+ rew = to_ndarray([rew], np.float32)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def _judge_action_type(self, action: Union[np.ndarray, dict]) -> Union[np.ndarray, dict]:
+ """
+ Overview:
+ Ensure the action taken by the agent is of the correct type.
+ This method is used to standardize different action types to a common format.
+ Arguments:
+ - action (Union[np.ndarray, dict]): The action taken by the agent.
+ Returns:
+ - action (Union[np.ndarray, dict]): The formatted action.
+ """
+ if isinstance(action, int):
+ return action
+ elif isinstance(action, np.int64):
+ return int(action)
+ elif isinstance(action, np.ndarray):
+ if action.shape == ():
+ action = action.item()
+ elif action.shape == (1, ) and action.dtype == np.int64:
+ action = action.item()
+ return action
+ elif isinstance(action, dict):
+ for k, v in action.items():
+ action[k] = self._judge_action_type(v)
+ return action
+ elif isinstance(action, tnp.ndarray):
+ return self._judge_action_type(action.json())
+ else:
+ raise TypeError(
+ '`action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format(
+ type(action), action
+ )
+ )
+
+ def random_action(self) -> np.ndarray:
+ """
+ Overview:
+ Return a random action from the action space of the environment.
+ Returns:
+ - action (:obj:`np.ndarray`): The random action.
+ """
+ random_action = self.action_space.sample()
+ if isinstance(random_action, np.ndarray):
+ pass
+ elif isinstance(random_action, int):
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ elif isinstance(random_action, dict):
+ random_action = to_ndarray(random_action)
+ else:
+ raise TypeError(
+ '`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format(
+ type(random_action), random_action
+ )
+ )
+ return random_action
+
+ def _wrap_env(self, caller: str = 'collector') -> None:
+ """
+ Overview:
+ Wrap the environment according to the configuration.
+ Arguments:
+ - caller (:obj:`str`): The caller of the environment, including ``collector`` or ``evaluator``. \
+ Different caller may need different wrappers. Default is 'collector'.
+ """
+ # wrapper_cfgs: Union[str, List]
+ wrapper_cfgs = self._cfg.env_wrapper
+ if isinstance(wrapper_cfgs, str):
+ wrapper_cfgs = get_default_wrappers(wrapper_cfgs, self._cfg.env_id, caller)
+ # self._wrapper_cfgs: List[Union[Callable, Dict]]
+ self._wrapper_cfgs = wrapper_cfgs
+ for wrapper in self._wrapper_cfgs:
+ # wrapper: Union[Callable, Dict]
+ if isinstance(wrapper, Dict):
+ self._env = create_env_wrapper(self._env, wrapper)
+ else: # Callable, such as lambda anonymous function
+ self._env = wrapper(self._env)
+
+ def __repr__(self) -> str:
+ """
+ Overview:
+ Return the string representation of the instance.
+ Returns:
+ - str (:obj:`str`): The string representation of the instance.
+ """
+ return "DI-engine Env({}), generated by DingEnvWrapper".format(self._cfg.env_id)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ """
+ Overview:
+ Create a list of environment configuration for collectors based on the input configuration.
+ Arguments:
+ - cfg (:obj:`dict`): The input configuration dictionary.
+ Returns:
+ - env_cfgs (:obj:`List[dict]`): The list of environment configurations for collectors.
+ """
+ actor_env_num = cfg.pop('collector_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = True
+ return [cfg for _ in range(actor_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ """
+ Overview:
+ Create a list of environment configuration for evaluators based on the input configuration.
+ Arguments:
+ - cfg (:obj:`dict`): The input configuration dictionary.
+ Returns:
+ - env_cfgs (:obj:`List[dict]`): The list of environment configurations for evaluators.
+ """
+ evaluator_env_num = cfg.pop('evaluator_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = False
+ return [cfg for _ in range(evaluator_env_num)]
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ """
+ Overview:
+ Enable the save replay functionality. The replay will be saved at the specified path.
+ Arguments:
+ - replay_path (:obj:`Optional[str]`): The path to save the replay, default is None.
+ """
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ """
+ Overview:
+ Return the observation space of the wrapped environment.
+ The observation space represents the range and shape of possible observations
+ that the environment can provide to the agent.
+ Note:
+ If the data type of the observation space is float64, it's converted to float32
+ for better compatibility with most machine learning libraries.
+ Returns:
+ - observation_space (gym.spaces.Space): The observation space of the environment.
+ """
+ if self._observation_space.dtype == np.float64:
+ self._observation_space.dtype = np.float32
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ """
+ Overview:
+ Return the action space of the wrapped environment.
+ The action space represents the range and shape of possible actions
+ that the agent can take in the environment.
+ Returns:
+ - action_space (gym.spaces.Space): The action space of the environment.
+ """
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ """
+ Overview:
+ Return the reward space of the wrapped environment.
+ The reward space represents the range and shape of possible rewards
+ that the agent can receive as a result of its actions.
+ Returns:
+ - reward_space (gym.spaces.Space): The reward space of the environment.
+ """
+ return self._reward_space
+
+ def clone(self, caller: str = 'collector') -> BaseEnv:
+ """
+ Overview:
+ Clone the current environment wrapper, creating a new environment with the same settings.
+ Arguments:
+ - caller (str): A string representing the caller of this method, including ``collector`` or ``evaluator``. \
+ Different caller may need different wrappers. Default is 'collector'.
+ Returns:
+ - DingEnvWrapper: A new instance of the environment with the same settings.
+ """
+ try:
+ spec = copy.deepcopy(self._raw_env.spec)
+ raw_env = CloudPickleWrapper(self._raw_env)
+ raw_env = copy.deepcopy(raw_env).data
+ raw_env.__setattr__('spec', spec)
+ except Exception:
+ raw_env = self._raw_env
+ return DingEnvWrapper(raw_env, self._cfg, self._seed_api, caller)
diff --git a/DI-engine/ding/envs/env/env_implementation_check.py b/DI-engine/ding/envs/env/env_implementation_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..18b24edc816dd9c1d6a0de88d1dcf288e5a82355
--- /dev/null
+++ b/DI-engine/ding/envs/env/env_implementation_check.py
@@ -0,0 +1,187 @@
+from tabnanny import check
+from typing import Any, Callable, List, Tuple
+import numpy as np
+from collections.abc import Sequence
+from easydict import EasyDict
+
+from ding.envs.env import BaseEnv, BaseEnvTimestep
+from ding.envs.env.tests import DemoEnv
+# from dizoo.atari.envs import AtariEnv
+
+
+def check_space_dtype(env: BaseEnv) -> None:
+ print("== 0. Test obs/act/rew space's dtype")
+ env.reset()
+ for name, space in zip(['obs', 'act', 'rew'], [env.observation_space, env.action_space, env.reward_space]):
+ if 'float' in repr(space.dtype):
+ assert space.dtype == np.float32, "If float, then must be np.float32, but get {} for {} space".format(
+ space.dtype, name
+ )
+ if 'int' in repr(space.dtype):
+ assert space.dtype == np.int64, "If int, then must be np.int64, but get {} for {} space".format(
+ space.dtype, name
+ )
+
+
+# Util function
+def check_array_space(ndarray, space, name) -> bool:
+ if isinstance(ndarray, np.ndarray):
+ # print("{}'s type should be np.ndarray".format(name))
+ assert ndarray.dtype == space.dtype, "{}'s dtype is {}, but requires {}".format(
+ name, ndarray.dtype, space.dtype
+ )
+ assert ndarray.shape == space.shape, "{}'s shape is {}, but requires {}".format(
+ name, ndarray.shape, space.shape
+ )
+ assert (space.low <= ndarray).all() and (ndarray <= space.high).all(
+ ), "{}'s value is {}, but requires in range ({},{})".format(name, ndarray, space.low, space.high)
+ elif isinstance(ndarray, Sequence):
+ for i in range(len(ndarray)):
+ try:
+ check_array_space(ndarray[i], space[i], name)
+ except AssertionError as e:
+ print("The following error happens at {}-th index".format(i))
+ raise e
+ elif isinstance(ndarray, dict):
+ for k in ndarray.keys():
+ try:
+ check_array_space(ndarray[k], space[k], name)
+ except AssertionError as e:
+ print("The following error happens at key {}".format(k))
+ raise e
+ else:
+ raise TypeError(
+ "Input array should be np.ndarray or sequence/dict of np.ndarray, but found {}".format(type(ndarray))
+ )
+
+
+def check_reset(env: BaseEnv) -> None:
+ print('== 1. Test reset method')
+ obs = env.reset()
+ check_array_space(obs, env.observation_space, 'obs')
+
+
+def check_step(env: BaseEnv) -> None:
+ done_times = 0
+ print('== 2. Test step method')
+ _ = env.reset()
+ if hasattr(env, "random_action"):
+ random_action = env.random_action()
+ else:
+ random_action = env.action_space.sample()
+ while True:
+ obs, rew, done, info = env.step(random_action)
+ for ndarray, space, name in zip([obs, rew], [env.observation_space, env.reward_space], ['obs', 'rew']):
+ check_array_space(ndarray, space, name)
+ if done:
+ assert 'eval_episode_return' in info, "info dict should have 'eval_episode_return' key."
+ done_times += 1
+ _ = env.reset()
+ if done_times == 3:
+ break
+
+
+# Util function
+def check_different_memory(array1, array2, step_times) -> None:
+ assert type(array1) == type(
+ array2
+ ), "In step times {}, obs_last_frame({}) and obs_this_frame({}) are not of the same type".format(
+ step_times, type(array1), type(array2)
+ )
+ if isinstance(array1, np.ndarray):
+ assert id(array1) != id(
+ array2
+ ), "In step times {}, obs_last_frame and obs_this_frame are the same np.ndarray".format(step_times)
+ elif isinstance(array1, Sequence):
+ assert len(array1) == len(
+ array2
+ ), "In step times {}, obs_last_frame({}) and obs_this_frame({}) have different sequence lengths".format(
+ step_times, len(array1), len(array2)
+ )
+ for i in range(len(array1)):
+ try:
+ check_different_memory(array1[i], array2[i], step_times)
+ except AssertionError as e:
+ print("The following error happens at {}-th index".format(i))
+ raise e
+ elif isinstance(array1, dict):
+ assert array1.keys() == array2.keys(), "In step times {}, obs_last_frame({}) and obs_this_frame({}) have \
+ different dict keys".format(step_times, array1.keys(), array2.keys())
+ for k in array1.keys():
+ try:
+ check_different_memory(array1[k], array2[k], step_times)
+ except AssertionError as e:
+ print("The following error happens at key {}".format(k))
+ raise e
+ else:
+ raise TypeError(
+ "Input array should be np.ndarray or list/dict of np.ndarray, but found {} and {}".format(
+ type(array1), type(array2)
+ )
+ )
+
+
+def check_obs_deepcopy(env: BaseEnv) -> None:
+
+ step_times = 0
+ print('== 3. Test observation deepcopy')
+ obs_1 = env.reset()
+ if hasattr(env, "random_action"):
+ random_action = env.random_action()
+ else:
+ random_action = env.action_space.sample()
+ while True:
+ step_times += 1
+ obs_2, _, done, _ = env.step(random_action)
+ check_different_memory(obs_1, obs_2, step_times)
+ obs_1 = obs_2
+ if done:
+ break
+
+
+def check_all(env: BaseEnv) -> None:
+ check_space_dtype(env)
+ check_reset(env)
+ check_step(env)
+ check_obs_deepcopy(env)
+
+
+def demonstrate_correct_procedure(env_fn: Callable) -> None:
+ print('== 4. Demonstrate the correct procudures')
+ done_times = 0
+ # Init the env.
+ env = env_fn({})
+ # Lazy init. The real env is not initialized until `reset` method is called
+ assert not hasattr(env, "_env")
+ # Must set seed before `reset` method is called.
+ env.seed(4)
+ assert env._seed == 4
+ # Reset the env. The real env is initialized here.
+ obs = env.reset()
+ while True:
+ # Using the policy to get the action from obs. But here we use `random_action` instead.
+ action = env.random_action()
+ obs, rew, done, info = env.step(action)
+ if done:
+ assert 'eval_episode_return' in info
+ done_times += 1
+ obs = env.reset()
+ # Seed will not change unless `seed` method is called again.
+ assert env._seed == 4
+ if done_times == 3:
+ break
+
+
+if __name__ == "__main__":
+ '''
+ # Moethods `check_*` are for user to check whether their implemented env obeys DI-engine's rules.
+ # You can replace `AtariEnv` with your own env.
+ atari_env = AtariEnv(EasyDict(env_id='PongNoFrameskip-v4', frame_stack=4, is_train=False))
+ check_reset(atari_env)
+ check_step(atari_env)
+ check_obs_deepcopy(atari_env)
+ '''
+ # Method `demonstrate_correct_procudure` is to demonstrate the correct procedure to
+ # use an env to generate trajectories.
+ # You can check whether your env's design is similar to `DemoEnv`
+ demonstrate_correct_procedure(DemoEnv)
diff --git a/DI-engine/ding/envs/env/tests/__init__.py b/DI-engine/ding/envs/env/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3accee4fb88df4ff046d6f163f92e0f498dd9c8
--- /dev/null
+++ b/DI-engine/ding/envs/env/tests/__init__.py
@@ -0,0 +1 @@
+from .demo_env import DemoEnv
diff --git a/DI-engine/ding/envs/env/tests/demo_env.py b/DI-engine/ding/envs/env/tests/demo_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..4867667f10a7910a761bc3c5afc4be632574d91d
--- /dev/null
+++ b/DI-engine/ding/envs/env/tests/demo_env.py
@@ -0,0 +1,72 @@
+from typing import Any, Union
+import gym
+import numpy as np
+
+from ding.envs.env import BaseEnv, BaseEnvTimestep
+
+
+class DemoEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._closed = True
+ # It is highly recommended to implement these three spaces
+ self._observation_space = gym.spaces.Dict(
+ {
+ "demo_dict": gym.spaces.Tuple(
+ [
+ gym.spaces.Box(low=-10., high=10., shape=(4, ), dtype=np.float32),
+ gym.spaces.Box(low=-100., high=100., shape=(1, ), dtype=np.float32)
+ ]
+ )
+ }
+ )
+ self._action_space = gym.spaces.Discrete(5)
+ self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, ), dtype=np.float32)
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def reset(self) -> Any:
+ """
+ Overview:
+ Resets the env to an initial state and returns an initial observation. Abstract Method from ``gym.Env``.
+ """
+ self._step_count = 0
+ self._env = "A real environment"
+ self._closed = False
+ return self.observation_space.sample()
+
+ def close(self) -> None:
+ self._closed = True
+
+ def step(self, action: Any) -> 'BaseEnv.timestep':
+ self._step_count += 1
+ obs = self.observation_space.sample()
+ rew = self.reward_space.sample()
+ if self._step_count == 30:
+ self._step_count = 0
+ done = True
+ else:
+ done = False
+ info = {}
+ if done:
+ info['eval_episode_return'] = self.reward_space.sample() * 30
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def seed(self, seed: int) -> None:
+ self._seed = seed
+
+ def random_action(self) -> Union[np.ndarray, int]:
+ return self.action_space.sample()
+
+ def __repr__(self) -> str:
+ return "Demo Env for env_implementation_test.py"
diff --git a/DI-engine/ding/envs/env/tests/test_ding_env_wrapper.py b/DI-engine/ding/envs/env/tests/test_ding_env_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d53adbfd3fecb7373fa2d27ae221593548135a2
--- /dev/null
+++ b/DI-engine/ding/envs/env/tests/test_ding_env_wrapper.py
@@ -0,0 +1,201 @@
+import gym
+import numpy as np
+import pytest
+from easydict import EasyDict
+
+from ding.torch_utils import to_ndarray
+from ding.envs.env import DingEnvWrapper
+
+
+class FakeEnvForTest(gym.Env):
+
+ def __init__(self):
+ self.observation_space = gym.spaces.Box(low=-1., high=1., shape=(10, ), dtype=np.float32)
+ self.action_space = gym.spaces.Tuple(
+ (
+ gym.spaces.Discrete(3),
+ gym.spaces.Box(low=np.array([0., -1.]), high=np.array([1., 1.]), shape=(2, ), dtype=np.float32)
+ )
+ )
+
+ def step(self, action):
+ assert self.action_space.contains(action)
+ self._step_count += 1
+ obs = self.observation_space.sample()
+ obs = to_ndarray(obs).astype(np.float32)
+ done = True if self._step_count == 100 else False
+ return (obs, 0.5, done, {})
+
+ def reset(self):
+ self._step_count = 0
+ obs = self.observation_space.sample()
+ obs = to_ndarray(obs).astype(np.float32)
+ return obs
+
+ def render(self, mode="human", close=False):
+ pass
+
+ def close(self):
+ pass
+
+
+gym.envs.registration.register(
+ id="FakeHybridForTest-v0",
+ entry_point="ding.envs.env.tests.test_ding_env_wrapper:FakeEnvForTest",
+)
+
+
+class TestDingEnvWrapper:
+
+ @pytest.mark.unittest
+ @pytest.mark.parametrize('env_id', ['CartPole-v0', 'Pendulum-v1'])
+ def test_cartpole_pendulum(self, env_id):
+ env = gym.make(env_id)
+ ding_env = DingEnvWrapper(env=env)
+ print(ding_env.observation_space, ding_env.action_space, ding_env.reward_space)
+ cfg = EasyDict(dict(
+ collector_env_num=16,
+ evaluator_env_num=3,
+ is_train=True,
+ ))
+ l1 = ding_env.create_collector_env_cfg(cfg)
+ assert isinstance(l1, list)
+ l1 = ding_env.create_evaluator_env_cfg(cfg)
+ assert isinstance(l1, list)
+ obs = ding_env.reset()
+ assert isinstance(obs, np.ndarray)
+ action = ding_env.random_action()
+ # assert isinstance(action, np.ndarray)
+ print('random_action: {}, action_space: {}'.format(action.shape, ding_env.action_space))
+
+ @pytest.mark.envtest
+ def test_mujoco(self):
+ env_cfg = EasyDict(
+ env_id='Ant-v3',
+ env_wrapper='mujoco_default',
+ )
+ ding_env_mujoco = DingEnvWrapper(cfg=env_cfg)
+ obs = ding_env_mujoco.reset()
+ assert isinstance(obs, np.ndarray)
+ # action_dim = ding_env_mujoco.action_space.shape # n
+ while True:
+ # action = np.random.random(size=action_dim) # Continuous Action
+ action = ding_env_mujoco.random_action()
+ timestep = ding_env_mujoco.step(action)
+ # print(_, timestep.reward)
+ assert timestep.reward.shape == (1, ), timestep.reward.shape
+ if timestep.done:
+ assert 'eval_episode_return' in timestep.info, timestep.info
+ break
+ print(ding_env_mujoco.observation_space, ding_env_mujoco.action_space, ding_env_mujoco.reward_space)
+ action = ding_env_mujoco.random_action()
+ # assert isinstance(action, np.ndarray)
+ assert action.shape == ding_env_mujoco.action_space.shape
+
+ @pytest.mark.envtest
+ @pytest.mark.parametrize('atari_env_id', ['Pong-v4', 'MontezumaRevenge-v4'])
+ def test_atari(self, atari_env_id):
+ env_cfg = EasyDict(
+ env_id=atari_env_id,
+ env_wrapper='atari_default',
+ )
+ ding_env_atari = DingEnvWrapper(cfg=env_cfg)
+
+ ding_env_atari.enable_save_replay('atari_path/')
+ obs = ding_env_atari.reset()
+ assert isinstance(obs, np.ndarray)
+ assert obs.shape == ding_env_atari.observation_space.shape # (4, 84, 84)
+ # action_dim = ding_env_atari.action_space.n
+ while True:
+ # action = np.random.choice(range(action_dim), size=(1, )) # Discrete Action
+ action = ding_env_atari.random_action()
+ timestep = ding_env_atari.step(action)
+ # print(timestep.reward)
+ assert timestep.reward.shape == ding_env_atari.reward_space.shape, timestep.reward.shape # (1, )
+ if timestep.done:
+ assert 'eval_episode_return' in timestep.info, timestep.info
+ break
+ print(ding_env_atari.observation_space, ding_env_atari.action_space, ding_env_atari.reward_space)
+ action = ding_env_atari.random_action()
+ # assert isinstance(action, np.ndarray)
+ assert action.shape == (1, )
+
+ @pytest.mark.unittest
+ @pytest.mark.parametrize('lun_bip_env_id', ['LunarLander-v2', 'LunarLanderContinuous-v2', 'BipedalWalker-v3'])
+ def test_lunarlander_bipedalwalker(self, lun_bip_env_id):
+ env_cfg = EasyDict(
+ env_id=lun_bip_env_id,
+ env_wrapper='default',
+ )
+ ding_env_lun_bip = DingEnvWrapper(cfg=env_cfg)
+
+ obs = ding_env_lun_bip.reset()
+ assert isinstance(obs, np.ndarray)
+ assert obs.shape == ding_env_lun_bip.observation_space.shape
+ # action_space = ding_env_lun_bip.action_space
+ # if lun_bip_env_id in ['LunarLanderContinuous-v2', 'BipedalWalker-v3']:
+ # action_dim = action_space.shape
+ # else:
+ # action_dim = action_space.n
+ while True:
+ # if lun_bip_env_id in ['LunarLanderContinuous-v2', 'BipedalWalker-v3']:
+ # action = np.random.random(size=action_dim) # Continuous Action
+ # else:
+ # action = np.random.choice(range(action_dim), size=(1, )) # Discrete Action
+ action = ding_env_lun_bip.random_action()
+ timestep = ding_env_lun_bip.step(action)
+ # print(timestep.reward)
+ assert timestep.reward.shape == ding_env_lun_bip.reward_space.shape, timestep.reward.shape # (1, )
+ if timestep.done:
+ assert 'eval_episode_return' in timestep.info, timestep.info
+ break
+ print(ding_env_lun_bip.observation_space, ding_env_lun_bip.action_space, ding_env_lun_bip.reward_space)
+ action = ding_env_lun_bip.random_action()
+ # assert isinstance(action, np.ndarray)
+ print('random_action: {}, action_space: {}'.format(action.shape, ding_env_lun_bip.action_space))
+
+ @pytest.mark.unittest
+ def test_hybrid(self):
+ env_cfg = EasyDict(env_id='FakeHybridForTest-v0', env_wrapper='gym_hybrid_default')
+ ding_env_hybrid = DingEnvWrapper(cfg=env_cfg)
+
+ obs = ding_env_hybrid.reset()
+ assert isinstance(obs, np.ndarray)
+ assert obs.shape == ding_env_hybrid.observation_space.shape
+ while True:
+ action = ding_env_hybrid.random_action()
+ # print('random_action:', action)
+ for k, v in action.items():
+ if isinstance(v, int):
+ continue
+ # print('before: {}, after: {}'.format(v.shape, ding_env_hybrid.action_space[k].shape))
+ v.shape = ding_env_hybrid.action_space[k].shape
+ timestep = ding_env_hybrid.step(action)
+ # print(timestep.reward)
+ assert timestep.reward.shape == ding_env_hybrid.reward_space.shape, timestep.reward.shape # (1, )
+ if timestep.done:
+ assert 'eval_episode_return' in timestep.info, timestep.info
+ break
+ print(ding_env_hybrid.observation_space, ding_env_hybrid.action_space, ding_env_hybrid.reward_space)
+ action = ding_env_hybrid.random_action()
+ print('random_action', action)
+ assert isinstance(action, dict)
+
+ @pytest.mark.envtest
+ def test_AllinObsWrapper(self):
+ env_cfg = EasyDict(env_id='PongNoFrameskip-v4', env_wrapper='reward_in_obs')
+ ding_env_aio = DingEnvWrapper(cfg=env_cfg)
+
+ data = ding_env_aio.reset()
+ assert isinstance(data, dict)
+ assert 'obs' in data.keys() and 'reward' in data.keys()
+ assert data['obs'].shape == ding_env_aio.observation_space
+ while True:
+ action = ding_env_aio.random_action()
+ timestep = ding_env_aio.step(action)
+ # print(timestep.reward)
+ assert isinstance(timestep.obs, dict)
+ if timestep.done:
+ assert 'eval_episode_return' in timestep.info, timestep.info
+ break
+ print(ding_env_aio.observation_space, ding_env_aio.action_space, ding_env_aio.reward_space)
diff --git a/DI-engine/ding/envs/env/tests/test_env_implementation_check.py b/DI-engine/ding/envs/env/tests/test_env_implementation_check.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb413304ce562fca6f9892396c8901821a208e1e
--- /dev/null
+++ b/DI-engine/ding/envs/env/tests/test_env_implementation_check.py
@@ -0,0 +1,51 @@
+import pytest
+from easydict import EasyDict
+import numpy as np
+import gym
+from copy import deepcopy
+
+from ding.envs.env import check_array_space, check_different_memory, check_all, demonstrate_correct_procedure
+from ding.envs.env.tests import DemoEnv
+
+
+@pytest.mark.unittest
+def test_an_implemented_env():
+ demo_env = DemoEnv({})
+ check_all(demo_env)
+ demonstrate_correct_procedure(DemoEnv)
+
+
+@pytest.mark.unittest
+def test_check_array_space():
+ seq_array = (np.array([1, 2, 3], dtype=np.int64), np.array([4., 5., 6.], dtype=np.float32))
+ seq_space = [gym.spaces.Box(low=0, high=10, shape=(3, ), dtype=np.int64) for _ in range(2)]
+ with pytest.raises(AssertionError):
+ check_array_space(seq_array, seq_space, 'test_sequence')
+
+ dict_array = {'a': np.array([1, 2, 3], dtype=np.int64), 'b': np.array([4., 5., 6.], dtype=np.float32)}
+ int_box = gym.spaces.Box(low=0, high=10, shape=(3, ), dtype=np.int64)
+ dict_space = {'a': deepcopy(int_box), 'b': deepcopy(int_box)}
+ with pytest.raises(AssertionError):
+ check_array_space(dict_array, dict_space, 'test_dict')
+
+ with pytest.raises(TypeError):
+ check_array_space(1, dict_space, 'test_type_error')
+
+
+@pytest.mark.unittest
+def test_check_different_memory():
+ int_seq = np.array([1, 2, 3], dtype=np.int64)
+ seq_array1 = (int_seq, np.array([4., 5., 6.], dtype=np.float32))
+ seq_array2 = (int_seq, np.array([4., 5., 6.], dtype=np.float32))
+ with pytest.raises(AssertionError):
+ check_different_memory(seq_array1, seq_array2, -1)
+
+ dict_array1 = {'a': np.array([4., 5., 6.], dtype=np.float32), 'b': int_seq}
+ dict_array2 = {'a': np.array([4., 5., 6.], dtype=np.float32), 'b': int_seq}
+ with pytest.raises(AssertionError):
+ check_different_memory(dict_array1, dict_array2, -1)
+
+ with pytest.raises(AssertionError):
+ check_different_memory(1, dict_array1, -1)
+ with pytest.raises(TypeError):
+ check_different_memory(1, 2, -1)
diff --git a/DI-engine/ding/envs/env_manager/__init__.py b/DI-engine/ding/envs/env_manager/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..62d45baf27fb4ec0743aa1f96935a03b4e48b562
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/__init__.py
@@ -0,0 +1,5 @@
+from .base_env_manager import BaseEnvManager, BaseEnvManagerV2, create_env_manager, get_env_manager_cls
+from .subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager, SubprocessEnvManagerV2
+from .gym_vector_env_manager import GymVectorEnvManager
+# Do not import PoolEnvManager here, because it depends on installation of `envpool`
+from .env_supervisor import EnvSupervisor
diff --git a/DI-engine/ding/envs/env_manager/base_env_manager.py b/DI-engine/ding/envs/env_manager/base_env_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..291390896c82942784b4209ae703bfd046e3dbe5
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/base_env_manager.py
@@ -0,0 +1,685 @@
+from types import MethodType
+from typing import Union, Any, List, Callable, Dict, Optional, Tuple
+from functools import partial, wraps
+from easydict import EasyDict
+from ditk import logging
+import copy
+import platform
+import numbers
+import enum
+import time
+import treetensor.numpy as tnp
+from ding.utils import ENV_MANAGER_REGISTRY, import_module, one_time_warning, make_key_as_identifier, WatchDog, \
+ remove_illegal_item
+from ding.envs import BaseEnv, BaseEnvTimestep
+
+global space_log_flag
+space_log_flag = True
+
+
+class EnvState(enum.IntEnum):
+ VOID = 0
+ INIT = 1
+ RUN = 2
+ RESET = 3
+ DONE = 4
+ ERROR = 5
+ NEED_RESET = 6
+
+
+def timeout_wrapper(func: Callable = None, timeout: Optional[int] = None) -> Callable:
+ """
+ Overview:
+ Watch the function that must be finihsed within a period of time. If timeout, raise the captured error.
+ """
+ if func is None:
+ return partial(timeout_wrapper, timeout=timeout)
+ if timeout is None:
+ return func
+
+ windows_flag = platform.system().lower() == 'windows'
+ if windows_flag:
+ one_time_warning("Timeout wrapper is not implemented in windows platform, so ignore it default")
+ return func
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ watchdog = WatchDog(timeout)
+ try:
+ watchdog.start()
+ except ValueError as e:
+ # watchdog invalid case
+ return func(*args, **kwargs)
+ try:
+ return func(*args, **kwargs)
+ except BaseException as e:
+ raise e
+ finally:
+ watchdog.stop()
+
+ return wrapper
+
+
+@ENV_MANAGER_REGISTRY.register('base')
+class BaseEnvManager(object):
+ """
+ Overview:
+ The basic class of env manager to manage multiple vectorized environments. BaseEnvManager define all the
+ necessary interfaces and derived class must extend this basic class.
+
+ The class is implemented by the pseudo-parallelism (i.e. serial) mechanism, therefore, this class is only
+ used in some tiny environments and for debug purpose.
+ Interfaces:
+ reset, step, seed, close, enable_save_replay, launch, default_config, reward_shaping, enable_save_figure
+ Properties:
+ env_num, env_ref, ready_obs, ready_obs_id, ready_imgs, done, closed, method_name_list, observation_space, \
+ action_space, reward_space
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ """
+ Overview:
+ Return the deepcopyed default config of env manager.
+ Returns:
+ - cfg (:obj:`EasyDict`): The default config of env manager.
+ """
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ config = dict(
+ # (int) The total episode number to be executed, defaults to inf, which means no episode limits.
+ episode_num=float("inf"),
+ # (int) The maximum retry times when the env is in error state, defaults to 1, i.e. no retry.
+ max_retry=1,
+ # (str) The retry type when the env is in error state, including ['reset', 'renew'], defaults to 'reset'.
+ # The former is to reset the env to the last reset state, while the latter is to create a new env.
+ retry_type='reset',
+ # (bool) Whether to automatically reset sub-environments when they are done, defaults to True.
+ auto_reset=True,
+ # (float) WatchDog timeout (second) for ``step`` method, defaults to None, which means no timeout.
+ step_timeout=None,
+ # (float) WatchDog timeout (second) for ``reset`` method, defaults to None, which means no timeout.
+ reset_timeout=None,
+ # (float) The interval waiting time for automatically retry mechanism, defaults to 0.1.
+ retry_waiting_time=0.1,
+ )
+
+ def __init__(
+ self,
+ env_fn: List[Callable],
+ cfg: EasyDict = EasyDict({}),
+ ) -> None:
+ """
+ Overview:
+ Initialize the base env manager with callable the env function and the EasyDict-type config. Here we use
+ ``env_fn`` to ensure the lazy initialization of sub-environments, which is benetificial to resource
+ allocation and parallelism. ``cfg`` is the merged result between the default config of this class
+ and user's config.
+ This construction function is in lazy-initialization mode, the actual initialization is in ``launch``.
+ Arguments:
+ - env_fn (:obj:`List[Callable]`): A list of functions to create ``env_num`` sub-environments.
+ - cfg (:obj:`EasyDict`): Final merged config.
+
+ .. note::
+ For more details about how to merge config, please refer to the system document of DI-engine \
+ (`en link <../03_system/config.html>`_).
+ """
+ self._cfg = cfg
+ self._env_fn = env_fn
+ self._env_num = len(self._env_fn)
+ self._closed = True
+ self._env_replay_path = None
+ # env_ref is used to acquire some common attributes of env, like obs_shape and act_shape
+ self._env_ref = self._env_fn[0]()
+ try:
+ self._observation_space = self._env_ref.observation_space
+ self._action_space = self._env_ref.action_space
+ self._reward_space = self._env_ref.reward_space
+ except:
+ # For some environment,
+ # we have to reset before getting observation description.
+ # However, for dmc-mujoco, we should not reset the env at the main thread,
+ # when using in a subprocess mode, which would cause opengl rendering bugs,
+ # leading to no response subprocesses.
+ self._env_ref.reset()
+ self._observation_space = self._env_ref.observation_space
+ self._action_space = self._env_ref.action_space
+ self._reward_space = self._env_ref.reward_space
+ self._env_ref.close()
+ self._env_states = {i: EnvState.VOID for i in range(self._env_num)}
+ self._env_seed = {i: None for i in range(self._env_num)}
+ self._episode_num = self._cfg.episode_num
+ self._max_retry = max(self._cfg.max_retry, 1)
+ self._auto_reset = self._cfg.auto_reset
+ self._retry_type = self._cfg.retry_type
+ assert self._retry_type in ['reset', 'renew'], self._retry_type
+ self._step_timeout = self._cfg.step_timeout
+ self._reset_timeout = self._cfg.reset_timeout
+ self._retry_waiting_time = self._cfg.retry_waiting_time
+
+ @property
+ def env_num(self) -> int:
+ """
+ Overview:
+ ``env_num`` is the number of sub-environments in env manager.
+ Returns:
+ - env_num (:obj:`int`): The number of sub-environments.
+ """
+ return self._env_num
+
+ @property
+ def env_ref(self) -> 'BaseEnv':
+ """
+ Overview:
+ ``env_ref`` is used to acquire some common attributes of env, like obs_shape and act_shape.
+ Returns:
+ - env_ref (:obj:`BaseEnv`): The reference of sub-environment.
+ """
+ return self._env_ref
+
+ @property
+ def observation_space(self) -> 'gym.spaces.Space': # noqa
+ """
+ Overview:
+ ``observation_space`` is the observation space of sub-environment, following the format of gym.spaces.
+ Returns:
+ - observation_space (:obj:`gym.spaces.Space`): The observation space of sub-environment.
+ """
+ return self._observation_space
+
+ @property
+ def action_space(self) -> 'gym.spaces.Space': # noqa
+ """
+ Overview:
+ ``action_space`` is the action space of sub-environment, following the format of gym.spaces.
+ Returns:
+ - action_space (:obj:`gym.spaces.Space`): The action space of sub-environment.
+ """
+ return self._action_space
+
+ @property
+ def reward_space(self) -> 'gym.spaces.Space': # noqa
+ """
+ Overview:
+ ``reward_space`` is the reward space of sub-environment, following the format of gym.spaces.
+ Returns:
+ - reward_space (:obj:`gym.spaces.Space`): The reward space of sub-environment.
+ """
+ return self._reward_space
+
+ @property
+ def ready_obs(self) -> Dict[int, Any]:
+ """
+ Overview:
+ Get the ready (next) observation, which is a special design to unify both aysnc/sync env manager.
+ For each interaction between policy and env, the policy will input the ready_obs and output the action.
+ Then the env_manager will ``step`` with the action and prepare the next ready_obs.
+ Returns:
+ - ready_obs (:obj:`Dict[int, Any]`): A dict with env_id keys and observation values.
+ Example:
+ >>> obs = env_manager.ready_obs
+ >>> stacked_obs = np.concatenate(list(obs.values()))
+ >>> action = policy(obs) # here policy inputs np obs and outputs np action
+ >>> action = {env_id: a for env_id, a in zip(obs.keys(), action)}
+ >>> timesteps = env_manager.step(action)
+ """
+ active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN]
+ return {i: self._ready_obs[i] for i in active_env}
+
+ @property
+ def ready_obs_id(self) -> List[int]:
+ """
+ Overview:
+ Get the ready (next) observation id, which is a special design to unify both aysnc/sync env manager.
+ Returns:
+ - ready_obs_id (:obj:`List[int]`): A list of env_ids for ready observations.
+ """
+ # In BaseEnvManager, if env_episode_count equals episode_num, this env is done.
+ return [i for i, s in self._env_states.items() if s == EnvState.RUN]
+
+ @property
+ def ready_imgs(self, render_mode: Optional[str] = 'rgb_array') -> Dict[int, Any]:
+ """
+ Overview:
+ Sometimes, we need to render the envs, this function is used to get the next ready renderd frame and \
+ corresponding env id.
+ Arguments:
+ - render_mode (:obj:`Optional[str]`): The render mode, can be 'rgb_array' or 'depth_array', which follows \
+ the definition in the ``render`` function of ``ding.utils`` .
+ Returns:
+ - ready_imgs (:obj:`Dict[int, np.ndarray]`): A dict with env_id keys and rendered frames.
+ """
+ from ding.utils import render
+ assert render_mode in ['rgb_array', 'depth_array'], render_mode
+ return {i: render(self._envs[i], render_mode) for i in self.ready_obs_id}
+
+ @property
+ def done(self) -> bool:
+ """
+ Overview:
+ ``done`` is a flag to indicate whether env manager is done, i.e., whether all sub-environments have \
+ executed enough episodes.
+ Returns:
+ - done (:obj:`bool`): Whether env manager is done.
+ """
+ return all([s == EnvState.DONE for s in self._env_states.values()])
+
+ @property
+ def method_name_list(self) -> list:
+ """
+ Overview:
+ The public methods list of sub-environments that can be directly called from the env manager level. Other \
+ methods and attributes will be accessed with the ``__getattr__`` method.
+ Methods defined in this list can be regarded as the vectorized extension of methods in sub-environments.
+ Sub-class of ``BaseEnvManager`` can override this method to add more methods.
+ Returns:
+ - method_name_list (:obj:`list`): The public methods list of sub-environments.
+ """
+ return [
+ 'reset', 'step', 'seed', 'close', 'enable_save_replay', 'render', 'reward_shaping', 'enable_save_figure'
+ ]
+
+ def env_state_done(self, env_id: int) -> bool:
+ return self._env_states[env_id] == EnvState.DONE
+
+ def __getattr__(self, key: str) -> Any:
+ """
+ Note:
+ If a python object doesn't have the attribute whose name is `key`, it will call this method.
+ We suppose that all envs have the same attributes.
+ If you need different envs, please implement other env managers.
+ """
+ if not hasattr(self._env_ref, key):
+ raise AttributeError("env `{}` doesn't have the attribute `{}`".format(type(self._env_ref), key))
+ if isinstance(getattr(self._env_ref, key), MethodType) and key not in self.method_name_list:
+ raise RuntimeError("env getattr doesn't support method({}), please override method_name_list".format(key))
+ self._check_closed()
+ return [getattr(env, key) if hasattr(env, key) else None for env in self._envs]
+
+ def _check_closed(self):
+ """
+ Overview:
+ Check whether the env manager is closed. Will be called in ``__getattr__`` and ``step``.
+ """
+ assert not self._closed, "env manager is closed, please use the alive env manager"
+
+ def launch(self, reset_param: Optional[Dict] = None) -> None:
+ """
+ Overview:
+ Launch the env manager, instantiate the sub-environments and set up the environments and their parameters.
+ Arguments:
+ - reset_param (:obj:`Optional[Dict]`): A dict of reset parameters for each environment, key is the env_id, \
+ value is the corresponding reset parameter, defaults to None.
+ """
+ assert self._closed, "Please first close the env manager"
+ try:
+ global space_log_flag
+ if space_log_flag:
+ logging.info("Env Space Information:")
+ logging.info("\tObservation Space: {}".format(self._observation_space))
+ logging.info("\tAction Space: {}".format(self._action_space))
+ logging.info("\tReward Space: {}".format(self._reward_space))
+ space_log_flag = False
+ except:
+ pass
+ if reset_param is not None:
+ assert len(reset_param) == len(self._env_fn)
+ self._create_state()
+ self.reset(reset_param)
+
+ def _create_state(self) -> None:
+ self._env_episode_count = {i: 0 for i in range(self.env_num)}
+ self._ready_obs = {i: None for i in range(self.env_num)}
+ self._envs = [e() for e in self._env_fn]
+ assert len(self._envs) == self._env_num
+ self._reset_param = {i: {} for i in range(self.env_num)}
+ self._env_states = {i: EnvState.INIT for i in range(self.env_num)}
+ if self._env_replay_path is not None:
+ for e, s in zip(self._envs, self._env_replay_path):
+ e.enable_save_replay(s)
+ self._closed = False
+
+ def reset(self, reset_param: Optional[Dict] = None) -> None:
+ """
+ Overview:
+ Forcely reset the sub-environments their corresponding parameters. Because in env manager all the \
+ sub-environments usually are reset automatically as soon as they are done, this method is only called when \
+ the caller must forcely reset all the sub-environments, such as in evaluation.
+ Arguments:
+ - reset_param (:obj:`List`): Dict of reset parameters for each environment, key is the env_id, \
+ value is the corresponding reset parameters.
+ """
+ self._check_closed()
+ # set seed if necessary
+ env_ids = list(range(self._env_num)) if reset_param is None else list(reset_param.keys())
+ for i, env_id in enumerate(env_ids): # loop-type is necessary
+ if self._env_seed[env_id] is not None:
+ if self._env_dynamic_seed is not None:
+ self._envs[env_id].seed(self._env_seed[env_id], self._env_dynamic_seed)
+ else:
+ self._envs[env_id].seed(self._env_seed[env_id])
+ self._env_seed[env_id] = None # seed only use once
+ # reset env
+ if reset_param is None:
+ env_range = range(self.env_num)
+ else:
+ for env_id in reset_param:
+ self._reset_param[env_id] = reset_param[env_id]
+ env_range = reset_param.keys()
+ for env_id in env_range:
+ if self._env_replay_path is not None and self._env_states[env_id] == EnvState.RUN:
+ logging.warning("please don't reset a unfinished env when you enable save replay, we just skip it")
+ continue
+ self._reset(env_id)
+
+ def _reset(self, env_id: int) -> None:
+
+ @timeout_wrapper(timeout=self._reset_timeout)
+ def reset_fn():
+ # if self._reset_param[env_id] is None, just reset specific env, not pass reset param
+ if self._reset_param[env_id] is not None:
+ assert isinstance(self._reset_param[env_id], dict), type(self._reset_param[env_id])
+ return self._envs[env_id].reset(**self._reset_param[env_id])
+ else:
+ return self._envs[env_id].reset()
+
+ exceptions = []
+ for _ in range(self._max_retry):
+ try:
+ self._env_states[env_id] = EnvState.RESET
+ obs = reset_fn()
+ self._ready_obs[env_id] = obs
+ self._env_states[env_id] = EnvState.RUN
+ return
+ except BaseException as e:
+ if self._retry_type == 'renew':
+ err_env = self._envs[env_id]
+ err_env.close()
+ self._envs[env_id] = self._env_fn[env_id]()
+ exceptions.append(e)
+ time.sleep(self._retry_waiting_time)
+ continue
+
+ self._env_states[env_id] = EnvState.ERROR
+ self.close()
+ logging.error("Env {} reset has exceeded max retries({})".format(env_id, self._max_retry))
+ runtime_error = RuntimeError(
+ "Env {} reset has exceeded max retries({}), and the latest exception is: {}".format(
+ env_id, self._max_retry, str(exceptions[-1])
+ )
+ )
+ runtime_error.__traceback__ = exceptions[-1].__traceback__
+ raise runtime_error
+
+ def step(self, actions: Dict[int, Any]) -> Dict[int, BaseEnvTimestep]:
+ """
+ Overview:
+ Execute env step according to input actions. If some sub-environments are done after this execution, \
+ they will be reset automatically when ``self._auto_reset`` is True, otherwise they need to be reset when \
+ the caller use the ``reset`` method of env manager.
+ Arguments:
+ - actions (:obj:`Dict[int, Any]`): A dict of actions, key is the env_id, value is corresponding action. \
+ action can be any type, it depends on the env, and the env will handle it. Ususlly, the action is \
+ a dict of numpy array, and the value is generated by the outer caller like ``policy``.
+ Returns:
+ - timesteps (:obj:`Dict[int, BaseEnvTimestep]`): Each timestep is a ``BaseEnvTimestep`` object, \
+ usually including observation, reward, done, info. Some special customized environments will have \
+ the special timestep definition. The length of timesteps is the same as the length of actions in \
+ synchronous env manager.
+ Example:
+ >>> timesteps = env_manager.step(action)
+ >>> for env_id, timestep in enumerate(timesteps):
+ >>> if timestep.done:
+ >>> print('Env {} is done'.format(env_id))
+ """
+ self._check_closed()
+ timesteps = {}
+ for env_id, act in actions.items():
+ timesteps[env_id] = self._step(env_id, act)
+ if timesteps[env_id].done:
+ self._env_episode_count[env_id] += 1
+ if self._env_episode_count[env_id] < self._episode_num:
+ if self._auto_reset:
+ self._reset(env_id)
+ else:
+ self._env_states[env_id] = EnvState.NEED_RESET
+ else:
+ self._env_states[env_id] = EnvState.DONE
+ else:
+ self._ready_obs[env_id] = timesteps[env_id].obs
+ return timesteps
+
+ def _step(self, env_id: int, act: Any) -> BaseEnvTimestep:
+
+ @timeout_wrapper(timeout=self._step_timeout)
+ def step_fn():
+ return self._envs[env_id].step(act)
+
+ exceptions = []
+ for _ in range(self._max_retry):
+ try:
+ return step_fn()
+ except BaseException as e:
+ exceptions.append(e)
+ self._env_states[env_id] = EnvState.ERROR
+ logging.error("Env {} step has exceeded max retries({})".format(env_id, self._max_retry))
+ runtime_error = RuntimeError(
+ "Env {} step has exceeded max retries({}), and the latest exception is: {}".format(
+ env_id, self._max_retry, str(exceptions[-1])
+ )
+ )
+ runtime_error.__traceback__ = exceptions[-1].__traceback__
+ raise runtime_error
+
+ def seed(self, seed: Union[Dict[int, int], List[int], int], dynamic_seed: bool = None) -> None:
+ """
+ Overview:
+ Set the random seed for each environment.
+ Arguments:
+ - seed (:obj:`Union[Dict[int, int], List[int], int]`): Dict or List of seeds for each environment; \
+ If only one seed is provided, it will be used in the same way for all environments.
+ - dynamic_seed (:obj:`bool`): Whether to use dynamic seed.
+
+ .. note::
+ For more details about ``dynamic_seed``, please refer to the best practice document of DI-engine \
+ (`en link <../04_best_practice/random_seed.html>`_).
+ """
+ if isinstance(seed, numbers.Integral):
+ seed = [seed + i for i in range(self.env_num)]
+ self._env_seed = seed
+ elif isinstance(seed, list):
+ assert len(seed) == self._env_num, "len(seed) {:d} != env_num {:d}".format(len(seed), self._env_num)
+ self._env_seed = seed
+ elif isinstance(seed, dict):
+ if not hasattr(self, '_env_seed'):
+ raise RuntimeError("please indicate all the seed of each env in the beginning")
+ for env_id, s in seed.items():
+ self._env_seed[env_id] = s
+ else:
+ raise TypeError("invalid seed arguments type: {}".format(type(seed)))
+ self._env_dynamic_seed = dynamic_seed
+ try:
+ self._action_space.seed(seed[0])
+ except Exception: # TODO(nyz) deal with nested action_space like SMAC
+ pass
+
+ def enable_save_replay(self, replay_path: Union[List[str], str]) -> None:
+ """
+ Overview:
+ Enable all environments to save replay video after each episode terminates.
+ Arguments:
+ - replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \
+ Or one path for all environments.
+ """
+ if isinstance(replay_path, str):
+ replay_path = [replay_path] * self.env_num
+ self._env_replay_path = replay_path
+
+ def enable_save_figure(self, env_id: int, figure_path: str) -> None:
+ """
+ Overview:
+ Enable a specific env to save figure (e.g. environment statistics or episode return curve).
+ Arguments:
+ - figure_path (:obj:`str`): The file directory path for all environments to save figures.
+ """
+ assert figure_path is not None
+ self._envs[env_id].enable_save_figure(figure_path)
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close the env manager and release all the environment resources.
+ """
+ if self._closed:
+ return
+ for env in self._envs:
+ env.close()
+ for i in range(self._env_num):
+ self._env_states[i] = EnvState.VOID
+ self._closed = True
+
+ def reward_shaping(self, env_id: int, transitions: List[dict]) -> List[dict]:
+ """
+ Overview:
+ Execute reward shaping for a specific environment, which is often called when a episode terminates.
+ Arguments:
+ - env_id (:obj:`int`): The id of the environment to be shaped.
+ - transitions (:obj:`List[dict]`): The transition data list of the environment to be shaped.
+ Returns:
+ - transitions (:obj:`List[dict]`): The shaped transition data list.
+ """
+ return self._envs[env_id].reward_shaping(transitions)
+
+ @property
+ def closed(self) -> bool:
+ """
+ Overview:
+ ``closed`` is a property that returns whether the env manager is closed.
+ Returns:
+ - closed (:obj:`bool`): Whether the env manager is closed.
+ """
+ return self._closed
+
+ def random_action(self) -> Dict:
+ return {env_id: self._env_ref.action_space.sample() for env_id in self.ready_obs_id}
+
+
+@ENV_MANAGER_REGISTRY.register('base_v2')
+class BaseEnvManagerV2(BaseEnvManager):
+ """
+ Overview:
+ The basic class of env manager to manage multiple vectorized environments. BaseEnvManager define all the
+ necessary interfaces and derived class must extend this basic class.
+
+ The class is implemented by the pseudo-parallelism (i.e. serial) mechanism, therefore, this class is only
+ used in some tiny environments and for debug purpose.
+
+ ``V2`` means this env manager is designed for new task pipeline and interfaces coupled with treetensor.`
+
+ .. note::
+ For more details about new task pipeline, please refer to the system document of DI-engine \
+ (`system en link <../03_system/index.html>`_).
+
+ Interfaces:
+ reset, step, seed, close, enable_save_replay, launch, default_config, reward_shaping, enable_save_figure
+ Properties:
+ env_num, env_ref, ready_obs, ready_obs_id, ready_imgs, done, closed, method_name_list, observation_space, \
+ action_space, reward_space
+ """
+
+ @property
+ def ready_obs(self) -> tnp.array:
+ """
+ Overview:
+ Get the ready (next) observation, which is a special design to unify both aysnc/sync env manager.
+ For each interaction between policy and env, the policy will input the ready_obs and output the action.
+ Then the env_manager will ``step`` with the action and prepare the next ready_obs.
+ For ``V2`` version, the observation is transformed and packed up into ``tnp.array`` type, which allows
+ more convenient operations.
+ Return:
+ - ready_obs (:obj:`tnp.array`): A stacked treenumpy-type observation data.
+ Example:
+ >>> obs = env_manager.ready_obs
+ >>> action = policy(obs) # here policy inputs treenp obs and output np action
+ >>> timesteps = env_manager.step(action)
+ """
+ active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN]
+ obs = [self._ready_obs[i] for i in active_env]
+ if isinstance(obs[0], dict): # transform each element to treenumpy array
+ obs = [tnp.array(o) for o in obs]
+ return tnp.stack(obs)
+
+ def step(self, actions: List[tnp.ndarray]) -> List[tnp.ndarray]:
+ """
+ Overview:
+ Execute env step according to input actions. If some sub-environments are done after this execution, \
+ they will be reset automatically by default.
+ Arguments:
+ - actions (:obj:`List[tnp.ndarray]`): A list of treenumpy-type actions, the value is generated by the \
+ outer caller like ``policy``.
+ Returns:
+ - timesteps (:obj:`List[tnp.ndarray]`): A list of timestep, Each timestep is a ``tnp.ndarray`` object, \
+ usually including observation, reward, done, info, env_id. Some special environments will have \
+ the special timestep definition. The length of timesteps is the same as the length of actions in \
+ synchronous env manager. For the compatibility of treenumpy, here we use ``make_key_as_identifier`` \
+ and ``remove_illegal_item`` functions to modify the original timestep.
+ Example:
+ >>> timesteps = env_manager.step(action)
+ >>> for timestep in timesteps:
+ >>> if timestep.done:
+ >>> print('Env {} is done'.format(timestep.env_id))
+ """
+ actions = {env_id: a for env_id, a in zip(self.ready_obs_id, actions)}
+ timesteps = super().step(actions)
+ new_data = []
+ for env_id, timestep in timesteps.items():
+ obs, reward, done, info = timestep
+ # make the type and content of key as similar as identifier,
+ # in order to call them as attribute (e.g. timestep.xxx), such as ``TimeLimit.truncated`` in cartpole info
+ info = make_key_as_identifier(info)
+ info = remove_illegal_item(info)
+ new_data.append(tnp.array({'obs': obs, 'reward': reward, 'done': done, 'info': info, 'env_id': env_id}))
+ return new_data
+
+
+def create_env_manager(manager_cfg: EasyDict, env_fn: List[Callable]) -> BaseEnvManager:
+ """
+ Overview:
+ Create an env manager according to ``manager_cfg`` and env functions.
+ Arguments:
+ - manager_cfg (:obj:`EasyDict`): Final merged env manager config.
+ - env_fn (:obj:`List[Callable]`): A list of functions to create ``env_num`` sub-environments.
+ ArgumentsKeys:
+ - type (:obj:`str`): Env manager type set in ``ENV_MANAGER_REGISTRY.register`` , such as ``base`` .
+ - import_names (:obj:`List[str]`): A list of module names (paths) to import before creating env manager, such \
+ as ``ding.envs.env_manager.base_env_manager`` .
+ Returns:
+ - env_manager (:obj:`BaseEnvManager`): The created env manager.
+
+ .. tip::
+ This method will not modify the ``manager_cfg`` , it will deepcopy the ``manager_cfg`` and then modify it.
+ """
+ manager_cfg = copy.deepcopy(manager_cfg)
+ if 'import_names' in manager_cfg:
+ import_module(manager_cfg.pop('import_names'))
+ manager_type = manager_cfg.pop('type')
+ return ENV_MANAGER_REGISTRY.build(manager_type, env_fn=env_fn, cfg=manager_cfg)
+
+
+def get_env_manager_cls(cfg: EasyDict) -> type:
+ """
+ Overview:
+ Get the env manager class according to config, which is used to access related class variables/methods.
+ Arguments:
+ - manager_cfg (:obj:`EasyDict`): Final merged env manager config.
+ ArgumentsKeys:
+ - type (:obj:`str`): Env manager type set in ``ENV_MANAGER_REGISTRY.register`` , such as ``base`` .
+ - import_names (:obj:`List[str]`): A list of module names (paths) to import before creating env manager, such \
+ as ``ding.envs.env_manager.base_env_manager`` .
+ Returns:
+ - env_manager_cls (:obj:`type`): The corresponding env manager class.
+ """
+ import_module(cfg.get('import_names', []))
+ return ENV_MANAGER_REGISTRY.get(cfg.type)
diff --git a/DI-engine/ding/envs/env_manager/ding_env_manager.py b/DI-engine/ding/envs/env_manager/ding_env_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..12a3ecf88f15aa5523c7c1c6a4456e6ac7dc4c5c
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/ding_env_manager.py
@@ -0,0 +1,23 @@
+from . import BaseEnvManagerV2, SubprocessEnvManagerV2
+from ..env import DingEnvWrapper
+from typing import Optional
+from functools import partial
+
+
+def setup_ding_env_manager(
+ env: DingEnvWrapper,
+ env_num: int,
+ context: Optional[str] = None,
+ debug: bool = False,
+ caller: str = 'collector'
+) -> BaseEnvManagerV2:
+ assert caller in ['evaluator', 'collector']
+ if debug:
+ env_cls = BaseEnvManagerV2
+ manager_cfg = env_cls.default_config()
+ else:
+ env_cls = SubprocessEnvManagerV2
+ manager_cfg = env_cls.default_config()
+ if context is not None:
+ manager_cfg.context = context
+ return env_cls([partial(env.clone, caller) for _ in range(env_num)], manager_cfg)
diff --git a/DI-engine/ding/envs/env_manager/env_supervisor.py b/DI-engine/ding/envs/env_manager/env_supervisor.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec5e29beab19680bd2f9770ea986d118f1778894
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/env_supervisor.py
@@ -0,0 +1,558 @@
+from collections import defaultdict
+import math
+import queue
+from time import sleep, time
+import gym
+from ding.framework import Supervisor
+from typing import TYPE_CHECKING, Any, List, Union, Dict, Optional, Callable
+from ding.framework.supervisor import ChildType, RecvPayload, SendPayload
+from ding.utils import make_key_as_identifier
+from ditk import logging
+from ding.data import ShmBufferContainer
+import enum
+import treetensor.numpy as tnp
+import numbers
+if TYPE_CHECKING:
+ from gym.spaces import Space
+
+
+class EnvState(enum.IntEnum):
+ """
+ VOID -> RUN -> DONE
+ """
+ VOID = 0
+ INIT = 1
+ RUN = 2
+ RESET = 3
+ DONE = 4
+ ERROR = 5
+ NEED_RESET = 6
+
+
+class EnvRetryType(str, enum.Enum):
+ RESET = "reset"
+ RENEW = "renew"
+
+
+class EnvSupervisor(Supervisor):
+ """
+ Manage multiple envs with supervisor.
+
+ New features (compared to env manager):
+ - Consistent interface in multi-process and multi-threaded mode.
+ - Add asynchronous features and recommend using asynchronous methods.
+ - Reset is performed after an error is encountered in the step method.
+
+ Breaking changes (compared to env manager):
+ - Without some states.
+ """
+
+ def __init__(
+ self,
+ type_: ChildType = ChildType.PROCESS,
+ env_fn: List[Callable] = None,
+ retry_type: EnvRetryType = EnvRetryType.RESET,
+ max_try: Optional[int] = None,
+ max_retry: Optional[int] = None,
+ auto_reset: bool = True,
+ reset_timeout: Optional[int] = None,
+ step_timeout: Optional[int] = None,
+ retry_waiting_time: Optional[int] = None,
+ episode_num: int = float("inf"),
+ shared_memory: bool = True,
+ copy_on_get: bool = True,
+ **kwargs
+ ) -> None:
+ """
+ Overview:
+ Supervisor that manage a group of envs.
+ Arguments:
+ - type_ (:obj:`ChildType`): Type of child process.
+ - env_fn (:obj:`List[Callable]`): The function to create environment
+ - retry_type (:obj:`EnvRetryType`): Retry reset or renew env.
+ - max_try (:obj:`EasyDict`): Max try times for reset or step action.
+ - max_retry (:obj:`Optional[int]`): Alias of max_try.
+ - auto_reset (:obj:`bool`): Auto reset env if reach done.
+ - reset_timeout (:obj:`Optional[int]`): Timeout in seconds for reset.
+ - step_timeout (:obj:`Optional[int]`): Timeout in seconds for step.
+ - retry_waiting_time (:obj:`Optional[float]`): Wait time on each retry.
+ - shared_memory (:obj:`bool`): Use shared memory in multiprocessing.
+ - copy_on_get (:obj:`bool`): Use copy on get in multiprocessing.
+ """
+ if kwargs:
+ logging.warning("Unknown parameters on env supervisor: {}".format(kwargs))
+ super().__init__(type_=type_)
+ if type_ is not ChildType.PROCESS and (shared_memory or copy_on_get):
+ logging.warning("shared_memory and copy_on_get only works in process mode.")
+ self._shared_memory = type_ is ChildType.PROCESS and shared_memory
+ self._copy_on_get = type_ is ChildType.PROCESS and copy_on_get
+ self._env_fn = env_fn
+ self._create_env_ref()
+ self._obs_buffers = None
+ if env_fn:
+ if self._shared_memory:
+ obs_space = self._observation_space
+ if isinstance(obs_space, gym.spaces.Dict):
+ # For multi_agent case, such as multiagent_mujoco and petting_zoo mpe.
+ # Now only for the case that each agent in the team have the same obs structure
+ # and corresponding shape.
+ shape = {k: v.shape for k, v in obs_space.spaces.items()}
+ dtype = {k: v.dtype for k, v in obs_space.spaces.items()}
+ else:
+ shape = obs_space.shape
+ dtype = obs_space.dtype
+ self._obs_buffers = {
+ env_id: ShmBufferContainer(dtype, shape, copy_on_get=self._copy_on_get)
+ for env_id in range(len(self._env_fn))
+ }
+ for env_init in env_fn:
+ self.register(env_init, shm_buffer=self._obs_buffers, shm_callback=self._shm_callback)
+ else:
+ for env_init in env_fn:
+ self.register(env_init)
+ self._retry_type = retry_type
+ self._auto_reset = auto_reset
+ if max_retry:
+ logging.warning("The `max_retry` is going to be deprecated, use `max_try` instead!")
+ self._max_try = max_try or max_retry or 1
+ self._reset_timeout = reset_timeout
+ self._step_timeout = step_timeout
+ self._retry_waiting_time = retry_waiting_time
+ self._env_replay_path = None
+ self._episode_num = episode_num
+ self._init_states()
+
+ def _init_states(self):
+ self._env_seed = {}
+ self._env_dynamic_seed = None
+ self._env_replay_path = None
+ self._env_states = {}
+ self._reset_param = {}
+ self._ready_obs = {}
+ self._env_episode_count = {i: 0 for i in range(self.env_num)}
+ self._retry_times = defaultdict(lambda: 0)
+ self._last_called = defaultdict(lambda: {"step": math.inf, "reset": math.inf})
+
+ def _shm_callback(self, payload: RecvPayload, obs_buffers: Any):
+ """
+ Overview:
+ This method will be called in child worker, so we can put large data into shared memory
+ and replace the original payload data to none, then reduce the serialization/deserialization cost.
+ """
+ if payload.method == "reset" and payload.data is not None:
+ obs_buffers[payload.proc_id].fill(payload.data)
+ payload.data = None
+ elif payload.method == "step" and payload.data is not None:
+ obs_buffers[payload.proc_id].fill(payload.data.obs)
+ payload.data._replace(obs=None)
+
+ def _create_env_ref(self):
+ # env_ref is used to acquire some common attributes of env, like obs_shape and act_shape
+ self._env_ref = self._env_fn[0]()
+ self._env_ref.reset()
+ self._observation_space = self._env_ref.observation_space
+ self._action_space = self._env_ref.action_space
+ self._reward_space = self._env_ref.reward_space
+ self._env_ref.close()
+
+ def step(self, actions: Union[Dict[int, List[Any]], List[Any]], block: bool = True) -> Optional[List[tnp.ndarray]]:
+ """
+ Overview:
+ Execute env step according to input actions. And reset an env if done.
+ Arguments:
+ - actions (:obj:`List[tnp.ndarray]`): Actions came from outer caller like policy, \
+ in structure of {env_id: actions}.
+ - block (:obj:`bool`): If block, return timesteps, else return none.
+ Returns:
+ - timesteps (:obj:`List[tnp.ndarray]`): Each timestep is a tnp.array with observation, reward, done, \
+ info, env_id.
+ """
+ assert not self.closed, "Env supervisor has closed."
+ if isinstance(actions, List):
+ actions = {i: p for i, p in enumerate(actions)}
+ assert actions, "Action is empty!"
+
+ send_payloads = []
+
+ for env_id, act in actions.items():
+ payload = SendPayload(proc_id=env_id, method="step", args=[act])
+ send_payloads.append(payload)
+ self.send(payload)
+
+ if not block:
+ # Retrieve the data for these steps from the recv method
+ return
+
+ # Wait for all steps returns
+ recv_payloads = self.recv_all(
+ send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._step_timeout
+ )
+ return [payload.data for payload in recv_payloads]
+
+ def recv(self, ignore_err: bool = False) -> RecvPayload:
+ """
+ Overview:
+ Wait for recv payload, this function will block the thread.
+ Arguments:
+ - ignore_err (:obj:`bool`): If ignore_err is true, payload with error object will be discarded.\
+ This option will not catch the exception.
+ Returns:
+ - recv_payload (:obj:`RecvPayload`): Recv payload.
+ """
+ self._detect_timeout()
+ try:
+ payload = super().recv(ignore_err=True, timeout=0.1)
+ payload = self._recv_callback(payload=payload)
+ if payload.err:
+ return self.recv(ignore_err=ignore_err)
+ else:
+ return payload
+ except queue.Empty:
+ return self.recv(ignore_err=ignore_err)
+
+ def _detect_timeout(self):
+ """
+ Overview:
+ Try to restart all timeout environments if detected timeout.
+ """
+ for env_id in self._last_called:
+ if self._step_timeout and time() - self._last_called[env_id]["step"] > self._step_timeout:
+ payload = RecvPayload(
+ proc_id=env_id, method="step", err=TimeoutError("Step timeout on env {}".format(env_id))
+ )
+ self._recv_queue.put(payload)
+ continue
+ if self._reset_timeout and time() - self._last_called[env_id]["reset"] > self._reset_timeout:
+ payload = RecvPayload(
+ proc_id=env_id, method="reset", err=TimeoutError("Step timeout on env {}".format(env_id))
+ )
+ self._recv_queue.put(payload)
+ continue
+
+ @property
+ def env_num(self) -> int:
+ return len(self._children)
+
+ @property
+ def observation_space(self) -> 'Space':
+ return self._observation_space
+
+ @property
+ def action_space(self) -> 'Space':
+ return self._action_space
+
+ @property
+ def reward_space(self) -> 'Space':
+ return self._reward_space
+
+ @property
+ def ready_obs(self) -> tnp.array:
+ """
+ Overview:
+ Get the ready (next) observation in ``tnp.array`` type, which is uniform for both async/sync scenarios.
+ Return:
+ - ready_obs (:obj:`tnp.array`): A stacked treenumpy-type observation data.
+ Example:
+ >>> obs = env_manager.ready_obs
+ >>> action = model(obs) # model input np obs and output np action
+ >>> timesteps = env_manager.step(action)
+ """
+ active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN]
+ active_env.sort()
+ obs = [self._ready_obs.get(i) for i in active_env]
+ if len(obs) == 0:
+ return tnp.array([])
+ return tnp.stack(obs)
+
+ @property
+ def ready_obs_id(self) -> List[int]:
+ return [i for i, s in self.env_states.items() if s == EnvState.RUN]
+
+ @property
+ def done(self) -> bool:
+ return all([s == EnvState.DONE for s in self.env_states.values()])
+
+ @property
+ def method_name_list(self) -> List[str]:
+ return ['reset', 'step', 'seed', 'close', 'enable_save_replay']
+
+ @property
+ def env_states(self) -> Dict[int, EnvState]:
+ return {env_id: self._env_states.get(env_id) or EnvState.VOID for env_id in range(self.env_num)}
+
+ def env_state_done(self, env_id: int) -> bool:
+ return self.env_states[env_id] == EnvState.DONE
+
+ def launch(self, reset_param: Optional[Dict] = None, block: bool = True) -> None:
+ """
+ Overview:
+ Set up the environments and their parameters.
+ Arguments:
+ - reset_param (:obj:`Optional[Dict]`): Dict of reset parameters for each environment, key is the env_id, \
+ value is the cooresponding reset parameters.
+ - block (:obj:`block`): Whether will block the process and wait for reset states.
+ """
+ assert self.closed, "Please first close the env supervisor before launch it"
+ if reset_param is not None:
+ assert len(reset_param) == self.env_num
+ self.start_link()
+ self._send_seed(self._env_seed, self._env_dynamic_seed, block=block)
+ self.reset(reset_param, block=block)
+ self._enable_env_replay()
+
+ def reset(self, reset_param: Optional[Dict[int, List[Any]]] = None, block: bool = True) -> None:
+ """
+ Overview:
+ Reset an environment.
+ Arguments:
+ - reset_param (:obj:`Optional[Dict[int, List[Any]]]`): Dict of reset parameters for each environment, \
+ key is the env_id, value is the cooresponding reset parameters.
+ - block (:obj:`block`): Whether will block the process and wait for reset states.
+ """
+ if not reset_param:
+ reset_param = {i: {} for i in range(self.env_num)}
+ elif isinstance(reset_param, List):
+ reset_param = {i: p for i, p in enumerate(reset_param)}
+
+ send_payloads = []
+
+ for env_id, kw_param in reset_param.items():
+ self._reset_param[env_id] = kw_param # For auto reset
+ send_payloads += self._reset(env_id, kw_param=kw_param)
+
+ if not block:
+ return
+
+ self.recv_all(send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._reset_timeout)
+
+ def _recv_callback(
+ self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None
+ ) -> RecvPayload:
+ """
+ Overview:
+ The callback function for each received payload, within this method will modify the state of \
+ each environment, replace objects in shared memory, and determine if a retry is needed due to an error.
+ Arguments:
+ - payload (:obj:`RecvPayload`): The received payload.
+ - remain_payloads (:obj:`Optional[Dict[str, SendPayload]]`): The callback may be called many times \
+ until remain_payloads be cleared, you can append new payload into remain_payloads to call this \
+ callback recursively.
+ """
+ self._set_shared_obs(payload=payload)
+ self.change_state(payload=payload)
+ if payload.method == "reset":
+ return self._recv_reset_callback(payload=payload, remain_payloads=remain_payloads)
+ elif payload.method == "step":
+ return self._recv_step_callback(payload=payload, remain_payloads=remain_payloads)
+ return payload
+
+ def _set_shared_obs(self, payload: RecvPayload):
+ if self._obs_buffers is None:
+ return
+ if payload.method == "reset" and payload.err is None:
+ payload.data = self._obs_buffers[payload.proc_id].get()
+ elif payload.method == "step" and payload.err is None:
+ payload.data._replace(obs=self._obs_buffers[payload.proc_id].get())
+
+ def _recv_reset_callback(
+ self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None
+ ) -> RecvPayload:
+ assert payload.method == "reset", "Recv error callback({}) in reset callback!".format(payload.method)
+ if remain_payloads is None:
+ remain_payloads = {}
+ env_id = payload.proc_id
+ if payload.err:
+ self._retry_times[env_id] += 1
+ if self._retry_times[env_id] > self._max_try - 1:
+ self.shutdown(5)
+ raise RuntimeError(
+ "Env {} reset has exceeded max_try({}), and the latest exception is: {}".format(
+ env_id, self._max_try, payload.err
+ )
+ )
+ if self._retry_waiting_time:
+ sleep(self._retry_waiting_time)
+ if self._retry_type == EnvRetryType.RENEW:
+ self._children[env_id].restart()
+ send_payloads = self._reset(env_id)
+ for p in send_payloads:
+ remain_payloads[p.req_id] = p
+ else:
+ self._retry_times[env_id] = 0
+ self._ready_obs[env_id] = payload.data
+ return payload
+
+ def _recv_step_callback(
+ self, payload: RecvPayload, remain_payloads: Optional[Dict[str, SendPayload]] = None
+ ) -> RecvPayload:
+ assert payload.method == "step", "Recv error callback({}) in step callback!".format(payload.method)
+ if remain_payloads is None:
+ remain_payloads = {}
+ if payload.err:
+ send_payloads = self._reset(payload.proc_id)
+ for p in send_payloads:
+ remain_payloads[p.req_id] = p
+ info = {"abnormal": True, "err": payload.err}
+ payload.data = tnp.array(
+ {
+ 'obs': None,
+ 'reward': None,
+ 'done': None,
+ 'info': info,
+ 'env_id': payload.proc_id
+ }
+ )
+ else:
+ obs, reward, done, info, *_ = payload.data
+ if done:
+ self._env_episode_count[payload.proc_id] += 1
+ if self._env_episode_count[payload.proc_id] < self._episode_num and self._auto_reset:
+ send_payloads = self._reset(payload.proc_id)
+ for p in send_payloads:
+ remain_payloads[p.req_id] = p
+ # make the type and content of key as similar as identifier,
+ # in order to call them as attribute (e.g. timestep.xxx), such as ``TimeLimit.truncated`` in cartpole info
+ info = make_key_as_identifier(info)
+ payload.data = tnp.array(
+ {
+ 'obs': obs,
+ 'reward': reward,
+ 'done': done,
+ 'info': info,
+ 'env_id': payload.proc_id
+ }
+ )
+ self._ready_obs[payload.proc_id] = obs
+ return payload
+
+ def _reset(self, env_id: int, kw_param: Optional[Dict[str, Any]] = None) -> List[SendPayload]:
+ """
+ Overview:
+ Reset an environment. This method does not wait for the result to be returned.
+ Arguments:
+ - env_id (:obj:`int`): Environment id.
+ - kw_param (:obj:`Optional[Dict[str, Any]]`): Reset parameters for the environment.
+ Returns:
+ - send_payloads (:obj:`List[SendPayload]`): The request payloads for seed and reset actions.
+ """
+ assert not self.closed, "Env supervisor has closed."
+ send_payloads = []
+ kw_param = kw_param or self._reset_param[env_id]
+
+ if self._env_replay_path is not None and self.env_states[env_id] == EnvState.RUN:
+ logging.warning("Please don't reset an unfinished env when you enable save replay, we just skip it")
+ return send_payloads
+
+ # Reset env
+ payload = SendPayload(proc_id=env_id, method="reset", kwargs=kw_param)
+ send_payloads.append(payload)
+ self.send(payload)
+
+ return send_payloads
+
+ def _send_seed(self, env_seed: Dict[int, int], env_dynamic_seed: Optional[bool] = None, block: bool = True) -> None:
+ send_payloads = []
+ for env_id, seed in env_seed.items():
+ if seed is None:
+ continue
+ args = [seed]
+ if env_dynamic_seed is not None:
+ args.append(env_dynamic_seed)
+ payload = SendPayload(proc_id=env_id, method="seed", args=args)
+ send_payloads.append(payload)
+ self.send(payload)
+ if not block or not send_payloads:
+ return
+ self.recv_all(send_payloads, ignore_err=True, callback=self._recv_callback, timeout=self._reset_timeout)
+
+ def change_state(self, payload: RecvPayload):
+ self._last_called[payload.proc_id][payload.method] = math.inf # Have recevied
+ if payload.err:
+ self._env_states[payload.proc_id] = EnvState.ERROR
+ elif payload.method == "reset":
+ self._env_states[payload.proc_id] = EnvState.RUN
+ elif payload.method == "step":
+ if payload.data[2]:
+ self._env_states[payload.proc_id] = EnvState.DONE
+
+ def send(self, payload: SendPayload) -> None:
+ self._last_called[payload.proc_id][payload.method] = time()
+ return super().send(payload)
+
+ def seed(self, seed: Union[Dict[int, int], List[int], int], dynamic_seed: Optional[bool] = None) -> None:
+ """
+ Overview:
+ Set the seed for each environment. The seed function will not be called until supervisor.launch \
+ was called.
+ Arguments:
+ - seed (:obj:`Union[Dict[int, int], List[int], int]`): List of seeds for each environment; \
+ Or one seed for the first environment and other seeds are generated automatically. \
+ Note that in threading mode, no matter how many seeds are given, only the last one will take effect. \
+ Because the execution in the thread is asynchronous, the results of each experiment \
+ are different even if a fixed seed is used.
+ - dynamic_seed (:obj:`Optional[bool]`): Dynamic seed is used in the training environment, \
+ trying to make the random seed of each episode different, they are all generated in the reset \
+ method by a random generator 100 * np.random.randint(1 , 1000) (but the seed of this random \
+ number generator is fixed by the environmental seed method, guranteeing the reproducibility \
+ of the experiment). You need not pass the dynamic_seed parameter in the seed method, or pass \
+ the parameter as True.
+ """
+ self._env_seed = {}
+ if isinstance(seed, numbers.Integral):
+ self._env_seed = {i: seed + i for i in range(self.env_num)}
+ elif isinstance(seed, list):
+ assert len(seed) == self.env_num, "len(seed) {:d} != env_num {:d}".format(len(seed), self.env_num)
+ self._env_seed = {i: _seed for i, _seed in enumerate(seed)}
+ elif isinstance(seed, dict):
+ self._env_seed = {env_id: s for env_id, s in seed.items()}
+ else:
+ raise TypeError("Invalid seed arguments type: {}".format(type(seed)))
+ self._env_dynamic_seed = dynamic_seed
+
+ def enable_save_replay(self, replay_path: Union[List[str], str]) -> None:
+ """
+ Overview:
+ Set each env's replay save path.
+ Arguments:
+ - replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \
+ Or one path for all environments.
+ """
+ if isinstance(replay_path, str):
+ replay_path = [replay_path] * self.env_num
+ self._env_replay_path = replay_path
+
+ def _enable_env_replay(self):
+ if self._env_replay_path is None:
+ return
+ send_payloads = []
+ for env_id, s in enumerate(self._env_replay_path):
+ payload = SendPayload(proc_id=env_id, method="enable_save_replay", args=[s])
+ send_payloads.append(payload)
+ self.send(payload)
+ self.recv_all(send_payloads=send_payloads)
+
+ def __getattr__(self, key: str) -> List[Any]:
+ if not hasattr(self._env_ref, key):
+ raise AttributeError("env `{}` doesn't have the attribute `{}`".format(type(self._env_ref), key))
+ return super().__getattr__(key)
+
+ def close(self, timeout: Optional[float] = None) -> None:
+ """
+ In order to be compatible with BaseEnvManager, the new version can use `shutdown` directly.
+ """
+ self.shutdown(timeout=timeout)
+
+ def shutdown(self, timeout: Optional[float] = None) -> None:
+ if self._running:
+ send_payloads = []
+ for env_id in range(self.env_num):
+ payload = SendPayload(proc_id=env_id, method="close")
+ send_payloads.append(payload)
+ self.send(payload)
+ self.recv_all(send_payloads=send_payloads, ignore_err=True, timeout=timeout)
+ super().shutdown(timeout=timeout)
+ self._init_states()
+
+ @property
+ def closed(self) -> bool:
+ return not self._running
diff --git a/DI-engine/ding/envs/env_manager/envpool_env_manager.py b/DI-engine/ding/envs/env_manager/envpool_env_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8d1a4ae03886755e67f74e5162c35de750b5872
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/envpool_env_manager.py
@@ -0,0 +1,126 @@
+import gym
+from easydict import EasyDict
+from copy import deepcopy
+import numpy as np
+from collections import namedtuple
+from typing import Any, Union, List, Tuple, Dict, Callable, Optional
+from ditk import logging
+try:
+ import envpool
+except ImportError:
+ import sys
+ logging.warning("Please install envpool first, use 'pip install envpool'")
+ envpool = None
+
+from ding.envs import BaseEnvTimestep
+from ding.utils import ENV_MANAGER_REGISTRY, deep_merge_dicts
+from ding.torch_utils import to_ndarray
+
+
+@ENV_MANAGER_REGISTRY.register('env_pool')
+class PoolEnvManager:
+ '''
+ Overview:
+ Envpool now supports Atari, Classic Control, Toy Text, ViZDoom.
+ Here we list some commonly used env_ids as follows.
+ For more examples, you can refer to .
+
+ - Atari: "Pong-v5", "SpaceInvaders-v5", "Qbert-v5"
+ - Classic Control: "CartPole-v0", "CartPole-v1", "Pendulum-v1"
+ '''
+
+ @classmethod
+ def default_config(cls) -> EasyDict:
+ return EasyDict(deepcopy(cls.config))
+
+ config = dict(
+ type='envpool',
+ # Sync mode: batch_size == env_num
+ # Async mode: batch_size < env_num
+ env_num=8,
+ batch_size=8,
+ )
+
+ def __init__(self, cfg: EasyDict) -> None:
+ self._cfg = cfg
+ self._env_num = cfg.env_num
+ self._batch_size = cfg.batch_size
+ self._ready_obs = {}
+ self._closed = True
+ self._seed = None
+
+ def launch(self) -> None:
+ assert self._closed, "Please first close the env manager"
+ if self._seed is None:
+ seed = 0
+ else:
+ seed = self._seed
+ self._envs = envpool.make(
+ task_id=self._cfg.env_id,
+ env_type="gym",
+ num_envs=self._env_num,
+ batch_size=self._batch_size,
+ seed=seed,
+ episodic_life=self._cfg.episodic_life,
+ reward_clip=self._cfg.reward_clip,
+ stack_num=self._cfg.stack_num,
+ gray_scale=self._cfg.gray_scale,
+ frame_skip=self._cfg.frame_skip
+ )
+ self._closed = False
+ self.reset()
+
+ def reset(self) -> None:
+ self._ready_obs = {}
+ self._envs.async_reset()
+ while True:
+ obs, _, _, info = self._envs.recv()
+ env_id = info['env_id']
+ obs = obs.astype(np.float32)
+ self._ready_obs = deep_merge_dicts({i: o for i, o in zip(env_id, obs)}, self._ready_obs)
+ if len(self._ready_obs) == self._env_num:
+ break
+ self._eval_episode_return = [0. for _ in range(self._env_num)]
+
+ def step(self, action: dict) -> Dict[int, namedtuple]:
+ env_id = np.array(list(action.keys()))
+ action = np.array(list(action.values()))
+ if len(action.shape) == 2:
+ action = action.squeeze(1)
+ self._envs.send(action, env_id)
+
+ obs, rew, done, info = self._envs.recv()
+ obs = obs.astype(np.float32)
+ rew = rew.astype(np.float32)
+ env_id = info['env_id']
+ timesteps = {}
+ self._ready_obs = {}
+ for i in range(len(env_id)):
+ d = bool(done[i])
+ r = to_ndarray([rew[i]])
+ self._eval_episode_return[env_id[i]] += r
+ timesteps[env_id[i]] = BaseEnvTimestep(obs[i], r, d, info={'env_id': i})
+ if d:
+ timesteps[env_id[i]].info['eval_episode_return'] = self._eval_episode_return[env_id[i]]
+ self._eval_episode_return[env_id[i]] = 0.
+ self._ready_obs[env_id[i]] = obs[i]
+ return timesteps
+
+ def close(self) -> None:
+ if self._closed:
+ return
+ # Envpool has no `close` API
+ self._closed = True
+
+ def seed(self, seed: int, dynamic_seed=False) -> None:
+ # The i-th environment seed in Envpool will be set with i+seed, so we don't do extra transformation here
+ self._seed = seed
+ logging.warning("envpool doesn't support dynamic_seed in different episode")
+
+ @property
+ def env_num(self) -> int:
+ return self._env_num
+
+ @property
+ def ready_obs(self) -> Dict[int, Any]:
+ return self._ready_obs
diff --git a/DI-engine/ding/envs/env_manager/gym_vector_env_manager.py b/DI-engine/ding/envs/env_manager/gym_vector_env_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..46bd8c076900b29aa464349ff88c82c34bd69246
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/gym_vector_env_manager.py
@@ -0,0 +1,137 @@
+from typing import Any, Union, List, Tuple, Dict, Callable, Optional
+from ditk import logging
+import numpy as np
+from easydict import EasyDict
+from collections import namedtuple
+import gym
+from gym.vector.async_vector_env import AsyncVectorEnv
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.torch_utils import to_ndarray, to_list
+from ding.utils import PropagatingThread, LockContextType, LockContext, ENV_MANAGER_REGISTRY
+from .base_env_manager import BaseEnvManager
+from .base_env_manager import EnvState
+
+
+@ENV_MANAGER_REGISTRY.register('gym_vector')
+class GymVectorEnvManager(BaseEnvManager):
+ """
+ Overview:
+ Create an GymVectorEnvManager to manage multiple environments.
+ Each Environment is run by a respective subprocess.
+ Interfaces:
+ seed, ready_obs, step, reset, close
+ """
+ config = dict(shared_memory=False, episode_num=float("inf"))
+
+ def __init__(self, env_fn: List[Callable], cfg: EasyDict) -> None:
+ """
+ .. note::
+ ``env_fn`` must create gym-type environment instance, which may different DI-engine environment.
+ """
+ self._cfg = cfg
+ self._env_fn = env_fn
+ self._env_num = len(self._env_fn)
+ self._closed = True
+ self._env_replay_path = None
+ # env_ref is used to acquire some common attributes of env, like obs_shape and act_shape
+ self._env_ref = self._env_fn[0]()
+ self._env_states = {i: EnvState.VOID for i in range(self._env_num)}
+
+ self._episode_num = self._cfg.episode_num
+ self._env_episode_count = {i: 0 for i in range(self.env_num)}
+
+ self._env_manager = AsyncVectorEnv(
+ env_fns=self._env_fn,
+ # observation_space=observation_space,
+ # action_space=action_space,
+ shared_memory=cfg.shared_memory,
+ )
+ self._env_states = {i: EnvState.INIT for i in range(self._env_num)}
+ self._eval_episode_return = [0. for _ in range(self._env_num)]
+
+ def reset(self, reset_param: Optional[Dict] = None) -> None:
+ assert reset_param is None
+ self._closed = False
+ for env_id in range(self.env_num):
+ self._env_states[env_id] = EnvState.RESET
+ self._ready_obs = self._env_manager.reset()
+ for env_id in range(self.env_num):
+ self._env_states[env_id] = EnvState.RUN
+ self._eval_episode_return = [0. for _ in range(self._env_num)]
+
+ def step(self, actions: Dict[int, Any]) -> Dict[int, namedtuple]:
+ assert isinstance(actions, Dict), type(actions)
+
+ env_ids_given = list(actions.keys())
+ for env_id in range(self.env_num):
+ if env_id not in actions.keys():
+ actions[env_id] = self._env_ref.random_action()
+ """actions should be sorted by keys, since the original implementation
+ of the step method in gym accepts list-type actions"""
+ actions = dict(sorted(actions.items()))
+
+ actions = list(actions.values())
+ elem = actions[0]
+ if not isinstance(elem, np.ndarray):
+ raise Exception('DI-engine only accept np.ndarray-type action!')
+ if elem.shape == (1, ):
+ actions = [v.item() for v in actions]
+
+ timestep = self._env_manager.step(actions)
+ timestep_collate_result = {}
+ for i in range(self.env_num):
+ if i in env_ids_given:
+ # Fix the compatability of API for both gym>=0.24.0 and gym<0.24.0
+ # https://github.com/openai/gym/pull/2773
+ if gym.version.VERSION >= '0.24.0':
+ timestepinfo = {}
+ for k, v in timestep[3].items():
+ timestepinfo[k] = v[i]
+ timestep_collate_result[i] = BaseEnvTimestep(
+ timestep[0][i], timestep[1][i], timestep[2][i], timestepinfo
+ )
+ else:
+ timestep_collate_result[i] = BaseEnvTimestep(
+ timestep[0][i], timestep[1][i], timestep[2][i], timestep[3][i]
+ )
+ self._eval_episode_return[i] += timestep_collate_result[i].reward
+ if timestep_collate_result[i].done:
+ timestep_collate_result[i].info['eval_episode_return'] = self._eval_episode_return[i]
+ self._eval_episode_return[i] = 0
+ self._env_episode_count[i] += 1
+ if self._env_episode_count[i] >= self._episode_num:
+ self._env_states[i] = EnvState.DONE
+ else:
+ self._env_states[i] = EnvState.RESET
+ if all([self._env_states[i] == EnvState.RESET for i in range(self.env_num)]):
+ self.reset()
+ else:
+ self._ready_obs[i] = timestep_collate_result[i].obs
+
+ return timestep_collate_result
+
+ @property
+ def ready_obs(self) -> Dict[int, Any]:
+ return {
+ i: self._ready_obs[i]
+ for i in range(len(self._ready_obs)) if self._env_episode_count[i] < self._episode_num
+ }
+
+ def seed(self, seed: Union[Dict[int, int], List[int], int], dynamic_seed: bool = None) -> None:
+ self._env_manager.seed(seed)
+ # TODO dynamic_seed
+ logging.warning("gym env doesn't support dynamic_seed in different episode")
+
+ def close(self) -> None:
+ """
+ Overview:
+ Release the environment resources
+ Since not calling super.__init__, no need to release BaseEnvManager's resources
+ """
+ if self._closed:
+ return
+ self._closed = True
+ self._env_ref.close()
+ self._env_manager.close()
+ self._env_manager.close_extras(terminate=True)
diff --git a/DI-engine/ding/envs/env_manager/subprocess_env_manager.py b/DI-engine/ding/envs/env_manager/subprocess_env_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..b30fe1039451f3cce7bcf40044b4ae81e40cf697
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/subprocess_env_manager.py
@@ -0,0 +1,834 @@
+from typing import Any, Union, List, Tuple, Dict, Callable, Optional
+from multiprocessing import connection, get_context
+from collections import namedtuple
+from ditk import logging
+import platform
+import time
+import copy
+import gymnasium
+import gym
+import traceback
+import torch
+import pickle
+import numpy as np
+import treetensor.numpy as tnp
+from easydict import EasyDict
+from types import MethodType
+from ding.data import ShmBufferContainer, ShmBuffer
+
+from ding.envs.env import BaseEnvTimestep
+from ding.utils import PropagatingThread, LockContextType, LockContext, ENV_MANAGER_REGISTRY, make_key_as_identifier, \
+ remove_illegal_item, CloudPickleWrapper
+from .base_env_manager import BaseEnvManager, EnvState, timeout_wrapper
+
+
+def is_abnormal_timestep(timestep: namedtuple) -> bool:
+ if isinstance(timestep.info, dict):
+ return timestep.info.get('abnormal', False)
+ elif isinstance(timestep.info, list) or isinstance(timestep.info, tuple):
+ return timestep.info[0].get('abnormal', False) or timestep.info[1].get('abnormal', False)
+ else:
+ raise TypeError("invalid env timestep type: {}".format(type(timestep.info)))
+
+
+@ENV_MANAGER_REGISTRY.register('async_subprocess')
+class AsyncSubprocessEnvManager(BaseEnvManager):
+ """
+ Overview:
+ Create an AsyncSubprocessEnvManager to manage multiple environments.
+ Each Environment is run by a respective subprocess.
+ Interfaces:
+ seed, launch, ready_obs, step, reset, active_env
+ """
+
+ config = dict(
+ episode_num=float("inf"),
+ max_retry=1,
+ step_timeout=None,
+ auto_reset=True,
+ retry_type='reset',
+ reset_timeout=None,
+ retry_waiting_time=0.1,
+ # subprocess specified args
+ shared_memory=True,
+ copy_on_get=True,
+ context='spawn' if platform.system().lower() == 'windows' else 'fork',
+ wait_num=2,
+ step_wait_timeout=0.01,
+ connect_timeout=60,
+ reset_inplace=False,
+ )
+
+ def __init__(
+ self,
+ env_fn: List[Callable],
+ cfg: EasyDict = EasyDict({}),
+ ) -> None:
+ """
+ Overview:
+ Initialize the AsyncSubprocessEnvManager.
+ Arguments:
+ - env_fn (:obj:`List[Callable]`): The function to create environment
+ - cfg (:obj:`EasyDict`): Config
+
+ .. note::
+
+ - wait_num: for each time the minimum number of env return to gather
+ - step_wait_timeout: for each time the minimum number of env return to gather
+ """
+ super().__init__(env_fn, cfg)
+ self._shared_memory = self._cfg.shared_memory
+ self._copy_on_get = self._cfg.copy_on_get
+ self._context = self._cfg.context
+ self._wait_num = self._cfg.wait_num
+ self._step_wait_timeout = self._cfg.step_wait_timeout
+
+ self._lock = LockContext(LockContextType.THREAD_LOCK)
+ self._connect_timeout = self._cfg.connect_timeout
+ self._async_args = {
+ 'step': {
+ 'wait_num': min(self._wait_num, self._env_num),
+ 'timeout': self._step_wait_timeout
+ }
+ }
+ self._reset_inplace = self._cfg.reset_inplace
+ if not self._auto_reset:
+ assert not self._reset_inplace, "reset_inplace is unavailable when auto_reset=False."
+
+ def _create_state(self) -> None:
+ r"""
+ Overview:
+ Fork/spawn sub-processes(Call ``_create_env_subprocess``) and create pipes to transfer the data.
+ """
+ self._env_episode_count = {env_id: 0 for env_id in range(self.env_num)}
+ self._ready_obs = {env_id: None for env_id in range(self.env_num)}
+ self._reset_param = {i: {} for i in range(self.env_num)}
+ if self._shared_memory:
+ obs_space = self._observation_space
+ if isinstance(obs_space, (gym.spaces.Dict, gymnasium.spaces.Dict)):
+ # For multi_agent case, such as multiagent_mujoco and petting_zoo mpe.
+ # Now only for the case that each agent in the team have the same obs structure
+ # and corresponding shape.
+ shape = {k: v.shape for k, v in obs_space.spaces.items()}
+ dtype = {k: v.dtype for k, v in obs_space.spaces.items()}
+ else:
+ shape = obs_space.shape
+ dtype = obs_space.dtype
+ self._obs_buffers = {
+ env_id: ShmBufferContainer(dtype, shape, copy_on_get=self._copy_on_get)
+ for env_id in range(self.env_num)
+ }
+ else:
+ self._obs_buffers = {env_id: None for env_id in range(self.env_num)}
+ self._pipe_parents, self._pipe_children = {}, {}
+ self._subprocesses = {}
+ for env_id in range(self.env_num):
+ self._create_env_subprocess(env_id)
+ self._waiting_env = {'step': set()}
+ self._closed = False
+
+ def _create_env_subprocess(self, env_id):
+ # start a new one
+ ctx = get_context(self._context)
+ self._pipe_parents[env_id], self._pipe_children[env_id] = ctx.Pipe()
+ self._subprocesses[env_id] = ctx.Process(
+ # target=self.worker_fn,
+ target=self.worker_fn_robust,
+ args=(
+ self._pipe_parents[env_id],
+ self._pipe_children[env_id],
+ CloudPickleWrapper(self._env_fn[env_id]),
+ self._obs_buffers[env_id],
+ self.method_name_list,
+ self._reset_timeout,
+ self._step_timeout,
+ self._reset_inplace,
+ ),
+ daemon=True,
+ name='subprocess_env_manager{}_{}'.format(env_id, time.time())
+ )
+ self._subprocesses[env_id].start()
+ self._pipe_children[env_id].close()
+ self._env_states[env_id] = EnvState.INIT
+
+ if self._env_replay_path is not None:
+ self._pipe_parents[env_id].send(['enable_save_replay', [self._env_replay_path[env_id]], {}])
+ self._pipe_parents[env_id].recv()
+
+ @property
+ def ready_env(self) -> List[int]:
+ active_env = [i for i, s in self._env_states.items() if s == EnvState.RUN]
+ return [i for i in active_env if i not in self._waiting_env['step']]
+
+ @property
+ def ready_obs(self) -> Dict[int, Any]:
+ """
+ Overview:
+ Get the next observations.
+ Return:
+ A dictionary with observations and their environment IDs.
+ Note:
+ The observations are returned in np.ndarray.
+ Example:
+ >>> obs_dict = env_manager.ready_obs
+ >>> actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())}
+ """
+ no_done_env_idx = [i for i, s in self._env_states.items() if s != EnvState.DONE]
+ sleep_count = 0
+ while not any([self._env_states[i] == EnvState.RUN for i in no_done_env_idx]):
+ if sleep_count != 0 and sleep_count % 10000 == 0:
+ logging.warning(
+ 'VEC_ENV_MANAGER: all the not done envs are resetting, sleep {} times'.format(sleep_count)
+ )
+ time.sleep(0.001)
+ sleep_count += 1
+ return {i: self._ready_obs[i] for i in self.ready_env}
+
+ @property
+ def ready_imgs(self, render_mode: Optional[str] = 'rgb_array') -> Dict[int, Any]:
+ """
+ Overview:
+ Get the next renderd frames.
+ Return:
+ A dictionary with rendered frames and their environment IDs.
+ Note:
+ The rendered frames are returned in np.ndarray.
+ """
+ for i in self.ready_env:
+ self._pipe_parents[i].send(['render', None, {'render_mode': render_mode}])
+ data = {i: self._pipe_parents[i].recv() for i in self.ready_env}
+ self._check_data(data)
+ return data
+
+ def launch(self, reset_param: Optional[Dict] = None) -> None:
+ """
+ Overview:
+ Set up the environments and their parameters.
+ Arguments:
+ - reset_param (:obj:`Optional[Dict]`): Dict of reset parameters for each environment, key is the env_id, \
+ value is the cooresponding reset parameters.
+ """
+ assert self._closed, "please first close the env manager"
+ if reset_param is not None:
+ assert len(reset_param) == len(self._env_fn)
+ self._create_state()
+ self.reset(reset_param)
+
+ def reset(self, reset_param: Optional[Dict] = None) -> None:
+ """
+ Overview:
+ Reset the environments their parameters.
+ Arguments:
+ - reset_param (:obj:`List`): Dict of reset parameters for each environment, key is the env_id, \
+ value is the cooresponding reset parameters.
+ """
+ self._check_closed()
+
+ if reset_param is None:
+ reset_env_list = [env_id for env_id in range(self._env_num)]
+ else:
+ reset_env_list = reset_param.keys()
+ for env_id in reset_param:
+ self._reset_param[env_id] = reset_param[env_id]
+
+ # clear previous info
+ for env_id in reset_env_list:
+ if env_id in self._waiting_env['step']:
+ self._pipe_parents[env_id].recv()
+ self._waiting_env['step'].remove(env_id)
+
+ sleep_count = 0
+ while any([self._env_states[i] == EnvState.RESET for i in reset_env_list]):
+ if sleep_count != 0 and sleep_count % 10000 == 0:
+ logging.warning(
+ 'VEC_ENV_MANAGER: not all the envs finish resetting, sleep {} times'.format(sleep_count)
+ )
+ time.sleep(0.001)
+ sleep_count += 1
+
+ # reset env
+ reset_thread_list = []
+ for i, env_id in enumerate(reset_env_list):
+ # set seed
+ if self._env_seed[env_id] is not None:
+ try:
+ if self._env_dynamic_seed is not None:
+ self._pipe_parents[env_id].send(['seed', [self._env_seed[env_id], self._env_dynamic_seed], {}])
+ else:
+ self._pipe_parents[env_id].send(['seed', [self._env_seed[env_id]], {}])
+ ret = self._pipe_parents[env_id].recv()
+ self._check_data({env_id: ret})
+ self._env_seed[env_id] = None # seed only use once
+ except BaseException as e:
+ logging.warning(
+ "subprocess reset set seed failed, ignore and continue... \n subprocess exception traceback: \n"
+ + traceback.format_exc()
+ )
+ self._env_states[env_id] = EnvState.RESET
+ reset_thread = PropagatingThread(target=self._reset, args=(env_id, ))
+ reset_thread.daemon = True
+ reset_thread_list.append(reset_thread)
+
+ for t in reset_thread_list:
+ t.start()
+ for t in reset_thread_list:
+ t.join()
+
+ def _reset(self, env_id: int) -> None:
+
+ def reset_fn():
+ if self._pipe_parents[env_id].poll():
+ recv_data = self._pipe_parents[env_id].recv()
+ raise RuntimeError("unread data left before sending to the pipe: {}".format(repr(recv_data)))
+ # if self._reset_param[env_id] is None, just reset specific env, not pass reset param
+ if self._reset_param[env_id] is not None:
+ assert isinstance(self._reset_param[env_id], dict), type(self._reset_param[env_id])
+ self._pipe_parents[env_id].send(['reset', [], self._reset_param[env_id]])
+ else:
+ self._pipe_parents[env_id].send(['reset', [], None])
+
+ if not self._pipe_parents[env_id].poll(self._connect_timeout):
+ raise ConnectionError("env reset connection timeout") # Leave it to try again
+
+ obs = self._pipe_parents[env_id].recv()
+ self._check_data({env_id: obs}, close=False)
+ if self._shared_memory:
+ obs = self._obs_buffers[env_id].get()
+ # it is necessary to add lock for the updates of env_state
+ with self._lock:
+ self._env_states[env_id] = EnvState.RUN
+ self._ready_obs[env_id] = obs
+
+ exceptions = []
+ for _ in range(self._max_retry):
+ try:
+ reset_fn()
+ return
+ except BaseException as e:
+ logging.info("subprocess exception traceback: \n" + traceback.format_exc())
+ if self._retry_type == 'renew' or isinstance(e, pickle.UnpicklingError):
+ self._pipe_parents[env_id].close()
+ if self._subprocesses[env_id].is_alive():
+ self._subprocesses[env_id].terminate()
+ self._create_env_subprocess(env_id)
+ exceptions.append(e)
+ time.sleep(self._retry_waiting_time)
+
+ logging.error("Env {} reset has exceeded max retries({})".format(env_id, self._max_retry))
+ runtime_error = RuntimeError(
+ "Env {} reset has exceeded max retries({}), and the latest exception is: {}".format(
+ env_id, self._max_retry, str(exceptions[-1])
+ )
+ )
+ runtime_error.__traceback__ = exceptions[-1].__traceback__
+ if self._closed: # exception cased by main thread closing parent_remote
+ return
+ else:
+ self.close()
+ raise runtime_error
+
+ def step(self, actions: Dict[int, Any]) -> Dict[int, namedtuple]:
+ """
+ Overview:
+ Step all environments. Reset an env if done.
+ Arguments:
+ - actions (:obj:`Dict[int, Any]`): {env_id: action}
+ Returns:
+ - timesteps (:obj:`Dict[int, namedtuple]`): {env_id: timestep}. Timestep is a \
+ ``BaseEnvTimestep`` tuple with observation, reward, done, env_info.
+ Example:
+ >>> actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())}
+ >>> timesteps = env_manager.step(actions_dict):
+ >>> for env_id, timestep in timesteps.items():
+ >>> pass
+
+ .. note:
+
+ - The env_id that appears in ``actions`` will also be returned in ``timesteps``.
+ - Each environment is run by a subprocess separately. Once an environment is done, it is reset immediately.
+ - Async subprocess env manager use ``connection.wait`` to poll.
+ """
+ self._check_closed()
+ env_ids = list(actions.keys())
+ assert all([self._env_states[env_id] == EnvState.RUN for env_id in env_ids]
+ ), 'current env state are: {}, please check whether the requested env is in reset or done'.format(
+ {env_id: self._env_states[env_id]
+ for env_id in env_ids}
+ )
+
+ for env_id, act in actions.items():
+ self._pipe_parents[env_id].send(['step', [act], None])
+
+ timesteps = {}
+ step_args = self._async_args['step']
+ wait_num, timeout = min(step_args['wait_num'], len(env_ids)), step_args['timeout']
+ rest_env_ids = list(set(env_ids).union(self._waiting_env['step']))
+ ready_env_ids = []
+ cur_rest_env_ids = copy.deepcopy(rest_env_ids)
+ while True:
+ rest_conn = [self._pipe_parents[env_id] for env_id in cur_rest_env_ids]
+ ready_conn, ready_ids = AsyncSubprocessEnvManager.wait(rest_conn, min(wait_num, len(rest_conn)), timeout)
+ cur_ready_env_ids = [cur_rest_env_ids[env_id] for env_id in ready_ids]
+ assert len(cur_ready_env_ids) == len(ready_conn)
+ # timesteps.update({env_id: p.recv() for env_id, p in zip(cur_ready_env_ids, ready_conn)})
+ for env_id, p in zip(cur_ready_env_ids, ready_conn):
+ try:
+ timesteps.update({env_id: p.recv()})
+ except pickle.UnpicklingError as e:
+ timestep = BaseEnvTimestep(None, None, None, {'abnormal': True})
+ timesteps.update({env_id: timestep})
+ self._pipe_parents[env_id].close()
+ if self._subprocesses[env_id].is_alive():
+ self._subprocesses[env_id].terminate()
+ self._create_env_subprocess(env_id)
+ self._check_data(timesteps)
+ ready_env_ids += cur_ready_env_ids
+ cur_rest_env_ids = list(set(cur_rest_env_ids).difference(set(cur_ready_env_ids)))
+ # At least one not done env timestep, or all envs' steps are finished
+ if any([not t.done for t in timesteps.values()]) or len(ready_conn) == len(rest_conn):
+ break
+ self._waiting_env['step']: set
+ for env_id in rest_env_ids:
+ if env_id in ready_env_ids:
+ if env_id in self._waiting_env['step']:
+ self._waiting_env['step'].remove(env_id)
+ else:
+ self._waiting_env['step'].add(env_id)
+
+ if self._shared_memory:
+ for i, (env_id, timestep) in enumerate(timesteps.items()):
+ timesteps[env_id] = timestep._replace(obs=self._obs_buffers[env_id].get())
+
+ for env_id, timestep in timesteps.items():
+ if is_abnormal_timestep(timestep):
+ self._env_states[env_id] = EnvState.ERROR
+ continue
+ if timestep.done:
+ self._env_episode_count[env_id] += 1
+ if self._env_episode_count[env_id] < self._episode_num:
+ if self._auto_reset:
+ if self._reset_inplace: # reset in subprocess at once
+ self._env_states[env_id] = EnvState.RUN
+ self._ready_obs[env_id] = timestep.obs
+ else:
+ # in this case, ready_obs is updated in ``self._reset``
+ self._env_states[env_id] = EnvState.RESET
+ reset_thread = PropagatingThread(target=self._reset, args=(env_id, ), name='regular_reset')
+ reset_thread.daemon = True
+ reset_thread.start()
+ else:
+ # in the case that auto_reset=False, caller should call ``env_manager.reset`` manually
+ self._env_states[env_id] = EnvState.NEED_RESET
+ else:
+ self._env_states[env_id] = EnvState.DONE
+ else:
+ self._ready_obs[env_id] = timestep.obs
+ return timesteps
+
+ # This method must be staticmethod, otherwise there will be some resource conflicts(e.g. port or file)
+ # Env must be created in worker, which is a trick of avoiding env pickle errors.
+ # A more robust version is used by default. But this one is also preserved.
+ @staticmethod
+ def worker_fn(
+ p: connection.Connection,
+ c: connection.Connection,
+ env_fn_wrapper: 'CloudPickleWrapper',
+ obs_buffer: ShmBuffer,
+ method_name_list: list,
+ reset_inplace: bool = False,
+ ) -> None: # noqa
+ """
+ Overview:
+ Subprocess's target function to run.
+ """
+ torch.set_num_threads(1)
+ env_fn = env_fn_wrapper.data
+ env = env_fn()
+ p.close()
+ try:
+ while True:
+ try:
+ cmd, args, kwargs = c.recv()
+ except EOFError: # for the case when the pipe has been closed
+ c.close()
+ break
+ try:
+ if cmd == 'getattr':
+ ret = getattr(env, args[0])
+ elif cmd in method_name_list:
+ if cmd == 'step':
+ timestep = env.step(*args, **kwargs)
+ if is_abnormal_timestep(timestep):
+ ret = timestep
+ else:
+ if reset_inplace and timestep.done:
+ obs = env.reset()
+ timestep = timestep._replace(obs=obs)
+ if obs_buffer is not None:
+ obs_buffer.fill(timestep.obs)
+ timestep = timestep._replace(obs=None)
+ ret = timestep
+ elif cmd == 'reset':
+ ret = env.reset(*args, **kwargs) # obs
+ if obs_buffer is not None:
+ obs_buffer.fill(ret)
+ ret = None
+ elif args is None and kwargs is None:
+ ret = getattr(env, cmd)()
+ else:
+ ret = getattr(env, cmd)(*args, **kwargs)
+ else:
+ raise KeyError("not support env cmd: {}".format(cmd))
+ c.send(ret)
+ except Exception as e:
+ # when there are some errors in env, worker_fn will send the errors to env manager
+ # directly send error to another process will lose the stack trace, so we create a new Exception
+ logging.warning("subprocess exception traceback: \n" + traceback.format_exc())
+ c.send(
+ e.__class__(
+ '\nEnv Process Exception:\n' + ''.join(traceback.format_tb(e.__traceback__)) + repr(e)
+ )
+ )
+ if cmd == 'close':
+ c.close()
+ break
+ except KeyboardInterrupt:
+ c.close()
+
+ @staticmethod
+ def worker_fn_robust(
+ parent,
+ child,
+ env_fn_wrapper,
+ obs_buffer,
+ method_name_list,
+ reset_timeout=None,
+ step_timeout=None,
+ reset_inplace=False,
+ ) -> None:
+ """
+ Overview:
+ A more robust version of subprocess's target function to run. Used by default.
+ """
+ torch.set_num_threads(1)
+ env_fn = env_fn_wrapper.data
+ env = env_fn()
+ parent.close()
+
+ @timeout_wrapper(timeout=step_timeout)
+ def step_fn(*args, **kwargs):
+ timestep = env.step(*args, **kwargs)
+ if is_abnormal_timestep(timestep):
+ ret = timestep
+ else:
+ if reset_inplace and timestep.done:
+ obs = env.reset()
+ timestep = timestep._replace(obs=obs)
+ if obs_buffer is not None:
+ obs_buffer.fill(timestep.obs)
+ timestep = timestep._replace(obs=None)
+ ret = timestep
+ return ret
+
+ @timeout_wrapper(timeout=reset_timeout)
+ def reset_fn(*args, **kwargs):
+ try:
+ ret = env.reset(*args, **kwargs)
+ if obs_buffer is not None:
+ obs_buffer.fill(ret)
+ ret = None
+ return ret
+ except BaseException as e:
+ logging.warning("subprocess exception traceback: \n" + traceback.format_exc())
+ env.close()
+ raise e
+
+ while True:
+ try:
+ cmd, args, kwargs = child.recv()
+ except EOFError: # for the case when the pipe has been closed
+ child.close()
+ break
+ try:
+ if cmd == 'getattr':
+ ret = getattr(env, args[0])
+ elif cmd in method_name_list:
+ if cmd == 'step':
+ ret = step_fn(*args)
+ elif cmd == 'reset':
+ if kwargs is None:
+ kwargs = {}
+ ret = reset_fn(*args, **kwargs)
+ elif cmd == 'render':
+ from ding.utils import render
+ ret = render(env, **kwargs)
+ elif args is None and kwargs is None:
+ ret = getattr(env, cmd)()
+ else:
+ ret = getattr(env, cmd)(*args, **kwargs)
+ else:
+ raise KeyError("not support env cmd: {}".format(cmd))
+ child.send(ret)
+ except BaseException as e:
+ logging.debug("Sub env '{}' error when executing {}".format(str(env), cmd))
+ # when there are some errors in env, worker_fn will send the errors to env manager
+ # directly send error to another process will lose the stack trace, so we create a new Exception
+ logging.warning("subprocess exception traceback: \n" + traceback.format_exc())
+ child.send(
+ e.__class__('\nEnv Process Exception:\n' + ''.join(traceback.format_tb(e.__traceback__)) + repr(e))
+ )
+ if cmd == 'close':
+ child.close()
+ break
+
+ def _check_data(self, data: Dict, close: bool = True) -> None:
+ exceptions = []
+ for i, d in data.items():
+ if isinstance(d, BaseException):
+ self._env_states[i] = EnvState.ERROR
+ exceptions.append(d)
+ # when receiving env Exception, env manager will safely close and raise this Exception to caller
+ if len(exceptions) > 0:
+ if close:
+ self.close()
+ raise exceptions[0]
+
+ # override
+ def __getattr__(self, key: str) -> Any:
+ self._check_closed()
+ # we suppose that all the envs has the same attributes, if you need different envs, please
+ # create different env managers.
+ if not hasattr(self._env_ref, key):
+ raise AttributeError("env `{}` doesn't have the attribute `{}`".format(type(self._env_ref), key))
+ if isinstance(getattr(self._env_ref, key), MethodType) and key not in self.method_name_list:
+ raise RuntimeError("env getattr doesn't supports method({}), please override method_name_list".format(key))
+ for _, p in self._pipe_parents.items():
+ p.send(['getattr', [key], {}])
+ data = {i: p.recv() for i, p in self._pipe_parents.items()}
+ self._check_data(data)
+ ret = [data[i] for i in self._pipe_parents.keys()]
+ return ret
+
+ # override
+ def enable_save_replay(self, replay_path: Union[List[str], str]) -> None:
+ """
+ Overview:
+ Set each env's replay save path.
+ Arguments:
+ - replay_path (:obj:`Union[List[str], str]`): List of paths for each environment; \
+ Or one path for all environments.
+ """
+ if isinstance(replay_path, str):
+ replay_path = [replay_path] * self.env_num
+ self._env_replay_path = replay_path
+
+ # override
+ def close(self) -> None:
+ """
+ Overview:
+ CLose the env manager and release all related resources.
+ """
+ if self._closed:
+ return
+ self._closed = True
+ for _, p in self._pipe_parents.items():
+ p.send(['close', None, None])
+ for env_id, p in self._pipe_parents.items():
+ if not p.poll(5):
+ continue
+ p.recv()
+ for i in range(self._env_num):
+ self._env_states[i] = EnvState.VOID
+ # disable process join for avoiding hang
+ # for p in self._subprocesses:
+ # p.join()
+ for _, p in self._subprocesses.items():
+ p.terminate()
+ for _, p in self._pipe_parents.items():
+ p.close()
+
+ @staticmethod
+ def wait(rest_conn: list, wait_num: int, timeout: Optional[float] = None) -> Tuple[list, list]:
+ """
+ Overview:
+ Wait at least enough(len(ready_conn) >= wait_num) connections within timeout constraint.
+ If timeout is None and wait_num == len(ready_conn), means sync mode;
+ If timeout is not None, will return when len(ready_conn) >= wait_num and
+ this method takes more than timeout seconds.
+ """
+ assert 1 <= wait_num <= len(rest_conn
+ ), 'please indicate proper wait_num: '.format(
+ wait_num, len(rest_conn)
+ )
+ rest_conn_set = set(rest_conn)
+ ready_conn = set()
+ start_time = time.time()
+ while len(rest_conn_set) > 0:
+ if len(ready_conn) >= wait_num and timeout:
+ if (time.time() - start_time) >= timeout:
+ break
+ finish_conn = set(connection.wait(rest_conn_set, timeout=timeout))
+ ready_conn = ready_conn.union(finish_conn)
+ rest_conn_set = rest_conn_set.difference(finish_conn)
+ ready_ids = [rest_conn.index(c) for c in ready_conn]
+ return list(ready_conn), ready_ids
+
+
+@ENV_MANAGER_REGISTRY.register('subprocess')
+class SyncSubprocessEnvManager(AsyncSubprocessEnvManager):
+ config = dict(
+ episode_num=float("inf"),
+ max_retry=1,
+ step_timeout=None,
+ auto_reset=True,
+ reset_timeout=None,
+ retry_type='reset',
+ retry_waiting_time=0.1,
+ # subprocess specified args
+ shared_memory=True,
+ copy_on_get=True,
+ context='spawn' if platform.system().lower() == 'windows' else 'fork',
+ wait_num=float("inf"), # inf mean all the environments
+ step_wait_timeout=None,
+ connect_timeout=60,
+ reset_inplace=False, # if reset_inplace=True in SyncSubprocessEnvManager, the interaction can be reproducible.
+ )
+
+ def step(self, actions: Dict[int, Any]) -> Dict[int, namedtuple]:
+ """
+ Overview:
+ Step all environments. Reset an env if done.
+ Arguments:
+ - actions (:obj:`Dict[int, Any]`): {env_id: action}
+ Returns:
+ - timesteps (:obj:`Dict[int, namedtuple]`): {env_id: timestep}. Timestep is a \
+ ``BaseEnvTimestep`` tuple with observation, reward, done, env_info.
+ Example:
+ >>> actions_dict = {env_id: model.forward(obs) for env_id, obs in obs_dict.items())}
+ >>> timesteps = env_manager.step(actions_dict):
+ >>> for env_id, timestep in timesteps.items():
+ >>> pass
+
+ .. note::
+
+ - The env_id that appears in ``actions`` will also be returned in ``timesteps``.
+ - Each environment is run by a subprocess separately. Once an environment is done, it is reset immediately.
+ """
+ self._check_closed()
+ env_ids = list(actions.keys())
+ assert all([self._env_states[env_id] == EnvState.RUN for env_id in env_ids]
+ ), 'current env state are: {}, please check whether the requested env is in reset or done'.format(
+ {env_id: self._env_states[env_id]
+ for env_id in env_ids}
+ )
+ for env_id, act in actions.items():
+ # it is necessary to set kwargs as None for saving cost of serialization in some env like cartpole,
+ # and step method never uses kwargs in known envs.
+ self._pipe_parents[env_id].send(['step', [act], None])
+
+ # === This part is different from async one. ===
+ # === Because operate in this way is more efficient. ===
+ timesteps = {}
+ ready_conn = [self._pipe_parents[env_id] for env_id in env_ids]
+ # timesteps.update({env_id: p.recv() for env_id, p in zip(env_ids, ready_conn)})
+ for env_id, p in zip(env_ids, ready_conn):
+ try:
+ timesteps.update({env_id: p.recv()})
+ except pickle.UnpicklingError as e:
+ timestep = BaseEnvTimestep(None, None, None, {'abnormal': True})
+ timesteps.update({env_id: timestep})
+ self._pipe_parents[env_id].close()
+ if self._subprocesses[env_id].is_alive():
+ self._subprocesses[env_id].terminate()
+ self._create_env_subprocess(env_id)
+ self._check_data(timesteps)
+ # ======================================================
+
+ if self._shared_memory:
+ # TODO(nyz) optimize sync shm
+ for i, (env_id, timestep) in enumerate(timesteps.items()):
+ timesteps[env_id] = timestep._replace(obs=self._obs_buffers[env_id].get())
+ for env_id, timestep in timesteps.items():
+ if is_abnormal_timestep(timestep):
+ self._env_states[env_id] = EnvState.ERROR
+ continue
+ if timestep.done:
+ self._env_episode_count[env_id] += 1
+ if self._env_episode_count[env_id] < self._episode_num:
+ if self._auto_reset:
+ if self._reset_inplace: # reset in subprocess at once
+ self._env_states[env_id] = EnvState.RUN
+ self._ready_obs[env_id] = timestep.obs
+ else:
+ # in this case, ready_obs is updated in ``self._reset``
+ self._env_states[env_id] = EnvState.RESET
+ reset_thread = PropagatingThread(target=self._reset, args=(env_id, ), name='regular_reset')
+ reset_thread.daemon = True
+ reset_thread.start()
+ else:
+ # in the case that auto_reset=False, caller should call ``env_manager.reset`` manually
+ self._env_states[env_id] = EnvState.NEED_RESET
+ else:
+ self._env_states[env_id] = EnvState.DONE
+ else:
+ self._ready_obs[env_id] = timestep.obs
+ return timesteps
+
+
+@ENV_MANAGER_REGISTRY.register('subprocess_v2')
+class SubprocessEnvManagerV2(SyncSubprocessEnvManager):
+ """
+ Overview:
+ SyncSubprocessEnvManager for new task pipeline and interfaces coupled with treetensor.
+ """
+
+ @property
+ def ready_obs(self) -> tnp.array:
+ """
+ Overview:
+ Get the ready (next) observation in ``tnp.array`` type, which is uniform for both async/sync scenarios.
+ Return:
+ - ready_obs (:obj:`tnp.array`): A stacked treenumpy-type observation data.
+ Example:
+ >>> obs = env_manager.ready_obs
+ >>> action = model(obs) # model input np obs and output np action
+ >>> timesteps = env_manager.step(action)
+ """
+ no_done_env_idx = [i for i, s in self._env_states.items() if s != EnvState.DONE]
+ sleep_count = 0
+ while not any([self._env_states[i] == EnvState.RUN for i in no_done_env_idx]):
+ if sleep_count != 0 and sleep_count % 10000 == 0:
+ logging.warning(
+ 'VEC_ENV_MANAGER: all the not done envs are resetting, sleep {} times'.format(sleep_count)
+ )
+ time.sleep(0.001)
+ sleep_count += 1
+ return tnp.stack([tnp.array(self._ready_obs[i]) for i in self.ready_env])
+
+ def step(self, actions: Union[List[tnp.ndarray], tnp.ndarray]) -> List[tnp.ndarray]:
+ """
+ Overview:
+ Execute env step according to input actions. And reset an env if done.
+ Arguments:
+ - actions (:obj:`Union[List[tnp.ndarray], tnp.ndarray]`): actions came from outer caller like policy.
+ Returns:
+ - timesteps (:obj:`List[tnp.ndarray]`): Each timestep is a tnp.array with observation, reward, done, \
+ info, env_id.
+ """
+ if isinstance(actions, tnp.ndarray):
+ # zip operation will lead to wrong behaviour if not split data
+ split_action = tnp.split(actions, actions.shape[0])
+ split_action = [s.squeeze(0) for s in split_action]
+ else:
+ split_action = actions
+ actions = {env_id: a for env_id, a in zip(self.ready_obs_id, split_action)}
+ timesteps = super().step(actions)
+ new_data = []
+ for env_id, timestep in timesteps.items():
+ obs, reward, done, info = timestep
+ # make the type and content of key as similar as identifier,
+ # in order to call them as attribute (e.g. timestep.xxx), such as ``TimeLimit.truncated`` in cartpole info
+ info = make_key_as_identifier(info)
+ info = remove_illegal_item(info)
+ new_data.append(tnp.array({'obs': obs, 'reward': reward, 'done': done, 'info': info, 'env_id': env_id}))
+ return new_data
diff --git a/DI-engine/ding/envs/env_manager/tests/__init__.py b/DI-engine/ding/envs/env_manager/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/ding/envs/env_manager/tests/conftest.py b/DI-engine/ding/envs/env_manager/tests/conftest.py
new file mode 100644
index 0000000000000000000000000000000000000000..f824899a0de331a13d3f7c1bfdffc0ca5506805b
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/tests/conftest.py
@@ -0,0 +1,254 @@
+import random
+import time
+from collections import namedtuple
+import pytest
+import torch
+import numpy as np
+from easydict import EasyDict
+from functools import partial
+import gym
+
+from ding.envs.env.base_env import BaseEnvTimestep
+from ding.envs.env_manager.base_env_manager import EnvState
+from ding.envs.env_manager import BaseEnvManager, SyncSubprocessEnvManager, AsyncSubprocessEnvManager
+from ding.torch_utils import to_tensor, to_ndarray, to_list
+from ding.utils import deep_merge_dicts
+
+
+class FakeEnv(object):
+
+ def __init__(self, cfg):
+ self._scale = cfg.scale
+ self._target_time = random.randint(3, 6) * self._scale
+ self._current_time = 0
+ self._name = cfg['name']
+ self._id = time.time()
+ self._stat = None
+ self._seed = 0
+ self._data_count = 0
+ self.timeout_flag = False
+ self._launched = False
+ self._state = EnvState.INIT
+ self._dead_once = False
+ self.observation_space = gym.spaces.Box(
+ low=np.array([-1.0, -1.0, -8.0]), high=np.array([1.0, 1.0, 8.0]), shape=(3, ), dtype=np.float32
+ )
+ self.action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(1, ), dtype=np.float32)
+ self.reward_space = gym.spaces.Box(
+ low=-1 * (3.14 * 3.14 + 0.1 * 8 * 8 + 0.001 * 2 * 2), high=0.0, shape=(1, ), dtype=np.float32
+ )
+
+ def reset(self, stat=None):
+ if isinstance(stat, str) and stat == 'error':
+ self.dead()
+ if isinstance(stat, str) and stat == 'error_once':
+ # Die on every two reset with error_once stat.
+ if self._dead_once:
+ self._dead_once = False
+ self.dead()
+ else:
+ self._dead_once = True
+ if isinstance(stat, str) and stat == "wait":
+ if self.timeout_flag: # after step(), the reset can hall with status of timeout
+ time.sleep(5)
+ if isinstance(stat, str) and stat == "block":
+ self.block()
+
+ self._launched = True
+ self._current_time = 0
+ self._stat = stat
+ self._state = EnvState.RUN
+ return to_ndarray(torch.randn(3))
+
+ def step(self, action):
+ assert self._launched
+ assert not self._state == EnvState.ERROR
+ self.timeout_flag = True # after one step, enable timeout flag
+ if isinstance(action, str) and action == 'error':
+ self.dead()
+ if isinstance(action, str) and action == 'catched_error':
+ return BaseEnvTimestep(None, None, True, {'abnormal': True})
+ if isinstance(action, str) and action == "wait":
+ if self.timeout_flag: # after step(), the reset can hall with status of timeout
+ time.sleep(3)
+ if isinstance(action, str) and action == 'block':
+ self.block()
+ obs = to_ndarray(torch.randn(3))
+ reward = to_ndarray(torch.randint(0, 2, size=[1]).numpy())
+ done = self._current_time >= self._target_time
+ if done:
+ self._state = EnvState.DONE
+ simulation_time = random.uniform(0.5, 1) * self._scale
+ info = {'name': self._name, 'time': simulation_time, 'tgt': self._target_time, 'cur': self._current_time}
+ time.sleep(simulation_time)
+ self._current_time += simulation_time
+ self._data_count += 1
+ return BaseEnvTimestep(obs, reward, done, info)
+
+ def dead(self):
+ self._state = EnvState.ERROR
+ raise RuntimeError("env error, current time {}".format(self._current_time))
+
+ def block(self):
+ self._state = EnvState.ERROR
+ time.sleep(1000)
+
+ def close(self):
+ self._launched = False
+ self._state = EnvState.INIT
+
+ def seed(self, seed):
+ self._seed = seed
+
+ @property
+ def name(self):
+ return self._name
+
+ @property
+ def time_id(self):
+ return self._id
+
+ def user_defined(self):
+ pass
+
+ def __repr__(self):
+ return self._name
+
+
+class FakeAsyncEnv(FakeEnv):
+
+ def reset(self, stat=None):
+ super().reset(stat)
+ time.sleep(random.randint(1, 3) * self._scale)
+ return to_ndarray(torch.randn(3))
+
+
+class FakeGymEnv(FakeEnv):
+
+ def __init__(self, cfg):
+ super().__init__(cfg)
+ self.metadata = "fake metadata"
+ self.action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(4, ), dtype=np.float32)
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ if isinstance(random_action, np.ndarray):
+ pass
+ elif isinstance(random_action, int):
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ elif isinstance(random_action, dict):
+ random_action = to_ndarray(random_action)
+ else:
+ raise TypeError(
+ '`random_action` should be either int/np.ndarray or dict of int/np.ndarray, but get {}: {}'.format(
+ type(random_action), random_action
+ )
+ )
+ return random_action
+
+
+class FakeModel(object):
+
+ def forward(self, obs):
+ if random.random() > 0.5:
+ return {k: [] for k in obs}
+ else:
+ env_num = len(obs)
+ exec_env = random.randint(1, env_num + 1)
+ keys = list(obs.keys())[:exec_env]
+ return {k: [] for k in keys}
+
+
+@pytest.fixture(scope='class')
+def setup_model_type():
+ return FakeModel
+
+
+def get_base_manager_cfg(env_num=3):
+ manager_cfg = {
+ 'env_cfg': [{
+ 'name': 'name{}'.format(i),
+ 'scale': 1.0,
+ } for i in range(env_num)],
+ 'episode_num': 2,
+ 'reset_timeout': 10,
+ 'step_timeout': 8,
+ 'max_retry': 5,
+ }
+ return EasyDict(manager_cfg)
+
+
+def get_subprecess_manager_cfg(env_num=3):
+ manager_cfg = {
+ 'env_cfg': [{
+ 'name': 'name{}'.format(i),
+ 'scale': 1.0,
+ } for i in range(env_num)],
+ 'episode_num': 2,
+ #'step_timeout': 8,
+ #'reset_timeout': 10,
+ 'connect_timeout': 8,
+ 'step_timeout': 5,
+ 'max_retry': 2,
+ }
+ return EasyDict(manager_cfg)
+
+
+def get_gym_vector_manager_cfg(env_num=3):
+ manager_cfg = {
+ 'env_cfg': [{
+ 'name': 'name{}'.format(i),
+ } for i in range(env_num)],
+ 'episode_num': 2,
+ 'connect_timeout': 8,
+ 'step_timeout': 5,
+ 'max_retry': 2,
+ 'share_memory': True
+ }
+ return EasyDict(manager_cfg)
+
+
+@pytest.fixture(scope='function')
+def setup_base_manager_cfg():
+ manager_cfg = get_base_manager_cfg(3)
+ env_cfg = manager_cfg.pop('env_cfg')
+ manager_cfg['env_fn'] = [partial(FakeEnv, cfg=c) for c in env_cfg]
+ return deep_merge_dicts(BaseEnvManager.default_config(), EasyDict(manager_cfg))
+
+
+@pytest.fixture(scope='function')
+def setup_fast_base_manager_cfg():
+ manager_cfg = get_base_manager_cfg(3)
+ env_cfg = manager_cfg.pop('env_cfg')
+ for e in env_cfg:
+ e['scale'] = 0.1
+ manager_cfg['env_fn'] = [partial(FakeEnv, cfg=c) for c in env_cfg]
+ return deep_merge_dicts(BaseEnvManager.default_config(), EasyDict(manager_cfg))
+
+
+@pytest.fixture(scope='function')
+def setup_sync_manager_cfg():
+ manager_cfg = get_subprecess_manager_cfg(3)
+ env_cfg = manager_cfg.pop('env_cfg')
+ # TODO(nyz) test fail when shared_memory = True
+ manager_cfg['shared_memory'] = False
+ manager_cfg['env_fn'] = [partial(FakeEnv, cfg=c) for c in env_cfg]
+ return deep_merge_dicts(SyncSubprocessEnvManager.default_config(), EasyDict(manager_cfg))
+
+
+@pytest.fixture(scope='function')
+def setup_async_manager_cfg():
+ manager_cfg = get_subprecess_manager_cfg(3)
+ env_cfg = manager_cfg.pop('env_cfg')
+ manager_cfg['env_fn'] = [partial(FakeAsyncEnv, cfg=c) for c in env_cfg]
+ manager_cfg['shared_memory'] = False
+ return deep_merge_dicts(AsyncSubprocessEnvManager.default_config(), EasyDict(manager_cfg))
+
+
+@pytest.fixture(scope='function')
+def setup_gym_vector_manager_cfg():
+ manager_cfg = get_subprecess_manager_cfg(3)
+ env_cfg = manager_cfg.pop('env_cfg')
+ manager_cfg['env_fn'] = [partial(FakeGymEnv, cfg=c) for c in env_cfg]
+ manager_cfg['shared_memory'] = False
+ return EasyDict(manager_cfg)
diff --git a/DI-engine/ding/envs/env_manager/tests/test_base_env_manager.py b/DI-engine/ding/envs/env_manager/tests/test_base_env_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..973a0a1291dbe61f6e445adf0650c3241887a8c3
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/tests/test_base_env_manager.py
@@ -0,0 +1,202 @@
+import time
+import signal
+import pytest
+import torch
+import numpy as np
+
+from ..base_env_manager import BaseEnvManagerV2, EnvState
+
+
+@pytest.mark.unittest
+class TestBaseEnvManagerV2:
+
+ def test_naive(self, setup_base_manager_cfg):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ env_manager = BaseEnvManagerV2(env_fn, setup_base_manager_cfg)
+ env_manager.seed([314 for _ in range(env_manager.env_num)])
+ assert env_manager._closed
+ obs = env_manager.launch(reset_param={i: {'stat': 'stat_test'} for i in range(env_manager.env_num)})
+ assert all([env_manager._env_states[env_id] == EnvState.RUN for env_id in range(env_manager.env_num)])
+ # Test basic
+ name = env_manager._name
+ assert len(name) == env_manager.env_num
+ assert all([isinstance(n, str) for n in name])
+ assert env_manager._max_retry == 5
+ assert env_manager._reset_timeout == 10
+ assert all([s == 314 for s in env_manager._seed])
+ assert all([s == 'stat_test'] for s in env_manager._stat)
+ # Test arribute
+ with pytest.raises(AttributeError):
+ _ = env_manager.xxx
+ with pytest.raises(RuntimeError):
+ env_manager.user_defined()
+ # Test step
+ count = 1
+ start_time = time.time()
+ while not env_manager.done:
+ env_id = env_manager.ready_obs_id
+ action = {i: np.random.randn(4) for i in env_id}
+ timestep = env_manager.step(action)
+ assert len(timestep) == len(env_id)
+ print('Count {}'.format(count))
+ print([v.info for v in timestep])
+ print([v.done for v in timestep])
+ count += 1
+ end_time = time.time()
+ print('total step time: {}'.format(end_time - start_time))
+ assert all([env_manager._env_states[env_id] == EnvState.DONE for env_id in range(env_manager.env_num)])
+ assert all([c == setup_base_manager_cfg.episode_num for c in env_manager._env_episode_count.values()])
+ # Test close
+ env_manager.close()
+ assert env_manager._closed
+ assert all([not env_manager._envs[env_id]._launched for env_id in range(env_manager.env_num)])
+ assert all([env_manager._env_states[env_id] == EnvState.VOID for env_id in range(env_manager.env_num)])
+ with pytest.raises(AssertionError):
+ env_manager.reset([])
+ with pytest.raises(AssertionError):
+ env_manager.step([])
+
+ def test_error(self, setup_base_manager_cfg):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ env_manager = BaseEnvManagerV2(env_fn, setup_base_manager_cfg)
+ # Test reset error
+ with pytest.raises(RuntimeError):
+ reset_param = {i: {'stat': 'error'} for i in range(env_manager.env_num)}
+ obs = env_manager.launch(reset_param=reset_param)
+ assert env_manager._closed
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_manager.env_num)}
+ obs = env_manager.launch(reset_param=reset_param)
+ assert not env_manager._closed
+
+ timestep = env_manager.step({i: np.random.randn(4) for i in range(env_manager.env_num)})
+ assert len(timestep) == env_manager.env_num
+ # Test reset error once
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_manager.env_num)}
+ assert env_manager._retry_type == 'reset'
+ env_id_0 = env_manager.time_id[0]
+ reset_param[0] = {'stat': 'error_once'}
+ env_manager.reset(reset_param)
+ env_manager.reset(reset_param)
+ assert not env_manager._closed
+ assert env_manager.time_id[0] == env_id_0
+ env_manager._retry_type = 'renew'
+ env_id_0 = env_manager.time_id[0]
+ reset_param[0] = {'stat': 'error_once'}
+ env_manager.reset(reset_param)
+ assert not env_manager._closed
+ assert env_manager.time_id[0] != env_id_0
+
+ # Test step catched error
+ action = [np.random.randn(4) for i in range(env_manager.env_num)]
+ action[0] = 'catched_error'
+ timestep = env_manager.step(action)
+ assert timestep[0].info.abnormal
+ assert all(['abnormal' not in timestep[i].info for i in range(1, env_manager.env_num)])
+ assert all([env_manager._env_states[i] == EnvState.RUN for i in range(env_manager.env_num)])
+ assert len(env_manager.ready_obs) == 3
+ timestep = env_manager.step({i: np.random.randn(4) for i in range(env_manager.env_num)})
+ # Test step error
+ action[0] = 'error'
+ with pytest.raises(RuntimeError):
+ timestep = env_manager.step(action)
+ assert env_manager._env_states[0] == EnvState.ERROR
+ assert all([env_manager._env_states[i] == EnvState.RUN for i in range(1, env_manager.env_num)])
+ obs = env_manager.reset(reset_param)
+ assert all([env_manager._env_states[i] == EnvState.RUN for i in range(env_manager.env_num)])
+ assert len(env_manager.ready_obs) == 3
+ timestep = env_manager.step({i: np.random.randn(4) for i in range(env_manager.env_num)})
+
+ env_manager.close()
+
+ @pytest.mark.timeout(60)
+ def test_block(self, setup_base_manager_cfg):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ setup_base_manager_cfg['max_retry'] = 1
+ env_manager = BaseEnvManagerV2(env_fn, setup_base_manager_cfg)
+ assert env_manager._max_retry == 1
+ # Test reset timeout
+ with pytest.raises(RuntimeError):
+ reset_param = {i: {'stat': 'block'} for i in range(env_manager.env_num)}
+ obs = env_manager.launch(reset_param=reset_param)
+ assert env_manager._closed
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_manager.env_num)}
+ reset_param[0]['stat'] = 'wait'
+
+ obs = env_manager.launch(reset_param=reset_param)
+ assert not env_manager._closed
+
+ timestep = env_manager.step({i: np.random.randn(4) for i in range(env_manager.env_num)})
+ assert len(timestep) == env_manager.env_num
+ # Test step timeout
+ action = [np.random.randn(4) for i in range(env_manager.env_num)]
+ action[0] = 'block'
+ with pytest.raises(RuntimeError):
+ timestep = env_manager.step(action)
+ assert all([env_manager._env_states[i] == EnvState.RUN for i in range(1, env_manager.env_num)])
+
+ obs = env_manager.reset(reset_param)
+ action[0] = 'wait'
+ timestep = env_manager.step(action)
+ assert len(timestep) == env_manager.env_num
+
+ env_manager.close()
+
+ def test_reset(self, setup_fast_base_manager_cfg, setup_model_type):
+ assert setup_fast_base_manager_cfg['episode_num'] > 1
+ env_fn = setup_fast_base_manager_cfg.pop('env_fn')
+ model = setup_model_type()
+ # auto_reset = True
+ env_manager = BaseEnvManagerV2(env_fn, setup_fast_base_manager_cfg)
+ env_manager.launch()
+ while True:
+ obs = env_manager.ready_obs
+ env_id = env_manager.ready_obs_id
+ obs = {i: o for i, o in zip(env_id, obs)}
+ action = model.forward(obs)
+ timestep = env_manager.step(action)
+ if env_manager.done:
+ break
+ assert all(
+ env_manager._env_episode_count[i] == setup_fast_base_manager_cfg['episode_num']
+ for i in range(env_manager.env_num)
+ )
+ assert all(env_manager._env_states[i] == EnvState.DONE for i in range(env_manager.env_num))
+
+ # auto_reset = False
+ setup_fast_base_manager_cfg['auto_reset'] = False
+ env_manager = BaseEnvManagerV2(env_fn, setup_fast_base_manager_cfg)
+ env_manager.launch()
+
+ while True:
+ obs = env_manager.ready_obs
+ env_id = env_manager.ready_obs_id
+ obs = {i: o for i, o in zip(env_id, obs)}
+ action = model.forward(obs)
+ timestep = env_manager.step(action)
+ if env_manager.done:
+ break
+ if all(env_manager._env_states[i] == EnvState.NEED_RESET for i in range(env_manager.env_num)):
+ env_manager.reset()
+ assert all(env_manager._env_episode_count[i] == 2 for i in range(env_manager.env_num))
+ assert all(env_manager._env_states[i] == EnvState.DONE for i in range(env_manager.env_num))
+ # auto_reset = False and reset each env independently
+ env_manager = BaseEnvManagerV2(env_fn, setup_fast_base_manager_cfg)
+ env_manager.launch()
+
+ while True:
+ obs = env_manager.ready_obs
+ env_id = env_manager.ready_obs_id
+ obs = {i: o for i, o in zip(env_id, obs)}
+ action = model.forward(obs)
+ timestep = env_manager.step(action)
+ if env_manager.done:
+ break
+ for t in timestep:
+ env_id = t.env_id.item()
+ if t.done and not env_manager.env_state_done(env_id):
+ env_manager.reset({env_id: {}})
+ assert all(
+ env_manager._env_episode_count[i] == setup_fast_base_manager_cfg['episode_num']
+ for i in range(env_manager.env_num)
+ )
+ assert all(env_manager._env_states[i] == EnvState.DONE for i in range(env_manager.env_num))
diff --git a/DI-engine/ding/envs/env_manager/tests/test_env_supervisor.py b/DI-engine/ding/envs/env_manager/tests/test_env_supervisor.py
new file mode 100644
index 0000000000000000000000000000000000000000..c16a7d1c6d6916a3bab4d7ad21479b48223af6b6
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/tests/test_env_supervisor.py
@@ -0,0 +1,423 @@
+import time
+import pytest
+import numpy as np
+import treetensor.numpy as tnp
+from ding.envs.env_manager import EnvSupervisor
+from ding.envs.env_manager.env_supervisor import EnvState
+from ding.framework.supervisor import ChildType
+from gym.spaces import Space
+
+
+class TestEnvSupervisorCompatible:
+ "Test compatibility with base env manager."
+
+ @pytest.mark.unittest
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_naive(self, setup_base_manager_cfg, type_):
+ """
+ To be compatible with the original env_manager, here uses the original configuration and blocking methods.
+ {
+ 'env_cfg': [{
+ 'name': 'name{}'.format(i),
+ 'scale': 1.0,
+ } for i in range(env_num)],
+ 'episode_num': 2,
+ 'reset_timeout': 10,
+ 'step_timeout': 8,
+ 'max_retry': 5,
+ }
+ """
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ env_supervisor = EnvSupervisor(type_=type_, env_fn=env_fn, **{**setup_base_manager_cfg, "auto_reset": False})
+ try:
+ env_supervisor.seed([314 for _ in range(env_supervisor.env_num)])
+ assert env_supervisor.closed
+ env_supervisor.launch(reset_param={i: {'stat': 'stat_test'} for i in range(env_supervisor.env_num)})
+
+ # Test basic
+ assert all([s == 314 for s in env_supervisor._env_seed.values()])
+
+ # Test step
+ count = 1
+ start_time = time.time()
+
+ # Loop over each env until done
+ while not env_supervisor.done:
+ env_id = env_supervisor.ready_obs_id
+ action = {i: np.random.randn(4) for i in env_id}
+ timestep = env_supervisor.step(action)
+ assert len(timestep) == len(env_id)
+ print('Count {}'.format(count))
+ count += 1
+
+ end_time = time.time()
+ print('Total step time: {}'.format(end_time - start_time))
+
+ assert all([env_supervisor.env_states[env_id] == EnvState.DONE for env_id in range(env_supervisor.env_num)])
+
+ finally:
+ # Test close
+ env_supervisor.close()
+
+ assert env_supervisor.closed
+ assert all([env_supervisor.env_states[env_id] == EnvState.VOID for env_id in range(env_supervisor.env_num)])
+ with pytest.raises(AssertionError):
+ env_supervisor.reset([])
+ with pytest.raises(AssertionError):
+ env_supervisor.step([])
+
+ @pytest.mark.unittest
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_reset_error(self, setup_base_manager_cfg, type_):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ env_supervisor = EnvSupervisor(type_=type_, env_fn=env_fn, **setup_base_manager_cfg)
+ # Test reset error
+ with pytest.raises(RuntimeError):
+ reset_param = {i: {'stat': 'error'} for i in range(env_supervisor.env_num)}
+ env_supervisor.launch(reset_param=reset_param)
+ assert env_supervisor.closed
+
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_reset_error_once(self, setup_base_manager_cfg, type_):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ env_supervisor = EnvSupervisor(type_=type_, env_fn=env_fn, **setup_base_manager_cfg)
+ # Normal launch
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_supervisor.env_num)}
+ env_supervisor.launch(reset_param=reset_param)
+
+ env_id_0 = env_supervisor.time_id[0]
+ # Normal step
+ timestep = env_supervisor.step({i: np.random.randn(4) for i in range(env_supervisor.env_num)})
+ assert len(timestep) == env_supervisor.env_num
+
+ # Test reset error once, will still go correct.
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_supervisor.env_num)}
+ assert env_supervisor._retry_type == 'reset'
+ reset_param[0] = {'stat': 'error_once'}
+ env_supervisor.reset(reset_param)
+ env_supervisor.reset(reset_param)
+
+ # If retry type is reset, time id should be equal
+ assert env_supervisor.time_id[0] == env_id_0
+ assert all([state == EnvState.RUN for state in env_supervisor.env_states.values()])
+ env_supervisor.close()
+
+ @pytest.mark.unittest
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_renew_error(self, setup_base_manager_cfg, type_):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ env_supervisor = EnvSupervisor(type_=type_, env_fn=env_fn, **{**setup_base_manager_cfg, "retry_type": "renew"})
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_supervisor.env_num)}
+ env_supervisor.launch(reset_param=reset_param)
+
+ assert env_supervisor._retry_type == "renew"
+ env_id_0 = env_supervisor.time_id[0]
+
+ reset_param[0] = {'stat': 'error_once'}
+ env_supervisor.reset(reset_param)
+ env_supervisor.reset(reset_param)
+ assert not env_supervisor.closed
+ # If retry type is renew, time id should not be equal
+ assert env_supervisor.time_id[0] != env_id_0
+ assert len(env_supervisor.ready_obs) == 3
+ for i, obs in enumerate(env_supervisor.ready_obs):
+ assert all(x == y for x, y in zip(obs, env_supervisor._ready_obs.get(i)))
+
+ # Test step catched error
+ action = [np.random.randn(4) for i in range(env_supervisor.env_num)]
+ action[0] = 'catched_error'
+ timestep = env_supervisor.step(action)
+ assert timestep[0].info.abnormal
+
+ assert all(['abnormal' not in timestep[i].info for i in range(1, env_supervisor.env_num)])
+ # With auto_reset, abnormal timestep with done==True will be auto reset.
+ assert all([env_supervisor.env_states[i] == EnvState.RUN for i in range(env_supervisor.env_num)])
+ assert len(env_supervisor.ready_obs) == 3
+ env_supervisor.close()
+
+ @pytest.mark.tmp # gitlab ci and local test pass, github always fail
+ @pytest.mark.timeout(60)
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_block_launch(self, setup_base_manager_cfg, type_):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ setup_base_manager_cfg['max_retry'] = 1
+ setup_base_manager_cfg['reset_timeout'] = 7
+
+ env_supervisor = EnvSupervisor(type_=type_, env_fn=env_fn, **setup_base_manager_cfg)
+ with pytest.raises(RuntimeError):
+ reset_param = {i: {'stat': 'block'} for i in range(env_supervisor.env_num)}
+ env_supervisor.launch(reset_param=reset_param)
+ assert env_supervisor.closed
+
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_supervisor.env_num)}
+ reset_param[0]['stat'] = 'wait'
+
+ env_supervisor.launch(reset_param=reset_param)
+ assert not env_supervisor.closed
+
+ env_supervisor.close(1)
+
+ @pytest.mark.tmp # gitlab ci and local test pass, github always fail
+ @pytest.mark.timeout(60)
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_block_step(self, setup_base_manager_cfg, type_):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ setup_base_manager_cfg['max_retry'] = 1
+ setup_base_manager_cfg['reset_timeout'] = 7
+
+ env_supervisor = EnvSupervisor(type_=type_, env_fn=env_fn, **setup_base_manager_cfg)
+
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_supervisor.env_num)}
+ env_supervisor.launch(reset_param=reset_param)
+
+ timestep = env_supervisor.step({i: np.random.randn(4) for i in range(env_supervisor.env_num)})
+ assert len(timestep) == env_supervisor.env_num
+
+ # Block step will reset env, thus cause runtime error
+ env_supervisor._reset_param[0] = {"stat": "block"}
+ # Test step timeout
+ action = [np.random.randn(4) for i in range(env_supervisor.env_num)]
+ action[0] = 'block'
+
+ with pytest.raises(RuntimeError):
+ timestep = env_supervisor.step(action)
+ assert env_supervisor.closed
+
+ env_supervisor.launch(reset_param)
+ action[0] = 'wait'
+ timestep = env_supervisor.step(action)
+ assert len(timestep) == env_supervisor.env_num
+
+ env_supervisor.close(1)
+
+ @pytest.mark.unittest
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_properties(self, setup_base_manager_cfg, type_):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ env_supervisor = EnvSupervisor(type_=type_, env_fn=env_fn, **setup_base_manager_cfg)
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_supervisor.env_num)}
+ env_supervisor.launch(reset_param=reset_param)
+
+ assert isinstance(env_supervisor.action_space, Space)
+ assert isinstance(env_supervisor.reward_space, Space)
+ assert isinstance(env_supervisor.observation_space, Space)
+ env_supervisor.close()
+
+ @pytest.mark.unittest
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_auto_reset(self, setup_base_manager_cfg, type_):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ env_supervisor = EnvSupervisor(
+ type_=type_, env_fn=env_fn, **{
+ **setup_base_manager_cfg, "auto_reset": True,
+ "episode_num": 1000
+ }
+ )
+ env_supervisor.launch(reset_param={i: {'stat': 'stat_test'} for i in range(env_supervisor.env_num)})
+
+ assert len(env_supervisor.ready_obs) == 3
+ assert len(env_supervisor.ready_obs_id) == 3
+
+ timesteps = []
+
+ for _ in range(10):
+ action = {i: np.random.randn(4) for i in range(env_supervisor.env_num)}
+ timesteps.append(env_supervisor.step(action))
+ assert len(env_supervisor.ready_obs) == 3
+ time.sleep(1)
+ timesteps = tnp.stack(timesteps).reshape(-1)
+ assert len(timesteps.done) == 30
+ assert any(done for done in timesteps.done)
+ assert all([env_supervisor.env_states[env_id] == EnvState.RUN for env_id in range(env_supervisor.env_num)])
+ env_supervisor.close()
+
+
+class TestEnvSupervisor:
+ """
+ Test async usage
+ """
+
+ @pytest.mark.unittest
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_normal(self, setup_base_manager_cfg, type_):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ setup_base_manager_cfg["auto_reset"] = False
+ env_supervisor = EnvSupervisor(type_=type_, env_fn=env_fn, **setup_base_manager_cfg)
+ env_supervisor.seed([314 for _ in range(env_supervisor.env_num)])
+ env_supervisor.launch(
+ reset_param={i: {
+ 'stat': 'stat_test'
+ }
+ for i in range(env_supervisor.env_num)}, block=False
+ )
+
+ count = 0
+ start_time = time.time()
+ while not env_supervisor.done:
+ recv_payload = env_supervisor.recv()
+ if recv_payload.method == "reset": # Recv reset obs
+ assert len(recv_payload.data) == 3
+ elif recv_payload.method == "step":
+ assert isinstance(recv_payload.data, tnp.ndarray)
+ if env_supervisor.env_states[recv_payload.proc_id] != EnvState.DONE:
+ action = {recv_payload.proc_id: np.random.randn(4)}
+ env_supervisor.step(action, block=False)
+ count += 1
+ print("Count", count)
+
+ end_time = time.time()
+ print("Total step time: {}".format(end_time - start_time))
+
+ env_supervisor.close()
+ assert env_supervisor.closed
+
+ @pytest.mark.unittest
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_reset_error(self, setup_base_manager_cfg, type_):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ env_supervisor = EnvSupervisor(type_=type_, env_fn=env_fn, **setup_base_manager_cfg)
+ with pytest.raises(RuntimeError):
+ reset_param = {i: {'stat': 'error'} for i in range(env_supervisor.env_num)}
+ env_supervisor.launch(reset_param=reset_param, block=False)
+ while True:
+ env_supervisor.recv()
+ env_supervisor.close()
+
+ @pytest.mark.unittest
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_reset_error_once(self, setup_base_manager_cfg, type_):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ env_supervisor = EnvSupervisor(type_=type_, env_fn=env_fn, **setup_base_manager_cfg)
+ # Normal launch
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_supervisor.env_num)}
+ env_supervisor.launch(reset_param=reset_param)
+
+ env_id_0 = env_supervisor.time_id[0]
+
+ # Normal step
+ env_supervisor.step({i: np.random.randn(4) for i in range(env_supervisor.env_num)}, block=False)
+ timestep = []
+ while len(timestep) != 3:
+ payload = env_supervisor.recv()
+ if payload.method == "step":
+ timestep.append(payload.data)
+ assert len(timestep) == env_supervisor.env_num
+
+ # Test reset error once, will still go correct.
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_supervisor.env_num)}
+ assert env_supervisor._retry_type == 'reset'
+ reset_param[0] = {'stat': 'error_once'}
+ env_supervisor.reset(reset_param, block=False) # First try, success
+ env_supervisor.reset(reset_param, block=False) # Second try, error and recover
+
+ reset_obs = []
+ while len(reset_obs) != 6:
+ reset_obs.append(env_supervisor.recv(ignore_err=True))
+ assert env_supervisor.time_id[0] == env_id_0
+ assert all([state == EnvState.RUN for state in env_supervisor.env_states.values()])
+ env_supervisor.close()
+
+ @pytest.mark.unittest
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_renew_error_once(self, setup_base_manager_cfg, type_):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ setup_base_manager_cfg["retry_type"] = "renew"
+ setup_base_manager_cfg["shared_memory"] = False
+ env_supervisor = EnvSupervisor(type_=type_, env_fn=env_fn, **setup_base_manager_cfg)
+ # Normal launch
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_supervisor.env_num)}
+ env_supervisor.launch(reset_param=reset_param)
+
+ env_id_0 = env_supervisor.time_id[0]
+ reset_param[0] = {'stat': 'error_once'}
+ env_supervisor.reset(reset_param, block=False)
+ env_supervisor.reset(reset_param, block=False)
+
+ reset_obs = []
+ while len(reset_obs) != 6:
+ reset_obs.append(env_supervisor.recv(ignore_err=True))
+
+ assert env_supervisor.time_id[0] != env_id_0
+ assert len(env_supervisor.ready_obs) == 3
+
+ # Test step catched error
+ action = [np.random.randn(4) for i in range(env_supervisor.env_num)]
+ action[0] = 'catched_error'
+ env_supervisor.step(action, block=False)
+
+ timestep = {}
+ while len(timestep) != 3:
+ payload = env_supervisor.recv()
+ if payload.method == "step":
+ timestep[payload.proc_id] = payload.data
+ assert len(timestep) == env_supervisor.env_num
+ assert timestep[0].info.abnormal
+
+ assert all(['abnormal' not in timestep[i].info for i in range(1, env_supervisor.env_num)])
+ env_supervisor.close()
+
+ @pytest.mark.tmp # gitlab ci and local test pass, github always fail
+ @pytest.mark.timeout(60)
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_block_launch(self, setup_base_manager_cfg, type_):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ setup_base_manager_cfg["retry_type"] = "renew"
+ setup_base_manager_cfg['max_retry'] = 1
+ setup_base_manager_cfg['reset_timeout'] = 7
+
+ env_supervisor = EnvSupervisor(type_=type_, env_fn=env_fn, **setup_base_manager_cfg)
+ with pytest.raises(RuntimeError):
+ reset_param = {i: {'stat': 'block'} for i in range(env_supervisor.env_num)}
+ env_supervisor.launch(reset_param=reset_param, block=False)
+ while True:
+ payload = env_supervisor.recv()
+ assert env_supervisor.closed
+
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_supervisor.env_num)}
+ reset_param[0]['stat'] = 'wait'
+
+ env_supervisor.launch(reset_param=reset_param, block=False)
+
+ reset_obs = []
+ while len(reset_obs) != 4:
+ payload = env_supervisor.recv(ignore_err=True)
+ if payload.method == "reset":
+ reset_obs.append(payload.data)
+
+ env_supervisor.close(1)
+
+ @pytest.mark.tmp # gitlab ci and local test pass, github always fail
+ @pytest.mark.timeout(60)
+ @pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+ def test_block_step(self, setup_base_manager_cfg, type_):
+ env_fn = setup_base_manager_cfg.pop('env_fn')
+ setup_base_manager_cfg["retry_type"] = "renew"
+ setup_base_manager_cfg['max_retry'] = 1
+ setup_base_manager_cfg['reset_timeout'] = 7
+
+ env_supervisor = EnvSupervisor(type_=type_, env_fn=env_fn, **setup_base_manager_cfg)
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_supervisor.env_num)}
+ env_supervisor.launch(reset_param=reset_param)
+
+ # Block step will reset env, thus cause runtime error
+ env_supervisor._reset_param[0] = {"stat": "block"}
+ # Test step timeout
+ action = [np.random.randn(4) for i in range(env_supervisor.env_num)]
+ action[0] = 'block'
+
+ with pytest.raises(RuntimeError):
+ env_supervisor.step(action, block=False)
+ while True:
+ env_supervisor.recv()
+ assert env_supervisor.closed
+
+ env_supervisor.launch(reset_param)
+ action[0] = 'wait'
+ env_supervisor.step(action, block=False)
+ timestep = []
+ while len(timestep) != 4:
+ payload = env_supervisor.recv(ignore_err=True)
+ if payload.method == "step":
+ timestep.append(payload.data)
+
+ env_supervisor.close(1)
diff --git a/DI-engine/ding/envs/env_manager/tests/test_envpool_env_manager.py b/DI-engine/ding/envs/env_manager/tests/test_envpool_env_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ac77307736a1231ed3aa154a53aeea0db11d930
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/tests/test_envpool_env_manager.py
@@ -0,0 +1,46 @@
+import time
+import pytest
+import numpy as np
+from easydict import EasyDict
+
+from ..envpool_env_manager import PoolEnvManager
+
+env_num_args = [[16, 8], [8, 8]]
+
+
+@pytest.mark.envpooltest
+@pytest.mark.parametrize('env_num, batch_size', env_num_args)
+class TestPoolEnvManager:
+
+ def test_naive(self, env_num, batch_size):
+ env_manager_cfg = EasyDict(
+ {
+ 'env_id': 'Pong-v5',
+ 'env_num': env_num,
+ 'batch_size': batch_size,
+ 'seed': 3,
+ # env wrappers
+ 'episodic_life': False,
+ 'reward_clip': False,
+ 'gray_scale': True,
+ 'stack_num': 4,
+ 'frame_skip': 4,
+ }
+ )
+ env_manager = PoolEnvManager(env_manager_cfg)
+ assert env_manager._closed
+ env_manager.launch()
+ # Test step
+ start_time = time.time()
+ for count in range(20):
+ env_id = env_manager.ready_obs.keys()
+ action = {i: np.random.randint(4) for i in env_id}
+ timestep = env_manager.step(action)
+ assert len(timestep) == env_manager_cfg.batch_size
+ print('Count {}'.format(count))
+ print([v.info for v in timestep.values()])
+ end_time = time.time()
+ print('total step time: {}'.format(end_time - start_time))
+ # Test close
+ env_manager.close()
+ assert env_manager._closed
diff --git a/DI-engine/ding/envs/env_manager/tests/test_gym_vector_env_manager.py b/DI-engine/ding/envs/env_manager/tests/test_gym_vector_env_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ea79f2f47938f2d535b1d4cc8b1e14641094c04
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/tests/test_gym_vector_env_manager.py
@@ -0,0 +1,56 @@
+import time
+import signal
+import pytest
+import torch
+import numpy as np
+
+from ding.envs.env_manager.base_env_manager import BaseEnvManager, EnvState
+from ding.envs.env_manager.gym_vector_env_manager import GymVectorEnvManager
+from gym.vector.async_vector_env import AsyncState
+
+
+@pytest.mark.tmp
+# @pytest.mark.unittest
+class TestGymVectorEnvManager:
+
+ def test_naive(self, setup_gym_vector_manager_cfg):
+ env_fn = setup_gym_vector_manager_cfg.pop('env_fn')
+ env_manager = GymVectorEnvManager(env_fn, setup_gym_vector_manager_cfg)
+ env_manager.seed([314 for _ in range(env_manager.env_num)])
+ # Test reset
+ obs = env_manager.reset()
+ assert not env_manager._closed
+ assert env_manager._env_manager._state == AsyncState.DEFAULT
+ # Test arribute
+ with pytest.raises(AttributeError):
+ _ = env_manager.xxx
+ with pytest.raises(RuntimeError):
+ env_manager.user_defined()
+ # Test step
+ count = 1
+ start_time = time.time()
+ while not env_manager.done:
+ env_id = env_manager.ready_obs.keys()
+ assert all(env_manager._env_episode_count[i] < env_manager._episode_num for i in env_id)
+ action = {i: np.random.randn(3) for i in env_id}
+ timestep = env_manager.step(action)
+ assert len(timestep) == len(env_id)
+ print('Count {}'.format(count))
+ print([v.info for v in timestep.values()])
+ print([v.done for v in timestep.values()])
+ count += 1
+ end_time = time.time()
+ print('total step time: {}'.format(end_time - start_time))
+ assert all(env_manager._env_episode_count[i] == env_manager._episode_num for i in env_id)
+
+ # Test close
+ assert not env_manager._closed
+ env_manager.close()
+ assert env_manager._closed
+ assert env_manager._env_ref._state == EnvState.INIT
+ # assert all([not env_manager._envs[env_id]._launched for env_id in range(env_manager.env_num)])
+ # assert all([env_manager._env_states[env_id] == EnvState.VOID for env_id in range(env_manager.env_num)])
+ with pytest.raises(AssertionError):
+ env_manager.reset([])
+ with pytest.raises(AssertionError):
+ env_manager.step([])
diff --git a/DI-engine/ding/envs/env_manager/tests/test_shm.py b/DI-engine/ding/envs/env_manager/tests/test_shm.py
new file mode 100644
index 0000000000000000000000000000000000000000..37647ca8c7073ca1b439afd8d748b7a908ba733c
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/tests/test_shm.py
@@ -0,0 +1,37 @@
+import pytest
+import time
+import numpy as np
+import torch
+from multiprocessing import Process
+
+from ding.envs.env_manager.subprocess_env_manager import ShmBuffer
+
+
+def writer(shm):
+ while True:
+ shm.fill(np.random.random(size=(4, 84, 84)).astype(np.float32))
+ time.sleep(1)
+
+
+@pytest.mark.unittest
+def test_shm():
+
+ shm = ShmBuffer(dtype=np.float32, shape=(4, 84, 84), copy_on_get=False)
+ writer_process = Process(target=writer, args=(shm, ))
+ writer_process.start()
+
+ time.sleep(0.1)
+
+ data1 = shm.get()
+ time.sleep(1)
+ data2 = shm.get()
+ # same memory
+ assert (data1 == data2).all()
+
+ time.sleep(1)
+ data3 = shm.get().copy()
+ time.sleep(1)
+ data4 = shm.get()
+ assert (data3 != data4).all()
+
+ writer_process.terminate()
diff --git a/DI-engine/ding/envs/env_manager/tests/test_subprocess_env_manager.py b/DI-engine/ding/envs/env_manager/tests/test_subprocess_env_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..218b87e383434f1ec2a0387d0ea9cf850e1876b3
--- /dev/null
+++ b/DI-engine/ding/envs/env_manager/tests/test_subprocess_env_manager.py
@@ -0,0 +1,220 @@
+import time
+import signal
+import pytest
+import torch
+import numpy as np
+
+from ..base_env_manager import EnvState
+from ..subprocess_env_manager import AsyncSubprocessEnvManager, SyncSubprocessEnvManager
+
+
+class TestSubprocessEnvManager:
+
+ @pytest.mark.unittest
+ def test_naive(self, setup_async_manager_cfg, setup_model_type):
+ env_fn = setup_async_manager_cfg.pop('env_fn')
+ env_manager = AsyncSubprocessEnvManager(env_fn, setup_async_manager_cfg)
+ model = setup_model_type()
+
+ env_manager.seed([314 for _ in range(env_manager.env_num)])
+ env_manager.launch(reset_param={i: {'stat': 'stat_test'} for i in range(env_manager.env_num)})
+ assert all([s == 314 for s in env_manager._seed])
+ assert all([s == 'stat_test'] for s in env_manager._stat)
+ # Test basic
+ name = env_manager._name
+ for i in range(env_manager.env_num):
+ assert name[i] == 'name{}'.format(i)
+ assert len(name) == env_manager.env_num
+ assert all([isinstance(n, str) for n in name])
+ name = env_manager.name
+ assert len(name) == env_manager.env_num
+ assert all([isinstance(n, str) for n in name])
+ assert env_manager._max_retry == 2
+ assert env_manager._connect_timeout == 8
+ assert env_manager._step_timeout == 5
+ # Test arribute
+ with pytest.raises(AttributeError):
+ data = env_manager.xxx
+ env_manager._env_ref.user_defined()
+ with pytest.raises(RuntimeError):
+ env_manager.user_defined()
+ # Test step
+ env_count = [0 for _ in range(env_manager.env_num)]
+ data_count = 0
+ start_time = time.time()
+ while not env_manager.done:
+ obs = env_manager.ready_obs
+ print('obs', obs.keys(), env_manager._env_states)
+ action = model.forward(obs)
+ assert 1 <= len(action) <= len(obs)
+ print('act', action.keys())
+ timestep = env_manager.step(action)
+ data_count += len(timestep)
+ assert len(timestep) >= 1
+ print('timestep', timestep.keys(), timestep, len(timestep))
+ for k, t in timestep.items():
+ if t.done:
+ print('env{} finish episode{}'.format(k, env_count[k]))
+ env_count[k] += 1
+ assert all([c == setup_async_manager_cfg.episode_num for c in env_count])
+ assert data_count == sum(env_manager._data_count)
+ assert all([env_manager._env_states[env_id] == EnvState.DONE for env_id in range(env_manager.env_num)])
+ end_time = time.time()
+ print('total step time: {}'.format(end_time - start_time))
+
+ # Test close
+ env_manager.close()
+ assert env_manager._closed
+ with pytest.raises(AssertionError):
+ env_manager.reset([])
+ with pytest.raises(AssertionError):
+ env_manager.step([])
+
+ @pytest.mark.unittest
+ def test_error(self, setup_sync_manager_cfg):
+ env_fn = setup_sync_manager_cfg.pop('env_fn')
+ env_manager = SyncSubprocessEnvManager(env_fn, setup_sync_manager_cfg)
+ # Test reset error
+ with pytest.raises(AssertionError):
+ env_manager.reset(reset_param={i: {'stat': 'stat_test'} for i in range(env_manager.env_num)})
+ with pytest.raises(RuntimeError):
+ env_manager.launch(reset_param={i: {'stat': 'error'} for i in range(env_manager.env_num)})
+ assert env_manager._closed
+ time.sleep(0.5) # necessary time interval
+ env_manager.launch(reset_param={i: {'stat': 'stat_test'} for i in range(env_manager.env_num)})
+ assert not env_manager._closed
+
+ timestep = env_manager.step({i: np.random.randn(4) for i in range(env_manager.env_num)})
+ assert len(timestep) == env_manager.env_num
+
+ # Test reset error once
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_manager.env_num)}
+ assert env_manager._retry_type == 'reset'
+ env_id_0 = env_manager.time_id[0]
+ reset_param[0] = {'stat': 'error_once'}
+ env_manager.reset(reset_param)
+ assert not env_manager._closed
+ assert env_manager.time_id[0] == env_id_0
+ env_manager._retry_type = 'renew'
+ env_id_0 = env_manager.time_id[0]
+ reset_param[0] = {'stat': 'error_once'}
+ env_manager.reset(reset_param)
+ assert not env_manager._closed
+ assert env_manager.time_id[0] != env_id_0
+
+ # Test step catched error
+ action = {i: np.random.randn(4) for i in range(env_manager.env_num)}
+ action[0] = 'catched_error'
+ assert not env_manager._closed
+ timestep = env_manager.step(action)
+ assert not env_manager._closed
+
+ assert timestep[0].info['abnormal']
+ assert all(['abnormal' not in timestep[i].info for i in range(1, env_manager.env_num)])
+ assert env_manager._env_states[0] == EnvState.ERROR
+ assert len(env_manager.ready_obs) == 2
+ # wait for reset
+ env_manager.reset({0: {'stat': 'stat_test'}})
+ while not len(env_manager.ready_obs) == env_manager.env_num:
+ time.sleep(0.1)
+ assert env_manager._env_states[0] == EnvState.RUN
+ assert len(env_manager.ready_obs) == 3
+ timestep = env_manager.step({i: np.random.randn(4) for i in range(env_manager.env_num)})
+
+ # # Test step error
+ action[0] = 'error'
+ with pytest.raises(RuntimeError):
+ timestep = env_manager.step(action)
+ assert env_manager._closed
+
+ env_manager.close()
+ with pytest.raises(AssertionError): # Assert env manager is not closed
+ env_manager.reset([])
+ with pytest.raises(AssertionError): # Assert env manager is not closed
+ env_manager.step([])
+
+ @pytest.mark.tmp # gitlab ci and local test pass, github always fail
+ @pytest.mark.timeout(100)
+ def test_block(self, setup_async_manager_cfg, setup_model_type):
+ env_fn = setup_async_manager_cfg.pop('env_fn')
+ env_manager = AsyncSubprocessEnvManager(env_fn, setup_async_manager_cfg)
+ model = setup_model_type()
+ # Test connect timeout
+ with pytest.raises(RuntimeError):
+ reset_param = {i: {'stat': 'block'} for i in range(env_manager.env_num)}
+ obs = env_manager.launch(reset_param=reset_param)
+ assert env_manager._closed
+ time.sleep(0.5)
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_manager.env_num)}
+ reset_param[0]['stat'] = 'wait'
+ env_manager.launch(reset_param=reset_param)
+ time.sleep(0.5)
+ assert not env_manager._closed
+
+ timestep = env_manager.step({i: np.random.randn(4) for i in range(env_manager.env_num)})
+ obs = env_manager.ready_obs
+ assert len(obs) >= 1
+
+ # Test reset timeout
+ env_manager._connect_timeout = 30
+ env_manager._reset_timeout = 8
+ with pytest.raises(RuntimeError):
+ reset_param = {i: {'stat': 'block'} for i in range(env_manager.env_num)}
+ obs = env_manager.reset(reset_param=reset_param)
+ assert env_manager._closed
+ time.sleep(0.5)
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_manager.env_num)}
+ reset_param[0]['stat'] = 'wait'
+ env_manager.launch(reset_param=reset_param)
+ time.sleep(0.5)
+ assert not env_manager._closed
+
+ # Test step timeout
+ env_manager._step_timeout = 5
+ obs = env_manager.reset({i: {'stat': 'stat_test'} for i in range(env_manager.env_num)})
+ action = {i: np.random.randn(4) for i in range(env_manager.env_num)}
+ action[0] = 'block'
+ with pytest.raises(TimeoutError):
+ timestep = env_manager.step(action)
+ obs = env_manager.ready_obs
+ while 0 not in obs:
+ action = model.forward(obs)
+ timestep = env_manager.step(action)
+ obs = env_manager.ready_obs
+ time.sleep(0.5)
+
+ obs = env_manager.launch(reset_param={i: {'stat': 'stat_test'} for i in range(env_manager.env_num)})
+ time.sleep(1)
+ action[0] = 'wait'
+ timestep = env_manager.step(action)
+ obs = env_manager.ready_obs
+ while 0 not in obs:
+ action = model.forward(obs)
+ timestep = env_manager.step(action)
+ obs = env_manager.ready_obs
+ assert len(obs) >= 1
+
+ env_manager.close()
+
+ @pytest.mark.unittest
+ def test_reset(self, setup_async_manager_cfg, setup_model_type):
+ env_fn = setup_async_manager_cfg.pop('env_fn')
+ setup_async_manager_cfg['auto_reset'] = False
+ env_manager = AsyncSubprocessEnvManager(env_fn, setup_async_manager_cfg)
+ model = setup_model_type()
+ reset_param = {i: {'stat': 'stat_test'} for i in range(env_manager.env_num)}
+ obs = env_manager.launch(reset_param=reset_param)
+ while True:
+ obs = env_manager.ready_obs
+ action = model.forward(obs)
+ timestep = env_manager.step(action)
+ if env_manager.done:
+ break
+ for env_id, t in timestep.items():
+ if t.done and not env_manager.env_state_done(env_id):
+ env_manager.reset({env_id: None})
+ assert all(
+ env_manager._env_episode_count[i] == setup_async_manager_cfg['episode_num']
+ for i in range(env_manager.env_num)
+ )
+ assert all(env_manager._env_states[i] == EnvState.DONE for i in range(env_manager.env_num))
diff --git a/DI-engine/ding/envs/env_wrappers/__init__.py b/DI-engine/ding/envs/env_wrappers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd8907b35c3c04a4c391b91faca29d0e0f04888e
--- /dev/null
+++ b/DI-engine/ding/envs/env_wrappers/__init__.py
@@ -0,0 +1 @@
+from .env_wrappers import *
diff --git a/DI-engine/ding/envs/env_wrappers/env_wrappers.py b/DI-engine/ding/envs/env_wrappers/env_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..08b1ce4eb19c7a454023a96b446d9103ba173ea4
--- /dev/null
+++ b/DI-engine/ding/envs/env_wrappers/env_wrappers.py
@@ -0,0 +1,1579 @@
+"""
+This code is adapted from OpenAI Baselines:
+ https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
+
+List of Environment Wrappers:
+- NoopResetWrapper: This wrapper facilitates the sampling of initial states by executing a random number of
+ no-operation actions upon environment reset.
+- MaxAndSkipWrapper: Incorporates max pooling across time steps, a method that reduces the temporal dimension by taking
+ the maximum value over specified time intervals.
+- WarpFrameWrapper: Implements frame warping by resizing the images to 84x84, a common preprocessing step in
+ reinforcement learning on visual data, as described in the DeepMind Nature paper and subsequent works.
+- ScaledFloatFrameWrapper: Normalizes observations to a range of 0 to 1, which is a common requirement for neural
+ network inputs.
+- ClipRewardWrapper: Clips the reward to {-1, 0, +1} based on its sign. This simplifies the reward structure and
+ can make learning more stable in environments with high variance in rewards.
+- DelayRewardWrapper: Returns cumulative reward at defined intervals, and at all other times, returns a reward of 0.
+ This can be useful for sparse reward problems.
+- FrameStackWrapper: Stacks the latest 'n' frames as a single observation. This allows the agent to have a sense of
+ dynamics and motion from the stacked frames.
+- ObsTransposeWrapper: Transposes the observation to bring the channel to the first dimension, a common requirement
+ for convolutional neural networks.
+- ObsNormWrapper: Normalizes observations based on a running mean and standard deviation. This can help to standardize
+ inputs for the agent and speed up learning.
+- RewardNormWrapper: Normalizes reward based on a running standard deviation, which can stabilize learning in
+ environments with high variance in rewards.
+- RamWrapper: Wraps a RAM-based environment into an image-like environment. This can be useful for applying
+ image-based algorithms to RAM-based Atari games.
+- EpisodicLifeWrapper: Treats end of life as the end of an episode, but only resets on true game over. This can help
+ the agent better differentiate between losing a life and losing the game.
+- FireResetWrapper: Executes the 'fire' action upon environment reset. This is specific to certain Atari games where
+ the 'fire' action starts the game.
+- GymHybridDictActionWrapper: Transforms the original `gym.spaces.Tuple` action space into a `gym.spaces.Dict`.
+- FlatObsWrapper: Flattens image and language observations into a single vector, which can be helpful for input into
+ certain types of models.
+- StaticObsNormWrapper: Provides functionality for normalizing observations according to a static mean and
+ standard deviation.
+- EvalEpisodeReturnWrapper: Evaluates the return over an episode during evaluation, providing a more comprehensive
+ view of the agent's performance.
+- GymToGymnasiumWrapper: Adapts environments from the Gym library to be compatible with the Gymnasium library.
+- AllinObsWrapper: Consolidates all information into the observation, useful for environments where the agent's
+ observation should include additional information such as the current score or time remaining.
+- ObsPlusPrevActRewWrapper: This wrapper is used in policy NGU. It sets a dict as the new wrapped observation,
+ which includes the current observation, previous action and previous reward.
+"""
+
+import copy
+import operator
+from collections import deque
+from functools import reduce
+from typing import Union, Any, Tuple, Dict, List
+
+import gym
+import gymnasium
+import numpy as np
+from easydict import EasyDict
+
+from ding.torch_utils import to_ndarray
+from ding.utils import ENV_WRAPPER_REGISTRY, import_module
+
+
+@ENV_WRAPPER_REGISTRY.register('noop_reset')
+class NoopResetWrapper(gym.Wrapper):
+ """
+ Overview:
+ Sample initial states by taking random number of no-ops on reset. No-op is assumed to be action 0.
+ Interfaces:
+ __init__, reset
+ Properties:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ - noop_max (:obj:`int`): the maximum value of no-ops to run.
+ """
+
+ def __init__(self, env: gym.Env, noop_max: int = 30):
+ """
+ Overview:
+ Initialize the NoopResetWrapper.
+ Arguments:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ - noop_max (:obj:`int`): the maximum value of no-ops to run. Defaults to 30.
+ """
+ super().__init__(env)
+ self.noop_max = noop_max
+ self.noop_action = 0
+ assert env.unwrapped.get_action_meanings()[0] == 'NOOP'
+
+ def reset(self) -> np.ndarray:
+ """
+ Overview:
+ Resets the state of the environment and returns an initial observation,
+ after taking a random number of no-ops.
+ Returns:
+ - observation (:obj:`Any`): The initial observation after no-ops.
+ """
+ self.env.reset()
+ noops = np.random.randint(1, self.noop_max + 1)
+ for _ in range(noops):
+ obs, _, done, _ = self.env.step(self.noop_action)
+ if done:
+ obs = self.env.reset()
+ return obs
+
+
+@ENV_WRAPPER_REGISTRY.register('max_and_skip')
+class MaxAndSkipWrapper(gym.Wrapper):
+ """
+ Overview:
+ Wraps the environment to return only every ``skip``-th frame (frameskipping) \
+ using most recent raw observations (for max pooling across time steps).
+ Interfaces:
+ __init__, step
+ Properties:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - skip (:obj:`int`): Number of ``skip``-th frame. Defaults to 4.
+ """
+
+ def __init__(self, env: gym.Env, skip: int = 4):
+ """
+ Overview:
+ Initialize the MaxAndSkipWrapper.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - skip (:obj:`int`): Number of ``skip``-th frame. Defaults to 4.
+ """
+ super().__init__(env)
+ self._skip = skip
+
+ def step(self, action: Union[int, np.ndarray]) -> tuple:
+ """
+ Overview:
+ Take the given action and repeat it for a specified number of steps. \
+ The rewards are summed up and the maximum frame over the last observations is returned.
+ Arguments:
+ - action (:obj:`Any`): The action to repeat.
+ Returns:
+ - max_frame (:obj:`np.array`): Max over last observations
+ - total_reward (:obj:`Any`): Sum of rewards after previous action.
+ - done (:obj:`Bool`): Whether the episode has ended.
+ - info (:obj:`Dict`): Contains auxiliary diagnostic information (helpful for \
+ debugging, and sometimes learning)
+ """
+ obs_list, total_reward, done = [], 0., False
+ for i in range(self._skip):
+ obs, reward, done, info = self.env.step(action)
+ obs_list.append(obs)
+ total_reward += reward
+ if done:
+ break
+ max_frame = np.max(obs_list[-2:], axis=0)
+ return max_frame, total_reward, done, info
+
+
+@ENV_WRAPPER_REGISTRY.register('warp_frame')
+class WarpFrameWrapper(gym.ObservationWrapper):
+ """
+ Overview:
+ The WarpFrameWrapper class is a gym observation wrapper that resizes
+ the frame of an environment observation to a specified size (default is 84x84).
+ This is often used in the preprocessing pipeline of observations in reinforcement learning,
+ especially for visual observations from Atari environments.
+ Interfaces:
+ __init__, observation
+ Properties:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ - size (:obj:`int`): the size to which the frames are to be resized.
+ - observation_space (:obj:`gym.Space`): the observation space of the wrapped environment.
+ """
+
+ def __init__(self, env: gym.Env, size: int = 84):
+ """
+ Overview:
+ Constructor for WarpFrameWrapper class, initializes the environment and the size.
+ Arguments:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ - size (:obj:`int`): the size to which the frames are to be resized. Default is 84.
+ """
+ super().__init__(env)
+ self.size = size
+ obs_space = env.observation_space
+ if not isinstance(obs_space, gym.spaces.tuple.Tuple):
+ obs_space = (obs_space, )
+ self.observation_space = gym.spaces.tuple.Tuple(
+ [
+ gym.spaces.Box(
+ low=np.min(obs_space[0].low),
+ high=np.max(obs_space[0].high),
+ shape=(self.size, self.size),
+ dtype=obs_space[0].dtype
+ ) for _ in range(len(obs_space))
+ ]
+ )
+ if len(self.observation_space) == 1:
+ self.observation_space = self.observation_space[0]
+
+ def observation(self, frame: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Resize the frame (observation) to the desired size.
+ Arguments:
+ - frame (:obj:`np.ndarray`): the frame to be resized.
+ Returns:
+ - frame (:obj:`np.ndarray`): the resized frame.
+ """
+ try:
+ import cv2
+ except ImportError:
+ from ditk import logging
+ import sys
+ logging.warning("Please install opencv-python first.")
+ sys.exit(1)
+ # deal with the `channel_first` case
+ if frame.shape[0] < 10:
+ frame = frame.transpose(1, 2, 0)
+ frame = cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
+ frame = frame.transpose(2, 0, 1)
+ else:
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
+ frame = cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
+
+ return frame
+
+
+@ENV_WRAPPER_REGISTRY.register('scaled_float_frame')
+class ScaledFloatFrameWrapper(gym.ObservationWrapper):
+ """
+ Overview:
+ The ScaledFloatFrameWrapper normalizes observations to between 0 and 1.
+ Interfaces:
+ __init__, observation
+ """
+
+ def __init__(self, env: gym.Env):
+ """
+ Overview:
+ Initialize the ScaledFloatFrameWrapper, setting the scale and bias for normalization.
+ Arguments:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ """
+ super().__init__(env)
+ low = np.min(env.observation_space.low)
+ high = np.max(env.observation_space.high)
+ self.bias = low
+ self.scale = high - low
+ self.observation_space = gym.spaces.Box(low=0., high=1., shape=env.observation_space.shape, dtype=np.float32)
+
+ def observation(self, observation: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Scale the observation to be within the range [0, 1].
+ Arguments:
+ - observation (:obj:`np.ndarray`): the original observation.
+ Returns:
+ - scaled_observation (:obj:`np.ndarray`): the scaled observation.
+ """
+ return ((observation - self.bias) / self.scale).astype('float32')
+
+
+@ENV_WRAPPER_REGISTRY.register('clip_reward')
+class ClipRewardWrapper(gym.RewardWrapper):
+ """
+ Overview:
+ The ClipRewardWrapper class is a gym reward wrapper that clips the reward to {-1, 0, +1} based on its sign.
+ This can be used to normalize the scale of the rewards in reinforcement learning algorithms.
+ Interfaces:
+ __init__, reward
+ Properties:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ - reward_range (:obj:`Tuple[int, int]`): the range of the reward values after clipping.
+ """
+
+ def __init__(self, env: gym.Env):
+ """
+ Overview:
+ Initialize the ClipRewardWrapper class.
+ Arguments:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ """
+ super().__init__(env)
+ self.reward_range = (-1, 1)
+
+ def reward(self, reward: float) -> float:
+ """
+ Overview:
+ Clip the reward to {-1, 0, +1} based on its sign. Note: np.sign(0) == 0.
+ Arguments:
+ - reward (:obj:`float`): the original reward.
+ Returns:
+ - reward (:obj:`float`): the clipped reward.
+ """
+ return np.sign(reward)
+
+
+@ENV_WRAPPER_REGISTRY.register('action_repeat')
+class ActionRepeatWrapper(gym.Wrapper):
+ """
+ Overview:
+ The ActionRepeatWrapper class is a gym wrapper that repeats the same action for a number of steps.
+ This wrapper is particularly useful in environments where the desired effect is achieved by maintaining
+ the same action across multiple time steps. For instance, some physical environments like motion control
+ tasks might require consistent force input to produce a significant state change.
+
+ Using this wrapper can reduce the temporal complexity of the problem, as it allows the agent to perform
+ multiple actions within a single time step. This can speed up learning, as the agent has fewer decisions
+ to make within a time step. However, it may also sacrifice some level of decision-making precision, as the
+ agent cannot change its action across successive time steps.
+
+ Note that the use of the ActionRepeatWrapper may not be suitable for all types of environments. Specifically,
+ it may not be the best choice for environments where new decisions must be made at each time step, or where
+ the time sequence of actions has a significant impact on the outcome.
+ Interfaces:
+ __init__, step
+ Properties:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ - action_repeat (:obj:`int`): the number of times to repeat the action.
+ """
+
+ def __init__(self, env: gym.Env, action_repeat: int = 1):
+ """
+ Overview:
+ Initialize the ActionRepeatWrapper class.
+ Arguments:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ - action_repeat (:obj:`int`): the number of times to repeat the action. Default is 1.
+ """
+ super().__init__(env)
+ self.action_repeat = action_repeat
+
+ def step(self, action: Union[int, np.ndarray]) -> tuple:
+ """
+ Overview:
+ Take the given action and repeat it for a specified number of steps. The rewards are summed up.
+ Arguments:
+ - action (:obj:`Union[int, np.ndarray]`): The action to repeat.
+ Returns:
+ - obs (:obj:`np.ndarray`): The observation after repeating the action.
+ - reward (:obj:`float`): The sum of rewards after repeating the action.
+ - done (:obj:`bool`): Whether the episode has ended.
+ - info (:obj:`Dict`): Contains auxiliary diagnostic information.
+ """
+ reward = 0
+ for _ in range(self.action_repeat):
+ obs, rew, done, info = self.env.step(action)
+ reward += rew or 0
+ if done:
+ break
+ return obs, reward, done, info
+
+
+@ENV_WRAPPER_REGISTRY.register('delay_reward')
+class DelayRewardWrapper(gym.Wrapper):
+ """
+ Overview:
+ The DelayRewardWrapper class is a gym wrapper that delays the reward. It cumulates the reward over a
+ predefined number of steps and returns the cumulated reward only at the end of this interval.
+ At other times, it returns a reward of 0.
+
+ This wrapper is particularly useful in environments where the impact of an action is not immediately
+ observable, but rather delayed over several steps. For instance, in strategic games or planning tasks,
+ the effect of an action may not be directly noticeable, but it contributes to a sequence of actions that
+ leads to a reward. In these cases, delaying the reward to match the action-effect delay can make the
+ learning process more consistent with the problem's nature.
+
+ However, using this wrapper may increase the difficulty of learning, as the agent needs to associate its
+ actions with delayed outcomes. It also introduces a non-standard reward structure, which could limit the
+ applicability of certain reinforcement learning algorithms.
+
+ Note that the use of the DelayRewardWrapper may not be suitable for all types of environments. Specifically,
+ it may not be the best choice for environments where the effect of actions is immediately observable and the
+ reward should be assigned accordingly.
+ Interfaces:
+ __init__, reset, step
+ Properties:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ - delay_reward_step (:obj:`int`): the number of steps over which to delay and cumulate the reward.
+ """
+
+ def __init__(self, env: gym.Env, delay_reward_step: int = 0):
+ """
+ Overview:
+ Initialize the DelayRewardWrapper class.
+ Arguments:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ - delay_reward_step (:obj:`int`): the number of steps over which to delay and cumulate the reward.
+ """
+ super().__init__(env)
+ self._delay_reward_step = delay_reward_step
+
+ def reset(self) -> np.ndarray:
+ """
+ Overview:
+ Resets the state of the environment and resets the delay reward duration and current delay reward.
+ Returns:
+ - obs (:obj:`np.ndarray`): the initial observation of the environment.
+ """
+ self._delay_reward_duration = 0
+ self._current_delay_reward = 0.
+ obs = self.env.reset()
+ return obs
+
+ def step(self, action: Union[int, np.ndarray]) -> tuple:
+ """
+ Overview:
+ Take the given action and repeat it for a specified number of steps. The rewards are summed up.
+ If the number of steps equals the delay reward step, return the cumulated reward and reset the
+ delay reward duration and current delay reward. Otherwise, return a reward of 0.
+ Arguments:
+ - action (:obj:`Union[int, np.ndarray]`): the action to take in the step.
+ Returns:
+ - obs (:obj:`np.ndarray`): The observation after the step.
+ - reward (:obj:`float`): The cumulated reward after the delay reward step or 0.
+ - done (:obj:`bool`): Whether the episode has ended.
+ - info (:obj:`Dict`): Contains auxiliary diagnostic information.
+ """
+ obs, reward, done, info = self.env.step(action)
+ self._current_delay_reward += reward
+ self._delay_reward_duration += 1
+ if done or self._delay_reward_duration >= self._delay_reward_step:
+ reward = self._current_delay_reward
+ self._current_delay_reward = 0.
+ self._delay_reward_duration = 0
+ else:
+ reward = 0.
+ return obs, reward, done, info
+
+
+@ENV_WRAPPER_REGISTRY.register('eval_episode_return')
+class EvalEpisodeReturnWrapper(gym.Wrapper):
+ """
+ Overview:
+ A wrapper for a gym environment that accumulates rewards at every timestep, and returns the total reward at the
+ end of the episode in `info`. This is used for evaluation purposes.
+ Interfaces:
+ __init__, reset, step
+ Properties:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ """
+
+ def __init__(self, env: gym.Env):
+ """
+ Overview:
+ Initialize the EvalEpisodeReturnWrapper. This involves setting up the environment to wrap.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ """
+ super().__init__(env)
+
+ def reset(self) -> np.ndarray:
+ """
+ Overview:
+ Reset the environment and initialize the accumulated reward to zero.
+ Returns:
+ - obs (:obj:`np.ndarray`): The initial observation from the environment.
+ """
+ self._eval_episode_return = 0.
+ return self.env.reset()
+
+ def step(self, action: Any) -> tuple:
+ """
+ Overview:
+ Step the environment with the provided action, accumulate the returned reward, and add the total reward to
+ `info` if the episode is done.
+ Arguments:
+ - action (:obj:`Any`): The action to take in the environment.
+ Returns:
+ - obs (:obj:`np.ndarray`): The next observation from the environment.
+ - reward (:obj:`float`): The reward from taking the action.
+ - done (:obj:`bool`): Whether the episode is done.
+ - info (:obj:`Dict[str, Any]`): A dictionary of extra information, which includes 'eval_episode_return' if
+ the episode is done.
+ Examples:
+ >>> env = gym.make("CartPole-v1")
+ >>> env = EvalEpisodeReturnWrapper(env)
+ >>> obs = env.reset()
+ >>> done = False
+ >>> while not done:
+ ... action = env.action_space.sample() # Replace with your own policy
+ ... obs, reward, done, info = env.step(action)
+ ... if done:
+ ... print("Total episode reward:", info['eval_episode_return'])
+ """
+ obs, reward, done, info = self.env.step(action)
+ self._eval_episode_return += reward
+ if done:
+ info['eval_episode_return'] = to_ndarray([self._eval_episode_return], dtype=np.float32)
+ return obs, reward, done, info
+
+
+@ENV_WRAPPER_REGISTRY.register('frame_stack')
+class FrameStackWrapper(gym.Wrapper):
+ """
+ Overview:
+ FrameStackWrapper is a gym environment wrapper that stacks the latest n frames (generally 4 in Atari)
+ as a single observation. It is commonly used in environments where the observation is an image,
+ and consecutive frames provide useful temporal information for the agent.
+ Interfaces:
+ __init__, reset, step, _get_ob
+ Properties:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - n_frames (:obj:`int`): The number of frames to stack.
+ - frames (:obj:`collections.deque`): A queue that holds the most recent frames.
+ - observation_space (:obj:`gym.Space`): The space of the stacked observations.
+ """
+
+ def __init__(self, env: gym.Env, n_frames: int = 4) -> None:
+ """
+ Overview:
+ Initialize the FrameStackWrapper. This process includes setting up the environment to wrap,
+ the number of frames to stack, and the observation space.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - n_frame (:obj:`int`): The number of frames to stack.
+ """
+ super().__init__(env)
+ self.n_frames = n_frames
+ self.frames = deque([], maxlen=n_frames)
+ obs_space = env.observation_space
+ if not isinstance(obs_space, gym.spaces.tuple.Tuple):
+ obs_space = (obs_space, )
+ shape = (n_frames, ) + obs_space[0].shape
+ self.observation_space = gym.spaces.tuple.Tuple(
+ [
+ gym.spaces.Box(
+ low=np.min(obs_space[0].low), high=np.max(obs_space[0].high), shape=shape, dtype=obs_space[0].dtype
+ ) for _ in range(len(obs_space))
+ ]
+ )
+ if len(self.observation_space) == 1:
+ self.observation_space = self.observation_space[0]
+
+ def reset(self) -> np.ndarray:
+ """
+ Overview:
+ Reset the environment and initialize frames with the initial observation.
+ Returns:
+ - init_obs (:obj:`np.ndarray`): The stacked initial observations.
+ """
+ obs = self.env.reset()
+ for _ in range(self.n_frames):
+ self.frames.append(obs)
+ return self._get_ob()
+
+ def step(self, action: Any) -> Tuple[np.ndarray, float, bool, Dict[str, Any]]:
+ """
+ Overview:
+ Perform a step in the environment with the given action, append the returned observation
+ to frames, and return the stacked observations.
+ Arguments:
+ - action (:obj:`Any`): The action to perform a step with.
+ Returns:
+ - self._get_ob() (:obj:`np.ndarray`): The stacked observations.
+ - reward (:obj:`float`): The amount of reward returned after the previous action.
+ - done (:obj:`bool`): Whether the episode has ended, in which case further step() calls will return
+ undefined results.
+ - info (:obj:`Dict[str, Any]`): Contains auxiliary diagnostic information (helpful for debugging,
+ and sometimes learning).
+ """
+ obs, reward, done, info = self.env.step(action)
+ self.frames.append(obs)
+ return self._get_ob(), reward, done, info
+
+ def _get_ob(self) -> np.ndarray:
+ """
+ Overview:
+ The original wrapper used `LazyFrames`, but since we use an np buffer, it has no effect.
+ Returns:
+ - stacked_frames (:obj:`np.ndarray`): The stacked frames.
+ """
+ return np.stack(self.frames, axis=0)
+
+
+@ENV_WRAPPER_REGISTRY.register('obs_transpose')
+class ObsTransposeWrapper(gym.ObservationWrapper):
+ """
+ Overview:
+ The ObsTransposeWrapper class is a gym wrapper that transposes the observation to put the channel dimension
+ first. This can be helpful for certain types of neural networks that expect the channel dimension to be
+ the first dimension.
+ Interfaces:
+ __init__, observation
+ Properties:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - observation_space (:obj:`gym.spaces.Box`): The transformed observation space.
+ """
+
+ def __init__(self, env: gym.Env):
+ """
+ Overview:
+ Initialize the ObsTransposeWrapper class and update the observation space according to the environment's
+ observation space.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ """
+ super().__init__(env)
+ obs_space = env.observation_space
+ if isinstance(obs_space, gym.spaces.tuple.Tuple):
+ self.observation_space = gym.spaces.Box(
+ low=np.min(obs_space[0].low),
+ high=np.max(obs_space[0].high),
+ shape=(len(obs_space), obs_space[0].shape[2], obs_space[0].shape[0], obs_space[0].shape[1]),
+ dtype=obs_space[0].dtype
+ )
+ else:
+ self.observation_space = gym.spaces.Box(
+ low=np.min(obs_space.low),
+ high=np.max(obs_space.high),
+ shape=(obs_space.shape[2], obs_space.shape[0], obs_space.shape[1]),
+ dtype=obs_space.dtype
+ )
+
+ def observation(self, obs: Union[tuple, np.ndarray]) -> Union[tuple, np.ndarray]:
+ """
+ Overview:
+ Transpose the observation to put the channel dimension first. If the observation is a tuple, each element
+ in the tuple is transposed independently.
+ Arguments:
+ - obs (:obj:`Union[tuple, np.ndarray]`): The original observation.
+ Returns:
+ - obs (:obj:`Union[tuple, np.ndarray]`): The transposed observation.
+ """
+ if isinstance(obs, tuple):
+ new_obs = []
+ for i in range(len(obs)):
+ new_obs.append(obs[i].transpose(2, 0, 1))
+ obs = np.stack(new_obs)
+ else:
+ obs = obs.transpose(2, 0, 1)
+ return obs
+
+
+class RunningMeanStd(object):
+ """
+ Overview:
+ The RunningMeanStd class is a utility that maintains a running mean and standard deviation calculation over
+ a stream of data.
+ Interfaces:
+ __init__, update, reset, mean, std
+ Properties:
+ - mean (:obj:`np.ndarray`): The running mean.
+ - std (:obj:`np.ndarray`): The running standard deviation.
+ - _epsilon (:obj:`float`): A small number to prevent division by zero when calculating standard deviation.
+ - _shape (:obj:`tuple`): The shape of the data stream.
+ - _mean (:obj:`np.ndarray`): The current mean of the data stream.
+ - _var (:obj:`np.ndarray`): The current variance of the data stream.
+ - _count (:obj:`float`): The number of data points processed.
+ """
+
+ def __init__(self, epsilon: float = 1e-4, shape: tuple = ()):
+ """
+ Overview:
+ Initialize the RunningMeanStd object.
+ Arguments:
+ - epsilon (:obj:`float`, optional): A small number to prevent division by zero when calculating standard
+ deviation. Default is 1e-4.
+ - shape (:obj:`tuple`, optional): The shape of the data stream. Default is an empty tuple, which
+ corresponds to scalars.
+ """
+ self._epsilon = epsilon
+ self._shape = shape
+ self.reset()
+
+ def update(self, x: np.array):
+ """
+ Overview:
+ Update the running statistics with a new batch of data.
+ Arguments:
+ - x (:obj:`np.array`): A batch of data.
+ """
+ batch_mean = np.mean(x, axis=0)
+ batch_var = np.var(x, axis=0)
+ batch_count = x.shape[0]
+
+ new_count = batch_count + self._count
+ mean_delta = batch_mean - self._mean
+ new_mean = self._mean + mean_delta * batch_count / new_count
+ # this method for calculating new variable might be numerically unstable
+ m_a = self._var * self._count
+ m_b = batch_var * batch_count
+ m2 = m_a + m_b + np.square(mean_delta) * self._count * batch_count / new_count
+ new_var = m2 / new_count
+ self._mean = new_mean
+ self._var = new_var
+ self._count = new_count
+
+ def reset(self):
+ """
+ Overview:
+ Resets the state of the environment and reset properties: \
+ ``_mean``, ``_var``, ``_count``
+ """
+ self._mean = np.zeros(self._shape, 'float64')
+ self._var = np.ones(self._shape, 'float64')
+ self._count = self._epsilon
+
+ @property
+ def mean(self) -> np.ndarray:
+ """
+ Overview:
+ Get the current running mean.
+ Returns:
+ The current running mean.
+ """
+ return self._mean
+
+ @property
+ def std(self) -> np.ndarray:
+ """
+ Overview:
+ Get the current running standard deviation.
+ Returns:
+ The current running mean.
+ """
+ return np.sqrt(self._var) + self._epsilon
+
+
+@ENV_WRAPPER_REGISTRY.register('obs_norm')
+class ObsNormWrapper(gym.ObservationWrapper):
+ """
+ Overview:
+ The ObsNormWrapper class is a gym observation wrapper that normalizes
+ observations according to running mean and standard deviation (std).
+ Interfaces:
+ __init__, step, reset, observation
+ Properties:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ - data_count (:obj:`int`): the count of data points observed so far.
+ - clip_range (:obj:`Tuple[int, int]`): the range to clip the normalized observation.
+ - rms (:obj:`RunningMeanStd`): running mean and standard deviation of the observations.
+ """
+
+ def __init__(self, env: gym.Env):
+ """
+ Overview:
+ Initialize the ObsNormWrapper class.
+ Arguments:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ """
+ super().__init__(env)
+ self.data_count = 0
+ self.clip_range = (-3, 3)
+ self.rms = RunningMeanStd(shape=env.observation_space.shape)
+
+ def step(self, action: Union[int, np.ndarray]):
+ """
+ Overview:
+ Take an action in the environment, update the running mean and std,
+ and return the normalized observation.
+ Arguments:
+ - action (:obj:`Union[int, np.ndarray]`): the action to take in the environment.
+ Returns:
+ - obs (:obj:`np.ndarray`): the normalized observation after the action.
+ - reward (:obj:`float`): the reward after the action.
+ - done (:obj:`bool`): whether the episode has ended.
+ - info (:obj:`Dict`): contains auxiliary diagnostic information.
+ """
+ self.data_count += 1
+ observation, reward, done, info = self.env.step(action)
+ self.rms.update(observation)
+ return self.observation(observation), reward, done, info
+
+ def observation(self, observation: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Normalize the observation using the current running mean and std.
+ If less than 30 data points have been observed, return the original observation.
+ Arguments:
+ - observation (:obj:`np.ndarray`): the original observation.
+ Returns:
+ - observation (:obj:`np.ndarray`): the normalized observation.
+ """
+ if self.data_count > 30:
+ return np.clip((observation - self.rms.mean) / self.rms.std, self.clip_range[0], self.clip_range[1])
+ else:
+ return observation
+
+ def reset(self, **kwargs):
+ """
+ Overview:
+ Reset the environment and the properties related to the running mean and std.
+ Arguments:
+ - kwargs (:obj:`Dict`): keyword arguments to be passed to the environment's reset function.
+ Returns:
+ - observation (:obj:`np.ndarray`): the initial observation of the environment.
+ """
+ self.data_count = 0
+ self.rms.reset()
+ observation = self.env.reset(**kwargs)
+ return self.observation(observation)
+
+
+@ENV_WRAPPER_REGISTRY.register('static_obs_norm')
+class StaticObsNormWrapper(gym.ObservationWrapper):
+ """
+ Overview:
+ The StaticObsNormWrapper class is a gym observation wrapper that normalizes
+ observations according to a precomputed mean and standard deviation (std) from a fixed dataset.
+ Interfaces:
+ __init__, observation
+ Properties:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ - mean (:obj:`numpy.ndarray`): the mean of the observations in the fixed dataset.
+ - std (:obj:`numpy.ndarray`): the standard deviation of the observations in the fixed dataset.
+ - clip_range (:obj:`Tuple[int, int]`): the range to clip the normalized observation.
+ """
+
+ def __init__(self, env: gym.Env, mean: np.ndarray, std: np.ndarray):
+ """
+ Overview:
+ Initialize the StaticObsNormWrapper class.
+ Arguments:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ - mean (:obj:`numpy.ndarray`): the mean of the observations in the fixed dataset.
+ - std (:obj:`numpy.ndarray`): the standard deviation of the observations in the fixed dataset.
+ """
+ super().__init__(env)
+ self.mean = mean
+ self.std = std
+ self.clip_range = (-3, 3)
+
+ def observation(self, observation: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Normalize the given observation using the precomputed mean and std.
+ The normalized observation is then clipped within the specified range.
+ Arguments:
+ - observation (:obj:`np.ndarray`): the original observation.
+ Returns:
+ - observation (:obj:`np.ndarray`): the normalized and clipped observation.
+ """
+ return np.clip((observation - self.mean) / self.std, self.clip_range[0], self.clip_range[1])
+
+
+@ENV_WRAPPER_REGISTRY.register('reward_norm')
+class RewardNormWrapper(gym.RewardWrapper):
+ """
+ Overview:
+ This wrapper class normalizes the reward according to running std. It extends the `gym.RewardWrapper`.
+ Interfaces:
+ __init__, step, reward, reset
+ Properties:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - cum_reward (:obj:`numpy.ndarray`): The cumulated reward, initialized as zero and updated in `step` method.
+ - reward_discount (:obj:`float`): The discount factor for reward.
+ - data_count (:obj:`int`): A counter for data, incremented in each `step` call.
+ - rms (:obj:`RunningMeanStd`): An instance of RunningMeanStd to compute the running mean and std of reward.
+ """
+
+ def __init__(self, env: gym.Env, reward_discount: float) -> None:
+ """
+ Overview:
+ Initialize the RewardNormWrapper, setup the properties according to running mean and std.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - reward_discount (:obj:`float`): The discount factor for reward.
+ """
+ super().__init__(env)
+ self.cum_reward = np.zeros((1, ), 'float64')
+ self.reward_discount = reward_discount
+ self.data_count = 0
+ self.rms = RunningMeanStd(shape=(1, ))
+
+ def step(self, action: Any) -> Tuple[np.ndarray, float, bool, Dict]:
+ """
+ Overview:
+ Step the environment with the given action, update properties and return the new observation, reward,
+ done status and info.
+ Arguments:
+ - action (:obj:`Any`): The action to execute in the environment.
+ Returns:
+ - observation (:obj:`np.ndarray`): Normalized observation after executing the action and updated `self.rms`.
+ - reward (:obj:`float`): Amount of reward returned after the action execution (normalized) and updated
+ `self.cum_reward`.
+ - done (:obj:`bool`): Whether the episode has ended, in which case further step() calls will return
+ undefined results.
+ - info (:obj:`Dict`): Contains auxiliary diagnostic information (helpful for debugging, and sometimes
+ learning).
+ """
+ self.data_count += 1
+ observation, reward, done, info = self.env.step(action)
+ reward = np.array([reward], 'float64')
+ self.cum_reward = self.cum_reward * self.reward_discount + reward
+ self.rms.update(self.cum_reward)
+ return observation, self.reward(reward), done, info
+
+ def reward(self, reward: float) -> float:
+ """
+ Overview:
+ Normalize reward if `data_count` is more than 30.
+ Arguments:
+ - reward (:obj:`float`): The raw reward.
+ Returns:
+ - reward (:obj:`float`): Normalized reward.
+ """
+ if self.data_count > 30:
+ return float(reward / self.rms.std)
+ else:
+ return float(reward)
+
+ def reset(self, **kwargs):
+ """
+ Overview:
+ Resets the state of the environment and reset properties (`NumType` ones to 0, \
+ and ``self.rms`` as reset rms wrapper)
+ Arguments:
+ - kwargs (:obj:`Dict`): Reset with this key argumets
+ """
+ self.cum_reward = 0.
+ self.data_count = 0
+ self.rms.reset()
+ return self.env.reset(**kwargs)
+
+
+@ENV_WRAPPER_REGISTRY.register('ram')
+class RamWrapper(gym.Wrapper):
+ """
+ Overview:
+ This wrapper class wraps a RAM environment into an image-like environment. It extends the `gym.Wrapper`.
+ Interfaces:
+ __init__, reset, step
+ Properties:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - observation_space (:obj:`gym.spaces.Box`): The observation space of the wrapped environment.
+ """
+
+ def __init__(self, env: gym.Env, render: bool = False) -> None:
+ """
+ Overview:
+ Initialize the RamWrapper and set up the observation space to wrap the RAM environment.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - render (:obj:`bool`): Whether to render the environment, default is False.
+ """
+ super().__init__(env)
+ shape = env.observation_space.shape + (1, 1)
+ self.observation_space = gym.spaces.Box(
+ low=np.min(env.observation_space.low),
+ high=np.max(env.observation_space.high),
+ shape=shape,
+ dtype=np.float32
+ )
+
+ def reset(self) -> np.ndarray:
+ """
+ Overview:
+ Resets the state of the environment and returns a reshaped observation.
+ Returns:
+ - observation (:obj:`np.ndarray`): New observation after reset and reshaped.
+ """
+ obs = self.env.reset()
+ return obs.reshape(128, 1, 1).astype(np.float32)
+
+ def step(self, action: Any) -> Tuple[np.ndarray, Any, bool, Dict]:
+ """
+ Overview:
+ Execute one step within the environment with the given action. Repeat action, sum reward and reshape the
+ observation.
+ Arguments:
+ - action (:obj:`Any`): The action to take in the environment.
+ Returns:
+ - observation (:obj:`np.ndarray`): Reshaped observation after step with type restriction.
+ - reward (:obj:`Any`): Amount of reward returned after previous action.
+ - done (:obj:`bool`): Whether the episode has ended, in which case further step() calls will return
+ undefined results.
+ - info (:obj:`Dict`): Contains auxiliary diagnostic information (helpful for debugging, and sometimes
+ learning).
+ """
+ obs, reward, done, info = self.env.step(action)
+ return obs.reshape(128, 1, 1).astype(np.float32), reward, done, info
+
+
+@ENV_WRAPPER_REGISTRY.register('episodic_life')
+class EpisodicLifeWrapper(gym.Wrapper):
+ """
+ Overview:
+ This wrapper makes end-of-life equivalent to end-of-episode, but only resets on
+ true game over. This helps in better value estimation.
+ Interfaces:
+ __init__, step, reset
+ Properties:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - lives (:obj:`int`): The current number of lives.
+ - was_real_done (:obj:`bool`): Whether the last episode was ended due to game over.
+ """
+
+ def __init__(self, env: gym.Env) -> None:
+ """
+ Overview:
+ Initialize the EpisodicLifeWrapper, setting lives to 0 and was_real_done to True.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ """
+ super().__init__(env)
+ self.lives = 0
+ self.was_real_done = True
+
+ def step(self, action: Any) -> Tuple[np.ndarray, float, bool, Dict]:
+ """
+ Overview:
+ Execute the given action in the environment, update properties based on the new
+ state and return the new observation, reward, done status and info.
+ Arguments:
+ - action (:obj:`Any`): The action to execute in the environment.
+ Returns:
+ - observation (:obj:`np.ndarray`): Normalized observation after the action execution and updated `self.rms`.
+ - reward (:obj:`float`): Amount of reward returned after the action execution.
+ - done (:obj:`bool`): Whether the episode has ended, in which case further step() calls will return
+ undefined results.
+ - info (:obj:`Dict`): Contains auxiliary diagnostic information (helpful for debugging, and
+ sometimes learning).
+ """
+ obs, reward, done, info = self.env.step(action)
+ self.was_real_done = done
+ # check current lives, make loss of life terminal, then update lives to
+ # handle bonus lives
+ lives = self.env.unwrapped.ale.lives()
+ if 0 < lives < self.lives:
+ # For Qbert sometimes we stay in lives == 0 condition for a few frames,
+ # so it is important to keep lives > 0, so that we only reset
+ # once the environment is actually done.
+ done = True
+ self.lives = lives
+ return obs, reward, done, info
+
+ def reset(self) -> np.ndarray:
+ """
+ Overview:
+ Resets the state of the environment and updates the number of lives, only when
+ lives are exhausted. This way all states are still reachable even though lives
+ are episodic, and the learner need not know about any of this behind-the-scenes.
+ Returns:
+ - observation (:obj:`np.ndarray`): New observation after reset with no-op step to advance from
+ terminal/lost life state.
+ """
+ if self.was_real_done:
+ obs = self.env.reset()
+ else:
+ # no-op step to advance from terminal/lost life state
+ obs = self.env.step(0)[0]
+ self.lives = self.env.unwrapped.ale.lives()
+ return obs
+
+
+@ENV_WRAPPER_REGISTRY.register('fire_reset')
+class FireResetWrapper(gym.Wrapper):
+ """
+ Overview:
+ This wrapper takes a fire action at environment reset.
+ Related discussion: https://github.com/openai/baselines/issues/240
+ Interfaces:
+ __init__, reset
+ Properties:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ """
+
+ def __init__(self, env: gym.Env) -> None:
+ """
+ Overview:
+ Initialize the FireResetWrapper. Assume that the second action of the environment
+ is 'FIRE' and there are at least three actions.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ """
+ super().__init__(env)
+ assert env.unwrapped.get_action_meanings()[1] == 'FIRE'
+ assert len(env.unwrapped.get_action_meanings()) >= 3
+
+ def reset(self) -> np.ndarray:
+ """
+ Overview:
+ Resets the state of the environment and executes a fire action, i.e. reset with action 1.
+ Returns:
+ - observation (:obj:`np.ndarray`): New observation after reset and fire action.
+ """
+ self.env.reset()
+ return self.env.step(1)[0]
+
+
+@ENV_WRAPPER_REGISTRY.register('gym_hybrid_dict_action')
+class GymHybridDictActionWrapper(gym.ActionWrapper):
+ """
+ Overview:
+ Transform Gym-Hybrid's original `gym.spaces.Tuple` action space to `gym.spaces.Dict`.
+ Interfaces:
+ __init__, action
+ Properties:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - action_space (:obj:`gym.spaces.Dict`): The new action space.
+ """
+
+ def __init__(self, env: gym.Env) -> None:
+ """
+ Overview:
+ Initialize the GymHybridDictActionWrapper, setting up the new action space.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ """
+ super().__init__(env)
+ self.action_space = gym.spaces.Dict(
+ {
+ 'type': gym.spaces.Discrete(3),
+ # shape = (2, ) 0 is for acceleration; 1 is for rotation
+ 'mask': gym.spaces.Box(low=0, high=1, shape=(2, ), dtype=np.int64),
+ 'args': gym.spaces.Box(
+ low=np.array([0., -1.], dtype=np.float32),
+ high=np.array([1., 1.], dtype=np.float32),
+ shape=(2, ),
+ dtype=np.float32
+ ),
+ }
+ )
+
+ def step(self, action: Dict) -> Tuple[Dict, float, bool, Dict]:
+ """
+ Overview:
+ Execute the given action in the environment, transform the action from Dict to Tuple,
+ and return the new observation, reward, done status and info.
+ Arguments:
+ - action (:obj:`Dict`): The action to execute in the environment, structured as a dictionary.
+ Returns:
+ - observation (:obj:`Dict`): The wrapped observation, which includes the current observation,
+ previous action and previous reward.
+ - reward (:obj:`float`): Amount of reward returned after the action execution.
+ - done (:obj:`bool`): Whether the episode has ended, in which case further step() calls will return
+ undefined results.
+ - info (:obj:`Dict`): Contains auxiliary diagnostic information (helpful for debugging, and
+ sometimes learning).
+ """
+ # # From Dict to Tuple
+ # action_type = action[0]
+ # if action_type == 0:
+ # action_mask = np.array([1, 0], dtype=np.int64)
+ # action_args = np.array([action[1][0], 0], dtype=np.float32)
+ # elif action_type == 1:
+ # action_mask = np.array([0, 1], dtype=np.int64)
+ # action_args = np.array([0, action[1][1]], dtype=np.float32)
+ # elif action_type == 2:
+ # action_mask = np.array([0, 0], dtype=np.int64)
+ # action_args = np.array([0, 0], dtype=np.float32)
+
+ # From Dict to Tuple
+ action_type, action_mask, action_args = action['type'], action['mask'], action['args']
+ return self.env.step((action_type, action_args))
+
+
+@ENV_WRAPPER_REGISTRY.register('obs_plus_prev_action_reward')
+class ObsPlusPrevActRewWrapper(gym.Wrapper):
+ """
+ Overview:
+ This wrapper is used in policy NGU. It sets a dict as the new wrapped observation,
+ which includes the current observation, previous action and previous reward.
+ Interfaces:
+ __init__, reset, step
+ Properties:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - prev_action (:obj:`int`): The previous action.
+ - prev_reward_extrinsic (:obj:`float`): The previous reward.
+ """
+
+ def __init__(self, env: gym.Env) -> None:
+ """
+ Overview:
+ Initialize the ObsPlusPrevActRewWrapper, setting up the previous action and reward.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ """
+ super().__init__(env)
+ self.observation_space = gym.spaces.Dict(
+ {
+ 'obs': env.observation_space,
+ 'prev_action': env.action_space,
+ 'prev_reward_extrinsic': gym.spaces.Box(
+ low=env.reward_range[0], high=env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ }
+ )
+ self.prev_action = -1 # null action
+ self.prev_reward_extrinsic = 0 # null reward
+
+ def reset(self) -> Dict:
+ """
+ Overview:
+ Resets the state of the environment, and returns the wrapped observation.
+ Returns:
+ - observation (:obj:`Dict`): The wrapped observation, which includes the current observation,
+ previous action and previous reward.
+ """
+ obs = self.env.reset()
+ obs = {'obs': obs, 'prev_action': self.prev_action, 'prev_reward_extrinsic': self.prev_reward_extrinsic}
+ return obs
+
+ def step(self, action: Any) -> Tuple[Dict, float, bool, Dict]:
+ """
+ Overview:
+ Execute the given action in the environment, save the previous action and reward
+ to be used in the next observation, and return the new observation, reward,
+ done status and info.
+ Arguments:
+ - action (:obj:`Any`): The action to execute in the environment.
+ Returns:
+ - observation (:obj:`Dict`): The wrapped observation, which includes the current observation,
+ previous action and previous reward.
+ - reward (:obj:`float`): Amount of reward returned after the action execution.
+ - done (:obj:`bool`): Whether the episode has ended, in which case further step() calls will return
+ undefined results.
+ - info (:obj:`Dict`): Contains auxiliary diagnostic information (helpful for debugging, and sometimes
+ learning).
+ """
+ obs, reward, done, info = self.env.step(action)
+ obs = {'obs': obs, 'prev_action': self.prev_action, 'prev_reward_extrinsic': self.prev_reward_extrinsic}
+ self.prev_action = action
+ self.prev_reward_extrinsic = reward
+ return obs, reward, done, info
+
+
+class TransposeWrapper(gym.Wrapper):
+ """
+ Overview:
+ This class is used to transpose the observation space of the environment.
+
+ Interfaces:
+ __init__, _process_obs, step, reset
+ """
+
+ def __init__(self, env: gym.Env) -> None:
+ """
+ Overview:
+ Initialize the TransposeWrapper, setting up the new observation space.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ """
+ super().__init__(env)
+ old_space = copy.deepcopy(env.observation_space)
+ new_shape = (old_space.shape[-1], *old_space.shape[:-1])
+ self._observation_space = gym.spaces.Box(
+ low=old_space.low.min(), high=old_space.high.max(), shape=new_shape, dtype=old_space.dtype
+ )
+
+ def _process_obs(self, obs: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Transpose the observation into the format (channels, height, width).
+ Arguments:
+ - obs (:obj:`np.ndarray`): The observation to transform.
+ Returns:
+ - obs (:obj:`np.ndarray`): The transposed observation.
+ """
+ obs = to_ndarray(obs)
+ obs = np.transpose(obs, (2, 0, 1))
+ return obs
+
+ def step(self, action: Any) -> Tuple[np.ndarray, float, bool, Dict]:
+ """
+ Overview:
+ Execute the given action in the environment, process the observation and return
+ the new observation, reward, done status, and info.
+ Arguments:
+ - action (:obj:`Any`): The action to execute in the environment.
+ Returns:
+ - observation (:obj:`np.ndarray`): The processed observation after the action execution.
+ - reward (:obj:`float`): Amount of reward returned after the action execution.
+ - done (:obj:`bool`): Whether the episode has ended, in which case further step() calls will return
+ undefined results.
+ - info (:obj:`Dict`): Contains auxiliary diagnostic information (helpful for debugging, and sometimes
+ learning).
+ """
+ obs, reward, done, info = self.env.step(action)
+ return self._process_obs(obs), reward, done, info
+
+ def reset(self) -> np.ndarray:
+ """
+ Overview:
+ Resets the state of the environment and returns the processed observation.
+ Returns:
+ - observation (:obj:`np.ndarray`): The processed observation after reset.
+ """
+ obs = self.env.reset()
+ return self._process_obs(obs)
+
+
+class TimeLimitWrapper(gym.Wrapper):
+ """
+ Overview:
+ This class is used to enforce a time limit on the environment.
+ Interfaces:
+ __init__, reset, step
+ """
+
+ def __init__(self, env: gym.Env, max_limit: int) -> None:
+ """
+ Overview:
+ Initialize the TimeLimitWrapper, setting up the maximum limit of time steps.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - max_limit (:obj:`int`): The maximum limit of time steps.
+ """
+ super().__init__(env)
+ self.max_limit = max_limit
+
+ def reset(self) -> np.ndarray:
+ """
+ Overview:
+ Resets the state of the environment and the time counter.
+ Returns:
+ - observation (:obj:`np.ndarray`): The new observation after reset.
+ """
+ self.time_count = 0
+ return self.env.reset()
+
+ def step(self, action: Any) -> Tuple[np.ndarray, float, bool, Dict]:
+ """
+ Overview:
+ Execute the given action in the environment, update the time counter, and
+ return the new observation, reward, done status and info.
+ Arguments:
+ - action (:obj:`Any`): The action to execute in the environment.
+ Returns:
+ - observation (:obj:`np.ndarray`): The new observation after the action execution.
+ - reward (:obj:`float`): Amount of reward returned after the action execution.
+ - done (:obj:`bool`): Whether the episode has ended, in which case further step() calls will return
+ undefined results.
+ - info (:obj:`Dict`): Contains auxiliary diagnostic information (helpful for debugging, and sometimes
+ learning).
+ """
+ obs, reward, done, info = self.env.step(action)
+ self.time_count += 1
+ if self.time_count >= self.max_limit:
+ done = True
+ info['time_limit'] = True
+ else:
+ info['time_limit'] = False
+ info['time_count'] = self.time_count
+ return obs, reward, done, info
+
+
+class FlatObsWrapper(gym.Wrapper):
+ """
+ Overview:
+ This class is used to flatten the observation space of the environment.
+ Note: only suitable for environments like minigrid.
+ Interfaces:
+ __init__, observation, reset, step
+ """
+
+ def __init__(self, env: gym.Env, maxStrLen: int = 96) -> None:
+ """
+ Overview:
+ Initialize the FlatObsWrapper, setup the new observation space.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ - maxStrLen (:obj:`int`): The maximum length of mission string, default is 96.
+ """
+ super().__init__(env)
+
+ self.maxStrLen = maxStrLen
+ self.numCharCodes = 28
+
+ imgSpace = env.observation_space.spaces["image"]
+ imgSize = reduce(operator.mul, imgSpace.shape, 1)
+
+ self.observation_space = gym.spaces.Box(
+ low=0,
+ high=255,
+ shape=(imgSize + self.numCharCodes * self.maxStrLen, ),
+ dtype="float32",
+ )
+
+ self.cachedStr: str = None
+
+ def observation(self, obs: Union[np.ndarray, Tuple]) -> np.ndarray:
+ """
+ Overview:
+ Process the observation, convert the mission into one-hot encoding and concatenate
+ it with the image data.
+ Arguments:
+ - obs (:obj:`Union[np.ndarray, Tuple]`): The raw observation to process.
+ Returns:
+ - obs (:obj:`np.ndarray`): The processed observation.
+ """
+ if isinstance(obs, tuple): # for compatibility of gymnasium
+ obs = obs[0]
+ image = obs["image"]
+ mission = obs["mission"]
+
+ # Cache the last-encoded mission string
+ if mission != self.cachedStr:
+ assert (len(mission) <= self.maxStrLen), f"mission string too long ({len(mission)} chars)"
+ mission = mission.lower()
+
+ strArray = np.zeros(shape=(self.maxStrLen, self.numCharCodes), dtype="float32")
+
+ for idx, ch in enumerate(mission):
+ if ch >= "a" and ch <= "z":
+ chNo = ord(ch) - ord("a")
+ elif ch == " ":
+ chNo = ord("z") - ord("a") + 1
+ elif ch == ",":
+ chNo = ord("z") - ord("a") + 2
+ else:
+ raise ValueError(f"Character {ch} is not available in mission string.")
+ assert chNo < self.numCharCodes, "%s : %d" % (ch, chNo)
+ strArray[idx, chNo] = 1
+
+ self.cachedStr = mission
+ self.cachedArray = strArray
+
+ obs = np.concatenate((image.flatten(), self.cachedArray.flatten()))
+
+ return obs
+
+ def reset(self, *args, **kwargs) -> np.ndarray:
+ """
+ Overview:
+ Resets the state of the environment and returns the processed observation.
+ Returns:
+ - observation (:obj:`np.ndarray`): The processed observation after reset.
+ """
+ obs = self.env.reset(*args, **kwargs)
+ return self.observation(obs)
+
+ def step(self, *args, **kwargs) -> Tuple[np.ndarray, float, bool, Dict]:
+ """
+ Overview:
+ Execute the given action in the environment, and return the processed observation,
+ reward, done status, and info.
+ Returns:
+ - observation (:obj:`np.ndarray`): The processed observation after the action execution.
+ - reward (:obj:`float`): Amount of reward returned after the action execution.
+ - done (:obj:`bool`): Whether the episode has ended, in which case further step() calls will return
+ undefined results.
+ - info (:obj:`Dict`): Contains auxiliary diagnostic information (helpful for debugging, and sometimes
+ learning).
+ """
+ o, r, d, i = self.env.step(*args, **kwargs)
+ o = self.observation(o)
+ return o, r, d, i
+
+
+class GymToGymnasiumWrapper(gym.Wrapper):
+ """
+ Overview:
+ This class is used to wrap a gymnasium environment to a gym environment.
+ Interfaces:
+ __init__, seed, reset
+ """
+
+ def __init__(self, env: gymnasium.Env) -> None:
+ """
+ Overview:
+ Initialize the GymToGymnasiumWrapper.
+ Arguments:
+ - env (:obj:`gymnasium.Env`): The gymnasium environment to wrap.
+ """
+ assert isinstance(env, gymnasium.Env), type(env)
+ super().__init__(env)
+ self._seed = None
+
+ def seed(self, seed: int) -> None:
+ """
+ Overview:
+ Set the seed for the environment.
+ Arguments:
+ - seed (:obj:`int`): The seed to set.
+ """
+ self._seed = seed
+
+ def reset(self) -> np.ndarray:
+ """
+ Overview:
+ Resets the state of the environment and returns the new observation. If a seed
+ was set, use it in the reset.
+ Returns:
+ - observation (:obj:`np.ndarray`): The new observation after reset.
+ """
+ if self.seed is not None:
+ return self.env.reset(seed=self._seed)
+ else:
+ return self.env.reset()
+
+
+@ENV_WRAPPER_REGISTRY.register('reward_in_obs')
+class AllinObsWrapper(gym.Wrapper):
+ """
+ Overview:
+ This wrapper is used in policy ``Decision Transformer``, which is proposed in paper
+ https://arxiv.org/abs/2106.01345. It sets a dict {'obs': obs, 'reward': reward}
+ as the new wrapped observation, which includes the current observation and previous reward.
+ Interfaces:
+ __init__, reset, step, seed
+ Properties:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ """
+
+ def __init__(self, env: gym.Env) -> None:
+ """
+ Overview:
+ Initialize the AllinObsWrapper.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment to wrap.
+ """
+ super().__init__(env)
+
+ def reset(self) -> Dict:
+ """
+ Overview:
+ Resets the state of the environment and returns the new observation.
+ Returns:
+ - observation (:obj:`Dict`): The new observation after reset, includes the current observation and reward.
+ """
+ ret = {'obs': self.env.reset(), 'reward': np.array([0])}
+ self._observation_space = gym.spaces.Dict(
+ {
+ 'obs': self.env.observation_space,
+ 'reward': gym.spaces.Box(low=-np.inf, high=np.inf, dtype=np.float32, shape=(1, ))
+ }
+ )
+ return ret
+
+ def step(self, action: Any):
+ """
+ Overview:
+ Execute the given action in the environment, and return the new observation,
+ reward, done status, and info.
+ Arguments:
+ - action (:obj:`Any`): The action to execute in the environment.
+ Returns:
+ - timestep (:obj:`BaseEnvTimestep`): The timestep after the action execution.
+ """
+ obs, reward, done, info = self.env.step(action)
+ obs = {'obs': obs, 'reward': reward}
+ from ding.envs import BaseEnvTimestep
+ return BaseEnvTimestep(obs, reward, done, info)
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ """
+ Overview:
+ Set the seed for the environment.
+ Arguments:
+ - seed (:obj:`int`): The seed to set.
+ - dynamic_seed (:obj:`bool`): Whether to use dynamic seed, default is True.
+ """
+ self.env.seed(seed, dynamic_seed)
+
+
+def update_shape(obs_shape: Any, act_shape: Any, rew_shape: Any, wrapper_names: List[str]) -> Tuple[Any, Any, Any]:
+ """
+ Overview:
+ Get new shapes of observation, action, and reward given the wrapper.
+ Arguments:
+ - obs_shape (:obj:`Any`): The original shape of observation.
+ - act_shape (:obj:`Any`): The original shape of action.
+ - rew_shape (:obj:`Any`): The original shape of reward.
+ - wrapper_names (:obj:`List[str]`): The names of the wrappers.
+ Returns:
+ - obs_shape (:obj:`Any`): The new shape of observation.
+ - act_shape (:obj:`Any`): The new shape of action.
+ - rew_shape (:obj:`Any`): The new shape of reward.
+ """
+ for wrapper_name in wrapper_names:
+ if wrapper_name:
+ try:
+ obs_shape, act_shape, rew_shape = eval(wrapper_name).new_shape(obs_shape, act_shape, rew_shape)
+ except Exception:
+ continue
+ return obs_shape, act_shape, rew_shape
+
+
+def create_env_wrapper(env: gym.Env, env_wrapper_cfg: EasyDict) -> gym.Wrapper:
+ """
+ Overview:
+ Create an environment wrapper according to the environment wrapper configuration and the environment instance.
+ Arguments:
+ - env (:obj:`gym.Env`): The environment instance to be wrapped.
+ - env_wrapper_cfg (:obj:`EasyDict`): The configuration for the environment wrapper.
+ Returns:
+ - env (:obj:`gym.Wrapper`): The wrapped environment instance.
+ """
+ env_wrapper_cfg = copy.deepcopy(env_wrapper_cfg)
+ if 'import_names' in env_wrapper_cfg:
+ import_module(env_wrapper_cfg.pop('import_names'))
+ env_wrapper_type = env_wrapper_cfg.pop('type')
+ return ENV_WRAPPER_REGISTRY.build(env_wrapper_type, env, **env_wrapper_cfg.get('kwargs', {}))
diff --git a/DI-engine/ding/envs/gym_env.py b/DI-engine/ding/envs/gym_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0f3e34dd2aab604d4bb84b10e4e88b3e22cb0e9
--- /dev/null
+++ b/DI-engine/ding/envs/gym_env.py
@@ -0,0 +1,6 @@
+from ding.envs import BaseEnv, DingEnvWrapper
+
+
+def env(cfg, seed_api=True, caller='collector', **kwargs) -> BaseEnv:
+ import gym
+ return DingEnvWrapper(gym.make(cfg.env_id, **kwargs), cfg=cfg, seed_api=seed_api, caller=caller)
diff --git a/DI-engine/ding/example/__init__.py b/DI-engine/ding/example/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/ding/example/bcq.py b/DI-engine/ding/example/bcq.py
new file mode 100755
index 0000000000000000000000000000000000000000..4bd1385c3fbe93a2a194707100ac885d032dfa32
--- /dev/null
+++ b/DI-engine/ding/example/bcq.py
@@ -0,0 +1,42 @@
+import gym
+from ditk import logging
+from ding.model import BCQ
+from ding.policy import BCQPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import create_dataset
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OfflineRLContext
+from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger
+from ding.utils import set_pkg_seed
+from dizoo.d4rl.envs import D4RLEnv
+from dizoo.d4rl.config.halfcheetah_medium_bcq_config import main_config, create_config
+
+
+def main():
+ # If you don't have offline data, you need to prepare if first and set the data_path in config
+ # For demostration, we also can train a RL policy (e.g. SAC) and collect some data
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OfflineRLContext()):
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: D4RLEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ dataset = create_dataset(cfg)
+ model = BCQ(**cfg.policy.model)
+ policy = BCQPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(offline_data_fetcher(cfg, dataset))
+ task.use(trainer(cfg, policy.learn_mode))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=10000000))
+ task.use(offline_logger())
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/c51_nstep.py b/DI-engine/ding/example/c51_nstep.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b98ece213f2a3ebc72890fafd3d8849afb19ef8
--- /dev/null
+++ b/DI-engine/ding/example/c51_nstep.py
@@ -0,0 +1,48 @@
+import gym
+from ditk import logging
+from ding.model import C51DQN
+from ding.policy import C51Policy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, nstep_reward_enhancer
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_c51_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'cartpole_c51_nstep'
+ main_config.policy.nstep = 3
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = C51DQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = C51Policy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/collect_demo_data.py b/DI-engine/ding/example/collect_demo_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..53e37b928c8349604687627e2b0e71119153f889
--- /dev/null
+++ b/DI-engine/ding/example/collect_demo_data.py
@@ -0,0 +1,36 @@
+import gym
+from ditk import logging
+import torch
+from ding.model import ContinuousQAC
+from ding.policy import SACPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import offline_data_save_type
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import StepCollector, offline_data_saver
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv
+from dizoo.classic_control.pendulum.config.pendulum_sac_data_generation_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True, evaluator=None)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(10)], cfg=cfg.env.manager)
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = ContinuousQAC(**cfg.policy.model)
+ policy = SACPolicy(cfg.policy, model=model, enable_field=['collect'])
+ state_dict = torch.load(cfg.policy.collect.state_dict_path, map_location='cpu')
+ policy.collect_mode.load_state_dict(state_dict)
+
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(offline_data_saver(cfg.policy.collect.save_path, data_type='hdf5'))
+ task.run(max_step=1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/cql.py b/DI-engine/ding/example/cql.py
new file mode 100644
index 0000000000000000000000000000000000000000..5af78dabd30506cae6698586710d592abf13112a
--- /dev/null
+++ b/DI-engine/ding/example/cql.py
@@ -0,0 +1,42 @@
+import gym
+from ditk import logging
+from ding.model import QAC
+from ding.policy import CQLPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import create_dataset
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OfflineRLContext
+from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv
+from dizoo.classic_control.pendulum.config.pendulum_cql_config import main_config, create_config
+
+
+def main():
+ # If you don't have offline data, you need to prepare if first and set the data_path in config
+ # For demostration, we also can train a RL policy (e.g. SAC) and collect some data
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OfflineRLContext()):
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ dataset = create_dataset(cfg)
+ model = QAC(**cfg.policy.model)
+ policy = CQLPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(offline_data_fetcher(cfg, dataset))
+ task.use(trainer(cfg, policy.learn_mode))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(offline_logger())
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/d4pg.py b/DI-engine/ding/example/d4pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..39806f166dd9154326794f846266e4a5738dea97
--- /dev/null
+++ b/DI-engine/ding/example/d4pg.py
@@ -0,0 +1,48 @@
+import gym
+from ditk import logging
+from ding.model.template.qac_dist import QACDIST
+from ding.policy import D4PGPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.data.buffer.middleware import PriorityExperienceReplay
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ CkptSaver, nstep_reward_enhancer
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv
+from dizoo.classic_control.pendulum.config.pendulum_d4pg_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = QACDIST(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ buffer_.use(PriorityExperienceReplay(buffer_, IS_weight=True))
+ policy = D4PGPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(
+ StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
+ )
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/ddpg.py b/DI-engine/ding/example/ddpg.py
new file mode 100644
index 0000000000000000000000000000000000000000..3fa9c18db2492e6dd74506034ff5b4c3eb5f14a1
--- /dev/null
+++ b/DI-engine/ding/example/ddpg.py
@@ -0,0 +1,46 @@
+import gym
+from ditk import logging
+from ding.model.template.qac import ContinuousQAC
+from ding.policy import DDPGPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ CkptSaver, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv
+from dizoo.classic_control.pendulum.config.pendulum_ddpg_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = ContinuousQAC(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = DDPGPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(
+ StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
+ )
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(termination_checker(max_train_iter=10000))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/dqn.py b/DI-engine/ding/example/dqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..0959b3ab22ccbc2441d3bf3ea17e2c2bba260032
--- /dev/null
+++ b/DI-engine/ding/example/dqn.py
@@ -0,0 +1,102 @@
+"""
+# Example of DQN pipeline
+
+Use the pipeline on a single process:
+
+> python3 -u ding/example/dqn.py
+
+Use the pipeline on multiple processes:
+
+We surpose there are N processes (workers) = 1 learner + 1 evaluator + (N-2) collectors
+
+## First Example —— Execute on one machine with multi processes.
+
+Execute 4 processes with 1 learner + 1 evaluator + 2 collectors
+Remember to keep them connected by mesh to ensure that they can exchange information with each other.
+
+> ditask --package . --main ding.example.dqn.main --parallel-workers 4 --topology mesh
+
+## Second Example —— Execute on multiple machines.
+
+1. Execute 1 learner + 1 evaluator on one machine.
+
+> ditask --package . --main ding.example.dqn.main --parallel-workers 2 --topology mesh --node-ids 0 --ports 50515
+
+2. Execute 2 collectors on another machine. (Suppose the ip of the first machine is 127.0.0.1).
+ Here we use `alone` topology instead of `mesh` because the collectors do not need communicate with each other.
+ Remember the `node_ids` cannot be duplicated with the learner, evaluator processes.
+ And remember to set the `ports` (should not conflict with others) and `attach_to` parameters.
+ The value of the `attach_to` parameter should be obtained from the log of the
+ process started earlier (e.g. 'NNG listen on tcp://10.0.0.4:50515').
+
+> ditask --package . --main ding.example.dqn.main --parallel-workers 2 --topology alone --node-ids 2 \
+ --ports 50517 --attach-to tcp://10.0.0.4:50515,tcp://127.0.0.1:50516
+
+3. You can repeat step 2 to start more collectors on other machines.
+"""
+import gym
+from ditk import logging
+from ding.data.model_loader import FileModelLoader
+from ding.data.storage_loader import FileStorageLoader
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, ContextExchanger, ModelExchanger, online_logger
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ # Consider the case with multiple processes
+ if task.router.is_active:
+ # You can use labels to distinguish between workers with different roles,
+ # here we use node_id to distinguish.
+ if task.router.node_id == 0:
+ task.add_role(task.role.LEARNER)
+ elif task.router.node_id == 1:
+ task.add_role(task.role.EVALUATOR)
+ else:
+ task.add_role(task.role.COLLECTOR)
+
+ # Sync their context and model between each worker.
+ task.use(ContextExchanger(skip_n_iter=1))
+ task.use(ModelExchanger(model))
+
+ # Here is the part of single process pipeline.
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(online_logger(train_show_freq=10))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/dqn_eval.py b/DI-engine/ding/example/dqn_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..296d8b0b8f72a7f8b97ce2738ee5bee4c071f333
--- /dev/null
+++ b/DI-engine/ding/example/dqn_eval.py
@@ -0,0 +1,42 @@
+import gym
+import torch
+from ditk import logging
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import interaction_evaluator
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'cartpole_dqn_eval'
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ model = DQN(**cfg.policy.model)
+
+ # Load the pretrained weights.
+ # First, you should get a pretrained network weights.
+ # For example, you can run ``python3 -u ding/examples/dqn.py``.
+ pretrained_state_dict = torch.load('cartpole_dqn_seed0/ckpt/final.pth.tar', map_location='cpu')['model']
+ model.load_state_dict(pretrained_state_dict)
+
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ # Define the evaluator middleware.
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.run(max_step=1)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/dqn_her.py b/DI-engine/ding/example/dqn_her.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88458aa33bd9065941cc634ec724853114f6dfe
--- /dev/null
+++ b/DI-engine/ding/example/dqn_her.py
@@ -0,0 +1,46 @@
+import gym
+from ditk import logging
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.reward_model import HerRewardModel
+from ding.envs import BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import HERLearner, EpisodeCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver
+from ding.utils import set_pkg_seed
+from dizoo.bitflip.envs import BitFlipEnv
+from dizoo.bitflip.config.bitflip_her_dqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: BitFlipEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: BitFlipEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = DQNPolicy(cfg.policy, model=model)
+ her_reward_model = HerRewardModel(cfg.policy.other.her, cfg.policy.cuda)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(EpisodeCollector(cfg, policy.collect_mode, collector_env))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(HERLearner(cfg, policy.learn_mode, buffer_, her_reward_model))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/dqn_new_env.py b/DI-engine/ding/example/dqn_new_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..e43a9a81874e721cc187f9ed5b38f691d4abd674
--- /dev/null
+++ b/DI-engine/ding/example/dqn_new_env.py
@@ -0,0 +1,48 @@
+import gym
+from ditk import logging
+from ding.framework.supervisor import ChildType
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, EnvSupervisor
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = EnvSupervisor(
+ type_=ChildType.THREAD,
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ **cfg.env.manager
+ )
+ evaluator_env = EnvSupervisor(
+ type_=ChildType.THREAD,
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ **cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/dqn_nstep.py b/DI-engine/ding/example/dqn_nstep.py
new file mode 100644
index 0000000000000000000000000000000000000000..09dc786d22432cb66c6dee91f3c4afe91f01fb40
--- /dev/null
+++ b/DI-engine/ding/example/dqn_nstep.py
@@ -0,0 +1,49 @@
+import gym
+from ditk import logging
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, nstep_reward_enhancer, final_ctx_saver
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'cartpole_dqn_nstep'
+ main_config.policy.nstep = 3
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(final_ctx_saver(cfg.exp_name))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/dqn_per.py b/DI-engine/ding/example/dqn_per.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd6d736f8bd6a47a3aea431d5263364390dd4a76
--- /dev/null
+++ b/DI-engine/ding/example/dqn_per.py
@@ -0,0 +1,50 @@
+import gym
+from ditk import logging
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.data.buffer.middleware import PriorityExperienceReplay
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_dqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'cartpole_dqn_per'
+ main_config.policy.priority = True
+ main_config.policy.priority_IS_weight = True
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ buffer_.use(PriorityExperienceReplay(buffer_, IS_weight=True))
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/dqn_rnd.py b/DI-engine/ding/example/dqn_rnd.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d5e1b93c3e0f3843bd35fe26cc8f58d7d60c9c1
--- /dev/null
+++ b/DI-engine/ding/example/dqn_rnd.py
@@ -0,0 +1,48 @@
+import gym
+from ditk import logging
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.reward_model import RndRewardModel
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, trainer, \
+ eps_greedy_handler, CkptSaver
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_dqn_rnd_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = DQNPolicy(cfg.policy, model=model)
+ reward_model = RndRewardModel(cfg.reward_model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(trainer(cfg, reward_model))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_, reward_model=reward_model))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/dt.py b/DI-engine/ding/example/dt.py
new file mode 100644
index 0000000000000000000000000000000000000000..407ea01d6b029351a33289638f97535fbede6ba8
--- /dev/null
+++ b/DI-engine/ding/example/dt.py
@@ -0,0 +1,47 @@
+import gym
+from ditk import logging
+from ding.model import DecisionTransformer
+from ding.policy import DTPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManager, BaseEnvManagerV2
+from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper
+from ding.data import create_dataset
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OfflineRLContext
+from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, \
+ offline_data_fetcher, offline_logger, termination_checker, final_ctx_saver
+from ding.utils import set_pkg_seed
+from dizoo.box2d.lunarlander.envs.lunarlander_env import LunarLanderEnv
+from dizoo.box2d.lunarlander.config.lunarlander_dt_config import main_config, create_config
+
+
+def main():
+ # If you don't have offline data, you need to prepare if first and set the data_path in config
+ # For demostration, we also can train a RL policy (e.g. SAC) and collect some data
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OfflineRLContext()):
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: AllinObsWrapper(LunarLanderEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ dataset = create_dataset(cfg)
+ cfg.policy.state_mean, cfg.policy.state_std = dataset.get_state_stats()
+ model = DecisionTransformer(**cfg.policy.model)
+ policy = DTPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(offline_data_fetcher(cfg, dataset))
+ task.use(trainer(cfg, policy.learn_mode))
+ task.use(termination_checker(max_train_iter=1e5))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(offline_logger())
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/edac.py b/DI-engine/ding/example/edac.py
new file mode 100755
index 0000000000000000000000000000000000000000..40230f3008fef2ab5d4fb6e27aa58bdf26172f5d
--- /dev/null
+++ b/DI-engine/ding/example/edac.py
@@ -0,0 +1,42 @@
+import gym
+from ditk import logging
+from ding.model import QACEnsemble
+from ding.policy import EDACPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import create_dataset
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OfflineRLContext
+from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher, offline_logger
+from ding.utils import set_pkg_seed
+from dizoo.d4rl.envs import D4RLEnv
+from dizoo.d4rl.config.halfcheetah_medium_edac_config import main_config, create_config
+
+
+def main():
+ # If you don't have offline data, you need to prepare if first and set the data_path in config
+ # For demostration, we also can train a RL policy (e.g. SAC) and collect some data
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OfflineRLContext()):
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: D4RLEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ dataset = create_dataset(cfg)
+ model = QACEnsemble(**cfg.policy.model)
+ policy = EDACPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(offline_data_fetcher(cfg, dataset))
+ task.use(trainer(cfg, policy.learn_mode))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1e4))
+ task.use(offline_logger())
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/impala.py b/DI-engine/ding/example/impala.py
new file mode 100644
index 0000000000000000000000000000000000000000..11602012af8d35d6a4e0651c4667f2782e6d4949
--- /dev/null
+++ b/DI-engine/ding/example/impala.py
@@ -0,0 +1,47 @@
+import gym
+from ditk import logging
+from ding.model import VAC
+from ding.policy import IMPALAPolicy
+from ding.envs import SubprocessEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ CkptSaver, online_logger, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.box2d.lunarlander.config.lunarlander_impala_config import main_config, create_config
+from dizoo.box2d.lunarlander.envs import LunarLanderEnv
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: LunarLanderEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: LunarLanderEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ buffer_ = DequeBuffer(
+ size=cfg.policy.other.replay_buffer.replay_buffer_size, sliced=cfg.policy.other.replay_buffer.sliced
+ )
+ policy = IMPALAPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=1024))
+ task.use(data_pusher(cfg, buffer_, group_by_env=True))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(online_logger(train_show_freq=300))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=10000))
+ task.use(termination_checker(max_env_step=2e6))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/iqn_nstep.py b/DI-engine/ding/example/iqn_nstep.py
new file mode 100644
index 0000000000000000000000000000000000000000..eff6df85bfa6a45a57f8b515ab185fb9559f2112
--- /dev/null
+++ b/DI-engine/ding/example/iqn_nstep.py
@@ -0,0 +1,48 @@
+import gym
+from ditk import logging
+from ding.model import IQN
+from ding.policy import IQNPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, nstep_reward_enhancer
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_iqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'cartpole_iqn_nstep'
+ main_config.policy.nstep = 3
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = IQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = IQNPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/mappo.py b/DI-engine/ding/example/mappo.py
new file mode 100644
index 0000000000000000000000000000000000000000..53ca5dff3c18f3b63a66e6ed34c9504a9d2baf91
--- /dev/null
+++ b/DI-engine/ding/example/mappo.py
@@ -0,0 +1,45 @@
+import gym
+from ditk import logging
+from ding.model import MAVAC
+from ding.policy import PPOPolicy
+from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
+ gae_estimator, online_logger, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.petting_zoo.config.ptz_simple_spread_mappo_config import main_config, create_config
+from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = MAVAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(gae_estimator(cfg, policy.collect_mode))
+ task.use(multistep_trainer(policy.learn_mode, log_freq=100))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+ task.use(online_logger(train_show_freq=10))
+ task.use(termination_checker(max_env_step=int(1e6)))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/masac.py b/DI-engine/ding/example/masac.py
new file mode 100644
index 0000000000000000000000000000000000000000..a268c7366b11b35db64fe33205362fb8044d429b
--- /dev/null
+++ b/DI-engine/ding/example/masac.py
@@ -0,0 +1,49 @@
+import gym
+from ditk import logging
+from ding.model import MAQAC
+from ding.policy import SACDiscretePolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, CkptSaver, \
+ data_pusher, online_logger, termination_checker, eps_greedy_handler
+from ding.utils import set_pkg_seed
+from dizoo.petting_zoo.config.ptz_simple_spread_masac_config import main_config, create_config
+from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: PettingZooEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = MAQAC(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = SACDiscretePolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(
+ StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
+ )
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_, log_freq=100))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+ task.use(online_logger(train_show_freq=10))
+ task.use(termination_checker(max_env_step=int(1e6)))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/pdqn.py b/DI-engine/ding/example/pdqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bc173d83c8e8412894323b37af882179b6561f9
--- /dev/null
+++ b/DI-engine/ding/example/pdqn.py
@@ -0,0 +1,45 @@
+import gym
+from ditk import logging
+from ding.model import PDQN
+from ding.policy import PDQNPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver
+from ding.utils import set_pkg_seed
+from dizoo.gym_hybrid.config.gym_hybrid_pdqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make(cfg.env.env_id)) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make(cfg.env.env_id)) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = PDQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = PDQNPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/ppg_offpolicy.py b/DI-engine/ding/example/ppg_offpolicy.py
new file mode 100644
index 0000000000000000000000000000000000000000..70bd211cd56c4d9d73d52f5af5d2c798d3e7298b
--- /dev/null
+++ b/DI-engine/ding/example/ppg_offpolicy.py
@@ -0,0 +1,53 @@
+import gym
+from ditk import logging
+from ding.model import PPG
+from ding.policy import PPGOffPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.data.buffer.middleware import use_time_check, sample_range_view
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ CkptSaver, gae_estimator
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_ppg_offpolicy_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = PPG(**cfg.policy.model)
+ buffer_cfg = cfg.policy.other.replay_buffer
+ max_size = max(buffer_cfg.policy.replay_buffer_size, buffer_cfg.value.replay_buffer_size)
+ buffer_ = DequeBuffer(size=max_size)
+ policy_buffer = buffer_.view() # shallow copy
+ policy_buffer.use(use_time_check(policy_buffer, max_use=buffer_cfg.policy.max_use))
+ policy_buffer.use(sample_range_view(policy_buffer, start=-buffer_cfg.policy.replay_buffer_size))
+ value_buffer = buffer_.view()
+ value_buffer.use(use_time_check(value_buffer, max_use=buffer_cfg.value.max_use))
+ value_buffer.use(sample_range_view(value_buffer, start=-buffer_cfg.value.replay_buffer_size))
+ policy = PPGOffPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(gae_estimator(cfg, policy.collect_mode, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, {'policy': policy_buffer, 'value': value_buffer}))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/ppo.py b/DI-engine/ding/example/ppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9807d1c415ee651c8d0e794258c2f1c1e213d38
--- /dev/null
+++ b/DI-engine/ding/example/ppo.py
@@ -0,0 +1,78 @@
+"""
+# Example of PPO pipeline
+
+Use the pipeline on a single process:
+
+> python3 -u ding/example/ppo.py
+
+Use the pipeline on multiple processes:
+
+We surpose there are N processes (workers) = 1 learner + 1 evaluator + (N-2) collectors
+
+## First Example —— Execute on one machine with multi processes.
+
+Execute 4 processes with 1 learner + 1 evaluator + 2 collectors
+Remember to keep them connected by mesh to ensure that they can exchange information with each other.
+
+> ditask --package . --main ding.example.ppo.main --parallel-workers 4 --topology mesh
+"""
+import gym
+from ditk import logging
+from ding.model import VAC
+from ding.policy import PPOPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
+ gae_estimator, online_logger, ContextExchanger, ModelExchanger
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_ppo_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+
+ # Consider the case with multiple processes
+ if task.router.is_active:
+ # You can use labels to distinguish between workers with different roles,
+ # here we use node_id to distinguish.
+ if task.router.node_id == 0:
+ task.add_role(task.role.LEARNER)
+ elif task.router.node_id == 1:
+ task.add_role(task.role.EVALUATOR)
+ else:
+ task.add_role(task.role.COLLECTOR)
+
+ # Sync their context and model between each worker.
+ task.use(ContextExchanger(skip_n_iter=1))
+ task.use(ModelExchanger(model))
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(gae_estimator(cfg, policy.collect_mode))
+ task.use(multistep_trainer(policy.learn_mode, log_freq=50))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(online_logger(train_show_freq=3))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/ppo_lunarlander.py b/DI-engine/ding/example/ppo_lunarlander.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2e60fe7d6a2b6ca813b501bae510056c9c1829d
--- /dev/null
+++ b/DI-engine/ding/example/ppo_lunarlander.py
@@ -0,0 +1,45 @@
+import gym
+from ditk import logging
+from ding.model import VAC
+from ding.policy import PPOPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
+ gae_estimator, online_logger
+from ding.utils import set_pkg_seed
+from dizoo.box2d.lunarlander.config.lunarlander_ppo_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(gae_estimator(cfg, policy.collect_mode))
+ task.use(multistep_trainer(policy.learn_mode, log_freq=50))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(online_logger(train_show_freq=3))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/ppo_offpolicy.py b/DI-engine/ding/example/ppo_offpolicy.py
new file mode 100644
index 0000000000000000000000000000000000000000..738b27f23060832c79a260a327df7a310064f003
--- /dev/null
+++ b/DI-engine/ding/example/ppo_offpolicy.py
@@ -0,0 +1,45 @@
+import gym
+from ditk import logging
+from ding.model import VAC
+from ding.policy import PPOOffPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.data.buffer.middleware import use_time_check, sample_range_view
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, \
+ CkptSaver, gae_estimator
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_ppo_offpolicy_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = PPOOffPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(gae_estimator(cfg, policy.collect_mode, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/ppo_with_complex_obs.py b/DI-engine/ding/example/ppo_with_complex_obs.py
new file mode 100644
index 0000000000000000000000000000000000000000..a05875ba29b599cb1c5e9ee63d82a85dd937e362
--- /dev/null
+++ b/DI-engine/ding/example/ppo_with_complex_obs.py
@@ -0,0 +1,200 @@
+from typing import Dict
+import os
+import torch
+import torch.nn as nn
+import numpy as np
+import gym
+from gym import spaces
+from ditk import logging
+from ding.envs import DingEnvWrapper, EvalEpisodeReturnWrapper, \
+ BaseEnvManagerV2
+from ding.config import compile_config
+from ding.policy import PPOPolicy
+from ding.utils import set_pkg_seed
+from ding.model import VAC
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
+ gae_estimator, online_logger
+from easydict import EasyDict
+
+my_env_ppo_config = dict(
+ exp_name='my_env_ppo_seed0',
+ env=dict(
+ collector_env_num=4,
+ evaluator_env_num=4,
+ n_evaluator_episode=4,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=None,
+ action_shape=2,
+ action_space='discrete',
+ critic_head_hidden_size=138,
+ actor_head_hidden_size=138,
+ ),
+ learn=dict(
+ epoch_per_collect=2,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ learner=dict(hook=dict(save_ckpt_after_iter=100)),
+ ),
+ collect=dict(
+ n_sample=256, unroll_len=1, discount_factor=0.9, gae_lambda=0.95, collector=dict(transform_obs=True, )
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+my_env_ppo_config = EasyDict(my_env_ppo_config)
+main_config = my_env_ppo_config
+my_env_ppo_create_config = dict(
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo'),
+)
+my_env_ppo_create_config = EasyDict(my_env_ppo_create_config)
+create_config = my_env_ppo_create_config
+
+
+class MyEnv(gym.Env):
+
+ def __init__(self, seq_len=5, feature_dim=10, image_size=(10, 10, 3)):
+ super().__init__()
+
+ # Define the action space
+ self.action_space = spaces.Discrete(2)
+
+ # Define the observation space
+ self.observation_space = spaces.Dict(
+ (
+ {
+ 'key_0': spaces.Dict(
+ {
+ 'k1': spaces.Box(low=0, high=np.inf, shape=(1, ), dtype=np.float32),
+ 'k2': spaces.Box(low=-1, high=1, shape=(1, ), dtype=np.float32),
+ }
+ ),
+ 'key_1': spaces.Box(low=-np.inf, high=np.inf, shape=(seq_len, feature_dim), dtype=np.float32),
+ 'key_2': spaces.Box(low=0, high=255, shape=image_size, dtype=np.uint8),
+ 'key_3': spaces.Box(low=0, high=np.array([np.inf, 3]), shape=(2, ), dtype=np.float32)
+ }
+ )
+ )
+
+ def reset(self):
+ # Generate a random initial state
+ return self.observation_space.sample()
+
+ def step(self, action):
+ # Compute the reward and done flag (which are not used in this example)
+ reward = np.random.uniform(low=0.0, high=1.0)
+
+ done = False
+ if np.random.uniform(low=0.0, high=1.0) > 0.7:
+ done = True
+
+ info = {}
+
+ # Return the next state, reward, and done flag
+ return self.observation_space.sample(), reward, done, info
+
+
+def ding_env_maker():
+ return DingEnvWrapper(
+ MyEnv(), cfg={'env_wrapper': [
+ lambda env: EvalEpisodeReturnWrapper(env),
+ ]}
+ )
+
+
+class Encoder(nn.Module):
+
+ def __init__(self, feature_dim: int):
+ super(Encoder, self).__init__()
+
+ # Define the networks for each input type
+ self.fc_net_1_k1 = nn.Sequential(nn.Linear(1, 8), nn.ReLU())
+ self.fc_net_1_k2 = nn.Sequential(nn.Linear(1, 8), nn.ReLU())
+ self.fc_net_1 = nn.Sequential(nn.Linear(16, 32), nn.ReLU())
+ """
+ Implementation of transformer_encoder refers to Vision Transformer (ViT) code:
+ https://arxiv.org/abs/2010.11929
+ https://pytorch.org/vision/main/_modules/torchvision/models/vision_transformer.html
+ """
+ self.class_token = nn.Parameter(torch.zeros(1, 1, feature_dim))
+ self.encoder_layer = nn.TransformerEncoderLayer(d_model=feature_dim, nhead=2, batch_first=True)
+ self.transformer_encoder = nn.TransformerEncoder(self.encoder_layer, num_layers=1)
+
+ self.conv_net = nn.Sequential(
+ nn.Conv2d(3, 16, kernel_size=3, padding=1), nn.ReLU(), nn.Conv2d(16, 32, kernel_size=3, padding=1),
+ nn.ReLU()
+ )
+ self.conv_fc_net = nn.Sequential(nn.Flatten(), nn.Linear(3200, 64), nn.ReLU())
+
+ self.fc_net_2 = nn.Sequential(nn.Linear(2, 16), nn.ReLU(), nn.Linear(16, 32), nn.ReLU(), nn.Flatten())
+
+ def forward(self, inputs: Dict[str, torch.Tensor]) -> torch.Tensor:
+ # Unpack the input tuple
+ dict_input = inputs['key_0'] # dict{key:(B)}
+ transformer_input = inputs['key_1'] # (B, seq_len, feature_dim)
+ conv_input = inputs['key_2'] # (B, H, W, 3)
+ fc_input = inputs['key_3'] # (B, X)
+
+ B = fc_input.shape[0]
+
+ # Pass each input through its corresponding network
+ dict_output = self.fc_net_1(
+ torch.cat(
+ [self.fc_net_1_k1(dict_input['k1'].unsqueeze(-1)),
+ self.fc_net_1_k2(dict_input['k2'].unsqueeze(-1))],
+ dim=1
+ )
+ )
+
+ batch_class_token = self.class_token.expand(B, -1, -1)
+ transformer_output = self.transformer_encoder(torch.cat([batch_class_token, transformer_input], dim=1))
+ transformer_output = transformer_output[:, 0]
+
+ conv_output = self.conv_fc_net(self.conv_net(conv_input.permute(0, 3, 1, 2)))
+ fc_output = self.fc_net_2(fc_input)
+
+ # Concatenate the outputs along the feature dimension
+ encoded_output = torch.cat([dict_output, transformer_output, conv_output, fc_output], dim=1)
+
+ return encoded_output
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[ding_env_maker for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[ding_env_maker for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ encoder = Encoder(feature_dim=10)
+ model = VAC(encoder=encoder, **cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(gae_estimator(cfg, policy.collect_mode))
+ task.use(multistep_trainer(policy.learn_mode, log_freq=50))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(online_logger(train_show_freq=3))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/qrdqn_nstep.py b/DI-engine/ding/example/qrdqn_nstep.py
new file mode 100644
index 0000000000000000000000000000000000000000..352828cf357b2ae7a348262a00b4bb0f19dece53
--- /dev/null
+++ b/DI-engine/ding/example/qrdqn_nstep.py
@@ -0,0 +1,48 @@
+import gym
+from ditk import logging
+from ding.model import QRDQN
+from ding.policy import QRDQNPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, nstep_reward_enhancer
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'cartpole_qrdqn_nstep'
+ main_config.policy.nstep = 3
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = QRDQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = QRDQNPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/r2d2.py b/DI-engine/ding/example/r2d2.py
new file mode 100644
index 0000000000000000000000000000000000000000..83fc6175637cac70d7a4904133fdb25453c18404
--- /dev/null
+++ b/DI-engine/ding/example/r2d2.py
@@ -0,0 +1,46 @@
+import gym
+from ditk import logging
+from ding.model import DRQN
+from ding.policy import R2D2Policy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, nstep_reward_enhancer
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_r2d2_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DRQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = R2D2Policy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_, group_by_env=True))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/sac.py b/DI-engine/ding/example/sac.py
new file mode 100644
index 0000000000000000000000000000000000000000..d83e552050a9a58b399fc1a2c8ff7a4ef28358d1
--- /dev/null
+++ b/DI-engine/ding/example/sac.py
@@ -0,0 +1,47 @@
+from ditk import logging
+from ding.model import ContinuousQAC
+from ding.policy import SACPolicy
+from ding.envs import BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import data_pusher, StepCollector, interaction_evaluator, \
+ CkptSaver, OffPolicyLearner, termination_checker, online_logger
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv
+from dizoo.classic_control.pendulum.config.pendulum_sac_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = ContinuousQAC(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = SACPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(
+ StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
+ )
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(termination_checker(max_train_iter=10000))
+ task.use(online_logger())
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/sqil.py b/DI-engine/ding/example/sqil.py
new file mode 100644
index 0000000000000000000000000000000000000000..6df54a5724699aab2d19650495350c706d58c8d7
--- /dev/null
+++ b/DI-engine/ding/example/sqil.py
@@ -0,0 +1,65 @@
+import gym
+from ditk import logging
+import torch
+from ding.model import DQN
+from ding.policy import SQLPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, \
+ eps_greedy_handler, CkptSaver, eps_greedy_masker, sqil_data_pusher
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_sql_config import main_config as ex_main_config
+from dizoo.classic_control.cartpole.config.cartpole_sql_config import create_config as ex_create_config
+from dizoo.classic_control.cartpole.config.cartpole_sqil_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ expert_cfg = compile_config(ex_main_config, create_cfg=ex_create_config, auto=True)
+ # expert config must have the same `n_sample`. The line below ensure we do not need to modify the expert configs
+ expert_cfg.policy.collect.n_sample = cfg.policy.collect.n_sample
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ expert_collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ expert_model = DQN(**cfg.policy.model)
+
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ expert_buffer = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+
+ policy = SQLPolicy(cfg.policy, model=model)
+ expert_policy = SQLPolicy(expert_cfg.policy, model=expert_model)
+ state_dict = torch.load(cfg.policy.collect.model_path, map_location='cpu')
+ expert_policy.collect_mode.load_state_dict(state_dict)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env)) # agent data collector
+ task.use(sqil_data_pusher(cfg, buffer_, expert=False))
+ task.use(eps_greedy_masker())
+ task.use(StepCollector(cfg, expert_policy.collect_mode, expert_collector_env)) # expert data collector
+ task.use(sqil_data_pusher(cfg, expert_buffer, expert=True))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, [(buffer_, 0.5), (expert_buffer, 0.5)]))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/sqil_continuous.py b/DI-engine/ding/example/sqil_continuous.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee3d36c9f3ab4fb358ead8f1d31aaf3d5e2631ec
--- /dev/null
+++ b/DI-engine/ding/example/sqil_continuous.py
@@ -0,0 +1,70 @@
+from ditk import logging
+import torch
+from ding.model import ContinuousQAC
+from ding.policy import SQILSACPolicy
+from ding.envs import BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, \
+ CkptSaver, sqil_data_pusher, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv
+from dizoo.classic_control.pendulum.config.pendulum_sac_config import main_config as ex_main_config
+from dizoo.classic_control.pendulum.config.pendulum_sac_config import create_config as ex_create_config
+from dizoo.classic_control.pendulum.config.pendulum_sqil_sac_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ expert_cfg = compile_config(ex_main_config, create_cfg=ex_create_config, auto=True)
+ # expert config must have the same `n_sample`. The line below ensure we do not need to modify the expert configs
+ expert_cfg.policy.collect.n_sample = cfg.policy.collect.n_sample
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ expert_collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = ContinuousQAC(**cfg.policy.model)
+ expert_model = ContinuousQAC(**cfg.policy.model)
+
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ expert_buffer = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+
+ policy = SQILSACPolicy(cfg.policy, model=model)
+ expert_policy = SQILSACPolicy(expert_cfg.policy, model=expert_model)
+ state_dict = torch.load(cfg.policy.collect.model_path, map_location='cpu')
+ expert_policy.collect_mode.load_state_dict(state_dict)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(
+ StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
+ ) # agent data collector
+ task.use(sqil_data_pusher(cfg, buffer_, expert=False))
+ task.use(
+ StepCollector(
+ cfg,
+ expert_policy.collect_mode,
+ expert_collector_env,
+ random_collect_size=cfg.policy.expert_random_collect_size
+ )
+ ) # expert data collector
+ task.use(sqil_data_pusher(cfg, expert_buffer, expert=True))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, [(buffer_, 0.5), (expert_buffer, 0.5)]))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(termination_checker(max_train_iter=10000))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/sql.py b/DI-engine/ding/example/sql.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c2a968082751665219b2b40b4d4c2e5eae4987c
--- /dev/null
+++ b/DI-engine/ding/example/sql.py
@@ -0,0 +1,45 @@
+import gym
+from ditk import logging
+from ding.model import DQN
+from ding.policy import SQLPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.cartpole.config.cartpole_sql_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = SQLPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/td3.py b/DI-engine/ding/example/td3.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d6508dd6fff3e6494dedc110853558749763087
--- /dev/null
+++ b/DI-engine/ding/example/td3.py
@@ -0,0 +1,47 @@
+from ditk import logging
+from ding.model import ContinuousQAC
+from ding.policy import TD3Policy
+from ding.envs import BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import data_pusher, StepCollector, interaction_evaluator, \
+ CkptSaver, OffPolicyLearner, termination_checker, online_logger
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.pendulum.envs.pendulum_env import PendulumEnv
+from dizoo.classic_control.pendulum.config.pendulum_td3_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = ContinuousQAC(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = TD3Policy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(
+ StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
+ )
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(termination_checker(max_train_iter=10000))
+ task.use(online_logger())
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/example/trex.py b/DI-engine/ding/example/trex.py
new file mode 100644
index 0000000000000000000000000000000000000000..97611ba6c245505c99291000dadbb14be3f611be
--- /dev/null
+++ b/DI-engine/ding/example/trex.py
@@ -0,0 +1,59 @@
+import gym
+from tensorboardX import SummaryWriter
+import copy
+import easydict
+import os
+from ditk import logging
+
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, \
+ eps_greedy_handler, CkptSaver, eps_greedy_masker, sqil_data_pusher, data_pusher
+from ding.utils import set_pkg_seed
+from ding.entry import trex_collecting_data
+from ding.reward_model import create_reward_model
+from dizoo.classic_control.cartpole.config.cartpole_trex_dqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ demo_arg = easydict.EasyDict({'cfg': [main_config, create_config], 'seed': 0})
+ trex_collecting_data(demo_arg)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True, renew_dir=False)
+
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("CartPole-v0")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ model = DQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ reward_model = create_reward_model(copy.deepcopy(cfg), policy.collect_mode.get_attribute('device'), tb_logger)
+ reward_model.train()
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_, reward_model))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/ding/framework/__init__.py b/DI-engine/ding/framework/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..72c23d0475eac1a927c3b26faa41066e9b48659d
--- /dev/null
+++ b/DI-engine/ding/framework/__init__.py
@@ -0,0 +1,11 @@
+from .context import Context, OnlineRLContext, OfflineRLContext
+from .task import Task, task, VoidMiddleware
+from .parallel import Parallel
+from .event_loop import EventLoop
+from .supervisor import Supervisor
+from easydict import EasyDict
+from ding.utils import DistributedWriter
+
+
+def ding_init(cfg: EasyDict):
+ DistributedWriter.get_instance(cfg.exp_name)
diff --git a/DI-engine/ding/framework/context.py b/DI-engine/ding/framework/context.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fb35eec13921332bf687ccdd7b3a725c6c559ba
--- /dev/null
+++ b/DI-engine/ding/framework/context.py
@@ -0,0 +1,102 @@
+import numpy as np
+import dataclasses
+import treetensor.torch as ttorch
+from typing import Union, Dict, List
+
+
+@dataclasses.dataclass
+class Context:
+ """
+ Overview:
+ Context is an object that pass contextual data between middlewares, whose life cycle
+ is only one training iteration. It is a dict that reflect itself, so you can set
+ any properties as you wish.
+ Note that the initial value of the property must be equal to False.
+ """
+ _kept_keys: set = dataclasses.field(default_factory=set)
+ total_step: int = 0
+
+ def renew(self) -> 'Context': # noqa
+ """
+ Overview:
+ Renew context from self, add total_step and shift kept properties to the new instance.
+ """
+ total_step = self.total_step
+ ctx = type(self)()
+ for key in self._kept_keys:
+ if self.has_attr(key):
+ setattr(ctx, key, getattr(self, key))
+ ctx.total_step = total_step + 1
+ return ctx
+
+ def keep(self, *keys: str) -> None:
+ """
+ Overview:
+ Keep this key/keys until next iteration.
+ """
+ for key in keys:
+ self._kept_keys.add(key)
+
+ def has_attr(self, key):
+ return hasattr(self, key)
+
+
+# TODO: Restrict data to specific types
+@dataclasses.dataclass
+class OnlineRLContext(Context):
+
+ # common
+ total_step: int = 0
+ env_step: int = 0
+ env_episode: int = 0
+ train_iter: int = 0
+ train_data: Union[Dict, List] = None
+ train_output: Union[Dict, List[Dict]] = None
+ # collect
+ collect_kwargs: Dict = dataclasses.field(default_factory=dict)
+ obs: ttorch.Tensor = None
+ action: List = None
+ inference_output: Dict[int, Dict] = None
+ trajectories: List = None
+ episodes: List = None
+ trajectory_end_idx: List = dataclasses.field(default_factory=list)
+ action: Dict = None
+ inference_output: Dict = None
+ # eval
+ eval_value: float = -np.inf
+ last_eval_iter: int = -1
+ last_eval_value: int = -np.inf
+ eval_output: List = dataclasses.field(default_factory=dict)
+ # wandb
+ wandb_url: str = ""
+
+ def __post_init__(self):
+ # This method is called just after __init__ method. Here, concretely speaking,
+ # this method is called just after the object initialize its fields.
+ # We use this method here to keep the fields needed for each iteration.
+ self.keep('env_step', 'env_episode', 'train_iter', 'last_eval_iter', 'last_eval_value', 'wandb_url')
+
+
+@dataclasses.dataclass
+class OfflineRLContext(Context):
+
+ # common
+ total_step: int = 0
+ trained_env_step: int = 0
+ train_epoch: int = 0
+ train_iter: int = 0
+ train_data: Union[Dict, List] = None
+ train_output: Union[Dict, List[Dict]] = None
+ # eval
+ eval_value: float = -np.inf
+ last_eval_iter: int = -1
+ last_eval_value: int = -np.inf
+ eval_output: List = dataclasses.field(default_factory=dict)
+ # wandb
+ wandb_url: str = ""
+
+ def __post_init__(self):
+ # This method is called just after __init__ method. Here, concretely speaking,
+ # this method is called just after the object initialize its fields.
+ # We use this method here to keep the fields needed for each iteration.
+ self.keep('trained_env_step', 'train_iter', 'last_eval_iter', 'last_eval_value', 'wandb_url')
diff --git a/DI-engine/ding/framework/event_loop.py b/DI-engine/ding/framework/event_loop.py
new file mode 100644
index 0000000000000000000000000000000000000000..6641d07adb8da75384d44894e88ccd8d5a871f92
--- /dev/null
+++ b/DI-engine/ding/framework/event_loop.py
@@ -0,0 +1,126 @@
+from collections import defaultdict
+from typing import Callable, Optional
+from concurrent.futures import ThreadPoolExecutor
+from copy import copy
+import fnmatch
+from ditk import logging
+
+
+class EventLoop:
+ loops = {}
+
+ def __init__(self, name: str = "default") -> None:
+ self._name = name
+ self._listeners = defaultdict(list)
+ self._thread_pool = ThreadPoolExecutor(max_workers=2)
+ self._exception = None
+ self._active = True
+
+ def on(self, event: str, fn: Callable) -> None:
+ """
+ Overview:
+ Subscribe to an event, execute this function every time the event is emitted.
+ Arguments:
+ - event (:obj:`str`): Event name.
+ - fn (:obj:`Callable`): The function.
+ """
+ self._listeners[event].append(fn)
+
+ def off(self, event: str, fn: Optional[Callable] = None) -> None:
+ """
+ Overview:
+ Unsubscribe an event, or a specific function in the event.
+ Arguments:
+ - event (:obj:`str`): Event name.
+ - fn (:obj:`Optional[Callable]`): The function.
+ """
+ for e in fnmatch.filter(self._listeners.keys(), event):
+ if fn:
+ try:
+ self._listeners[e].remove(fn)
+ except:
+ pass
+ else:
+ self._listeners[e] = []
+
+ def once(self, event: str, fn: Callable) -> None:
+ """
+ Overview:
+ Subscribe to an event, execute this function only once when the event is emitted.
+ Arguments:
+ - event (:obj:`str`): Event name.
+ - fn (:obj:`Callable`): The function.
+ """
+
+ def once_callback(*args, **kwargs):
+ self.off(event, once_callback)
+ fn(*args, **kwargs)
+
+ self.on(event, once_callback)
+
+ def emit(self, event: str, *args, **kwargs) -> None:
+ """
+ Overview:
+ Emit an event, call listeners.
+ If there is an unhandled error in this event loop, calling emit will raise an exception,
+ which will cause the process to exit.
+ Arguments:
+ - event (:obj:`str`): Event name.
+ """
+ if self._exception:
+ raise self._exception
+ if self._active:
+ self._thread_pool.submit(self._trigger, event, *args, **kwargs)
+
+ def _trigger(self, event: str, *args, **kwargs) -> None:
+ """
+ Overview:
+ Execute the callbacks under the event. If any callback raise an exception,
+ we will save the traceback and ignore the exception.
+ Arguments:
+ - event (:obj:`str`): Event name.
+ """
+ if event not in self._listeners:
+ logging.debug("Event {} is not registered in the callbacks of {}!".format(event, self._name))
+ return
+ for fn in copy(self._listeners[event]):
+ try:
+ fn(*args, **kwargs)
+ except Exception as e:
+ self._exception = e
+
+ def listened(self, event: str) -> bool:
+ """
+ Overview:
+ Check if the event has been listened to.
+ Arguments:
+ - event (:obj:`str`): Event name
+ Returns:
+ - listened (:obj:`bool`): Whether this event has been listened to.
+ """
+ return event in self._listeners
+
+ @classmethod
+ def get_event_loop(cls: type, name: str = "default") -> "EventLoop":
+ """
+ Overview:
+ Get new event loop when name not exists, or return the existed instance.
+ Arguments:
+ - name (:obj:`str`): Name of event loop.
+ """
+ if name in cls.loops:
+ return cls.loops[name]
+ cls.loops[name] = loop = cls(name)
+ return loop
+
+ def stop(self) -> None:
+ self._active = False
+ self._listeners = defaultdict(list)
+ self._exception = None
+ self._thread_pool.shutdown()
+ if self._name in EventLoop.loops:
+ del EventLoop.loops[self._name]
+
+ def __del__(self) -> None:
+ if self._active:
+ self.stop()
diff --git a/DI-engine/ding/framework/message_queue/__init__.py b/DI-engine/ding/framework/message_queue/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cbbbcd93c7d8c8943be7c047c7ec003d63350c8
--- /dev/null
+++ b/DI-engine/ding/framework/message_queue/__init__.py
@@ -0,0 +1,3 @@
+from .mq import MQ
+from .redis import RedisMQ
+from .nng import NNGMQ
diff --git a/DI-engine/ding/framework/message_queue/mq.py b/DI-engine/ding/framework/message_queue/mq.py
new file mode 100644
index 0000000000000000000000000000000000000000..4386882020557a53c84263db0edcac6c3fe182b2
--- /dev/null
+++ b/DI-engine/ding/framework/message_queue/mq.py
@@ -0,0 +1,66 @@
+from typing import Tuple
+
+
+class MQ:
+ """
+ Overview:
+ Abstract basic mq class.
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ """
+ Overview:
+ The __init__ method of the inheritance must support the extra kwargs parameter.
+ """
+ pass
+
+ def listen(self) -> None:
+ """
+ Overview:
+ Bind to local socket or connect to third party components.
+ """
+ raise NotImplementedError
+
+ def publish(self, topic: str, data: bytes) -> None:
+ """
+ Overview:
+ Send data to mq.
+ Arguments:
+ - topic (:obj:`str`): Topic.
+ - data (:obj:`bytes`): Payload data.
+ """
+ raise NotImplementedError
+
+ def subscribe(self, topic: str) -> None:
+ """
+ Overview:
+ Subscribe to the topic.
+ Arguments:
+ - topic (:obj:`str`): Topic
+ """
+ raise NotImplementedError
+
+ def unsubscribe(self, topic: str) -> None:
+ """
+ Overview:
+ Unsubscribe from the topic.
+ Arguments:
+ - topic (:obj:`str`): Topic
+ """
+ raise NotImplementedError
+
+ def recv(self) -> Tuple[str, bytes]:
+ """
+ Overview:
+ Wait for incoming message, this function will block the current thread.
+ Returns:
+ - data (:obj:`Any`): The sent payload.
+ """
+ raise NotImplementedError
+
+ def stop(self) -> None:
+ """
+ Overview:
+ Unsubscribe from all topics and stop the connection to the message queue server.
+ """
+ return
diff --git a/DI-engine/ding/framework/message_queue/nng.py b/DI-engine/ding/framework/message_queue/nng.py
new file mode 100644
index 0000000000000000000000000000000000000000..379601b0ed8fc2694192e3d7fa64424244a7567a
--- /dev/null
+++ b/DI-engine/ding/framework/message_queue/nng.py
@@ -0,0 +1,73 @@
+import pynng
+from ditk import logging
+from typing import List, Optional, Tuple
+from pynng import Bus0
+from time import sleep
+
+from ding.framework.message_queue.mq import MQ
+from ding.utils import MQ_REGISTRY
+
+
+@MQ_REGISTRY.register("nng")
+class NNGMQ(MQ):
+
+ def __init__(self, listen_to: str, attach_to: Optional[List[str]] = None, **kwargs) -> None:
+ """
+ Overview:
+ Connect distributed processes with nng
+ Arguments:
+ - listen_to (:obj:`Optional[List[str]]`): The node address to attach to.
+ - attach_to (:obj:`Optional[List[str]]`): The node's addresses you want to attach to.
+ """
+ self.listen_to = listen_to
+ self.attach_to = attach_to or []
+ self._sock: Bus0 = None
+ self._running = False
+
+ def listen(self) -> None:
+ self._sock = sock = Bus0()
+ sock.listen(self.listen_to)
+ sleep(0.1) # Wait for peers to bind
+ for contact in self.attach_to:
+ sock.dial(contact)
+ logging.info("NNG listen on {}, attach to {}".format(self.listen_to, self.attach_to))
+ self._running = True
+
+ def publish(self, topic: str, data: bytes) -> None:
+ if self._running:
+ topic += "::"
+ data = topic.encode() + data
+ self._sock.send(data)
+
+ def subscribe(self, topic: str) -> None:
+ return
+
+ def unsubscribe(self, topic: str) -> None:
+ return
+
+ def recv(self) -> Tuple[str, bytes]:
+ while True:
+ try:
+ if not self._running:
+ break
+ msg = self._sock.recv()
+ # Use topic at the beginning of the message, so we don't need to call pickle.loads
+ # when the current process is not subscribed to the topic.
+ topic, payload = msg.split(b"::", maxsplit=1)
+ return topic.decode(), payload
+ except pynng.Timeout:
+ logging.warning("Timeout on node {} when waiting for message from bus".format(self.listen_to))
+ except pynng.Closed:
+ if self._running:
+ logging.error("The socket was not closed under normal circumstances!")
+ except Exception as e:
+ logging.error("Meet exception when listening for new messages", e)
+
+ def stop(self) -> None:
+ if self._running:
+ self._running = False
+ self._sock.close()
+ self._sock = None
+
+ def __del__(self) -> None:
+ self.stop()
diff --git a/DI-engine/ding/framework/message_queue/redis.py b/DI-engine/ding/framework/message_queue/redis.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cbf10e8a6b7ad9f386ed82175462c1b790e1832
--- /dev/null
+++ b/DI-engine/ding/framework/message_queue/redis.py
@@ -0,0 +1,71 @@
+import uuid
+from ditk import logging
+from time import sleep
+from typing import Tuple
+
+import redis
+from ding.framework.message_queue.mq import MQ
+from ding.utils import MQ_REGISTRY
+
+
+@MQ_REGISTRY.register("redis")
+class RedisMQ(MQ):
+
+ def __init__(self, redis_host: str, redis_port: int, **kwargs) -> None:
+ """
+ Overview:
+ Connect distributed processes with redis
+ Arguments:
+ - redis_host (:obj:`str`): Redis server host.
+ - redis_port (:obj:`int`): Redis server port.
+ """
+ self.host = redis_host
+ self.port = redis_port if isinstance(redis_port, int) else int(redis_port)
+ self.db = 0
+ self._running = False
+ self._id = uuid.uuid4().hex.encode()
+
+ def listen(self) -> None:
+ self._client = client = redis.Redis(host=self.host, port=self.port, db=self.db)
+ self._sub = client.pubsub()
+ self._running = True
+
+ def publish(self, topic: str, data: bytes) -> None:
+ data = self._id + b"::" + data
+ self._client.publish(topic, data)
+
+ def subscribe(self, topic: str) -> None:
+ self._sub.subscribe(topic)
+
+ def unsubscribe(self, topic: str) -> None:
+ self._sub.unsubscribe(topic)
+
+ def recv(self) -> Tuple[str, bytes]:
+ while True:
+ if not self._running:
+ raise RuntimeError("Redis MQ was not running!")
+ try:
+ msg = self._sub.get_message(ignore_subscribe_messages=True)
+ if msg is None:
+ sleep(0.001)
+ continue
+ topic = msg["channel"].decode()
+ data = msg["data"].split(b"::", maxsplit=1)
+ if len(data) != 2 or len(data[0]) != 32:
+ logging.warn("Got invalid message from topic: {}".format(topic))
+ continue
+ node_id, data = data
+ if node_id == self._id: # Discard message sent by self
+ continue
+ return topic, data
+ except (OSError, AttributeError, Exception) as e:
+ logging.error("Meet exception when listening for new messages", e)
+
+ def stop(self) -> None:
+ if self._running:
+ self._running = False
+ self._sub.close()
+ self._client.close()
+
+ def __del__(self) -> None:
+ self.stop()
diff --git a/DI-engine/ding/framework/message_queue/tests/test_nng.py b/DI-engine/ding/framework/message_queue/tests/test_nng.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab9cf4e0b31bc9044146d5ebfb414e8f4887a462
--- /dev/null
+++ b/DI-engine/ding/framework/message_queue/tests/test_nng.py
@@ -0,0 +1,32 @@
+from time import sleep
+import pytest
+
+import multiprocessing as mp
+from ding.framework.message_queue.nng import NNGMQ
+
+
+def nng_main(i):
+ if i == 0:
+ listen_to = "tcp://127.0.0.1:50515"
+ attach_to = None
+ mq = NNGMQ(listen_to=listen_to, attach_to=attach_to)
+ mq.listen()
+ for _ in range(10):
+ mq.publish("t", b"data")
+ sleep(0.1)
+ else:
+ listen_to = "tcp://127.0.0.1:50516"
+ attach_to = ["tcp://127.0.0.1:50515"]
+ mq = NNGMQ(listen_to=listen_to, attach_to=attach_to)
+ mq.listen()
+ topic, msg = mq.recv()
+ assert topic == "t"
+ assert msg == b"data"
+
+
+@pytest.mark.unittest
+@pytest.mark.execution_timeout(10)
+def test_nng():
+ ctx = mp.get_context("spawn")
+ with ctx.Pool(processes=2) as pool:
+ pool.map(nng_main, range(2))
diff --git a/DI-engine/ding/framework/message_queue/tests/test_redis.py b/DI-engine/ding/framework/message_queue/tests/test_redis.py
new file mode 100644
index 0000000000000000000000000000000000000000..56f44b5af59f4027987875f231cbdd7eb2e9f119
--- /dev/null
+++ b/DI-engine/ding/framework/message_queue/tests/test_redis.py
@@ -0,0 +1,71 @@
+from time import sleep
+import uuid
+import pytest
+
+from multiprocessing import Pool
+from unittest.mock import Mock, patch
+from threading import Thread
+from ding.utils import WatchDog
+
+from ding.framework.message_queue.redis import RedisMQ
+
+
+def redis_main(i):
+ node_id0 = uuid.uuid4().hex.encode()
+
+ class MockRedis(Mock):
+
+ def publish(self, topic, data):
+ assert topic == "t"
+ assert b"::" in data
+
+ def pubsub(self):
+ return MockPubSub()
+
+ class MockPubSub(Mock):
+
+ def get_message(self, **kwargs):
+ return {"channel": b"t", "data": node_id0 + b"::data"}
+
+ with patch("redis.Redis", MockRedis):
+ host = "127.0.0.1"
+ port = 6379
+ mq = RedisMQ(redis_host=host, redis_port=port)
+ mq.listen()
+ if i == 0:
+ mq._id = node_id0
+
+ def send_message():
+ for _ in range(5):
+ mq.publish("t", b"data")
+ sleep(0.1)
+
+ def recv_message():
+ # Should not receive any message
+ mq.subscribe("t")
+ print("RECV", mq.recv())
+
+ send_thread = Thread(target=send_message, daemon=True)
+ recv_thread = Thread(target=recv_message, daemon=True)
+ send_thread.start()
+ recv_thread.start()
+
+ send_thread.join()
+
+ watchdog = WatchDog(1)
+ with pytest.raises(TimeoutError):
+ watchdog.start()
+ recv_thread.join()
+ watchdog.stop()
+ else:
+ mq.subscribe("t")
+ topic, msg = mq.recv()
+ assert topic == "t"
+ assert msg == b"data"
+
+
+@pytest.mark.unittest
+@pytest.mark.execution_timeout(10)
+def test_redis():
+ with Pool(processes=2) as pool:
+ pool.map(redis_main, range(2))
diff --git a/DI-engine/ding/framework/middleware/__init__.py b/DI-engine/ding/framework/middleware/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9e3c5005d040092b6f45c84933aa9e4acdf72d2
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/__init__.py
@@ -0,0 +1,7 @@
+from .functional import *
+from .collector import StepCollector, EpisodeCollector, PPOFStepCollector
+from .learner import OffPolicyLearner, HERLearner
+from .ckpt_handler import CkptSaver
+from .distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
+from .barrier import Barrier, BarrierRuntime
+from .data_fetcher import OfflineMemoryDataFetcher
diff --git a/DI-engine/ding/framework/middleware/barrier.py b/DI-engine/ding/framework/middleware/barrier.py
new file mode 100644
index 0000000000000000000000000000000000000000..f958c079aec893409402f41c7d7a924e970ff242
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/barrier.py
@@ -0,0 +1,227 @@
+from time import sleep, time
+from ditk import logging
+from ding.framework import task
+from ding.utils.lock_helper import LockContext, LockContextType
+from ding.utils.design_helper import SingletonMetaclass
+
+
+class BarrierRuntime(metaclass=SingletonMetaclass):
+
+ def __init__(self, node_id: int, max_world_size: int = 100):
+ """
+ Overview:
+ 'BarrierRuntime' is a singleton class. In addition, it must be initialized before the
+ class 'Parallel' starts MQ, otherwise the messages sent by other nodes may be lost after
+ the detection is completed. We don't have a message retransmission mechanism, and losing
+ a message means deadlock.
+ Arguments:
+ - node_id (int): Process ID.
+ - max_world_size (int, optional): The maximum total number of processes that can be
+ synchronized, the defalut value is 100.
+ """
+ self.node_id = node_id
+ self._has_detected = False
+ self._range_len = len(str(max_world_size)) + 1
+
+ self._barrier_epoch = 0
+ self._barrier_recv_peers_buff = dict()
+ self._barrier_recv_peers = dict()
+ self._barrier_ack_peers = []
+ self._barrier_lock = LockContext(LockContextType.THREAD_LOCK)
+
+ self.mq_type = task.router.mq_type
+ self._connected_peers = dict()
+ self._connected_peers_lock = LockContext(LockContextType.THREAD_LOCK)
+ self._keep_alive_daemon = False
+
+ self._event_name_detect = "b_det"
+ self.event_name_req = "b_req"
+ self.event_name_ack = "b_ack"
+
+ def _alive_msg_handler(self, peer_id):
+ with self._connected_peers_lock:
+ self._connected_peers[peer_id] = time()
+
+ def _add_barrier_req(self, msg):
+ peer, epoch = self._unpickle_barrier_tag(msg)
+ logging.debug("Node:[{}] recv barrier request from node:{}, epoch:{}".format(self.node_id, peer, epoch))
+ with self._barrier_lock:
+ if peer not in self._barrier_recv_peers:
+ self._barrier_recv_peers[peer] = []
+ self._barrier_recv_peers[peer].append(epoch)
+
+ def _add_barrier_ack(self, peer):
+ logging.debug("Node:[{}] recv barrier ack from node:{}".format(self.node_id, peer))
+ with self._barrier_lock:
+ self._barrier_ack_peers.append(peer)
+
+ def _unpickle_barrier_tag(self, msg):
+ return msg % self._range_len, msg // self._range_len
+
+ def pickle_barrier_tag(self):
+ return int(self._barrier_epoch * self._range_len + self.node_id)
+
+ def reset_all_peers(self):
+ with self._barrier_lock:
+ for peer, q in self._barrier_recv_peers.items():
+ if len(q) != 0:
+ assert q.pop(0) == self._barrier_epoch
+ self._barrier_ack_peers = []
+ self._barrier_epoch += 1
+
+ def get_recv_num(self):
+ count = 0
+ with self._barrier_lock:
+ if len(self._barrier_recv_peers) > 0:
+ for _, q in self._barrier_recv_peers.items():
+ if len(q) > 0 and q[0] == self._barrier_epoch:
+ count += 1
+ return count
+
+ def get_ack_num(self):
+ with self._barrier_lock:
+ return len(self._barrier_ack_peers)
+
+ def detect_alive(self, expected, timeout):
+ # The barrier can only block other nodes within the visible range of the current node.
+ # If the 'attch_to' list of a node is empty, it does not know how many nodes will attach to him,
+ # so we cannot specify the effective range of a barrier in advance.
+ assert task._running
+ task.on(self._event_name_detect, self._alive_msg_handler)
+ task.on(self.event_name_req, self._add_barrier_req)
+ task.on(self.event_name_ack, self._add_barrier_ack)
+ start = time()
+ while True:
+ sleep(0.1)
+ task.emit(self._event_name_detect, self.node_id, only_remote=True)
+ # In case the other node has not had time to receive our detect message,
+ # we will send an additional round.
+ if self._has_detected:
+ break
+ with self._connected_peers_lock:
+ if len(self._connected_peers) == expected:
+ self._has_detected = True
+
+ if time() - start > timeout:
+ raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id))
+
+ task.off(self._event_name_detect)
+ logging.info(
+ "Barrier detect node done, node-[{}] has connected with {} active nodes!".format(self.node_id, expected)
+ )
+
+
+class BarrierContext:
+
+ def __init__(self, runtime: BarrierRuntime, detect_timeout, expected_peer_num: int = 0):
+ self._runtime = runtime
+ self._expected_peer_num = expected_peer_num
+ self._timeout = detect_timeout
+
+ def __enter__(self):
+ if not self._runtime._has_detected:
+ self._runtime.detect_alive(self._expected_peer_num, self._timeout)
+
+ def __exit__(self, exc_type, exc_value, tb):
+ if exc_type is not None:
+ import traceback
+ traceback.print_exception(exc_type, exc_value, tb)
+ self._runtime.reset_all_peers()
+
+
+class Barrier:
+
+ def __init__(self, attch_from_nums: int, timeout: int = 60):
+ """
+ Overview:
+ Barrier() is a middleware for debug or profiling. It can synchronize the task step of each
+ process within the scope of all visible processes. When using Barrier(), you need to pay
+ attention to the following points:
+
+ 1. All processes must call the same number of Barrier(), otherwise a deadlock occurs.
+
+ 2. 'attch_from_nums' is a very important variable, This value indicates the number of times
+ the current process will be attached to by other processes (the number of connections
+ established).
+ For example:
+ Node0: address: 127.0.0.1:12345, attach_to = []
+ Node1: address: 127.0.0.1:12346, attach_to = ["tcp://127.0.0.1:12345"]
+ For Node0, the 'attch_from_nums' value is 1. (It will be acttched by Node1)
+ For Node1, the 'attch_from_nums' value is 0. (No one will attach to Node1)
+ Please note that this value must be given correctly, otherwise, for a node whose 'attach_to'
+ list is empty, it cannot perceive how many processes will establish connections with it,
+ resulting in any form of synchronization cannot be performed.
+
+ 3. Barrier() is thread-safe, but it is not recommended to use barrier in multithreading. You need
+ to carefully calculate the number of times each thread calls Barrier() to avoid deadlock.
+
+ 4. In normal training tasks, please do not use Barrier(), which will force the step synchronization
+ between each process, so it will greatly damage the training efficiency. In addition, if your
+ training task has dynamic processes, do not use Barrier() to prevent deadlock.
+
+ Arguments:
+ - attch_from_nums (int): [description]
+ - timeout (int, optional): The timeout for successful detection of 'expected_peer_num'
+ number of nodes, the default value is 60 seconds.
+ """
+ self.node_id = task.router.node_id
+ self.timeout = timeout
+ self._runtime: BarrierRuntime = task.router.barrier_runtime
+ self._barrier_peers_nums = task.get_attch_to_len() + attch_from_nums
+
+ logging.info(
+ "Node:[{}], attach to num is:{}, attach from num is:{}".format(
+ self.node_id, task.get_attch_to_len(), attch_from_nums
+ )
+ )
+
+ def __call__(self, ctx):
+ self._wait_barrier(ctx)
+ yield
+ self._wait_barrier(ctx)
+
+ def _wait_barrier(self, ctx):
+ self_ready = False
+ with BarrierContext(self._runtime, self.timeout, self._barrier_peers_nums):
+ logging.debug("Node:[{}] enter barrier".format(self.node_id))
+ # Step1: Notifies all the attached nodes that we have reached the barrier.
+ task.emit(self._runtime.event_name_req, self._runtime.pickle_barrier_tag(), only_remote=True)
+ logging.debug("Node:[{}] sended barrier request".format(self.node_id))
+
+ # Step2: We check the number of flags we have received.
+ # In the current CI design of DI-engine, there will always be a node whose 'attach_to' list is empty,
+ # so there will always be a node that will send ACK unconditionally, so deadlock will not occur.
+ if self._runtime.get_recv_num() == self._barrier_peers_nums:
+ self_ready = True
+
+ # Step3: Waiting for our own to be ready.
+ # Even if the current process has reached the barrier, we will not send an ack immediately,
+ # we need to wait for the slowest directly connected or indirectly connected peer to
+ # reach the barrier.
+ start = time()
+ if not self_ready:
+ while True:
+ if time() - start > self.timeout:
+ raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id))
+
+ if self._runtime.get_recv_num() != self._barrier_peers_nums:
+ sleep(0.1)
+ else:
+ break
+
+ # Step4: Notifies all attached nodes that we are ready.
+ task.emit(self._runtime.event_name_ack, self.node_id, only_remote=True)
+ logging.debug("Node:[{}] sended barrier ack".format(self.node_id))
+
+ # Step5: Wait until all directly or indirectly connected nodes are ready.
+ start = time()
+ while True:
+ if time() - start > self.timeout:
+ raise TimeoutError("Node-[{}] timeout when waiting barrier! ".format(task.router.node_id))
+
+ if self._runtime.get_ack_num() != self._barrier_peers_nums:
+ sleep(0.1)
+ else:
+ break
+
+ logging.info("Node-[{}] env_step:[{}] barrier finish".format(self.node_id, ctx.env_step))
diff --git a/DI-engine/ding/framework/middleware/ckpt_handler.py b/DI-engine/ding/framework/middleware/ckpt_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca75f1661807443d470dd2b62f00b2b62a559ad2
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/ckpt_handler.py
@@ -0,0 +1,74 @@
+from typing import TYPE_CHECKING, Optional, Union
+from easydict import EasyDict
+import os
+import numpy as np
+
+from ding.utils import save_file
+from ding.policy import Policy
+from ding.framework import task
+
+if TYPE_CHECKING:
+ from ding.framework import OnlineRLContext, OfflineRLContext
+
+
+class CkptSaver:
+ """
+ Overview:
+ The class used to save checkpoint data.
+ """
+
+ def __new__(cls, *args, **kwargs):
+ if task.router.is_active and not (task.has_role(task.role.LEARNER) or task.has_role(task.role.EVALUATOR)):
+ return task.void()
+ return super(CkptSaver, cls).__new__(cls)
+
+ def __init__(self, policy: Policy, save_dir: str, train_freq: Optional[int] = None, save_finish: bool = True):
+ """
+ Overview:
+ Initialize the `CkptSaver`.
+ Arguments:
+ - policy (:obj:`Policy`): Policy used to save the checkpoint.
+ - save_dir (:obj:`str`): The directory path to save ckpt.
+ - train_freq (:obj:`int`): Number of training iterations between each saving checkpoint data.
+ - save_finish (:obj:`bool`): Whether save final ckpt when ``task.finish = True``.
+ """
+ self.policy = policy
+ self.train_freq = train_freq
+ if str(os.path.basename(os.path.normpath(save_dir))) != "ckpt":
+ self.prefix = '{}/ckpt'.format(os.path.normpath(save_dir))
+ else:
+ self.prefix = '{}/'.format(os.path.normpath(save_dir))
+ if not os.path.exists(self.prefix):
+ os.makedirs(self.prefix)
+ self.last_save_iter = 0
+ self.max_eval_value = -np.inf
+ self.save_finish = save_finish
+
+ def __call__(self, ctx: Union["OnlineRLContext", "OfflineRLContext"]) -> None:
+ """
+ Overview:
+ The method used to save checkpoint data. \
+ The checkpoint data will be saved in a file in following 3 cases: \
+ - When a multiple of `self.train_freq` iterations have elapsed since the beginning of training; \
+ - When the evaluation episode return is the best so far; \
+ - When `task.finish` is True.
+ Input of ctx:
+ - train_iter (:obj:`int`): Number of training iteration, i.e. the number of updating policy related network.
+ - eval_value (:obj:`float`): The episode return of current iteration.
+ """
+ # train enough iteration
+ if self.train_freq:
+ if ctx.train_iter == 0 or ctx.train_iter - self.last_save_iter >= self.train_freq:
+ save_file(
+ "{}/iteration_{}.pth.tar".format(self.prefix, ctx.train_iter), self.policy.learn_mode.state_dict()
+ )
+ self.last_save_iter = ctx.train_iter
+
+ # best episode return so far
+ if ctx.eval_value is not None and ctx.eval_value > self.max_eval_value:
+ save_file("{}/eval.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict())
+ self.max_eval_value = ctx.eval_value
+
+ # finish
+ if task.finish and self.save_finish:
+ save_file("{}/final.pth.tar".format(self.prefix), self.policy.learn_mode.state_dict())
diff --git a/DI-engine/ding/framework/middleware/collector.py b/DI-engine/ding/framework/middleware/collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..beb4894ad99c13a2bae273c6b91ac3247ffd14ff
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/collector.py
@@ -0,0 +1,193 @@
+from typing import TYPE_CHECKING
+from easydict import EasyDict
+import treetensor.torch as ttorch
+
+from ding.policy import get_random_policy
+from ding.envs import BaseEnvManager
+from ding.framework import task
+from .functional import inferencer, rolloutor, TransitionList
+
+if TYPE_CHECKING:
+ from ding.framework import OnlineRLContext
+
+
+class StepCollector:
+ """
+ Overview:
+ The class of the collector running by steps, including model inference and transition \
+ process. Use the `__call__` method to execute the whole collection process.
+ """
+
+ def __new__(cls, *args, **kwargs):
+ if task.router.is_active and not task.has_role(task.role.COLLECTOR):
+ return task.void()
+ return super(StepCollector, cls).__new__(cls)
+
+ def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None:
+ """
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config.
+ - policy (:obj:`Policy`): The policy to be collected.
+ - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
+ its derivatives are supported.
+ - random_collect_size (:obj:`int`): The count of samples that will be collected randomly, \
+ typically used in initial runs.
+ """
+ self.cfg = cfg
+ self.env = env
+ self.policy = policy
+ self.random_collect_size = random_collect_size
+ self._transitions = TransitionList(self.env.env_num)
+ self._inferencer = task.wrap(inferencer(cfg.seed, policy, env))
+ self._rolloutor = task.wrap(rolloutor(policy, env, self._transitions))
+
+ def __call__(self, ctx: "OnlineRLContext") -> None:
+ """
+ Overview:
+ An encapsulation of inference and rollout middleware. Stop when completing \
+ the target number of steps.
+ Input of ctx:
+ - env_step (:obj:`int`): The env steps which will increase during collection.
+ """
+ old = ctx.env_step
+ if self.random_collect_size > 0 and old < self.random_collect_size:
+ target_size = self.random_collect_size - old
+ random_policy = get_random_policy(self.cfg, self.policy, self.env)
+ current_inferencer = task.wrap(inferencer(self.cfg.seed, random_policy, self.env))
+ else:
+ # compatible with old config, a train sample = unroll_len step
+ target_size = self.cfg.policy.collect.n_sample * self.cfg.policy.collect.unroll_len
+ current_inferencer = self._inferencer
+
+ while True:
+ current_inferencer(ctx)
+ self._rolloutor(ctx)
+ if ctx.env_step - old >= target_size:
+ ctx.trajectories, ctx.trajectory_end_idx = self._transitions.to_trajectories()
+ self._transitions.clear()
+ break
+
+
+class PPOFStepCollector:
+ """
+ Overview:
+ The class of the collector running by steps, including model inference and transition \
+ process. Use the `__call__` method to execute the whole collection process.
+ """
+
+ def __new__(cls, *args, **kwargs):
+ if task.router.is_active and not task.has_role(task.role.COLLECTOR):
+ return task.void()
+ return super(PPOFStepCollector, cls).__new__(cls)
+
+ def __init__(self, seed: int, policy, env: BaseEnvManager, n_sample: int, unroll_len: int = 1) -> None:
+ """
+ Arguments:
+ - seed (:obj:`int`): Random seed.
+ - policy (:obj:`Policy`): The policy to be collected.
+ - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
+ its derivatives are supported.
+ """
+ self.env = env
+ self.env.seed(seed)
+ self.policy = policy
+ self.n_sample = n_sample
+ self.unroll_len = unroll_len
+ self._transitions = TransitionList(self.env.env_num)
+ self._env_episode_id = [_ for _ in range(env.env_num)]
+ self._current_id = env.env_num
+
+ def __call__(self, ctx: "OnlineRLContext") -> None:
+ """
+ Overview:
+ An encapsulation of inference and rollout middleware. Stop when completing \
+ the target number of steps.
+ Input of ctx:
+ - env_step (:obj:`int`): The env steps which will increase during collection.
+ """
+ device = self.policy._device
+ old = ctx.env_step
+ target_size = self.n_sample * self.unroll_len
+
+ if self.env.closed:
+ self.env.launch()
+
+ while True:
+ obs = ttorch.as_tensor(self.env.ready_obs).to(dtype=ttorch.float32)
+ obs = obs.to(device)
+ inference_output = self.policy.collect(obs, **ctx.collect_kwargs)
+ inference_output = inference_output.cpu()
+ action = inference_output.action.numpy()
+ timesteps = self.env.step(action)
+ ctx.env_step += len(timesteps)
+
+ obs = obs.cpu()
+ for i, timestep in enumerate(timesteps):
+ transition = self.policy.process_transition(obs[i], inference_output[i], timestep)
+ transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter])
+ transition.env_data_id = ttorch.as_tensor([self._env_episode_id[timestep.env_id]])
+ self._transitions.append(timestep.env_id, transition)
+ if timestep.done:
+ self.policy.reset([timestep.env_id])
+ self._env_episode_id[timestep.env_id] = self._current_id
+ self._current_id += 1
+ ctx.env_episode += 1
+
+ if ctx.env_step - old >= target_size:
+ ctx.trajectories, ctx.trajectory_end_idx = self._transitions.to_trajectories()
+ self._transitions.clear()
+ break
+
+
+class EpisodeCollector:
+ """
+ Overview:
+ The class of the collector running by episodes, including model inference and transition \
+ process. Use the `__call__` method to execute the whole collection process.
+ """
+
+ def __init__(self, cfg: EasyDict, policy, env: BaseEnvManager, random_collect_size: int = 0) -> None:
+ """
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config.
+ - policy (:obj:`Policy`): The policy to be collected.
+ - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
+ its derivatives are supported.
+ - random_collect_size (:obj:`int`): The count of samples that will be collected randomly, \
+ typically used in initial runs.
+ """
+ self.cfg = cfg
+ self.env = env
+ self.policy = policy
+ self.random_collect_size = random_collect_size
+ self._transitions = TransitionList(self.env.env_num)
+ self._inferencer = task.wrap(inferencer(cfg.seed, policy, env))
+ self._rolloutor = task.wrap(rolloutor(policy, env, self._transitions))
+
+ def __call__(self, ctx: "OnlineRLContext") -> None:
+ """
+ Overview:
+ An encapsulation of inference and rollout middleware. Stop when completing the \
+ target number of episodes.
+ Input of ctx:
+ - env_episode (:obj:`int`): The env env_episode which will increase during collection.
+ """
+ old = ctx.env_episode
+ if self.random_collect_size > 0 and old < self.random_collect_size:
+ target_size = self.random_collect_size - old
+ random_policy = get_random_policy(self.cfg, self.policy, self.env)
+ current_inferencer = task.wrap(inferencer(self.cfg, random_policy, self.env))
+ else:
+ target_size = self.cfg.policy.collect.n_episode
+ current_inferencer = self._inferencer
+
+ while True:
+ current_inferencer(ctx)
+ self._rolloutor(ctx)
+ if ctx.env_episode - old >= target_size:
+ ctx.episodes = self._transitions.to_episodes()
+ self._transitions.clear()
+ break
+
+
+# TODO battle collector
diff --git a/DI-engine/ding/framework/middleware/data_fetcher.py b/DI-engine/ding/framework/middleware/data_fetcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..2103a8668d2df8cc1fc622dececda6dbb58b0d4f
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/data_fetcher.py
@@ -0,0 +1,99 @@
+from typing import TYPE_CHECKING
+from threading import Thread, Event
+from queue import Queue
+import time
+import numpy as np
+import torch
+from easydict import EasyDict
+from ding.framework import task
+from ding.data import Dataset, DataLoader
+from ding.utils import get_rank, get_world_size
+
+if TYPE_CHECKING:
+ from ding.framework import OfflineRLContext
+
+
+class OfflineMemoryDataFetcher:
+
+ def __new__(cls, *args, **kwargs):
+ if task.router.is_active and not task.has_role(task.role.FETCHER):
+ return task.void()
+ return super(OfflineMemoryDataFetcher, cls).__new__(cls)
+
+ def __init__(self, cfg: EasyDict, dataset: Dataset):
+ device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
+ if device != 'cpu':
+ stream = torch.cuda.Stream()
+
+ def producer(queue, dataset, batch_size, device, event):
+ torch.set_num_threads(4)
+ if device != 'cpu':
+ nonlocal stream
+ sbatch_size = batch_size * get_world_size()
+ rank = get_rank()
+ idx_list = np.random.permutation(len(dataset))
+ temp_idx_list = []
+ for i in range(len(dataset) // sbatch_size):
+ temp_idx_list.extend(idx_list[i + rank * batch_size:i + (rank + 1) * batch_size])
+ idx_iter = iter(temp_idx_list)
+
+ if device != 'cpu':
+ with torch.cuda.stream(stream):
+ while True:
+ if queue.full():
+ time.sleep(0.1)
+ else:
+ data = []
+ for _ in range(batch_size):
+ try:
+ data.append(dataset.__getitem__(next(idx_iter)))
+ except StopIteration:
+ del idx_iter
+ idx_list = np.random.permutation(len(dataset))
+ idx_iter = iter(idx_list)
+ data.append(dataset.__getitem__(next(idx_iter)))
+ data = [[i[j] for i in data] for j in range(len(data[0]))]
+ data = [torch.stack(x).to(device) for x in data]
+ queue.put(data)
+ if event.is_set():
+ break
+ else:
+ while True:
+ if queue.full():
+ time.sleep(0.1)
+ else:
+ data = []
+ for _ in range(batch_size):
+ try:
+ data.append(dataset.__getitem__(next(idx_iter)))
+ except StopIteration:
+ del idx_iter
+ idx_list = np.random.permutation(len(dataset))
+ idx_iter = iter(idx_list)
+ data.append(dataset.__getitem__(next(idx_iter)))
+ data = [[i[j] for i in data] for j in range(len(data[0]))]
+ data = [torch.stack(x) for x in data]
+ queue.put(data)
+ if event.is_set():
+ break
+
+ self.queue = Queue(maxsize=50)
+ self.event = Event()
+ self.producer_thread = Thread(
+ target=producer,
+ args=(self.queue, dataset, cfg.policy.batch_size, device, self.event),
+ name='cuda_fetcher_producer'
+ )
+
+ def __call__(self, ctx: "OfflineRLContext"):
+ if not self.producer_thread.is_alive():
+ time.sleep(5)
+ self.producer_thread.start()
+ while self.queue.empty():
+ time.sleep(0.001)
+ ctx.train_data = self.queue.get()
+
+ def __del__(self):
+ if self.producer_thread.is_alive():
+ self.event.set()
+ del self.queue
diff --git a/DI-engine/ding/framework/middleware/distributer.py b/DI-engine/ding/framework/middleware/distributer.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2f5e36402cee1cf68beb6e77a8de6709a001c5b
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/distributer.py
@@ -0,0 +1,415 @@
+import numpy as np
+from time import sleep, time
+from dataclasses import fields
+from typing import TYPE_CHECKING, List, Dict, Any, Optional, Union
+from ditk import logging
+from ding.framework import task
+from ding.data import StorageLoader, Storage, ModelLoader
+if TYPE_CHECKING:
+ from ding.framework.context import Context
+ from torch.nn import Module
+
+
+class ContextExchanger:
+
+ def __init__(self, skip_n_iter: int = 1, storage_loader: Optional[StorageLoader] = None) -> None:
+ """
+ Overview:
+ Exchange context between processes,
+ support properties: trajectories, episodes, env_step, env_episode, train_iter
+ Arguments:
+ - skip_n_iter (:obj:`int`): For collectors, it may be necessary to skip waiting \
+ for the first n iterations to collect data for the learner to learn. This parameter \
+ will not work on learner.
+ - storage_loader (:obj:`Optional[StorageLoader]`): Turn data into storage class to reduce \
+ the network overhead.
+ """
+ if not task.router.is_active:
+ raise RuntimeError("ContextHandler should be used in parallel mode!")
+ self._state = {}
+ self._local_state = {} # just save local state, not send to remote node
+ if task.has_role(task.role.COLLECTOR):
+ self._local_state['env_step'] = 0
+ self._local_state['env_episode'] = 0
+ self._event_name = "context_exchanger_{role}"
+ self._skip_n_iter = skip_n_iter
+ self._storage_loader = storage_loader
+ for role in task.role: # Only subscribe to other roles
+ if not task.has_role(role):
+ task.on(self._event_name.format(role=role), self.put)
+ if storage_loader:
+ task.once("finish", lambda _: storage_loader.shutdown())
+
+ def __new__(cls, *args, **kwargs):
+ if not task.router.is_active:
+ return task.void()
+
+ if len(task.roles) == 0:
+ logging.warning("The task does not have any roles defined, the ContextExchanger will not work.")
+ return task.void()
+
+ if len(task.roles) > 1:
+ logging.warning(
+ "Use multiple roles in one exchanger may lead to unexpected result, please check your code."
+ )
+
+ return super(ContextExchanger, cls).__new__(cls)
+
+ def __call__(self, ctx: "Context"):
+ self.merge(ctx)
+ yield
+ payload = self.fetch(ctx)
+ if payload:
+ if self._storage_loader and task.has_role(task.role.COLLECTOR):
+ payload = self._storage_loader.save(payload)
+ for role in task.roles:
+ task.emit(self._event_name.format(role=role), payload, only_remote=True)
+
+ def __del__(self):
+ if self._storage_loader:
+ self._storage_loader.shutdown()
+
+ def put(self, payload: Union[Dict, Storage]):
+ """
+ Overview:
+ Get attributes from ctx on the callback of event.
+ Each attribute should have a standalone put handler, which named `_put_{key}`
+ """
+
+ def callback(payload: Dict):
+ for key, item in payload.items():
+ fn_name = "_put_{}".format(key)
+ if hasattr(self, fn_name):
+ getattr(self, fn_name)(item)
+ else:
+ logging.warning("Receive unexpected key ({}) in context exchanger".format(key))
+
+ if isinstance(payload, Storage):
+ assert self._storage_loader is not None, "Storage loader is not defined when data is a storage object."
+ self._storage_loader.load(payload, callback)
+ else:
+ callback(payload)
+
+ def fetch(self, ctx: "Context") -> Dict[str, Any]:
+ """
+ Overview:
+ Fetch attributes from ctx before emit them to the event bus.
+ Each attribute should have a standalone fetch handler, which named `_fetch_{key}`
+ """
+ payload = {}
+ for field in fields(ctx):
+ key, item = field.name, getattr(ctx, field.name)
+ fn_name = "_fetch_{}".format(key)
+ if hasattr(self, fn_name):
+ value = getattr(self, fn_name)(item)
+ if value is not None:
+ payload[key] = value
+ return payload
+
+ def merge(self, ctx: "Context"):
+ if task.has_role(task.role.LEARNER):
+ # Learner should always wait for trajs.
+ # TODO: Automaticlly wait based on properties, not roles.
+ while len(self._state) == 0:
+ sleep(0.01)
+ elif ctx.total_step >= self._skip_n_iter:
+ start = time()
+ while len(self._state) == 0:
+ if time() - start > 60:
+ logging.warning("Timeout when waiting for new context! Node id: {}".format(task.router.node_id))
+ break
+ sleep(0.01)
+
+ for k, v in self._state.items():
+ if not task.has_role(task.role.COLLECTOR) and k.startswith('increment_'):
+ pure_k = k.split('increment_')[-1]
+ setattr(ctx, pure_k, getattr(ctx, pure_k) + v)
+ else:
+ setattr(ctx, k, v)
+ self._state = {}
+
+ # Handle each attibute of context
+ def _put_trajectories(self, traj: List[Any]):
+ if not task.has_role(task.role.LEARNER):
+ return
+ if "trajectories" not in self._state:
+ self._state["trajectories"] = []
+ self._state["trajectories"].extend(traj)
+
+ def _fetch_trajectories(self, traj: List[Any]):
+ if task.has_role(task.role.COLLECTOR):
+ return traj
+
+ def _put_episodes(self, episodes: List[Any]):
+ if not task.has_role(task.role.LEARNER):
+ return
+ if "episodes" not in self._state:
+ self._state["episodes"] = []
+ self._state["episodes"].extend(episodes)
+
+ def _fetch_episodes(self, episodes: List[Any]):
+ if task.has_role(task.role.COLLECTOR):
+ return episodes
+
+ def _put_trajectory_end_idx(self, trajectory_end_idx: List[str]):
+ if not task.has_role(task.role.LEARNER):
+ return
+ if "trajectory_end_idx" not in self._state:
+ self._state["trajectory_end_idx"] = []
+ self._state["trajectory_end_idx"].extend(trajectory_end_idx)
+
+ def _fetch_trajectory_end_idx(self, trajectory_end_idx: List[str]):
+ if task.has_role(task.role.COLLECTOR):
+ return trajectory_end_idx
+
+ def _put_env_step(self, increment_env_step: int):
+ if not task.has_role(task.role.COLLECTOR):
+ if 'increment_env_step' not in self._state:
+ self._state['increment_env_step'] = 0
+ self._state["increment_env_step"] += increment_env_step
+
+ def _fetch_env_step(self, env_step: int):
+ if task.has_role(task.role.COLLECTOR):
+ increment_env_step = env_step - self._local_state['env_step']
+ self._local_state['env_step'] = env_step
+ return increment_env_step
+
+ def _put_env_episode(self, increment_env_episode: int):
+ if not task.has_role(task.role.COLLECTOR):
+ if 'increment_env_episode' not in self._state:
+ self._state['increment_env_episode'] = 0
+ self._state["increment_env_episode"] += increment_env_episode
+
+ def _fetch_env_episode(self, env_episode: int):
+ if task.has_role(task.role.COLLECTOR):
+ increment_env_episode = env_episode - self._local_state['env_episode']
+ self._local_state['env_episode'] = env_episode
+ return increment_env_episode
+
+ def _put_train_iter(self, train_iter: int):
+ if not task.has_role(task.role.LEARNER):
+ self._state["train_iter"] = train_iter
+
+ def _fetch_train_iter(self, train_iter: int):
+ if task.has_role(task.role.LEARNER):
+ return train_iter
+
+
+class ModelExchanger:
+
+ def __init__(self, model: "Module", model_loader: Optional[ModelLoader] = None) -> None:
+ """
+ Overview:
+ Exchange model between processes, only the learner will send the model,
+ otherwise the model will only be received.
+ If you are using a shared model on a single host, there is no need to use this middleware.
+ Arguments:
+ - model (:obj:`torch.nn.Module`): Pytorch module.
+ - model_loader (:obj:`ModelLoader`): Encode model in subprocess.
+ """
+ self._model = model
+ self._model_loader = model_loader
+ self._event_name = "model_exchanger"
+ self._state_dict_cache: Optional[Union[object, Storage]] = None
+ self._is_learner = task.has_role(task.role.LEARNER)
+ if not self._is_learner:
+ task.on(self._event_name, self._cache_state_dict)
+ if model_loader:
+ task.once("finish", lambda _: model_loader.shutdown())
+
+ def _cache_state_dict(self, state_dict: Union[object, Storage]):
+ self._state_dict_cache = state_dict
+
+ def __new__(cls, *args, **kwargs):
+ if not task.router.is_active:
+ return task.void()
+
+ if len(task.roles) == 0:
+ logging.warning("The task does not have any roles defined, the ModelExchanger will not work.")
+ return task.void()
+
+ if len(task.roles) > 1:
+ logging.warning(
+ "Use multiple roles in one exchanger may lead to unexpected result, please check your code."
+ )
+
+ return super(ModelExchanger, cls).__new__(cls)
+
+ def __call__(self, ctx: "Context") -> Any:
+ if self._model_loader:
+ self._model_loader.start()
+
+ if not self._is_learner:
+ if ctx.total_step != 0: # Skip first iteration
+ self._update_model()
+ else:
+ yield
+ self._send_model()
+
+ def _update_model(self):
+ start = time()
+ while True:
+ if task.finish:
+ return
+ if time() - start > 60:
+ logging.warning("Timeout when waiting for new model! Node id: {}".format(task.router.node_id))
+ break
+ if self._state_dict_cache is None:
+ sleep(0.01)
+ else:
+ if isinstance(self._state_dict_cache, Storage) and self._model_loader is not None:
+ try:
+ self._model.load_state_dict(self._model_loader.load(self._state_dict_cache))
+ self._state_dict_cache = None
+ break
+ except FileNotFoundError as e:
+ logging.warning(
+ "Model file has been deleted on node {}, maybe you can increase the ttl.".format(
+ task.router.node_id
+ )
+ )
+ self._state_dict_cache = None
+ continue
+ else:
+ self._model.load_state_dict(self._state_dict_cache)
+ self._state_dict_cache = None
+ break
+
+ def _send_model(self):
+ if self._model_loader:
+ self._model_loader.save(self._send_callback)
+ else:
+ task.emit(self._event_name, self._model.state_dict(), only_remote=True)
+
+ def _send_callback(self, storage: Storage):
+ if task.running:
+ task.emit(self._event_name, storage, only_remote=True)
+
+ def __del__(self):
+ if self._model_loader:
+ self._model_loader.shutdown()
+
+
+class PeriodicalModelExchanger:
+
+ def __init__(
+ self,
+ model: "Module",
+ mode: str,
+ period: int = 1,
+ delay_toleration: float = np.inf,
+ stale_toleration: int = 1,
+ event_name: str = "model_exchanger",
+ model_loader: Optional[ModelLoader] = None
+ ) -> None:
+ """
+ Overview:
+ Exchange model between processes, set the mode to "send" or "receive" to specify the role of the process.
+ If you are using a shared model on a single host, there is no need to use this middleware.
+ Arguments:
+ - model (:obj:`torch.nn.Module`): Pytorch module.
+ - mode (:obj:`str`): "send" or "receive".
+ - period (:obj:`int`): The period of model exchange.
+ - delay_toleration (:obj:`float`): The permitted time interval for receiving model after being sent.
+ - stale_toleration (:obj:`int`): The permitted number of iterations for receiving model after being sent.
+ - event_name (:obj:`str`): The event name for model exchange.
+ - model_loader (:obj:`ModelLoader`): ModelLoader for this PeriodicalModelExchanger to use.
+ """
+ self._model = model
+ self._model_loader = model_loader
+ self._event_name = event_name
+ self._period = period
+ self._mode = mode
+ if self._mode == "receive":
+ self._id_counter = -1
+ self._model_id = -1
+ else:
+ self._id_counter = 0
+ self._stale_toleration = stale_toleration
+ self._model_stale = stale_toleration
+ self._delay_toleration = delay_toleration
+ self._state_dict_cache: Optional[Union[object, Storage]] = None
+
+ if self._mode == "receive":
+ task.on(self._event_name, self._cache_state_dict)
+ if model_loader:
+ task.once("finish", lambda _: model_loader.shutdown())
+
+ def _cache_state_dict(self, msg: Dict[str, Any]):
+ if msg['id'] % self._period == 0:
+ self._state_dict_cache = msg['model']
+ self._id_counter = msg['id']
+ self._time = msg['time']
+
+ def __new__(cls, *args, **kwargs):
+ return super(PeriodicalModelExchanger, cls).__new__(cls)
+
+ def __call__(self, ctx: "Context") -> Any:
+ if self._model_loader:
+ self._model_loader.start()
+
+ if self._mode == "receive":
+ if ctx.total_step != 0: # Skip first iteration
+ self._update_model()
+ elif self._mode == "send":
+ yield
+ if self._id_counter % self._period == 0:
+ self._send_model(id=self._id_counter)
+ self._id_counter += 1
+ else:
+ raise NotImplementedError
+
+ def _update_model(self):
+ start = time()
+ while True:
+ if task.finish:
+ return
+ if time() - start > 60:
+ logging.warning("Timeout when waiting for new model! Node id: {}".format(task.router.node_id))
+ self._model_stale += 1
+ break
+ if self._state_dict_cache is None:
+ if self._model_stale < self._stale_toleration and time() - self._time < self._delay_toleration:
+ self._model_stale += 1
+ break
+ else:
+ sleep(0.01)
+ else:
+ if self._id_counter > self._model_id and time() - self._time < self._delay_toleration:
+ if isinstance(self._state_dict_cache, Storage) and self._model_loader is not None:
+ try:
+ self._model.load_state_dict(self._model_loader.load(self._state_dict_cache))
+ self._state_dict_cache = None
+ self._model_id = self._id_counter
+ self._model_stale = 1
+ break
+ except FileNotFoundError as e:
+ logging.warning(
+ "Model file has been deleted on node {}, maybe you can increase the ttl.".format(
+ task.router.node_id
+ )
+ )
+ self._state_dict_cache = None
+ continue
+ else:
+ self._model.load_state_dict(self._state_dict_cache)
+ self._state_dict_cache = None
+ self._model_id = self._id_counter
+ self._model_stale = 1
+ break
+ else:
+ self._model_stale += 1
+
+ def _send_model(self, id: int):
+ if self._model_loader:
+ self._model_loader.save(self._send_callback)
+ else:
+ task.emit(self._event_name, {'id': id, 'model': self._model.state_dict(), 'time': time()}, only_remote=True)
+
+ def _send_callback(self, storage: Storage):
+ if task.running:
+ task.emit(self._event_name, storage, only_remote=True)
+
+ def __del__(self):
+ if self._model_loader:
+ self._model_loader.shutdown()
diff --git a/DI-engine/ding/framework/middleware/functional/__init__.py b/DI-engine/ding/framework/middleware/functional/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8474f2626e01ededc1bcfc6b7f88e8bb096c352a
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/functional/__init__.py
@@ -0,0 +1,15 @@
+from .trainer import trainer, multistep_trainer
+from .data_processor import offpolicy_data_fetcher, data_pusher, offline_data_fetcher, offline_data_saver, \
+ offline_data_fetcher_from_mem, sqil_data_pusher, buffer_saver
+from .collector import inferencer, rolloutor, TransitionList
+from .evaluator import interaction_evaluator, interaction_evaluator_ttorch
+from .termination_checker import termination_checker, ddp_termination_checker
+from .logger import online_logger, offline_logger, wandb_online_logger, wandb_offline_logger
+from .ctx_helper import final_ctx_saver
+
+# algorithm
+from .explorer import eps_greedy_handler, eps_greedy_masker
+from .advantage_estimator import gae_estimator, ppof_adv_estimator, montecarlo_return_estimator
+from .enhancer import reward_estimator, her_data_enhancer, nstep_reward_enhancer
+from .priority import priority_calculator
+from .timer import epoch_timer
diff --git a/DI-engine/ding/framework/middleware/functional/advantage_estimator.py b/DI-engine/ding/framework/middleware/functional/advantage_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb80089fe2451fbc231ea90170b6e3970a5725fe
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/functional/advantage_estimator.py
@@ -0,0 +1,156 @@
+from typing import TYPE_CHECKING, Callable, Optional
+from easydict import EasyDict
+from ditk import logging
+import torch
+import treetensor.torch as ttorch
+from ding.policy import Policy
+from ding.data import Buffer
+from ding.rl_utils import gae, gae_data, get_train_sample
+from ding.framework import task
+from ding.utils.data import ttorch_collate
+from ding.torch_utils import to_device
+
+if TYPE_CHECKING:
+ from ding.framework import OnlineRLContext
+
+
+def gae_estimator(cfg: EasyDict, policy: Policy, buffer_: Optional[Buffer] = None) -> Callable:
+ """
+ Overview:
+ Calculate value using observation of input data, then call function `gae` to get advantage. \
+ The processed data will be pushed into `buffer_` if `buffer_` is not None, \
+ otherwise it will be assigned to `ctx.train_data`.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config which should contain the following keys: \
+ `cfg.policy.collect.discount_factor`, `cfg.policy.collect.gae_lambda`.
+ - policy (:obj:`Policy`): Policy in `policy.collect_mode`, used to get model to calculate value.
+ - buffer\_ (:obj:`Optional[Buffer]`): The `buffer_` to push the processed data in if `buffer_` is not None.
+ """
+ if task.router.is_active and not task.has_role(task.role.LEARNER):
+ return task.void()
+
+ model = policy.get_attribute('model')
+ # Unify the shape of obs and action
+ obs_shape = cfg['policy']['model']['obs_shape']
+ obs_shape = torch.Size(torch.tensor(obs_shape)) if isinstance(obs_shape, list) \
+ else torch.Size(torch.tensor(obs_shape).unsqueeze(0))
+ action_shape = cfg['policy']['model']['action_shape']
+ action_shape = torch.Size(torch.tensor(action_shape)) if isinstance(action_shape, list) \
+ else torch.Size(torch.tensor(action_shape).unsqueeze(0))
+
+ def _gae(ctx: "OnlineRLContext"):
+ """
+ Input of ctx:
+ - trajectories (:obj:`List[treetensor.torch.Tensor]`): The data to be processed.\
+ Each element should contain the following keys: `obs`, `next_obs`, `reward`, `done`.
+ - trajectory_end_idx: (:obj:`treetensor.torch.IntTensor`):
+ The indices that define the end of trajectories, \
+ which should be shorter than the length of `ctx.trajectories`.
+ Output of ctx:
+ - train_data (:obj:`List[treetensor.torch.Tensor]`): The processed data if `buffer_` is None.
+ """
+ cuda = cfg.policy.cuda and torch.cuda.is_available()
+
+ # action shape (B,) for discete action, (B, D,) for continuous action
+ # reward shape (B,) done shape (B,) value shape (B,)
+ data = ttorch_collate(ctx.trajectories, cat_1dim=True)
+ if data['action'].dtype in [torch.float16, torch.float32, torch.double] \
+ and data['action'].dim() == 1:
+ # action shape
+ data['action'] = data['action'].unsqueeze(-1)
+
+ with torch.no_grad():
+ if cuda:
+ data = data.cuda()
+ value = model.forward(data.obs.to(dtype=ttorch.float32), mode='compute_critic')['value']
+ next_value = model.forward(data.next_obs.to(dtype=ttorch.float32), mode='compute_critic')['value']
+ data.value = value
+
+ traj_flag = data.done.clone()
+ traj_flag[ctx.trajectory_end_idx] = True
+ data.traj_flag = traj_flag
+
+ # done is bool type when acquired from env.step
+ data_ = gae_data(data.value, next_value, data.reward, data.done.float(), traj_flag.float())
+ data.adv = gae(data_, cfg.policy.collect.discount_factor, cfg.policy.collect.gae_lambda)
+ if buffer_ is None:
+ ctx.train_data = data
+ else:
+ data = data.cpu()
+ data = ttorch.split(data, 1)
+ # To ensure the shape of obs is same as config
+ if data[0]['obs'].shape == obs_shape:
+ pass
+ elif data[0]['obs'].shape[0] == 1 and data[0]['obs'].shape[1:] == obs_shape:
+ for d in data:
+ d['obs'] = d['obs'].squeeze(0)
+ d['next_obs'] = d['next_obs'].squeeze(0)
+ if 'logit' in data[0]:
+ for d in data:
+ d['logit'] = d['logit'].squeeze(0)
+ if 'log_prob' in data[0]:
+ for d in data:
+ d['log_prob'] = d['log_prob'].squeeze(0)
+ else:
+ raise RuntimeError("The shape of obs is {}, which is not same as config.".format(data[0]['obs'].shape))
+
+ if data[0]['action'].dtype in [torch.float16, torch.float32, torch.double] \
+ and data[0]['action'].dim() == 2:
+ for d in data:
+ d['action'] = d['action'].squeeze(0)
+ for d in data:
+ buffer_.push(d)
+ ctx.trajectories = None
+
+ return _gae
+
+
+def ppof_adv_estimator(policy: Policy) -> Callable:
+ if task.router.is_active and not task.has_role(task.role.LEARNER):
+ return task.void()
+
+ def _estimator(ctx: "OnlineRLContext"):
+ data = ttorch_collate(ctx.trajectories, cat_1dim=True)
+ if data['action'].dtype == torch.float32 and data['action'].dim() == 1:
+ data['action'] = data['action'].unsqueeze(-1)
+ traj_flag = data.done.clone()
+ traj_flag[ctx.trajectory_end_idx] = True
+ data.traj_flag = traj_flag
+ ctx.train_data = data
+
+ return _estimator
+
+
+def montecarlo_return_estimator(policy: Policy) -> Callable:
+ if task.router.is_active and not task.has_role(task.role.LEARNER):
+ return task.void()
+
+ def pg_policy_get_train_sample(data):
+ assert data[-1]['done'], "PG needs a complete epsiode"
+
+ if policy._cfg.learn.ignore_done:
+ raise NotImplementedError
+
+ R = 0.
+ if isinstance(data, ttorch.Tensor):
+ data_size = data['done'].shape[0]
+ data['return'] = ttorch.Tensor([0.0 for i in range(data_size)])
+ for i in reversed(range(data_size)):
+ R = policy._gamma * R + data['reward'][i]
+ data['return'][i] = R
+ return get_train_sample(data, policy._unroll_len)
+ else:
+ raise ValueError
+
+ def _estimator(ctx: "OnlineRLContext"):
+ train_data = []
+ for episode in ctx.episodes:
+ data = ttorch_collate(episode, cat_1dim=True)
+ if data['action'].dtype in [torch.float16, torch.float32, torch.double] \
+ and data['action'].dim() == 1:
+ data['action'] = data['action'].unsqueeze(-1)
+ data = pg_policy_get_train_sample(data)
+ train_data.append(data)
+ ctx.train_data = ttorch.cat(train_data, dim=0)
+
+ return _estimator
diff --git a/DI-engine/ding/framework/middleware/functional/collector.py b/DI-engine/ding/framework/middleware/functional/collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2fb4483b9e9ac43dc59b455dff76c805c59e179
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/functional/collector.py
@@ -0,0 +1,213 @@
+from typing import TYPE_CHECKING, Callable, List, Tuple, Any
+from functools import reduce
+import treetensor.torch as ttorch
+import numpy as np
+from ditk import logging
+from ding.utils import EasyTimer
+from ding.envs import BaseEnvManager
+from ding.policy import Policy
+from ding.torch_utils import to_ndarray, get_shape0
+
+if TYPE_CHECKING:
+ from ding.framework import OnlineRLContext
+
+
+class TransitionList:
+
+ def __init__(self, env_num: int) -> None:
+ self.env_num = env_num
+ self._transitions = [[] for _ in range(env_num)]
+ self._done_idx = [[] for _ in range(env_num)]
+
+ def append(self, env_id: int, transition: Any) -> None:
+ self._transitions[env_id].append(transition)
+ if transition.done:
+ self._done_idx[env_id].append(len(self._transitions[env_id]))
+
+ def to_trajectories(self) -> Tuple[List[Any], List[int]]:
+ trajectories = sum(self._transitions, [])
+ lengths = [len(t) for t in self._transitions]
+ trajectory_end_idx = [reduce(lambda x, y: x + y, lengths[:i + 1]) for i in range(len(lengths))]
+ trajectory_end_idx = [t - 1 for t in trajectory_end_idx]
+ return trajectories, trajectory_end_idx
+
+ def to_episodes(self) -> List[List[Any]]:
+ episodes = []
+ for env_id in range(self.env_num):
+ last_idx = 0
+ for done_idx in self._done_idx[env_id]:
+ episodes.append(self._transitions[env_id][last_idx:done_idx])
+ last_idx = done_idx
+ return episodes
+
+ def clear(self):
+ for item in self._transitions:
+ item.clear()
+ for item in self._done_idx:
+ item.clear()
+
+
+def inferencer(seed: int, policy: Policy, env: BaseEnvManager) -> Callable:
+ """
+ Overview:
+ The middleware that executes the inference process.
+ Arguments:
+ - seed (:obj:`int`): Random seed.
+ - policy (:obj:`Policy`): The policy to be inferred.
+ - env (:obj:`BaseEnvManager`): The env where the inference process is performed. \
+ The env.ready_obs (:obj:`tnp.array`) will be used as model input.
+ """
+
+ env.seed(seed)
+
+ def _inference(ctx: "OnlineRLContext"):
+ """
+ Output of ctx:
+ - obs (:obj:`Union[torch.Tensor, Dict[torch.Tensor]]`): The input observations collected \
+ from all collector environments.
+ - action: (:obj:`List[np.ndarray]`): The inferred actions listed by env_id.
+ - inference_output (:obj:`Dict[int, Dict]`): The dict of which the key is env_id (int), \
+ and the value is inference result (Dict).
+ """
+
+ if env.closed:
+ env.launch()
+
+ obs = ttorch.as_tensor(env.ready_obs)
+ ctx.obs = obs
+ obs = obs.to(dtype=ttorch.float32)
+ # TODO mask necessary rollout
+
+ obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD
+ inference_output = policy.forward(obs, **ctx.collect_kwargs)
+ ctx.action = [to_ndarray(v['action']) for v in inference_output.values()] # TBD
+ ctx.inference_output = inference_output
+
+ return _inference
+
+
+def rolloutor(
+ policy: Policy,
+ env: BaseEnvManager,
+ transitions: TransitionList,
+ collect_print_freq=100,
+) -> Callable:
+ """
+ Overview:
+ The middleware that executes the transition process in the env.
+ Arguments:
+ - policy (:obj:`Policy`): The policy to be used during transition.
+ - env (:obj:`BaseEnvManager`): The env for the collection, the BaseEnvManager object or \
+ its derivatives are supported.
+ - transitions (:obj:`TransitionList`): The transition information which will be filled \
+ in this process, including `obs`, `next_obs`, `action`, `logit`, `value`, `reward` \
+ and `done`.
+ """
+
+ env_episode_id = [_ for _ in range(env.env_num)]
+ current_id = env.env_num
+ timer = EasyTimer()
+ last_train_iter = 0
+ total_envstep_count = 0
+ total_episode_count = 0
+ total_train_sample_count = 0
+ env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(env.env_num)}
+ episode_info = []
+
+ def _rollout(ctx: "OnlineRLContext"):
+ """
+ Input of ctx:
+ - action: (:obj:`List[np.ndarray]`): The inferred actions from previous inference process.
+ - obs (:obj:`Dict[Tensor]`): The states fed into the transition dict.
+ - inference_output (:obj:`Dict[int, Dict]`): The inference results to be fed into the \
+ transition dict.
+ - train_iter (:obj:`int`): The train iteration count to be fed into the transition dict.
+ - env_step (:obj:`int`): The count of env step, which will increase by 1 for a single \
+ transition call.
+ - env_episode (:obj:`int`): The count of env episode, which will increase by 1 if the \
+ trajectory stops.
+ """
+
+ nonlocal current_id, env_info, episode_info, timer, \
+ total_episode_count, total_envstep_count, total_train_sample_count, last_train_iter
+ timesteps = env.step(ctx.action)
+ ctx.env_step += len(timesteps)
+ timesteps = [t.tensor() for t in timesteps]
+
+ collected_sample = 0
+ collected_step = 0
+ collected_episode = 0
+ interaction_duration = timer.value / len(timesteps)
+ for i, timestep in enumerate(timesteps):
+ with timer:
+ transition = policy.process_transition(ctx.obs[i], ctx.inference_output[i], timestep)
+ transition = ttorch.as_tensor(transition)
+ transition.collect_train_iter = ttorch.as_tensor([ctx.train_iter])
+ transition.env_data_id = ttorch.as_tensor([env_episode_id[timestep.env_id]])
+ transitions.append(timestep.env_id, transition)
+
+ collected_step += 1
+ collected_sample += len(transition.obs)
+ env_info[timestep.env_id.item()]['step'] += 1
+ env_info[timestep.env_id.item()]['train_sample'] += len(transition.obs)
+
+ env_info[timestep.env_id.item()]['time'] += timer.value + interaction_duration
+ if timestep.done:
+ info = {
+ 'reward': timestep.info['eval_episode_return'],
+ 'time': env_info[timestep.env_id.item()]['time'],
+ 'step': env_info[timestep.env_id.item()]['step'],
+ 'train_sample': env_info[timestep.env_id.item()]['train_sample'],
+ }
+
+ episode_info.append(info)
+ policy.reset([timestep.env_id.item()])
+ env_episode_id[timestep.env_id.item()] = current_id
+ collected_episode += 1
+ current_id += 1
+ ctx.env_episode += 1
+
+ total_envstep_count += collected_step
+ total_episode_count += collected_episode
+ total_train_sample_count += collected_sample
+
+ if (ctx.train_iter - last_train_iter) >= collect_print_freq and len(episode_info) > 0:
+ output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count)
+ last_train_iter = ctx.train_iter
+
+ return _rollout
+
+
+def output_log(episode_info, total_episode_count, total_envstep_count, total_train_sample_count) -> None:
+ """
+ Overview:
+ Print the output log information. You can refer to the docs of `Best Practice` to understand \
+ the training generated logs and tensorboards.
+ Arguments:
+ - train_iter (:obj:`int`): the number of training iteration.
+ """
+ episode_count = len(episode_info)
+ envstep_count = sum([d['step'] for d in episode_info])
+ train_sample_count = sum([d['train_sample'] for d in episode_info])
+ duration = sum([d['time'] for d in episode_info])
+ episode_return = [d['reward'].item() for d in episode_info]
+ info = {
+ 'episode_count': episode_count,
+ 'envstep_count': envstep_count,
+ 'train_sample_count': train_sample_count,
+ 'avg_envstep_per_episode': envstep_count / episode_count,
+ 'avg_sample_per_episode': train_sample_count / episode_count,
+ 'avg_envstep_per_sec': envstep_count / duration,
+ 'avg_train_sample_per_sec': train_sample_count / duration,
+ 'avg_episode_per_sec': episode_count / duration,
+ 'reward_mean': np.mean(episode_return),
+ 'reward_std': np.std(episode_return),
+ 'reward_max': np.max(episode_return),
+ 'reward_min': np.min(episode_return),
+ 'total_envstep_count': total_envstep_count,
+ 'total_train_sample_count': total_train_sample_count,
+ 'total_episode_count': total_episode_count,
+ # 'each_reward': episode_return,
+ }
+ episode_info.clear()
+ logging.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))
diff --git a/DI-engine/ding/framework/middleware/functional/ctx_helper.py b/DI-engine/ding/framework/middleware/functional/ctx_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c3254079b14630081931ce368fd284a000d685b
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/functional/ctx_helper.py
@@ -0,0 +1,29 @@
+from typing import TYPE_CHECKING, Callable
+import os
+import pickle
+import dataclasses
+from ding.framework import task
+if TYPE_CHECKING:
+ from ding.framework import Context
+
+
+def final_ctx_saver(name: str) -> Callable:
+
+ def _save(ctx: "Context"):
+ if task.finish:
+ # make sure the items to be recorded are all kept in the context
+ with open(os.path.join(name, 'result.pkl'), 'wb') as f:
+ final_data = {
+ 'total_step': ctx.total_step,
+ 'train_iter': ctx.train_iter,
+ 'last_eval_iter': ctx.last_eval_iter,
+ 'eval_value': ctx.last_eval_value,
+ }
+ if ctx.has_attr('env_step'):
+ final_data['env_step'] = ctx.env_step
+ final_data['env_episode'] = ctx.env_episode
+ if ctx.has_attr('trained_env_step'):
+ final_data['trained_env_step'] = ctx.trained_env_step
+ pickle.dump(final_data, f)
+
+ return _save
diff --git a/DI-engine/ding/framework/middleware/functional/data_processor.py b/DI-engine/ding/framework/middleware/functional/data_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbcc39e7a2b9dd4bc7aa6b3c2e9f04943c1defd4
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/functional/data_processor.py
@@ -0,0 +1,321 @@
+import os
+from typing import TYPE_CHECKING, Callable, List, Union, Tuple, Dict, Optional
+from easydict import EasyDict
+from ditk import logging
+import torch
+from ding.data import Buffer, Dataset, DataLoader, offline_data_save_type
+from ding.data.buffer.middleware import PriorityExperienceReplay
+from ding.framework import task
+from ding.utils import get_rank
+
+if TYPE_CHECKING:
+ from ding.framework import OnlineRLContext, OfflineRLContext
+
+
+def data_pusher(cfg: EasyDict, buffer_: Buffer, group_by_env: Optional[bool] = None):
+ """
+ Overview:
+ Push episodes or trajectories into the buffer.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config.
+ - buffer (:obj:`Buffer`): Buffer to push the data in.
+ """
+ if task.router.is_active and not task.has_role(task.role.LEARNER):
+ return task.void()
+
+ def _push(ctx: "OnlineRLContext"):
+ """
+ Overview:
+ In ctx, either `ctx.trajectories` or `ctx.episodes` should not be None.
+ Input of ctx:
+ - trajectories (:obj:`List[Dict]`): Trajectories.
+ - episodes (:obj:`List[Dict]`): Episodes.
+ """
+
+ if ctx.trajectories is not None: # each data in buffer is a transition
+ if group_by_env:
+ for i, t in enumerate(ctx.trajectories):
+ buffer_.push(t, {'env': t.env_data_id.item()})
+ else:
+ for t in ctx.trajectories:
+ buffer_.push(t)
+ ctx.trajectories = None
+ elif ctx.episodes is not None: # each data in buffer is a episode
+ for t in ctx.episodes:
+ buffer_.push(t)
+ ctx.episodes = None
+ else:
+ raise RuntimeError("Either ctx.trajectories or ctx.episodes should be not None.")
+
+ return _push
+
+
+def buffer_saver(cfg: EasyDict, buffer_: Buffer, every_envstep: int = 1000, replace: bool = False):
+ """
+ Overview:
+ Save current buffer data.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config.
+ - buffer (:obj:`Buffer`): Buffer to push the data in.
+ - every_envstep (:obj:`int`): save at every env step.
+ - replace (:obj:`bool`): Whether replace the last file.
+ """
+
+ buffer_saver_env_counter = -every_envstep
+
+ def _save(ctx: "OnlineRLContext"):
+ """
+ Overview:
+ In ctx, `ctx.env_step` should not be None.
+ Input of ctx:
+ - env_step (:obj:`int`): env step.
+ """
+ nonlocal buffer_saver_env_counter
+ if ctx.env_step is not None:
+ if ctx.env_step >= every_envstep + buffer_saver_env_counter:
+ buffer_saver_env_counter = ctx.env_step
+ if replace:
+ buffer_.save_data(os.path.join(cfg.exp_name, "replaybuffer", "data_latest.hkl"))
+ else:
+ buffer_.save_data(
+ os.path.join(cfg.exp_name, "replaybuffer", "data_envstep_{}.hkl".format(ctx.env_step))
+ )
+ else:
+ raise RuntimeError("buffer_saver only supports collecting data by step rather than episode.")
+
+ return _save
+
+
+def offpolicy_data_fetcher(
+ cfg: EasyDict,
+ buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]],
+ data_shortage_warning: bool = False,
+) -> Callable:
+ """
+ Overview:
+ The return function is a generator which meanly fetch a batch of data from a buffer, \
+ a list of buffers, or a dict of buffers.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`.
+ - buffer (:obj:`Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]]`): \
+ The buffer where the data is fetched from. \
+ ``Buffer`` type means a buffer.\
+ ``List[Tuple[Buffer, float]]`` type means a list of tuple. In each tuple there is a buffer and a float. \
+ The float defines, how many batch_size is the size of the data \
+ which is sampled from the corresponding buffer.\
+ ``Dict[str, Buffer]`` type means a dict in which the value of each element is a buffer. \
+ For each key-value pair of dict, batch_size of data will be sampled from the corresponding buffer \
+ and assigned to the same key of `ctx.train_data`.
+ - data_shortage_warning (:obj:`bool`): Whether to output warning when data shortage occurs in fetching.
+ """
+
+ def _fetch(ctx: "OnlineRLContext"):
+ """
+ Input of ctx:
+ - train_output (:obj:`Union[Dict, Deque[Dict]]`): This attribute should exist \
+ if `buffer_` is of type Buffer and if `buffer_` use the middleware `PriorityExperienceReplay`. \
+ The meta data `priority` of the sampled data in the `buffer_` will be updated \
+ to the `priority` attribute of `ctx.train_output` if `ctx.train_output` is a dict, \
+ or the `priority` attribute of `ctx.train_output`'s popped element \
+ if `ctx.train_output` is a deque of dicts.
+ Output of ctx:
+ - train_data (:obj:`Union[List[Dict], Dict[str, List[Dict]]]`): The fetched data. \
+ ``List[Dict]`` type means a list of data.
+ `train_data` is of this type if the type of `buffer_` is Buffer or List.
+ ``Dict[str, List[Dict]]]`` type means a dict, in which the value of each key-value pair
+ is a list of data. `train_data` is of this type if the type of `buffer_` is Dict.
+ """
+ try:
+ unroll_len = cfg.policy.collect.unroll_len
+ if isinstance(buffer_, Buffer):
+ if unroll_len > 1:
+ buffered_data = buffer_.sample(
+ cfg.policy.learn.batch_size, groupby="env", unroll_len=unroll_len, replace=True
+ )
+ ctx.train_data = [[t.data for t in d] for d in buffered_data] # B, unroll_len
+ else:
+ buffered_data = buffer_.sample(cfg.policy.learn.batch_size)
+ ctx.train_data = [d.data for d in buffered_data]
+ elif isinstance(buffer_, List): # like sqil, r2d3
+ assert unroll_len == 1, "not support"
+ buffered_data = []
+ for buffer_elem, p in buffer_:
+ data_elem = buffer_elem.sample(int(cfg.policy.learn.batch_size * p))
+ assert data_elem is not None
+ buffered_data.append(data_elem)
+ buffered_data = sum(buffered_data, [])
+ ctx.train_data = [d.data for d in buffered_data]
+ elif isinstance(buffer_, Dict): # like ppg_offpolicy
+ assert unroll_len == 1, "not support"
+ buffered_data = {k: v.sample(cfg.policy.learn.batch_size) for k, v in buffer_.items()}
+ ctx.train_data = {k: [d.data for d in v] for k, v in buffered_data.items()}
+ else:
+ raise TypeError("not support buffer argument type: {}".format(type(buffer_)))
+
+ assert buffered_data is not None
+ except (ValueError, AssertionError):
+ if data_shortage_warning:
+ # You can modify data collect config to avoid this warning, e.g. increasing n_sample, n_episode.
+ # Fetcher will skip this this attempt.
+ logging.warning(
+ "Replay buffer's data is not enough to support training, so skip this training to wait more data."
+ )
+ ctx.train_data = None
+ return
+
+ yield
+
+ if isinstance(buffer_, Buffer):
+ if any([isinstance(m, PriorityExperienceReplay) for m in buffer_._middleware]):
+ index = [d.index for d in buffered_data]
+ meta = [d.meta for d in buffered_data]
+ # such as priority
+ if isinstance(ctx.train_output, List):
+ priority = ctx.train_output.pop()['priority']
+ else:
+ priority = ctx.train_output['priority']
+ for idx, m, p in zip(index, meta, priority):
+ m['priority'] = p
+ buffer_.update(index=idx, data=None, meta=m)
+
+ return _fetch
+
+
+def offline_data_fetcher_from_mem(cfg: EasyDict, dataset: Dataset) -> Callable:
+
+ from threading import Thread
+ from queue import Queue
+ import time
+ stream = torch.cuda.Stream()
+
+ def producer(queue, dataset, batch_size, device):
+ torch.set_num_threads(4)
+ nonlocal stream
+ idx_iter = iter(range(len(dataset) - batch_size))
+
+ if len(dataset) < batch_size:
+ logging.warning('batch_size is too large!!!!')
+ with torch.cuda.stream(stream):
+ while True:
+ if queue.full():
+ time.sleep(0.1)
+ else:
+ try:
+ start_idx = next(idx_iter)
+ except StopIteration:
+ del idx_iter
+ idx_iter = iter(range(len(dataset) - batch_size))
+ start_idx = next(idx_iter)
+ data = [dataset.__getitem__(idx) for idx in range(start_idx, start_idx + batch_size)]
+ data = [[i[j] for i in data] for j in range(len(data[0]))]
+ data = [torch.stack(x).to(device) for x in data]
+ queue.put(data)
+
+ queue = Queue(maxsize=50)
+ device = 'cuda:{}'.format(get_rank() % torch.cuda.device_count()) if cfg.policy.cuda else 'cpu'
+ producer_thread = Thread(
+ target=producer, args=(queue, dataset, cfg.policy.learn.batch_size, device), name='cuda_fetcher_producer'
+ )
+
+ def _fetch(ctx: "OfflineRLContext"):
+ nonlocal queue, producer_thread
+ if not producer_thread.is_alive():
+ time.sleep(5)
+ producer_thread.start()
+ while queue.empty():
+ time.sleep(0.001)
+ ctx.train_data = queue.get()
+
+ return _fetch
+
+
+def offline_data_fetcher(cfg: EasyDict, dataset: Dataset) -> Callable:
+ """
+ Overview:
+ The outer function transforms a Pytorch `Dataset` to `DataLoader`. \
+ The return function is a generator which each time fetches a batch of data from the previous `DataLoader`.\
+ Please refer to the link https://pytorch.org/tutorials/beginner/basics/data_tutorial.html \
+ and https://pytorch.org/docs/stable/data.html for more details.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config which should contain the following keys: `cfg.policy.learn.batch_size`.
+ - dataset (:obj:`Dataset`): The dataset of type `torch.utils.data.Dataset` which stores the data.
+ """
+ # collate_fn is executed in policy now
+ dataloader = DataLoader(dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)
+ dataloader = iter(dataloader)
+
+ def _fetch(ctx: "OfflineRLContext"):
+ """
+ Overview:
+ Every time this generator is iterated, the fetched data will be assigned to ctx.train_data. \
+ After the dataloader is empty, the attribute `ctx.train_epoch` will be incremented by 1.
+ Input of ctx:
+ - train_epoch (:obj:`int`): Number of `train_epoch`.
+ Output of ctx:
+ - train_data (:obj:`List[Tensor]`): The fetched data batch.
+ """
+ nonlocal dataloader
+ try:
+ ctx.train_data = next(dataloader) # noqa
+ except StopIteration:
+ ctx.train_epoch += 1
+ del dataloader
+ dataloader = DataLoader(
+ dataset, batch_size=cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x
+ )
+ dataloader = iter(dataloader)
+ ctx.train_data = next(dataloader)
+ # TODO apply data update (e.g. priority) in offline setting when necessary
+ ctx.trained_env_step += len(ctx.train_data)
+
+ return _fetch
+
+
+def offline_data_saver(data_path: str, data_type: str = 'hdf5') -> Callable:
+ """
+ Overview:
+ Save the expert data of offline RL in a directory.
+ Arguments:
+ - data_path (:obj:`str`): File path where the expert data will be written into, which is usually ./expert.pkl'.
+ - data_type (:obj:`str`): Define the type of the saved data. \
+ The type of saved data is pkl if `data_type == 'naive'`. \
+ The type of saved data is hdf5 if `data_type == 'hdf5'`.
+ """
+
+ def _save(ctx: "OnlineRLContext"):
+ """
+ Input of ctx:
+ - trajectories (:obj:`List[Tensor]`): The expert data to be saved.
+ """
+ data = ctx.trajectories
+ offline_data_save_type(data, data_path, data_type)
+ ctx.trajectories = None
+
+ return _save
+
+
+def sqil_data_pusher(cfg: EasyDict, buffer_: Buffer, expert: bool) -> Callable:
+ """
+ Overview:
+ Push trajectories into the buffer in sqil learning pipeline.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config.
+ - buffer (:obj:`Buffer`): Buffer to push the data in.
+ - expert (:obj:`bool`): Whether the pushed data is expert data or not. \
+ In each element of the pushed data, the reward will be set to 1 if this attribute is `True`, otherwise 0.
+ """
+
+ def _pusher(ctx: "OnlineRLContext"):
+ """
+ Input of ctx:
+ - trajectories (:obj:`List[Dict]`): The trajectories to be pushed.
+ """
+ for t in ctx.trajectories:
+ if expert:
+ t.reward = torch.ones_like(t.reward)
+ else:
+ t.reward = torch.zeros_like(t.reward)
+ buffer_.push(t)
+ ctx.trajectories = None
+
+ return _pusher
diff --git a/DI-engine/ding/framework/middleware/functional/enhancer.py b/DI-engine/ding/framework/middleware/functional/enhancer.py
new file mode 100644
index 0000000000000000000000000000000000000000..597a086850f9a9594a544e95383ce02ae88affda
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/functional/enhancer.py
@@ -0,0 +1,107 @@
+from typing import TYPE_CHECKING, Callable
+from easydict import EasyDict
+from ditk import logging
+import torch
+from ding.framework import task
+if TYPE_CHECKING:
+ from ding.framework import OnlineRLContext
+ from ding.reward_model import BaseRewardModel, HerRewardModel
+ from ding.data import Buffer
+
+
+def reward_estimator(cfg: EasyDict, reward_model: "BaseRewardModel") -> Callable:
+ """
+ Overview:
+ Estimate the reward of `train_data` using `reward_model`.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config.
+ - reward_model (:obj:`BaseRewardModel`): Reward model.
+ """
+ if task.router.is_active and not task.has_role(task.role.LEARNER):
+ return task.void()
+
+ def _enhance(ctx: "OnlineRLContext"):
+ """
+ Input of ctx:
+ - train_data (:obj:`List`): The list of data used for estimation.
+ """
+ reward_model.estimate(ctx.train_data) # inplace modification
+
+ return _enhance
+
+
+def her_data_enhancer(cfg: EasyDict, buffer_: "Buffer", her_reward_model: "HerRewardModel") -> Callable:
+ """
+ Overview:
+ Fetch a batch of data/episode from `buffer_`, \
+ then use `her_reward_model` to get HER processed episodes from original episodes.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config which should contain the following keys \
+ if her_reward_model.episode_size is None: `cfg.policy.learn.batch_size`.
+ - buffer\_ (:obj:`Buffer`): Buffer to sample data from.
+ - her_reward_model (:obj:`HerRewardModel`): Hindsight Experience Replay (HER) model \
+ which is used to process episodes.
+ """
+ if task.router.is_active and not task.has_role(task.role.LEARNER):
+ return task.void()
+
+ def _fetch_and_enhance(ctx: "OnlineRLContext"):
+ """
+ Output of ctx:
+ - train_data (:obj:`List[treetensor.torch.Tensor]`): The HER processed episodes.
+ """
+ if her_reward_model.episode_size is None:
+ size = cfg.policy.learn.batch_size
+ else:
+ size = her_reward_model.episode_size
+ try:
+ buffered_episode = buffer_.sample(size)
+ train_episode = [d.data for d in buffered_episode]
+ except (ValueError, AssertionError):
+ # You can modify data collect config to avoid this warning, e.g. increasing n_sample, n_episode.
+ logging.warning(
+ "Replay buffer's data is not enough to support training, so skip this training for waiting more data."
+ )
+ ctx.train_data = None
+ return
+
+ her_episode = sum([her_reward_model.estimate(e) for e in train_episode], [])
+ ctx.train_data = sum(her_episode, [])
+
+ return _fetch_and_enhance
+
+
+def nstep_reward_enhancer(cfg: EasyDict) -> Callable:
+
+ if task.router.is_active and (not task.has_role(task.role.LEARNER) and not task.has_role(task.role.COLLECTOR)):
+ return task.void()
+
+ def _enhance(ctx: "OnlineRLContext"):
+ nstep = cfg.policy.nstep
+ gamma = cfg.policy.discount_factor
+ L = len(ctx.trajectories)
+ reward_template = ctx.trajectories[0].reward
+ nstep_rewards = []
+ value_gamma = []
+ for i in range(L):
+ valid = min(nstep, L - i)
+ for j in range(1, valid):
+ if ctx.trajectories[j + i].done:
+ valid = j
+ break
+ value_gamma.append(torch.FloatTensor([gamma ** valid]))
+ nstep_reward = [ctx.trajectories[j].reward for j in range(i, i + valid)]
+ if nstep > valid:
+ nstep_reward.extend([torch.zeros_like(reward_template) for j in range(nstep - valid)])
+ nstep_reward = torch.cat(nstep_reward) # (nstep, )
+ nstep_rewards.append(nstep_reward)
+ for i in range(L):
+ ctx.trajectories[i].reward = nstep_rewards[i]
+ ctx.trajectories[i].value_gamma = value_gamma[i]
+
+ return _enhance
+
+
+# TODO MBPO
+# TODO SIL
+# TODO TD3 VAE
diff --git a/DI-engine/ding/framework/middleware/functional/evaluator.py b/DI-engine/ding/framework/middleware/functional/evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..611bbcdea69c95526a3e9dc60279ea9f4bab9470
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/functional/evaluator.py
@@ -0,0 +1,436 @@
+from typing import Callable, Any, List, Union, Optional
+from abc import ABC, abstractmethod
+from collections import deque
+from ditk import logging
+import numpy as np
+import torch
+import treetensor.numpy as tnp
+import treetensor.torch as ttorch
+from easydict import EasyDict
+from ding.envs import BaseEnvManager
+from ding.framework.context import Context, OfflineRLContext, OnlineRLContext
+from ding.policy import Policy
+from ding.data import Dataset, DataLoader
+from ding.framework import task
+from ding.torch_utils import to_ndarray, get_shape0
+from ding.utils import lists_to_dicts
+
+
+class IMetric(ABC):
+
+ @abstractmethod
+ def eval(self, inputs: Any, label: Any) -> dict:
+ raise NotImplementedError
+
+ @abstractmethod
+ def reduce_mean(self, inputs: List[Any]) -> Any:
+ raise NotImplementedError
+
+ @abstractmethod
+ def gt(self, metric1: Any, metric2: Any) -> bool:
+ """
+ Overview:
+ Whether metric1 is greater than metric2 (>=)
+
+ .. note::
+ If metric2 is None, return True
+ """
+ raise NotImplementedError
+
+
+class VectorEvalMonitor(object):
+ """
+ Overview:
+ In some cases, different environment in evaluator may collect different length episode. For example, \
+ suppose we want to collect 12 episodes in evaluator but only have 5 environments, if we didn’t do \
+ any thing, it is likely that we will get more short episodes than long episodes. As a result, \
+ our average reward will have a bias and may not be accurate. we use VectorEvalMonitor to solve the problem.
+ Interfaces:
+ __init__, is_finished, update_info, update_reward, get_episode_return, get_latest_reward, get_current_episode,\
+ get_episode_info, update_video, get_episode_video
+ """
+
+ def __init__(self, env_num: int, n_episode: int) -> None:
+ """
+ Overview:
+ Init method. According to the number of episodes and the number of environments, determine how many \
+ episodes need to be opened for each environment, and initialize the reward, info and other \
+ information
+ Arguments:
+ - env_num (:obj:`int`): the number of episodes need to be open
+ - n_episode (:obj:`int`): the number of environments
+ """
+ assert n_episode >= env_num, "n_episode < env_num, please decrease the number of eval env"
+ self._env_num = env_num
+ self._n_episode = n_episode
+ each_env_episode = [n_episode // env_num for _ in range(env_num)]
+ for i in range(n_episode % env_num):
+ each_env_episode[i] += 1
+ self._reward = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)}
+ self._info = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)}
+ self._video = {
+ env_id: deque([[] for _ in range(maxlen)], maxlen=maxlen)
+ for env_id, maxlen in enumerate(each_env_episode)
+ }
+ self._output = {
+ env_id: deque([[] for _ in range(maxlen)], maxlen=maxlen)
+ for env_id, maxlen in enumerate(each_env_episode)
+ }
+
+ def is_finished(self) -> bool:
+ """
+ Overview:
+ Determine whether the evaluator has completed the work.
+ Return:
+ - result: (:obj:`bool`): whether the evaluator has completed the work
+ """
+ return all([len(v) == v.maxlen for v in self._reward.values()])
+
+ def update_info(self, env_id: int, info: Any) -> None:
+ """
+ Overview:
+ Update the information of the environment indicated by env_id.
+ Arguments:
+ - env_id: (:obj:`int`): the id of the environment we need to update information
+ - info: (:obj:`Any`): the information we need to update
+ """
+ self._info[env_id].append(info)
+
+ def update_reward(self, env_id: Union[int, np.ndarray], reward: Any) -> None:
+ """
+ Overview:
+ Update the reward indicated by env_id.
+ Arguments:
+ - env_id: (:obj:`int`): the id of the environment we need to update the reward
+ - reward: (:obj:`Any`): the reward we need to update
+ """
+ if isinstance(reward, torch.Tensor):
+ reward = reward.item()
+ if isinstance(env_id, np.ndarray):
+ env_id = env_id.item()
+ self._reward[env_id].append(reward)
+
+ def get_episode_return(self) -> list:
+ """
+ Overview:
+ Sum up all reward and get the total return of one episode.
+ """
+ return sum([list(v) for v in self._reward.values()], []) # sum(iterable, start)
+
+ def get_latest_reward(self, env_id: int) -> int:
+ """
+ Overview:
+ Get the latest reward of a certain environment.
+ Arguments:
+ - env_id: (:obj:`int`): the id of the environment we need to get reward.
+ """
+ return self._reward[env_id][-1]
+
+ def get_current_episode(self) -> int:
+ """
+ Overview:
+ Get the current episode. We can know which episode our evaluator is executing now.
+ """
+ return sum([len(v) for v in self._reward.values()])
+
+ def get_episode_info(self) -> dict:
+ """
+ Overview:
+ Get all episode information, such as total return of one episode.
+ """
+ if len(self._info[0]) == 0:
+ return None
+ else:
+ # sum among all envs
+ total_info = sum([list(v) for v in self._info.values()], [])
+ if isinstance(total_info[0], tnp.ndarray):
+ total_info = [t.json() for t in total_info]
+ total_info = lists_to_dicts(total_info)
+ new_dict = {}
+ for k in total_info.keys():
+ try:
+ if np.isscalar(total_info[k][0].item()):
+ new_dict[k + '_mean'] = np.mean(total_info[k])
+ except: # noqa
+ pass
+ return new_dict
+
+ def _select_idx(self):
+ reward = [t.item() for t in self.get_episode_return()]
+ sortarg = np.argsort(reward)
+ # worst, median(s), best
+ if len(sortarg) == 1:
+ idxs = [sortarg[0]]
+ elif len(sortarg) == 2:
+ idxs = [sortarg[0], sortarg[-1]]
+ elif len(sortarg) == 3:
+ idxs = [sortarg[0], sortarg[len(sortarg) // 2], sortarg[-1]]
+ else:
+ # TensorboardX pad the number of videos to even numbers with black frames,
+ # therefore providing even number of videos prevents black frames being rendered.
+ idxs = [sortarg[0], sortarg[len(sortarg) // 2 - 1], sortarg[len(sortarg) // 2], sortarg[-1]]
+ return idxs
+
+ def update_video(self, imgs):
+ for env_id, img in imgs.items():
+ if len(self._reward[env_id]) == self._reward[env_id].maxlen:
+ continue
+ self._video[env_id][len(self._reward[env_id])].append(img)
+
+ def get_episode_video(self):
+ """
+ Overview:
+ Convert list of videos into [N, T, C, H, W] tensor, containing
+ worst, median, best evaluation trajectories for video logging.
+ """
+ videos = sum([list(v) for v in self._video.values()], [])
+ videos = [np.transpose(np.stack(video, 0), [0, 3, 1, 2]) for video in videos]
+ idxs = self._select_idx()
+ videos = [videos[idx] for idx in idxs]
+ # pad videos to the same length with last frames
+ max_length = max(video.shape[0] for video in videos)
+ for i in range(len(videos)):
+ if videos[i].shape[0] < max_length:
+ padding = np.tile([videos[i][-1]], (max_length - videos[i].shape[0], 1, 1, 1))
+ videos[i] = np.concatenate([videos[i], padding], 0)
+ videos = np.stack(videos, 0)
+ assert len(videos.shape) == 5, 'Need [N, T, C, H, W] input tensor for video logging!'
+ return videos
+
+ def update_output(self, output):
+ for env_id, o in output.items():
+ if len(self._reward[env_id]) == self._reward[env_id].maxlen:
+ continue
+ self._output[env_id][len(self._reward[env_id])].append(to_ndarray(o))
+
+ def get_episode_output(self):
+ output = sum([list(v) for v in self._output.values()], [])
+ idxs = self._select_idx()
+ output = [output[idx] for idx in idxs]
+ return output
+
+
+def interaction_evaluator(cfg: EasyDict, policy: Policy, env: BaseEnvManager, render: bool = False) -> Callable:
+ """
+ Overview:
+ The middleware that executes the evaluation.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config.
+ - policy (:obj:`Policy`): The policy to be evaluated.
+ - env (:obj:`BaseEnvManager`): The env for the evaluation.
+ - render (:obj:`bool`): Whether to render env images and policy logits.
+ """
+ if task.router.is_active and not task.has_role(task.role.EVALUATOR):
+ return task.void()
+
+ env.seed(cfg.seed, dynamic_seed=False)
+
+ def _evaluate(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
+ """
+ Overview:
+ - The evaluation will be executed if the task begins and enough train_iter passed \
+ since last evaluation.
+ Input of ctx:
+ - last_eval_iter (:obj:`int`): Last evaluation iteration.
+ - train_iter (:obj:`int`): Current train iteration.
+ Output of ctx:
+ - eval_value (:obj:`float`): The average reward in the current evaluation.
+ """
+
+ # evaluation will be executed if the task begins or enough train_iter after last evaluation
+ if ctx.last_eval_iter != -1 and \
+ (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq):
+ return
+
+ if env.closed:
+ env.launch()
+ else:
+ env.reset()
+ policy.reset()
+ eval_monitor = VectorEvalMonitor(env.env_num, cfg.env.n_evaluator_episode)
+
+ while not eval_monitor.is_finished():
+ obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
+ obs = {i: obs[i] for i in range(get_shape0(obs))} # TBD
+ inference_output = policy.forward(obs)
+ if render:
+ eval_monitor.update_video(env.ready_imgs)
+ eval_monitor.update_output(inference_output)
+ output = [v for v in inference_output.values()]
+ action = [to_ndarray(v['action']) for v in output] # TBD
+ timesteps = env.step(action)
+ for timestep in timesteps:
+ env_id = timestep.env_id.item()
+ if timestep.done:
+ policy.reset([env_id])
+ reward = timestep.info.eval_episode_return
+ eval_monitor.update_reward(env_id, reward)
+ if 'episode_info' in timestep.info:
+ eval_monitor.update_info(env_id, timestep.info.episode_info)
+ episode_return = eval_monitor.get_episode_return()
+ episode_return_min = np.min(episode_return)
+ episode_return_max = np.max(episode_return)
+ episode_return_std = np.std(episode_return)
+ episode_return = np.mean(episode_return)
+ stop_flag = episode_return >= cfg.env.stop_value and ctx.train_iter > 0
+ if isinstance(ctx, OnlineRLContext):
+ logging.info(
+ 'Evaluation: Train Iter({})\tEnv Step({})\tEpisode Return({:.3f})'.format(
+ ctx.train_iter, ctx.env_step, episode_return
+ )
+ )
+ elif isinstance(ctx, OfflineRLContext):
+ logging.info('Evaluation: Train Iter({})\tEval Reward({:.3f})'.format(ctx.train_iter, episode_return))
+ else:
+ raise TypeError("not supported ctx type: {}".format(type(ctx)))
+ ctx.last_eval_iter = ctx.train_iter
+ ctx.eval_value = episode_return
+ ctx.eval_value_min = episode_return_min
+ ctx.eval_value_max = episode_return_max
+ ctx.eval_value_std = episode_return_std
+ ctx.last_eval_value = ctx.eval_value
+ ctx.eval_output = {'episode_return': episode_return}
+ episode_info = eval_monitor.get_episode_info()
+ if episode_info is not None:
+ ctx.eval_output['episode_info'] = episode_info
+ if render:
+ ctx.eval_output['replay_video'] = eval_monitor.get_episode_video()
+ ctx.eval_output['output'] = eval_monitor.get_episode_output()
+ else:
+ ctx.eval_output['output'] = output # for compatibility
+
+ if stop_flag:
+ task.finish = True
+
+ return _evaluate
+
+
+def interaction_evaluator_ttorch(
+ seed: int,
+ policy: Policy,
+ env: BaseEnvManager,
+ n_evaluator_episode: Optional[int] = None,
+ stop_value: float = np.inf,
+ eval_freq: int = 1000,
+ render: bool = False,
+) -> Callable:
+ """
+ Overview:
+ The middleware that executes the evaluation with ttorch data.
+ Arguments:
+ - policy (:obj:`Policy`): The policy to be evaluated.
+ - env (:obj:`BaseEnvManager`): The env for the evaluation.
+ - render (:obj:`bool`): Whether to render env images and policy logits.
+ """
+ if task.router.is_active and not task.has_role(task.role.EVALUATOR):
+ return task.void()
+
+ env.seed(seed, dynamic_seed=False)
+ if n_evaluator_episode is None:
+ n_evaluator_episode = env.env_num
+
+ def _evaluate(ctx: "OnlineRLContext"):
+ """
+ Overview:
+ - The evaluation will be executed if the task begins and enough train_iter passed \
+ since last evaluation.
+ Input of ctx:
+ - last_eval_iter (:obj:`int`): Last evaluation iteration.
+ - train_iter (:obj:`int`): Current train iteration.
+ Output of ctx:
+ - eval_value (:obj:`float`): The average reward in the current evaluation.
+ """
+
+ # evaluation will be executed if the task begins or enough train_iter after last evaluation
+ if ctx.last_eval_iter != -1 and (ctx.train_iter - ctx.last_eval_iter < eval_freq):
+ return
+
+ if env.closed:
+ env.launch()
+ else:
+ env.reset()
+ policy.reset()
+ device = policy._device
+ eval_monitor = VectorEvalMonitor(env.env_num, n_evaluator_episode)
+
+ while not eval_monitor.is_finished():
+ obs = ttorch.as_tensor(env.ready_obs).to(dtype=ttorch.float32)
+ obs = obs.to(device)
+ inference_output = policy.eval(obs)
+ inference_output = inference_output.cpu()
+ if render:
+ eval_monitor.update_video(env.ready_imgs)
+ # eval_monitor.update_output(inference_output)
+ action = inference_output.action.numpy()
+ timesteps = env.step(action)
+ for timestep in timesteps:
+ env_id = timestep.env_id.item()
+ if timestep.done:
+ policy.reset([env_id])
+ reward = timestep.info.eval_episode_return
+ eval_monitor.update_reward(env_id, reward)
+ if 'episode_info' in timestep.info:
+ eval_monitor.update_info(env_id, timestep.info.episode_info)
+ episode_return = eval_monitor.get_episode_return()
+ episode_return_std = np.std(episode_return)
+ episode_return_mean = np.mean(episode_return)
+ stop_flag = episode_return_mean >= stop_value and ctx.train_iter > 0
+ logging.info(
+ 'Evaluation: Train Iter({})\tEnv Step({})\tMean Episode Return({:.3f})'.format(
+ ctx.train_iter, ctx.env_step, episode_return_mean
+ )
+ )
+ ctx.last_eval_iter = ctx.train_iter
+ ctx.eval_value = episode_return_mean
+ ctx.eval_value_std = episode_return_std
+ ctx.last_eval_value = ctx.eval_value
+ ctx.eval_output = {'episode_return': episode_return}
+ episode_info = eval_monitor.get_episode_info()
+ if episode_info is not None:
+ ctx.eval_output['episode_info'] = episode_info
+ if render:
+ ctx.eval_output['replay_video'] = eval_monitor.get_episode_video()
+ ctx.eval_output['output'] = eval_monitor.get_episode_output()
+ else:
+ ctx.eval_output['output'] = inference_output.numpy() # for compatibility
+
+ if stop_flag:
+ task.finish = True
+
+ return _evaluate
+
+
+def metric_evaluator(cfg: EasyDict, policy: Policy, dataset: Dataset, metric: IMetric) -> Callable:
+ dataloader = DataLoader(dataset, batch_size=cfg.policy.eval.batch_size)
+
+ def _evaluate(ctx: "Context"):
+ # evaluation will be executed if the task begins or enough train_iter after last evaluation
+ if ctx.last_eval_iter != -1 and \
+ (ctx.train_iter - ctx.last_eval_iter < cfg.policy.eval.evaluator.eval_freq):
+ return
+
+ policy.reset()
+ eval_output = []
+
+ for batch_idx, batch_data in enumerate(dataloader):
+ inputs, label = batch_data
+ inference_output = policy.forward(inputs)
+ eval_output.append(metric.eval(inference_output, label))
+ # TODO reduce avg_eval_output among different gpus
+ avg_eval_output = metric.reduce_mean(eval_output)
+ stop_flag = metric.gt(avg_eval_output, cfg.env.stop_value) and ctx.train_iter > 0
+ logging.info(
+ 'Evaluation: Train Iter({})\tEnv Step({})\tEpisode Return({:.3f})'.format(
+ ctx.train_iter, ctx.env_step, avg_eval_output
+ )
+ )
+ ctx.last_eval_iter = ctx.train_iter
+ ctx.eval_value = avg_eval_output
+
+ if stop_flag:
+ task.finish = True
+
+ return _evaluate
+
+
+# TODO battle evaluator
diff --git a/DI-engine/ding/framework/middleware/functional/explorer.py b/DI-engine/ding/framework/middleware/functional/explorer.py
new file mode 100644
index 0000000000000000000000000000000000000000..45aa9bd24a19712351042c58012c780c5d3176ee
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/functional/explorer.py
@@ -0,0 +1,56 @@
+from typing import TYPE_CHECKING, Callable
+from easydict import EasyDict
+from ding.rl_utils import get_epsilon_greedy_fn
+from ding.framework import task
+
+if TYPE_CHECKING:
+ from ding.framework import OnlineRLContext
+
+
+def eps_greedy_handler(cfg: EasyDict) -> Callable:
+ """
+ Overview:
+ The middleware that computes epsilon value according to the env_step.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config.
+ """
+ if task.router.is_active and not task.has_role(task.role.COLLECTOR):
+ return task.void()
+
+ eps_cfg = cfg.policy.other.eps
+ handle = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ def _eps_greedy(ctx: "OnlineRLContext"):
+ """
+ Input of ctx:
+ - env_step (:obj:`int`): The env steps count.
+ Output of ctx:
+ - collect_kwargs['eps'] (:obj:`float`): The eps conditioned on env_step and cfg.
+ """
+
+ ctx.collect_kwargs['eps'] = handle(ctx.env_step)
+ yield
+ try:
+ ctx.collect_kwargs.pop('eps')
+ except: # noqa
+ pass
+
+ return _eps_greedy
+
+
+def eps_greedy_masker():
+ """
+ Overview:
+ The middleware that returns masked epsilon value and stop generating \
+ actions by the e_greedy method.
+ """
+
+ def _masker(ctx: "OnlineRLContext"):
+ """
+ Output of ctx:
+ - collect_kwargs['eps'] (:obj:`float`): The masked eps value, default to -1.
+ """
+
+ ctx.collect_kwargs['eps'] = -1
+
+ return _masker
diff --git a/DI-engine/ding/framework/middleware/functional/logger.py b/DI-engine/ding/framework/middleware/functional/logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f62e2f429ff43fe59d8d882c76e0ea04fb0f6de
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/functional/logger.py
@@ -0,0 +1,707 @@
+from typing import TYPE_CHECKING, Optional, Callable, Dict, List, Union
+from ditk import logging
+from easydict import EasyDict
+from matplotlib import pyplot as plt
+from matplotlib import animation
+import os
+import numpy as np
+import torch
+import wandb
+import pickle
+import treetensor.numpy as tnp
+from ding.framework import task
+from ding.envs import BaseEnvManagerV2
+from ding.utils import DistributedWriter
+from ding.torch_utils import to_ndarray
+from ding.utils.default_helper import one_time_warning
+
+if TYPE_CHECKING:
+ from ding.framework import OnlineRLContext, OfflineRLContext
+
+
+def online_logger(record_train_iter: bool = False, train_show_freq: int = 100) -> Callable:
+ """
+ Overview:
+ Create an online RL tensorboard logger for recording training and evaluation metrics.
+ Arguments:
+ - record_train_iter (:obj:`bool`): Whether to record training iteration. Default is False.
+ - train_show_freq (:obj:`int`): Frequency of showing training logs. Default is 100.
+ Returns:
+ - _logger (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input.
+ Raises:
+ - RuntimeError: If writer is None.
+ - NotImplementedError: If the key of train_output is not supported, such as "scalars".
+
+ Examples:
+ >>> task.use(online_logger(record_train_iter=False, train_show_freq=1000))
+ """
+ if task.router.is_active and not task.has_role(task.role.LEARNER):
+ return task.void()
+ writer = DistributedWriter.get_instance()
+ if writer is None:
+ raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.")
+ last_train_show_iter = -1
+
+ def _logger(ctx: "OnlineRLContext"):
+ if task.finish:
+ writer.close()
+ nonlocal last_train_show_iter
+
+ if not np.isinf(ctx.eval_value):
+ if record_train_iter:
+ writer.add_scalar('basic/eval_episode_return_mean-env_step', ctx.eval_value, ctx.env_step)
+ writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter)
+ else:
+ writer.add_scalar('basic/eval_episode_return_mean', ctx.eval_value, ctx.env_step)
+ if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq:
+ last_train_show_iter = ctx.train_iter
+ if isinstance(ctx.train_output, List):
+ output = ctx.train_output.pop() # only use latest output for some algorithms, like PPO
+ else:
+ output = ctx.train_output
+ for k, v in output.items():
+ if k in ['priority', 'td_error_priority']:
+ continue
+ if "[scalars]" in k:
+ new_k = k.split(']')[-1]
+ raise NotImplementedError
+ elif "[histogram]" in k:
+ new_k = k.split(']')[-1]
+ writer.add_histogram(new_k, v, ctx.env_step)
+ if record_train_iter:
+ writer.add_histogram(new_k, v, ctx.train_iter)
+ else:
+ if record_train_iter:
+ writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter)
+ writer.add_scalar('basic/train_{}-env_step'.format(k), v, ctx.env_step)
+ else:
+ writer.add_scalar('basic/train_{}'.format(k), v, ctx.env_step)
+
+ return _logger
+
+
+def offline_logger(train_show_freq: int = 100) -> Callable:
+ """
+ Overview:
+ Create an offline RL tensorboard logger for recording training and evaluation metrics.
+ Arguments:
+ - train_show_freq (:obj:`int`): Frequency of showing training logs. Defaults to 100.
+ Returns:
+ - _logger (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input.
+ Raises:
+ - RuntimeError: If writer is None.
+ - NotImplementedError: If the key of train_output is not supported, such as "scalars".
+
+ Examples:
+ >>> task.use(offline_logger(train_show_freq=1000))
+ """
+ if task.router.is_active and not task.has_role(task.role.LEARNER):
+ return task.void()
+ writer = DistributedWriter.get_instance()
+ if writer is None:
+ raise RuntimeError("logger writer is None, you should call `ding_init(cfg)` at the beginning of training.")
+ last_train_show_iter = -1
+
+ def _logger(ctx: "OfflineRLContext"):
+ nonlocal last_train_show_iter
+ if task.finish:
+ writer.close()
+ if not np.isinf(ctx.eval_value):
+ writer.add_scalar('basic/eval_episode_return_mean-train_iter', ctx.eval_value, ctx.train_iter)
+ if ctx.train_output is not None and ctx.train_iter - last_train_show_iter >= train_show_freq:
+ last_train_show_iter = ctx.train_iter
+ output = ctx.train_output
+ for k, v in output.items():
+ if k in ['priority']:
+ continue
+ if "[scalars]" in k:
+ new_k = k.split(']')[-1]
+ raise NotImplementedError
+ elif "[histogram]" in k:
+ new_k = k.split(']')[-1]
+ writer.add_histogram(new_k, v, ctx.train_iter)
+ else:
+ writer.add_scalar('basic/train_{}-train_iter'.format(k), v, ctx.train_iter)
+
+ return _logger
+
+
+# four utility functions for wandb logger
+def softmax(logit: np.ndarray) -> np.ndarray:
+ v = np.exp(logit)
+ return v / v.sum(axis=-1, keepdims=True)
+
+
+def action_prob(num, action_prob, ln):
+ ax = plt.gca()
+ ax.set_ylim([0, 1])
+ for rect, x in zip(ln, action_prob[num]):
+ rect.set_height(x)
+ return ln
+
+
+def return_prob(num, return_prob, ln):
+ return ln
+
+
+def return_distribution(episode_return):
+ num = len(episode_return)
+ max_return = max(episode_return)
+ min_return = min(episode_return)
+ hist, bins = np.histogram(episode_return, bins=np.linspace(min_return - 50, max_return + 50, 6))
+ gap = (max_return - min_return + 100) / 5
+ x_dim = ['{:.1f}'.format(min_return - 50 + gap * x) for x in range(5)]
+ return hist / num, x_dim
+
+
+def wandb_online_logger(
+ record_path: str = None,
+ cfg: Union[dict, EasyDict] = None,
+ exp_config: Union[dict, EasyDict] = None,
+ metric_list: Optional[List[str]] = None,
+ env: Optional[BaseEnvManagerV2] = None,
+ model: Optional[torch.nn.Module] = None,
+ anonymous: bool = False,
+ project_name: str = 'default-project',
+ run_name: str = None,
+ wandb_sweep: bool = False,
+) -> Callable:
+ """
+ Overview:
+ Wandb visualizer to track the experiment.
+ Arguments:
+ - record_path (:obj:`str`): The path to save the replay of simulation.
+ - cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings:
+ - gradient_logger: boolean. Whether to track the gradient.
+ - plot_logger: boolean. Whether to track the metrics like reward and loss.
+ - video_logger: boolean. Whether to upload the rendering video replay.
+ - action_logger: boolean. `q_value` or `action probability`.
+ - return_logger: boolean. Whether to track the return value.
+ - metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies.
+ - env (:obj:`BaseEnvManagerV2`): Evaluator environment.
+ - model (:obj:`nn.Module`): Policy neural network model.
+ - anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \
+ of data without wandb count.
+ - project_name (:obj:`str`): The name of wandb project.
+ - run_name (:obj:`str`): The name of wandb run.
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep.
+ '''
+ Returns:
+ - _plot (:obj:`Callable`): A logger function that takes an OnlineRLContext object as input.
+ """
+ if task.router.is_active and not task.has_role(task.role.LEARNER):
+ return task.void()
+ color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"]
+ if metric_list is None:
+ metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"]
+ # Initialize wandb with default settings
+ # Settings can be covered by calling wandb.init() at the top of the script
+ if exp_config:
+ if not wandb_sweep:
+ if run_name is not None:
+ if anonymous:
+ wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must")
+ else:
+ wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name)
+ else:
+ if anonymous:
+ wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must")
+ else:
+ wandb.init(project=project_name, config=exp_config, reinit=True)
+ else:
+ if run_name is not None:
+ if anonymous:
+ wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must")
+ else:
+ wandb.init(project=project_name, config=exp_config, name=run_name)
+ else:
+ if anonymous:
+ wandb.init(project=project_name, config=exp_config, anonymous="must")
+ else:
+ wandb.init(project=project_name, config=exp_config)
+ else:
+ if not wandb_sweep:
+ if run_name is not None:
+ if anonymous:
+ wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must")
+ else:
+ wandb.init(project=project_name, reinit=True, name=run_name)
+ else:
+ if anonymous:
+ wandb.init(project=project_name, reinit=True, anonymous="must")
+ else:
+ wandb.init(project=project_name, reinit=True)
+ else:
+ if run_name is not None:
+ if anonymous:
+ wandb.init(project=project_name, name=run_name, anonymous="must")
+ else:
+ wandb.init(project=project_name, name=run_name)
+ else:
+ if anonymous:
+ wandb.init(project=project_name, anonymous="must")
+ else:
+ wandb.init(project=project_name)
+ plt.switch_backend('agg')
+ if cfg is None:
+ cfg = EasyDict(
+ dict(
+ gradient_logger=False,
+ plot_logger=True,
+ video_logger=False,
+ action_logger=False,
+ return_logger=False,
+ )
+ )
+ else:
+ if not isinstance(cfg, EasyDict):
+ cfg = EasyDict(cfg)
+ for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]:
+ if key not in cfg.keys():
+ cfg[key] = False
+
+ # The visualizer is called to save the replay of the simulation
+ # which will be uploaded to wandb later
+ if env is not None and cfg.video_logger is True and record_path is not None:
+ env.enable_save_replay(replay_path=record_path)
+ if cfg.gradient_logger:
+ wandb.watch(model, log="all", log_freq=100, log_graph=True)
+ else:
+ one_time_warning(
+ "If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config."
+ )
+
+ first_plot = True
+
+ def _plot(ctx: "OnlineRLContext"):
+ nonlocal first_plot
+ if first_plot:
+ first_plot = False
+ ctx.wandb_url = wandb.run.get_project_url()
+
+ info_for_logging = {}
+
+ if cfg.plot_logger:
+ for metric in metric_list:
+ if isinstance(ctx.train_output, Dict) and metric in ctx.train_output:
+ if isinstance(ctx.train_output[metric], torch.Tensor):
+ info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()})
+ else:
+ info_for_logging.update({metric: ctx.train_output[metric]})
+ elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]:
+ metric_value_list = []
+ for item in ctx.train_output:
+ if isinstance(item[metric], torch.Tensor):
+ metric_value_list.append(item[metric].cpu().detach().numpy())
+ else:
+ metric_value_list.append(item[metric])
+ metric_value = np.mean(metric_value_list)
+ info_for_logging.update({metric: metric_value})
+ else:
+ one_time_warning(
+ "If you want to use wandb to visualize the result, please set plot_logger = True in the config."
+ )
+
+ if ctx.eval_value != -np.inf:
+ if hasattr(ctx, "eval_value_min"):
+ info_for_logging.update({
+ "episode return min": ctx.eval_value_min,
+ })
+ if hasattr(ctx, "eval_value_max"):
+ info_for_logging.update({
+ "episode return max": ctx.eval_value_max,
+ })
+ if hasattr(ctx, "eval_value_std"):
+ info_for_logging.update({
+ "episode return std": ctx.eval_value_std,
+ })
+ if hasattr(ctx, "eval_value"):
+ info_for_logging.update({
+ "episode return mean": ctx.eval_value,
+ })
+ if hasattr(ctx, "train_iter"):
+ info_for_logging.update({
+ "train iter": ctx.train_iter,
+ })
+ if hasattr(ctx, "env_step"):
+ info_for_logging.update({
+ "env step": ctx.env_step,
+ })
+
+ eval_output = ctx.eval_output['output']
+ episode_return = ctx.eval_output['episode_return']
+ episode_return = np.array(episode_return)
+ if len(episode_return.shape) == 2:
+ episode_return = episode_return.squeeze(1)
+
+ if cfg.video_logger:
+ if 'replay_video' in ctx.eval_output:
+ # save numpy array "images" of shape (N,1212,3,224,320) to N video files in mp4 format
+ # The numpy tensor must be either 4 dimensional or 5 dimensional.
+ # Channels should be (time, channel, height, width) or (batch, time, channel, height width)
+ video_images = ctx.eval_output['replay_video']
+ video_images = video_images.astype(np.uint8)
+ info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)})
+ elif record_path is not None:
+ file_list = []
+ for p in os.listdir(record_path):
+ if os.path.splitext(p)[-1] == ".mp4":
+ file_list.append(p)
+ file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn)))
+ video_path = os.path.join(record_path, file_list[-2])
+ info_for_logging.update({"video": wandb.Video(video_path, format="mp4")})
+
+ if cfg.action_logger:
+ action_path = os.path.join(record_path, (str(ctx.env_step) + "_action.gif"))
+ if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"):
+ if isinstance(eval_output, tnp.ndarray):
+ action_prob = softmax(eval_output.logit)
+ else:
+ action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output]
+ fig, ax = plt.subplots()
+ plt.ylim([-1, 1])
+ action_dim = len(action_prob[1])
+ x_range = [str(x + 1) for x in range(action_dim)]
+ ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim])
+ ani = animation.FuncAnimation(
+ fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob)
+ )
+ ani.save(action_path, writer='pillow')
+ info_for_logging.update({"action": wandb.Video(action_path, format="gif")})
+
+ elif all(['action' in v for v in eval_output[0]]):
+ for i, action_trajectory in enumerate(eval_output):
+ fig, ax = plt.subplots()
+ fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)])
+ steps = fig_data[:, 0]
+ actions = fig_data[:, 1:]
+ plt.ylim([-1, 1])
+ for j in range(actions.shape[1]):
+ ax.scatter(steps, actions[:, j])
+ info_for_logging.update({"actions_of_trajectory_{}".format(i): fig})
+
+ if cfg.return_logger:
+ return_path = os.path.join(record_path, (str(ctx.env_step) + "_return.gif"))
+ fig, ax = plt.subplots()
+ ax = plt.gca()
+ ax.set_ylim([0, 1])
+ hist, x_dim = return_distribution(episode_return)
+ assert len(hist) == len(x_dim)
+ ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7)
+ ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1)
+ ani.save(return_path, writer='pillow')
+ info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")})
+
+ if bool(info_for_logging):
+ wandb.log(data=info_for_logging, step=ctx.env_step)
+ plt.clf()
+
+ return _plot
+
+
+def wandb_offline_logger(
+ record_path: str = None,
+ cfg: Union[dict, EasyDict] = None,
+ exp_config: Union[dict, EasyDict] = None,
+ metric_list: Optional[List[str]] = None,
+ env: Optional[BaseEnvManagerV2] = None,
+ model: Optional[torch.nn.Module] = None,
+ anonymous: bool = False,
+ project_name: str = 'default-project',
+ run_name: str = None,
+ wandb_sweep: bool = False,
+) -> Callable:
+ """
+ Overview:
+ Wandb visualizer to track the experiment.
+ Arguments:
+ - record_path (:obj:`str`): The path to save the replay of simulation.
+ - cfg (:obj:`Union[dict, EasyDict]`): Config, a dict of following settings:
+ - gradient_logger: boolean. Whether to track the gradient.
+ - plot_logger: boolean. Whether to track the metrics like reward and loss.
+ - video_logger: boolean. Whether to upload the rendering video replay.
+ - action_logger: boolean. `q_value` or `action probability`.
+ - return_logger: boolean. Whether to track the return value.
+ - vis_dataset: boolean. Whether to visualize the dataset.
+ - metric_list (:obj:`Optional[List[str]]`): Logged metric list, specialized by different policies.
+ - env (:obj:`BaseEnvManagerV2`): Evaluator environment.
+ - model (:obj:`nn.Module`): Policy neural network model.
+ - anonymous (:obj:`bool`): Open the anonymous mode of wandb or not. The anonymous mode allows visualization \
+ of data without wandb count.
+ - project_name (:obj:`str`): The name of wandb project.
+ - run_name (:obj:`str`): The name of wandb run.
+ - wandb_sweep (:obj:`bool`): Whether to use wandb sweep.
+ '''
+ Returns:
+ - _plot (:obj:`Callable`): A logger function that takes an OfflineRLContext object as input.
+ """
+ if task.router.is_active and not task.has_role(task.role.LEARNER):
+ return task.void()
+ color_list = ["orange", "red", "blue", "purple", "green", "darkcyan"]
+ if metric_list is None:
+ metric_list = ["q_value", "target q_value", "loss", "lr", "entropy", "target_q_value", "td_error"]
+ # Initialize wandb with default settings
+ # Settings can be covered by calling wandb.init() at the top of the script
+ if exp_config:
+ if not wandb_sweep:
+ if run_name is not None:
+ if anonymous:
+ wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name, anonymous="must")
+ else:
+ wandb.init(project=project_name, config=exp_config, reinit=True, name=run_name)
+ else:
+ if anonymous:
+ wandb.init(project=project_name, config=exp_config, reinit=True, anonymous="must")
+ else:
+ wandb.init(project=project_name, config=exp_config, reinit=True)
+ else:
+ if run_name is not None:
+ if anonymous:
+ wandb.init(project=project_name, config=exp_config, name=run_name, anonymous="must")
+ else:
+ wandb.init(project=project_name, config=exp_config, name=run_name)
+ else:
+ if anonymous:
+ wandb.init(project=project_name, config=exp_config, anonymous="must")
+ else:
+ wandb.init(project=project_name, config=exp_config)
+ else:
+ if not wandb_sweep:
+ if run_name is not None:
+ if anonymous:
+ wandb.init(project=project_name, reinit=True, name=run_name, anonymous="must")
+ else:
+ wandb.init(project=project_name, reinit=True, name=run_name)
+ else:
+ if anonymous:
+ wandb.init(project=project_name, reinit=True, anonymous="must")
+ else:
+ wandb.init(project=project_name, reinit=True)
+ else:
+ if run_name is not None:
+ if anonymous:
+ wandb.init(project=project_name, name=run_name, anonymous="must")
+ else:
+ wandb.init(project=project_name, name=run_name)
+ else:
+ if anonymous:
+ wandb.init(project=project_name, anonymous="must")
+ else:
+ wandb.init(project=project_name)
+ plt.switch_backend('agg')
+ plt.switch_backend('agg')
+ if cfg is None:
+ cfg = EasyDict(
+ dict(
+ gradient_logger=False,
+ plot_logger=True,
+ video_logger=False,
+ action_logger=False,
+ return_logger=False,
+ vis_dataset=True,
+ )
+ )
+ else:
+ if not isinstance(cfg, EasyDict):
+ cfg = EasyDict(cfg)
+ for key in ["gradient_logger", "plot_logger", "video_logger", "action_logger", "return_logger", "vis_dataset"]:
+ if key not in cfg.keys():
+ cfg[key] = False
+
+ # The visualizer is called to save the replay of the simulation
+ # which will be uploaded to wandb later
+ if env is not None and cfg.video_logger is True and record_path is not None:
+ env.enable_save_replay(replay_path=record_path)
+ if cfg.gradient_logger:
+ wandb.watch(model, log="all", log_freq=100, log_graph=True)
+ else:
+ one_time_warning(
+ "If you want to use wandb to visualize the gradient, please set gradient_logger = True in the config."
+ )
+
+ first_plot = True
+
+ def _vis_dataset(datasetpath: str):
+ try:
+ from sklearn.manifold import TSNE
+ except ImportError:
+ import sys
+ logging.warning("Please install sklearn first, such as `pip3 install scikit-learn`.")
+ sys.exit(1)
+ try:
+ import h5py
+ except ImportError:
+ import sys
+ logging.warning("Please install h5py first, such as `pip3 install h5py`.")
+ sys.exit(1)
+ assert os.path.splitext(datasetpath)[-1] in ['.pkl', '.h5', '.hdf5']
+ if os.path.splitext(datasetpath)[-1] == '.pkl':
+ with open(datasetpath, 'rb') as f:
+ data = pickle.load(f)
+ obs = []
+ action = []
+ reward = []
+ for i in range(len(data)):
+ obs.extend(data[i]['observations'])
+ action.extend(data[i]['actions'])
+ reward.extend(data[i]['rewards'])
+ elif os.path.splitext(datasetpath)[-1] in ['.h5', '.hdf5']:
+ with h5py.File(datasetpath, 'r') as f:
+ obs = f['obs'][()]
+ action = f['action'][()]
+ reward = f['reward'][()]
+
+ cmap = plt.cm.hsv
+ obs = np.array(obs)
+ reward = np.array(reward)
+ obs_action = np.hstack((obs, np.array(action)))
+ reward = reward / (max(reward) - min(reward))
+
+ embedded_obs = TSNE(n_components=2).fit_transform(obs)
+ embedded_obs_action = TSNE(n_components=2).fit_transform(obs_action)
+ x_min, x_max = np.min(embedded_obs, 0), np.max(embedded_obs, 0)
+ embedded_obs = embedded_obs / (x_max - x_min)
+
+ x_min, x_max = np.min(embedded_obs_action, 0), np.max(embedded_obs_action, 0)
+ embedded_obs_action = embedded_obs_action / (x_max - x_min)
+
+ fig = plt.figure()
+ f, axes = plt.subplots(nrows=1, ncols=3)
+
+ axes[0].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(reward))
+ axes[1].scatter(embedded_obs[:, 0], embedded_obs[:, 1], c=cmap(action))
+ axes[2].scatter(embedded_obs_action[:, 0], embedded_obs_action[:, 1], c=cmap(reward))
+ axes[0].set_title('state-reward')
+ axes[1].set_title('state-action')
+ axes[2].set_title('stateAction-reward')
+ plt.savefig('dataset.png')
+
+ wandb.log({"dataset": wandb.Image("dataset.png")})
+
+ if cfg.vis_dataset is True:
+ _vis_dataset(exp_config.dataset_path)
+
+ def _plot(ctx: "OfflineRLContext"):
+ nonlocal first_plot
+ if first_plot:
+ first_plot = False
+ ctx.wandb_url = wandb.run.get_project_url()
+
+ info_for_logging = {}
+
+ if cfg.plot_logger:
+ for metric in metric_list:
+ if isinstance(ctx.train_output, Dict) and metric in ctx.train_output:
+ if isinstance(ctx.train_output[metric], torch.Tensor):
+ info_for_logging.update({metric: ctx.train_output[metric].cpu().detach().numpy()})
+ else:
+ info_for_logging.update({metric: ctx.train_output[metric]})
+ elif isinstance(ctx.train_output, List) and len(ctx.train_output) > 0 and metric in ctx.train_output[0]:
+ metric_value_list = []
+ for item in ctx.train_output:
+ if isinstance(item[metric], torch.Tensor):
+ metric_value_list.append(item[metric].cpu().detach().numpy())
+ else:
+ metric_value_list.append(item[metric])
+ metric_value = np.mean(metric_value_list)
+ info_for_logging.update({metric: metric_value})
+ else:
+ one_time_warning(
+ "If you want to use wandb to visualize the result, please set plot_logger = True in the config."
+ )
+
+ if ctx.eval_value != -np.inf:
+ if hasattr(ctx, "eval_value_min"):
+ info_for_logging.update({
+ "episode return min": ctx.eval_value_min,
+ })
+ if hasattr(ctx, "eval_value_max"):
+ info_for_logging.update({
+ "episode return max": ctx.eval_value_max,
+ })
+ if hasattr(ctx, "eval_value_std"):
+ info_for_logging.update({
+ "episode return std": ctx.eval_value_std,
+ })
+ if hasattr(ctx, "eval_value"):
+ info_for_logging.update({
+ "episode return mean": ctx.eval_value,
+ })
+ if hasattr(ctx, "train_iter"):
+ info_for_logging.update({
+ "train iter": ctx.train_iter,
+ })
+ if hasattr(ctx, "train_epoch"):
+ info_for_logging.update({
+ "train_epoch": ctx.train_epoch,
+ })
+
+ eval_output = ctx.eval_output['output']
+ episode_return = ctx.eval_output['episode_return']
+ episode_return = np.array(episode_return)
+ if len(episode_return.shape) == 2:
+ episode_return = episode_return.squeeze(1)
+
+ if cfg.video_logger:
+ if 'replay_video' in ctx.eval_output:
+ # save numpy array "images" of shape (N,1212,3,224,320) to N video files in mp4 format
+ # The numpy tensor must be either 4 dimensional or 5 dimensional.
+ # Channels should be (time, channel, height, width) or (batch, time, channel, height width)
+ video_images = ctx.eval_output['replay_video']
+ video_images = video_images.astype(np.uint8)
+ info_for_logging.update({"replay_video": wandb.Video(video_images, fps=60)})
+ elif record_path is not None:
+ file_list = []
+ for p in os.listdir(record_path):
+ if os.path.splitext(p)[-1] == ".mp4":
+ file_list.append(p)
+ file_list.sort(key=lambda fn: os.path.getmtime(os.path.join(record_path, fn)))
+ video_path = os.path.join(record_path, file_list[-2])
+ info_for_logging.update({"video": wandb.Video(video_path, format="mp4")})
+
+ if cfg.action_logger:
+ action_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_action.gif"))
+ if all(['logit' in v for v in eval_output]) or hasattr(eval_output, "logit"):
+ if isinstance(eval_output, tnp.ndarray):
+ action_prob = softmax(eval_output.logit)
+ else:
+ action_prob = [softmax(to_ndarray(v['logit'])) for v in eval_output]
+ fig, ax = plt.subplots()
+ plt.ylim([-1, 1])
+ action_dim = len(action_prob[1])
+ x_range = [str(x + 1) for x in range(action_dim)]
+ ln = ax.bar(x_range, [0 for x in range(action_dim)], color=color_list[:action_dim])
+ ani = animation.FuncAnimation(
+ fig, action_prob, fargs=(action_prob, ln), blit=True, save_count=len(action_prob)
+ )
+ ani.save(action_path, writer='pillow')
+ info_for_logging.update({"action": wandb.Video(action_path, format="gif")})
+
+ elif all(['action' in v for v in eval_output[0]]):
+ for i, action_trajectory in enumerate(eval_output):
+ fig, ax = plt.subplots()
+ fig_data = np.array([[i + 1, *v['action']] for i, v in enumerate(action_trajectory)])
+ steps = fig_data[:, 0]
+ actions = fig_data[:, 1:]
+ plt.ylim([-1, 1])
+ for j in range(actions.shape[1]):
+ ax.scatter(steps, actions[:, j])
+ info_for_logging.update({"actions_of_trajectory_{}".format(i): fig})
+
+ if cfg.return_logger:
+ return_path = os.path.join(record_path, (str(ctx.trained_env_step) + "_return.gif"))
+ fig, ax = plt.subplots()
+ ax = plt.gca()
+ ax.set_ylim([0, 1])
+ hist, x_dim = return_distribution(episode_return)
+ assert len(hist) == len(x_dim)
+ ln_return = ax.bar(x_dim, hist, width=1, color='r', linewidth=0.7)
+ ani = animation.FuncAnimation(fig, return_prob, fargs=(hist, ln_return), blit=True, save_count=1)
+ ani.save(return_path, writer='pillow')
+ info_for_logging.update({"return distribution": wandb.Video(return_path, format="gif")})
+
+ if bool(info_for_logging):
+ wandb.log(data=info_for_logging, step=ctx.trained_env_step)
+ plt.clf()
+
+ return _plot
diff --git a/DI-engine/ding/framework/middleware/functional/priority.py b/DI-engine/ding/framework/middleware/functional/priority.py
new file mode 100644
index 0000000000000000000000000000000000000000..e62afbb5c978cdedff1300b609395924867760f9
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/functional/priority.py
@@ -0,0 +1,24 @@
+from typing import TYPE_CHECKING, Callable
+from ding.framework import task
+if TYPE_CHECKING:
+ from ding.framework import OnlineRLContext
+
+
+def priority_calculator(priority_calculation_fn: Callable) -> Callable:
+ """
+ Overview:
+ The middleware that calculates the priority of the collected data.
+ Arguments:
+ - priority_calculation_fn (:obj:`Callable`): The function that calculates the priority of the collected data.
+ """
+
+ if task.router.is_active and not task.has_role(task.role.COLLECTOR):
+ return task.void()
+
+ def _priority_calculator(ctx: "OnlineRLContext") -> None:
+
+ priority = priority_calculation_fn(ctx.trajectories)
+ for i in range(len(priority)):
+ ctx.trajectories[i]['priority'] = priority[i]
+
+ return _priority_calculator
diff --git a/DI-engine/ding/framework/middleware/functional/termination_checker.py b/DI-engine/ding/framework/middleware/functional/termination_checker.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e0ed518873d9624b1ebd012e42f851fdd46acb0
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/functional/termination_checker.py
@@ -0,0 +1,54 @@
+from typing import TYPE_CHECKING, Union, Callable, Optional
+from ditk import logging
+import numpy as np
+import torch
+from ding.utils import broadcast
+from ding.framework import task
+
+if TYPE_CHECKING:
+ from ding.framework import OnlineRLContext, OfflineRLContext
+
+
+def termination_checker(max_env_step: Optional[int] = None, max_train_iter: Optional[int] = None) -> Callable:
+ if max_env_step is None:
+ max_env_step = np.inf
+ if max_train_iter is None:
+ max_train_iter = np.inf
+
+ def _check(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
+ # ">" is better than ">=" when taking logger result into consideration
+ assert hasattr(ctx, "env_step") or hasattr(ctx, "train_iter"), "Context must have env_step or train_iter"
+ if hasattr(ctx, "env_step") and ctx.env_step > max_env_step:
+ task.finish = True
+ logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step))
+ elif hasattr(ctx, "train_iter") and ctx.train_iter > max_train_iter:
+ task.finish = True
+ logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter))
+
+ return _check
+
+
+def ddp_termination_checker(max_env_step=None, max_train_iter=None, rank=0):
+ if rank == 0:
+ if max_env_step is None:
+ max_env_step = np.inf
+ if max_train_iter is None:
+ max_train_iter = np.inf
+
+ def _check(ctx):
+ if rank == 0:
+ if ctx.env_step > max_env_step:
+ finish = torch.ones(1).long().cuda()
+ logging.info('Exceeded maximum number of env_step({}), program is terminated'.format(ctx.env_step))
+ elif ctx.train_iter > max_train_iter:
+ finish = torch.ones(1).long().cuda()
+ logging.info('Exceeded maximum number of train_iter({}), program is terminated'.format(ctx.train_iter))
+ else:
+ finish = torch.LongTensor([task.finish]).cuda()
+ else:
+ finish = torch.zeros(1).long().cuda()
+ # broadcast finish result to other DDP workers
+ broadcast(finish, 0)
+ task.finish = finish.cpu().bool().item()
+
+ return _check
diff --git a/DI-engine/ding/framework/middleware/functional/timer.py b/DI-engine/ding/framework/middleware/functional/timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..db8a2c00562ea14781a65bb8e4daf383cf96bf0d
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/functional/timer.py
@@ -0,0 +1,35 @@
+import numpy as np
+from collections import deque
+from ditk import logging
+from time import time
+
+from ding.framework import task
+from typing import TYPE_CHECKING
+if TYPE_CHECKING:
+ from ding.framework.context import Context
+
+
+def epoch_timer(print_per: int = 1, smooth_window: int = 10):
+ """
+ Overview:
+ Print time cost of each epoch.
+ Arguments:
+ - print_per (:obj:`int`): Print each N epoch.
+ - smooth_window (:obj:`int`): The window size to smooth the mean.
+ """
+ records = deque(maxlen=print_per * smooth_window)
+
+ def _epoch_timer(ctx: "Context"):
+ start = time()
+ yield
+ time_cost = time() - start
+ records.append(time_cost)
+ if ctx.total_step % print_per == 0:
+ logging.info(
+ "[Epoch Timer][Node:{:>2}]: Cost: {:.2f}ms, Mean: {:.2f}ms".format(
+ task.router.node_id or 0, time_cost * 1000,
+ np.mean(records) * 1000
+ )
+ )
+
+ return _epoch_timer
diff --git a/DI-engine/ding/framework/middleware/functional/trainer.py b/DI-engine/ding/framework/middleware/functional/trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..11c281c1a16993eaf86b7001221a56efb6febce0
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/functional/trainer.py
@@ -0,0 +1,105 @@
+from typing import TYPE_CHECKING, Callable, Union
+from easydict import EasyDict
+import treetensor.torch as ttorch
+from ditk import logging
+import numpy as np
+from ding.policy import Policy
+from ding.framework import task, OfflineRLContext, OnlineRLContext
+
+
+def trainer(cfg: EasyDict, policy: Policy, log_freq: int = 100) -> Callable:
+ """
+ Overview:
+ The middleware that executes a single training process.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config.
+ - policy (:obj:`Policy`): The policy to be trained in step-by-step mode.
+ - log_freq (:obj:`int`): The frequency (iteration) of showing log.
+ """
+ if task.router.is_active and not task.has_role(task.role.LEARNER):
+ return task.void()
+
+ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
+ """
+ Input of ctx:
+ - train_data (:obj:`Dict`): The data used to update the network. It will train only if \
+ the data is not empty.
+ - train_iter: (:obj:`int`): The training iteration count. The log will be printed once \
+ it reachs certain values.
+ Output of ctx:
+ - train_output (:obj:`Dict`): The training output in the Dict format, including loss info.
+ """
+
+ if ctx.train_data is None:
+ return
+ train_output = policy.forward(ctx.train_data)
+ if ctx.train_iter % log_freq == 0:
+ if isinstance(train_output, list):
+ train_output_loss = np.mean([item['total_loss'] for item in train_output])
+ else:
+ train_output_loss = train_output['total_loss']
+ if isinstance(ctx, OnlineRLContext):
+ logging.info(
+ 'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format(
+ ctx.train_iter, ctx.env_step, train_output_loss
+ )
+ )
+ elif isinstance(ctx, OfflineRLContext):
+ logging.info('Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, train_output_loss))
+ else:
+ raise TypeError("not supported ctx type: {}".format(type(ctx)))
+ ctx.train_iter += 1
+ ctx.train_output = train_output
+
+ return _train
+
+
+def multistep_trainer(policy: Policy, log_freq: int = 100) -> Callable:
+ """
+ Overview:
+ The middleware that executes training for a target num of steps.
+ Arguments:
+ - policy (:obj:`Policy`): The policy specialized for multi-step training.
+ - log_freq (:obj:`int`): The frequency (iteration) of showing log.
+ """
+ if task.router.is_active and not task.has_role(task.role.LEARNER):
+ return task.void()
+ last_log_iter = -1
+
+ def _train(ctx: Union["OnlineRLContext", "OfflineRLContext"]):
+ """
+ Input of ctx:
+ - train_data: The data used to update the network.
+ It will train only if the data is not empty.
+ - train_iter: (:obj:`int`): The training iteration count.
+ The log will be printed if it reachs certain values.
+ Output of ctx:
+ - train_output (:obj:`List[Dict]`): The training output listed by steps.
+ """
+
+ if ctx.train_data is None: # no enough data from data fetcher
+ return
+ if hasattr(policy, "_device"): # For ppof policy
+ data = ctx.train_data.to(policy._device)
+ elif hasattr(policy, "get_attribute"): # For other policy
+ data = ctx.train_data.to(policy.get_attribute("device"))
+ else:
+ assert AttributeError("Policy should have attribution '_device'.")
+ train_output = policy.forward(data)
+ nonlocal last_log_iter
+ if ctx.train_iter - last_log_iter >= log_freq:
+ loss = np.mean([o['total_loss'] for o in train_output])
+ if isinstance(ctx, OfflineRLContext):
+ logging.info('Training: Train Iter({})\tLoss({:.3f})'.format(ctx.train_iter, loss))
+ else:
+ logging.info(
+ 'Training: Train Iter({})\tEnv Step({})\tLoss({:.3f})'.format(ctx.train_iter, ctx.env_step, loss)
+ )
+ last_log_iter = ctx.train_iter
+ ctx.train_iter += len(train_output)
+ ctx.train_output = train_output
+
+ return _train
+
+
+# TODO reward model
diff --git a/DI-engine/ding/framework/middleware/learner.py b/DI-engine/ding/framework/middleware/learner.py
new file mode 100644
index 0000000000000000000000000000000000000000..9abf88e9b378c1da96f4c6025d30fcc3085e6e43
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/learner.py
@@ -0,0 +1,107 @@
+from typing import TYPE_CHECKING, Callable, List, Tuple, Union, Dict, Optional
+from easydict import EasyDict
+from collections import deque
+
+from ding.framework import task
+from ding.data import Buffer
+from .functional import trainer, offpolicy_data_fetcher, reward_estimator, her_data_enhancer
+
+if TYPE_CHECKING:
+ from ding.framework import Context, OnlineRLContext
+ from ding.policy import Policy
+ from ding.reward_model import BaseRewardModel
+
+
+class OffPolicyLearner:
+ """
+ Overview:
+ The class of the off-policy learner, including data fetching and model training. Use \
+ the `__call__` method to execute the whole learning process.
+ """
+
+ def __new__(cls, *args, **kwargs):
+ if task.router.is_active and not task.has_role(task.role.LEARNER):
+ return task.void()
+ return super(OffPolicyLearner, cls).__new__(cls)
+
+ def __init__(
+ self,
+ cfg: EasyDict,
+ policy: 'Policy',
+ buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]],
+ reward_model: Optional['BaseRewardModel'] = None,
+ log_freq: int = 100,
+ ) -> None:
+ """
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config.
+ - policy (:obj:`Policy`): The policy to be trained.
+ - buffer (:obj:`Buffer`): The replay buffer to store the data for training.
+ - reward_model (:obj:`BaseRewardModel`): Additional reward estimator likes RND, ICM, etc. \
+ default to None.
+ - log_freq (:obj:`int`): The frequency (iteration) of showing log.
+ """
+ self.cfg = cfg
+ self._fetcher = task.wrap(offpolicy_data_fetcher(cfg, buffer_))
+ self._trainer = task.wrap(trainer(cfg, policy, log_freq=log_freq))
+ if reward_model is not None:
+ self._reward_estimator = task.wrap(reward_estimator(cfg, reward_model))
+ else:
+ self._reward_estimator = None
+
+ def __call__(self, ctx: "OnlineRLContext") -> None:
+ """
+ Output of ctx:
+ - train_output (:obj:`Deque`): The training output in deque.
+ """
+ train_output_queue = []
+ for _ in range(self.cfg.policy.learn.update_per_collect):
+ self._fetcher(ctx)
+ if ctx.train_data is None:
+ break
+ if self._reward_estimator:
+ self._reward_estimator(ctx)
+ self._trainer(ctx)
+ train_output_queue.append(ctx.train_output)
+ ctx.train_output = train_output_queue
+
+
+class HERLearner:
+ """
+ Overview:
+ The class of the learner with the Hindsight Experience Replay (HER). \
+ Use the `__call__` method to execute the data featching and training \
+ process.
+ """
+
+ def __init__(
+ self,
+ cfg: EasyDict,
+ policy,
+ buffer_: Union[Buffer, List[Tuple[Buffer, float]], Dict[str, Buffer]],
+ her_reward_model,
+ ) -> None:
+ """
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config.
+ - policy (:obj:`Policy`): The policy to be trained.
+ - buffer\_ (:obj:`Buffer`): The replay buffer to store the data for training.
+ - her_reward_model (:obj:`HerRewardModel`): HER reward model.
+ """
+ self.cfg = cfg
+ self._fetcher = task.wrap(her_data_enhancer(cfg, buffer_, her_reward_model))
+ self._trainer = task.wrap(trainer(cfg, policy))
+
+ def __call__(self, ctx: "OnlineRLContext") -> None:
+ """
+ Output of ctx:
+ - train_output (:obj:`Deque`): The deque of training output.
+ """
+ train_output_queue = []
+ for _ in range(self.cfg.policy.learn.update_per_collect):
+ self._fetcher(ctx)
+ if ctx.train_data is None:
+ break
+ self._trainer(ctx)
+ train_output_queue.append(ctx.train_output)
+ ctx.train_output = train_output_queue
diff --git a/DI-engine/ding/framework/middleware/tests/__init__.py b/DI-engine/ding/framework/middleware/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bb84e7fe2d3c4739ce33bee84f0a230fd25d889
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/__init__.py
@@ -0,0 +1 @@
+from .mock_for_test import MockEnv, MockPolicy, MockHerRewardModel, CONFIG
diff --git a/DI-engine/ding/framework/middleware/tests/mock_for_test.py b/DI-engine/ding/framework/middleware/tests/mock_for_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ad88909a554a3d070f727724987781cfd329b38
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/mock_for_test.py
@@ -0,0 +1,118 @@
+from typing import Union, Any, List, Callable, Dict, Optional
+from collections import namedtuple
+import torch
+import treetensor.numpy as tnp
+from easydict import EasyDict
+from unittest.mock import Mock
+
+obs_dim = [2, 2]
+action_space = 1
+env_num = 2
+
+CONFIG = dict(
+ seed=0,
+ policy=dict(
+ learn=dict(
+ update_per_collect=4,
+ batch_size=8,
+ learner=dict(hook=dict(log_show_after_iter=10), ),
+ ),
+ collect=dict(
+ n_sample=16,
+ unroll_len=1,
+ n_episode=16,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10), ),
+ other=dict(eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ), ),
+ ),
+ env=dict(
+ n_evaluator_episode=5,
+ stop_value=2.0,
+ ),
+)
+CONFIG = EasyDict(CONFIG)
+
+
+class MockPolicy(Mock):
+
+ def __init__(self) -> None:
+ super(MockPolicy, self).__init__()
+ self.action_space = action_space
+ self.obs_dim = obs_dim
+
+ def reset(self, data_id: Optional[List[int]] = None) -> None:
+ return
+
+ def forward(self, data: dict, **kwargs) -> dict:
+ res = {}
+ for i, v in data.items():
+ res[i] = {'action': torch.sum(v)}
+ return res
+
+ def process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ transition = {
+ 'obs': torch.rand(self.obs_dim),
+ 'next_obs': torch.rand(self.obs_dim),
+ 'action': torch.zeros(self.action_space),
+ 'logit': 1.0,
+ 'value': 2.0,
+ 'reward': 0.1,
+ 'done': True,
+ }
+ return transition
+
+
+class MockEnv(Mock):
+
+ def __init__(self) -> None:
+ super(MockEnv, self).__init__()
+ self.env_num = env_num
+ self.obs_dim = obs_dim
+ self.closed = False
+ self._reward_grow_indicator = 1
+
+ @property
+ def ready_obs(self) -> tnp.array:
+ return tnp.stack([
+ torch.zeros(self.obs_dim),
+ torch.ones(self.obs_dim),
+ ])
+
+ def seed(self, seed: Union[Dict[int, int], List[int], int], dynamic_seed: bool = None) -> None:
+ return
+
+ def launch(self, reset_param: Optional[Dict] = None) -> None:
+ return
+
+ def reset(self, reset_param: Optional[Dict] = None) -> None:
+ return
+
+ def step(self, actions: tnp.ndarray) -> List[tnp.ndarray]:
+ timesteps = []
+ for i in range(self.env_num):
+ timestep = dict(
+ obs=torch.rand(self.obs_dim),
+ reward=1.0,
+ done=True,
+ info={'eval_episode_return': self._reward_grow_indicator * 1.0},
+ env_id=i,
+ )
+ timesteps.append(tnp.array(timestep))
+ self._reward_grow_indicator += 1 # eval_episode_return will increase as step method is called
+ return timesteps
+
+
+class MockHerRewardModel(Mock):
+
+ def __init__(self) -> None:
+ super(MockHerRewardModel, self).__init__()
+ self.episode_size = 8
+ self.episode_element_size = 4
+
+ def estimate(self, episode: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ return [[episode[0] for _ in range(self.episode_element_size)]]
diff --git a/DI-engine/ding/framework/middleware/tests/test_advantage_estimator.py b/DI-engine/ding/framework/middleware/tests/test_advantage_estimator.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd2a12c19b29535eee9c7464def16133e707c93b
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/test_advantage_estimator.py
@@ -0,0 +1,160 @@
+import pytest
+
+from ding.data.buffer import DequeBuffer
+from ding.data import Buffer
+from easydict import EasyDict
+from ding.framework import OnlineRLContext
+import treetensor
+import torch
+import copy
+
+from ding.framework.middleware.functional.advantage_estimator import gae_estimator
+from ding.framework.middleware.functional.advantage_estimator import montecarlo_return_estimator
+from ding.utils.data import ttorch_collate
+
+from typing import Any, List, Dict, Optional
+
+from unittest.mock import Mock, patch
+
+
+class TheModelClass:
+
+ def forward(self, obs: Dict, mode: str) -> Dict:
+ return {'value': torch.distributions.uniform.Uniform(0, 4).sample([len(obs.data)])}
+
+
+class MockPolicy(Mock):
+
+ def __init__(self, model) -> None:
+ super(MockPolicy, self).__init__()
+ self._model = model
+
+ def get_attribute(self, name: str) -> Any:
+ return self._model
+
+
+def call_gae_estimator(batch_size: int = 32, trajectory_end_idx_size: int = 5, buffer: Optional[Buffer] = None):
+ cfg = EasyDict(
+ {
+ 'policy': {
+ 'model': {
+ 'obs_shape': 4,
+ 'action_shape': 2,
+ },
+ 'collect': {
+ 'discount_factor': 0.9,
+ 'gae_lambda': 0.95
+ },
+ 'cuda': False
+ }
+ }
+ )
+
+ ctx = OnlineRLContext()
+ assert trajectory_end_idx_size <= batch_size
+
+ ctx.trajectory_end_idx = treetensor.torch.randint(low=0, high=batch_size, size=(trajectory_end_idx_size, ))
+ ctx.trajectories = [
+ treetensor.torch.Tensor(
+ {
+ 'action': treetensor.torch.randint(low=0, high=2, size=(1, )),
+ 'collect_train_iter': [0],
+ 'done': False,
+ 'logit': treetensor.torch.randn(2),
+ 'next_obs': treetensor.torch.randn(4),
+ 'obs': treetensor.torch.randn(4),
+ 'reward': [1.0],
+ 'value': torch.distributions.uniform.Uniform(0, 4).sample([1])
+ }
+ ) for _ in range(batch_size)
+ ]
+ ctx.trajectories_copy = ttorch_collate(copy.deepcopy(ctx.trajectories), cat_1dim=True)
+ traj_flag = ctx.trajectories_copy.done.clone()
+ traj_flag[ctx.trajectory_end_idx] = True
+ ctx.trajectories_copy.traj_flag = traj_flag
+
+ with patch("ding.policy.Policy", MockPolicy):
+ gae_estimator(cfg, MockPolicy(TheModelClass()), buffer)(ctx)
+
+ if buffer is not None:
+ train_data = [d.data for d in list(buffer.storage)]
+ for d in train_data:
+ d.logit = d.logit
+ d.next_obs = d.next_obs
+ d.obs = d.obs
+ ctx.train_data = ttorch_collate(train_data, cat_1dim=True)
+
+ assert ctx.trajectories is None
+ assert torch.equal(ctx.trajectories_copy.action, ctx.train_data.action)
+ assert torch.equal(ctx.trajectories_copy.collect_train_iter, ctx.train_data.collect_train_iter)
+ assert torch.equal(ctx.trajectories_copy.logit, ctx.train_data.logit)
+ assert torch.equal(ctx.trajectories_copy.next_obs, ctx.train_data.next_obs)
+ assert torch.equal(ctx.trajectories_copy.obs, ctx.train_data.obs)
+ assert torch.equal(ctx.trajectories_copy.reward, ctx.train_data.reward)
+ assert torch.equal(ctx.trajectories_copy.traj_flag, ctx.train_data.traj_flag)
+
+
+@pytest.mark.unittest
+def test_gae_estimator():
+ batch_size = 32
+ trajectory_end_idx_size = 5
+ call_gae_estimator(batch_size, trajectory_end_idx_size)
+ call_gae_estimator(batch_size, trajectory_end_idx_size, DequeBuffer(size=batch_size))
+
+
+class MockPGPolicy(Mock):
+
+ def __init__(self, cfg) -> None:
+ super(MockPGPolicy, self).__init__()
+ self._cfg = EasyDict(cfg)
+ self._gamma = self._cfg.collect.discount_factor
+ self._unroll_len = self._cfg.collect.unroll_len
+
+ def get_attribute(self, name: str) -> Any:
+ return self._model
+
+
+def call_montecarlo_return_estimator(batch_size: int = 32):
+
+ cfg = dict(
+ learn=dict(ignore_done=False, ),
+ collect=dict(
+ unroll_len=1,
+ discount_factor=0.9,
+ ),
+ )
+ ctx = OnlineRLContext()
+ ctx.episodes = [
+ [
+ treetensor.torch.Tensor(
+ {
+ 'action': treetensor.torch.randint(low=0, high=2, size=(1, )),
+ 'collect_train_iter': [0],
+ 'done': False if i != batch_size - 1 else True,
+ 'logit': treetensor.torch.randn(2),
+ 'next_obs': treetensor.torch.randn(4),
+ 'obs': treetensor.torch.randn(4),
+ 'reward': [1.0],
+ 'value': torch.distributions.uniform.Uniform(0, 4).sample([1])
+ }
+ ) for i in range(batch_size)
+ ]
+ ]
+ ctx.episodes_copy = treetensor.torch.concat(
+ [ttorch_collate(copy.deepcopy(episode), cat_1dim=True) for episode in ctx.episodes], dim=0
+ )
+ with patch("ding.policy.Policy", MockPGPolicy):
+ montecarlo_return_estimator(MockPGPolicy(cfg))(ctx)
+
+ assert torch.equal(ctx.episodes_copy.action, ctx.train_data.action)
+ assert torch.equal(ctx.episodes_copy.collect_train_iter, ctx.train_data.collect_train_iter)
+ assert torch.equal(ctx.episodes_copy.logit, ctx.train_data.logit)
+ assert torch.equal(ctx.episodes_copy.next_obs, ctx.train_data.next_obs)
+ assert torch.equal(ctx.episodes_copy.obs, ctx.train_data.obs)
+ assert torch.equal(ctx.episodes_copy.reward, ctx.train_data.reward)
+
+
+@pytest.mark.unittest
+def test_montecarlo_return_estimator():
+ batch_size = 32
+ call_montecarlo_return_estimator(batch_size)
diff --git a/DI-engine/ding/framework/middleware/tests/test_barrier.py b/DI-engine/ding/framework/middleware/tests/test_barrier.py
new file mode 100644
index 0000000000000000000000000000000000000000..de176eda2dbc1752ba43acb241da73cc5bb0a302
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/test_barrier.py
@@ -0,0 +1,144 @@
+import random
+import time
+import socket
+import pytest
+import multiprocessing as mp
+from ditk import logging
+from ding.framework import task
+from ding.framework.parallel import Parallel
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware.barrier import Barrier
+
+PORTS_LIST = ["1235", "1236", "1237"]
+
+
+class EnvStepMiddleware:
+
+ def __call__(self, ctx):
+ yield
+ ctx.env_step += 1
+
+
+class SleepMiddleware:
+
+ def __init__(self, node_id):
+ self.node_id = node_id
+
+ def random_sleep(self, diection, step):
+ random.seed(self.node_id + step)
+ sleep_second = random.randint(1, 5)
+ logging.info("Node:[{}] env_step:[{}]-{} will sleep:{}s".format(self.node_id, step, diection, sleep_second))
+ for i in range(sleep_second):
+ time.sleep(1)
+ print("Node:[{}] sleepping...".format(self.node_id))
+ logging.info("Node:[{}] env_step:[{}]-{} wake up!".format(self.node_id, step, diection))
+
+ def __call__(self, ctx):
+ self.random_sleep("forward", ctx.env_step)
+ yield
+ self.random_sleep("backward", ctx.env_step)
+
+
+def star_barrier():
+ with task.start(ctx=OnlineRLContext()):
+ node_id = task.router.node_id
+ if node_id == 0:
+ attch_from_nums = 3
+ else:
+ attch_from_nums = 0
+ barrier = Barrier(attch_from_nums)
+ task.use(barrier, lock=False)
+ task.use(SleepMiddleware(node_id), lock=False)
+ task.use(barrier, lock=False)
+ task.use(EnvStepMiddleware(), lock=False)
+ try:
+ task.run(2)
+ except Exception as e:
+ logging.error(e)
+ assert False
+
+
+def mesh_barrier():
+ with task.start(ctx=OnlineRLContext()):
+ node_id = task.router.node_id
+ attch_from_nums = 3 - task.router.node_id
+ barrier = Barrier(attch_from_nums)
+ task.use(barrier, lock=False)
+ task.use(SleepMiddleware(node_id), lock=False)
+ task.use(barrier, lock=False)
+ task.use(EnvStepMiddleware(), lock=False)
+ try:
+ task.run(2)
+ except Exception as e:
+ logging.error(e)
+ assert False
+
+
+def unmatch_barrier():
+ with task.start(ctx=OnlineRLContext()):
+ node_id = task.router.node_id
+ attch_from_nums = 3 - task.router.node_id
+ task.use(Barrier(attch_from_nums, 5), lock=False)
+ if node_id != 2:
+ task.use(Barrier(attch_from_nums, 5), lock=False)
+ try:
+ task.run(2)
+ except TimeoutError as e:
+ assert node_id != 2
+ logging.info("Node:[{}] timeout with barrier".format(node_id))
+ else:
+ time.sleep(5)
+ assert node_id == 2
+ logging.info("Node:[{}] finish barrier".format(node_id))
+
+
+def launch_barrier(args):
+ i, topo, fn, test_id = args
+ address = socket.gethostbyname(socket.gethostname())
+ topology = "alone"
+ attach_to = []
+ port_base = PORTS_LIST[test_id]
+ port = port_base + str(i)
+ if topo == 'star':
+ if i != 0:
+ attach_to = ['tcp://{}:{}{}'.format(address, port_base, 0)]
+ elif topo == 'mesh':
+ for j in range(i):
+ attach_to.append('tcp://{}:{}{}'.format(address, port_base, j))
+
+ Parallel.runner(
+ node_ids=i,
+ ports=int(port),
+ attach_to=attach_to,
+ topology=topology,
+ protocol="tcp",
+ n_parallel_workers=1,
+ startup_interval=0
+ )(fn)
+
+
+@pytest.mark.unittest
+def test_star_topology_barrier():
+ ctx = mp.get_context("spawn")
+ with ctx.Pool(processes=4) as pool:
+ pool.map(launch_barrier, [[i, 'star', star_barrier, 0] for i in range(4)])
+ pool.close()
+ pool.join()
+
+
+@pytest.mark.unittest
+def test_mesh_topology_barrier():
+ ctx = mp.get_context("spawn")
+ with ctx.Pool(processes=4) as pool:
+ pool.map(launch_barrier, [[i, 'mesh', mesh_barrier, 1] for i in range(4)])
+ pool.close()
+ pool.join()
+
+
+@pytest.mark.unittest
+def test_unmatch_barrier():
+ ctx = mp.get_context("spawn")
+ with ctx.Pool(processes=4) as pool:
+ pool.map(launch_barrier, [[i, 'mesh', unmatch_barrier, 2] for i in range(4)])
+ pool.close()
+ pool.join()
diff --git a/DI-engine/ding/framework/middleware/tests/test_ckpt_handler.py b/DI-engine/ding/framework/middleware/tests/test_ckpt_handler.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0d81f0a33588c08c01250fa9f62b142b62735fb
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/test_ckpt_handler.py
@@ -0,0 +1,75 @@
+import pytest
+
+from easydict import EasyDict
+from ding.framework import OnlineRLContext
+from ding.framework.middleware.ckpt_handler import CkptSaver
+
+import torch.nn as nn
+import torch.optim as optim
+import os
+import shutil
+
+from unittest.mock import Mock, patch
+from ding.framework import task
+from ding.policy.base_policy import Policy
+
+
+class TheModelClass(nn.Module):
+
+ def state_dict(self):
+ return 'fake_state_dict'
+
+
+class MockPolicy(Mock):
+
+ def __init__(self, model, **kwargs) -> None:
+ super(MockPolicy, self).__init__(model)
+ self.learn_mode = model
+
+ @property
+ def eval_mode(self):
+ return EasyDict({"state_dict": lambda: {}})
+
+
+@pytest.mark.unittest
+def test_ckpt_saver():
+ exp_name = 'test_ckpt_saver_exp'
+
+ ctx = OnlineRLContext()
+
+ train_freq = 100
+ model = TheModelClass()
+
+ if not os.path.exists(exp_name):
+ os.makedirs(exp_name)
+
+ prefix = '{}/ckpt'.format(exp_name)
+
+ with patch("ding.policy.Policy", MockPolicy), task.start():
+ policy = MockPolicy(model)
+
+ def mock_save_file(path, data, fs_type=None, use_lock=False):
+ assert path == "{}/eval.pth.tar".format(prefix)
+
+ with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file):
+ ctx.train_iter = 1
+ ctx.eval_value = 9.4
+ ckpt_saver = CkptSaver(policy, exp_name, train_freq)
+ ckpt_saver(ctx)
+
+ def mock_save_file(path, data, fs_type=None, use_lock=False):
+ assert path == "{}/iteration_{}.pth.tar".format(prefix, ctx.train_iter)
+
+ with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file):
+ ctx.train_iter = 100
+ ctx.eval_value = 1
+ ckpt_saver(ctx)
+
+ def mock_save_file(path, data, fs_type=None, use_lock=False):
+ assert path == "{}/final.pth.tar".format(prefix)
+
+ with patch("ding.framework.middleware.ckpt_handler.save_file", mock_save_file):
+ task.finish = True
+ ckpt_saver(ctx)
+
+ shutil.rmtree(exp_name)
diff --git a/DI-engine/ding/framework/middleware/tests/test_collector.py b/DI-engine/ding/framework/middleware/tests/test_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..13d45c3c3d4dd3a1b02113588e5d049ba062ee5f
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/test_collector.py
@@ -0,0 +1,84 @@
+import pytest
+import torch
+import copy
+from unittest.mock import patch
+from ding.framework import OnlineRLContext, task
+from ding.framework.middleware import TransitionList, inferencer, rolloutor
+from ding.framework.middleware import StepCollector, EpisodeCollector
+from ding.framework.middleware.tests import MockPolicy, MockEnv, CONFIG
+
+
+@pytest.mark.unittest
+def test_inferencer():
+ ctx = OnlineRLContext()
+ with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv):
+ policy = MockPolicy()
+ env = MockEnv()
+ inferencer(0, policy, env)(ctx)
+ assert isinstance(ctx.inference_output, dict)
+ assert ctx.inference_output[0] == {'action': torch.Tensor([0.])} # sum of zeros([2, 2])
+ assert ctx.inference_output[1] == {'action': torch.Tensor([4.])} # sum of ones([2, 2])
+
+
+@pytest.mark.unittest
+def test_rolloutor():
+ ctx = OnlineRLContext()
+ transitions = TransitionList(2)
+ with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv):
+ policy = MockPolicy()
+ env = MockEnv()
+ for _ in range(10):
+ inferencer(0, policy, env)(ctx)
+ rolloutor(policy, env, transitions)(ctx)
+ assert ctx.env_episode == 20 # 10 * env_num
+ assert ctx.env_step == 20 # 10 * env_num
+
+
+@pytest.mark.unittest
+def test_step_collector():
+ cfg = copy.deepcopy(CONFIG)
+ ctx = OnlineRLContext()
+
+ # test no random_collect_size
+ with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv):
+ with task.start():
+ policy = MockPolicy()
+ env = MockEnv()
+ collector = StepCollector(cfg, policy, env)
+ collector(ctx)
+ assert len(ctx.trajectories) == 16
+ assert ctx.trajectory_end_idx == [7, 15]
+
+ # test with random_collect_size
+ with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv):
+ with task.start():
+ policy = MockPolicy()
+ env = MockEnv()
+ collector = StepCollector(cfg, policy, env, random_collect_size=8)
+ collector(ctx)
+ assert len(ctx.trajectories) == 16
+ assert ctx.trajectory_end_idx == [7, 15]
+
+
+@pytest.mark.unittest
+def test_episode_collector():
+ cfg = copy.deepcopy(CONFIG)
+ ctx = OnlineRLContext()
+
+ # test no random_collect_size
+ with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv):
+ with task.start():
+ policy = MockPolicy()
+ env = MockEnv()
+ collector = EpisodeCollector(cfg, policy, env)
+ collector(ctx)
+ assert len(ctx.episodes) == 16
+
+ # test with random_collect_size
+ with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv):
+ with task.start():
+ policy = MockPolicy()
+ env = MockEnv()
+ collector = EpisodeCollector(cfg, policy, env, random_collect_size=8)
+ collector(ctx)
+ assert len(ctx.episodes) == 16
diff --git a/DI-engine/ding/framework/middleware/tests/test_data_processor.py b/DI-engine/ding/framework/middleware/tests/test_data_processor.py
new file mode 100644
index 0000000000000000000000000000000000000000..d63d392943376906c394a38592cf9f8e14eadd4d
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/test_data_processor.py
@@ -0,0 +1,260 @@
+import tempfile
+import pytest
+
+from ding.data.buffer import DequeBuffer
+
+from ding.framework import Context, OnlineRLContext, OfflineRLContext
+from ding.framework.middleware.functional.data_processor import \
+ data_pusher, offpolicy_data_fetcher, offline_data_fetcher, offline_data_saver, sqil_data_pusher, buffer_saver
+
+from ding.data.buffer.middleware import PriorityExperienceReplay
+
+from easydict import EasyDict
+from ding.data import Dataset
+from collections import deque
+import torch
+import math
+import os
+import copy
+
+from unittest.mock import patch
+
+
+@pytest.mark.unittest
+def test_data_pusher():
+ buffer_ = DequeBuffer(size=10)
+ ctx = OnlineRLContext()
+ ctx.trajectories = [i for i in range(5)]
+ data_pusher(cfg=None, buffer_=buffer_)(ctx)
+ assert buffer_.count() == 5
+
+ buffer_ = DequeBuffer(size=10)
+ ctx = OnlineRLContext()
+ ctx.episodes = [i for i in range(5)]
+ data_pusher(cfg=None, buffer_=buffer_)(ctx)
+ assert buffer_.count() == 5
+
+ buffer_ = DequeBuffer(size=10)
+ ctx = OnlineRLContext()
+ with pytest.raises(RuntimeError) as exc_info:
+ data_pusher(cfg=None, buffer_=buffer_)(ctx)
+ assert str(exc_info.value) == "Either ctx.trajectories or ctx.episodes should be not None."
+
+
+def offpolicy_data_fetcher_type_buffer_helper(priority=0.5, use_list=True):
+ cfg = EasyDict({'policy': {'learn': {'batch_size': 20}, 'collect': {'unroll_len': 1}}})
+ buffer = DequeBuffer(size=20)
+ buffer.use(PriorityExperienceReplay(buffer=buffer))
+ for i in range(20):
+ buffer.push({'obs': i, 'reward': 1, 'info': 'xxx'})
+ ctx = OnlineRLContext()
+
+ if use_list:
+ ctx.train_output = [{'priority': [priority for _ in range(20)]}]
+ else:
+ ctx.train_output = {'priority': [priority for _ in range(20)]}
+
+ func_generator = offpolicy_data_fetcher(cfg=cfg, buffer_=buffer)(ctx)
+ next(func_generator)
+ assert len(ctx.train_data) == cfg.policy.learn.batch_size
+ assert all(d['obs'] >= 0 and i < 20 and isinstance(i, int) for d in ctx.train_data)
+ assert [d['obs'] for d in ctx.train_data] == [i for i in range(20)]
+ assert [d['reward'] for d in ctx.train_data] == [1 for i in range(20)]
+ assert [d['info'] for d in ctx.train_data] == ['xxx' for i in range(20)]
+ assert [d['priority_IS'] for d in ctx.train_data] == [torch.tensor([1]) for i in range(20)]
+ assert list(buffer.storage)[0].meta['priority'] == 1.0
+ # assert sorted(ctx.train_data) == [i for i in range(20)]
+
+ try:
+ next(func_generator)
+ except StopIteration:
+ pass
+ assert list(buffer.storage)[0].meta['priority'] == priority
+
+
+def call_offpolicy_data_fetcher_type_buffer():
+ # if isinstance(buffer_, Buffer):
+ offpolicy_data_fetcher_type_buffer_helper(priority=0.5, use_list=True)
+ offpolicy_data_fetcher_type_buffer_helper(priority=0.3, use_list=False)
+
+
+def call_offpolicy_data_fetcher_type_list():
+ #elif isinstance(buffer_, List)
+ cfg = EasyDict({'policy': {'learn': {'batch_size': 5}, 'collect': {'unroll_len': 1}}})
+ buffer = DequeBuffer(size=20)
+ for i in range(20):
+ buffer.push(i)
+ ctx = OnlineRLContext()
+ buffer1 = copy.deepcopy(buffer)
+ buffer2 = copy.deepcopy(buffer)
+ buffer3 = copy.deepcopy(buffer)
+ buffer_list = [(buffer1, 1), (buffer2, 2), (buffer3, 3)]
+
+ next(offpolicy_data_fetcher(cfg=cfg, buffer_=buffer_list)(ctx))
+ assert len(ctx.train_data) == cfg.policy.learn.batch_size * (1 + 2 + 3)
+ assert all(i >= 0 and i < 20 and isinstance(i, int) for i in ctx.train_data)
+
+
+def call_offpolicy_data_fetcher_type_dict():
+ #elif isinstance(buffer_, Dict)
+ cfg = EasyDict({'policy': {'learn': {'batch_size': 5}, 'collect': {'unroll_len': 1}}})
+ buffer = DequeBuffer(size=20)
+ for i in range(20):
+ buffer.push(i)
+ ctx = OnlineRLContext()
+ buffer1 = copy.deepcopy(buffer)
+ buffer2 = copy.deepcopy(buffer)
+ buffer3 = copy.deepcopy(buffer)
+ buffer_dict = {'key1': buffer1, 'key2': buffer2, 'key3': buffer3}
+
+ next(offpolicy_data_fetcher(cfg=cfg, buffer_=buffer_dict)(ctx))
+ assert all(len(v) == cfg.policy.learn.batch_size for k, v in ctx.train_data.items())
+ assert all(all(i >= 0 and i < 20 and isinstance(i, int) for i in v) for k, v in ctx.train_data.items())
+
+
+def call_offpolicy_data_fetcher_type_int():
+ # else catch TypeError
+ cfg = EasyDict({'policy': {'learn': {'batch_size': 5}, 'collect': {'unroll_len': 1}}})
+ ctx = OnlineRLContext()
+ with pytest.raises(TypeError) as exc_info:
+ next(offpolicy_data_fetcher(cfg=cfg, buffer_=1)(ctx))
+ assert str(exc_info.value) == "not support buffer argument type: {}".format(type(1))
+
+
+@pytest.mark.unittest
+def test_offpolicy_data_fetcher():
+ call_offpolicy_data_fetcher_type_buffer()
+ call_offpolicy_data_fetcher_type_list()
+ call_offpolicy_data_fetcher_type_dict()
+ call_offpolicy_data_fetcher_type_int()
+
+
+@pytest.mark.unittest
+def test_offline_data_fetcher():
+ cfg = EasyDict({'policy': {'learn': {'batch_size': 5}}})
+ dataset_size = 10
+ num_batch = math.ceil(dataset_size / cfg.policy.learn.batch_size)
+ data = torch.linspace(11, 20, dataset_size)
+ data_list = list(data)
+
+ class MyDataset(Dataset):
+
+ def __init__(self):
+ self.x = data
+ self.len = len(self.x)
+
+ def __getitem__(self, index):
+ return self.x[index]
+
+ def __len__(self):
+ return self.len
+
+ ctx = OfflineRLContext()
+ ctx.train_epoch = 0
+
+ data_tmp = []
+ fetch = offline_data_fetcher(cfg, MyDataset())
+ for i in range(num_batch):
+ fetch(ctx)
+ assert i // num_batch == ctx.train_epoch
+ data_tmp.extend(ctx.train_data)
+
+ if i % num_batch == num_batch - 1:
+ assert sorted(data_tmp) == data_list
+ data_tmp = []
+ if i >= num_batch * 5 - 1:
+ break
+
+
+@pytest.mark.unittest
+def test_offline_data_saver():
+ transition = {}
+ transition['obs'] = torch.zeros((3, 1))
+ transition['next_obs'] = torch.zeros((3, 1))
+ transition['action'] = torch.zeros((1, 1))
+ transition['reward'] = torch.tensor((1, ))
+ transition['done'] = False
+ transition['collect_iter'] = 0
+
+ fake_data = [transition for i in range(32)]
+
+ ctx = OnlineRLContext()
+ ctx.trajectories = fake_data
+ data_path_ = './expert.pkl'
+
+ def mock_offline_data_save_type(exp_data, expert_data_path, data_type):
+ assert exp_data == fake_data
+ assert expert_data_path == data_path_
+ assert data_type == 'naive'
+
+ with patch("ding.framework.middleware.functional.data_processor.offline_data_save_type",
+ mock_offline_data_save_type):
+ offline_data_saver(data_path=data_path_, data_type='naive')(ctx)
+
+ assert ctx.trajectories is None
+
+ ctx = OnlineRLContext()
+ ctx.trajectories = fake_data
+
+ def mock_offline_data_save_type(exp_data, expert_data_path, data_type):
+ assert exp_data == fake_data
+ assert expert_data_path == data_path_
+ assert data_type == 'hdf5'
+
+ with patch("ding.framework.middleware.functional.data_processor.offline_data_save_type",
+ mock_offline_data_save_type):
+ offline_data_saver(data_path=data_path_, data_type='hdf5')(ctx)
+
+ assert ctx.trajectories is None
+
+
+@pytest.mark.unittest
+def test_sqil_data_pusher():
+ transition = {}
+ transition['obs'] = torch.zeros((3, 1))
+ transition['next_obs'] = torch.zeros((3, 1))
+ transition['action'] = torch.zeros((1, 1))
+ transition['reward'] = torch.tensor((2, ))
+ transition['done'] = False
+ transition['collect_iter'] = 0
+ transition = EasyDict(transition)
+
+ fake_data = [transition for i in range(5)]
+
+ # expert = True
+ ctx = OnlineRLContext()
+ ctx.trajectories = copy.deepcopy(fake_data)
+ buffer = DequeBuffer(size=10)
+ sqil_data_pusher(cfg=None, buffer_=buffer, expert=True)(ctx)
+ assert buffer.count() == 5
+ assert all(t.data.reward == 1 for t in list(buffer.storage))
+
+ # expert = False
+ ctx = OnlineRLContext()
+ ctx.trajectories = copy.deepcopy(fake_data)
+ buffer = DequeBuffer(size=10)
+ sqil_data_pusher(cfg=None, buffer_=buffer, expert=False)(ctx)
+ assert buffer.count() == 5
+ assert all(t.data.reward == 0 for t in list(buffer.storage))
+
+
+@pytest.mark.unittest
+def test_buffer_saver():
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ test_folder = os.path.join(tmpdirname, "test_buffer_saver")
+ cfg = EasyDict({"exp_name": test_folder})
+ os.makedirs(test_folder)
+ buffer_ = DequeBuffer(size=10)
+ ctx = OnlineRLContext()
+ ctx.trajectories = [i for i in range(5)]
+ ctx.env_step = 0
+ data_pusher(cfg=cfg, buffer_=buffer_)(ctx)
+ assert buffer_.count() == 5
+ buffer_saver(cfg=cfg, buffer_=buffer_, replace=False)(ctx)
+ buffer_saver(cfg=cfg, buffer_=buffer_, replace=True)(ctx)
+ buffer_1 = DequeBuffer(size=10)
+ buffer_1.load_data(os.path.join(test_folder, "replaybuffer", "data_latest.hkl"))
+ assert buffer_1.count() == 5
+ buffer_2 = DequeBuffer(size=10)
+ buffer_2.load_data(os.path.join(test_folder, "replaybuffer", "data_envstep_0.hkl"))
+ assert buffer_2.count() == 5
diff --git a/DI-engine/ding/framework/middleware/tests/test_distributer.py b/DI-engine/ding/framework/middleware/tests/test_distributer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7651e66ec79df1870aac62f1ab37bd5409e27678
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/test_distributer.py
@@ -0,0 +1,269 @@
+import shutil
+from time import sleep
+import pytest
+import numpy as np
+import tempfile
+
+import torch
+from ding.data.model_loader import FileModelLoader
+from ding.data.storage_loader import FileStorageLoader
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware.distributer import ContextExchanger, ModelExchanger, PeriodicalModelExchanger
+from ding.framework.parallel import Parallel
+from ding.utils.default_helper import set_pkg_seed
+from os import path
+
+
+def context_exchanger_main():
+ with task.start(ctx=OnlineRLContext()):
+ if task.router.node_id == 0:
+ task.add_role(task.role.LEARNER)
+ elif task.router.node_id == 1:
+ task.add_role(task.role.COLLECTOR)
+
+ task.use(ContextExchanger(skip_n_iter=1))
+
+ if task.has_role(task.role.LEARNER):
+
+ def learner_context(ctx: OnlineRLContext):
+ assert len(ctx.trajectories) == 2
+ assert len(ctx.trajectory_end_idx) == 4
+ assert len(ctx.episodes) == 8
+ assert ctx.env_step > 0
+ assert ctx.env_episode > 0
+ yield
+ ctx.train_iter += 1
+
+ task.use(learner_context)
+ elif task.has_role(task.role.COLLECTOR):
+
+ def collector_context(ctx: OnlineRLContext):
+ if ctx.total_step > 0:
+ assert ctx.train_iter > 0
+ yield
+ ctx.trajectories = [np.random.rand(10, 10) for _ in range(2)]
+ ctx.trajectory_end_idx = [1 for _ in range(4)]
+ ctx.episodes = [np.random.rand(10, 10) for _ in range(8)]
+ ctx.env_step += 1
+ ctx.env_episode += 1
+
+ task.use(collector_context)
+
+ task.run(max_step=3)
+
+
+@pytest.mark.tmp
+def test_context_exchanger():
+ Parallel.runner(n_parallel_workers=2)(context_exchanger_main)
+
+
+def context_exchanger_with_storage_loader_main():
+ with task.start(ctx=OnlineRLContext()):
+ if task.router.node_id == 0:
+ task.add_role(task.role.LEARNER)
+ elif task.router.node_id == 1:
+ task.add_role(task.role.COLLECTOR)
+
+ tempdir = path.join(tempfile.gettempdir(), "test_storage_loader")
+ storage_loader = FileStorageLoader(dirname=tempdir)
+ try:
+ task.use(ContextExchanger(skip_n_iter=1, storage_loader=storage_loader))
+
+ if task.has_role(task.role.LEARNER):
+
+ def learner_context(ctx: OnlineRLContext):
+ assert len(ctx.trajectories) == 2
+ assert len(ctx.trajectory_end_idx) == 4
+ assert len(ctx.episodes) == 8
+ assert ctx.env_step > 0
+ assert ctx.env_episode > 0
+ yield
+ ctx.train_iter += 1
+
+ task.use(learner_context)
+ elif task.has_role(task.role.COLLECTOR):
+
+ def collector_context(ctx: OnlineRLContext):
+ if ctx.total_step > 0:
+ assert ctx.train_iter > 0
+ yield
+ ctx.trajectories = [np.random.rand(10, 10) for _ in range(2)]
+ ctx.trajectory_end_idx = [1 for _ in range(4)]
+ ctx.episodes = [np.random.rand(10, 10) for _ in range(8)]
+ ctx.env_step += 1
+ ctx.env_episode += 1
+
+ task.use(collector_context)
+
+ task.run(max_step=3)
+ finally:
+ storage_loader.shutdown()
+ sleep(1)
+ if path.exists(tempdir):
+ shutil.rmtree(tempdir)
+
+
+@pytest.mark.tmp
+def test_context_exchanger_with_storage_loader():
+ Parallel.runner(n_parallel_workers=2)(context_exchanger_with_storage_loader_main)
+
+
+class MockPolicy:
+
+ def __init__(self) -> None:
+ self._model = self._get_model(10, 10)
+
+ def _get_model(self, X_shape, y_shape) -> torch.nn.Module:
+ return torch.nn.Sequential(
+ torch.nn.Linear(X_shape, 24), torch.nn.ReLU(), torch.nn.Linear(24, 24), torch.nn.ReLU(),
+ torch.nn.Linear(24, y_shape)
+ )
+
+ def train(self, X, y):
+ loss_fn = torch.nn.MSELoss(reduction="mean")
+ optimizer = torch.optim.Adam(self._model.parameters(), lr=0.01)
+ y_pred = self._model(X)
+ loss = loss_fn(y_pred, y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+
+ def predict(self, X):
+ with torch.no_grad():
+ return self._model(X)
+
+
+def model_exchanger_main():
+ with task.start(ctx=OnlineRLContext()):
+ set_pkg_seed(0, use_cuda=False)
+ policy = MockPolicy()
+ X = torch.rand(10)
+ y = torch.rand(10)
+
+ if task.router.node_id == 0:
+ task.add_role(task.role.LEARNER)
+ else:
+ task.add_role(task.role.COLLECTOR)
+
+ task.use(ModelExchanger(policy._model))
+
+ if task.has_role(task.role.LEARNER):
+
+ def train(ctx):
+ policy.train(X, y)
+ sleep(0.3)
+
+ task.use(train)
+ else:
+ y_pred1 = policy.predict(X)
+
+ def pred(ctx):
+ if ctx.total_step > 0:
+ y_pred2 = policy.predict(X)
+ # Ensure model is upgraded
+ assert any(y_pred1 != y_pred2)
+ sleep(0.3)
+
+ task.use(pred)
+
+ task.run(2)
+
+
+@pytest.mark.tmp
+def test_model_exchanger():
+ Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main)
+
+
+def model_exchanger_main_with_model_loader():
+ with task.start(ctx=OnlineRLContext()):
+ set_pkg_seed(0, use_cuda=False)
+ policy = MockPolicy()
+ X = torch.rand(10)
+ y = torch.rand(10)
+
+ if task.router.node_id == 0:
+ task.add_role(task.role.LEARNER)
+ else:
+ task.add_role(task.role.COLLECTOR)
+
+ tempdir = path.join(tempfile.gettempdir(), "test_model_loader")
+ model_loader = FileModelLoader(policy._model, dirname=tempdir)
+ task.use(ModelExchanger(policy._model, model_loader=model_loader))
+
+ try:
+ if task.has_role(task.role.LEARNER):
+
+ def train(ctx):
+ policy.train(X, y)
+ sleep(0.3)
+
+ task.use(train)
+ else:
+ y_pred1 = policy.predict(X)
+
+ def pred(ctx):
+ if ctx.total_step > 0:
+ y_pred2 = policy.predict(X)
+ # Ensure model is upgraded
+ assert any(y_pred1 != y_pred2)
+ sleep(0.3)
+
+ task.use(pred)
+ task.run(2)
+ finally:
+ model_loader.shutdown()
+ sleep(0.3)
+ if path.exists(tempdir):
+ shutil.rmtree(tempdir)
+
+
+@pytest.mark.tmp
+def test_model_exchanger_with_model_loader():
+ Parallel.runner(n_parallel_workers=2, startup_interval=0)(model_exchanger_main_with_model_loader)
+
+
+def periodical_model_exchanger_main():
+ with task.start(ctx=OnlineRLContext()):
+ set_pkg_seed(0, use_cuda=False)
+ policy = MockPolicy()
+ X = torch.rand(10)
+ y = torch.rand(10)
+
+ if task.router.node_id == 0:
+ task.add_role(task.role.LEARNER)
+ task.use(PeriodicalModelExchanger(policy._model, mode="send", period=3))
+ else:
+ task.add_role(task.role.COLLECTOR)
+ task.use(PeriodicalModelExchanger(policy._model, mode="receive", period=1, stale_toleration=3))
+
+ if task.has_role(task.role.LEARNER):
+
+ def train(ctx):
+ policy.train(X, y)
+ sleep(0.3)
+
+ task.use(train)
+ else:
+ y_pred1 = policy.predict(X)
+ print("y_pred1: ", y_pred1)
+ stale = 1
+
+ def pred(ctx):
+ nonlocal stale
+ y_pred2 = policy.predict(X)
+ print("y_pred2: ", y_pred2)
+ stale += 1
+ assert stale <= 3 or all(y_pred1 == y_pred2)
+ if any(y_pred1 != y_pred2):
+ stale = 1
+
+ sleep(0.3)
+
+ task.use(pred)
+ task.run(8)
+
+
+@pytest.mark.tmp
+def test_periodical_model_exchanger():
+ Parallel.runner(n_parallel_workers=2, startup_interval=0)(periodical_model_exchanger_main)
diff --git a/DI-engine/ding/framework/middleware/tests/test_enhancer.py b/DI-engine/ding/framework/middleware/tests/test_enhancer.py
new file mode 100644
index 0000000000000000000000000000000000000000..10d34b264f753c7e1a3b37de61a0ea8fd6b26697
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/test_enhancer.py
@@ -0,0 +1,65 @@
+import pytest
+import torch
+from ding.framework import OnlineRLContext
+from ding.data.buffer import DequeBuffer
+from typing import Any
+import numpy as np
+import copy
+from ding.framework.middleware.functional.enhancer import reward_estimator, her_data_enhancer
+from unittest.mock import Mock, patch
+from ding.framework.middleware.tests import MockHerRewardModel, CONFIG
+
+DATA = [{'obs': torch.rand(2, 2), 'next_obs': torch.rand(2, 2)} for _ in range(20)]
+
+
+class MockRewardModel(Mock):
+
+ def estimate(self, data: list) -> Any:
+ assert len(data) == len(DATA)
+ assert torch.equal(data[0]['obs'], DATA[0]['obs'])
+
+
+@pytest.mark.unittest
+def test_reward_estimator():
+ ctx = OnlineRLContext()
+ ctx.train_data = copy.deepcopy(DATA)
+ with patch("ding.reward_model.HerRewardModel", MockHerRewardModel):
+ reward_estimator(cfg=None, reward_model=MockRewardModel())(ctx)
+
+
+@pytest.mark.unittest
+def test_her_data_enhancer():
+ cfg = copy.deepcopy(CONFIG)
+ ctx = OnlineRLContext()
+
+ with patch("ding.reward_model.HerRewardModel", MockHerRewardModel):
+ mock_her_reward_model = MockHerRewardModel()
+ buffer = DequeBuffer(mock_her_reward_model.episode_size)
+
+ train_data = [
+ [
+ {
+ 'action': torch.randint(low=0, high=5, size=(1, )),
+ 'collect_train_iter': torch.tensor([0]),
+ 'done': torch.tensor(False),
+ 'next_obs': torch.randint(low=0, high=2, size=(10, ), dtype=torch.float32),
+ 'obs': torch.randint(low=0, high=2, size=(10, ), dtype=torch.float32),
+ 'reward': torch.randint(low=0, high=2, size=(1, ), dtype=torch.float32),
+ } for _ in range(np.random.choice([1, 4, 5], size=1)[0])
+ ] for _ in range(mock_her_reward_model.episode_size)
+ ]
+
+ for d in train_data:
+ buffer.push(d)
+
+ her_data_enhancer(cfg=cfg, buffer_=buffer, her_reward_model=MockHerRewardModel())(ctx)
+ assert len(ctx.train_data) == mock_her_reward_model.episode_size * mock_her_reward_model.episode_element_size
+ assert len(ctx.train_data[0]) == 6
+
+ buffer = DequeBuffer(cfg.policy.learn.batch_size)
+ for d in train_data:
+ buffer.push(d)
+ mock_her_reward_model.episode_size = None
+ her_data_enhancer(cfg=cfg, buffer_=buffer, her_reward_model=MockHerRewardModel())(ctx)
+ assert len(ctx.train_data) == cfg.policy.learn.batch_size * mock_her_reward_model.episode_element_size
+ assert len(ctx.train_data[0]) == 6
diff --git a/DI-engine/ding/framework/middleware/tests/test_evaluator.py b/DI-engine/ding/framework/middleware/tests/test_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e78b150bc9cfdd40efab6fa228e9cfe6144cc3e
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/test_evaluator.py
@@ -0,0 +1,27 @@
+import pytest
+import torch
+import copy
+from unittest.mock import patch
+from ding.framework import OnlineRLContext, task
+from ding.framework.middleware import interaction_evaluator
+from ding.framework.middleware.tests import MockPolicy, MockEnv, CONFIG
+
+
+@pytest.mark.unittest
+def test_interaction_evaluator():
+ cfg = copy.deepcopy(CONFIG)
+ ctx = OnlineRLContext()
+ with patch("ding.policy.Policy", MockPolicy), patch("ding.envs.BaseEnvManagerV2", MockEnv):
+ with task.start():
+ policy = MockPolicy()
+ env = MockEnv()
+ for i in range(30):
+ ctx.train_iter += 1
+ interaction_evaluator(cfg, policy, env)(ctx)
+ # interaction_evaluator will run every 10 train_iter in the test
+ assert ctx.last_eval_iter == i // 10 * 10 + 1
+ # the reward will increase 1.0 each step.
+ # there are 2 env_num and 5 episodes in the test.
+ # so when interaction_evaluator runs the first time, reward is [[1, 2, 3], [2, 3]] and the avg = 2.2
+ # the second time, reward is [[4, 5, 6], [5, 6]] . . .
+ assert ctx.eval_value == 2.2 + i // 10 * 3.0
diff --git a/DI-engine/ding/framework/middleware/tests/test_explorer.py b/DI-engine/ding/framework/middleware/tests/test_explorer.py
new file mode 100644
index 0000000000000000000000000000000000000000..89eb3b22f6c4fc1d48bce703d4f48f7b1aba7870
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/test_explorer.py
@@ -0,0 +1,27 @@
+import pytest
+import copy
+from ding.framework import OnlineRLContext
+from ding.framework.middleware import eps_greedy_handler, eps_greedy_masker
+from ding.framework.middleware.tests import MockPolicy, MockEnv, CONFIG
+
+
+@pytest.mark.unittest
+def test_eps_greedy_handler():
+ cfg = copy.deepcopy(CONFIG)
+ ctx = OnlineRLContext()
+
+ ctx.env_step = 0
+ next(eps_greedy_handler(cfg)(ctx))
+ assert ctx.collect_kwargs['eps'] == 0.95
+
+ ctx.env_step = 1000000
+ next(eps_greedy_handler(cfg)(ctx))
+ assert ctx.collect_kwargs['eps'] == 0.1
+
+
+@pytest.mark.unittest
+def test_eps_greedy_masker():
+ ctx = OnlineRLContext()
+ for _ in range(10):
+ eps_greedy_masker()(ctx)
+ assert ctx.collect_kwargs['eps'] == -1
diff --git a/DI-engine/ding/framework/middleware/tests/test_logger.py b/DI-engine/ding/framework/middleware/tests/test_logger.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c742a1c2d8c52375ed1764d2ac52d3ca4a2e283
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/test_logger.py
@@ -0,0 +1,295 @@
+from os import path
+import os
+import copy
+from easydict import EasyDict
+from collections import deque
+import pytest
+import shutil
+import wandb
+import h5py
+import torch.nn as nn
+from unittest.mock import MagicMock
+from unittest.mock import Mock, patch
+
+from ding.utils import DistributedWriter
+from ding.framework.middleware.tests import MockPolicy, CONFIG
+from ding.framework import OnlineRLContext, OfflineRLContext
+from ding.framework.middleware.functional import online_logger, offline_logger, wandb_online_logger, \
+ wandb_offline_logger
+
+test_folder = "test_exp"
+test_path = path.join(os.getcwd(), test_folder)
+cfg = EasyDict({"exp_name": "test_exp"})
+
+
+def get_online_ctx():
+ ctx = OnlineRLContext()
+ ctx.eval_value = -10000
+ ctx.train_iter = 34
+ ctx.env_step = 78
+ ctx.train_output = {'priority': [107], '[histogram]test_histogram': [1, 2, 3, 4, 5, 6], 'td_error': 15}
+ return ctx
+
+
+@pytest.fixture(scope='function')
+def online_ctx_output_dict():
+ ctx = get_online_ctx()
+ return ctx
+
+
+@pytest.fixture(scope='function')
+def online_ctx_output_deque():
+ ctx = get_online_ctx()
+ ctx.train_output = deque([ctx.train_output])
+ return ctx
+
+
+@pytest.fixture(scope='function')
+def online_ctx_output_list():
+ ctx = get_online_ctx()
+ ctx.train_output = [ctx.train_output]
+ return ctx
+
+
+@pytest.fixture(scope='function')
+def online_scalar_ctx():
+ ctx = get_online_ctx()
+ ctx.train_output = {'[scalars]': 1}
+ return ctx
+
+
+class MockOnlineWriter:
+
+ def __init__(self):
+ self.ctx = get_online_ctx()
+
+ def add_scalar(self, tag, scalar_value, global_step):
+ if tag in ['basic/eval_episode_return_mean-env_step', 'basic/eval_episode_return_mean']:
+ assert scalar_value == self.ctx.eval_value
+ assert global_step == self.ctx.env_step
+ elif tag == 'basic/eval_episode_return_mean-train_iter':
+ assert scalar_value == self.ctx.eval_value
+ assert global_step == self.ctx.train_iter
+ elif tag in ['basic/train_td_error-env_step', 'basic/train_td_error']:
+ assert scalar_value == self.ctx.train_output['td_error']
+ assert global_step == self.ctx.env_step
+ elif tag == 'basic/train_td_error-train_iter':
+ assert scalar_value == self.ctx.train_output['td_error']
+ assert global_step == self.ctx.train_iter
+ else:
+ raise NotImplementedError('tag should be in the tags defined')
+
+ def add_histogram(self, tag, values, global_step):
+ assert tag == 'test_histogram'
+ assert values == [1, 2, 3, 4, 5, 6]
+ assert global_step in [self.ctx.train_iter, self.ctx.env_step]
+
+ def close(self):
+ pass
+
+
+def mock_get_online_instance():
+ return MockOnlineWriter()
+
+
+@pytest.mark.unittest
+class TestOnlineLogger:
+
+ def test_online_logger_output_dict(self, online_ctx_output_dict):
+ with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance):
+ online_logger()(online_ctx_output_dict)
+
+ def test_online_logger_record_output_dict(self, online_ctx_output_dict):
+ with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance):
+ online_logger(record_train_iter=True)(online_ctx_output_dict)
+
+ def test_online_logger_record_output_deque(self, online_ctx_output_deque):
+ with patch.object(DistributedWriter, 'get_instance', new=mock_get_online_instance):
+ online_logger()(online_ctx_output_deque)
+
+
+def get_offline_ctx():
+ ctx = OfflineRLContext()
+ ctx.eval_value = -10000000000
+ ctx.train_iter = 3333
+ ctx.train_output = {'priority': [107], '[histogram]test_histogram': [1, 2, 3, 4, 5, 6], 'td_error': 15}
+ return ctx
+
+
+@pytest.fixture(scope='function')
+def offline_ctx_output_dict():
+ ctx = get_offline_ctx()
+ return ctx
+
+
+@pytest.fixture(scope='function')
+def offline_scalar_ctx():
+ ctx = get_offline_ctx()
+ ctx.train_output = {'[scalars]': 1}
+ return ctx
+
+
+class MockOfflineWriter:
+
+ def __init__(self):
+ self.ctx = get_offline_ctx()
+
+ def add_scalar(self, tag, scalar_value, global_step):
+ assert global_step == self.ctx.train_iter
+ if tag == 'basic/eval_episode_return_mean-train_iter':
+ assert scalar_value == self.ctx.eval_value
+ elif tag == 'basic/train_td_error-train_iter':
+ assert scalar_value == self.ctx.train_output['td_error']
+ else:
+ raise NotImplementedError('tag should be in the tags defined')
+
+ def add_histogram(self, tag, values, global_step):
+ assert tag == 'test_histogram'
+ assert values == [1, 2, 3, 4, 5, 6]
+ assert global_step == self.ctx.train_iter
+
+ def close(self):
+ pass
+
+
+def mock_get_offline_instance():
+ return MockOfflineWriter()
+
+
+class TestOfflineLogger:
+
+ def test_offline_logger_no_scalars(self, offline_ctx_output_dict):
+ with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance):
+ offline_logger()(offline_ctx_output_dict)
+
+ def test_offline_logger_scalars(self, offline_scalar_ctx):
+ with patch.object(DistributedWriter, 'get_instance', new=mock_get_offline_instance):
+ with pytest.raises(NotImplementedError) as exc_info:
+ offline_logger()(offline_scalar_ctx)
+
+
+class TheModelClass(nn.Module):
+
+ def state_dict(self):
+ return 'fake_state_dict'
+
+
+class TheEnvClass(Mock):
+
+ def enable_save_replay(self, replay_path):
+ return
+
+
+class TheObsDataClass(Mock):
+
+ def __getitem__(self, index):
+ return [[1, 1, 1]] * 50
+
+
+class The1DDataClass(Mock):
+
+ def __getitem__(self, index):
+ return [[1]] * 50
+
+
+@pytest.mark.unittest
+def test_wandb_online_logger():
+ record_path = './video_qbert_dqn'
+ cfg = EasyDict(
+ dict(
+ gradient_logger=True,
+ plot_logger=True,
+ action_logger=True,
+ return_logger=True,
+ video_logger=True,
+ )
+ )
+ env = TheEnvClass()
+ ctx = OnlineRLContext()
+ ctx.train_output = [{'reward': 1, 'q_value': [1.0]}]
+ model = TheModelClass()
+ wandb.init(config=cfg, anonymous="must")
+
+ def mock_metric_logger(data, step):
+ metric_list = [
+ "q_value",
+ "target q_value",
+ "loss",
+ "lr",
+ "entropy",
+ "reward",
+ "q value",
+ "video",
+ "q value distribution",
+ "train iter",
+ "episode return mean",
+ "env step",
+ "action",
+ "actions_of_trajectory_0",
+ "actions_of_trajectory_1",
+ "actions_of_trajectory_2",
+ "actions_of_trajectory_3",
+ "return distribution",
+ ]
+ assert set(data.keys()) <= set(metric_list)
+
+ def mock_gradient_logger(input_model, log, log_freq, log_graph):
+ assert input_model == model
+
+ def test_wandb_online_logger_metric():
+ with patch.object(wandb, 'log', new=mock_metric_logger):
+ wandb_online_logger(record_path, cfg, env=env, model=model, anonymous=True)(ctx)
+
+ def test_wandb_online_logger_gradient():
+ with patch.object(wandb, 'watch', new=mock_gradient_logger):
+ wandb_online_logger(record_path, cfg, env=env, model=model, anonymous=True)(ctx)
+
+ test_wandb_online_logger_metric()
+ test_wandb_online_logger_gradient()
+
+
+@pytest.mark.tmp
+def test_wandb_offline_logger():
+ record_path = './video_pendulum_cql'
+ cfg = EasyDict(dict(gradient_logger=True, plot_logger=True, action_logger=True, vis_dataset=True))
+ env = TheEnvClass()
+ ctx = OfflineRLContext()
+ ctx.train_output = [{'reward': 1, 'q_value': [1.0]}]
+ model = TheModelClass()
+ wandb.init(config=cfg, anonymous="must")
+ exp_config = EasyDict(dict(dataset_path='dataset.h5'))
+
+ def mock_metric_logger(data, step=None):
+ metric_list = [
+ "q_value", "target q_value", "loss", "lr", "entropy", "reward", "q value", "video", "q value distribution",
+ "train iter", 'dataset'
+ ]
+ assert set(data.keys()) < set(metric_list)
+
+ def mock_gradient_logger(input_model, log, log_freq, log_graph):
+ assert input_model == model
+
+ def mock_image_logger(imagepath):
+ assert os.path.splitext(imagepath)[-1] == '.png'
+
+ def test_wandb_offline_logger_gradient():
+ cfg.vis_dataset = False
+ print(cfg)
+ with patch.object(wandb, 'watch', new=mock_gradient_logger):
+ wandb_offline_logger(
+ record_path=record_path, cfg=cfg, exp_config=exp_config, env=env, model=model, anonymous=True
+ )(ctx)
+
+ def test_wandb_offline_logger_dataset():
+ cfg.vis_dataset = True
+ m = MagicMock()
+ m.__enter__.return_value = {'obs': TheObsDataClass(), 'action': The1DDataClass(), 'reward': The1DDataClass()}
+ with patch.object(wandb, 'log', new=mock_metric_logger):
+ with patch.object(wandb, 'Image', new=mock_image_logger):
+ with patch('h5py.File', return_value=m):
+ wandb_offline_logger(
+ record_path=record_path, cfg=cfg, exp_config=exp_config, env=env, model=model, anonymous=True
+ )(ctx)
+
+ test_wandb_offline_logger_gradient()
+ test_wandb_offline_logger_dataset()
diff --git a/DI-engine/ding/framework/middleware/tests/test_priority.py b/DI-engine/ding/framework/middleware/tests/test_priority.py
new file mode 100644
index 0000000000000000000000000000000000000000..19261213d6e3ea02350fd92059549ea3d6812bc3
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/test_priority.py
@@ -0,0 +1,33 @@
+#unittest for priority_calculator
+
+import unittest
+import pytest
+import numpy as np
+from unittest.mock import Mock, patch
+from ding.framework import OnlineRLContext, OfflineRLContext
+from ding.framework import task, Parallel
+from ding.framework.middleware.functional import priority_calculator
+
+
+class MockPolicy(Mock):
+
+ def priority_fun(self, data):
+ return np.random.rand(len(data))
+
+
+@pytest.mark.unittest
+def test_priority_calculator():
+ policy = MockPolicy()
+ ctx = OnlineRLContext()
+ ctx.trajectories = [
+ {
+ 'obs': np.random.rand(2, 2),
+ 'next_obs': np.random.rand(2, 2),
+ 'reward': np.random.rand(1),
+ 'info': {}
+ } for _ in range(10)
+ ]
+ priority_calculator_middleware = priority_calculator(priority_calculation_fn=policy.priority_fun)
+ priority_calculator_middleware(ctx)
+ assert len(ctx.trajectories) == 10
+ assert all([isinstance(traj['priority'], float) for traj in ctx.trajectories])
diff --git a/DI-engine/ding/framework/middleware/tests/test_trainer.py b/DI-engine/ding/framework/middleware/tests/test_trainer.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9dbf9f55cc20397e1f62ce8156888dc1223b9ad
--- /dev/null
+++ b/DI-engine/ding/framework/middleware/tests/test_trainer.py
@@ -0,0 +1,116 @@
+import pytest
+import random
+import copy
+import torch
+import treetensor.torch as ttorch
+from unittest.mock import Mock, patch
+from ding.data.buffer import DequeBuffer
+from ding.framework import OnlineRLContext, task
+from ding.framework.middleware import trainer, multistep_trainer, OffPolicyLearner, HERLearner
+from ding.framework.middleware.tests import MockHerRewardModel, CONFIG
+
+
+class MockPolicy(Mock):
+ _device = 'cpu'
+
+ # MockPolicy class for train mode
+ def forward(self, train_data, **kwargs):
+ res = {
+ 'total_loss': 0.1,
+ }
+ return res
+
+
+class MultiStepMockPolicy(Mock):
+ _device = 'cpu'
+
+ # MockPolicy class for multi-step train mode
+ def forward(self, train_data, **kwargs):
+ res = [
+ {
+ 'total_loss': 0.1,
+ },
+ {
+ 'total_loss': 1.0,
+ },
+ ]
+ return res
+
+
+def get_mock_train_input():
+ data = {'obs': torch.rand(2, 2), 'next_obs': torch.rand(2, 2), 'reward': random.random(), 'info': {}}
+ return ttorch.as_tensor(data)
+
+
+@pytest.mark.unittest
+def test_trainer():
+ cfg = copy.deepcopy(CONFIG)
+ ctx = OnlineRLContext()
+
+ ctx.train_data = None
+ with patch("ding.policy.Policy", MockPolicy):
+ policy = MockPolicy()
+ for _ in range(10):
+ trainer(cfg, policy)(ctx)
+ assert ctx.train_iter == 0
+
+ ctx.train_data = get_mock_train_input()
+ with patch("ding.policy.Policy", MockPolicy):
+ policy = MockPolicy()
+ for _ in range(30):
+ trainer(cfg, policy)(ctx)
+ assert ctx.train_iter == 30
+ assert ctx.train_output["total_loss"] == 0.1
+
+
+@pytest.mark.unittest
+def test_multistep_trainer():
+ cfg = copy.deepcopy(CONFIG)
+ ctx = OnlineRLContext()
+
+ ctx.train_data = None
+ with patch("ding.policy.Policy", MockPolicy):
+ policy = MockPolicy()
+ for _ in range(10):
+ trainer(cfg, policy)(ctx)
+ assert ctx.train_iter == 0
+
+ ctx.train_data = get_mock_train_input()
+ with patch("ding.policy.Policy", MultiStepMockPolicy):
+ policy = MultiStepMockPolicy()
+ for _ in range(30):
+ multistep_trainer(policy, 10)(ctx)
+ assert ctx.train_iter == 60
+ assert ctx.train_output[0]["total_loss"] == 0.1
+ assert ctx.train_output[1]["total_loss"] == 1.0
+
+
+@pytest.mark.unittest
+def test_offpolicy_learner():
+ cfg = copy.deepcopy(CONFIG)
+ ctx = OnlineRLContext()
+ buffer = DequeBuffer(size=10)
+ for _ in range(10):
+ buffer.push(get_mock_train_input())
+ with patch("ding.policy.Policy", MockPolicy):
+ with task.start():
+ policy = MockPolicy()
+ learner = OffPolicyLearner(cfg, policy, buffer)
+ learner(ctx)
+ assert len(ctx.train_output) == 4
+
+
+@pytest.mark.unittest
+def test_her_learner():
+ cfg = copy.deepcopy(CONFIG)
+ ctx = OnlineRLContext()
+ buffer = DequeBuffer(size=10)
+ for _ in range(10):
+ buffer.push([get_mock_train_input(), get_mock_train_input()])
+ with patch("ding.policy.Policy", MockPolicy), patch("ding.reward_model.HerRewardModel", MockHerRewardModel):
+ with task.start():
+ policy = MockPolicy()
+ her_reward_model = MockHerRewardModel()
+ learner = HERLearner(cfg, policy, buffer, her_reward_model)
+ learner(ctx)
+ assert len(ctx.train_output) == 4
diff --git a/DI-engine/ding/framework/parallel.py b/DI-engine/ding/framework/parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..df7b430a8f111fba11b9fa1a80d49e545d889993
--- /dev/null
+++ b/DI-engine/ding/framework/parallel.py
@@ -0,0 +1,412 @@
+import atexit
+import os
+import random
+import time
+import traceback
+import pickle
+from mpire.pool import WorkerPool
+from ditk import logging
+import tempfile
+import socket
+from os import path
+from typing import Callable, Dict, List, Optional, Tuple, Union, Set
+from threading import Thread
+from ding.framework.event_loop import EventLoop
+from ding.utils.design_helper import SingletonMetaclass
+from ding.framework.message_queue import *
+from ding.utils.registry_factory import MQ_REGISTRY
+
+# Avoid ipc address conflict, random should always use random seed
+random = random.Random()
+
+
+class Parallel(metaclass=SingletonMetaclass):
+
+ def __init__(self) -> None:
+ # Init will only be called once in a process
+ self._listener = None
+ self.is_active = False
+ self.node_id = None
+ self.local_id = None
+ self.labels = set()
+ self._event_loop = EventLoop("parallel_{}".format(id(self)))
+ self._retries = 0 # Retries in auto recovery
+
+ def _run(
+ self,
+ node_id: int,
+ local_id: int,
+ n_parallel_workers: int,
+ labels: Optional[Set[str]] = None,
+ auto_recover: bool = False,
+ max_retries: int = float("inf"),
+ mq_type: str = "nng",
+ startup_interval: int = 1,
+ **kwargs
+ ) -> None:
+ self.node_id = node_id
+ self.local_id = local_id
+ self.startup_interval = startup_interval
+ self.n_parallel_workers = n_parallel_workers
+ self.labels = labels or set()
+ self.auto_recover = auto_recover
+ self.max_retries = max_retries
+ self._mq = MQ_REGISTRY.get(mq_type)(**kwargs)
+ time.sleep(self.local_id * self.startup_interval)
+ self._listener = Thread(target=self.listen, name="mq_listener", daemon=True)
+ self._listener.start()
+
+ self.mq_type = mq_type
+ self.barrier_runtime = Parallel.get_barrier_runtime()(self.node_id)
+
+ @classmethod
+ def runner(
+ cls,
+ n_parallel_workers: int,
+ mq_type: str = "nng",
+ attach_to: Optional[List[str]] = None,
+ protocol: str = "ipc",
+ address: Optional[str] = None,
+ ports: Optional[Union[List[int], int]] = None,
+ topology: str = "mesh",
+ labels: Optional[Set[str]] = None,
+ node_ids: Optional[Union[List[int], int]] = None,
+ auto_recover: bool = False,
+ max_retries: int = float("inf"),
+ redis_host: Optional[str] = None,
+ redis_port: Optional[int] = None,
+ startup_interval: int = 1
+ ) -> Callable:
+ """
+ Overview:
+ This method allows you to configure parallel parameters, and now you are still in the parent process.
+ Arguments:
+ - n_parallel_workers (:obj:`int`): Workers to spawn.
+ - mq_type (:obj:`str`): Embedded message queue type, i.e. nng, redis.
+ - attach_to (:obj:`Optional[List[str]]`): The node's addresses you want to attach to.
+ - protocol (:obj:`str`): Network protocol.
+ - address (:obj:`Optional[str]`): Bind address, ip or file path.
+ - ports (:obj:`Optional[List[int]]`): Candidate ports.
+ - topology (:obj:`str`): Network topology, includes:
+ `mesh` (default): fully connected between each other;
+ `star`: only connect to the first node;
+ `alone`: do not connect to any node, except the node attached to;
+ - labels (:obj:`Optional[Set[str]]`): Labels.
+ - node_ids (:obj:`Optional[List[int]]`): Candidate node ids.
+ - auto_recover (:obj:`bool`): Auto recover from uncaught exceptions from main.
+ - max_retries (:obj:`int`): Max retries for auto recover.
+ - redis_host (:obj:`str`): Redis server host.
+ - redis_port (:obj:`int`): Redis server port.
+ - startup_interval (:obj:`int`): Start up interval between each task.
+ Returns:
+ - _runner (:obj:`Callable`): The wrapper function for main.
+ """
+ all_args = locals()
+ del all_args["cls"]
+ args_parsers = {"nng": cls._nng_args_parser, "redis": cls._redis_args_parser}
+
+ assert n_parallel_workers > 0, "Parallel worker number should bigger than 0"
+
+ def _runner(main_process: Callable, *args, **kwargs) -> None:
+ """
+ Overview:
+ Prepare to run in subprocess.
+ Arguments:
+ - main_process (:obj:`Callable`): The main function, your program start from here.
+ """
+ runner_params = args_parsers[mq_type](**all_args)
+ params_group = []
+ for i, runner_kwargs in enumerate(runner_params):
+ runner_kwargs["local_id"] = i
+ params_group.append([runner_kwargs, (main_process, args, kwargs)])
+
+ if n_parallel_workers == 1:
+ cls._subprocess_runner(*params_group[0])
+ else:
+ with WorkerPool(n_jobs=n_parallel_workers, start_method="spawn", daemon=False) as pool:
+ # Cleanup the pool just in case the program crashes.
+ atexit.register(pool.__exit__)
+ pool.map(cls._subprocess_runner, params_group)
+
+ return _runner
+
+ @classmethod
+ def _nng_args_parser(
+ cls,
+ n_parallel_workers: int,
+ attach_to: Optional[List[str]] = None,
+ protocol: str = "ipc",
+ address: Optional[str] = None,
+ ports: Optional[Union[List[int], int]] = None,
+ topology: str = "mesh",
+ node_ids: Optional[Union[List[int], int]] = None,
+ **kwargs
+ ) -> Dict[str, dict]:
+ attach_to = attach_to or []
+ nodes = cls.get_node_addrs(n_parallel_workers, protocol=protocol, address=address, ports=ports)
+
+ def cleanup_nodes():
+ for node in nodes:
+ protocol, file_path = node.split("://")
+ if protocol == "ipc" and path.exists(file_path):
+ os.remove(file_path)
+
+ atexit.register(cleanup_nodes)
+
+ def topology_network(i: int) -> List[str]:
+ if topology == "mesh":
+ return nodes[:i] + attach_to
+ elif topology == "star":
+ return nodes[:min(1, i)] + attach_to
+ elif topology == "alone":
+ return attach_to
+ else:
+ raise ValueError("Unknown topology: {}".format(topology))
+
+ runner_params = []
+ candidate_node_ids = cls.padding_param(node_ids, n_parallel_workers, 0)
+ for i in range(n_parallel_workers):
+ runner_kwargs = {
+ **kwargs,
+ "node_id": candidate_node_ids[i],
+ "listen_to": nodes[i],
+ "attach_to": topology_network(i),
+ "n_parallel_workers": n_parallel_workers,
+ }
+ runner_params.append(runner_kwargs)
+
+ return runner_params
+
+ @classmethod
+ def _redis_args_parser(cls, n_parallel_workers: int, node_ids: Optional[Union[List[int], int]] = None, **kwargs):
+ runner_params = []
+ candidate_node_ids = cls.padding_param(node_ids, n_parallel_workers, 0)
+ for i in range(n_parallel_workers):
+ runner_kwargs = {**kwargs, "n_parallel_workers": n_parallel_workers, "node_id": candidate_node_ids[i]}
+ runner_params.append(runner_kwargs)
+ return runner_params
+
+ @classmethod
+ def _subprocess_runner(cls, runner_kwargs: dict, main_params: Tuple[Union[List, Dict]]) -> None:
+ """
+ Overview:
+ Really run in subprocess.
+ Arguments:
+ - runner_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for runner.
+ - main_params (:obj:`Tuple[Union[List, Dict]]`): Args and kwargs for main function.
+ """
+ logging.getLogger().setLevel(logging.INFO)
+ main_process, args, kwargs = main_params
+
+ with Parallel() as router:
+ router.is_active = True
+ router._run(**runner_kwargs)
+ time.sleep(0.3) # Waiting for network pairing
+ router._supervised_runner(main_process, *args, **kwargs)
+
+ def _supervised_runner(self, main: Callable, *args, **kwargs) -> None:
+ """
+ Overview:
+ Run in supervised mode.
+ Arguments:
+ - main (:obj:`Callable`): Main function.
+ """
+ if self.auto_recover:
+ while True:
+ try:
+ main(*args, **kwargs)
+ break
+ except Exception as e:
+ if self._retries < self.max_retries:
+ logging.warning(
+ "Auto recover from exception: {}, node: {}, retries: {}".format(
+ e, self.node_id, self._retries
+ )
+ )
+ logging.warning(traceback.format_exc())
+ self._retries += 1
+ else:
+ logging.warning(
+ "Exceed the max retries, node: {}, retries: {}, max_retries: {}".format(
+ self.node_id, self._retries, self.max_retries
+ )
+ )
+ raise e
+ else:
+ main(*args, **kwargs)
+
+ @classmethod
+ def get_node_addrs(
+ cls,
+ n_workers: int,
+ protocol: str = "ipc",
+ address: Optional[str] = None,
+ ports: Optional[Union[List[int], int]] = None
+ ) -> None:
+ if protocol == "ipc":
+ node_name = "".join(random.choices("abcdefghijklmnopqrstuvwxyz0123456789", k=4))
+ tmp_dir = tempfile.gettempdir()
+ nodes = ["ipc://{}/ditask_{}_{}.ipc".format(tmp_dir, node_name, i) for i in range(n_workers)]
+ elif protocol == "tcp":
+ address = address or cls.get_ip()
+ ports = cls.padding_param(ports, n_workers, 50515)
+ assert len(ports) == n_workers, "The number of ports must be the same as the number of workers, \
+now there are {} ports and {} workers".format(len(ports), n_workers)
+ nodes = ["tcp://{}:{}".format(address, port) for port in ports]
+ else:
+ raise Exception("Unknown protocol {}".format(protocol))
+ return nodes
+
+ @classmethod
+ def padding_param(cls, int_or_list: Optional[Union[List[int], int]], n_max: int, start_value: int) -> List[int]:
+ """
+ Overview:
+ Padding int or list param to the length of n_max.
+ Arguments:
+ - int_or_list (:obj:`Optional[Union[List[int], int]]`): Int or list typed value.
+ - n_max (:obj:`int`): Max length.
+ - start_value (:obj:`int`): Start from value.
+ """
+ param = int_or_list
+ if isinstance(param, List) and len(param) == 1:
+ param = param[0] # List with only 1 element is equal to int
+
+ if isinstance(param, int):
+ param = range(param, param + n_max)
+ else:
+ param = param or range(start_value, start_value + n_max)
+ return param
+
+ def listen(self):
+ self._mq.listen()
+ while True:
+ if not self._mq:
+ break
+ msg = self._mq.recv()
+ # msg is none means that the message queue is no longer being listened to,
+ # especially if the message queue is already closed
+ if not msg:
+ break
+ topic, msg = msg
+ self._handle_message(topic, msg)
+
+ def on(self, event: str, fn: Callable) -> None:
+ """
+ Overview:
+ Register an remote event on parallel instance, this function will be executed \
+ when a remote process emit this event via network.
+ Arguments:
+ - event (:obj:`str`): Event name.
+ - fn (:obj:`Callable`): Function body.
+ """
+ if self.is_active:
+ self._mq.subscribe(event)
+ self._event_loop.on(event, fn)
+
+ def once(self, event: str, fn: Callable) -> None:
+ """
+ Overview:
+ Register an remote event which will only call once on parallel instance,
+ this function will be executed when a remote process emit this event via network.
+ Arguments:
+ - event (:obj:`str`): Event name.
+ - fn (:obj:`Callable`): Function body.
+ """
+ if self.is_active:
+ self._mq.subscribe(event)
+ self._event_loop.once(event, fn)
+
+ def off(self, event: str) -> None:
+ """
+ Overview:
+ Unregister an event.
+ Arguments:
+ - event (:obj:`str`): Event name.
+ """
+ if self.is_active:
+ self._mq.unsubscribe(event)
+ self._event_loop.off(event)
+
+ def emit(self, event: str, *args, **kwargs) -> None:
+ """
+ Overview:
+ Send an remote event via network to subscribed processes.
+ Arguments:
+ - event (:obj:`str`): Event name.
+ """
+ if self.is_active:
+ payload = {"a": args, "k": kwargs}
+ try:
+ data = pickle.dumps(payload, protocol=pickle.HIGHEST_PROTOCOL)
+ except AttributeError as e:
+ logging.error("Arguments are not pickable! Event: {}, Args: {}".format(event, args))
+ raise e
+ self._mq.publish(event, data)
+
+ def _handle_message(self, topic: str, msg: bytes) -> None:
+ """
+ Overview:
+ Recv and parse payload from other processes, and call local functions.
+ Arguments:
+ - topic (:obj:`str`): Recevied topic.
+ - msg (:obj:`bytes`): Recevied message.
+ """
+ event = topic
+ if not self._event_loop.listened(event):
+ logging.debug("Event {} was not listened in parallel {}".format(event, self.node_id))
+ return
+ try:
+ payload = pickle.loads(msg)
+ except Exception as e:
+ logging.error("Error when unpacking message on node {}, msg: {}".format(self.node_id, e))
+ return
+ self._event_loop.emit(event, *payload["a"], **payload["k"])
+
+ @classmethod
+ def get_ip(cls):
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ try:
+ # doesn't even have to be reachable
+ s.connect(('10.255.255.255', 1))
+ ip = s.getsockname()[0]
+ except Exception:
+ ip = '127.0.0.1'
+ finally:
+ s.close()
+ return ip
+
+ def get_attch_to_len(self) -> int:
+ """
+ Overview:
+ Get the length of the 'attach_to' list of message queue.
+ Returns:
+ int: the length of the self._mq.attach_to. Returns 0 if self._mq is not initialized
+ """
+ if self._mq:
+ if hasattr(self._mq, 'attach_to'):
+ return len(self._mq.attach_to)
+ return 0
+
+ def __enter__(self) -> "Parallel":
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.stop()
+
+ def stop(self):
+ logging.info("Stopping parallel worker on node: {}".format(self.node_id))
+ self.is_active = False
+ time.sleep(0.03)
+ if self._mq:
+ self._mq.stop()
+ self._mq = None
+ if self._listener:
+ self._listener.join(timeout=1)
+ self._listener = None
+ self._event_loop.stop()
+
+ @classmethod
+ def get_barrier_runtime(cls):
+ # We get the BarrierRuntime object in the closure to avoid circular import.
+ from ding.framework.middleware.barrier import BarrierRuntime
+ return BarrierRuntime
diff --git a/DI-engine/ding/framework/supervisor.py b/DI-engine/ding/framework/supervisor.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d385c12c6502d23fd3455d12d15b3370abf77d8
--- /dev/null
+++ b/DI-engine/ding/framework/supervisor.py
@@ -0,0 +1,387 @@
+from abc import ABC, abstractmethod
+import functools
+import torch.multiprocessing as mp
+from multiprocessing.context import BaseContext
+import threading
+import queue
+import platform
+import traceback
+import uuid
+import time
+from ditk import logging
+from dataclasses import dataclass, field
+from typing import Any, Callable, Dict, List, Optional, Union
+from enum import Enum
+
+
+@functools.lru_cache(maxsize=1)
+def get_mp_ctx() -> BaseContext:
+ context = 'spawn' if platform.system().lower() == 'windows' else 'fork'
+ mp_ctx = mp.get_context(context)
+ return mp_ctx
+
+
+@dataclass
+class SendPayload:
+ proc_id: int
+ # Use uuid1 here to include the timestamp
+ req_id: str = field(default_factory=lambda: uuid.uuid1().hex)
+ method: str = None
+ args: List = field(default_factory=list)
+ kwargs: Dict = field(default_factory=dict)
+
+
+@dataclass
+class RecvPayload:
+ proc_id: int
+ req_id: str = None
+ method: str = None
+ data: Any = None
+ err: Exception = None
+ extra: Any = None
+
+
+class ReserveMethod(Enum):
+ SHUTDOWN = "_shutdown"
+ GETATTR = "_getattr"
+
+
+class ChildType(Enum):
+ PROCESS = "process"
+ THREAD = "thread"
+
+
+class Child(ABC):
+ """
+ Abstract class of child process/thread.
+ """
+
+ def __init__(self, proc_id: int, init: Union[Callable, object], **kwargs) -> None:
+ self._proc_id = proc_id
+ self._init = init
+ self._recv_queue = None
+ self._send_queue = None
+
+ @abstractmethod
+ def start(self, recv_queue: Union[mp.Queue, queue.Queue]):
+ raise NotImplementedError
+
+ def restart(self):
+ self.shutdown()
+ self.start(self._recv_queue)
+
+ @abstractmethod
+ def shutdown(self, timeout: Optional[float] = None):
+ raise NotImplementedError
+
+ @abstractmethod
+ def send(self, payload: SendPayload):
+ raise NotImplementedError
+
+ def _target(
+ self,
+ proc_id: int,
+ init: Union[Callable, object],
+ send_queue: Union[mp.Queue, queue.Queue],
+ recv_queue: Union[mp.Queue, queue.Queue],
+ shm_buffer: Optional[Any] = None,
+ shm_callback: Optional[Callable] = None
+ ):
+ send_payload = SendPayload(proc_id=proc_id)
+ if isinstance(init, Callable):
+ child_ins = init()
+ else:
+ child_ins = init
+ while True:
+ try:
+ send_payload: SendPayload = send_queue.get()
+ if send_payload.method == ReserveMethod.SHUTDOWN:
+ break
+ if send_payload.method == ReserveMethod.GETATTR:
+ data = getattr(child_ins, send_payload.args[0])
+ else:
+ data = getattr(child_ins, send_payload.method)(*send_payload.args, **send_payload.kwargs)
+ recv_payload = RecvPayload(
+ proc_id=proc_id, req_id=send_payload.req_id, method=send_payload.method, data=data
+ )
+ if shm_callback is not None and shm_buffer is not None:
+ shm_callback(recv_payload, shm_buffer)
+ recv_queue.put(recv_payload)
+ except Exception as e:
+ logging.warning(traceback.format_exc())
+ logging.warning("Error in child process! id: {}, error: {}".format(self._proc_id, e))
+ recv_payload = RecvPayload(
+ proc_id=proc_id, req_id=send_payload.req_id, method=send_payload.method, err=e
+ )
+ recv_queue.put(recv_payload)
+
+ def __del__(self):
+ self.shutdown()
+
+
+class ChildProcess(Child):
+
+ def __init__(
+ self,
+ proc_id: int,
+ init: Union[Callable, object],
+ shm_buffer: Optional[Any] = None,
+ shm_callback: Optional[Callable] = None,
+ mp_ctx: Optional[BaseContext] = None,
+ **kwargs
+ ) -> None:
+ super().__init__(proc_id, init, **kwargs)
+ self._proc = None
+ self._mp_ctx = mp_ctx
+ self._shm_buffer = shm_buffer
+ self._shm_callback = shm_callback
+
+ def start(self, recv_queue: mp.Queue):
+ if self._proc is None:
+ self._recv_queue = recv_queue
+ ctx = self._mp_ctx or get_mp_ctx()
+ self._send_queue = ctx.Queue()
+ proc = ctx.Process(
+ target=self._target,
+ args=(
+ self._proc_id, self._init, self._send_queue, self._recv_queue, self._shm_buffer, self._shm_callback
+ ),
+ name="supervisor_child_{}_{}".format(self._proc_id, time.time()),
+ daemon=True
+ )
+ proc.start()
+ self._proc = proc
+
+ def shutdown(self, timeout: Optional[float] = None):
+ if self._proc:
+ self._send_queue.put(SendPayload(proc_id=self._proc_id, method=ReserveMethod.SHUTDOWN))
+ self._proc.terminate()
+ self._proc.join(timeout=timeout)
+ if hasattr(self._proc, "close"): # Compatible with 3.6
+ self._proc.close()
+ self._proc = None
+ self._send_queue.close()
+ self._send_queue.join_thread()
+ self._send_queue = None
+
+ def send(self, payload: SendPayload):
+ if self._send_queue is None:
+ logging.warning("Child worker has been terminated or not started.")
+ return
+ self._send_queue.put(payload)
+
+
+class ChildThread(Child):
+
+ def __init__(self, proc_id: int, init: Union[Callable, object], *args, **kwargs) -> None:
+ super().__init__(proc_id, init, *args, **kwargs)
+ self._thread = None
+
+ def start(self, recv_queue: queue.Queue):
+ if self._thread is None:
+ self._recv_queue = recv_queue
+ self._send_queue = queue.Queue()
+ thread = threading.Thread(
+ target=self._target,
+ args=(self._proc_id, self._init, self._send_queue, self._recv_queue),
+ name="supervisor_child_{}_{}".format(self._proc_id, time.time()),
+ daemon=True
+ )
+ thread.start()
+ self._thread = thread
+
+ def shutdown(self, timeout: Optional[float] = None):
+ if self._thread:
+ self._send_queue.put(SendPayload(proc_id=self._proc_id, method=ReserveMethod.SHUTDOWN))
+ self._thread.join(timeout=timeout)
+ self._thread = None
+ self._send_queue = None
+
+ def send(self, payload: SendPayload):
+ if self._send_queue is None:
+ logging.warning("Child worker has been terminated or not started.")
+ return
+ self._send_queue.put(payload)
+
+
+class Supervisor:
+
+ TYPE_MAPPING = {ChildType.PROCESS: ChildProcess, ChildType.THREAD: ChildThread}
+
+ def __init__(self, type_: ChildType, mp_ctx: Optional[BaseContext] = None) -> None:
+ self._children: List[Child] = []
+ self._type = type_
+ self._child_class = self.TYPE_MAPPING[self._type]
+ self._running = False
+ self.__queue = None
+ self._mp_ctx = mp_ctx or get_mp_ctx()
+
+ def register(
+ self,
+ init: Union[Callable, object],
+ shm_buffer: Optional[Any] = None,
+ shm_callback: Optional[Callable] = None
+ ) -> None:
+ proc_id = len(self._children)
+ self._children.append(
+ self._child_class(proc_id, init, shm_buffer=shm_buffer, shm_callback=shm_callback, mp_ctx=self._mp_ctx)
+ )
+
+ @property
+ def _recv_queue(self) -> Union[queue.Queue, mp.Queue]:
+ if not self.__queue:
+ if self._type is ChildType.PROCESS:
+ self.__queue = self._mp_ctx.Queue()
+ elif self._type is ChildType.THREAD:
+ self.__queue = queue.Queue()
+ return self.__queue
+
+ @_recv_queue.setter
+ def _recv_queue(self, queue: Union[queue.Queue, mp.Queue]):
+ self.__queue = queue
+
+ def start_link(self) -> None:
+ if not self._running:
+ for child in self._children:
+ child.start(recv_queue=self._recv_queue)
+ self._running = True
+
+ def send(self, payload: SendPayload) -> None:
+ """
+ Overview:
+ Send message to child process.
+ Arguments:
+ - payload (:obj:`SendPayload`): Send payload.
+ """
+ if not self._running:
+ logging.warning("Please call start_link before sending any payload to child process.")
+ return
+ self._children[payload.proc_id].send(payload)
+
+ def recv(self, ignore_err: bool = False, timeout: float = None) -> RecvPayload:
+ """
+ Overview:
+ Wait for message from child process
+ Arguments:
+ - ignore_err (:obj:`bool`): If ignore_err is True, put the err in the property of recv_payload. \
+ Otherwise, an exception will be raised.
+ - timeout (:obj:`float`): Timeout for queue.get, will raise an Empty exception if timeout.
+ Returns:
+ - recv_payload (:obj:`RecvPayload`): Recv payload.
+ """
+ recv_payload: RecvPayload = self._recv_queue.get(timeout=timeout)
+ if recv_payload.err and not ignore_err:
+ raise recv_payload.err
+ return recv_payload
+
+ def recv_all(
+ self,
+ send_payloads: List[SendPayload],
+ ignore_err: bool = False,
+ callback: Callable = None,
+ timeout: Optional[float] = None
+ ) -> List[RecvPayload]:
+ """
+ Overview:
+ Wait for messages with specific req ids until all ids are fulfilled.
+ Arguments:
+ - send_payloads (:obj:`List[SendPayload]`): Request payloads.
+ - ignore_err (:obj:`bool`): If ignore_err is True, \
+ put the err in the property of recv_payload. Otherwise, an exception will be raised. \
+ This option will also ignore timeout error.
+ - callback (:obj:`Callable`): Callback for each recv payload.
+ - timeout (:obj:`Optional[float]`): Timeout when wait for responses.
+ Returns:
+ - recv_payload (:obj:`List[RecvPayload]`): Recv payload, may contain timeout error.
+ """
+ assert send_payloads, "Req payload is empty!"
+ recv_payloads = {}
+ remain_payloads = {payload.req_id: payload for payload in send_payloads}
+ unrelated_payloads = []
+ try:
+ while remain_payloads:
+ try:
+ recv_payload: RecvPayload = self._recv_queue.get(block=True, timeout=timeout)
+ if recv_payload.req_id in remain_payloads:
+ del remain_payloads[recv_payload.req_id]
+ recv_payloads[recv_payload.req_id] = recv_payload
+ if recv_payload.err and not ignore_err:
+ raise recv_payload.err
+ if callback:
+ callback(recv_payload, remain_payloads)
+ else:
+ unrelated_payloads.append(recv_payload)
+ except queue.Empty:
+ if ignore_err:
+ req_ids = list(remain_payloads.keys())
+ logging.warning("Timeout ({}s) when receving payloads! Req ids: {}".format(timeout, req_ids))
+ for req_id in req_ids:
+ send_payload = remain_payloads.pop(req_id)
+ # If timeout error happens in timeout recover, there may not find any send_payload
+ # in the original indexed payloads.
+ recv_payload = RecvPayload(
+ proc_id=send_payload.proc_id,
+ req_id=send_payload.req_id,
+ method=send_payload.method,
+ err=TimeoutError("Timeout on req_id ({})".format(req_id))
+ )
+ recv_payloads[req_id] = recv_payload
+ if callback:
+ callback(recv_payload, remain_payloads)
+ else:
+ raise TimeoutError("Timeout ({}s) when receving payloads!".format(timeout))
+ finally:
+ # Put back the unrelated payload.
+ for payload in unrelated_payloads:
+ self._recv_queue.put(payload)
+
+ # Keep the original order of requests.
+ return [recv_payloads[p.req_id] for p in send_payloads]
+
+ def shutdown(self, timeout: Optional[float] = None) -> None:
+ if self._running:
+ for child in self._children:
+ child.shutdown(timeout=timeout)
+ self._cleanup_queue()
+ self._running = False
+
+ def _cleanup_queue(self):
+ while True:
+ while not self._recv_queue.empty():
+ self._recv_queue.get()
+ time.sleep(0.1) # mp.Queue is not reliable.
+ if self._recv_queue.empty():
+ break
+ if hasattr(self._recv_queue, "close"):
+ self._recv_queue.close()
+ self._recv_queue.join_thread()
+ self._recv_queue = None
+
+ def __getattr__(self, key: str) -> List[Any]:
+ assert self._running, "Supervisor is not running, please call start_link first!"
+ send_payloads = []
+ for i, child in enumerate(self._children):
+ payload = SendPayload(proc_id=i, method=ReserveMethod.GETATTR, args=[key])
+ send_payloads.append(payload)
+ child.send(payload)
+ return [payload.data for payload in self.recv_all(send_payloads)]
+
+ def get_child_attr(self, proc_id: str, key: str) -> Any:
+ """
+ Overview:
+ Get attr of one child process instance.
+ Arguments:
+ - proc_id (:obj:`str`): Proc id.
+ - key (:obj:`str`): Attribute key.
+ Returns:
+ - attr (:obj:`Any`): Attribute of child.
+ """
+ assert self._running, "Supervisor is not running, please call start_link first!"
+ payload = SendPayload(proc_id=proc_id, method=ReserveMethod.GETATTR, args=[key])
+ self._children[proc_id].send(payload)
+ payloads = self.recv_all([payload])
+ return payloads[0].data
+
+ def __del__(self) -> None:
+ self.shutdown(timeout=5)
+ self._children.clear()
diff --git a/DI-engine/ding/framework/task.py b/DI-engine/ding/framework/task.py
new file mode 100644
index 0000000000000000000000000000000000000000..e5f8e7d9f70e553d4b67467c5d1a5b4e03f098b5
--- /dev/null
+++ b/DI-engine/ding/framework/task.py
@@ -0,0 +1,553 @@
+from asyncio import InvalidStateError
+from asyncio.tasks import FIRST_EXCEPTION
+from collections import OrderedDict
+from threading import Lock
+import time
+import asyncio
+import concurrent.futures
+import fnmatch
+import math
+import enum
+from types import GeneratorType
+from typing import Any, Awaitable, Callable, Dict, Generator, Iterable, List, Optional, Set, Union
+import inspect
+
+from ding.framework.context import Context
+from ding.framework.parallel import Parallel
+from ding.framework.event_loop import EventLoop
+from functools import wraps
+
+
+def enable_async(func: Callable) -> Callable:
+ """
+ Overview:
+ Empower the function with async ability.
+ Arguments:
+ - func (:obj:`Callable`): The original function.
+ Returns:
+ - runtime_handler (:obj:`Callable`): The wrap function.
+ """
+
+ @wraps(func)
+ def runtime_handler(task: "Task", *args, async_mode: Optional[bool] = None, **kwargs) -> "Task":
+ """
+ Overview:
+ If task's async mode is enabled, execute the step in current loop executor asyncly,
+ or execute the task sync.
+ Arguments:
+ - task (:obj:`Task`): The task instance.
+ - async_mode (:obj:`Optional[bool]`): Whether using async mode.
+ Returns:
+ - result (:obj:`Union[Any, Awaitable]`): The result or future object of middleware.
+ """
+ if async_mode is None:
+ async_mode = task.async_mode
+ if async_mode:
+ assert not kwargs, "Should not use kwargs in async_mode, use position parameters, kwargs: {}".format(kwargs)
+ t = task._async_loop.run_in_executor(task._thread_pool, func, task, *args, **kwargs)
+ task._async_stack.append(t)
+ return task
+ else:
+ return func(task, *args, **kwargs)
+
+ return runtime_handler
+
+
+class Role(str, enum.Enum):
+ LEARNER = "learner"
+ COLLECTOR = "collector"
+ EVALUATOR = "evaluator"
+ FETCHER = 'fetcher'
+
+
+class VoidMiddleware:
+
+ def __call__(self, _):
+ return
+
+
+class Task:
+ """
+ Task will manage the execution order of the entire pipeline, register new middleware,
+ and generate new context objects.
+ """
+ role = Role
+
+ def __init__(self) -> None:
+ self.router = Parallel()
+ self._finish = False
+
+ def start(
+ self,
+ async_mode: bool = False,
+ n_async_workers: int = 3,
+ ctx: Optional[Context] = None,
+ labels: Optional[Set[str]] = None
+ ) -> "Task":
+ # This flag can be modified by external or associated processes
+ self._finish = False
+ # This flag can only be modified inside the class, it will be set to False in the end of stop
+ self._running = True
+ self._middleware = []
+ self._wrappers = []
+ self.ctx = ctx or Context()
+ self._backward_stack = OrderedDict()
+ self._roles = set()
+ # Bind event loop functions
+ self._event_loop = EventLoop("task_{}".format(id(self)))
+
+ # Async segment
+ self.async_mode = async_mode
+ self.n_async_workers = n_async_workers
+ self._async_stack = []
+ self._async_loop = None
+ self._thread_pool = None
+ self._exception = None
+ self._thread_lock = Lock()
+ self.labels = labels or set()
+
+ # Parallel segment
+ if async_mode or self.router.is_active:
+ self._activate_async()
+
+ if self.router.is_active:
+
+ def sync_finish(value):
+ self._finish = value
+
+ self.on("finish", sync_finish)
+
+ self.init_labels()
+ return self
+
+ def add_role(self, role: Role):
+ self._roles.add(role)
+
+ def has_role(self, role: Role) -> bool:
+ if len(self._roles) == 0:
+ return True
+ return role in self._roles
+
+ @property
+ def roles(self) -> Set[Role]:
+ return self._roles
+
+ def void(self):
+ return VoidMiddleware()
+
+ def init_labels(self):
+ if self.async_mode:
+ self.labels.add("async")
+ if self.router.is_active:
+ self.labels.add("distributed")
+ self.labels.add("node.{}".format(self.router.node_id))
+ for label in self.router.labels:
+ self.labels.add(label)
+ else:
+ self.labels.add("standalone")
+
+ def use(self, fn: Callable, lock: Union[bool, Lock] = False) -> 'Task':
+ """
+ Overview:
+ Register middleware to task. The middleware will be executed by it's registry order.
+ Arguments:
+ - fn (:obj:`Callable`): A middleware is a function with only one argument: ctx.
+ - lock (:obj:`Union[bool, Lock]`): There can only be one middleware execution under lock at any one time.
+ Returns:
+ - task (:obj:`Task`): The task.
+ """
+ assert isinstance(fn, Callable), "Middleware function should be a callable object, current fn {}".format(fn)
+ if isinstance(fn, VoidMiddleware): # Skip void function
+ return self
+ for wrapper in self._wrappers:
+ fn = wrapper(fn)
+ self._middleware.append(self.wrap(fn, lock=lock))
+ return self
+
+ def use_wrapper(self, fn: Callable) -> 'Task':
+ """
+ Overview:
+ Register wrappers to task. A wrapper works like a decorator, but task will apply this \
+ decorator on top of each middleware.
+ Arguments:
+ - fn (:obj:`Callable`): A wrapper is a decorator, so the first argument is a callable function.
+ Returns:
+ - task (:obj:`Task`): The task.
+ """
+ # Wrap exist middlewares
+ for i, middleware in enumerate(self._middleware):
+ self._middleware[i] = fn(middleware)
+ self._wrappers.append(fn)
+ return self
+
+ def match_labels(self, patterns: Union[Iterable[str], str]) -> bool:
+ """
+ Overview:
+ A list of patterns to match labels.
+ Arguments:
+ - patterns (:obj:`Union[Iterable[str], str]`): Glob like pattern, e.g. node.1, node.*.
+ """
+ if isinstance(patterns, str):
+ patterns = [patterns]
+ return any([fnmatch.filter(self.labels, p) for p in patterns])
+
+ def run(self, max_step: int = int(1e12)) -> None:
+ """
+ Overview:
+ Execute the iterations, when reach the max_step or task.finish is true,
+ The loop will be break.
+ Arguments:
+ - max_step (:obj:`int`): Max step of iterations.
+ """
+ assert self._running, "Please make sure the task is running before calling the this method, see the task.start"
+ if len(self._middleware) == 0:
+ return
+ for i in range(max_step):
+ for fn in self._middleware:
+ self.forward(fn)
+ # Sync should be called before backward, otherwise it is possible
+ # that some generators have not been pushed to backward_stack.
+ self.sync()
+ self.backward()
+ self.sync()
+ if i == max_step - 1:
+ self.finish = True
+ if self.finish:
+ break
+ self.renew()
+
+ def wrap(self, fn: Callable, lock: Union[bool, Lock] = False) -> Callable:
+ """
+ Overview:
+ Wrap the middleware, make it can be called directly in other middleware.
+ Arguments:
+ - fn (:obj:`Callable`): The middleware.
+ - lock (:obj:`Union[bool, Lock]`): There can only be one middleware execution under lock at any one time.
+ Returns:
+ - fn_back (:obj:`Callable`): It will return a backward function, which will call the rest part of
+ the middleware after yield. If this backward function was not called, the rest part of the middleware
+ will be called in the global backward step.
+ """
+ if lock is True:
+ lock = self._thread_lock
+
+ def forward(ctx: Context):
+ if lock:
+ with lock:
+ g = self.forward(fn, ctx, async_mode=False)
+ else:
+ g = self.forward(fn, ctx, async_mode=False)
+
+ def backward():
+ backward_stack = OrderedDict()
+ key = id(g)
+ backward_stack[key] = self._backward_stack.pop(key)
+ if lock:
+ with lock:
+ self.backward(backward_stack, async_mode=False)
+ else:
+ self.backward(backward_stack, async_mode=False)
+
+ return backward
+
+ if hasattr(fn, "__name__"):
+ forward = wraps(fn)(forward)
+ else:
+ forward = wraps(fn.__class__)(forward)
+
+ return forward
+
+ @enable_async
+ def forward(self, fn: Callable, ctx: Optional[Context] = None) -> Optional[Generator]:
+ """
+ Overview:
+ This function will execute the middleware until the first yield statment,
+ or the end of the middleware.
+ Arguments:
+ - fn (:obj:`Callable`): Function with contain the ctx argument in middleware.
+ - ctx (:obj:`Optional[Context]`): Replace global ctx with a customized ctx.
+ Returns:
+ - g (:obj:`Optional[Generator]`): The generator if the return value of fn is a generator.
+ """
+ assert self._running, "Please make sure the task is running before calling the this method, see the task.start"
+ if not ctx:
+ ctx = self.ctx
+ g = fn(ctx)
+ if isinstance(g, GeneratorType):
+ try:
+ next(g)
+ self._backward_stack[id(g)] = g
+ return g
+ except StopIteration:
+ pass
+
+ @enable_async
+ def backward(self, backward_stack: Optional[Dict[str, Generator]] = None) -> None:
+ """
+ Overview:
+ Execute the rest part of middleware, by the reversed order of registry.
+ Arguments:
+ - backward_stack (:obj:`Optional[Dict[str, Generator]]`): Replace global backward_stack with a customized \
+ stack.
+ """
+ assert self._running, "Please make sure the task is running before calling the this method, see the task.start"
+ if not backward_stack:
+ backward_stack = self._backward_stack
+ while backward_stack:
+ # FILO
+ _, g = backward_stack.popitem()
+ try:
+ next(g)
+ except StopIteration:
+ continue
+
+ @property
+ def running(self):
+ return self._running
+
+ def serial(self, *fns: List[Callable]) -> Callable:
+ """
+ Overview:
+ Wrap functions and keep them run in serial, Usually in order to avoid the confusion
+ of dependencies in async mode.
+ Arguments:
+ - fn (:obj:`Callable`): Chain a serial of middleware, wrap them into one middleware function.
+ """
+
+ def _serial(ctx):
+ backward_keys = []
+ for fn in fns:
+ g = self.forward(fn, ctx, async_mode=False)
+ if isinstance(g, GeneratorType):
+ backward_keys.append(id(g))
+ yield
+ backward_stack = OrderedDict()
+ for k in backward_keys:
+ backward_stack[k] = self._backward_stack.pop(k)
+ self.backward(backward_stack=backward_stack, async_mode=False)
+
+ name = ",".join([fn.__name__ for fn in fns])
+ _serial.__name__ = "serial<{}>".format(name)
+ return _serial
+
+ def parallel(self, *fns: List[Callable]) -> Callable:
+ """
+ Overview:
+ Wrap functions and keep them run in parallel, should not use this funciton in async mode.
+ Arguments:
+ - fn (:obj:`Callable`): Parallelized middleware, wrap them into one middleware function.
+ """
+ self._activate_async()
+
+ def _parallel(ctx):
+ backward_keys = []
+ for fn in fns:
+ g = self.forward(fn, ctx, async_mode=True)
+ if isinstance(g, GeneratorType):
+ backward_keys.append(id(g))
+ self.sync()
+ yield
+ backward_stack = OrderedDict()
+ for k in backward_keys:
+ backward_stack[k] = self._backward_stack.pop(k)
+ self.backward(backward_stack, async_mode=True)
+ self.sync()
+
+ name = ",".join([fn.__name__ for fn in fns])
+ _parallel.__name__ = "parallel<{}>".format(name)
+ return _parallel
+
+ def renew(self) -> 'Task':
+ """
+ Overview:
+ Renew the context instance, this function should be called after backward in the end of iteration.
+ """
+ assert self._running, "Please make sure the task is running before calling the this method, see the task.start"
+ self.ctx = self.ctx.renew()
+ return self
+
+ def __enter__(self) -> "Task":
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.stop()
+
+ def stop(self) -> None:
+ """
+ Overview:
+ Stop and cleanup every thing in the runtime of task.
+ """
+ if self.router.is_active:
+ self.emit("finish", True)
+ if self._thread_pool:
+ self._thread_pool.shutdown()
+ self._event_loop.stop()
+ self.router.off(self._wrap_event_name("*"))
+ if self._async_loop:
+ self._async_loop.stop()
+ self._async_loop.close()
+ # The middleware and listeners may contain some methods that reference to task,
+ # If we do not clear them after the task exits, we may find that gc will not clean up the task object.
+ self._middleware.clear()
+ self._wrappers.clear()
+ self._backward_stack.clear()
+ self._async_stack.clear()
+ self._running = False
+
+ def sync(self) -> 'Task':
+ if self._async_loop:
+ self._async_loop.run_until_complete(self.sync_tasks())
+ return self
+
+ async def sync_tasks(self) -> Awaitable[None]:
+ if self._async_stack:
+ await asyncio.wait(self._async_stack, return_when=FIRST_EXCEPTION)
+ while self._async_stack:
+ t = self._async_stack.pop(0)
+ try:
+ e = t.exception()
+ if e:
+ self._exception = e
+ raise e
+ except InvalidStateError:
+ # Not finished. https://docs.python.org/3/library/asyncio-task.html#asyncio.Task.exception
+ pass
+
+ def async_executor(self, fn: Callable, *args, **kwargs) -> None:
+ """
+ Overview:
+ Execute task in background, then apppend the future instance in _async_stack.
+ Arguments:
+ - fn (:obj:`Callable`): Synchronization fuction.
+ """
+ if not self._async_loop:
+ raise Exception("Event loop was not initialized, please call this function in async or parallel mode")
+ t = self._async_loop.run_in_executor(self._thread_pool, fn, *args, **kwargs)
+ self._async_stack.append(t)
+
+ def emit(self, event: str, *args, only_remote: bool = False, only_local: bool = False, **kwargs) -> None:
+ """
+ Overview:
+ Emit an event, call listeners.
+ Arguments:
+ - event (:obj:`str`): Event name.
+ - only_remote (:obj:`bool`): Only broadcast the event to the connected nodes, default is False.
+ - only_local (:obj:`bool`): Only emit local event, default is False.
+ - args (:obj:`any`): Rest arguments for listeners.
+ """
+ # Check if need to broadcast event to connected nodes, default is True
+ assert self._running, "Please make sure the task is running before calling the this method, see the task.start"
+ if only_local:
+ self._event_loop.emit(event, *args, **kwargs)
+ elif only_remote:
+ if self.router.is_active:
+ self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs)
+ else:
+ if self.router.is_active:
+ self.async_executor(self.router.emit, self._wrap_event_name(event), event, *args, **kwargs)
+ self._event_loop.emit(event, *args, **kwargs)
+
+ def on(self, event: str, fn: Callable) -> None:
+ """
+ Overview:
+ Subscribe to an event, execute this function every time the event is emitted.
+ Arguments:
+ - event (:obj:`str`): Event name.
+ - fn (:obj:`Callable`): The function.
+ """
+ self._event_loop.on(event, fn)
+ if self.router.is_active:
+ self.router.on(self._wrap_event_name(event), self._event_loop.emit)
+
+ def once(self, event: str, fn: Callable) -> None:
+ """
+ Overview:
+ Subscribe to an event, execute this function only once when the event is emitted.
+ Arguments:
+ - event (:obj:`str`): Event name.
+ - fn (:obj:`Callable`): The function.
+ """
+ self._event_loop.once(event, fn)
+ if self.router.is_active:
+ self.router.on(self._wrap_event_name(event), self._event_loop.emit)
+
+ def off(self, event: str, fn: Optional[Callable] = None) -> None:
+ """
+ Overview:
+ Unsubscribe an event
+ Arguments:
+ - event (:obj:`str`): Event name.
+ - fn (:obj:`Callable`): The function.
+ """
+ self._event_loop.off(event, fn)
+ if self.router.is_active:
+ self.router.off(self._wrap_event_name(event))
+
+ def wait_for(self, event: str, timeout: float = math.inf, ignore_timeout_exception: bool = True) -> Any:
+ """
+ Overview:
+ Wait for an event and block the thread.
+ Arguments:
+ - event (:obj:`str`): Event name.
+ - timeout (:obj:`float`): Timeout in seconds.
+ - ignore_timeout_exception (:obj:`bool`): If this is False, an exception will occur when meeting timeout.
+ """
+ assert self._running, "Please make sure the task is running before calling the this method, see the task.start"
+ received = False
+ result = None
+
+ def _receive_event(*args, **kwargs):
+ nonlocal result, received
+ result = (args, kwargs)
+ received = True
+
+ self.once(event, _receive_event)
+
+ start = time.time()
+ while time.time() - start < timeout:
+ if received or self._exception:
+ return result
+ time.sleep(0.01)
+
+ if ignore_timeout_exception:
+ return result
+ else:
+ raise TimeoutError("Timeout when waiting for event: {}".format(event))
+
+ @property
+ def finish(self):
+ return self._finish
+
+ @finish.setter
+ def finish(self, value: bool):
+ self._finish = value
+
+ def _wrap_event_name(self, event: str) -> str:
+ """
+ Overview:
+ Wrap the event name sent to the router.
+ Arguments:
+ - event (:obj:`str`): Event name
+ """
+ return "task.{}".format(event)
+
+ def _activate_async(self):
+ if not self._thread_pool:
+ self._thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=self.n_async_workers)
+ if not self._async_loop:
+ self._async_loop = asyncio.new_event_loop()
+
+ def get_attch_to_len(self) -> int:
+ """
+ Overview:
+ Get the length of the 'attach_to' list in Parallel._mq.
+ Returns:
+ int: the length of the Parallel._mq.
+ """
+ if self.router.is_active:
+ return self.router.get_attch_to_len()
+ else:
+ raise RuntimeError("The router is inactive, failed to be obtained the length of 'attch_to' list.")
+
+
+task = Task()
diff --git a/DI-engine/ding/framework/tests/context_fake_data.py b/DI-engine/ding/framework/tests/context_fake_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee65048e6cae0e614074bce5e97b2b2b99729735
--- /dev/null
+++ b/DI-engine/ding/framework/tests/context_fake_data.py
@@ -0,0 +1,76 @@
+from ding.framework import Context, OnlineRLContext, OfflineRLContext
+import random
+import numpy as np
+import treetensor.torch as ttorch
+import torch
+
+batch_size = 64
+n_sample = 8
+action_dim = 1
+obs_dim = 4
+logit_dim = 2
+
+n_episodes = 2
+n_episode_length = 16
+update_per_collect = 4
+collector_env_num = 8
+
+
+# the range here is meaningless and just for test
+def fake_train_data():
+ train_data = ttorch.as_tensor(
+ {
+ 'action': torch.randint(0, 2, size=(action_dim, )),
+ 'collect_train_iter': torch.randint(0, 100, size=(1, )),
+ 'done': torch.tensor(False),
+ 'env_data_id': torch.tensor([2]),
+ 'next_obs': torch.randn(obs_dim),
+ 'obs': torch.randn(obs_dim),
+ 'reward': torch.randint(0, 2, size=(1, )),
+ }
+ )
+ return train_data
+
+
+def fake_online_rl_context():
+ ctx = OnlineRLContext(
+ env_step=random.randint(0, 100),
+ env_episode=random.randint(0, 100),
+ train_iter=random.randint(0, 100),
+ train_data=[fake_train_data() for _ in range(batch_size)],
+ train_output=[{
+ 'cur_lr': 0.001,
+ 'total_loss': random.uniform(0, 2)
+ } for _ in range(update_per_collect)],
+ obs=torch.randn(collector_env_num, obs_dim),
+ action=[np.random.randint(low=0, high=1, size=(action_dim), dtype=np.int64) for _ in range(collector_env_num)],
+ inference_output={
+ env_id: {
+ 'logit': torch.randn(logit_dim),
+ 'action': torch.randint(0, 2, size=(action_dim, ))
+ }
+ for env_id in range(collector_env_num)
+ },
+ collect_kwargs={'eps': random.uniform(0, 1)},
+ trajectories=[fake_train_data() for _ in range(n_sample)],
+ episodes=[[fake_train_data() for _ in range(n_episode_length)] for _ in range(n_episodes)],
+ trajectory_end_idx=[i for i in range(n_sample)],
+ eval_value=random.uniform(-1.0, 1.0),
+ last_eval_iter=random.randint(0, 100),
+ )
+ return ctx
+
+
+def fake_offline_rl_context():
+ ctx = OfflineRLContext(
+ train_epoch=random.randint(0, 100),
+ train_iter=random.randint(0, 100),
+ train_data=[fake_train_data() for _ in range(batch_size)],
+ train_output=[{
+ 'cur_lr': 0.001,
+ 'total_loss': random.uniform(0, 2)
+ } for _ in range(update_per_collect)],
+ eval_value=random.uniform(-1.0, 1.0),
+ last_eval_iter=random.randint(0, 100),
+ )
+ return ctx
diff --git a/DI-engine/ding/framework/tests/test_context.py b/DI-engine/ding/framework/tests/test_context.py
new file mode 100644
index 0000000000000000000000000000000000000000..c20efd85d27a575105ce75a348936627e9ff5e56
--- /dev/null
+++ b/DI-engine/ding/framework/tests/test_context.py
@@ -0,0 +1,59 @@
+import pytest
+import pickle
+import numpy as np
+from ding.framework import Context, OnlineRLContext, OfflineRLContext
+from dataclasses import dataclass
+
+
+@dataclass
+class MockContext(Context):
+ hello: str = "world"
+ keep_me: int = 0
+ not_keep_me: int = 0
+
+
+@pytest.mark.unittest
+def test_pickable():
+ ctx = MockContext()
+ ctx.keep("keep_me")
+ _ctx = pickle.loads(pickle.dumps(ctx))
+ assert _ctx.hello == "world"
+
+ ctx.keep_me += 1
+ ctx.not_keep_me += 1
+
+ _ctx = ctx.renew()
+ assert _ctx.keep_me == 1
+ assert _ctx.not_keep_me == 0
+
+
+@pytest.mark.unittest
+def test_online():
+ ctx = OnlineRLContext()
+ assert ctx.env_step == 0
+ assert ctx.eval_value == -np.inf
+
+ ctx.env_step += 1
+ ctx.eval_value = 1
+ assert ctx.env_step == 1
+ assert ctx.eval_value == 1
+
+ _ctx = ctx.renew()
+ assert _ctx.env_step == 1
+ assert _ctx.eval_value == -np.inf
+
+
+@pytest.mark.unittest
+def test_offline():
+ ctx = OfflineRLContext()
+ assert ctx.train_iter == 0
+ assert ctx.eval_value == -np.inf
+
+ ctx.train_iter += 1
+ ctx.eval_value = 1
+ assert ctx.train_iter == 1
+ assert ctx.eval_value == 1
+
+ _ctx = ctx.renew()
+ assert _ctx.train_iter == 1
+ assert _ctx.eval_value == -np.inf
diff --git a/DI-engine/ding/framework/tests/test_event_loop.py b/DI-engine/ding/framework/tests/test_event_loop.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f3545f3f54b13820dcb220d77770c9c710ad1bd
--- /dev/null
+++ b/DI-engine/ding/framework/tests/test_event_loop.py
@@ -0,0 +1,54 @@
+from time import sleep
+import pytest
+from ding.framework import EventLoop
+from threading import Lock
+
+
+@pytest.mark.unittest
+def test_event_loop():
+ loop = EventLoop.get_event_loop("test")
+ try:
+ counter = 0
+ lock = Lock()
+
+ def callback(n, lock):
+ nonlocal counter
+ with lock:
+ counter += n
+
+ # Test on
+ loop.on("count", callback)
+
+ for i in range(5):
+ loop.emit("count", i, lock)
+ sleep(0.1)
+ assert counter == 10
+
+ # Test off
+ loop.off("count")
+ loop.emit("count", 10, lock)
+ sleep(0.1)
+ assert counter == 10
+
+ # Test once
+ counter = 0
+ loop.once("count", callback)
+ loop.once("count", callback)
+ loop.emit("count", 10, lock)
+ sleep(0.1)
+ assert counter == 20
+ loop.emit("count", 10, lock)
+ assert counter == 20
+
+ # Test exception
+ def except_callback():
+ raise Exception("error")
+
+ loop.on("error", except_callback)
+ loop.emit("error")
+ sleep(0.1)
+ assert loop._exception is not None
+ with pytest.raises(Exception):
+ loop.emit("error")
+ finally:
+ loop.stop()
diff --git a/DI-engine/ding/framework/tests/test_parallel.py b/DI-engine/ding/framework/tests/test_parallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bdf6ea343741ac1bc7c31536b33506510deffc9
--- /dev/null
+++ b/DI-engine/ding/framework/tests/test_parallel.py
@@ -0,0 +1,156 @@
+from collections import defaultdict
+import pytest
+import time
+from ding.framework import Parallel
+
+
+def parallel_main():
+ msg = defaultdict(bool)
+
+ def test_callback(key):
+ msg[key] = True
+
+ router = Parallel()
+ router.on("test_callback", test_callback)
+ # Wait for nodes to bind
+ time.sleep(0.7)
+ for _ in range(30):
+ router.emit("test_callback", "ping")
+ if msg["ping"]:
+ break
+ time.sleep(0.03)
+ assert msg["ping"]
+ # Avoid can not receiving messages from each other after exit parallel
+ time.sleep(0.7)
+
+
+@pytest.mark.tmp
+def test_parallel_run():
+ Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main)
+ Parallel.runner(n_parallel_workers=2, protocol="tcp", startup_interval=0.1)(parallel_main)
+
+
+def uncaught_exception_main():
+ router = Parallel()
+ if router.node_id == 0:
+ time.sleep(0.1)
+ raise Exception("uncaught exception")
+ else:
+ time.sleep(0.2)
+
+
+@pytest.mark.tmp
+def test_uncaught_exception():
+ # Make one process crash, then the parent process will also crash and output the stack of the wrong process.
+ with pytest.raises(Exception) as exc_info:
+ Parallel.runner(n_parallel_workers=2, topology="mesh", startup_interval=0.1)(uncaught_exception_main)
+ e = exc_info._excinfo[1]
+ assert "uncaught exception" in str(e)
+
+
+def disconnected_main():
+ router = Parallel()
+
+ if router.node_id == 0:
+ time.sleep(0.1)
+ # Receive two messages then exit
+ greets = []
+ router.on("greeting", lambda: greets.append("."))
+ for _ in range(10):
+ if len(greets) == 1:
+ break
+ else:
+ time.sleep(0.1)
+ assert len(greets) > 0
+ else:
+ # Send 10 greetings even if the target process is exited
+ for i in range(10):
+ router.emit("greeting")
+ time.sleep(0.1)
+ assert i == 9
+
+
+@pytest.mark.tmp
+def test_disconnected():
+ # Make one process exit normally and the rest will still run, even if the network request
+ # is not received by other processes.
+ Parallel.runner(n_parallel_workers=2, topology="mesh", startup_interval=0.1)(disconnected_main)
+
+
+class AutoRecover:
+
+ @classmethod
+ def main_p0(cls):
+ # Wait for p1's message and recovered message from p1
+ greets = []
+ router = Parallel()
+ router.on("greeting_0", lambda msg: greets.append(msg))
+ for _ in range(50):
+ if greets and greets[-1] == "recovered_p1":
+ break
+ time.sleep(0.1)
+ assert greets and greets[-1] == "recovered_p1"
+
+ @classmethod
+ def main_p1(cls):
+ # Send empty message to p0
+ # When recovered from exception, send recovered_p1 to p0
+ # Listen msgs from p2
+ greets = []
+ router = Parallel()
+ router.on("greeting_1", lambda msg: greets.append(msg))
+
+ # Test sending message to p0
+ if router._retries == 0:
+ for _ in range(10):
+ router.emit("greeting_0", "")
+ time.sleep(0.1)
+ raise Exception("P1 Error")
+ elif router._retries == 1:
+ for _ in range(10):
+ router.emit("greeting_0", "recovered_p1")
+ time.sleep(0.1)
+ else:
+ raise Exception("Failed too many times")
+
+ # Test recover and receving message from p2
+ for _ in range(20):
+ if greets:
+ break
+ time.sleep(0.1)
+ assert len(greets) > 0
+
+ @classmethod
+ def main_p2(cls):
+ # Simply send message to p1
+ router = Parallel()
+ for _ in range(20):
+ router.emit("greeting_1", "")
+ time.sleep(0.1)
+
+ @classmethod
+ def main(cls):
+ router = Parallel()
+ if router.node_id == 0:
+ cls.main_p0()
+ elif router.node_id == 1:
+ cls.main_p1()
+ elif router.node_id == 2:
+ cls.main_p2()
+ else:
+ raise Exception("Invalid node id")
+
+
+@pytest.mark.tmp
+def test_auto_recover():
+ # With max_retries=1
+ Parallel.runner(
+ n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=1, startup_interval=0.1
+ )(AutoRecover.main)
+ # With max_retries=0
+ with pytest.raises(Exception) as exc_info:
+ Parallel.runner(
+ n_parallel_workers=3, topology="mesh", auto_recover=True, max_retries=0, startup_interval=0.1
+ )(AutoRecover.main)
+ e = exc_info._excinfo[1]
+ assert "P1 Error" in str(e)
diff --git a/DI-engine/ding/framework/tests/test_supervisor.py b/DI-engine/ding/framework/tests/test_supervisor.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6f4c646fae2709e82db7755eb296f4754e7801b
--- /dev/null
+++ b/DI-engine/ding/framework/tests/test_supervisor.py
@@ -0,0 +1,310 @@
+import multiprocessing as mp
+import ctypes
+from time import sleep, time
+from typing import Any, Dict, List
+import pytest
+from ding.framework.supervisor import RecvPayload, SendPayload, Supervisor, ChildType
+
+
+class MockEnv():
+
+ def __init__(self, _) -> None:
+ self._counter = 0
+
+ def step(self, _):
+ self._counter += 1
+ return self._counter
+
+ @property
+ def action_space(self):
+ return 3
+
+ def block(self):
+ sleep(10)
+
+ def block_reset(self):
+ sleep(10)
+
+ def sleep1(self):
+ sleep(1)
+
+
+@pytest.mark.tmp
+@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+def test_supervisor(type_):
+ sv = Supervisor(type_=type_)
+ for _ in range(3):
+ sv.register(lambda: MockEnv("AnyArgs"))
+ sv.start_link()
+
+ for env_id in range(len(sv._children)):
+ sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"]))
+
+ recv_states: List[RecvPayload] = []
+ for _ in range(3):
+ recv_states.append(sv.recv())
+
+ assert sum([payload.proc_id for payload in recv_states]) == 3
+ assert all([payload.data == 1 for payload in recv_states])
+
+ # Test recv_all
+ send_payloads = []
+ for env_id in range(len(sv._children)):
+ payload = SendPayload(
+ proc_id=env_id,
+ method="step",
+ args=["any action"],
+ )
+ send_payloads.append(payload)
+ sv.send(payload)
+
+ req_ids = [payload.req_id for payload in send_payloads]
+ # Only wait for last two messages, keep the first one in the queue.
+ recv_payloads = sv.recv_all(send_payloads[1:])
+ assert len(recv_payloads) == 2
+ for req_id, payload in zip(req_ids[1:], recv_payloads):
+ assert req_id == payload.req_id
+
+ recv_payload = sv.recv()
+ assert recv_payload.req_id == req_ids[0]
+
+ assert len(sv.action_space) == 3
+ assert all(a == 3 for a in sv.action_space)
+
+ sv.shutdown()
+
+
+@pytest.mark.tmp
+def test_supervisor_spawn():
+ sv = Supervisor(type_=ChildType.PROCESS, mp_ctx=mp.get_context("spawn"))
+ for _ in range(3):
+ sv.register(MockEnv("AnyArgs"))
+ sv.start_link()
+
+ for env_id in range(len(sv._children)):
+ sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"]))
+
+ recv_states: List[RecvPayload] = []
+ for _ in range(3):
+ recv_states.append(sv.recv())
+
+ assert sum([payload.proc_id for payload in recv_states]) == 3
+ assert all([payload.data == 1 for payload in recv_states])
+ sv.shutdown()
+
+
+class MockCrashEnv(MockEnv):
+
+ def step(self, _):
+ super().step(_)
+ if self._counter == 2:
+ raise Exception("Ohh")
+
+ return self._counter
+
+
+@pytest.mark.tmp
+@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+def test_crash_supervisor(type_):
+ sv = Supervisor(type_=type_)
+ for _ in range(2):
+ sv.register(lambda: MockEnv("AnyArgs"))
+ sv.register(lambda: MockCrashEnv("AnyArgs"))
+ sv.start_link()
+
+ # Send 6 messages, will cause the third subprocess crash
+ for env_id in range(len(sv._children)):
+ for _ in range(2):
+ sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"]))
+
+ # Find the error mesasge
+ recv_states: List[RecvPayload] = []
+ for _ in range(6):
+ recv_payload = sv.recv(ignore_err=True)
+ if recv_payload.err:
+ sv._children[recv_payload.proc_id].restart()
+ recv_states.append(recv_payload)
+ assert any([isinstance(payload.err, Exception) for payload in recv_states])
+
+ # Resume
+ for env_id in range(len(sv._children)):
+ sv.send(SendPayload(proc_id=env_id, method="step", args=["any action"]))
+ recv_states: List[RecvPayload] = []
+ for _ in range(3):
+ recv_states.append(sv.recv())
+
+ # 3 + 3 + 1
+ assert sum([p.data for p in recv_states]) == 7
+
+ with pytest.raises(Exception):
+ sv.send(SendPayload(proc_id=2, method="step", args=["any action"]))
+ sv.recv(ignore_err=False)
+
+ sv.shutdown()
+
+
+@pytest.mark.tmp
+@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+def test_recv_all(type_):
+ sv = Supervisor(type_=type_)
+ for _ in range(3):
+ sv.register(lambda: MockEnv("AnyArgs"))
+ sv.start_link()
+
+ # Test recv_all
+ send_payloads = []
+ for env_id in range(len(sv._children)):
+ payload = SendPayload(
+ proc_id=env_id,
+ method="step",
+ args=["any action"],
+ )
+ send_payloads.append(payload)
+ sv.send(payload)
+
+ retry_times = {env_id: 0 for env_id in range(len(sv._children))}
+
+ def recv_callback(recv_payload: RecvPayload, remain_payloads: Dict[str, SendPayload]):
+ if retry_times[recv_payload.proc_id] == 2:
+ return
+ retry_times[recv_payload.proc_id] += 1
+ payload = SendPayload(proc_id=recv_payload.proc_id, method="step", args={"action"})
+ sv.send(payload)
+ remain_payloads[payload.req_id] = payload
+
+ recv_payloads = sv.recv_all(send_payloads=send_payloads, callback=recv_callback)
+ assert len(recv_payloads) == 3
+ assert all([v == 2 for v in retry_times.values()])
+
+ sv.shutdown()
+
+
+@pytest.mark.timeout(60)
+@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+def test_timeout(type_):
+ sv = Supervisor(type_=type_)
+ for _ in range(3):
+ sv.register(lambda: MockEnv("AnyArgs"))
+ sv.start_link()
+
+ send_payloads = []
+ for env_id in range(len(sv._children)):
+ payload = SendPayload(proc_id=env_id, method="block")
+ send_payloads.append(payload)
+ sv.send(payload)
+
+ # Test timeout exception
+ with pytest.raises(TimeoutError):
+ sv.recv_all(send_payloads=send_payloads, timeout=1)
+ sv.shutdown(timeout=1)
+
+ # Test timeout with ignore error
+ sv.start_link()
+ send_payloads = []
+
+ # 0 is block
+ payload = SendPayload(proc_id=0, method="block")
+ send_payloads.append(payload)
+ sv.send(payload)
+
+ # 1 is step
+ payload = SendPayload(proc_id=1, method="step", args=[""])
+ send_payloads.append(payload)
+ sv.send(payload)
+
+ payloads = sv.recv_all(send_payloads=send_payloads, timeout=1, ignore_err=True)
+ assert isinstance(payloads[0].err, TimeoutError)
+ assert payloads[1].err is None
+
+ sv.shutdown(timeout=1)
+
+
+@pytest.mark.timeout(60)
+@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+def test_timeout_with_callback(type_):
+ sv = Supervisor(type_=type_)
+ for _ in range(3):
+ sv.register(lambda: MockEnv("AnyArgs"))
+ sv.start_link()
+ send_payloads = []
+
+ # 0 is block
+ payload = SendPayload(proc_id=0, method="block")
+ send_payloads.append(payload)
+ sv.send(payload)
+
+ # 1 is step
+ payload = SendPayload(proc_id=1, method="step", args=[""])
+ send_payloads.append(payload)
+ sv.send(payload)
+
+ block_reset_callback = False
+
+ # 1. Add another send payload in the callback
+ # 2. Recv this send payload and check for the method
+ def recv_callback(recv_payload: RecvPayload, remain_payloads: Dict[str, SendPayload]):
+ if recv_payload.method == "block" and recv_payload.err:
+ new_send_payload = SendPayload(proc_id=recv_payload.proc_id, method="block_reset")
+ remain_payloads[new_send_payload.req_id] = new_send_payload
+ return
+
+ if recv_payload.method == "block_reset" and recv_payload.err:
+ nonlocal block_reset_callback
+ block_reset_callback = True
+ return
+
+ payloads = sv.recv_all(send_payloads=send_payloads, timeout=1, ignore_err=True, callback=recv_callback)
+ assert block_reset_callback
+ assert isinstance(payloads[0].err, TimeoutError)
+ assert payloads[1].err is None
+
+ sv.shutdown(timeout=1)
+
+
+@pytest.mark.tmp # gitlab ci and local test pass, github always fail
+def test_shared_memory():
+ sv = Supervisor(type_=ChildType.PROCESS)
+
+ def shm_callback(payload: RecvPayload, shm: Any):
+ shm[payload.proc_id] = payload.req_id
+ payload.data = 0
+
+ shm = mp.Array(ctypes.c_uint8, 3)
+ for i in range(3):
+ sv.register(lambda: MockEnv("AnyArgs"), shm_buffer=shm, shm_callback=shm_callback)
+ sv.start_link()
+
+ # Send init request
+ for env_id in range(len(sv._children)):
+ sv.send(SendPayload(proc_id=env_id, req_id=env_id, method="sleep1", args=[]))
+
+ start = time()
+ for i in range(6):
+ payload = sv.recv()
+ assert payload.data == 0
+ assert shm[payload.proc_id] == payload.req_id
+ sv.send(SendPayload(proc_id=payload.proc_id, req_id=i, method="sleep1", args=[]))
+
+ # Non blocking
+ assert time() - start < 3
+
+ sv.shutdown()
+
+
+@pytest.mark.benchmark
+@pytest.mark.parametrize("type_", [ChildType.PROCESS, ChildType.THREAD])
+def test_supervisor_benchmark(type_):
+ sv = Supervisor(type_=type_)
+ for _ in range(3):
+ sv.register(lambda: MockEnv("AnyArgs"))
+ sv.start_link()
+
+ for env_id in range(len(sv._children)):
+ sv.send(SendPayload(proc_id=env_id, method="step", args=[""]))
+
+ start = time()
+ for _ in range(1000):
+ payload = sv.recv()
+ sv.send(SendPayload(proc_id=payload.proc_id, method="step", args=[""]))
+
+ assert time() - start < 1
diff --git a/DI-engine/ding/framework/tests/test_task.py b/DI-engine/ding/framework/tests/test_task.py
new file mode 100644
index 0000000000000000000000000000000000000000..67f3dc34c77b05494881d5d388cf67dbe075117b
--- /dev/null
+++ b/DI-engine/ding/framework/tests/test_task.py
@@ -0,0 +1,382 @@
+import multiprocessing as mp
+import pytest
+from threading import Lock
+from time import sleep, time
+import random
+import dataclasses
+from ding.framework import task, Context, Parallel
+
+
+@dataclasses.dataclass
+class TestContext(Context):
+ pipeline: list = dataclasses.field(default_factory=list)
+
+
+@pytest.mark.unittest
+def test_serial_pipeline():
+
+ def step0(ctx):
+ ctx.pipeline.append(0)
+
+ def step1(ctx):
+ ctx.pipeline.append(1)
+
+ # Execute step1, step2 twice
+ with task.start(ctx=TestContext()):
+ for _ in range(2):
+ task.forward(step0)
+ task.forward(step1)
+ assert task.ctx.pipeline == [0, 1, 0, 1]
+
+ # Renew and execute step1, step2
+ task.renew()
+ assert task.ctx.total_step == 1
+ task.forward(step0)
+ task.forward(step1)
+ assert task.ctx.pipeline == [0, 1]
+
+ # Test context inheritance
+ task.renew()
+
+
+@pytest.mark.unittest
+def test_serial_yield_pipeline():
+
+ def step0(ctx):
+ ctx.pipeline.append(0)
+ yield
+ ctx.pipeline.append(0)
+
+ def step1(ctx):
+ ctx.pipeline.append(1)
+
+ with task.start(ctx=TestContext()):
+ task.forward(step0)
+ task.forward(step1)
+ task.backward()
+ assert task.ctx.pipeline == [0, 1, 0]
+ assert len(task._backward_stack) == 0
+
+
+@pytest.mark.unittest
+def test_async_pipeline():
+
+ def step0(ctx):
+ ctx.pipeline.append(0)
+
+ def step1(ctx):
+ ctx.pipeline.append(1)
+
+ # Execute step1, step2 twice
+ with task.start(async_mode=True, ctx=TestContext()):
+ for _ in range(2):
+ task.forward(step0)
+ sleep(0.1)
+ task.forward(step1)
+ sleep(0.1)
+ task.backward()
+ assert task.ctx.pipeline == [0, 1, 0, 1]
+ task.renew()
+ assert task.ctx.total_step == 1
+
+
+@pytest.mark.unittest
+def test_async_yield_pipeline():
+
+ def step0(ctx):
+ sleep(0.1)
+ ctx.pipeline.append(0)
+ yield
+ ctx.pipeline.append(0)
+
+ def step1(ctx):
+ sleep(0.2)
+ ctx.pipeline.append(1)
+
+ with task.start(async_mode=True, ctx=TestContext()):
+ task.forward(step0)
+ task.forward(step1)
+ sleep(0.3)
+ task.backward().sync()
+ assert task.ctx.pipeline == [0, 1, 0]
+ assert len(task._backward_stack) == 0
+
+
+def parallel_main():
+ sync_count = 0
+
+ def on_count():
+ nonlocal sync_count
+ sync_count += 1
+
+ def counter(task):
+
+ def _counter(ctx):
+ sleep(0.2 + random.random() / 10)
+ task.emit("count", only_remote=True)
+
+ return _counter
+
+ with task.start():
+ task.on("count", on_count)
+ task.use(counter(task))
+ task.run(max_step=10)
+ assert sync_count > 0
+
+
+@pytest.mark.tmp
+def test_parallel_pipeline():
+ Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(parallel_main)
+
+
+@pytest.mark.tmp
+def test_emit():
+ with task.start():
+ greets = []
+ task.on("Greeting", lambda msg: greets.append(msg))
+
+ def step1(ctx):
+ task.emit("Greeting", "Hi")
+
+ task.use(step1)
+ task.run(max_step=10)
+ sleep(0.1)
+ assert len(greets) == 10
+
+
+def emit_remote_main():
+ with task.start():
+ greets = []
+ if task.router.node_id == 0:
+ task.on("Greeting", lambda msg: greets.append(msg))
+ for _ in range(20):
+ if greets:
+ break
+ sleep(0.1)
+ assert len(greets) > 0
+ else:
+ for _ in range(20):
+ task.emit("Greeting", "Hi", only_remote=True)
+ sleep(0.1)
+ assert len(greets) == 0
+
+
+@pytest.mark.tmp
+def test_emit_remote():
+ Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(emit_remote_main)
+
+
+@pytest.mark.tmp
+def test_wait_for():
+ # Wait for will only work in async or parallel mode
+ with task.start(async_mode=True, n_async_workers=2):
+ greets = []
+
+ def step1(_):
+ hi = task.wait_for("Greeting")[0][0]
+ if hi:
+ greets.append(hi)
+
+ def step2(_):
+ task.emit("Greeting", "Hi")
+
+ task.use(step1)
+ task.use(step2)
+ task.run(max_step=10)
+
+ assert len(greets) == 10
+ assert all(map(lambda hi: hi == "Hi", greets))
+
+ # Test timeout exception
+ with task.start(async_mode=True, n_async_workers=2):
+
+ def step1(_):
+ task.wait_for("Greeting", timeout=0.3, ignore_timeout_exception=False)
+
+ task.use(step1)
+ with pytest.raises(TimeoutError):
+ task.run(max_step=1)
+
+
+@pytest.mark.tmp
+def test_async_exception():
+ with task.start(async_mode=True, n_async_workers=2):
+
+ def step1(_):
+ task.wait_for("any_event") # Never end
+
+ def step2(_):
+ sleep(0.3)
+ raise Exception("Oh")
+
+ task.use(step1)
+ task.use(step2)
+ with pytest.raises(Exception):
+ task.run(max_step=2)
+
+ assert task.ctx.total_step == 0
+
+
+def early_stop_main():
+ with task.start():
+ task.use(lambda _: sleep(0.5))
+ if task.match_labels("node.0"):
+ task.run(max_step=10)
+ else:
+ task.run(max_step=2)
+ assert task.ctx.total_step < 7
+
+
+@pytest.mark.tmp
+def test_early_stop():
+ Parallel.runner(n_parallel_workers=2, startup_interval=0.1)(early_stop_main)
+
+
+@pytest.mark.tmp
+def test_parallel_in_sequencial():
+ result = []
+
+ def fast(_):
+ result.append("fast")
+
+ def slow(_):
+ sleep(0.1)
+ result.append("slow")
+
+ with task.start():
+ task.use(lambda _: result.append("begin"))
+ task.use(task.parallel(slow, fast))
+ task.run(max_step=1)
+ assert result == ["begin", "fast", "slow"]
+
+
+@pytest.mark.tmp
+def test_serial_in_parallel():
+ result = []
+
+ def fast(_):
+ result.append("fast")
+
+ def slow(_):
+ sleep(0.1)
+ result.append("slow")
+
+ with task.start(async_mode=True):
+ task.use(lambda _: result.append("begin"))
+ task.use(task.serial(slow, fast))
+ task.run(max_step=1)
+
+ assert result == ["begin", "slow", "fast"]
+
+
+@pytest.mark.unittest
+def test_nested_middleware():
+ """
+ When there is a yield in the middleware,
+ calling this middleware in another will lead to an unexpected result.
+ Use task.forward or task.wrap can fix this problem.
+ """
+ result = []
+
+ def child():
+
+ def _child(ctx: Context):
+ result.append(3 * ctx.total_step)
+ yield
+ result.append(2 + 3 * ctx.total_step)
+
+ return _child
+
+ def mother():
+ _child = task.wrap(child())
+
+ def _mother(ctx: Context):
+ child_back = _child(ctx)
+ result.append(1 + 3 * ctx.total_step)
+ child_back()
+
+ return _mother
+
+ with task.start():
+ task.use(mother())
+ task.run(2)
+ assert result == [0, 1, 2, 3, 4, 5]
+
+
+@pytest.mark.unittest
+def test_use_lock():
+
+ def slow(ctx):
+ sleep(0.1)
+ ctx.result = "slow"
+
+ def fast(ctx):
+ ctx.result = "fast"
+
+ with task.start(async_mode=True):
+ # The lock will turn async middleware into serial
+ task.use(slow, lock=True)
+ task.use(fast, lock=True)
+ task.run(1)
+ assert task.ctx.result == "fast"
+
+ # With custom lock, it will not affect the inner lock of task
+ lock = Lock()
+
+ def slowest(ctx):
+ sleep(0.3)
+ ctx.result = "slowest"
+
+ with task.start(async_mode=True):
+ task.use(slow, lock=lock)
+ # If it receives other locks, it will not be the last one to finish execution
+ task.use(slowest, lock=True)
+ task.use(fast, lock=lock)
+ task.run(1)
+ assert task.ctx.result == "slowest"
+
+
+def broadcast_finish_main():
+ with task.start():
+
+ def tick(ctx: Context):
+ if task.router.node_id == 1 and ctx.total_step == 1:
+ task.finish = True
+ sleep(1)
+
+ task.use(tick)
+ task.run(20)
+
+
+def broadcast_main_target():
+ Parallel.runner(
+ n_parallel_workers=1, protocol="tcp", address="127.0.0.1", topology="mesh", ports=50555, startup_interval=0.1
+ )(broadcast_finish_main)
+
+
+def broadcast_secondary_target():
+ "Start two standalone processes and connect to the main process."
+ Parallel.runner(
+ n_parallel_workers=2,
+ protocol="tcp",
+ address="127.0.0.1",
+ topology="alone",
+ ports=50556,
+ attach_to=["tcp://127.0.0.1:50555"],
+ node_ids=[1, 2],
+ startup_interval=0.1
+ )(broadcast_finish_main)
+
+
+@pytest.mark.tmp # gitlab ci and local test pass, github always fail
+@pytest.mark.timeout(10)
+def test_broadcast_finish():
+ start = time()
+ ctx = mp.get_context("spawn")
+ main_process = ctx.Process(target=broadcast_main_target)
+ secondary_process = ctx.Process(target=broadcast_secondary_target)
+ main_process.start()
+ secondary_process.start()
+ main_process.join()
+ secondary_process.join()
+ assert (time() - start) < 10
diff --git a/DI-engine/ding/framework/tests/test_wrapper.py b/DI-engine/ding/framework/tests/test_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a1e0c453d0bf564234cc8fe2833a9629a903aef
--- /dev/null
+++ b/DI-engine/ding/framework/tests/test_wrapper.py
@@ -0,0 +1,52 @@
+# In use mode
+# In forward mode
+# Wrapper in wrapper
+
+import pytest
+from ding.framework import task
+from ding.framework.wrapper import StepTimer
+
+
+@pytest.mark.unittest
+def test_step_timer():
+
+ def step1(_):
+ 1
+
+ def step2(_):
+ 2
+
+ def step3(_):
+ 3
+
+ def step4(_):
+ 4
+
+ step_timer = StepTimer()
+ with task.start(async_mode=True):
+ task.use_wrapper(step_timer)
+ task.use(step1)
+ task.use(step2)
+ task.use(task.serial(step3, step4))
+ assert len(task._middleware) == 3
+ task.run(3)
+
+ assert len(step_timer.records) == 3
+ for records in step_timer.records.values():
+ assert len(records) == 3
+
+ # Wrapper in wrapper
+ step_timer1 = StepTimer()
+ step_timer2 = StepTimer()
+ with task.start():
+ task.use_wrapper(step_timer1)
+ task.use_wrapper(step_timer2)
+ task.use(step1)
+ task.use(step2)
+ assert len(task._middleware) == 2
+ task.run(3)
+
+ for records in step_timer1.records.values():
+ assert len(records) == 3
+ for records in step_timer2.records.values():
+ assert len(records) == 3
diff --git a/DI-engine/ding/framework/wrapper/__init__.py b/DI-engine/ding/framework/wrapper/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4883440435c131c4e7bf65eab526e1cf36edc49b
--- /dev/null
+++ b/DI-engine/ding/framework/wrapper/__init__.py
@@ -0,0 +1 @@
+from .step_timer import StepTimer
diff --git a/DI-engine/ding/framework/wrapper/step_timer.py b/DI-engine/ding/framework/wrapper/step_timer.py
new file mode 100644
index 0000000000000000000000000000000000000000..dfabdd1476099299704fe982a26f75d04c296eea
--- /dev/null
+++ b/DI-engine/ding/framework/wrapper/step_timer.py
@@ -0,0 +1,57 @@
+from collections import deque, defaultdict
+from functools import wraps
+from types import GeneratorType
+from typing import Callable
+import numpy as np
+import time
+from ditk import logging
+from ding.framework import task
+
+
+class StepTimer:
+
+ def __init__(self, print_per_step: int = 1, smooth_window: int = 10) -> None:
+ """
+ Overview:
+ Print time cost of each step (execute one middleware).
+ Arguments:
+ - print_per_step (:obj:`int`): Print each N step.
+ - smooth_window (:obj:`int`): The window size to smooth the mean.
+ """
+
+ self.print_per_step = print_per_step
+ self.records = defaultdict(lambda: deque(maxlen=print_per_step * smooth_window))
+
+ def __call__(self, fn: Callable) -> Callable:
+ step_name = getattr(fn, "__name__", type(fn).__name__)
+
+ @wraps(fn)
+ def executor(ctx):
+ start_time = time.time()
+ time_cost = 0
+ g = fn(ctx)
+ if isinstance(g, GeneratorType):
+ try:
+ next(g)
+ except StopIteration:
+ pass
+ time_cost = time.time() - start_time
+ yield
+ start_time = time.time()
+ try:
+ next(g)
+ except StopIteration:
+ pass
+ time_cost += time.time() - start_time
+ else:
+ time_cost = time.time() - start_time
+ self.records[step_name].append(time_cost)
+ if ctx.total_step % self.print_per_step == 0:
+ logging.info(
+ "[Step Timer][Node:{:>2}] {}: Cost: {:.2f}ms, Mean: {:.2f}ms".format(
+ task.router.node_id or 0, step_name, time_cost * 1000,
+ np.mean(self.records[step_name]) * 1000
+ )
+ )
+
+ return executor
diff --git a/DI-engine/ding/hpc_rl/README.md b/DI-engine/ding/hpc_rl/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..92de0090bb92ea3c82e75f92c1ecb3df3774beff
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/README.md
@@ -0,0 +1,12 @@
+Step 0. clean old version
+rm ~/.local/lib/python3.6/site-packages/hpc_*.so
+rm ~/.local/lib/python3.6/site-packages/hpc_rl* -rf
+rm ~/.local/lib/python3.6/site-packages/di_hpc_rl* -rf
+
+Step 1.
+pip install di_hpc_rll-0.0.1-cp36-cp36m-linux_x86_64.whl --user
+ls ~/.local/lib/python3.6/site-packages/di_hpc_rl*
+ls ~/.local/lib/python3.6/site-packages/hpc_rl*
+
+Step 2.
+python3 tests/test_gae.py
\ No newline at end of file
diff --git a/DI-engine/ding/hpc_rl/__init__.py b/DI-engine/ding/hpc_rl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d359e89e067c990115314bcbe8dcbf3f4bdf5af
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/__init__.py
@@ -0,0 +1 @@
+from .wrapper import hpc_wrapper
diff --git a/DI-engine/ding/hpc_rl/tests/test_dntd.py b/DI-engine/ding/hpc_rl/tests/test_dntd.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbfa95070e670f5d3b17b7041496f8da659c3818
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/tests/test_dntd.py
@@ -0,0 +1,160 @@
+import time
+import torch
+from hpc_rll.origin.td import dist_nstep_td_error, dist_nstep_td_data
+from hpc_rll.rl_utils.td import DistNStepTD
+from testbase import mean_relative_error, times
+
+assert torch.cuda.is_available()
+use_cuda = True
+
+T = 128
+B = 128
+N = 128
+gamma = 0.95
+v_min = -10.0
+v_max = 10.0
+n_atom = 51
+
+
+def dntd_val():
+ ori_dist = torch.randn(B, N, n_atom).abs()
+ ori_next_n_dist = torch.randn(B, N, n_atom).abs()
+ ori_action = torch.randint(0, N, size=(B, ))
+ ori_next_n_action = torch.randint(0, N, size=(B, ))
+ ori_reward = torch.randn(T, B)
+ ori_done = torch.randn(B)
+ ori_weight = torch.randn(B)
+
+ hpc_dist = ori_dist.clone().detach()
+ hpc_next_n_dist = ori_next_n_dist.clone().detach()
+ hpc_action = ori_action.clone().detach()
+ hpc_next_n_action = ori_next_n_action.clone().detach()
+ hpc_reward = ori_reward.clone().detach()
+ hpc_done = ori_done.clone().detach()
+ hpc_weight = ori_weight.clone().detach()
+ hpc_dntd = DistNStepTD(T, B, N, n_atom)
+
+ if use_cuda:
+ ori_dist = ori_dist.cuda()
+ ori_next_n_dist = ori_next_n_dist.cuda()
+ ori_action = ori_action.cuda()
+ ori_next_n_action = ori_next_n_action.cuda()
+ ori_reward = ori_reward.cuda()
+ ori_done = ori_done.cuda()
+ ori_weight = ori_weight.cuda()
+
+ hpc_dist = hpc_dist.cuda()
+ hpc_next_n_dist = hpc_next_n_dist.cuda()
+ hpc_action = hpc_action.cuda()
+ hpc_next_n_action = hpc_next_n_action.cuda()
+ hpc_reward = hpc_reward.cuda()
+ hpc_done = hpc_done.cuda()
+ hpc_weight = hpc_weight.cuda()
+ hpc_dntd = hpc_dntd.cuda()
+
+ ori_dist.requires_grad_(True)
+ ori_loss, ori_td_err = dist_nstep_td_error(
+ dist_nstep_td_data(ori_dist, ori_next_n_dist, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight),
+ gamma, v_min, v_max, n_atom, T
+ )
+ ori_loss = ori_loss.mean()
+ ori_loss.backward()
+
+ hpc_dist.requires_grad_(True)
+ hpc_loss, hpc_td_err = hpc_dntd(
+ hpc_dist, hpc_next_n_dist, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma, v_min, v_max
+ )
+ hpc_loss = hpc_loss.mean()
+ hpc_loss.backward()
+
+ mre = mean_relative_error(
+ torch.flatten(ori_loss).cpu().detach().numpy(),
+ torch.flatten(hpc_loss).cpu().detach().numpy()
+ )
+ print("dntd fp mean_relative_error: " + str(mre))
+ mre = mean_relative_error(
+ torch.flatten(ori_td_err).cpu().detach().numpy(),
+ torch.flatten(hpc_td_err).cpu().detach().numpy()
+ )
+ print("dntd fp td_err mean_relative_error: " + str(mre))
+ mre = mean_relative_error(
+ torch.flatten(ori_dist.grad).cpu().detach().numpy(),
+ torch.flatten(hpc_dist.grad).cpu().detach().numpy()
+ )
+ print("dntd bp mean_relative_error: " + str(mre))
+
+
+def dntd_perf():
+ ori_dist = torch.randn(B, N, n_atom).abs()
+ ori_next_n_dist = torch.randn(B, N, n_atom).abs()
+ ori_action = torch.randint(0, N, size=(B, ))
+ ori_next_n_action = torch.randint(0, N, size=(B, ))
+ ori_reward = torch.randn(T, B)
+ ori_done = torch.randn(B)
+ ori_weight = torch.randn(B)
+
+ hpc_dist = ori_dist.clone().detach()
+ hpc_next_n_dist = ori_next_n_dist.clone().detach()
+ hpc_action = ori_action.clone().detach()
+ hpc_next_n_action = ori_next_n_action.clone().detach()
+ hpc_reward = ori_reward.clone().detach()
+ hpc_done = ori_done.clone().detach()
+ hpc_weight = ori_weight.clone().detach()
+ hpc_dntd = DistNStepTD(T, B, N, n_atom)
+
+ if use_cuda:
+ ori_dist = ori_dist.cuda()
+ ori_next_n_dist = ori_next_n_dist.cuda()
+ ori_action = ori_action.cuda()
+ ori_next_n_action = ori_next_n_action.cuda()
+ ori_reward = ori_reward.cuda()
+ ori_done = ori_done.cuda()
+ ori_weight = ori_weight.cuda()
+
+ hpc_dist = hpc_dist.cuda()
+ hpc_next_n_dist = hpc_next_n_dist.cuda()
+ hpc_action = hpc_action.cuda()
+ hpc_next_n_action = hpc_next_n_action.cuda()
+ hpc_reward = hpc_reward.cuda()
+ hpc_done = hpc_done.cuda()
+ hpc_weight = hpc_weight.cuda()
+ hpc_dntd = hpc_dntd.cuda()
+
+ ori_dist.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ ori_loss, ori_td_err = dist_nstep_td_error(
+ dist_nstep_td_data(
+ ori_dist, ori_next_n_dist, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight
+ ), gamma, v_min, v_max, n_atom, T
+ )
+ ori_loss = ori_loss.mean()
+ ori_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, origin dntd cost time: {}'.format(i, time.time() - t))
+
+ hpc_dist.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ hpc_loss, hpc_td_err = hpc_dntd(
+ hpc_dist, hpc_next_n_dist, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma, v_min,
+ v_max
+ )
+ hpc_loss = hpc_loss.mean()
+ hpc_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, hpc dntd cost time: {}'.format(i, time.time() - t))
+
+
+if __name__ == '__main__':
+ print(
+ "target problem: T = {}, B = {}, N = {}, gamma = {}, v_min = {}, v_max = {}, n_atom = {}".format(
+ T, B, N, gamma, v_min, v_max, n_atom
+ )
+ )
+ print("================run dntd validation test================")
+ dntd_val()
+ print("================run dntd performance test================")
+ dntd_perf()
diff --git a/DI-engine/ding/hpc_rl/tests/test_gae.py b/DI-engine/ding/hpc_rl/tests/test_gae.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bc6f351825537e46bde5c200d15b0263b9e260c
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/tests/test_gae.py
@@ -0,0 +1,65 @@
+import time
+import torch
+from hpc_rll.origin.gae import gae, gae_data
+from hpc_rll.rl_utils.gae import GAE
+from testbase import mean_relative_error, times
+
+assert torch.cuda.is_available()
+use_cuda = True
+
+T = 1024
+B = 64
+
+
+def gae_val():
+ value = torch.randn(T + 1, B)
+ reward = torch.randn(T, B)
+
+ hpc_gae = GAE(T, B)
+
+ if use_cuda:
+ value = value.cuda()
+ reward = reward.cuda()
+ hpc_gae = hpc_gae.cuda()
+ ori_adv = gae(gae_data(value, reward))
+ hpc_adv = hpc_gae(value, reward)
+ if use_cuda:
+ torch.cuda.synchronize()
+
+ mre = mean_relative_error(
+ torch.flatten(ori_adv).cpu().detach().numpy(),
+ torch.flatten(hpc_adv).cpu().detach().numpy()
+ )
+ print("gae mean_relative_error: " + str(mre))
+
+
+def gae_perf():
+ value = torch.randn(T + 1, B)
+ reward = torch.randn(T, B)
+
+ hpc_gae = GAE(T, B)
+
+ if use_cuda:
+ value = value.cuda()
+ reward = reward.cuda()
+ hpc_gae = hpc_gae.cuda()
+ for i in range(times):
+ t = time.time()
+ adv = gae(gae_data(value, reward))
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, original gae cost time: {}'.format(i, time.time() - t))
+ for i in range(times):
+ t = time.time()
+ hpc_adv = hpc_gae(value, reward)
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, hpc gae cost time: {}'.format(i, time.time() - t))
+
+
+if __name__ == '__main__':
+ print("target problem: T = {}, B = {}".format(T, B))
+ print("================run gae validation test================")
+ gae_val()
+ print("================run gae performance test================")
+ gae_perf()
diff --git a/DI-engine/ding/hpc_rl/tests/test_lstm.py b/DI-engine/ding/hpc_rl/tests/test_lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..f752abd3e7e97e5788172f1002cf0feae4a6c745
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/tests/test_lstm.py
@@ -0,0 +1,140 @@
+import time
+import torch
+from hpc_rll.origin.rnn import get_lstm
+from hpc_rll.torch_utils.network.rnn import LSTM
+from testbase import mean_relative_error, times
+
+assert torch.cuda.is_available()
+use_cuda = True
+
+seq_len = 64
+batch_size = 3
+input_size = 1792
+hidden_size = 384
+num_layers = 3
+norm_type = 'LN'
+dropout = 0 # 0.1
+
+
+# Note: need open load_params for hpc_lstm to validation
+# Note: only used to case of num_layers = 3
+def lstm_val():
+ ori_lstm = get_lstm('normal', input_size, hidden_size, num_layers, norm_type, dropout)
+ hpc_lstm = LSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout)
+
+ ori_x = torch.randn(seq_len, batch_size, input_size)
+ ori_h0 = torch.randn(num_layers, batch_size, hidden_size)
+ ori_c0 = torch.randn(num_layers, batch_size, hidden_size)
+
+ if use_cuda:
+ ori_x = ori_x.cuda()
+ ori_h0 = ori_h0.cuda()
+ ori_c0 = ori_c0.cuda()
+ ori_lstm = ori_lstm.cuda()
+ hpc_lstm = hpc_lstm.cuda()
+
+ ori_x.requires_grad_(True)
+ ori_output, ori_next_state = ori_lstm(ori_x, [ori_h0, ori_c0])
+ ori_loss = ori_output.mean()
+ ori_loss.backward()
+
+ hpc_x = ori_x.clone().detach()
+ hpc_h0 = ori_h0.clone().detach()
+ hpc_c0 = ori_c0.clone().detach()
+ hpc_x.requires_grad_(True)
+ hpc_output, hpc_next_state = hpc_lstm(hpc_x, [hpc_h0, hpc_c0])
+ hpc_loss = hpc_output.mean()
+ hpc_loss.backward()
+ torch.cuda.synchronize()
+
+ mre = mean_relative_error(
+ torch.flatten(ori_loss).cpu().detach().numpy(),
+ torch.flatten(hpc_loss).cpu().detach().numpy()
+ )
+ print("lstm fp mean_relative_error: " + str(mre))
+ mre = mean_relative_error(
+ torch.flatten(ori_x.grad).cpu().detach().numpy(),
+ torch.flatten(hpc_x.grad).cpu().detach().numpy()
+ )
+ print("lstm bp mean_relative_error: " + str(mre))
+
+ ori_wx_grad = torch.cat((ori_lstm.wx[0].grad, ori_lstm.wx[1].grad, ori_lstm.wx[2].grad))
+ hpc_wx_grad = hpc_lstm.wx.grad
+ mre = mean_relative_error(torch.flatten(ori_wx_grad).cpu().numpy(), torch.flatten(hpc_wx_grad).cpu().numpy())
+ print("wx grad mean_relative_error: " + str(mre))
+
+ ori_wh_grad = torch.cat((ori_lstm.wh[0].grad, ori_lstm.wh[1].grad, ori_lstm.wh[2].grad))
+ hpc_wh_grad = hpc_lstm.wh.grad
+ mre = mean_relative_error(torch.flatten(ori_wh_grad).cpu().numpy(), torch.flatten(hpc_wh_grad).cpu().numpy())
+ print("wh grad mean_relative_error: " + str(mre))
+
+ ori_bias_grad = ori_lstm.bias.grad
+ hpc_bias_grad = hpc_lstm.bias.grad
+ mre = mean_relative_error(torch.flatten(ori_bias_grad).cpu().numpy(), torch.flatten(hpc_bias_grad).cpu().numpy())
+ print("bias grad mean_relative_error: " + str(mre))
+
+ params = list(ori_lstm.parameters())
+ gamma_0_x = params[1]
+ beta_0_x = params[2]
+ gamma_0_h = params[3]
+ beta_0_h = params[4]
+ gamma_1_x = params[5]
+ beta_1_x = params[6]
+ gamma_1_h = params[7]
+ beta_1_h = params[8]
+ gamma_2_x = params[9]
+ beta_2_x = params[10]
+ gamma_2_h = params[11]
+ beta_2_h = params[12]
+ ori_gamma_grad = torch.cat(
+ (gamma_0_x.grad, gamma_0_h.grad, gamma_1_x.grad, gamma_1_h.grad, gamma_2_x.grad, gamma_2_h.grad)
+ )
+ ori_beta_grad = torch.cat(
+ (beta_0_x.grad, beta_0_h.grad, beta_1_x.grad, beta_1_h.grad, beta_2_x.grad, beta_2_h.grad)
+ )
+ hpc_gamma_grad = hpc_lstm.ln_gamma.grad
+ hpc_beta_grad = hpc_lstm.ln_beta.grad
+ mre = mean_relative_error(torch.flatten(ori_gamma_grad).cpu().numpy(), torch.flatten(hpc_gamma_grad).cpu().numpy())
+ print("ln gamma grad mean_relative_error: " + str(mre))
+ mre = mean_relative_error(torch.flatten(ori_beta_grad).cpu().numpy(), torch.flatten(hpc_beta_grad).cpu().numpy())
+ print("ln beta grad mean_relative_error: " + str(mre))
+
+
+def lstm_perf():
+ ori_lstm = get_lstm('normal', input_size, hidden_size, num_layers, norm_type, dropout)
+ hpc_lstm = LSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout)
+
+ lstms = {'normal': ori_lstm, 'hpc': hpc_lstm}
+
+ for lstm_type, lstm in lstms.items():
+ x = torch.rand(seq_len, batch_size, input_size)
+ h0 = torch.randn(num_layers, batch_size, hidden_size)
+ c0 = torch.randn(num_layers, batch_size, hidden_size)
+ if use_cuda:
+ x = x.cuda()
+ h0 = h0.cuda()
+ c0 = c0.cuda()
+ lstm = lstm.cuda()
+
+ prev_state = [h0, c0]
+ x.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ output, _ = lstm(x, prev_state)
+ loss = output.mean()
+ loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, {} lstm cost time: {}'.format(i, lstm_type, time.time() - t))
+
+
+if __name__ == '__main__':
+ print(
+ "target problem: seq_len = {}, batch_size = {}, input_size = {}, hidden_size = {}, num_layers = {}, norm_type = {}, dropout = {}" # noqa
+ .format(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout)
+ )
+ print("==============lstm has no validation test================")
+ #print("===============run lstm validation test==================")
+ #lstm_val()
+ print("===============run lstm performance test=================")
+ lstm_perf()
diff --git a/DI-engine/ding/hpc_rl/tests/test_ppo.py b/DI-engine/ding/hpc_rl/tests/test_ppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..a1c5557ce12277566390464237f81baa463bd861
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/tests/test_ppo.py
@@ -0,0 +1,176 @@
+import time
+import torch
+import torch.nn.functional as F
+from hpc_rll.origin.ppo import ppo_error, ppo_data
+from hpc_rll.rl_utils.ppo import PPO
+from testbase import mean_relative_error, times
+
+assert torch.cuda.is_available()
+use_cuda = True
+
+B = 128
+N = 128
+clip_ratio = 0.2
+use_value_clip = True
+dual_clip = None
+
+
+def ppo_val():
+ ori_logits_new = torch.randn(B, N)
+ ori_logits_old = torch.randn(B, N)
+ ori_action = torch.randint(0, N, size=(B, ))
+ ori_value_new = torch.randn(B)
+ ori_value_old = torch.randn(B)
+ ori_adv = torch.randn(B)
+ ori_return = torch.randn(B)
+ ori_weight = torch.randn(B)
+
+ hpc_logits_new = ori_logits_new.clone().detach()
+ hpc_logits_old = ori_logits_old.clone().detach()
+ hpc_action = ori_action.clone().detach()
+ hpc_value_new = ori_value_new.clone().detach()
+ hpc_value_old = ori_value_old.clone().detach()
+ hpc_adv = ori_adv.clone().detach()
+ hpc_return = ori_return.clone().detach()
+ hpc_weight = ori_weight.clone().detach()
+ hpc_ppo = PPO(B, N)
+
+ if use_cuda:
+ ori_logits_new = ori_logits_new.cuda()
+ ori_logits_old = ori_logits_old.cuda()
+ ori_action = ori_action.cuda()
+ ori_value_new = ori_value_new.cuda()
+ ori_value_old = ori_value_old.cuda()
+ ori_adv = ori_adv.cuda()
+ ori_return = ori_return.cuda()
+ ori_weight = ori_weight.cuda()
+
+ hpc_logits_new = hpc_logits_new.cuda()
+ hpc_logits_old = hpc_logits_old.cuda()
+ hpc_action = hpc_action.cuda()
+ hpc_value_new = hpc_value_new.cuda()
+ hpc_value_old = hpc_value_old.cuda()
+ hpc_adv = hpc_adv.cuda()
+ hpc_return = hpc_return.cuda()
+ hpc_weight = hpc_weight.cuda()
+ hpc_ppo = hpc_ppo.cuda()
+
+ ori_logits_new.requires_grad_(True)
+ ori_value_new.requires_grad_(True)
+ ori_loss, ori_info = ppo_error(
+ ppo_data(
+ ori_logits_new, ori_logits_old, ori_action, ori_value_new, ori_value_old, ori_adv, ori_return, ori_weight
+ ), clip_ratio, use_value_clip, dual_clip
+ )
+ ori_loss = sum(ori_loss)
+ ori_loss.backward()
+
+ hpc_logits_new.requires_grad_(True)
+ hpc_value_new.requires_grad_(True)
+ hpc_loss, hpc_info = hpc_ppo(
+ hpc_logits_new, hpc_logits_old, hpc_action, hpc_value_new, hpc_value_old, hpc_adv, hpc_return, hpc_weight,
+ clip_ratio, use_value_clip, dual_clip
+ )
+ hpc_loss = sum(hpc_loss)
+ hpc_loss.backward()
+
+ print("ori_info: " + str(ori_info))
+ print("hpc_info: " + str(hpc_info))
+ mre = mean_relative_error(
+ torch.flatten(ori_loss).cpu().detach().numpy(),
+ torch.flatten(hpc_loss).cpu().detach().numpy()
+ )
+ print("ppo fp loss mean_relative_error: " + str(mre))
+ mre = mean_relative_error(
+ torch.flatten(ori_logits_new.grad).cpu().detach().numpy(),
+ torch.flatten(hpc_logits_new.grad).cpu().detach().numpy()
+ )
+ print("ppo bp logits_new mean_relative_error: " + str(mre))
+ mre = mean_relative_error(
+ torch.flatten(ori_value_new.grad).cpu().detach().numpy(),
+ torch.flatten(hpc_value_new.grad).cpu().detach().numpy()
+ )
+ print("ppo bp value_new mean_relative_error: " + str(mre))
+
+
+def ppo_perf():
+ ori_logits_new = torch.randn(B, N)
+ ori_logits_old = torch.randn(B, N)
+ ori_action = torch.randint(0, N, size=(B, ))
+ ori_value_new = torch.randn(B)
+ ori_value_old = torch.randn(B)
+ ori_adv = torch.randn(B)
+ ori_return = torch.randn(B)
+ ori_weight = torch.randn(B)
+
+ hpc_logits_new = ori_logits_new.clone().detach()
+ hpc_logits_old = ori_logits_old.clone().detach()
+ hpc_action = ori_action.clone().detach()
+ hpc_value_new = ori_value_new.clone().detach()
+ hpc_value_old = ori_value_old.clone().detach()
+ hpc_adv = ori_adv.clone().detach()
+ hpc_return = ori_return.clone().detach()
+ hpc_weight = ori_weight.clone().detach()
+ hpc_ppo = PPO(B, N)
+
+ if use_cuda:
+ ori_logits_new = ori_logits_new.cuda()
+ ori_logits_old = ori_logits_old.cuda()
+ ori_action = ori_action.cuda()
+ ori_value_new = ori_value_new.cuda()
+ ori_value_old = ori_value_old.cuda()
+ ori_adv = ori_adv.cuda()
+ ori_return = ori_return.cuda()
+ ori_weight = ori_weight.cuda()
+
+ hpc_logits_new = hpc_logits_new.cuda()
+ hpc_logits_old = hpc_logits_old.cuda()
+ hpc_action = hpc_action.cuda()
+ hpc_value_new = hpc_value_new.cuda()
+ hpc_value_old = hpc_value_old.cuda()
+ hpc_adv = hpc_adv.cuda()
+ hpc_return = hpc_return.cuda()
+ hpc_weight = hpc_weight.cuda()
+ hpc_ppo = hpc_ppo.cuda()
+
+ ori_logits_new.requires_grad_(True)
+ ori_value_new.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ ori_loss, ori_info = ppo_error(
+ ppo_data(
+ ori_logits_new, ori_logits_old, ori_action, ori_value_new, ori_value_old, ori_adv, ori_return,
+ ori_weight
+ ), clip_ratio, use_value_clip, dual_clip
+ )
+ ori_loss = sum(ori_loss)
+ ori_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, origin ppo cost time: {}'.format(i, time.time() - t))
+
+ hpc_logits_new.requires_grad_(True)
+ hpc_value_new.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ hpc_loss, hpc_info = hpc_ppo(
+ hpc_logits_new, hpc_logits_old, hpc_action, hpc_value_new, hpc_value_old, hpc_adv, hpc_return, hpc_weight,
+ clip_ratio, use_value_clip, dual_clip
+ )
+ hpc_loss = sum(hpc_loss)
+ hpc_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, hpc ppo cost time: {}'.format(i, time.time() - t))
+
+
+if __name__ == '__main__':
+ print(
+ "target problem: B = {}, N = {}, clip_ratio = {}, use_value_clip = {}, dual_clip = {}".format(
+ B, N, clip_ratio, use_value_clip, dual_clip
+ )
+ )
+ print("================run ppo validation test================")
+ ppo_val()
+ print("================run ppo performance test================")
+ ppo_perf()
diff --git a/DI-engine/ding/hpc_rl/tests/test_qntd.py b/DI-engine/ding/hpc_rl/tests/test_qntd.py
new file mode 100644
index 0000000000000000000000000000000000000000..6943c7ff594db86af078bfa9ae805c12c793868c
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/tests/test_qntd.py
@@ -0,0 +1,147 @@
+import time
+import torch
+from hpc_rll.origin.td import q_nstep_td_error, q_nstep_td_data
+from hpc_rll.rl_utils.td import QNStepTD
+from testbase import mean_relative_error, times
+
+assert torch.cuda.is_available()
+use_cuda = True
+
+T = 1024
+B = 64
+N = 64
+gamma = 0.95
+
+
+def qntd_val():
+ ori_q = torch.randn(B, N)
+ ori_next_n_q = torch.randn(B, N)
+ ori_action = torch.randint(0, N, size=(B, ))
+ ori_next_n_action = torch.randint(0, N, size=(B, ))
+ ori_reward = torch.randn(T, B)
+ ori_done = torch.randn(B)
+ ori_weight = torch.randn(B)
+
+ hpc_q = ori_q.clone().detach()
+ hpc_next_n_q = ori_next_n_q.clone().detach()
+ hpc_action = ori_action.clone().detach()
+ hpc_next_n_action = ori_next_n_action.clone().detach()
+ hpc_reward = ori_reward.clone().detach()
+ hpc_done = ori_done.clone().detach()
+ hpc_weight = ori_weight.clone().detach()
+ hpc_qntd = QNStepTD(T, B, N)
+
+ if use_cuda:
+ ori_q = ori_q.cuda()
+ ori_next_n_q = ori_next_n_q.cuda()
+ ori_action = ori_action.cuda()
+ ori_next_n_action = ori_next_n_action.cuda()
+ ori_reward = ori_reward.cuda()
+ ori_done = ori_done.cuda()
+ ori_weight = ori_weight.cuda()
+
+ hpc_q = hpc_q.cuda()
+ hpc_next_n_q = hpc_next_n_q.cuda()
+ hpc_action = hpc_action.cuda()
+ hpc_next_n_action = hpc_next_n_action.cuda()
+ hpc_reward = hpc_reward.cuda()
+ hpc_done = hpc_done.cuda()
+ hpc_weight = hpc_weight.cuda()
+ hpc_qntd = hpc_qntd.cuda()
+
+ ori_q.requires_grad_(True)
+ ori_loss, _ = q_nstep_td_error(
+ q_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight), gamma, T
+ )
+ ori_loss = ori_loss.mean()
+ ori_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+
+ hpc_q.requires_grad_(True)
+ hpc_loss, _ = hpc_qntd(hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma)
+ hpc_loss = hpc_loss.mean()
+ hpc_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+
+ mre = mean_relative_error(
+ torch.flatten(ori_loss).cpu().detach().numpy(),
+ torch.flatten(hpc_loss).cpu().detach().numpy()
+ )
+ print("qntd fp mean_relative_error: " + str(mre))
+ mre = mean_relative_error(
+ torch.flatten(ori_q.grad).cpu().detach().numpy(),
+ torch.flatten(hpc_q.grad).cpu().detach().numpy()
+ )
+ print("qntd bp mean_relative_error: " + str(mre))
+
+
+def qntd_perf():
+ ori_q = torch.randn(B, N)
+ ori_next_n_q = torch.randn(B, N)
+ ori_action = torch.randint(0, N, size=(B, ))
+ ori_next_n_action = torch.randint(0, N, size=(B, ))
+ ori_reward = torch.randn(T, B)
+ ori_done = torch.randn(B)
+ ori_weight = torch.randn(B)
+
+ hpc_q = ori_q.clone().detach()
+ hpc_next_n_q = ori_next_n_q.clone().detach()
+ hpc_action = ori_action.clone().detach()
+ hpc_next_n_action = ori_next_n_action.clone().detach()
+ hpc_reward = ori_reward.clone().detach()
+ hpc_done = ori_done.clone().detach()
+ hpc_weight = ori_weight.clone().detach()
+ hpc_qntd = QNStepTD(T, B, N)
+
+ if use_cuda:
+ ori_q = ori_q.cuda()
+ ori_next_n_q = ori_next_n_q.cuda()
+ ori_action = ori_action.cuda()
+ ori_next_n_action = ori_next_n_action.cuda()
+ ori_reward = ori_reward.cuda()
+ ori_done = ori_done.cuda()
+ ori_weight = ori_weight.cuda()
+
+ hpc_q = hpc_q.cuda()
+ hpc_next_n_q = hpc_next_n_q.cuda()
+ hpc_action = hpc_action.cuda()
+ hpc_next_n_action = hpc_next_n_action.cuda()
+ hpc_reward = hpc_reward.cuda()
+ hpc_done = hpc_done.cuda()
+ hpc_weight = hpc_weight.cuda()
+ hpc_qntd = hpc_qntd.cuda()
+
+ ori_q.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ ori_loss, _ = q_nstep_td_error(
+ q_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight),
+ gamma, T
+ )
+ ori_loss = ori_loss.mean()
+ ori_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, original qntd cost time: {}'.format(i, time.time() - t))
+
+ hpc_q.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ hpc_loss, _ = hpc_qntd(
+ hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma
+ )
+ hpc_loss = hpc_loss.mean()
+ hpc_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, hpc qntd cost time: {}'.format(i, time.time() - t))
+
+
+if __name__ == '__main__':
+ print("target problem: T = {}, B = {}, N = {}, gamma = {}".format(T, B, N, gamma))
+ print("================run qntd validation test================")
+ qntd_val()
+ print("================run qntd performance test================")
+ qntd_perf()
diff --git a/DI-engine/ding/hpc_rl/tests/test_qntd_rescale.py b/DI-engine/ding/hpc_rl/tests/test_qntd_rescale.py
new file mode 100644
index 0000000000000000000000000000000000000000..076281113f0eccbe786e1e95a1b68f4cd64a15f0
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/tests/test_qntd_rescale.py
@@ -0,0 +1,149 @@
+import time
+import torch
+from hpc_rll.origin.td import q_nstep_td_error_with_rescale, q_nstep_td_data
+from hpc_rll.rl_utils.td import QNStepTDRescale
+from testbase import mean_relative_error, times
+
+assert torch.cuda.is_available()
+use_cuda = True
+
+T = 1024
+B = 64
+N = 64
+gamma = 0.95
+
+
+def qntd_rescale_val():
+ ori_q = torch.randn(B, N)
+ ori_next_n_q = torch.randn(B, N)
+ ori_action = torch.randint(0, N, size=(B, ))
+ ori_next_n_action = torch.randint(0, N, size=(B, ))
+ ori_reward = torch.randn(T, B)
+ ori_done = torch.randn(B)
+ ori_weight = torch.randn(B)
+
+ hpc_q = ori_q.clone().detach()
+ hpc_next_n_q = ori_next_n_q.clone().detach()
+ hpc_action = ori_action.clone().detach()
+ hpc_next_n_action = ori_next_n_action.clone().detach()
+ hpc_reward = ori_reward.clone().detach()
+ hpc_done = ori_done.clone().detach()
+ hpc_weight = ori_weight.clone().detach()
+ hpc_qntd_rescale = QNStepTDRescale(T, B, N)
+
+ if use_cuda:
+ ori_q = ori_q.cuda()
+ ori_next_n_q = ori_next_n_q.cuda()
+ ori_action = ori_action.cuda()
+ ori_next_n_action = ori_next_n_action.cuda()
+ ori_reward = ori_reward.cuda()
+ ori_done = ori_done.cuda()
+ ori_weight = ori_weight.cuda()
+
+ hpc_q = hpc_q.cuda()
+ hpc_next_n_q = hpc_next_n_q.cuda()
+ hpc_action = hpc_action.cuda()
+ hpc_next_n_action = hpc_next_n_action.cuda()
+ hpc_reward = hpc_reward.cuda()
+ hpc_done = hpc_done.cuda()
+ hpc_weight = hpc_weight.cuda()
+ hpc_qntd_rescale = hpc_qntd_rescale.cuda()
+
+ ori_q.requires_grad_(True)
+ ori_loss, _ = q_nstep_td_error_with_rescale(
+ q_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight), gamma, T
+ )
+ ori_loss = ori_loss.mean()
+ ori_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+
+ hpc_q.requires_grad_(True)
+ hpc_loss, _ = hpc_qntd_rescale(
+ hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma
+ )
+ hpc_loss = hpc_loss.mean()
+ hpc_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+
+ mre = mean_relative_error(
+ torch.flatten(ori_loss).cpu().detach().numpy(),
+ torch.flatten(hpc_loss).cpu().detach().numpy()
+ )
+ print("qntd rescale fp mean_relative_error: " + str(mre))
+ mre = mean_relative_error(
+ torch.flatten(ori_q.grad).cpu().detach().numpy(),
+ torch.flatten(hpc_q.grad).cpu().detach().numpy()
+ )
+ print("qntd rescale bp mean_relative_error: " + str(mre))
+
+
+def qntd_rescale_perf():
+ ori_q = torch.randn(B, N)
+ ori_next_n_q = torch.randn(B, N)
+ ori_action = torch.randint(0, N, size=(B, ))
+ ori_next_n_action = torch.randint(0, N, size=(B, ))
+ ori_reward = torch.randn(T, B)
+ ori_done = torch.randn(B)
+ ori_weight = torch.randn(B)
+
+ hpc_q = ori_q.clone().detach()
+ hpc_next_n_q = ori_next_n_q.clone().detach()
+ hpc_action = ori_action.clone().detach()
+ hpc_next_n_action = ori_next_n_action.clone().detach()
+ hpc_reward = ori_reward.clone().detach()
+ hpc_done = ori_done.clone().detach()
+ hpc_weight = ori_weight.clone().detach()
+ hpc_qntd_rescale = QNStepTDRescale(T, B, N)
+
+ if use_cuda:
+ ori_q = ori_q.cuda()
+ ori_next_n_q = ori_next_n_q.cuda()
+ ori_action = ori_action.cuda()
+ ori_next_n_action = ori_next_n_action.cuda()
+ ori_reward = ori_reward.cuda()
+ ori_done = ori_done.cuda()
+ ori_weight = ori_weight.cuda()
+
+ hpc_q = hpc_q.cuda()
+ hpc_next_n_q = hpc_next_n_q.cuda()
+ hpc_action = hpc_action.cuda()
+ hpc_next_n_action = hpc_next_n_action.cuda()
+ hpc_reward = hpc_reward.cuda()
+ hpc_done = hpc_done.cuda()
+ hpc_weight = hpc_weight.cuda()
+ hpc_qntd_rescale = hpc_qntd_rescale.cuda()
+
+ ori_q.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ ori_loss, _ = q_nstep_td_error_with_rescale(
+ q_nstep_td_data(ori_q, ori_next_n_q, ori_action, ori_next_n_action, ori_reward, ori_done, ori_weight),
+ gamma, T
+ )
+ ori_loss = ori_loss.mean()
+ ori_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, original qntd rescale cost time: {}'.format(i, time.time() - t))
+
+ hpc_q.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ hpc_loss, _ = hpc_qntd_rescale(
+ hpc_q, hpc_next_n_q, hpc_action, hpc_next_n_action, hpc_reward, hpc_done, hpc_weight, gamma
+ )
+ hpc_loss = hpc_loss.mean()
+ hpc_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, hpc qntd rescale cost time: {}'.format(i, time.time() - t))
+
+
+if __name__ == '__main__':
+ print("target problem: T = {}, B = {}, N = {}, gamma = {}".format(T, B, N, gamma))
+ print("================run qntd rescale validation test================")
+ qntd_rescale_val()
+ print("================run qntd rescale performance test================")
+ qntd_rescale_perf()
diff --git a/DI-engine/ding/hpc_rl/tests/test_scatter.py b/DI-engine/ding/hpc_rl/tests/test_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..cbdd5a705226a4e6c19a8071b4d680b87cdc8dae
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/tests/test_scatter.py
@@ -0,0 +1,138 @@
+import time
+import torch
+from typing import Tuple
+from hpc_rll.origin.scatter_connection import ScatterConnection
+from hpc_rll.torch_utils.network.scatter_connection import ScatterConnection as HPCScatterConnection
+from testbase import mean_relative_error, times
+
+assert torch.cuda.is_available()
+use_cuda = True
+
+B = 256
+M = 256
+N = 256
+H = 16
+W = 16
+
+
+# Note: origin gpu version of cover mode is not determinate, thus validation test use origin cpu version instead
+def scatter_val():
+ for scatter_type in ['add', 'cover']:
+ ori_input = torch.randn(B, M, N)
+ h = torch.randint(
+ low=0, high=H, size=(
+ B,
+ M,
+ )
+ ).unsqueeze(dim=2)
+ w = torch.randint(
+ low=0, high=W, size=(
+ B,
+ M,
+ )
+ ).unsqueeze(dim=2)
+ ori_location = torch.cat([h, w], dim=2)
+ ori_scatter = ScatterConnection(scatter_type)
+
+ hpc_input = ori_input.clone().detach()
+ hpc_location = ori_location.clone().detach()
+ hpc_scatter = HPCScatterConnection(B, M, N, H, W, scatter_type)
+
+ if use_cuda:
+ #ori_input = ori_input.cuda()
+ #ori_location = ori_location.cuda()
+ #ori_scatter = ori_scatter.cuda()
+
+ hpc_input = hpc_input.cuda()
+ hpc_location = hpc_location.cuda()
+ hpc_scatter = hpc_scatter.cuda()
+
+ ori_input.requires_grad_(True)
+ ori_output = ori_scatter(ori_input, (H, W), ori_location)
+ ori_loss = ori_output * ori_output
+ ori_loss = ori_loss.mean()
+ ori_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+
+ hpc_input.requires_grad_(True)
+ hpc_output = hpc_scatter(hpc_input, hpc_location)
+ hpc_loss = hpc_output * hpc_output
+ hpc_loss = hpc_loss.mean()
+ hpc_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+
+ mre = mean_relative_error(
+ torch.flatten(ori_loss).cpu().detach().numpy(),
+ torch.flatten(hpc_loss).cpu().detach().numpy()
+ )
+ print("scatter type {} fp mean_relative_error: {}".format(scatter_type, str(mre)))
+ mre = mean_relative_error(
+ torch.flatten(ori_input.grad).cpu().detach().numpy(),
+ torch.flatten(hpc_input.grad).cpu().detach().numpy()
+ )
+ print("scatter type {} bp mean_relative_error: {}".format(scatter_type, str(mre)))
+
+
+# Note: performance test use origin gpu version
+def scatter_perf():
+ for scatter_type in ['add', 'cover']:
+ ori_input = torch.randn(B, M, N)
+ h = torch.randint(
+ low=0, high=H, size=(
+ B,
+ M,
+ )
+ ).unsqueeze(dim=2)
+ w = torch.randint(
+ low=0, high=W, size=(
+ B,
+ M,
+ )
+ ).unsqueeze(dim=2)
+ ori_location = torch.cat([h, w], dim=2)
+ ori_scatter = ScatterConnection(scatter_type)
+
+ hpc_input = ori_input.clone().detach()
+ hpc_location = ori_location.clone().detach()
+ hpc_scatter = HPCScatterConnection(B, M, N, H, W, scatter_type)
+
+ if use_cuda:
+ ori_input = ori_input.cuda()
+ ori_location = ori_location.cuda()
+ ori_scatter = ori_scatter.cuda()
+
+ hpc_input = hpc_input.cuda()
+ hpc_location = hpc_location.cuda()
+ hpc_scatter = hpc_scatter.cuda()
+
+ for i in range(times):
+ t = time.time()
+ ori_input.requires_grad_(True)
+ ori_output = ori_scatter(ori_input, (H, W), ori_location)
+ ori_loss = ori_output * ori_output
+ ori_loss = ori_loss.mean()
+ ori_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, original scatter type {} cost time: {}'.format(i, scatter_type, time.time() - t))
+
+ for i in range(times):
+ t = time.time()
+ hpc_input.requires_grad_(True)
+ hpc_output = hpc_scatter(hpc_input, hpc_location)
+ hpc_loss = hpc_output * hpc_output
+ hpc_loss = hpc_loss.mean()
+ hpc_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, hpc scatter type {} cost time: {}'.format(i, scatter_type, time.time() - t))
+
+
+if __name__ == '__main__':
+ print("target problem: B = {}, M = {}, N = {}, H = {}, W = {}".format(B, M, N, H, W))
+ print("================run scatter validation test================")
+ scatter_val()
+ print("================run scatter performance test================")
+ scatter_perf()
diff --git a/DI-engine/ding/hpc_rl/tests/test_tdlambda.py b/DI-engine/ding/hpc_rl/tests/test_tdlambda.py
new file mode 100644
index 0000000000000000000000000000000000000000..90ee659af3ffc9fa7bcf54ebe2ab46ab41516159
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/tests/test_tdlambda.py
@@ -0,0 +1,106 @@
+import time
+import torch
+from hpc_rll.origin.td import td_lambda_error, td_lambda_data
+from hpc_rll.rl_utils.td import TDLambda
+from testbase import mean_relative_error, times
+
+assert torch.cuda.is_available()
+use_cuda = True
+
+T = 1024
+B = 64
+
+
+def td_val():
+ ori_value = torch.randn(T + 1, B)
+ ori_reward = torch.randn(T, B)
+ ori_weight = torch.randn(T, B)
+
+ hpc_value = ori_value.clone().detach()
+ hpc_reward = ori_reward.clone().detach()
+ hpc_weight = ori_weight.clone().detach()
+ hpc_td = TDLambda(T, B)
+
+ if use_cuda:
+ ori_value = ori_value.cuda()
+ ori_reward = ori_reward.cuda()
+ ori_weight = ori_weight.cuda()
+
+ hpc_value = hpc_value.cuda()
+ hpc_reward = hpc_reward.cuda()
+ hpc_weight = hpc_weight.cuda()
+ hpc_td = hpc_td.cuda()
+
+ ori_value.requires_grad_(True)
+ ori_loss = td_lambda_error(td_lambda_data(ori_value, ori_reward, ori_weight))
+ ori_loss = ori_loss.mean()
+ ori_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+
+ hpc_value.requires_grad_(True)
+ hpc_loss = hpc_td(hpc_value, hpc_reward, hpc_weight)
+ hpc_loss = hpc_loss.mean()
+ hpc_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+
+ mre = mean_relative_error(
+ torch.flatten(ori_loss).cpu().detach().numpy(),
+ torch.flatten(hpc_loss).cpu().detach().numpy()
+ )
+ print("td fp mean_relative_error: " + str(mre))
+ mre = mean_relative_error(
+ torch.flatten(ori_value.grad).cpu().detach().numpy(),
+ torch.flatten(hpc_value.grad).cpu().detach().numpy()
+ )
+ print("td bp mean_relative_error: " + str(mre))
+
+
+def td_perf():
+ ori_value = torch.randn(T + 1, B)
+ ori_reward = torch.randn(T, B)
+ ori_weight = torch.randn(T, B)
+
+ hpc_value = ori_value.clone().detach()
+ hpc_reward = ori_reward.clone().detach()
+ hpc_weight = ori_weight.clone().detach()
+ hpc_td = TDLambda(T, B)
+
+ if use_cuda:
+ ori_value = ori_value.cuda()
+ ori_reward = ori_reward.cuda()
+ ori_weight = ori_weight.cuda()
+
+ hpc_value = hpc_value.cuda()
+ hpc_reward = hpc_reward.cuda()
+ hpc_weight = hpc_weight.cuda()
+ hpc_td = hpc_td.cuda()
+
+ ori_value.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ ori_loss = td_lambda_error(td_lambda_data(ori_value, ori_reward, ori_weight))
+ ori_loss = ori_loss.mean()
+ ori_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, original td cost time: {}'.format(i, time.time() - t))
+
+ hpc_value.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ hpc_loss = hpc_td(hpc_value, hpc_reward, hpc_weight)
+ hpc_loss = hpc_loss.mean()
+ hpc_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, hpc td cost time: {}'.format(i, time.time() - t))
+
+
+if __name__ == '__main__':
+ print("target problem: T = {}, B = {}".format(T, B))
+ print("================run td validation test================")
+ td_val()
+ print("================run td performance test================")
+ td_perf()
diff --git a/DI-engine/ding/hpc_rl/tests/test_upgo.py b/DI-engine/ding/hpc_rl/tests/test_upgo.py
new file mode 100644
index 0000000000000000000000000000000000000000..c61f8df618ed4dfc90263623b51a7bf2bb64a15e
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/tests/test_upgo.py
@@ -0,0 +1,133 @@
+import time
+import torch
+from hpc_rll.origin.upgo import upgo_loss
+from hpc_rll.rl_utils.upgo import UPGO
+from testbase import mean_relative_error, times
+
+assert torch.cuda.is_available()
+use_cuda = True
+
+T = 256
+B = 256
+N = 256
+
+
+def upgo_val():
+ ori_target_output = torch.randn(T, B, N)
+ ori_rhos = torch.randn(T, B)
+ ori_action = torch.randint(
+ 0, N, size=(
+ T,
+ B,
+ )
+ )
+ ori_rewards = torch.randn(T, B)
+ ori_bootstrap_values = torch.randn(T + 1, B)
+
+ hpc_target_output = ori_target_output.clone().detach()
+ hpc_rhos = ori_rhos.clone().detach()
+ hpc_action = ori_action.clone().detach()
+ hpc_rewards = ori_rewards.clone().detach()
+ hpc_bootstrap_values = ori_bootstrap_values.clone().detach()
+ hpc_upgo = UPGO(T, B, N)
+
+ if use_cuda:
+ ori_target_output = ori_target_output.cuda()
+ ori_rhos = ori_rhos.cuda()
+ ori_action = ori_action.cuda()
+ ori_rewards = ori_rewards.cuda()
+ ori_bootstrap_values = ori_bootstrap_values.cuda()
+
+ hpc_target_output = hpc_target_output.cuda()
+ hpc_rhos = hpc_rhos.cuda()
+ hpc_action = hpc_action.cuda()
+ hpc_rewards = hpc_rewards.cuda()
+ hpc_bootstrap_values = hpc_bootstrap_values.cuda()
+ hpc_upgo = hpc_upgo.cuda()
+
+ ori_target_output.requires_grad_(True)
+ ori_loss = upgo_loss(ori_target_output, ori_rhos, ori_action, ori_rewards, ori_bootstrap_values)
+ ori_loss = ori_loss.mean()
+ ori_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+
+ hpc_target_output.requires_grad_(True)
+ hpc_loss = hpc_upgo(hpc_target_output, hpc_rhos, hpc_action, hpc_rewards, hpc_bootstrap_values)
+ hpc_loss = hpc_loss.mean()
+ hpc_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+
+ mre = mean_relative_error(
+ torch.flatten(ori_loss).cpu().detach().numpy(),
+ torch.flatten(hpc_loss).cpu().detach().numpy()
+ )
+ print("upgo fp mean_relative_error: " + str(mre))
+ mre = mean_relative_error(
+ torch.flatten(ori_target_output.grad).cpu().detach().numpy(),
+ torch.flatten(hpc_target_output.grad).cpu().detach().numpy()
+ )
+ print("upgo bp mean_relative_error: " + str(mre))
+
+
+def upgo_perf():
+ ori_target_output = torch.randn(T, B, N)
+ ori_rhos = torch.randn(T, B)
+ ori_action = torch.randint(
+ 0, N, size=(
+ T,
+ B,
+ )
+ )
+ ori_rewards = torch.randn(T, B)
+ ori_bootstrap_values = torch.randn(T + 1, B)
+
+ hpc_target_output = ori_target_output.clone().detach()
+ hpc_rhos = ori_rhos.clone().detach()
+ hpc_action = ori_action.clone().detach()
+ hpc_rewards = ori_rewards.clone().detach()
+ hpc_bootstrap_values = ori_bootstrap_values.clone().detach()
+ hpc_upgo = UPGO(T, B, N)
+
+ if use_cuda:
+ ori_target_output = ori_target_output.cuda()
+ ori_rhos = ori_rhos.cuda()
+ ori_action = ori_action.cuda()
+ ori_rewards = ori_rewards.cuda()
+ ori_bootstrap_values = ori_bootstrap_values.cuda()
+
+ hpc_target_output = hpc_target_output.cuda()
+ hpc_rhos = hpc_rhos.cuda()
+ hpc_action = hpc_action.cuda()
+ hpc_rewards = hpc_rewards.cuda()
+ hpc_bootstrap_values = hpc_bootstrap_values.cuda()
+ hpc_upgo = hpc_upgo.cuda()
+
+ ori_target_output.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ ori_loss = upgo_loss(ori_target_output, ori_rhos, ori_action, ori_rewards, ori_bootstrap_values)
+ ori_loss = ori_loss.mean()
+ ori_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, original upgo cost time: {}'.format(i, time.time() - t))
+
+ hpc_target_output.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ hpc_loss = hpc_upgo(hpc_target_output, hpc_rhos, hpc_action, hpc_rewards, hpc_bootstrap_values)
+ hpc_loss = hpc_loss.mean()
+ hpc_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, hpc upgo cost time: {}'.format(i, time.time() - t))
+
+
+if __name__ == '__main__':
+ print("target problem: T = {}, B = {}, N = {}".format(T, B, N))
+ print("================run upgo validation test================")
+ upgo_val()
+ print("================run upgo performance test================")
+ upgo_perf()
diff --git a/DI-engine/ding/hpc_rl/tests/test_vtrace.py b/DI-engine/ding/hpc_rl/tests/test_vtrace.py
new file mode 100644
index 0000000000000000000000000000000000000000..c26ab4f407d3e0f73267531eba7b24dcd7ce1f22
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/tests/test_vtrace.py
@@ -0,0 +1,143 @@
+import time
+import torch
+import torch.nn.functional as F
+from hpc_rll.origin.vtrace import vtrace_error_discrete_action, vtrace_data
+from hpc_rll.rl_utils.vtrace import VTrace
+from testbase import mean_relative_error, times
+
+assert torch.cuda.is_available()
+use_cuda = True
+
+T = 128
+B = 128
+N = 128
+
+
+def vtrace_val():
+ ori_target_output = torch.randn(T, B, N)
+ ori_behaviour_output = torch.randn(T, B, N)
+ ori_action = torch.randint(
+ 0, N, size=(
+ T,
+ B,
+ )
+ )
+ ori_value = torch.randn(T + 1, B)
+ ori_reward = torch.randn(T, B)
+
+ hpc_target_output = ori_target_output.clone().detach()
+ hpc_behaviour_output = ori_behaviour_output.clone().detach()
+ hpc_action = ori_action.clone().detach()
+ hpc_value = ori_value.clone().detach()
+ hpc_reward = ori_reward.clone().detach()
+ hpc_vtrace = VTrace(T, B, N)
+
+ if use_cuda:
+ ori_target_output = ori_target_output.cuda()
+ ori_behaviour_output = ori_behaviour_output.cuda()
+ ori_action = ori_action.cuda()
+ ori_value = ori_value.cuda()
+ ori_reward = ori_reward.cuda()
+
+ hpc_target_output = hpc_target_output.cuda()
+ hpc_behaviour_output = hpc_behaviour_output.cuda()
+ hpc_action = hpc_action.cuda()
+ hpc_value = hpc_value.cuda()
+ hpc_reward = hpc_reward.cuda()
+ hpc_vtrace = hpc_vtrace.cuda()
+
+ ori_target_output.requires_grad_(True)
+ ori_value.requires_grad_(True)
+ ori_loss = vtrace_error_discrete_action(
+ vtrace_data(ori_target_output, ori_behaviour_output, ori_action, ori_value, ori_reward, None)
+ )
+ ori_loss = sum(ori_loss)
+ ori_loss.backward()
+
+ hpc_target_output.requires_grad_(True)
+ hpc_value.requires_grad_(True)
+ hpc_loss = hpc_vtrace(hpc_target_output, hpc_behaviour_output, hpc_action, hpc_value, hpc_reward)
+ hpc_loss = sum(hpc_loss)
+ hpc_loss.backward()
+
+ mre = mean_relative_error(
+ torch.flatten(ori_loss).cpu().detach().numpy(),
+ torch.flatten(hpc_loss).cpu().detach().numpy()
+ )
+ print("vtrace fp mean_relative_error: " + str(mre))
+ mre = mean_relative_error(
+ torch.flatten(ori_target_output.grad).cpu().detach().numpy(),
+ torch.flatten(hpc_target_output.grad).cpu().detach().numpy()
+ )
+ print("vtrace bp target_output mean_relative_error: " + str(mre))
+ mre = mean_relative_error(
+ torch.flatten(ori_value.grad).cpu().detach().numpy(),
+ torch.flatten(hpc_value.grad).cpu().detach().numpy()
+ )
+ print("vtrace bp value mean_relative_error: " + str(mre))
+
+
+def vtrace_perf():
+ ori_target_output = torch.randn(T, B, N)
+ ori_behaviour_output = torch.randn(T, B, N)
+ ori_action = torch.randint(
+ 0, N, size=(
+ T,
+ B,
+ )
+ )
+ ori_value = torch.randn(T + 1, B)
+ ori_reward = torch.randn(T, B)
+
+ hpc_target_output = ori_target_output.clone().detach()
+ hpc_behaviour_output = ori_behaviour_output.clone().detach()
+ hpc_action = ori_action.clone().detach()
+ hpc_value = ori_value.clone().detach()
+ hpc_reward = ori_reward.clone().detach()
+ hpc_vtrace = VTrace(T, B, N)
+
+ if use_cuda:
+ ori_target_output = ori_target_output.cuda()
+ ori_behaviour_output = ori_behaviour_output.cuda()
+ ori_action = ori_action.cuda()
+ ori_value = ori_value.cuda()
+ ori_reward = ori_reward.cuda()
+
+ hpc_target_output = hpc_target_output.cuda()
+ hpc_behaviour_output = hpc_behaviour_output.cuda()
+ hpc_action = hpc_action.cuda()
+ hpc_value = hpc_value.cuda()
+ hpc_reward = hpc_reward.cuda()
+ hpc_vtrace = hpc_vtrace.cuda()
+
+ ori_target_output.requires_grad_(True)
+ ori_value.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ ori_loss = vtrace_error_discrete_action(
+ vtrace_data(ori_target_output, ori_behaviour_output, ori_action, ori_value, ori_reward, None)
+ )
+ ori_loss = sum(ori_loss)
+ ori_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, original vtrace cost time: {}'.format(i, time.time() - t))
+
+ hpc_target_output.requires_grad_(True)
+ hpc_value.requires_grad_(True)
+ for i in range(times):
+ t = time.time()
+ hpc_loss = hpc_vtrace(hpc_target_output, hpc_behaviour_output, hpc_action, hpc_value, hpc_reward)
+ hpc_loss = sum(hpc_loss)
+ hpc_loss.backward()
+ if use_cuda:
+ torch.cuda.synchronize()
+ print('epoch: {}, hpc vtrace cost time: {}'.format(i, time.time() - t))
+
+
+if __name__ == '__main__':
+ print("target problem: T = {}, B = {}, N = {}".format(T, B, N))
+ print("================run vtrace validation test================")
+ vtrace_val()
+ print("================run vtrace performance test================")
+ vtrace_perf()
diff --git a/DI-engine/ding/hpc_rl/tests/testbase.py b/DI-engine/ding/hpc_rl/tests/testbase.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dc09b499f852de0d809f8d09d953d2461dbf1aa
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/tests/testbase.py
@@ -0,0 +1,12 @@
+import torch
+import numpy as np
+
+torch.set_printoptions(precision=6)
+
+times = 6
+
+
+def mean_relative_error(y_true, y_pred):
+ eps = 1e-5
+ relative_error = np.average(np.abs(y_true - y_pred) / (y_true + eps), axis=0)
+ return relative_error
diff --git a/DI-engine/ding/hpc_rl/wrapper.py b/DI-engine/ding/hpc_rl/wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e4ac0bf9523aa4ba10e88d50b0a44f4b2ef89f2
--- /dev/null
+++ b/DI-engine/ding/hpc_rl/wrapper.py
@@ -0,0 +1,133 @@
+import importlib
+from ditk import logging
+from collections import OrderedDict
+from functools import wraps
+import ding
+'''
+Overview:
+ `hpc_wrapper` is the wrapper for functions which are supported by hpc. If a function is wrapped by it, we will
+ search for its hpc type and return the function implemented by hpc.
+ We will use the following code as a sample to introduce `hpc_wrapper`:
+ ```
+ @hpc_wrapper(shape_fn=shape_fn_dntd, namedtuple_data=True, include_args=[0,1,2,3],
+ include_kwargs=['data', 'gamma', 'v_min', 'v_max'], is_cls_method=False)
+ def dist_nstep_td_error(
+ data: namedtuple,
+ gamma: float,
+ v_min: float,
+ v_max: float,
+ n_atom: int,
+ nstep: int = 1,
+ ) -> torch.Tensor:
+ ...
+ ```
+Parameters:
+ - shape_fn (:obj:`function`): a function which return the shape needed by hpc function. In fact, it returns
+ all args that the hpc function needs.
+ - nametuple_data (:obj:`bool`): If True, when hpc function is called, it will be called as hpc_function(*nametuple).
+ If False, nametuple data will remain its `nametuple` type.
+ - include_args (:obj:`list`): a list of index of the args need to be set in hpc function. As shown in the sample,
+ include_args=[0,1,2,3], which means `data`, `gamma`, `v_min` and `v_max` will be set in hpc function.
+ - include_kwargs (:obj:`list`): a list of key of the kwargs need to be set in hpc function. As shown in the sample,
+ include_kwargs=['data', 'gamma', 'v_min', 'v_max'], which means `data`, `gamma`, `v_min` and `v_max` will be
+ set in hpc function.
+ - is_cls_method (:obj:`bool`): If True, it means the function we wrap is a method of a class. `self` will be put
+ into args. We will get rid of `self` in args. Besides, we will use its classname as its fn_name.
+ If False, it means the function is a simple method.
+Q&A:
+ - Q: Is `include_args` and `include_kwargs` need to be set at the same time?
+ - A: Yes. `include_args` and `include_kwargs` can deal with all type of input, such as (data, gamma, v_min=v_min,
+ v_max=v_max) and (data, gamma, v_min, v_max).
+ - Q: What is `hpc_fns`?
+ - A: Here we show a normal `hpc_fns`:
+ ```
+ hpc_fns = {
+ 'fn_name1': {
+ 'runtime_name1': hpc_fn1,
+ 'runtime_name2': hpc_fn2,
+ ...
+ },
+ ...
+ }
+ ```
+ Besides, `per_fn_limit` means the max length of `hpc_fns[fn_name]`. When new function comes, the oldest
+ function will be popped from `hpc_fns[fn_name]`.
+'''
+
+hpc_fns = {}
+per_fn_limit = 3
+
+
+def register_runtime_fn(fn_name, runtime_name, shape):
+ fn_name_mapping = {
+ 'gae': ['hpc_rll.rl_utils.gae', 'GAE'],
+ 'dist_nstep_td_error': ['hpc_rll.rl_utils.td', 'DistNStepTD'],
+ 'LSTM': ['hpc_rll.torch_utils.network.rnn', 'LSTM'],
+ 'ppo_error': ['hpc_rll.rl_utils.ppo', 'PPO'],
+ 'q_nstep_td_error': ['hpc_rll.rl_utils.td', 'QNStepTD'],
+ 'q_nstep_td_error_with_rescale': ['hpc_rll.rl_utils.td', 'QNStepTDRescale'],
+ 'ScatterConnection': ['hpc_rll.torch_utils.network.scatter_connection', 'ScatterConnection'],
+ 'td_lambda_error': ['hpc_rll.rl_utils.td', 'TDLambda'],
+ 'upgo_loss': ['hpc_rll.rl_utils.upgo', 'UPGO'],
+ 'vtrace_error_discrete_action': ['hpc_rll.rl_utils.vtrace', 'VTrace'],
+ }
+ fn_str = fn_name_mapping[fn_name]
+ cls = getattr(importlib.import_module(fn_str[0]), fn_str[1])
+ hpc_fn = cls(*shape).cuda()
+ if fn_name not in hpc_fns:
+ hpc_fns[fn_name] = OrderedDict()
+ hpc_fns[fn_name][runtime_name] = hpc_fn
+ while len(hpc_fns[fn_name]) > per_fn_limit:
+ hpc_fns[fn_name].popitem(last=False)
+ # print(hpc_fns)
+ return hpc_fn
+
+
+def hpc_wrapper(shape_fn=None, namedtuple_data=False, include_args=[], include_kwargs=[], is_cls_method=False):
+
+ def decorate(fn):
+
+ @wraps(fn)
+ def wrapper(*args, **kwargs):
+ if ding.enable_hpc_rl:
+ shape = shape_fn(args, kwargs)
+ if is_cls_method:
+ fn_name = args[0].__class__.__name__
+ else:
+ fn_name = fn.__name__
+ runtime_name = '_'.join([fn_name] + [str(s) for s in shape])
+ if fn_name not in hpc_fns or runtime_name not in hpc_fns[fn_name]:
+ hpc_fn = register_runtime_fn(fn_name, runtime_name, shape)
+ else:
+ hpc_fn = hpc_fns[fn_name][runtime_name]
+ if is_cls_method:
+ args = args[1:]
+ clean_args = []
+ for i in include_args:
+ if i < len(args):
+ clean_args.append(args[i])
+ nouse_args = list(set(list(range(len(args)))).difference(set(include_args)))
+ clean_kwargs = {}
+ for k, v in kwargs.items():
+ if k in include_kwargs:
+ if k == 'lambda_':
+ k = 'lambda'
+ clean_kwargs[k] = v
+ nouse_kwargs = list(set(kwargs.keys()).difference(set(include_kwargs)))
+ if len(nouse_args) > 0 or len(nouse_kwargs) > 0:
+ logging.warn(
+ 'in {}, index {} of args are dropped, and keys {} of kwargs are dropped.'.format(
+ runtime_name, nouse_args, nouse_kwargs
+ )
+ )
+ if namedtuple_data:
+ data = args[0] # args[0] is a namedtuple
+ return hpc_fn(*data, *clean_args[1:], **clean_kwargs)
+ else:
+ return hpc_fn(*clean_args, **clean_kwargs)
+ else:
+ return fn(*args, **kwargs)
+
+ return wrapper
+
+ return decorate
diff --git a/DI-engine/ding/interaction/__init__.py b/DI-engine/ding/interaction/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57d32ca013fb4cb6963ebf2b1a7901c13d60819
--- /dev/null
+++ b/DI-engine/ding/interaction/__init__.py
@@ -0,0 +1,2 @@
+from .master import *
+from .slave import *
diff --git a/DI-engine/ding/interaction/base/__init__.py b/DI-engine/ding/interaction/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..87275366e28b9f5a5c2149332309b0607377b819
--- /dev/null
+++ b/DI-engine/ding/interaction/base/__init__.py
@@ -0,0 +1,5 @@
+from .app import CommonErrorCode, success_response, failure_response, get_values_from_response, flask_response, \
+ ResponsibleException, responsible
+from .common import random_token, translate_dict_func, ControllableService, ControllableContext, default_func
+from .network import get_host_ip, get_http_engine_class, HttpEngine, split_http_address
+from .threading import DblEvent
diff --git a/DI-engine/ding/interaction/base/app.py b/DI-engine/ding/interaction/base/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf53c356f2922dc7a78a80e580b0793f5fafa711
--- /dev/null
+++ b/DI-engine/ding/interaction/base/app.py
@@ -0,0 +1,102 @@
+import json
+from enum import IntEnum, unique
+from functools import wraps
+from typing import Mapping, Any, Type, Optional, Tuple, Union, Iterable, Callable
+
+import flask
+import requests
+from flask import jsonify
+
+
+@unique
+class CommonErrorCode(IntEnum):
+ SUCCESS = 0
+ COMMON_FAILURE = 1
+
+
+def flask_response(
+ success: bool,
+ data: Optional[Mapping[str, Any]] = None,
+ message: Optional[str] = None,
+ code: Optional[int] = None,
+):
+ return jsonify(
+ {
+ 'success': success,
+ 'code': 0 if success else (code or CommonErrorCode.COMMON_FAILURE),
+ 'message': (message or 'Success.') if success else (message or 'Failed.'),
+ 'data': data,
+ }
+ )
+
+
+def success_response(data: Optional[Mapping[str, Any]] = None, message: Optional[str] = None):
+ return flask_response(
+ success=True,
+ code=CommonErrorCode.SUCCESS,
+ message=message,
+ data=data,
+ )
+
+
+def failure_response(
+ code: Optional[int] = None, message: Optional[str] = None, data: Optional[Mapping[str, Any]] = None
+):
+ return flask_response(
+ success=False,
+ code=code or CommonErrorCode.COMMON_FAILURE,
+ message=message,
+ data=data,
+ )
+
+
+_RESPONSE_VALUE_TYPE = Tuple[int, bool, int, str, Mapping[str, Any]]
+
+
+def get_values_from_response(response: Union[requests.Response, flask.Response]) -> _RESPONSE_VALUE_TYPE:
+ status_code = response.status_code
+
+ _content = response.content if hasattr(response, 'content') else response.data
+ _json = json.loads(_content.decode())
+ success, code, message, data = _json['success'], _json['code'], _json.get('message', ''), _json.get('data', {})
+
+ return status_code, success, code, message, data
+
+
+class ResponsibleException(Exception):
+
+ def __init__(
+ self,
+ code: int = CommonErrorCode.COMMON_FAILURE,
+ message: Optional[str] = None,
+ data: Optional[Mapping[str, Any]] = None,
+ status_code: int = 400
+ ):
+ Exception.__init__(self, message)
+ self.__code = code
+ self.__message = message
+ self.__data = data or {}
+ self.__status_code = status_code
+
+ def get_response(self):
+ return failure_response(self.__code, self.__message, self.__data), self.__status_code
+
+
+def responsible(classes: Iterable[Type[ResponsibleException]] = None):
+ if classes is None:
+ classes = (ResponsibleException, )
+
+ def _decorator(func: Callable[..., Any]) -> Callable[..., Any]:
+
+ @wraps(func)
+ def _func(*args, **kwargs):
+ try:
+ ret = func(*args, **kwargs)
+ except tuple(classes) as err:
+ return err.get_response()
+ else:
+ return ret
+
+ return _func
+
+ return _decorator
diff --git a/DI-engine/ding/interaction/base/common.py b/DI-engine/ding/interaction/base/common.py
new file mode 100644
index 0000000000000000000000000000000000000000..afc3407d30a75d25ed3a191b6ff2b33b4059bc35
--- /dev/null
+++ b/DI-engine/ding/interaction/base/common.py
@@ -0,0 +1,184 @@
+import random
+import string
+from abc import ABCMeta, abstractmethod
+from typing import Optional, Callable, Mapping, Any, Dict
+
+_LENGTH_OF_RANDOM_TOKEN = 64
+
+
+def random_token(length: Optional[int] = None) -> str:
+ """
+ Overview:
+ Generate random hex token
+ Arguments:
+ - length (:obj:`Optional[int]`): Length of the random token (`None` means `64`)
+ Returns:
+ - token (:obj:`str`): Generated random token
+ Example:
+ >>> random_token() # '4eAbd5218e3d0da5e7AAFcBF48Ea0Df2dadED1bdDF0B8724FdE1569AA78F24A7'
+ >>> random_token(24) # 'Cd1CdD98caAb8602ac6501aC'
+ """
+ return ''.join([random.choice(string.hexdigits) for _ in range(length or _LENGTH_OF_RANDOM_TOKEN)])
+
+
+class ControllableContext(metaclass=ABCMeta):
+ """
+ Overview:
+ Basic context-supported class structure
+ Example:
+ - Common usage
+
+ >>> c = MyControllableContext() # One of the superclasses if ControllableContext
+ >>> c.start()
+ >>> try:
+ >>> pass # do anything you like
+ >>> finally:
+ >>> c.close()
+
+ - Use with keyword (the same as code above)
+
+ >>> c = MyControllableContext() # One of the superclasses if ControllableContext
+ >>> with c as cc: # cc is c, have the same id
+ >>> pass # do anything you like
+ """
+
+ @abstractmethod
+ def start(self):
+ """
+ Overview:
+ Start the context
+ """
+ raise NotImplementedError # pragma: no cover
+
+ @abstractmethod
+ def close(self):
+ """
+ Overview:
+ Close the context
+ """
+ raise NotImplementedError # pragma: no cover
+
+ def __enter__(self):
+ """
+ Overview:
+ Enter the context
+ Returns:
+ - self (:obj:`ControllableContext`): Context object itself
+ """
+ self.start()
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ """
+ Overview:
+ Exit the context
+ """
+ self.close()
+
+
+class ControllableService(ControllableContext, metaclass=ABCMeta):
+ """
+ Overview:
+ Controllable service with context support, usually has concurrent feature.
+ Example:
+ - A common usage
+
+ >>> c = MyControllableService() # One of its superclasses is ControllableService
+ >>> c.start()
+ >>> try:
+ >>> pass # do anything you like
+ >>> finally:
+ >>> c.shutdown() # shutdown the service
+ >>> c.join() # wait until service is down
+
+ - Use with keyword (the same as code above)
+
+ >>> c = MyControllableService() # One of its superclasses is ControllableService
+ >>> with c as cc: # cc is c, have the same id
+ >>> pass # do anything you like
+ """
+
+ @abstractmethod
+ def start(self):
+ """
+ Overview:
+ Start the service
+ """
+ raise NotImplementedError # pragma: no cover
+
+ @abstractmethod
+ def shutdown(self):
+ """
+ Overview:
+ Shutdown the service (but service will not down immediately)
+ """
+ raise NotImplementedError # pragma: no cover
+
+ @abstractmethod
+ def join(self):
+ """
+ Overview:
+ Wait until the service is completely down
+ """
+ raise NotImplementedError # pragma: no cover
+
+ def close(self):
+ """
+ Overview:
+ Close the service, wait until the service is down.
+ """
+ self.shutdown()
+ self.join()
+
+
+def translate_dict_func(d: Mapping[str, Callable[..., Any]]) -> Callable[..., Dict[str, Any]]:
+ """
+ Overview:
+ Transform dict with funcs to function generating dict.
+ Arguments:
+ - d (:obj:`Mapping[str, Callable[..., Any]]`): Dict with funcs
+ Returns:
+ - func (:obj:`Callable[..., Dict[str, Any]]`): Function generating dict
+ Example:
+ >>> f1 = lambda x, y: x + y
+ >>> f2 = lambda x, y: x - y
+ >>> f3 = lambda x, y: x * y
+ >>> fx = translate_dict_func({'a': f1, 'b': f2, 'c': f3})
+ >>> fx(2, 3) # {'a': 5, 'b': -1, 'c': 6}
+ >>> fx(5, 11) # ('a': 16, 'b': -6, 'c': 55}
+ """
+
+ def _func(*args, **kwargs) -> Dict[str, Any]:
+ return {k: f(*args, **kwargs) for k, f in d.items()}
+
+ return _func
+
+
+def default_func(return_value=None) -> Callable[[Callable[..., Any]], Callable[..., Any]]:
+ """
+ Overview:
+ Transform optional function (maybe `None`) to function with default value
+ Argument:
+ - return_value (:obj:): Return value of the default function
+ Returns:
+ - decorator (:obj:`Callable[[Callable[..., Any]], Callable[..., Any]]`): A decorator function \
+ that can decorator optional function to real function (must be not None)
+ Example:
+ >>> f1 = None
+ >>> f2 = lambda x, y: x + y
+ >>> ff1 = default_func()(f1)
+ >>> ft1 = default_func(0)(f1)
+ >>> ff2 = default_func()(f2)
+ >>> ff1(2, 3) # None
+ >>> ft1(2, 3) # 0
+ >>> ff2(2, 3) # 5
+ """
+
+ def _decorator(func: Callable[..., Any]) -> Callable[..., Any]:
+ # noinspection PyUnusedLocal
+ def _func(*args, **kwargs):
+ return return_value
+
+ return func or _func
+
+ return _decorator
diff --git a/DI-engine/ding/interaction/base/network.py b/DI-engine/ding/interaction/base/network.py
new file mode 100644
index 0000000000000000000000000000000000000000..474ce5fa05cfe8079128b8a14099b3ae90624513
--- /dev/null
+++ b/DI-engine/ding/interaction/base/network.py
@@ -0,0 +1,152 @@
+import json
+import socket
+import time
+from typing import Optional, Any, Mapping, Callable, Type, Tuple
+
+import requests
+from requests import HTTPError
+from urlobject import URLObject
+from urlobject.path import URLPath
+
+from .common import translate_dict_func
+
+
+def get_host_ip() -> Optional[str]:
+ s = None
+ try:
+ s = socket.socket(socket.AF_INET, socket.SOCK_DGRAM)
+ s.connect(('8.8.8.8', 80))
+ ip = s.getsockname()[0]
+ finally:
+ if s is not None:
+ s.close()
+ return ip
+
+
+_DEFAULT_HTTP_PORT = 80
+_DEFAULT_HTTPS_PORT = 443
+
+
+def split_http_address(address: str, default_port: Optional[int] = None) -> Tuple[str, int, bool, str]:
+ _url = URLObject(address)
+
+ _host = _url.hostname
+ _https = (_url.scheme.lower()) == 'https'
+ _port = _url.port or default_port or (_DEFAULT_HTTPS_PORT if _https else _DEFAULT_HTTP_PORT)
+ _path = str(_url.path) or ''
+
+ return _host, _port, _https, _path
+
+
+DEFAULT_RETRIES = 5
+DEFAULT_RETRY_WAITING = 1.0
+
+
+class HttpEngine:
+
+ def __init__(self, host: str, port: int, https: bool = False, path: str = None):
+ self.__base_url = URLObject().with_scheme('https' if https else 'http') \
+ .with_hostname(host).with_port(port).add_path(path or '')
+ self.__session = requests.session()
+ self.__session.trust_env = False
+
+ # noinspection PyMethodMayBeStatic
+ def _data_process(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
+ return data or {}
+
+ # noinspection PyMethodMayBeStatic
+ def _base_headers(self) -> Mapping[str, None]:
+ return {}
+
+ def _error_handler(self, err: Exception):
+ raise err
+
+ def get_url(self, path: str = None):
+ original_segments = self.__base_url.path.segments
+ path_segments = URLPath().add(path or '').segments
+ return str(self.__base_url.with_path(URLPath.join_segments(original_segments + path_segments)))
+
+ def __single_request(
+ self,
+ method: str,
+ path: str,
+ data: Optional[Mapping[str, Any]] = None,
+ headers: Optional[Mapping[str, Any]] = None,
+ params: Optional[Mapping[str, Any]] = None,
+ raise_for_status: bool = True
+ ):
+ response = self.__session.request(
+ method=method,
+ url=self.get_url(path),
+ data=json.dumps(self._data_process(data) or {}),
+ headers=headers,
+ params=params or {},
+ )
+ if raise_for_status:
+ response.raise_for_status()
+
+ return response
+
+ def request(
+ self,
+ method: str,
+ path: str,
+ data: Optional[Mapping[str, Any]] = None,
+ headers: Optional[Mapping[str, Any]] = None,
+ params: Optional[Mapping[str, Any]] = None,
+ raise_for_status: bool = True,
+ retries: Optional[int] = None,
+ retry_waiting: Optional[float] = None,
+ ) -> requests.Response:
+ _headers = dict(self._base_headers())
+ _headers.update(headers or {})
+
+ retries = retries or DEFAULT_RETRIES
+ retry_waiting = retry_waiting or DEFAULT_RETRY_WAITING
+
+ try:
+ _current_retries = 0
+ while True:
+ try:
+ response = self.__single_request(method, path, data, _headers, params, raise_for_status)
+ except requests.exceptions.HTTPError as err:
+ raise err
+ except requests.exceptions.RequestException as err:
+ _current_retries += 1
+ if _current_retries > retries:
+ raise err
+ else:
+ time.sleep(retry_waiting)
+ else:
+ break
+ except Exception as e:
+ self._error_handler(e)
+ else:
+ return response
+
+
+def get_http_engine_class(
+ headers: Mapping[str, Callable[..., Any]],
+ data_processor: Optional[Callable[[Mapping[str, Any]], Mapping[str, Any]]] = None,
+ http_error_gene: Optional[Callable[[HTTPError], Exception]] = None,
+) -> Callable[..., Type[HttpEngine]]:
+
+ def _func(*args, **kwargs) -> Type[HttpEngine]:
+
+ class _HttpEngine(HttpEngine):
+
+ def _data_process(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
+ return (data_processor or (lambda d: d or {}))(data or {})
+
+ def _base_headers(self) -> Mapping[str, None]:
+ return translate_dict_func(headers)(*args, **kwargs)
+
+ def _error_handler(self, err: Exception):
+ if http_error_gene is not None and isinstance(err, HTTPError):
+ raise http_error_gene(err)
+ else:
+ raise err
+
+ return _HttpEngine
+
+ return _func
diff --git a/DI-engine/ding/interaction/base/threading.py b/DI-engine/ding/interaction/base/threading.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b9275bbc637f158f4bb9bf1c27fc9c503d39785
--- /dev/null
+++ b/DI-engine/ding/interaction/base/threading.py
@@ -0,0 +1,82 @@
+from threading import Event, Lock
+from typing import Optional
+
+
+class DblEvent:
+ """
+ Overview:
+ A double event object, can open and close.
+ Bases on 2 event objects
+ """
+
+ def __init__(self, opened: bool = False):
+ """
+ Overview:
+ Constructor of `DblEvent`
+ Arguments:
+ - opened (:obj:`bool`): Initial status (`True` means open, `False` means close, default is `False`)
+ """
+ self.__open_event = Event()
+ self.__close_event = Event()
+ self.__lock = Lock()
+
+ if opened:
+ self.__open_event.set()
+ else:
+ self.__close_event.set()
+
+ def wait_for_open(self, timeout: Optional[float] = None):
+ """
+ Overview:
+ Wait until the event is opened
+ Arguments:
+ - timeout (:obj:`Optional[float]`): Waiting time out in seconds
+ """
+ self.__open_event.wait(timeout=timeout)
+
+ def wait_for_close(self, timeout: Optional[float] = None):
+ """
+ Overview:
+ Wait until the event is closed
+ Arguments:
+ - timeout (:obj:`Optional[float]`): Waiting time out in seconds
+ """
+ self.__close_event.wait(timeout=timeout)
+
+ def open(self):
+ """
+ Overview:
+ Open this event
+ """
+ with self.__lock:
+ self.__open_event.set()
+ self.__close_event.clear()
+
+ def close(self):
+ """
+ Overview:
+ Close this event
+ """
+ with self.__lock:
+ self.__close_event.set()
+ self.__open_event.clear()
+
+ def is_open(self) -> bool:
+ """
+ Overview:
+ Get if the event is opened
+ Returns:
+ - opened (:obj:`bool`): The event is opened or not
+ """
+ with self.__lock:
+ return self.__open_event.is_set()
+
+ def is_close(self) -> bool:
+ """
+ Overview:
+ Get if the event is closed
+ Returns:
+ - opened (:obj:`bool`): The event is closed or not
+ """
+ with self.__lock:
+ return self.__close_event.is_set()
diff --git a/DI-engine/ding/interaction/config/__init__.py b/DI-engine/ding/interaction/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..98491582a22e312ca5e5ad16f65390f9e5407298
--- /dev/null
+++ b/DI-engine/ding/interaction/config/__init__.py
@@ -0,0 +1,3 @@
+from .base import MIN_HEARTBEAT_CHECK_SPAN, MIN_HEARTBEAT_SPAN, DEFAULT_MASTER_PORT, DEFAULT_SLAVE_PORT, \
+ DEFAULT_CHANNEL, DEFAULT_HEARTBEAT_CHECK_SPAN, DEFAULT_HEARTBEAT_TOLERANCE, DEFAULT_HEARTBEAT_SPAN, LOCAL_HOST, \
+ GLOBAL_HOST, DEFAULT_REQUEST_RETRIES, DEFAULT_REQUEST_RETRY_WAITING
diff --git a/DI-engine/ding/interaction/config/base.py b/DI-engine/ding/interaction/config/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..446e260203ff8119d34f20748b6d29698b85da7d
--- /dev/null
+++ b/DI-engine/ding/interaction/config/base.py
@@ -0,0 +1,21 @@
+# System configs
+GLOBAL_HOST = '0.0.0.0'
+LOCAL_HOST = '127.0.0.1'
+
+# General request
+DEFAULT_REQUEST_RETRIES = 5
+DEFAULT_REQUEST_RETRY_WAITING = 1.0
+
+# Slave configs
+MIN_HEARTBEAT_SPAN = 0.2
+DEFAULT_HEARTBEAT_SPAN = 3.0
+DEFAULT_SLAVE_PORT = 7236
+
+# Master configs
+MIN_HEARTBEAT_CHECK_SPAN = 0.1
+DEFAULT_HEARTBEAT_CHECK_SPAN = 1.0
+DEFAULT_HEARTBEAT_TOLERANCE = 17.0
+DEFAULT_MASTER_PORT = 7235
+
+# Two-side configs
+DEFAULT_CHANNEL = 0
diff --git a/DI-engine/ding/interaction/exception/__init__.py b/DI-engine/ding/interaction/exception/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..100d2144649b5d3f48e4d8fa2998dcbf785a5e8f
--- /dev/null
+++ b/DI-engine/ding/interaction/exception/__init__.py
@@ -0,0 +1,9 @@
+from .base import ResponseException
+from .master import MasterErrorCode, get_master_exception_by_error, MasterResponseException, MasterSuccess, \
+ MasterChannelInvalid, MasterChannelNotGiven, MasterMasterTokenInvalid, MasterMasterTokenNotGiven, \
+ MasterSelfTokenInvalid, MasterSelfTokenNotGiven, MasterSlaveTokenInvalid, MasterSlaveTokenNotGiven, \
+ MasterSystemShuttingDown, MasterTaskDataInvalid
+from .slave import SlaveErrorCode, get_slave_exception_by_error, SlaveResponseException, SlaveSuccess, \
+ SlaveChannelInvalid, SlaveChannelNotFound, SlaveSelfTokenInvalid, SlaveTaskAlreadyExist, SlaveTaskRefused, \
+ SlaveMasterTokenInvalid, SlaveMasterTokenNotFound, SlaveSelfTokenNotFound, SlaveSlaveAlreadyConnected, \
+ SlaveSlaveConnectionRefused, SlaveSlaveDisconnectionRefused, SlaveSlaveNotConnected, SlaveSystemShuttingDown
diff --git a/DI-engine/ding/interaction/exception/base.py b/DI-engine/ding/interaction/exception/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e8537b118be9748e8f64e4f07208652c7079ca9
--- /dev/null
+++ b/DI-engine/ding/interaction/exception/base.py
@@ -0,0 +1,99 @@
+from abc import ABCMeta
+from typing import Mapping, Any
+
+from requests.exceptions import HTTPError
+
+from ..base import get_values_from_response
+
+
+class _IResponseInformation(metaclass=ABCMeta):
+ """
+ Overview:
+ Response information basic structure interface
+ """
+
+ @property
+ def success(self) -> bool:
+ """
+ Overview:
+ Get response success or not
+ Returns:
+ - success (:obj:`bool`): Response success or not
+ """
+ raise NotImplementedError
+
+ @property
+ def code(self) -> int:
+ """
+ Overview:
+ Get response error code (`0` means success)
+ Returns:
+ - code (:obj:`int`): Response error code
+ """
+ raise NotImplementedError
+
+ @property
+ def message(self) -> str:
+ """
+ Overview:
+ Get response message
+ Returns:
+ - message (:obj:`str`): Response message
+ """
+ raise NotImplementedError
+
+ @property
+ def data(self) -> Mapping[str, Any]:
+ """
+ Overview:
+ Get response data
+ Returns:
+ - data (:obj:`Mapping[str, Any]`): Response data
+ """
+ raise NotImplementedError
+
+
+# exception class for processing response
+class ResponseException(Exception, _IResponseInformation, metaclass=ABCMeta):
+ """
+ Overview:
+ Response exception, which can be directly raised in methods to create fail http response.
+ """
+
+ def __init__(self, error: HTTPError):
+ """
+ Overview:
+ Constructor of `ResponseException`
+ Arguments:
+ - error (:obj:`HTTPError`): Original http exception object
+ """
+ self.__error = error
+ self.__status_code, self.__success, self.__code, self.__message, self.__data = \
+ get_values_from_response(error.response)
+ Exception.__init__(self, self.__message)
+
+ @property
+ def status_code(self) -> int:
+ """
+ Overview:
+ Get http status code of response
+ Returns:
+ - status_code (:obj:`int`): Http status code
+ """
+ return self.__status_code
+
+ @property
+ def success(self) -> bool:
+ return self.__success
+
+ @property
+ def code(self) -> int:
+ return self.__code
+
+ @property
+ def message(self) -> str:
+ return self.__message
+
+ @property
+ def data(self) -> Mapping[str, Any]:
+ return self.__data
diff --git a/DI-engine/ding/interaction/exception/master.py b/DI-engine/ding/interaction/exception/master.py
new file mode 100644
index 0000000000000000000000000000000000000000..071ae8c1dca67294a7b4a6a32ae05048bba88b31
--- /dev/null
+++ b/DI-engine/ding/interaction/exception/master.py
@@ -0,0 +1,126 @@
+from abc import ABCMeta
+from enum import unique, IntEnum
+from typing import Type
+
+import enum_tools
+from requests import HTTPError
+
+from .base import ResponseException
+from ..base import get_values_from_response
+
+
+@enum_tools.documentation.document_enum
+@unique
+class MasterErrorCode(IntEnum):
+ """
+ Overview:
+ Error codes for master end
+ """
+ SUCCESS = 0 # doc: Master request success
+
+ SYSTEM_SHUTTING_DOWN = 101 # doc: Master end is shutting down
+
+ CHANNEL_NOT_GIVEN = 201 # doc: No channel id given in request
+ CHANNEL_INVALID = 202 # doc: Channel id given not match with master end
+
+ MASTER_TOKEN_NOT_GIVEN = 301 # doc: No master token found in connection request from slave
+ MASTER_TOKEN_INVALID = 302 # doc: Master token auth failed in master end
+
+ SELF_TOKEN_NOT_GIVEN = 401 # doc: No self token given in self request (such as ping, shutdown)
+ SELF_TOKEN_INVALID = 402 # doc: Self token auth failed in master end itself (such as ping, shutdown)
+
+ SLAVE_TOKEN_NOT_GIVEN = 501 # doc: No slave token given in service request from slave
+ SLAVE_TOKEN_INVALID = 502 # doc: Slave token not found in master end
+
+ TASK_DATA_INVALID = 601 # doc: Task data is invalid
+
+
+# noinspection DuplicatedCode
+class MasterResponseException(ResponseException, metaclass=ABCMeta):
+ """
+ Overview:
+ Response exception for master client
+ """
+
+ def __init__(self, error: HTTPError):
+ """
+ Overview:
+ Constructor
+ Arguments:
+ - error (:obj:`HTTPError`): Original http exception object
+ """
+ ResponseException.__init__(self, error)
+
+
+class MasterSuccess(MasterResponseException):
+ pass
+
+
+class MasterSystemShuttingDown(MasterResponseException):
+ pass
+
+
+class MasterChannelNotGiven(MasterResponseException):
+ pass
+
+
+class MasterChannelInvalid(MasterResponseException):
+ pass
+
+
+class MasterMasterTokenNotGiven(MasterResponseException):
+ pass
+
+
+class MasterMasterTokenInvalid(MasterResponseException):
+ pass
+
+
+class MasterSelfTokenNotGiven(MasterResponseException):
+ pass
+
+
+class MasterSelfTokenInvalid(MasterResponseException):
+ pass
+
+
+class MasterSlaveTokenNotGiven(MasterResponseException):
+ pass
+
+
+class MasterSlaveTokenInvalid(MasterResponseException):
+ pass
+
+
+class MasterTaskDataInvalid(MasterResponseException):
+ pass
+
+
+_PREFIX = ['master']
+
+
+def get_master_exception_class_by_error_code(error_code: MasterErrorCode) -> Type[MasterResponseException]:
+ """
+ Overview:
+ Transform from master error code to `MasterResponseException` class
+ Arguments:
+ - error_code (:obj:`MasterErrorCode`): Master error code
+ Returns:
+ - exception_class (:obj:`Type[MasterResponseException`): Master response exception class
+ """
+ class_name = ''.join([word.lower().capitalize() for word in (_PREFIX + error_code.name.split('_'))])
+ return eval(class_name)
+
+
+def get_master_exception_by_error(error: HTTPError) -> MasterResponseException:
+ """
+ Overview:
+ Auto transform http error object to master response exception object.
+ Arguments:
+ - error (:obj:`HTTPError`): Http error object
+ Returns:
+ - exception (:obj:`MasterResponseException`): Master response exception object
+ """
+ _, _, code, _, _ = get_values_from_response(error.response)
+ error_code = {v.value: v for k, v in MasterErrorCode.__members__.items()}[code]
+ return get_master_exception_class_by_error_code(error_code)(error)
diff --git a/DI-engine/ding/interaction/exception/slave.py b/DI-engine/ding/interaction/exception/slave.py
new file mode 100644
index 0000000000000000000000000000000000000000..b534243a8c34a00d5a863ec409995b6682a82663
--- /dev/null
+++ b/DI-engine/ding/interaction/exception/slave.py
@@ -0,0 +1,141 @@
+from abc import ABCMeta
+from enum import unique, IntEnum
+from typing import Type
+
+import enum_tools
+from requests import HTTPError
+
+from .base import ResponseException
+from ..base import get_values_from_response
+
+
+@enum_tools.documentation.document_enum
+@unique
+class SlaveErrorCode(IntEnum):
+ """
+ Overview:
+ Error code for slave end
+ """
+ SUCCESS = 0 # doc: Slave request success
+
+ SYSTEM_SHUTTING_DOWN = 101 # doc: Slave end is shutting down
+
+ CHANNEL_NOT_FOUND = 201 # doc: No channel id given in request
+ CHANNEL_INVALID = 202 # doc: Channel id given not match with slave end
+
+ MASTER_TOKEN_NOT_FOUND = 301 # doc: No master token found in connection request from master
+ MASTER_TOKEN_INVALID = 302 # doc: Master token auth failed in slave end
+
+ SELF_TOKEN_NOT_FOUND = 401 # doc: No self token given in self request (such as ping, shutdown)
+ SELF_TOKEN_INVALID = 402 # doc: Self token auth failed in slave end itself (such as ping, shutdown)
+
+ SLAVE_ALREADY_CONNECTED = 501 # doc: Slave end has already connected to another master end
+ SLAVE_NOT_CONNECTED = 502 # doc: Slave end not connected with master end yey
+ SLAVE_CONNECTION_REFUSED = 503 # doc: Connection to slave end refused
+ SLAVE_DISCONNECTION_REFUSED = 504 # doc: Disconnection to slave end refused
+
+ TASK_ALREADY_EXIST = 601 # doc: Slave end is processing another task
+ TASK_REFUSED = 602 # doc: Task for slave end refused
+
+
+# noinspection DuplicatedCode
+class SlaveResponseException(ResponseException, metaclass=ABCMeta):
+ """
+ Overview:
+ Response exception for slave client
+ """
+
+ def __init__(self, error: HTTPError):
+ """
+ Overview:
+ Constructor
+ Arguments:
+ - error (:obj:`HTTPError`): Original http exception object
+ """
+ ResponseException.__init__(self, error)
+
+
+class SlaveSuccess(SlaveResponseException):
+ pass
+
+
+class SlaveSystemShuttingDown(SlaveResponseException):
+ pass
+
+
+class SlaveChannelNotFound(SlaveResponseException):
+ pass
+
+
+class SlaveChannelInvalid(SlaveResponseException):
+ pass
+
+
+class SlaveMasterTokenNotFound(SlaveResponseException):
+ pass
+
+
+class SlaveMasterTokenInvalid(SlaveResponseException):
+ pass
+
+
+class SlaveSelfTokenNotFound(SlaveResponseException):
+ pass
+
+
+class SlaveSelfTokenInvalid(SlaveResponseException):
+ pass
+
+
+class SlaveSlaveAlreadyConnected(SlaveResponseException):
+ pass
+
+
+class SlaveSlaveNotConnected(SlaveResponseException):
+ pass
+
+
+class SlaveSlaveConnectionRefused(SlaveResponseException):
+ pass
+
+
+class SlaveSlaveDisconnectionRefused(SlaveResponseException):
+ pass
+
+
+class SlaveTaskAlreadyExist(SlaveResponseException):
+ pass
+
+
+class SlaveTaskRefused(SlaveResponseException):
+ pass
+
+
+_PREFIX = ['slave']
+
+
+def get_slave_exception_class_by_error_code(error_code: SlaveErrorCode) -> Type[SlaveResponseException]:
+ """
+ Overview:
+ Transform from slave error code to `SlaveResponseException` class
+ Arguments:
+ - error_code (:obj:`SlaveErrorCode`): Slave error code
+ Returns:
+ - exception_class (:obj:`Type[SlaveResponseException`): Slave response exception class
+ """
+ class_name = ''.join([word.lower().capitalize() for word in (_PREFIX + error_code.name.split('_'))])
+ return eval(class_name)
+
+
+def get_slave_exception_by_error(error: HTTPError) -> SlaveResponseException:
+ """
+ Overview:
+ Auto transform http error object to slave response exception object.
+ Arguments:
+ - error (:obj:`HTTPError`): Http error object
+ Returns:
+ - exception (:obj:`SlaveResponseException`): Slave response exception object
+ """
+ _, _, code, _, _ = get_values_from_response(error.response)
+ error_code = {v.value: v for k, v in SlaveErrorCode.__members__.items()}[code]
+ return get_slave_exception_class_by_error_code(error_code)(error)
diff --git a/DI-engine/ding/interaction/master/__init__.py b/DI-engine/ding/interaction/master/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb00dbca36b4558621a11637f812e370dcf8a90f
--- /dev/null
+++ b/DI-engine/ding/interaction/master/__init__.py
@@ -0,0 +1 @@
+from .master import Master
diff --git a/DI-engine/ding/interaction/master/base.py b/DI-engine/ding/interaction/master/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..752deb61d22c34dcd8e88fe87fe4ac6ea147346d
--- /dev/null
+++ b/DI-engine/ding/interaction/master/base.py
@@ -0,0 +1,7 @@
+from typing import Callable, Mapping, Any, Optional
+
+from requests import RequestException
+
+_BEFORE_HOOK_TYPE = Callable[..., Mapping[str, Any]]
+_AFTER_HOOK_TYPE = Callable[[int, bool, int, Optional[str], Optional[Mapping[str, Any]]], Any]
+_ERROR_HOOK_TYPE = Callable[[RequestException], Any]
diff --git a/DI-engine/ding/interaction/master/connection.py b/DI-engine/ding/interaction/master/connection.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ba64424494d31a2d399eefc6c5a77bfa11b32fc
--- /dev/null
+++ b/DI-engine/ding/interaction/master/connection.py
@@ -0,0 +1,450 @@
+from abc import ABCMeta, abstractmethod
+from functools import wraps
+from threading import Lock
+from typing import Optional, Any, Mapping, Type, Callable
+from uuid import uuid4, UUID
+
+import requests
+from requests.exceptions import RequestException
+
+from .base import _BEFORE_HOOK_TYPE, _AFTER_HOOK_TYPE, _ERROR_HOOK_TYPE
+from .task import Task, _task_complete, _task_fail
+from ..base import random_token, ControllableContext, get_http_engine_class, get_values_from_response
+from ..config import DEFAULT_CHANNEL, DEFAULT_SLAVE_PORT, DEFAULT_REQUEST_RETRIES, DEFAULT_REQUEST_RETRY_WAITING
+from ..exception import get_slave_exception_by_error
+
+_COMPLETE_TRIGGER_NAME = '__TASK_COMPLETE__'
+_FAIL_TRIGGER_NAME = '__TASK_FAIL__'
+
+
+class _ISlaveConnection(ControllableContext, metaclass=ABCMeta):
+ """
+ Overview:
+ Basic model of the connection classes, such as `SlaveConnection` and `SlaveConnectionProxy`, \
+ which are used widely in interaction module.
+ Example:
+ - The following code shows a sample to correctly use slave connection
+ >>> connection = master.new_connection('cnn1,', '127.0.0.1', 2333)
+ >>> connection.connect()
+ >>> try:
+ >>> pass # do anything you like
+ >>> finally:
+ >>> connection.disconnect()
+
+ - Another simple structure of the code above
+ >>> with master.new_connection('cnn1,', '127.0.0.1', 2333) as connection:
+ >>> pass # do anything you like, connect and disconnect will be done automatically
+ """
+
+ @abstractmethod
+ def connect(self):
+ """
+ Overview:
+ Connect to slave end.
+ """
+ raise NotImplementedError # pragma: no cover
+
+ @abstractmethod
+ def disconnect(self):
+ """
+ Overview:
+ Disconnect from slave end.
+ """
+ raise NotImplementedError # pragma: no cover
+
+ @abstractmethod
+ def new_task(self, data: Optional[Mapping[str, Any]] = None):
+ """
+ Overview:
+ Send new task to slave end and receive task result from it.
+ Arguments:
+ - data (:obj:`Optional[Mapping[str, Any]]`): Data of the new task
+ Returns:
+ - result (:obj:`Mapping[str, Any]`): Result of the task processed by slave end
+ """
+ raise NotImplementedError # pragma: no cover
+
+ def start(self):
+ """
+ Overview:
+ Alias for `connect`, for supporting context manager.
+ """
+ self.connect()
+
+ def close(self):
+ """
+ Overview:
+ Alias for `disconnect`, for support context manager.
+ """
+ self.disconnect()
+
+
+class SlaveConnection(_ISlaveConnection, metaclass=ABCMeta):
+ """
+ Overview:
+ Slave connection object, which need to directly interact with slave end.
+ """
+
+ def __init__(
+ self,
+ host: str,
+ port: Optional[int] = None,
+ https: bool = False,
+ channel: Optional[int] = None,
+ my_address: Optional[str] = None,
+ token: Optional[str] = None,
+ request_retries: Optional[int] = None,
+ request_retry_waiting: Optional[float] = None,
+ ):
+ """
+ Overview:
+ Constructor of `SlaveConnection`
+ Arguments:
+ - host (:obj:`str`): Host of the slave server
+ - port (:obj:`Optional[int]`): Port of the slave server (None means `7236`)
+ - https (:obj:`bool`): Use https or not
+ - channel (:obj:`Optional[int]`): Channel id for the slave client.
+ - my_address (:obj:`Optional[str]`): The address of current server (None will grep local ip automatically, \
+ this address will be used when connect to slave, the slave's request will be send to this address, \
+ **so please make sure the address can be achieved by slave**)
+ - token (:obj:`Optional[str]`): Token of this connection, it is a token for authenticate to the \
+ connection (`None` means this token would be randomly generated)
+ - request_retries (:obj:`Optional[int]`): Max times for request retries (None means `5`)
+ - request_retry_waiting (:obj:`Optional[float]`): Sleep time before requests' retrying (None means `1.0`, \
+ unit: second)
+ """
+ # meta info part
+ self.__channel = channel or DEFAULT_CHANNEL
+ self.__my_address = my_address
+ self.__token = token or random_token()
+
+ # request part
+ self.__http_engine = get_http_engine_class(
+ headers={
+ 'Channel': lambda: str(self.__channel),
+ 'Token': lambda: self.__token,
+ },
+ http_error_gene=get_slave_exception_by_error,
+ )()(host, port or DEFAULT_SLAVE_PORT, https)
+ self.__request_retries = max(request_retries or DEFAULT_REQUEST_RETRIES, 0)
+ self.__request_retry_waiting = max(request_retry_waiting or DEFAULT_REQUEST_RETRY_WAITING, 0.0)
+
+ # threading part
+ self.__lock = Lock()
+ self.__is_connected = False
+
+ # task part
+ self.__tasks = {}
+
+ self.__init_triggers()
+
+ def __request(self, method: str, path: str, data: Optional[Mapping[str, Any]] = None) -> requests.Response:
+ return self.__http_engine.request(
+ method,
+ path,
+ data,
+ retries=self.__request_retries,
+ retry_waiting=self.__request_retry_waiting,
+ )
+
+ @property
+ def is_connected(self) -> bool:
+ """
+ Overview:
+ Check connection status
+ Returns:
+ - connected (:obj:`bool`): Whether this connection is still alive
+ """
+ with self.__lock:
+ return self.__is_connected
+
+ def _before_connect(self) -> Mapping[str, Any]:
+ pass # pragma: no cover
+
+ def _after_connect(
+ self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
+ ) -> Any:
+ pass # pragma: no cover
+
+ def _error_connect(self, error: RequestException) -> Any:
+ raise error # pragma: no cover
+
+ def __connect(self):
+ try:
+ response = self.__request(
+ 'POST', '/connect', {
+ 'master': {
+ 'address': self.__my_address,
+ },
+ 'data': (self._before_connect() or {})
+ }
+ )
+ except RequestException as err:
+ return self._error_connect(err)
+ else:
+ self.__is_connected = True
+ return self._after_connect(*get_values_from_response(response))
+
+ def connect(self):
+ with self.__lock:
+ return self.__connect()
+
+ def _before_disconnect(self) -> Mapping[str, Any]:
+ pass # pragma: no cover
+
+ def _after_disconnect(
+ self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
+ ) -> Any:
+ pass # pragma: no cover
+
+ def _error_disconnect(self, error: RequestException) -> Any:
+ raise error # pragma: no cover
+
+ def __disconnect(self):
+ try:
+ response = self.__request('DELETE', '/disconnect', {
+ 'data': self._before_disconnect() or {},
+ })
+ except RequestException as err:
+ return self._error_disconnect(err)
+ else:
+ self.__is_connected = False
+ return self._after_disconnect(*get_values_from_response(response))
+
+ def disconnect(self):
+ with self.__lock:
+ return self.__disconnect()
+
+ def _before_new_task(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
+ return data # pragma: no cover
+
+ def _after_new_task(
+ self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
+ ) -> Any:
+ pass # pragma: no cover
+
+ def _error_new_task(self, error: RequestException) -> Any:
+ raise error # pragma: no cover
+
+ def new_task(self, data: Optional[Mapping[str, Any]] = None) -> Task:
+ with self.__lock:
+ _uuid = uuid4()
+ _task = Task(
+ http_engine=self.__http_engine,
+ data=data,
+ task_id=_uuid,
+ before_task_start=self._before_new_task,
+ after_task_start=self._after_new_task,
+ error_task_start=self._error_new_task,
+ )
+
+ self.__tasks[_uuid] = _task
+ return _task
+
+ def __task_complete(self, task_id: UUID, task_result: Mapping[str, Any]):
+ _task = self.__tasks[task_id]
+ _task_complete(_task, task_result)
+ del self.__tasks[task_id]
+
+ def __task_fail(self, task_id: UUID, task_result: Mapping[str, Any]):
+ _task = self.__tasks[task_id]
+ _task_fail(_task, task_result)
+ del self.__tasks[task_id]
+
+ def __task_complete_trigger(self, task_id: UUID, task_result: Mapping[str, Any]):
+ with self.__lock:
+ if task_id in self.__tasks.keys():
+ return self.__task_complete(task_id, task_result)
+ else:
+ raise KeyError("Task {uuid} not found in this connection.".format(uuid=repr(str(task_id))))
+
+ def __task_fail_trigger(self, task_id: UUID, task_result: Mapping[str, Any]):
+ with self.__lock:
+ if task_id in self.__tasks.keys():
+ return self.__task_fail(task_id, task_result)
+ else:
+ raise KeyError("Task {uuid} not found in this connection.".format(uuid=repr(str(task_id))))
+
+ def __init_triggers(self):
+ setattr(self, _COMPLETE_TRIGGER_NAME, self.__task_complete_trigger)
+ setattr(self, _FAIL_TRIGGER_NAME, self.__task_fail_trigger)
+
+
+def _connection_task_complete(connection: SlaveConnection, task_id: UUID, task_result: Mapping[str, Any]):
+ return getattr(connection, _COMPLETE_TRIGGER_NAME)(task_id, task_result)
+
+
+def _connection_task_fail(connection: SlaveConnection, task_id: UUID, task_result: Mapping[str, Any]):
+ return getattr(connection, _FAIL_TRIGGER_NAME)(task_id, task_result)
+
+
+class SlaveConnectionProxy(_ISlaveConnection):
+ """
+ Overview:
+ Proxy class for `SlaveConnection` class, which wraps the original methods.
+ """
+
+ def __init__(
+ self,
+ connection: SlaveConnection,
+ after_connect: Optional[Callable] = None,
+ after_disconnect: Optional[Callable] = None
+ ):
+ """
+ Overview:
+ Constructor of `SlaveConnectionProxy`
+ Arguments:
+ - connection (:obj:`SlaveConnection`): Slave connection object
+ - after_connect (:obj:`Optional[Callable]`): Behaviour going to be executed after connection established
+ - after_disconnect (:obj:`Optional[Callable]`): Behaviour going to be executed after connection killed
+ """
+ self.__connection = connection
+ self.__lock = Lock()
+ self.__after_connect = after_connect
+ self.__after_disconnect = after_disconnect
+
+ self.__init_triggers()
+
+ @property
+ def is_connected(self) -> bool:
+ """
+ Overview:
+ Check connection status
+ Returns:
+ - connected (:obj:`bool`): Whether this connection is still alive
+ """
+ with self.__lock:
+ return self.__connection.is_connected
+
+ def connect(self):
+ with self.__lock:
+ result = self.__connection.connect()
+ if self.__after_connect is not None:
+ self.__after_connect(connection=self)
+ return result
+
+ def disconnect(self):
+ with self.__lock:
+ result = self.__connection.disconnect()
+ if self.__after_disconnect is not None:
+ self.__after_disconnect(connection=self)
+ return result
+
+ def new_task(self, data: Optional[Mapping[str, Any]] = None):
+ with self.__lock:
+ return self.__connection.new_task(data)
+
+ def __task_complete_trigger(self, task_id: UUID, task_result: Mapping[str, Any]):
+ with self.__lock:
+ return _connection_task_complete(self.__connection, task_id, task_result)
+
+ def __task_fail_trigger(self, task_id: UUID, task_result: Mapping[str, Any]):
+ with self.__lock:
+ return _connection_task_fail(self.__connection, task_id, task_result)
+
+ def __init_triggers(self):
+ setattr(self, _COMPLETE_TRIGGER_NAME, self.__task_complete_trigger)
+ setattr(self, _FAIL_TRIGGER_NAME, self.__task_fail_trigger)
+
+
+def _proxy_task_complete(proxy: SlaveConnectionProxy, task_id: UUID, task_result: Mapping[str, Any]):
+ return getattr(proxy, _COMPLETE_TRIGGER_NAME)(task_id, task_result)
+
+
+def _proxy_task_fail(proxy: SlaveConnectionProxy, task_id: UUID, task_result: Mapping[str, Any]):
+ return getattr(proxy, _FAIL_TRIGGER_NAME)(task_id, task_result)
+
+
+def _slave_task_complete(connection: _ISlaveConnection, task_id: UUID, task_result: Mapping[str, Any]):
+ if isinstance(connection, SlaveConnection):
+ return _connection_task_complete(connection, task_id, task_result)
+ elif isinstance(connection, SlaveConnectionProxy):
+ return _proxy_task_complete(connection, task_id, task_result)
+ else:
+ raise TypeError(
+ "{expect1} or {expect2} expected, but {actual} found.".format(
+ expect1=SlaveConnection.__name__,
+ expect2=SlaveConnectionProxy.__name__,
+ actual=type(connection).__name__,
+ )
+ )
+
+
+def _slave_task_fail(connection: _ISlaveConnection, task_id: UUID, task_result: Mapping[str, Any]):
+ if isinstance(connection, SlaveConnection):
+ return _connection_task_fail(connection, task_id, task_result)
+ elif isinstance(connection, SlaveConnectionProxy):
+ return _proxy_task_fail(connection, task_id, task_result)
+ else:
+ raise TypeError(
+ "{expect1} or {expect2} expected, but {actual} found.".format(
+ expect1=SlaveConnection.__name__,
+ expect2=SlaveConnectionProxy.__name__,
+ actual=type(connection).__name__,
+ )
+ )
+
+
+def _default_wrap(func: Callable) -> Callable:
+
+ @wraps(func)
+ def _new_func(*args, **kwargs):
+ if func:
+ return func(*args, **kwargs)
+ else:
+ return None
+
+ return _new_func
+
+
+def _get_connection_class(
+ before_new_task: Optional[_BEFORE_HOOK_TYPE] = None,
+ after_new_task: Optional[_AFTER_HOOK_TYPE] = None,
+ error_new_task: Optional[_ERROR_HOOK_TYPE] = None,
+ before_connect: Optional[_BEFORE_HOOK_TYPE] = None,
+ after_connect: Optional[_AFTER_HOOK_TYPE] = None,
+ error_connect: Optional[_ERROR_HOOK_TYPE] = None,
+ before_disconnect: Optional[_BEFORE_HOOK_TYPE] = None,
+ after_disconnect: Optional[_AFTER_HOOK_TYPE] = None,
+ error_disconnect: Optional[_ERROR_HOOK_TYPE] = None,
+) -> Type[SlaveConnection]:
+
+ class _Connection(SlaveConnection):
+
+ def _before_connect(self) -> Mapping[str, Any]:
+ return _default_wrap(before_connect)() or {}
+
+ def _after_connect(
+ self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str,
+ Any]]
+ ) -> Any:
+ return _default_wrap(after_connect)(status_code, success, code, message, data)
+
+ def _error_connect(self, error: RequestException) -> Any:
+ return _default_wrap(error_connect)(error)
+
+ def _before_disconnect(self) -> Mapping[str, Any]:
+ return _default_wrap(before_disconnect)() or {}
+
+ def _after_disconnect(
+ self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str,
+ Any]]
+ ) -> Any:
+ return _default_wrap(after_disconnect)(status_code, success, code, message, data)
+
+ def _error_disconnect(self, error: RequestException) -> Any:
+ return _default_wrap(error_disconnect)(error)
+
+ def _before_new_task(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
+ return _default_wrap(before_new_task)(data) or {}
+
+ def _after_new_task(
+ self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str,
+ Any]]
+ ) -> Any:
+ return _default_wrap(after_new_task)(status_code, success, code, message, data)
+
+ def _error_new_task(self, error: RequestException) -> Any:
+ return _default_wrap(error_new_task)(error)
+
+ return _Connection
diff --git a/DI-engine/ding/interaction/master/master.py b/DI-engine/ding/interaction/master/master.py
new file mode 100644
index 0000000000000000000000000000000000000000..555ec6cbc8f30bf344d3ad2d7a62701915dc9959
--- /dev/null
+++ b/DI-engine/ding/interaction/master/master.py
@@ -0,0 +1,653 @@
+import json
+import time
+from functools import wraps, partial
+from queue import Queue, Empty
+from threading import Lock, Thread, Event
+from typing import Optional, Any, Mapping, Type, Callable
+from uuid import UUID
+
+import requests
+from flask import Flask, request
+from requests.exceptions import RequestException
+from urlobject import URLObject
+
+from .connection import SlaveConnectionProxy, SlaveConnection, _ISlaveConnection, _get_connection_class, \
+ _slave_task_complete, _slave_task_fail
+from .task import TaskResultType
+from ..base import random_token, ControllableService, failure_response, success_response, get_host_ip, \
+ get_http_engine_class
+from ..config import GLOBAL_HOST, DEFAULT_MASTER_PORT, DEFAULT_CHANNEL, MIN_HEARTBEAT_SPAN, \
+ DEFAULT_HEARTBEAT_TOLERANCE, MIN_HEARTBEAT_CHECK_SPAN, DEFAULT_HEARTBEAT_CHECK_SPAN, DEFAULT_REQUEST_RETRIES, \
+ DEFAULT_REQUEST_RETRY_WAITING
+from ..exception import MasterErrorCode, get_master_exception_by_error
+
+
+class Master(ControllableService):
+ """
+ Overview:
+ Interaction master end
+ """
+
+ def __init__(
+ self,
+ host: Optional[str] = None,
+ port: Optional[int] = None,
+ heartbeat_tolerance: Optional[float] = None,
+ heartbeat_check_span: Optional[float] = None,
+ request_retries: Optional[int] = None,
+ request_retry_waiting: Optional[float] = None,
+ channel: Optional[int] = None,
+ my_address: Optional[str] = None
+ ):
+ """
+ Overview:
+ Constructor of Master
+ Arguments:
+ - host (:obj:`Optional[str]`): Host of the master server, based on flask (None means `0.0.0.0`)
+ - port (:obj:`Optional[int]`): Port of the master server, based on flask (None means `7235`)
+ - heartbeat_tolerance: (:obj:`Optional[float]`): Max time tolerance of the heartbeat missing (None means \
+ `15.0`, minimum is `0.2`, unit: second)
+ - heartbeat_check_span: (:obj:`Optional[float]`): Timespan between the heartbeat status check (None means \
+ `1.0`, minimum is `0.1`, unit: second)
+ - request_retries (:obj:`Optional[int]`): Max times for request retries (None means `5`)
+ - request_retry_waiting (:obj:`Optional[float]`): Sleep time before requests' retrying (None means `1.0`, \
+ unit: second)
+ - channel (:obj:`Optional[int]`): Channel id for the master client, please make sure that channel id is \
+ equal to the slave client's channel id, or the connection cannot be established. (None means `0`, \
+ but 0 channel is not recommended to be used in production)
+ - my_address (:obj:`Optional[str]`): The address of current server (None will grep local ip automatically, \
+ this address will be used when connect to slave, the slave's request will be send to this address, \
+ **so please make sure the address can be achieved by slave**)
+ """
+ # server part
+ self.__host = host or GLOBAL_HOST
+ self.__port = port or DEFAULT_MASTER_PORT
+ self.__flask_app_value = None
+ self.__run_app_thread = Thread(target=self.__run_app, name='master_run_app')
+
+ # heartbeat part
+ self.__heartbeat_tolerance = max(heartbeat_tolerance or DEFAULT_HEARTBEAT_TOLERANCE, MIN_HEARTBEAT_SPAN)
+ self.__heartbeat_check_span = max(
+ heartbeat_check_span or DEFAULT_HEARTBEAT_CHECK_SPAN, MIN_HEARTBEAT_CHECK_SPAN
+ )
+ self.__heartbeat_check_thread = Thread(target=self.__heartbeat_check, name='master_heartbeat')
+ self.__request_retries = max(request_retries or DEFAULT_REQUEST_RETRIES, 0)
+ self.__request_retry_waiting = max(request_retry_waiting or DEFAULT_REQUEST_RETRY_WAITING, 0.0)
+
+ # self-connection part
+ self.__self_http_engine = get_http_engine_class(
+ headers={
+ 'Token': lambda: self.__self_token,
+ },
+ http_error_gene=get_master_exception_by_error,
+ # )()('localhost', self.__port, False)
+ )()(self.__host, self.__port, False) # TODO: Confirm how to ping itself
+ self.__self_token = random_token()
+
+ # slave-connection part
+ self.__channel = channel or DEFAULT_CHANNEL
+ self.__my_address = my_address or str(
+ URLObject().with_scheme('http').with_hostname(get_host_ip()).with_port(self.__port)
+ )
+
+ # slaves part
+ self.__slaves = {} # name --> (token, slave_connection)
+ self.__token_slaves = {} # token --> (name, slave_connection)
+ self.__slave_last_heartbeat = {} # name --> last_heartbeat
+ self.__slave_lock = Lock()
+
+ # task part
+ self.__task_result_queue = Queue()
+ self.__task_result_process_thread = Thread(target=self.__task_result_process, name='master_task_result')
+
+ # global part
+ self.__shutdown_event = Event()
+ self.__lock = Lock()
+
+ # slave connection
+ def __connection_open(self, name: str, token: str, connection: SlaveConnectionProxy):
+ with self.__slave_lock:
+ self.__slaves[name] = (token, connection)
+ self.__token_slaves[token] = (name, connection)
+ self.__slave_last_heartbeat[name] = time.time()
+
+ # noinspection PyUnusedLocal
+ def __connection_close(self, name: str, connection: Optional[SlaveConnectionProxy] = None):
+ with self.__slave_lock:
+ token, _conn = self.__slaves[name]
+ connection = connection or _conn
+ del self.__slaves[name]
+ del self.__token_slaves[token]
+ del self.__slave_last_heartbeat[name]
+
+ # server part
+ def __generate_app(self):
+ app = Flask(__name__)
+
+ # self apis
+ app.route('/ping', methods=['GET'])(self.__check_self_request(self.__self_ping))
+ app.route('/shutdown', methods=['DELETE'])(self.__check_self_request(self.__self_shutdown))
+
+ # slave apis
+ app.route('/slave/heartbeat', methods=['GET'])(self.__check_slave_request(self.__heartbeat))
+ app.route(
+ '/slave/task/complete', methods=['PUT']
+ )(self.__check_slave_request(self.__check_task_info(self.__task_complete)))
+ app.route(
+ '/slave/task/fail', methods=['PUT']
+ )(self.__check_slave_request(self.__check_task_info(self.__task_fail)))
+
+ return app
+
+ def __flask_app(self) -> Flask:
+ return self.__flask_app_value or self.__generate_app()
+
+ def __run_app(self):
+ self.__flask_app().run(
+ host=self.__host,
+ port=self.__port,
+ )
+
+ # both method checkers
+ def __check_shutdown(self, func: Callable[[], Any]) -> Callable[[], Any]:
+
+ @wraps(func)
+ def _func():
+ if self.__shutdown_event.is_set():
+ return failure_response(
+ code=MasterErrorCode.SYSTEM_SHUTTING_DOWN, message='System has already been shutting down.'
+ ), 401
+ else:
+ return func()
+
+ return _func
+
+ # server method checkers (self)
+ # noinspection DuplicatedCode
+ def __check_self_request(self, func: Callable[[], Any]) -> Callable[[], Any]:
+ return self.__check_shutdown(self.__check_master_token(func))
+
+ def __check_master_token(self, func: Callable[[], Any]) -> Callable[[], Any]:
+
+ @wraps(func)
+ def _func():
+ master_token = request.headers.get('Token', None)
+
+ if master_token is None:
+ return failure_response(
+ code=MasterErrorCode.SELF_TOKEN_NOT_GIVEN, message='Master token not found.'
+ ), 400
+ elif master_token != self.__self_token:
+ return failure_response(
+ code=MasterErrorCode.SELF_TOKEN_INVALID, message='Master token not match with this endpoint.'
+ ), 403
+ else:
+ return func()
+
+ return _func
+
+ # server method checkers (slave)
+ def __check_slave_request(self, func: Callable[[str, _ISlaveConnection], Any]) -> Callable[[], Any]:
+ return self.__check_shutdown(self.__check_channel(self.__check_slave_token(func)))
+
+ # noinspection DuplicatedCode
+ def __check_channel(self, func: Callable[[], Any]) -> Callable[[], Any]:
+
+ @wraps(func)
+ def _func():
+ channel = request.headers.get('Channel', None)
+ channel = int(channel) if channel else None
+
+ if channel is None:
+ return failure_response(code=MasterErrorCode.CHANNEL_NOT_GIVEN, message='Channel not found.'), 400
+ elif channel != self.__channel:
+ return failure_response(
+ code=MasterErrorCode.CHANNEL_INVALID, message='Channel not match with this endpoint.'
+ ), 403
+ else:
+ return func()
+
+ return _func
+
+ def __check_slave_token(self, func: Callable[[str, _ISlaveConnection], Any]) -> Callable[[], Any]:
+
+ @wraps(func)
+ def _func():
+ slave_token = request.headers.get('Token', None)
+
+ if slave_token is None:
+ return failure_response(
+ code=MasterErrorCode.SLAVE_TOKEN_NOT_GIVEN, message='Slave token not found.'
+ ), 400
+ elif slave_token not in self.__token_slaves.keys():
+ return failure_response(
+ code=MasterErrorCode.SLAVE_TOKEN_INVALID, message='No matching slave token found in this endpoint.'
+ ), 403
+ else:
+ name, connection = self.__token_slaves[slave_token]
+ return func(name, connection)
+
+ return _func
+
+ # noinspection PyMethodMayBeStatic
+ def __get_request_data(self, func: Callable[[str, _ISlaveConnection, Mapping[str, Any]], Any]) \
+ -> Callable[[str, _ISlaveConnection], Any]:
+
+ @wraps(func)
+ def _func(name: str, connection: _ISlaveConnection):
+ _data = json.loads(request.data.decode())
+ return func(name, connection, _data)
+
+ return _func
+
+ def __check_task_info(self, func: Callable[[str, _ISlaveConnection, UUID, Mapping[str, Any]], Any]) \
+ -> Callable[[str, _ISlaveConnection], Any]:
+
+ @wraps(func)
+ @self.__get_request_data
+ def _func(name: str, connection: _ISlaveConnection, data: Mapping[str, Any]):
+ if 'task' not in data.keys():
+ return failure_response(
+ code=MasterErrorCode.TASK_DATA_INVALID,
+ message='Task information not found.',
+ )
+ _task_info, _task_result = data['task'], data['result']
+
+ if 'id' not in _task_info.keys():
+ return failure_response(code=MasterErrorCode.TASK_DATA_INVALID, message='Task ID not found.')
+ _task_id = UUID(_task_info['id'])
+
+ return func(name, connection, _task_id, _task_result)
+
+ return _func
+
+ # server methods (self)
+ # noinspection PyMethodMayBeStatic
+ def __self_ping(self):
+ return success_response(message='PONG!')
+
+ def __self_shutdown(self):
+ _shutdown_func = request.environ.get('werkzeug.server.shutdown')
+ if _shutdown_func is None:
+ raise RuntimeError('Not running with the Werkzeug Server')
+
+ self.__shutdown_event.set()
+ _shutdown_func()
+
+ return success_response(message='Shutdown request received, this server will be down later.')
+
+ # server methods (slave)
+ # noinspection PyMethodMayBeStatic,PyUnusedLocal
+ def __heartbeat(self, name: str, connection: _ISlaveConnection):
+ self.__slave_last_heartbeat[name] = time.time()
+ return success_response(message='Received!')
+
+ # noinspection PyUnusedLocal
+ def __task_complete(self, name: str, connection: _ISlaveConnection, task_id: UUID, task_result: Mapping[str, Any]):
+ self.__task_result_queue.put((TaskResultType.COMPLETED, (connection, task_id, task_result)))
+ return success_response(message='Result received!')
+
+ # noinspection PyUnusedLocal
+ def __task_fail(self, name: str, connection: _ISlaveConnection, task_id: UUID, task_result: Mapping[str, Any]):
+ self.__task_result_queue.put((TaskResultType.FAILED, (connection, task_id, task_result)))
+ return success_response(message='Result received!')
+
+ # self request
+ def __self_request(self, method: Optional[str] = 'GET', path: Optional[str] = None) -> requests.Response:
+ return self.__self_http_engine.request(
+ method,
+ path,
+ retries=self.__request_retries,
+ retry_waiting=self.__request_retry_waiting,
+ )
+
+ def __ping_once(self):
+ return self.__self_request('GET', '/ping')
+
+ def __ping_until_started(self):
+ while True:
+ try:
+ self.__ping_once()
+ except (requests.exceptions.BaseHTTPError, requests.exceptions.RequestException):
+ time.sleep(0.2)
+ else:
+ break
+
+ def __shutdown(self):
+ self.__self_request('DELETE', '/shutdown')
+
+ # heartbeat part
+ def __heartbeat_check(self):
+ _last_time = time.time()
+ while not self.__shutdown_event.is_set():
+ _current_time = time.time()
+
+ _common_names = set(self.__slaves.keys()) & set(self.__slave_last_heartbeat.keys())
+ for name in _common_names:
+ _, connection = self.__slaves[name]
+ last_heartbeat = self.__slave_last_heartbeat[name]
+ if _current_time - last_heartbeat > self.__heartbeat_tolerance:
+ self.__connection_close(name, connection)
+
+ _last_time += self.__heartbeat_check_span
+ time.sleep(max(_last_time - time.time(), 0))
+
+ # task process part
+ def __task_result_process(self):
+ while not self.__task_result_queue.empty() or not self.__shutdown_event.is_set():
+ try:
+ _result = self.__task_result_queue.get(timeout=3.0)
+ except Empty:
+ continue
+ else:
+ _type, (_connection, _task_id, _task_result) = _result
+ _trigger_func = _slave_task_complete if _type == TaskResultType.COMPLETED else _slave_task_fail
+ _trigger_func(_connection, _task_id, _task_result)
+
+ # connection part
+ def __get_connection_class(self) -> Type[SlaveConnection]:
+ return _get_connection_class(
+ before_new_task=self._before_new_task,
+ after_new_task=self._after_new_task,
+ error_new_task=self._error_new_task,
+ before_connect=self._before_connect,
+ after_connect=self._after_connect,
+ error_connect=self._error_connect,
+ before_disconnect=self._before_disconnect,
+ after_disconnect=self._after_disconnect,
+ error_disconnect=self._error_disconnect,
+ )
+
+ def __get_new_connection(
+ self, name: str, host: str, port: Optional[int] = None, https: bool = False
+ ) -> SlaveConnectionProxy:
+ if name in self.__slaves.keys():
+ raise KeyError('Connection {name} already exist.'.format(name=repr(name)))
+ else:
+ slave_token = random_token()
+ connection = self.__get_connection_class()(
+ host=host,
+ port=port,
+ https=https,
+ channel=self.__channel,
+ my_address=self.__my_address,
+ token=slave_token,
+ )
+
+ return SlaveConnectionProxy(
+ connection=connection,
+ after_connect=partial(self.__connection_open, name=name, token=slave_token),
+ after_disconnect=partial(self.__connection_close, name=name),
+ )
+
+ # public properties
+ @property
+ def my_address(self) -> str:
+ """
+ Overview:
+ Get my address property of current master client.
+ Returns:
+ - output (:obj:`str`): My address which can be used to establish connection from slave end to here.
+ """
+ with self.__lock:
+ return self.__my_address
+
+ # public methods
+ def ping(self) -> bool:
+ """
+ Overview:
+ Ping the current http server, check if it still run properly.
+ Returns:
+ - output (:obj:`bool`): The http server run properly or not. \
+ `True` means run properly, otherwise return `False`.
+ """
+ with self.__lock:
+ try:
+ self.__ping_once()
+ except (requests.exceptions.BaseHTTPError, requests.exceptions.RequestException):
+ return False
+ else:
+ return True
+
+ def new_connection(
+ self, name: str, host: str, port: Optional[int] = None, https: bool = False
+ ) -> SlaveConnectionProxy:
+ """
+ Overview:
+ Create a new connection object to slave end (but **the connection will be established immediately** \
+ before `connect` method in connection object is called).
+ Arguments:
+ - name (:obj:`str`): Name of the connection (this name is an unique label used in this master client)
+ - host (:obj:`str`): Host of the slave end
+ - port (:obj:`Optional[int]`): Port of the slave end (None means `7236`)
+ - https (:obj:`bool`): Use https to connect or not (Default is `False`)
+ Returns:
+ - output (:obj:`SlaveConnectionProxy`): A connection object represents the connection from here to the \
+ slave end. More actions can be operated by this connection object.
+ """
+ with self.__lock:
+ return self.__get_new_connection(name, host, port, https)
+
+ def __contains__(self, name: str):
+ """
+ Overview:
+ Check if the active connection with the given name exist in this master client.
+ Only connections still alive can be found here.
+ Arguments:
+ - name (:obj:`str`): Name of the connection
+ Returns:
+ - output (:obj:`bool`): Whether connection with the given name exist.
+ """
+ with self.__lock:
+ return name in self.__slaves.keys()
+
+ def __getitem__(self, name: str):
+ """
+ Overview:
+ Try get the active connection with the given name.
+ Only connections still alive can be found here.
+ Arguments:
+ - name (:obj:`str`): Name of the connection
+ Returns:
+ - output (:obj:`bool`): Connection object with the given name.
+ """
+ with self.__lock:
+ if name in self.__slaves.keys():
+ _token, _connection = self.__slaves[name]
+ return _connection
+ else:
+ raise KeyError('Connection {name} not found.'.format(name=repr(name)))
+
+ def __delitem__(self, name: str):
+ """
+ Overview:
+ Delete connection from this master client, and the deleted connection will be killed as well.
+ Only connections still alive can be found here.
+ Arguments:
+ - name (:obj:`str`): Name of the connection
+ """
+ with self.__lock:
+ if name in self.__slaves.keys():
+ _token, _connection = self.__slaves[name]
+ _connection.disconnect()
+ else:
+ raise KeyError('Connection {name} not found.'.format(name=repr(name)))
+
+ def start(self):
+ """
+ Overview:
+ Start current master client
+ Here are the steps executed inside in order:
+ 1. Start the result-processing thread
+ 2. Start the heartbeat check thread
+ 3. Start the http server thread
+ 4. Wait until the http server is online (can be pinged)
+ """
+ with self.__lock:
+ self.__task_result_process_thread.start()
+ self.__heartbeat_check_thread.start()
+ self.__run_app_thread.start()
+
+ self.__ping_until_started()
+
+ def shutdown(self):
+ """
+ Overview:
+ Shutdown current master client.
+ A shutdown request will be sent to the http server, and the shutdown signal will be apply into the \
+ threads, the server will be down soon (You can use `join` method to wait until that time).
+ """
+ with self.__lock:
+ self.__shutdown()
+
+ def join(self):
+ """
+ Overview:
+ Wait until current slave client is down completely.
+ Here are the steps executed inside in order:
+ 1. Wait until the http server thread down
+ 2. Wait until the heartbeat check thread down
+ 3. Wait until the result-processing thread down
+ """
+ with self.__lock:
+ self.__run_app_thread.join()
+ self.__heartbeat_check_thread.join()
+ self.__task_result_process_thread.join()
+
+ # inherit methods
+ def _before_connect(self) -> Mapping[str, Any]:
+ """
+ Overview:
+ Behaviours executed before trying to establish connection, connection data is generated here as well.
+ Default behaviour is to do nothing and return `None`, you can reload this method to change its behaviour.
+ If exception raised in this method, the connection will be canceled.
+ Returns:
+ - output (:obj:`Mapping[str, Any]`): Connection data
+ """
+ pass
+
+ def _after_connect(
+ self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
+ ) -> Any:
+ """
+ Overview:
+ Behaviours executed after trying to establish connection.
+ Default behaviour is to do nothing and return `None`, you can reload this method to change its behaviour.
+ Arguments:
+ - status_code (:obj:`int`): Status code of the connection request
+ - success (:obj:`bool`): Connect success or not
+ - code (:obj:`int`): Error code of the connection (`0` means no error, \
+ other code can be found in `SlaveErrorCode`)
+ - message (:obj:`Optional[str]`): Connection message of the connection
+ - data (:obj:`Optional[Mapping[str, Any]]`): Connection data of the connection (returned by slave end)
+ Returns:
+ - output (:obj:`Any`): Any return data, \
+ this data will be returned in `connect` method in connection object.
+ """
+ pass
+
+ def _error_connect(self, error: RequestException) -> Any:
+ """
+ Overview:
+ Behaviours executed after web error occurred in connection request.
+ Default behaviour is to raise the `error` exception, you can reload this method to change its behaviour, \
+ such as return a proper value like `None`.
+ Arguments:
+ - error (:obj:`RequestException`): Error raised from requests
+ Returns:
+ - output (:obj:`Any`): Any data, this data will be returned in `connect` method in connection object
+ """
+ raise error
+
+ def _before_disconnect(self) -> Mapping[str, Any]:
+ """
+ Overview:
+ Behaviours executed before trying to end connection, disconnection data is generated here as well.
+ Default behaviour is to do nothing and return `None`, you can reload this method to change its behaviour.
+ If exception raised in this method, the disconnection will be canceled.
+ Returns:
+ - output (:obj:`Mapping[str, Any]`): Disconnection data
+ """
+ pass
+
+ def _after_disconnect(
+ self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
+ ) -> Any:
+ """
+ Overview:
+ Behaviours executed after trying to end connection.
+ Default behaviour is to do nothing and return `None`, you can reload this method to change its behaviour.
+ Arguments:
+ - status_code (:obj:`int`): Status code of the disconnection request
+ - success (:obj:`bool`): Disconnect success or not
+ - code (:obj:`int`): Error code of the disconnection (`0` means no error, \
+ other code can be found in `SlaveErrorCode`)
+ - message (:obj:`Optional[str]`): Disconnection message of the disconnection
+ - data (:obj:`Optional[Mapping[str, Any]]`): Disconnection data of the disconnection (returned by slave end)
+ Returns:
+ - output (:obj:`Any`): Any return data, \
+ this data will be returned in `disconnect` method in connection object.
+ """
+ pass
+
+ def _error_disconnect(self, error: RequestException):
+ """
+ Overview:
+ Behaviours executed after web error occurred in disconnection request.
+ Default behaviour is to raise the `error` exception, you can reload this method to change its behaviour, \
+ such as return a proper value like `None`.
+ Arguments:
+ - error (:obj:`RequestException`): Error raised from requests
+ Returns:
+ - output (:obj:`Any`): Any data, this data will be returned in `disconnect` method in connection object
+ """
+ raise error
+
+ # noinspection PyMethodMayBeStatic
+ def _before_new_task(self, data: Optional[Mapping[str, Any]] = None) -> Mapping[str, Any]:
+ """
+ Overview:
+ Behaviours executed before trying to create task.
+ Default behaviour is to do nothing and return the original task data, \
+ you can reload this method to change its behaviour, such as preprocess the task data.
+ If exception raised in this method, the task request will be canceled.
+ Arguments:
+ - data (:obj:`Optional[Mapping[str, Any]]`): Original task data
+ Returns:
+ - output (:obj:`Mapping[str, Any]`): Final task data, which will be send to slave end
+ """
+ return data or {}
+
+ def _after_new_task(
+ self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
+ ) -> Any:
+ """
+ Overview:
+ Behaviours executed after trying to create task.
+ Default behaviour is to do nothing and return `None`, \
+ you can reload this method to change its behaviour, such as return the new task data.
+ Arguments:
+ - status_code (:obj:`int`): Status code of the task request
+ - success (:obj:`bool`): Disconnect success or not
+ - code (:obj:`int`): Error code of the task request (`0` means no error, \
+ other code can be found in `SlaveErrorCode`)
+ - message (:obj:`Optional[str]`): Task message of the task request
+ - data (:obj:`Optional[Mapping[str, Any]]`): Task data of the task request (returned by slave end)
+ Returns:
+ - output (:obj:`Any`): Any return data, \
+ this data will be returned in `start` method in task object.
+ """
+ pass
+
+ def _error_new_task(self, error: RequestException):
+ """
+ Overview:
+ Behaviours executed after web error occurred in task request.
+ Default behaviour is to raise the `error` exception, you can reload this method to change its behaviour, \
+ such as return a proper value like `None`.
+ Arguments:
+ - error (:obj:`RequestException`): Error raised from requests
+ Returns:
+ - output (:obj:`Any`): Any data, this data will be returned in `start` method in task object
+ """
+ raise error
diff --git a/DI-engine/ding/interaction/master/task.py b/DI-engine/ding/interaction/master/task.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9b82396570ba1a37e4e79eec3d9339fb65e4f01
--- /dev/null
+++ b/DI-engine/ding/interaction/master/task.py
@@ -0,0 +1,263 @@
+from enum import unique, IntEnum
+from threading import Lock
+from typing import Mapping, Any, Optional, Callable
+from uuid import UUID, uuid4
+
+import enum_tools
+import requests
+from requests import RequestException
+
+from .base import _BEFORE_HOOK_TYPE, _AFTER_HOOK_TYPE, _ERROR_HOOK_TYPE
+from ..base import HttpEngine, get_values_from_response, default_func
+
+
+@enum_tools.documentation.document_enum
+@unique
+class TaskResultType(IntEnum):
+ """
+ Overview:
+ Types of the task result
+ """
+ COMPLETED = 1 # doc: Task complete without error
+ FAILED = 2 # doc: Task end with error
+
+
+@enum_tools.documentation.document_enum
+@unique
+class TaskStatus(IntEnum):
+ """
+ Overview:
+ Status of a task
+ """
+ IDLE = 0x00 # doc: Task not started, waiting for awake
+
+ STARTING = 0x11 # doc: Task is starting, but initialization is not completed.
+ STARTED = 0x12 # doc: Task started, initialization is completed.
+ START_FAILED = 0x13 # doc: Task start failed, error occurred when initializing.
+
+ COMPLETED = 0x21 # doc: Task completed without error
+ FAILED = 0x22 # doc: Task ended with error
+
+
+_COMPLETE_TRIGGER_NAME = '__TASK_COMPLETE__'
+_FAIL_TRIGGER_NAME = '__TASK_FAIL__'
+
+
+class Task:
+ """
+ Overview:
+ Task object of the connections.
+ Linking call is fully supported.
+ Example:
+ >>> with master.new_connection('cnn1,', '127.0.0.1', 2333) as connection:
+ >>> task = connection.new_task({'data': 233})
+ >>> # task is not sent yet
+ >>>
+ >>> task = task.on_complete(func1).on_fail(func2).on_complete(func3).start().join()
+ >>> # task is completed or failed after this line
+ >>> # when task completed : func1(result) --> func3(result)
+ >>> # when task failed : func2(result)
+ """
+
+ def __init__(
+ self,
+ http_engine: HttpEngine,
+ data: Mapping[str, Any],
+ task_id: Optional[UUID] = None,
+ before_task_start: Optional[_BEFORE_HOOK_TYPE] = None,
+ after_task_start: Optional[_AFTER_HOOK_TYPE] = None,
+ error_task_start: Optional[_ERROR_HOOK_TYPE] = None
+ ):
+ """
+ Overview:
+ Constructor of `Task`
+ Arguments:
+ - http_engine (:obj:`HttpEngine`): Http engine object used by the task
+ - data (:obj:`Mapping[str, Any]`): Task data of the task
+ - task_id (:obj:`Optional[UUID]`): Id of the task
+ - before_task_start (:obj:`Optional[_BEFORE_HOOK_TYPE]`): Callback to be executed before task start \
+ (`None` means do nothing)
+ - after_task_start (:obj:`Optional[_AFTER_HOOK_TYPE]`): Callback to be executed after task start \
+ (`None` means do nothing)
+ - error_task_start (:obj:`Optional[_ERROR_HOOK_TYPE]`): Callback to be executed when task start failed \
+ (`None` means do nothing)
+ """
+ self.__http_engine = http_engine
+ self.__lock = Lock()
+
+ self.__task_id = task_id or uuid4()
+ self.__task_data = data
+ self.__task_result = None
+ self.__task_status = TaskStatus.IDLE
+ self.__task_lock = Lock()
+
+ self.__before_task_start = before_task_start or (lambda d: d)
+ self.__after_task_start = default_func(None)(after_task_start)
+ self.__error_task_start = default_func(None)(error_task_start)
+ self.__after_task_completed_callbacks = []
+ self.__after_task_failed_callbacks = []
+
+ self.__init_triggers()
+
+ def __request(self, method: str, path: str, data: Optional[Mapping[str, Any]] = None) -> requests.Response:
+ return self.__http_engine.request(method, path, data)
+
+ def __task_start(self):
+ try:
+ self.__task_status = TaskStatus.STARTING
+ response = self.__request(
+ 'POST', '/task/new', {
+ 'task': {
+ 'id': str(self.__task_id)
+ },
+ 'data': self.__before_task_start(self.__task_data) or {}
+ }
+ )
+ except RequestException as err:
+ self.__task_status = TaskStatus.START_FAILED
+ return self.__error_task_start(err)
+ else:
+ self.__task_status = TaskStatus.STARTED
+ ret = self.__after_task_start(*get_values_from_response(response))
+ self.__task_lock.acquire()
+ return ret
+
+ def __task_complete(self, result: Mapping[str, Any]):
+ self.__task_status = TaskStatus.COMPLETED
+ self.__task_result = result
+ for _callback in self.__after_task_completed_callbacks:
+ _callback(self.__task_data, result)
+ self.__task_lock.release()
+
+ def __task_fail(self, result: Mapping[str, Any]):
+ self.__task_status = TaskStatus.FAILED
+ self.__task_result = result
+ for _callback in self.__after_task_failed_callbacks:
+ _callback(self.__task_data, result)
+ self.__task_lock.release()
+
+ # trigger methods
+ def __task_complete_trigger(self, result: Mapping[str, Any]):
+ with self.__lock:
+ if self.__task_status == TaskStatus.STARTED:
+ self.__task_complete(result)
+ else:
+ raise ValueError(
+ "Only task with {expect} status can be completed, but {actual} found.".format(
+ expect=repr(TaskStatus.STARTED.name),
+ actual=repr(self.__task_status.name),
+ )
+ )
+
+ def __task_fail_trigger(self, result: Mapping[str, Any]):
+ with self.__lock:
+ if self.__task_status == TaskStatus.STARTED:
+ self.__task_fail(result)
+ else:
+ raise ValueError(
+ "Only task with {expect} status can be failed, but {actual} found.".format(
+ expect=repr(TaskStatus.STARTED.name),
+ actual=repr(self.__task_status.name),
+ )
+ )
+
+ def __init_triggers(self):
+ setattr(self, _COMPLETE_TRIGGER_NAME, self.__task_complete_trigger)
+ setattr(self, _FAIL_TRIGGER_NAME, self.__task_fail_trigger)
+
+ # public properties
+ @property
+ def status(self) -> TaskStatus:
+ """
+ Overview:
+ Get status of the current task
+ Returns:
+ - status (:obj:`TaskStatus`): Task status
+ """
+ return self.__task_status
+
+ @property
+ def task(self) -> Mapping[str, Any]:
+ """
+ Overview:
+ Get task data of the current task
+ Returns:
+ - data (:obj:`Mapping[str, Any]`): Task data
+ """
+ return self.__task_data
+
+ @property
+ def result(self) -> Optional[Mapping[str, Any]]:
+ """
+ Overview:
+ Get task result of the current task, return `None` if task is not completed or failed.
+ Returns:
+ - result (:obj:`Optional[Mapping[str, Any]]`): Task result (`None` when not completed or failed)
+ """
+ return self.__task_result
+
+ # public methods
+ def start(self) -> 'Task':
+ """
+ Overview:
+ Start current task.
+ Returns:
+ - task (:obj:`Task`): Self object, supporting linking call
+ """
+ with self.__lock:
+ if self.__task_status == TaskStatus.IDLE:
+ self.__task_start()
+ return self
+ else:
+ raise ValueError(
+ "Only task with {expect} status can be started, but {actual} found.".format(
+ expect=repr(TaskStatus.IDLE.name),
+ actual=repr(self.__task_status.name),
+ )
+ )
+
+ def join(self) -> 'Task':
+ """
+ Overview:
+ Wait until the task is completed or failed.
+ Returns:
+ - task (:obj:`Task`): Self object, supporting linking call
+ """
+ with self.__task_lock:
+ return self
+
+ def on_complete(self, callback: Callable[[Mapping[str, Any], Mapping[str, Any]], Any]) -> 'Task':
+ """
+ Overview:
+ Execute the callback when the task completed. Multiple callbacks is supported by using linking call.
+ Arguments:
+ - callback (:obj:`Callable[[Mapping[str, Any], Mapping[str, Any]], Any]`): Function to be executed when \
+ task completed.
+ Returns:
+ - task (:obj:`Task`): Self object, supporting linking call
+ """
+ with self.__lock:
+ self.__after_task_completed_callbacks.append(callback)
+ return self
+
+ def on_fail(self, callback: Callable[[Mapping[str, Any], Mapping[str, Any]], Any]) -> 'Task':
+ """
+ Overview:
+ Execute the callback when the task failed. Multiple callbacks is supported by using linking call.
+ Arguments:
+ - callback (:obj:`Callable[[Mapping[str, Any], Mapping[str, Any]], Any]`): Function to be executed when \
+ task failed.
+ Returns:
+ - task (:obj:`Task`): Self object, supporting linking call
+ """
+ with self.__lock:
+ self.__after_task_failed_callbacks.append(callback)
+ return self
+
+
+def _task_complete(task: Task, result: Mapping[str, Any]):
+ getattr(task, _COMPLETE_TRIGGER_NAME)(result)
+
+
+def _task_fail(task: Task, result: Mapping[str, Any]):
+ getattr(task, _FAIL_TRIGGER_NAME)(result)
diff --git a/DI-engine/ding/interaction/slave/__init__.py b/DI-engine/ding/interaction/slave/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4d113964d98c14df61ff5c6cb086529ff1f2dcf
--- /dev/null
+++ b/DI-engine/ding/interaction/slave/__init__.py
@@ -0,0 +1,2 @@
+from .action import TaskRefuse, DisconnectionRefuse, ConnectionRefuse, TaskFail
+from .slave import Slave
diff --git a/DI-engine/ding/interaction/slave/action.py b/DI-engine/ding/interaction/slave/action.py
new file mode 100644
index 0000000000000000000000000000000000000000..c6ed70597229fba8d988b05b46f04762237c8a5c
--- /dev/null
+++ b/DI-engine/ding/interaction/slave/action.py
@@ -0,0 +1,138 @@
+from typing import Optional, Any, Mapping
+
+from ..base import ResponsibleException
+from ..exception import SlaveErrorCode
+
+
+class ConnectionRefuse(ResponsibleException):
+ """
+ Overview:
+ Exception represents the refuse to connection to slave from master, can be used in method `_before_connection`.
+ Example:
+ - Without data
+
+ >>> raise ConnectionRefuse
+
+ - With refuse data
+
+ >>> raise ConnectionRefuse({'data': 233})
+ """
+
+ def __init__(self, data: Optional[Mapping[str, Any]] = None):
+ """
+ Overview:
+ Constructor of ConnectionRefuse
+ Arguments:
+ - data (:obj:`Optional[Mapping[str, Any]]`): Key-value-formed refuse data
+ """
+ ResponsibleException.__init__(
+ self,
+ SlaveErrorCode.SLAVE_CONNECTION_REFUSED,
+ message='Connection refused!',
+ data=data or {},
+ status_code=403,
+ )
+
+
+class DisconnectionRefuse(ResponsibleException):
+ """
+ Overview:
+ Exception represents the refuse to disconnection to slave from master,
+ can be used in method `_before_disconnection`.
+ Example:
+ - Without data
+
+ >>> raise DisconnectionRefuse
+
+ - With refuse data
+
+ >>> raise DisconnectionRefuse({'data': 233})
+ """
+
+ def __init__(self, data: Optional[Mapping[str, Any]] = None):
+ """
+ Overview:
+ Constructor of DisconnectionRefuse
+ Arguments:
+ - data (:obj:`Optional[Mapping[str, Any]]`): Key-value-formed refuse data
+ """
+ ResponsibleException.__init__(
+ self,
+ SlaveErrorCode.SLAVE_DISCONNECTION_REFUSED,
+ message='Disconnection refused!',
+ data=data or {},
+ status_code=403,
+ )
+
+
+class TaskRefuse(ResponsibleException):
+ """
+ Overview:
+ Exception represents the refuse to tasks, can be used in method `_before_task`.
+ Example:
+ - Without data
+
+ >>> raise TaskRefuse
+
+ - With refuse data
+
+ >>> raise TaskRefuse({'data': 233})
+ """
+
+ def __init__(self, data: Optional[Mapping[str, Any]] = None):
+ """
+ Overview:
+ Constructor of TaskRefuse
+ Arguments:
+ - data (:obj:`Optional[Mapping[str, Any]]`): Key-value-formed refuse data
+ """
+ ResponsibleException.__init__(
+ self,
+ SlaveErrorCode.TASK_REFUSED,
+ message='Task refused!',
+ data=data or {},
+ status_code=403,
+ )
+
+
+class TaskFail(Exception):
+ """
+ Overview:
+ Exception represents the failure of tasks, can be used in method `_process_task`.
+ Example:
+ - Without data
+
+ >>> raise TaskFail
+
+ - With failure data
+
+ >>> raise TaskFail({'data': 233})
+
+ - With both data and message
+
+ >>> raise TaskFail({'data': 233}, 'this is message')
+ """
+
+ def __init__(self, result: Optional[Mapping[str, Any]] = None, message: Optional[str] = None):
+ """
+ Overview:
+ Constructor of TaskFail
+ Arguments:
+ - result (:obj:`Optional[Mapping[str, Any]]`): Result of task failure
+ - message (:obj:`Optional[str]`): Message of task failure
+ """
+ if message:
+ Exception.__init__(self, 'Task process failed - {message}.'.format(message=message))
+ else:
+ Exception.__init__(self, 'Task process failed.')
+ self.__result = result or {}
+
+ @property
+ def result(self) -> Mapping[str, Any]:
+ """
+ Overview:
+ Get the result of task failure.
+ Returns:
+ Result of task failure.
+ """
+ return self.__result
diff --git a/DI-engine/ding/interaction/slave/slave.py b/DI-engine/ding/interaction/slave/slave.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bc7125717fae960217bf28a574dfab918db9dbd
--- /dev/null
+++ b/DI-engine/ding/interaction/slave/slave.py
@@ -0,0 +1,520 @@
+import json
+import sys
+import time
+import traceback
+from abc import abstractmethod
+from functools import wraps
+from threading import Thread, Event, Lock
+from typing import Optional, Callable, Any, Mapping
+from uuid import UUID
+
+import requests
+from flask import Flask, request
+
+from .action import ConnectionRefuse, DisconnectionRefuse, TaskRefuse, TaskFail
+from ..base import random_token, ControllableService, get_http_engine_class, split_http_address, success_response, \
+ failure_response, DblEvent
+from ..config import DEFAULT_SLAVE_PORT, DEFAULT_CHANNEL, GLOBAL_HOST, DEFAULT_HEARTBEAT_SPAN, MIN_HEARTBEAT_SPAN, \
+ DEFAULT_REQUEST_RETRIES, DEFAULT_REQUEST_RETRY_WAITING
+from ..exception import SlaveErrorCode, get_slave_exception_by_error, get_master_exception_by_error
+
+
+class Slave(ControllableService):
+ r"""
+ Overview:
+ Interaction slave client
+ """
+
+ def __init__(
+ self,
+ host: Optional[str] = None,
+ port: Optional[int] = None,
+ heartbeat_span: Optional[float] = None,
+ request_retries: Optional[int] = None,
+ request_retry_waiting: Optional[float] = None,
+ channel: Optional[int] = None
+ ):
+ """
+ Overview:
+ Constructor of Slave class
+ Arguments:
+ - host (:obj:`Optional[str]`): Host of the slave server, based on flask (None means `0.0.0.0`)
+ - port (:obj:`Optional[int]`): Port of the slave server, based on flask (None means `7236`)
+ - heartbeat_span (:obj:`Optional[float]`): Time span of heartbeat packages in seconds \
+ (None means `3.0`, minimum is `0.2`)
+ - request_retries (:obj:`Optional[int]`): Max times for request retries (None means `5`)
+ - request_retry_waiting (:obj:`Optional[float]`): Sleep time before requests' retrying (None means `1.0`)
+ - channel (:obj:`Optional[int]`): Channel id for the slave client, please make sure that channel id is \
+ equal to the master client's channel id, or the connection cannot be established. (None means `0`, \
+ but 0 channel is not recommended to be used in production)
+ """
+ # server part
+ self.__host = host or GLOBAL_HOST
+ self.__port = port or DEFAULT_SLAVE_PORT
+ self.__flask_app_value = None
+ self.__run_app_thread = Thread(target=self.__run_app, name='slave_run_app')
+
+ # heartbeat part
+ self.__heartbeat_span = max(heartbeat_span or DEFAULT_HEARTBEAT_SPAN, MIN_HEARTBEAT_SPAN)
+ self.__heartbeat_thread = Thread(target=self.__heartbeat, name='slave_heartbeat')
+ self.__request_retries = max(request_retries or DEFAULT_REQUEST_RETRIES, 0)
+ self.__request_retry_waiting = max(request_retry_waiting or DEFAULT_REQUEST_RETRY_WAITING, 0.0)
+
+ # task part
+ self.__has_task = DblEvent()
+ self.__task_lock = Lock()
+ self.__task_id = None
+ self.__task_data = None
+ self.__task_thread = Thread(target=self.__task, name='slave_task')
+
+ # self-connection part
+ self.__self_http_engine = get_http_engine_class(
+ headers={
+ 'Token': lambda: self.__self_token,
+ },
+ http_error_gene=get_slave_exception_by_error,
+ # )()('localhost', self.__port, False)
+ )()(self.__host, self.__port, False) # TODO: Confirm how to ping itself
+ self.__self_token = random_token()
+
+ # master-connection part
+ self.__channel = channel or DEFAULT_CHANNEL
+ self.__connected = DblEvent()
+ self.__master_token = None
+ self.__master_address = None
+ self.__master_http_engine = None
+
+ # global part
+ self.__shutdown_event = Event()
+ self.__lock = Lock()
+
+ # master connection
+ def __register_master(self, token: str, address: str):
+ self.__master_token = token
+ self.__master_address = address
+ self.__master_http_engine = get_http_engine_class(
+ headers={
+ 'Channel': lambda: str(self.__channel),
+ 'Token': lambda: self.__master_token,
+ },
+ http_error_gene=get_master_exception_by_error,
+ )()(*split_http_address(self.__master_address))
+
+ def __unregister_master(self):
+ self.__master_token = None
+ self.__master_address = None
+ self.__master_http_engine = None
+
+ def __open_master_connection(self, token: str, address: str):
+ self.__register_master(token, address)
+ self.__connected.open()
+
+ def __close_master_connection(self):
+ self.__unregister_master()
+ self.__connected.close()
+
+ # server part
+ def __generate_app(self):
+ app = Flask(__name__)
+
+ # master apis
+ app.route('/connect', methods=['POST'])(self.__check_master_request(self.__connect, False))
+ app.route('/disconnect', methods=['DELETE'])(self.__check_master_request(self.__disconnect, True))
+ app.route('/task/new', methods=['POST'])(self.__check_master_request(self.__new_task, True))
+
+ # self apis
+ app.route('/ping', methods=['GET'])(self.__check_self_request(self.__self_ping))
+ app.route('/shutdown', methods=['DELETE'])(self.__check_self_request(self.__self_shutdown))
+
+ return app
+
+ def __flask_app(self) -> Flask:
+ return self.__flask_app_value or self.__generate_app()
+
+ def __run_app(self):
+ self.__flask_app().run(
+ host=self.__host,
+ port=self.__port,
+ )
+
+ # both method checkers
+ def __check_shutdown(self, func: Callable[[], Any]) -> Callable[[], Any]:
+
+ @wraps(func)
+ def _func():
+ if self.__shutdown_event.is_set():
+ return failure_response(
+ code=SlaveErrorCode.SYSTEM_SHUTTING_DOWN, message='System has already been shutting down.'
+ ), 401
+ else:
+ return func()
+
+ return _func
+
+ # server method checkers (master)
+ def __check_master_request(self,
+ func: Callable[[str, Mapping[str, Any]], Any],
+ need_match: bool = True) -> Callable[[], Any]:
+ return self.__check_shutdown(self.__check_channel(self.__check_master_token(func, need_match)))
+
+ # noinspection DuplicatedCode
+ def __check_channel(self, func: Callable[[], Any]) -> Callable[[], Any]:
+
+ @wraps(func)
+ def _func():
+ channel = request.headers.get('Channel', None)
+ channel = int(channel) if channel else None
+
+ if channel is None:
+ return failure_response(code=SlaveErrorCode.CHANNEL_NOT_FOUND, message='Channel not found.'), 400
+ elif channel != self.__channel:
+ return failure_response(
+ code=SlaveErrorCode.CHANNEL_INVALID, message='Channel not match with this endpoint.'
+ ), 403
+ else:
+ return func()
+
+ return _func
+
+ def __check_master_token(self,
+ func: Callable[[str, Mapping[str, Any]], Any],
+ need_match: bool = True) -> Callable[[], Any]:
+
+ @wraps(func)
+ def _func():
+ master_token = request.headers.get('Token', None)
+ if master_token is None:
+ return failure_response(
+ code=SlaveErrorCode.MASTER_TOKEN_NOT_FOUND, message='Master token not found.'
+ ), 400
+ elif need_match and (master_token != self.__master_token):
+ return failure_response(
+ code=SlaveErrorCode.MASTER_TOKEN_INVALID, message='Master not match with this endpoint.'
+ ), 403
+ else:
+ return func(master_token, json.loads(request.data.decode()))
+
+ return _func
+
+ # server method checkers (self)
+ # noinspection DuplicatedCode
+ def __check_self_request(self, func: Callable[[], Any]) -> Callable[[], Any]:
+ return self.__check_shutdown(self.__check_slave_token(func))
+
+ def __check_slave_token(self, func: Callable[[], Any]) -> Callable[[], Any]:
+
+ @wraps(func)
+ def _func():
+ slave_token = request.headers.get('Token', None)
+
+ if slave_token is None:
+ return failure_response(code=SlaveErrorCode.SELF_TOKEN_NOT_FOUND, message='Slave token not found.'), 400
+ elif slave_token != self.__self_token:
+ return failure_response(
+ code=SlaveErrorCode.SELF_TOKEN_INVALID, message='Slave token not match with this endpoint.'
+ ), 403
+ else:
+ return func()
+
+ return _func
+
+ # server methods (self)
+ # noinspection PyMethodMayBeStatic
+ def __self_ping(self):
+ return success_response(message='PONG!')
+
+ def __self_shutdown(self):
+ _shutdown_func = request.environ.get('werkzeug.server.shutdown')
+ if _shutdown_func is None:
+ raise RuntimeError('Not running with the Werkzeug Server')
+
+ self.__shutdown_event.set()
+ _shutdown_func()
+
+ return success_response(message='Shutdown request received, this server will be down later.')
+
+ # server methods (master)
+ # noinspection PyUnusedLocal
+ def __connect(self, token: str, data: Mapping[str, Any]):
+ if self.__connected.is_open():
+ return failure_response(
+ code=SlaveErrorCode.SLAVE_ALREADY_CONNECTED, message='This slave already connected.'
+ ), 400
+ else:
+ _master_info, _connection_data = data['master'], data['data']
+
+ try:
+ self._before_connection(_connection_data)
+ except ConnectionRefuse as err:
+ return err.get_response()
+ else:
+ self.__open_master_connection(token, _master_info['address'])
+ return success_response(message='Connect success.')
+
+ # noinspection PyUnusedLocal
+ def __new_task(self, token: str, data: Mapping[str, Any]):
+ with self.__task_lock:
+ if self.__has_task.is_open():
+ return failure_response(code=SlaveErrorCode.TASK_ALREADY_EXIST, message='Already has a task.'), 400
+ else:
+ _task_info, _task_data = data['task'], data['data']
+ _task_id = _task_info['id']
+
+ try:
+ self._before_task(_task_data)
+ except TaskRefuse as err:
+ return err.get_response()
+ else:
+ self.__task_id = UUID(_task_id)
+ self.__task_data = _task_data
+ self.__has_task.open()
+ return success_response(message='Task received!')
+
+ # noinspection PyUnusedLocal
+ def __disconnect(self, token: str, data: Mapping[str, Any]):
+ if self.__connected.is_close():
+ return failure_response(
+ code=SlaveErrorCode.SLAVE_NOT_CONNECTED, message='This slave not connected yet.'
+ ), 400
+ else:
+ _disconnection_data = data['data']
+
+ try:
+ self._before_disconnection(_disconnection_data)
+ except DisconnectionRefuse as err:
+ return err.get_response()
+ else:
+ self.__close_master_connection()
+ return success_response(message='Disconnect success.')
+
+ # heartbeat part
+ def __heartbeat(self):
+ _last_time = time.time()
+ while not self.__shutdown_event.is_set():
+ if self.__connected.is_open():
+ try:
+ self.__master_heartbeat()
+ except requests.exceptions.RequestException as err:
+ self._lost_connection(self.__master_address, err)
+ self.__close_master_connection()
+ traceback.print_exception(*sys.exc_info(), file=sys.stderr)
+
+ _last_time += self.__heartbeat_span
+ time.sleep(max(_last_time - time.time(), 0))
+
+ # task part
+ def __task(self):
+ while not self.__shutdown_event.is_set():
+ self.__has_task.wait_for_open(timeout=1.0)
+ if self.__has_task.is_open():
+ # noinspection PyBroadException
+ try:
+ result = self._process_task(self.__task_data)
+ except TaskFail as fail:
+ self.__has_task.close()
+ self.__master_task_fail(fail.result)
+ except Exception:
+ self.__has_task.close()
+ traceback.print_exception(*sys.exc_info(), file=sys.stderr)
+ else:
+ self.__has_task.close()
+ self.__master_task_complete(result)
+
+ # self request operations
+ def __self_request(self, method: Optional[str] = 'GET', path: Optional[str] = None) -> requests.Response:
+ return self.__self_http_engine.request(
+ method,
+ path,
+ retries=self.__request_retries,
+ retry_waiting=self.__request_retry_waiting,
+ )
+
+ def __ping_once(self):
+ return self.__self_request('GET', '/ping')
+
+ def __ping_until_started(self):
+ while True:
+ try:
+ self.__ping_once()
+ except (requests.exceptions.BaseHTTPError, requests.exceptions.RequestException):
+ time.sleep(0.2)
+ else:
+ break
+
+ def __shutdown(self):
+ self.__self_request('DELETE', '/shutdown')
+
+ # master request operations
+ def __master_request(
+ self,
+ method: Optional[str] = 'GET',
+ path: Optional[str] = None,
+ data: Optional[Mapping[str, Any]] = None
+ ) -> requests.Response:
+ return self.__master_http_engine.request(
+ method,
+ path,
+ data,
+ retries=self.__request_retries,
+ retry_waiting=self.__request_retry_waiting,
+ )
+
+ def __master_heartbeat(self):
+ return self.__master_request('GET', '/slave/heartbeat')
+
+ def __master_task_complete(self, result: Mapping[str, Any]):
+ return self.__master_request(
+ 'PUT', '/slave/task/complete', data={
+ 'task': {
+ 'id': str(self.__task_id)
+ },
+ 'result': result or {},
+ }
+ )
+
+ def __master_task_fail(self, result: Mapping[str, Any]):
+ return self.__master_request(
+ 'PUT', '/slave/task/fail', data={
+ 'task': {
+ 'id': str(self.__task_id)
+ },
+ 'result': result or {},
+ }
+ )
+
+ # public methods
+ def ping(self) -> bool:
+ """
+ Overview:
+ Ping the current http server, check if it still run properly.
+ Returns:
+ - output (:obj:`bool`): The http server run properly or not. \
+ `True` means run properly, otherwise return `False`.
+ """
+ with self.__lock:
+ try:
+ self.__ping_once()
+ except (requests.exceptions.BaseHTTPError, requests.exceptions.RequestException):
+ return False
+ else:
+ return True
+
+ def start(self):
+ """
+ Overview:
+ Start current slave client
+ Here are the steps executed inside in order:
+
+ 1. Start the task-processing thread
+ 2. Start the heartbeat thread
+ 3. Start the http server thread
+ 4. Wait until the http server is online (can be pinged)
+ """
+ with self.__lock:
+ self.__task_thread.start()
+ self.__heartbeat_thread.start()
+ self.__run_app_thread.start()
+
+ self.__ping_until_started()
+
+ def shutdown(self):
+ """
+ Overview:
+ Shutdown current slave client.
+ A shutdown request will be sent to the http server, and the shutdown signal will be apply into the \
+ threads, the server will be down soon (You can use `join` method to wait until that time).
+ """
+ with self.__lock:
+ self.__shutdown()
+
+ def join(self):
+ """
+ Overview:
+ Wait until current slave client is down completely.
+ Here are the steps executed inside in order:
+
+ 1. Wait until the http server thread down
+ 2. Wait until the heartbeat thread down
+ 3. Wait until the task-processing thread down
+ """
+ with self.__lock:
+ self.__run_app_thread.join()
+ self.__heartbeat_thread.join()
+ self.__task_thread.join()
+
+ # inherit method
+ def _before_connection(self, data: Mapping[str, Any]):
+ """
+ Overview:
+ Behaviours that will be executed before connection is established.
+ Arguments:
+ - data (:obj:`Mapping[str, Any]`): Connection data when connect to this slave, sent from master.
+ Raises:
+ - `ConnectionRefuse` After raise this, the connection from master end will be refused, \
+ no new connection will be established.
+ """
+ pass
+
+ def _before_disconnection(self, data: Mapping[str, Any]):
+ """
+ Overview:
+ Behaviours that will be executed before disconnection is executed.
+ Arguments:
+ - data (:obj:`Mapping[str, Any]`): Disconnection data when disconnect with this slave, sent from master.
+ Raises:
+ - `DisconnectionRefuse` After raise this, the disconnection request will be refused, \
+ current connection will be still exist.
+ """
+ pass
+
+ def _before_task(self, data: Mapping[str, Any]):
+ """
+ Overview:
+ Behaviours that will be executed before task is executed.
+ Arguments:
+ - data (:obj:`Mapping[str, Any]`): Data of the task
+ Raises:
+ - `TaskRefuse` After raise this, the new task will be refused.
+ """
+ pass
+
+ def _lost_connection(self, master_address: str, err: requests.exceptions.RequestException):
+ """
+ Overview:
+ Behaviours that will be executed after connection is lost.
+ Arguments:
+ - master_address (:obj:`str`): String address of master end
+ - err (:obj:`request.exceptions.RequestException`): Http exception of this connection loss
+ """
+ pass
+
+ @abstractmethod
+ def _process_task(self, task: Mapping[str, Any]):
+ """
+ Overview:
+ Execute the task, this protected method must be implement in the subclass.
+ Arguments:
+ - task (:obj:`Mapping[str, Any]`): Data of the task
+ Raises:
+ - `TaskFail` After raise this, this task will be recognized as run failed, \
+ master will received the failure signal.
+ Example:
+ - A success task with return value (the return value will be received in master end)
+
+ >>> def _process_task(self, task):
+ >>> print('this is task data :', task)
+ >>> return str(task)
+
+ - A failed task with data (the data will be received in master end)
+
+ >>> def _process_task(self, task):
+ >>> print('this is task data :', task)
+ >>> raise TaskFail(task) # this is a failed task
+
+ - A failed task with data and message (both will be received in master end)
+
+ >>> def _process_task(self, task):
+ >>> print('this is task data :', task)
+ >>> raise TaskFail(task, 'this is message') # this is a failed task with message
+ """
+ raise NotImplementedError
diff --git a/DI-engine/ding/interaction/tests/__init__.py b/DI-engine/ding/interaction/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dff8351b7fce1cac8d21db6147e0e9812819b3a8
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/__init__.py
@@ -0,0 +1,4 @@
+from .base import *
+from .config import *
+from .exception import *
+from .interaction import *
diff --git a/DI-engine/ding/interaction/tests/base/__init__.py b/DI-engine/ding/interaction/tests/base/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..597a8197d45647531babb39232abf7ec77f22990
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/base/__init__.py
@@ -0,0 +1,4 @@
+from .test_app import TestInteractionBaseApp, TestInteractionBaseResponsibleException
+from .test_common import TestInteractionBaseCommon, TestInteractionBaseControllableService
+from .test_network import TestInteractionBaseHttpEngine, TestInteractionBaseNetwork
+from .test_threading import TestInteractionBaseThreading
diff --git a/DI-engine/ding/interaction/tests/base/test_app.py b/DI-engine/ding/interaction/tests/base/test_app.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3707c19a8001adea92bc85befe7ee04210d967f
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/base/test_app.py
@@ -0,0 +1,227 @@
+import json
+
+import pytest
+from flask import Flask
+
+from ...base import success_response, failure_response, get_values_from_response, ResponsibleException, responsible
+
+
+@pytest.mark.unittest
+class TestInteractionBaseApp:
+
+ def test_success_response(self):
+ app = Flask('_test_success_response')
+
+ @app.route('/success', methods=['GET'])
+ def success_method():
+ return success_response(
+ data={
+ 'a': 1,
+ 'b': 2,
+ 'sum': 3,
+ },
+ message='This is success message.',
+ )
+
+ client = app.test_client()
+
+ response = client.get('/success')
+ assert response.status_code == 200
+ assert json.loads(response.data.decode()) == {
+ 'success': True,
+ 'code': 0,
+ 'data': {
+ 'a': 1,
+ 'b': 2,
+ 'sum': 3,
+ },
+ 'message': 'This is success message.',
+ }
+
+ # noinspection DuplicatedCode
+ def test_failure_response(self):
+ app = Flask('_test_failure_response')
+
+ @app.route('/fail', methods=['GET'])
+ def fail_method():
+ return failure_response(
+ code=233,
+ message='This is failure message.',
+ data={
+ 'a': 2,
+ 'b': 3,
+ 'sum': 5,
+ },
+ ), 404
+
+ client = app.test_client()
+
+ response = client.get('/fail')
+ assert response.status_code == 404
+ assert json.loads(response.data.decode()) == {
+ 'success': False,
+ 'code': 233,
+ 'data': {
+ 'a': 2,
+ 'b': 3,
+ 'sum': 5,
+ },
+ 'message': 'This is failure message.',
+ }
+
+ def test_get_values_from_response(self):
+ app = Flask('_test_get_values_from_response')
+
+ @app.route('/success', methods=['GET'])
+ def success_method():
+ return success_response(
+ data={
+ 'a': 1,
+ 'b': 2,
+ 'sum': 3,
+ },
+ message='This is success message.',
+ )
+
+ @app.route('/fail', methods=['GET'])
+ def fail_method():
+ return failure_response(
+ code=233,
+ message='This is failure message.',
+ data={
+ 'a': 2,
+ 'b': 3,
+ 'sum': 5,
+ },
+ ), 404
+
+ client = app.test_client()
+
+ response = client.get('/success')
+ assert response.status_code == 200
+ assert get_values_from_response(response) == (
+ 200,
+ True,
+ 0,
+ 'This is success message.',
+ {
+ 'a': 1,
+ 'b': 2,
+ 'sum': 3,
+ },
+ )
+
+ response = client.get('/fail')
+ assert response.status_code == 404
+ assert get_values_from_response(response) == (
+ 404,
+ False,
+ 233,
+ 'This is failure message.',
+ {
+ 'a': 2,
+ 'b': 3,
+ 'sum': 5,
+ },
+ )
+
+
+@pytest.mark.unittest
+class TestInteractionBaseResponsibleException:
+ # noinspection DuplicatedCode
+ def test_it(self):
+
+ class NotFound(ResponsibleException):
+
+ def __init__(self):
+ ResponsibleException.__init__(
+ self=self,
+ status_code=404,
+ code=233,
+ message='This is failure message.',
+ data={
+ 'a': 2,
+ 'b': 3,
+ 'sum': 5,
+ }
+ )
+
+ class AccessDenied(ResponsibleException):
+
+ def __init__(self):
+ ResponsibleException.__init__(
+ self=self,
+ status_code=403,
+ code=322,
+ message='This is another failure message.',
+ data={
+ 'a': 2,
+ 'b': 3,
+ 'sum': 7,
+ }
+ )
+
+ app = Flask('_test_failure_response')
+
+ @app.route('/fail', methods=['GET'])
+ @responsible(classes=(NotFound, ))
+ def fail_method():
+ raise NotFound
+
+ @app.route('/403', methods=['GET'])
+ @responsible()
+ def denied_method():
+ raise AccessDenied
+
+ @app.route('/success', methods=['GET'])
+ @responsible()
+ def success_method():
+ return success_response(
+ data={
+ 'a': 1,
+ 'b': 2,
+ 'sum': 3,
+ },
+ message='This is success message.',
+ )
+
+ client = app.test_client()
+
+ response = client.get('/fail')
+ assert response.status_code == 404
+ assert json.loads(response.data.decode()) == {
+ 'success': False,
+ 'code': 233,
+ 'data': {
+ 'a': 2,
+ 'b': 3,
+ 'sum': 5,
+ },
+ 'message': 'This is failure message.',
+ }
+
+ response = client.get('/403')
+ assert response.status_code == 403
+ assert json.loads(response.data.decode()) == {
+ 'success': False,
+ 'code': 322,
+ 'data': {
+ 'a': 2,
+ 'b': 3,
+ 'sum': 7,
+ },
+ 'message': 'This is another failure message.',
+ }
+
+ response = client.get('/success')
+ assert response.status_code == 200
+ assert json.loads(response.data.decode()) == {
+ 'success': True,
+ 'code': 0,
+ 'data': {
+ 'a': 1,
+ 'b': 2,
+ 'sum': 3,
+ },
+ 'message': 'This is success message.',
+ }
diff --git a/DI-engine/ding/interaction/tests/base/test_common.py b/DI-engine/ding/interaction/tests/base/test_common.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1eff39d9a18e82f09aea4d7fec300f5a8b59b5e
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/base/test_common.py
@@ -0,0 +1,75 @@
+import string
+import time
+from typing import Any, Callable
+
+import pytest
+
+from ...base import random_token, translate_dict_func, default_func, ControllableService
+
+
+@pytest.mark.unittest
+class TestInteractionBaseCommon:
+
+ def test_random_token(self):
+ assert len(random_token()) == 64
+ assert len(random_token(32)) == 32
+ assert set(random_token()) - set(string.hexdigits) == set()
+
+ def test_translate_dict_func(self):
+ assert translate_dict_func({
+ 'a': lambda: 2,
+ 'b': lambda: 3,
+ 'sum': lambda: 5,
+ })() == {
+ 'a': 2,
+ 'b': 3,
+ 'sum': 5
+ }
+ assert translate_dict_func(
+ {
+ 'a': lambda ax, bx: 2 + ax,
+ 'b': lambda ax, bx: 3 + bx,
+ 'sum': lambda ax, bx: 5 + ax + bx,
+ }
+ )(4, 5) == {
+ 'a': 6,
+ 'b': 8,
+ 'sum': 14
+ }
+
+ def test_default_func(self):
+
+ def _calculate(a: int, b: int, callback: Callable[..., Any] = None):
+ return default_func(233)(callback)(a, b)
+
+ assert _calculate(1, 2) == 233
+ assert _calculate(1, 2, lambda a, b: a + b) == 3
+ assert _calculate(1, 2, lambda a, b: a * b) == 2
+
+
+@pytest.mark.unittest
+class TestInteractionBaseControllableService:
+
+ def test_it(self):
+ _start, _shutdown, _finished = False, False, False
+
+ class _Service(ControllableService):
+
+ def start(self):
+ nonlocal _start
+ _start = True
+
+ def shutdown(self):
+ nonlocal _shutdown
+ _shutdown = True
+
+ def join(self):
+ time.sleep(1.0)
+ nonlocal _finished
+ _finished = True
+
+ assert (_start, _shutdown, _finished) == (False, False, False)
+ with _Service():
+ assert (_start, _shutdown, _finished) == (True, False, False)
+
+ assert (_start, _shutdown, _finished) == (True, True, True)
diff --git a/DI-engine/ding/interaction/tests/base/test_network.py b/DI-engine/ding/interaction/tests/base/test_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7407e387908383b4ac4b2a5c54c8bfbdfb1be5a
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/base/test_network.py
@@ -0,0 +1,171 @@
+import json
+import time
+from contextlib import contextmanager
+from multiprocessing import Process
+
+import pytest
+import requests
+import responses
+from flask import Flask, request
+from requests import HTTPError
+from urlobject import URLObject
+
+from ..test_utils import silence
+from ...base import get_host_ip, success_response, get_values_from_response, split_http_address, HttpEngine, \
+ get_http_engine_class
+
+app = Flask('_test_get_host_ip')
+
+
+@app.route('/ping', methods=['GET'])
+def ping_method():
+ return success_response(message='PONG!')
+
+
+@app.route('/shutdown', methods=['DELETE'])
+def shutdown_method():
+ _shutdown_func = request.environ.get('werkzeug.server.shutdown')
+ if _shutdown_func is None:
+ raise RuntimeError('Not running with the Werkzeug Server')
+
+ _shutdown_func()
+ return success_response(message='Shutdown request received, this server will be down later.')
+
+
+_APP_PORT = 17503
+
+
+def run_test_app():
+ with silence():
+ app.run(host='0.0.0.0', port=_APP_PORT)
+
+
+@pytest.mark.unittest
+class TestInteractionBaseNetwork:
+
+ @pytest.mark.execution_timeout(5.0, method='thread')
+ def test_get_host_ip(self):
+ app_process = Process(target=run_test_app)
+ app_process.start()
+
+ _local_ip = get_host_ip()
+ _local_server_host = URLObject().with_scheme('http').with_hostname(_local_ip).with_port(_APP_PORT)
+
+ try:
+ _start_time = time.time()
+ _start_complete = False
+ while not _start_complete and time.time() - _start_time < 5.0:
+ try:
+ response = requests.get(_local_server_host.add_path('/ping'))
+ if response.ok:
+ _start_complete = True
+ break
+ time.sleep(0.2)
+ except (requests.exceptions.BaseHTTPError, requests.exceptions.RequestException):
+ time.sleep(0.2)
+
+ if not _start_complete:
+ pytest.fail('Test server start failed.')
+
+ assert get_values_from_response(response) == (
+ 200,
+ True,
+ 0,
+ 'PONG!',
+ None,
+ )
+ finally:
+ try:
+ requests.delete(_local_server_host.add_path('/shutdown'))
+ finally:
+ app_process.join()
+
+ def test_split_http_address(self):
+ assert split_http_address('http://1.2.3.4') == ('1.2.3.4', 80, False, '')
+ assert split_http_address('https://1.2.3.4') == ('1.2.3.4', 443, True, '')
+ assert split_http_address('http://1.2.3.4:8888') == ('1.2.3.4', 8888, False, '')
+ assert split_http_address('https://1.2.3.4:8787/this/is/path') == ('1.2.3.4', 8787, True, '/this/is/path')
+
+
+@pytest.mark.unittest
+class TestInteractionBaseHttpEngine:
+
+ @contextmanager
+ def __yield_http_engine(self):
+ with responses.RequestsMock(assert_all_requests_are_fired=False) as rsp:
+ rsp.add(
+ **{
+ 'method': responses.GET,
+ 'url': 'http://example.com:7777/this/is/404',
+ 'body': json.dumps({"exception": "reason"}),
+ 'status': 404,
+ 'content_type': 'application/json',
+ }
+ )
+ rsp.add(
+ **{
+ 'method': responses.GET,
+ 'url': 'http://example.com:7777/this/is/200',
+ 'body': json.dumps({"success": True}),
+ 'status': 200,
+ 'content_type': 'application/json',
+ }
+ )
+
+ yield
+
+ @responses.activate
+ def test_http_engine_basic(self):
+ with self.__yield_http_engine():
+ engine = HttpEngine(host='example.com', port=7777)
+ response = engine.request('GET', '/this/is/200')
+ assert response.status_code == 200
+ assert json.loads(response.content.decode()) == {"success": True}
+
+ with pytest.raises(HTTPError) as ei:
+ engine.request('GET', '/this/is/404')
+
+ err = ei.value
+ assert err.response.status_code == 404
+ assert json.loads(err.response.content.decode()) == {'exception': 'reason'}
+
+ @responses.activate
+ def test_http_engine_with_path(self):
+ with self.__yield_http_engine():
+ engine = HttpEngine(host='example.com', port=7777, path='/this/is')
+ response = engine.request('GET', '200')
+ assert response.status_code == 200
+ assert json.loads(response.content.decode()) == {"success": True}
+
+ with pytest.raises(HTTPError) as ei:
+ engine.request('GET', '404')
+
+ err = ei.value
+ assert err.response.status_code == 404
+ assert json.loads(err.response.content.decode()) == {'exception': 'reason'}
+
+ @responses.activate
+ def test_get_http_engine_class(self):
+ with self.__yield_http_engine():
+ _token = '233'
+
+ _http_engine_class = get_http_engine_class(
+ headers={'Token': lambda: _token},
+ data_processor=(lambda d: {
+ 'data': json.dumps(d)
+ }),
+ http_error_gene=lambda e: RuntimeError('This is {status}'.format(status=e.response.status_code))
+ )()
+ engine = _http_engine_class(host='example.com', port=7777, path='/this/is')
+
+ response = engine.request('GET', '200', {'a': 'skdjgflksdj'})
+ assert response.status_code == 200
+ assert json.loads(response.content.decode()) == {"success": True}
+ assert response.request.headers['Token'] == '233'
+ assert json.loads(response.request.body) == {'data': json.dumps({'a': 'skdjgflksdj'})}
+
+ with pytest.raises(RuntimeError) as ei:
+ engine.request('GET', '404', {'a': 'skdjgflksdj'})
+
+ err = ei.value
+ assert 'This is 404' in str(err)
diff --git a/DI-engine/ding/interaction/tests/base/test_threading.py b/DI-engine/ding/interaction/tests/base/test_threading.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e85a3465d1094f2eef054bd4678d2520d1d2f2c
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/base/test_threading.py
@@ -0,0 +1,128 @@
+import time
+from threading import Thread
+
+import pytest
+
+from ...base import DblEvent
+
+
+@pytest.mark.unittest
+class TestInteractionBaseThreading:
+ # noinspection DuplicatedCode
+ @pytest.mark.execution_timeout(5.0, method='thread')
+ def test_dbl_event_open(self):
+ event = DblEvent()
+ assert event.is_close()
+ assert not event.is_open()
+
+ # Opening test
+ _time_1, _time_2 = 0.0, 0.0
+
+ def _run_1_wait_for_open():
+ nonlocal _time_1
+ event.wait_for_open()
+ _time_1 = time.time()
+
+ def _run_2_wait_for_open():
+ nonlocal _time_2
+ event.wait_for_open()
+ _time_2 = time.time()
+
+ _thread_1 = Thread(target=_run_1_wait_for_open)
+ _thread_2 = Thread(target=_run_2_wait_for_open)
+
+ _thread_1.start()
+ _thread_2.start()
+
+ time.sleep(0.2)
+ assert event.is_close()
+ assert not event.is_open()
+ assert _time_1 == 0.0
+ assert _time_2 == 0.0
+
+ time.sleep(0.8)
+ event.open()
+ _thread_1.join()
+ _thread_2.join()
+
+ assert abs(time.time() - _time_1) < 0.3
+ assert abs(time.time() - _time_2) < 0.3
+ assert not event.is_close()
+ assert event.is_open()
+
+ # Closing test
+ _time_1, _time_2 = 0.0, 0.0
+
+ def _run_1_wait_for_close():
+ nonlocal _time_1
+ event.wait_for_close()
+ _time_1 = time.time()
+
+ def _run_2_wait_for_close():
+ nonlocal _time_2
+ event.wait_for_close()
+ _time_2 = time.time()
+
+ _thread_1 = Thread(target=_run_1_wait_for_close)
+ _thread_2 = Thread(target=_run_2_wait_for_close)
+
+ _thread_1.start()
+ _thread_2.start()
+
+ time.sleep(0.2)
+ assert not event.is_close()
+ assert event.is_open()
+ assert _time_1 == 0.0
+ assert _time_2 == 0.0
+
+ time.sleep(0.8)
+ event.close()
+ _thread_1.join()
+ _thread_2.join()
+
+ assert abs(time.time() - _time_1) < 0.3
+ assert abs(time.time() - _time_2) < 0.3
+ assert event.is_close()
+ assert not event.is_open()
+
+ # noinspection DuplicatedCode
+ @pytest.mark.execution_timeout(5.0, method='thread')
+ def test_dbl_event_close(self):
+ event = DblEvent(True)
+ assert not event.is_close()
+ assert event.is_open()
+
+ # Closing test
+ _time_1, _time_2 = 0.0, 0.0
+
+ def _run_1_wait_for_close():
+ nonlocal _time_1
+ event.wait_for_close()
+ _time_1 = time.time()
+
+ def _run_2_wait_for_close():
+ nonlocal _time_2
+ event.wait_for_close()
+ _time_2 = time.time()
+
+ _thread_1 = Thread(target=_run_1_wait_for_close)
+ _thread_2 = Thread(target=_run_2_wait_for_close)
+
+ _thread_1.start()
+ _thread_2.start()
+
+ time.sleep(0.2)
+ assert not event.is_close()
+ assert event.is_open()
+ assert _time_1 == 0.0
+ assert _time_2 == 0.0
+
+ time.sleep(0.8)
+ event.close()
+ _thread_1.join()
+ _thread_2.join()
+
+ assert abs(time.time() - _time_1) < 0.3
+ assert abs(time.time() - _time_2) < 0.3
+ assert event.is_close()
+ assert not event.is_open()
diff --git a/DI-engine/ding/interaction/tests/config/__init__.py b/DI-engine/ding/interaction/tests/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1cbf8c160ff968d2e33e1ea839a7215af9dc58bb
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/config/__init__.py
@@ -0,0 +1 @@
+from .test_base import TestInteractionConfig
diff --git a/DI-engine/ding/interaction/tests/config/test_base.py b/DI-engine/ding/interaction/tests/config/test_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b87d1bda444d9c5d6f4f915ea96c347aec5985f
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/config/test_base.py
@@ -0,0 +1,11 @@
+import pytest
+
+from ...config import GLOBAL_HOST, LOCAL_HOST
+
+
+@pytest.mark.unittest
+class TestInteractionConfig:
+
+ def test_base_host(self):
+ assert GLOBAL_HOST == '0.0.0.0'
+ assert LOCAL_HOST == '127.0.0.1'
diff --git a/DI-engine/ding/interaction/tests/exception/__init__.py b/DI-engine/ding/interaction/tests/exception/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..98a6f19efe4e5a987f2f085769e439bd6a2cfe4e
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/exception/__init__.py
@@ -0,0 +1,2 @@
+from .test_master import TestInteractionExceptionMaster
+from .test_slave import TestInteractionExceptionSlave
diff --git a/DI-engine/ding/interaction/tests/exception/test_base.py b/DI-engine/ding/interaction/tests/exception/test_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7c7a6ba00d17a5f22b27b535641bd0be9ffeffa
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/exception/test_base.py
@@ -0,0 +1,51 @@
+import json
+from contextlib import contextmanager
+from typing import Optional, Mapping, Any
+
+import pytest
+import requests
+import responses
+from requests import HTTPError
+
+
+class _HTTPErrorGenerator:
+
+ @classmethod
+ def _generate_exception(
+ cls, code: int, message: str, data: Optional[Mapping[str, Any]] = None, success: bool = False
+ ):
+
+ @contextmanager
+ def _yield_func():
+ with responses.RequestsMock(assert_all_requests_are_fired=False) as rsp:
+ rsp.add(
+ **{
+ 'method': responses.GET,
+ 'url': 'http://example.com/path',
+ 'body': json.dumps(
+ {
+ "success": not not success,
+ "code": int(code),
+ "message": str(message),
+ "data": data or {},
+ }
+ ),
+ 'status': 400,
+ 'content_type': 'application/json',
+ }
+ )
+
+ yield
+
+ @responses.activate
+ def _get_exception():
+ try:
+ with _yield_func():
+ response = requests.get('http://example.com/path')
+ response.raise_for_status()
+ except HTTPError as err:
+ return err
+ else:
+ pytest.fail('Should not reach here.')
+
+ return _get_exception()
diff --git a/DI-engine/ding/interaction/tests/exception/test_master.py b/DI-engine/ding/interaction/tests/exception/test_master.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcd89ccb5f85d47ac270f84d11942e81ecec674a
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/exception/test_master.py
@@ -0,0 +1,60 @@
+import pytest
+
+from .test_base import _HTTPErrorGenerator
+from ...exception.master import MasterErrorCode, \
+ get_master_exception_class_by_error_code, get_master_exception_by_error, MasterSuccess, \
+ MasterSystemShuttingDown, MasterTaskDataInvalid, MasterSlaveTokenNotGiven, MasterSlaveTokenInvalid, \
+ MasterSelfTokenNotGiven, MasterSelfTokenInvalid, MasterChannelInvalid, \
+ MasterChannelNotGiven, MasterMasterTokenInvalid, MasterMasterTokenNotGiven
+
+
+@pytest.mark.unittest
+class TestInteractionExceptionMaster(_HTTPErrorGenerator):
+
+ def test_error_code(self):
+ assert len(MasterErrorCode.__members__) == 11
+ assert MasterErrorCode.SUCCESS == 0
+
+ def test_exception_class(self):
+ assert get_master_exception_class_by_error_code(MasterErrorCode.SUCCESS) == MasterSuccess
+
+ assert get_master_exception_class_by_error_code(
+ MasterErrorCode.SYSTEM_SHUTTING_DOWN
+ ) == MasterSystemShuttingDown
+
+ assert get_master_exception_class_by_error_code(MasterErrorCode.CHANNEL_NOT_GIVEN) == MasterChannelNotGiven
+ assert get_master_exception_class_by_error_code(MasterErrorCode.CHANNEL_INVALID) == MasterChannelInvalid
+
+ assert get_master_exception_class_by_error_code(
+ MasterErrorCode.MASTER_TOKEN_NOT_GIVEN
+ ) == MasterMasterTokenNotGiven
+ assert get_master_exception_class_by_error_code(
+ MasterErrorCode.MASTER_TOKEN_INVALID
+ ) == MasterMasterTokenInvalid
+
+ assert get_master_exception_class_by_error_code(MasterErrorCode.SELF_TOKEN_NOT_GIVEN) == MasterSelfTokenNotGiven
+ assert get_master_exception_class_by_error_code(MasterErrorCode.SELF_TOKEN_INVALID) == MasterSelfTokenInvalid
+
+ assert get_master_exception_class_by_error_code(
+ MasterErrorCode.SLAVE_TOKEN_NOT_GIVEN
+ ) == MasterSlaveTokenNotGiven
+ assert get_master_exception_class_by_error_code(MasterErrorCode.SLAVE_TOKEN_INVALID) == MasterSlaveTokenInvalid
+
+ assert get_master_exception_class_by_error_code(MasterErrorCode.TASK_DATA_INVALID) == MasterTaskDataInvalid
+
+ def test_get_master_exception_by_error(self):
+ err = get_master_exception_by_error(self._generate_exception(101, 'This is system shutting down.'))
+ assert isinstance(err, MasterSystemShuttingDown)
+ assert not err.success
+ assert err.status_code == 400
+ assert err.code == 101
+ assert err.message == 'This is system shutting down.'
+ assert err.data == {}
+
+ err = get_master_exception_by_error(self._generate_exception(601, 'Task data invalid.', data={'value': 233}))
+ assert isinstance(err, MasterTaskDataInvalid)
+ assert not err.success
+ assert err.status_code == 400
+ assert err.code == 601
+ assert err.message == 'Task data invalid.'
+ assert err.data == {'value': 233}
diff --git a/DI-engine/ding/interaction/tests/exception/test_slave.py b/DI-engine/ding/interaction/tests/exception/test_slave.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac4fd261416a4daec734f1b69ed113bf04bd1d74
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/exception/test_slave.py
@@ -0,0 +1,64 @@
+import pytest
+
+from .test_base import _HTTPErrorGenerator
+from ...exception.slave import SlaveErrorCode, \
+ get_slave_exception_class_by_error_code, get_slave_exception_by_error, SlaveSystemShuttingDown, \
+ SlaveSlaveConnectionRefused, SlaveSlaveDisconnectionRefused, SlaveSlaveNotConnected, SlaveSlaveAlreadyConnected, \
+ SlaveTaskRefused, SlaveMasterTokenInvalid, SlaveMasterTokenNotFound, SlaveSelfTokenNotFound, \
+ SlaveTaskAlreadyExist, SlaveSelfTokenInvalid, SlaveChannelNotFound, SlaveChannelInvalid, SlaveSuccess
+
+
+@pytest.mark.unittest
+class TestInteractionExceptionSlave(_HTTPErrorGenerator):
+
+ def test_error_code(self):
+ assert len(SlaveErrorCode.__members__) == 14
+ assert SlaveErrorCode.SUCCESS == 0
+
+ # noinspection DuplicatedCode
+ def test_exception_class(self):
+ assert get_slave_exception_class_by_error_code(SlaveErrorCode.SUCCESS) == SlaveSuccess
+
+ assert get_slave_exception_class_by_error_code(SlaveErrorCode.SYSTEM_SHUTTING_DOWN) == SlaveSystemShuttingDown
+
+ assert get_slave_exception_class_by_error_code(SlaveErrorCode.CHANNEL_NOT_FOUND) == SlaveChannelNotFound
+ assert get_slave_exception_class_by_error_code(SlaveErrorCode.CHANNEL_INVALID) == SlaveChannelInvalid
+
+ assert get_slave_exception_class_by_error_code(
+ SlaveErrorCode.MASTER_TOKEN_NOT_FOUND
+ ) == SlaveMasterTokenNotFound
+ assert get_slave_exception_class_by_error_code(SlaveErrorCode.MASTER_TOKEN_INVALID) == SlaveMasterTokenInvalid
+
+ assert get_slave_exception_class_by_error_code(SlaveErrorCode.SELF_TOKEN_NOT_FOUND) == SlaveSelfTokenNotFound
+ assert get_slave_exception_class_by_error_code(SlaveErrorCode.SELF_TOKEN_INVALID) == SlaveSelfTokenInvalid
+
+ assert get_slave_exception_class_by_error_code(
+ SlaveErrorCode.SLAVE_ALREADY_CONNECTED
+ ) == SlaveSlaveAlreadyConnected
+ assert get_slave_exception_class_by_error_code(SlaveErrorCode.SLAVE_NOT_CONNECTED) == SlaveSlaveNotConnected
+ assert get_slave_exception_class_by_error_code(
+ SlaveErrorCode.SLAVE_CONNECTION_REFUSED
+ ) == SlaveSlaveConnectionRefused
+ assert get_slave_exception_class_by_error_code(
+ SlaveErrorCode.SLAVE_DISCONNECTION_REFUSED
+ ) == SlaveSlaveDisconnectionRefused
+
+ assert get_slave_exception_class_by_error_code(SlaveErrorCode.TASK_ALREADY_EXIST) == SlaveTaskAlreadyExist
+ assert get_slave_exception_class_by_error_code(SlaveErrorCode.TASK_REFUSED) == SlaveTaskRefused
+
+ def test_get_slave_exception_by_error(self):
+ err = get_slave_exception_by_error(self._generate_exception(101, 'This is slave shutting down.'))
+ assert isinstance(err, SlaveSystemShuttingDown)
+ assert not err.success
+ assert err.status_code == 400
+ assert err.code == 101
+ assert err.message == 'This is slave shutting down.'
+ assert err.data == {}
+
+ err = get_slave_exception_by_error(self._generate_exception(602, 'Task refused.', data={'value': 233}))
+ assert isinstance(err, SlaveTaskRefused)
+ assert not err.success
+ assert err.status_code == 400
+ assert err.code == 602
+ assert err.message == 'Task refused.'
+ assert err.data == {'value': 233}
diff --git a/DI-engine/ding/interaction/tests/interaction/__init__.py b/DI-engine/ding/interaction/tests/interaction/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3afe68e3d6e04c6f3434a1baf5672a50c3fea86e
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/interaction/__init__.py
@@ -0,0 +1,2 @@
+from .test_errors import TestInteractionErrors
+from .test_simple import TestInteractionSimple
diff --git a/DI-engine/ding/interaction/tests/interaction/bases.py b/DI-engine/ding/interaction/tests/interaction/bases.py
new file mode 100644
index 0000000000000000000000000000000000000000..d312a969c3292436540824a4c07e079825c9b67f
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/interaction/bases.py
@@ -0,0 +1,50 @@
+from functools import partial
+from multiprocessing import Event, Process
+from typing import Mapping, Any, Tuple
+
+from ..test_utils import silence_function, random_channel, random_port
+from ...master import Master
+from ...slave import Slave, TaskFail
+
+
+class MySlave(Slave):
+
+ def _process_task(self, task: Mapping[str, Any]):
+ if 'a' in task.keys() and 'b' in task.keys():
+ return {'sum': task['a'] + task['b']}
+ else:
+ raise TaskFail(result={'message': 'ab not found'}, message='A or B not found in task data.')
+
+
+def _run_slave(port, channel, open_slave_event, close_slave_event):
+ with MySlave('0.0.0.0', port, channel=channel):
+ open_slave_event.set()
+ close_slave_event.wait()
+
+
+def _slave_endpoint(port: int, channel: int, silence: bool = False):
+ open_slave_event = Event()
+ close_slave_event = Event()
+
+ _run = partial(_run_slave, port, channel, open_slave_event, close_slave_event)
+ if silence:
+ _run = silence_function()(_run)
+ slave_process = Process(target=_run)
+
+ return slave_process, open_slave_event, close_slave_event
+
+
+class _MyMaster(Master):
+ pass
+
+
+def _get_master_endpoint(port: int, channel: int):
+ return _MyMaster('0.0.0.0', port, channel=channel)
+
+
+def _random_slave_channel_and_port() -> Tuple[int, int]:
+ return random_port(), random_channel()
+
+
+class _TestInteractionBase:
+ pass
diff --git a/DI-engine/ding/interaction/tests/interaction/test_errors.py b/DI-engine/ding/interaction/tests/interaction/test_errors.py
new file mode 100644
index 0000000000000000000000000000000000000000..98a391daa7307f873a9a3f72cae97effaa510f55
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/interaction/test_errors.py
@@ -0,0 +1,38 @@
+import pytest
+
+from .bases import _TestInteractionBase, _random_slave_channel_and_port, _slave_endpoint, _get_master_endpoint
+from ..test_utils import random_port, random_channel
+from ...exception import SlaveErrorCode, SlaveChannelInvalid
+
+
+@pytest.mark.unittest
+class TestInteractionErrors(_TestInteractionBase):
+
+ @pytest.mark.execution_timeout(20.0, method='thread')
+ def test_slave_simple_connection(self):
+ _slave_port, _slave_channel = _random_slave_channel_and_port()
+ slave_thread, open_slave_event, close_slave_event = _slave_endpoint(_slave_port, _slave_channel)
+
+ slave_thread.start()
+ open_slave_event.wait()
+
+ try:
+ _master_port = random_port()
+ _master_channel = random_channel(excludes=[_slave_channel])
+ master = _get_master_endpoint(_master_port, _master_channel)
+ with master:
+ assert master.ping()
+
+ with pytest.raises(SlaveChannelInvalid) as ei:
+ with master.new_connection('conn', '127.0.0.1', _slave_port):
+ pytest.fail('Should not reach here!')
+
+ err = ei.value
+ assert not err.success
+ assert err.status_code == 403
+ assert err.code == SlaveErrorCode.CHANNEL_INVALID
+
+ assert 'conn' not in master
+ finally:
+ close_slave_event.set()
+ slave_thread.join()
diff --git a/DI-engine/ding/interaction/tests/interaction/test_simple.py b/DI-engine/ding/interaction/tests/interaction/test_simple.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcb89e9a780001e83deb03b52774e268f6a36cad
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/interaction/test_simple.py
@@ -0,0 +1,130 @@
+import pytest
+from requests import HTTPError
+
+from .bases import _TestInteractionBase, _random_slave_channel_and_port, _slave_endpoint, _get_master_endpoint
+from ..test_utils import random_port
+from ...master.task import TaskStatus
+
+
+@pytest.mark.unittest
+class TestInteractionSimple(_TestInteractionBase):
+
+ @pytest.mark.execution_timeout(10.0, method='thread')
+ def test_slave_launch(self):
+ _slave_port, _channel = _random_slave_channel_and_port()
+ slave_thread, open_slave_event, close_slave_event = _slave_endpoint(_slave_port, _channel)
+
+ slave_thread.start()
+ open_slave_event.wait()
+
+ close_slave_event.set()
+ slave_thread.join()
+
+ @pytest.mark.execution_timeout(20.0, method='thread')
+ def test_slave_simple_connection(self):
+ _slave_port, _channel = _random_slave_channel_and_port()
+ slave_thread, open_slave_event, close_slave_event = _slave_endpoint(_slave_port, _channel)
+
+ slave_thread.start()
+ open_slave_event.wait()
+
+ try:
+ _master_port = random_port()
+ master = _get_master_endpoint(_master_port, _channel)
+ with master:
+ assert master.ping()
+
+ with master.new_connection('conn', '127.0.0.1', _slave_port) as conn:
+ assert conn.is_connected
+ assert 'conn' in master
+ assert master['conn'] == conn
+
+ assert not conn.is_connected
+ assert 'conn' not in master
+
+ conn = master.new_connection('conn', '127.0.0.1', _slave_port)
+ conn.connect()
+ assert conn.is_connected
+ assert 'conn' in master
+ assert master['conn'] == conn
+ conn.disconnect()
+ assert not conn.is_connected
+ assert 'conn' not in master
+
+ conn = master.new_connection('conn', '127.0.0.1', _slave_port)
+ conn.connect()
+ assert conn.is_connected
+ assert 'conn' in master
+ assert master['conn'] == conn
+ del master['conn']
+ assert not conn.is_connected
+ assert 'conn' not in master
+
+ finally:
+ close_slave_event.set()
+ slave_thread.join()
+
+ @pytest.mark.execution_timeout(20.0, method='thread')
+ def test_slave_simple_task(self):
+ _slave_port, _channel = _random_slave_channel_and_port()
+ slave_thread, open_slave_event, close_slave_event = _slave_endpoint(_slave_port, _channel)
+
+ slave_thread.start()
+ open_slave_event.wait()
+
+ try:
+ _master_port = random_port()
+ master = _get_master_endpoint(_master_port, _channel)
+ with master:
+ with master.new_connection('conn', '127.0.0.1', _slave_port) as conn:
+ task = conn.new_task({'a': 2, 'b': 3})
+ task.start().join()
+
+ assert task.result == {'sum': 5}
+ assert task.status == TaskStatus.COMPLETED
+
+ _res_1, _res_2, _res_3 = None, None, None
+
+ def _set_res_1(t, r):
+ nonlocal _res_1
+ _res_1 = r['sum']
+
+ def _set_res_2(t, r):
+ nonlocal _res_2
+ _res_2 = r
+
+ def _set_res_3(t, r):
+ nonlocal _res_3
+ _res_3 = r
+
+ task = conn.new_task({'a': 2, 'b': 3}) \
+ .on_complete(_set_res_1).on_complete(_set_res_2) \
+ .on_fail(_set_res_3)
+ task.start().join()
+
+ assert task.result == {'sum': 5}
+ assert task.status == TaskStatus.COMPLETED
+ assert _res_1 == 5
+ assert _res_2 == {'sum': 5}
+ assert _res_3 is None
+
+ _res_1, _res_2, _res_3 = None, None, None
+ task = conn.new_task({'a': 2, 'bb': 3}) \
+ .on_complete(_set_res_1).on_complete(_set_res_2) \
+ .on_fail(_set_res_3)
+ task.start().join()
+
+ assert task.result == {'message': 'ab not found'}
+ assert task.status == TaskStatus.FAILED
+ assert _res_1 is None
+ assert _res_2 is None
+ assert _res_3 == {'message': 'ab not found'}
+ except HTTPError as err:
+ print(err.response)
+ print(err.response.content)
+ print(err.request)
+
+ raise err
+ finally:
+ close_slave_event.set()
+ slave_thread.join()
diff --git a/DI-engine/ding/interaction/tests/test_utils/__init__.py b/DI-engine/ding/interaction/tests/test_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b34b8955ad8c3d94c45bc434962fb88a99b3aff0
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/test_utils/__init__.py
@@ -0,0 +1,2 @@
+from .random import random_port, random_channel
+from .stream import silence, silence_function
diff --git a/DI-engine/ding/interaction/tests/test_utils/random.py b/DI-engine/ding/interaction/tests/test_utils/random.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b308970d7c14b9ed0d0cdae3caf128b06f3c4f9
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/test_utils/random.py
@@ -0,0 +1,14 @@
+import random
+from typing import Iterable
+
+
+def random_port(excludes: Iterable[int] = None) -> int:
+ return random.choice(list(set(range(10000, 20000)) - set(excludes or [])))
+
+
+def random_channel(excludes: Iterable[int] = None) -> int:
+ excludes = set(list(excludes or []))
+ while True:
+ _channel = random.randint(1000, (1 << 31) - 1)
+ if _channel not in excludes:
+ return _channel
diff --git a/DI-engine/ding/interaction/tests/test_utils/stream.py b/DI-engine/ding/interaction/tests/test_utils/stream.py
new file mode 100644
index 0000000000000000000000000000000000000000..0fc4657a88c73564a1f2b682f14f0f5be2b0ca2d
--- /dev/null
+++ b/DI-engine/ding/interaction/tests/test_utils/stream.py
@@ -0,0 +1,40 @@
+import os
+import sys
+from contextlib import contextmanager
+from functools import wraps
+from threading import Lock
+from typing import Callable, Any
+
+_global_no_output_lock = Lock()
+
+
+@contextmanager
+def silence(no_stdout: bool = True, no_stderr: bool = True):
+ with _global_no_output_lock:
+ if no_stdout:
+ # Don't use `wb` mode here, otherwise it will cause all streaming methods to crash
+ _real_stdout, sys.stdout = sys.stdout, open(os.devnull, 'w')
+ if no_stderr:
+ _real_stderr, sys.stderr = sys.stderr, open(os.devnull, 'w')
+
+ try:
+ yield
+ finally:
+ if no_stdout:
+ sys.stdout = _real_stdout
+ if no_stderr:
+ sys.stderr = _real_stderr
+
+
+def silence_function(no_stdout: bool = True, no_stderr: bool = True):
+
+ def _decorator(func: Callable[..., Any]) -> Callable[..., Any]:
+
+ @wraps(func)
+ def _func(*args, **kwargs):
+ with silence(no_stdout, no_stderr):
+ return func(*args, **kwargs)
+
+ return _func
+
+ return _decorator
diff --git a/DI-engine/ding/league/__init__.py b/DI-engine/ding/league/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d87ef83193038970b8658b056ea90b83a304da8d
--- /dev/null
+++ b/DI-engine/ding/league/__init__.py
@@ -0,0 +1,6 @@
+from .base_league import BaseLeague, create_league
+from .one_vs_one_league import OneVsOneLeague
+from .player import Player, ActivePlayer, HistoricalPlayer, create_player
+from .starcraft_player import MainPlayer, MainExploiter, LeagueExploiter
+from .shared_payoff import create_payoff
+from .metric import get_elo, get_elo_array, LeagueMetricEnv
diff --git a/DI-engine/ding/league/algorithm.py b/DI-engine/ding/league/algorithm.py
new file mode 100644
index 0000000000000000000000000000000000000000..09cb23b3527f195133851ed2c2be9a6a574a0067
--- /dev/null
+++ b/DI-engine/ding/league/algorithm.py
@@ -0,0 +1,43 @@
+import numpy as np
+
+
+def pfsp(win_rates: np.ndarray, weighting: str) -> np.ndarray:
+ """
+ Overview:
+ Prioritized Fictitious Self-Play algorithm.
+ Process win_rates with a weighting function to get priority, then calculate the selection probability of each.
+ Arguments:
+ - win_rates (:obj:`np.ndarray`): a numpy ndarray of win rates between one player and N opponents, shape(N)
+ - weighting (:obj:`str`): pfsp weighting function type, refer to ``weighting_func`` below
+ Returns:
+ - probs (:obj:`np.ndarray`): a numpy ndarray of probability at which one element is selected, shape(N)
+ """
+ weighting_func = {
+ 'squared': lambda x: (1 - x) ** 2,
+ 'variance': lambda x: x * (1 - x),
+ }
+ if weighting in weighting_func.keys():
+ fn = weighting_func[weighting]
+ else:
+ raise KeyError("invalid weighting arg: {} in pfsp".format(weighting))
+
+ assert isinstance(win_rates, np.ndarray)
+ assert win_rates.shape[0] >= 1, win_rates.shape
+ # all zero win rates case, return uniform selection prob
+ if win_rates.sum() < 1e-8:
+ return np.full_like(win_rates, 1.0 / len(win_rates))
+ fn_win_rates = fn(win_rates)
+ probs = fn_win_rates / fn_win_rates.sum()
+ return probs
+
+
+def uniform(win_rates: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Uniform opponent selection algorithm. Select an opponent uniformly, regardless of historical win rates.
+ Arguments:
+ - win_rates (:obj:`np.ndarray`): a numpy ndarray of win rates between one player and N opponents, shape(N)
+ Returns:
+ - probs (:obj:`np.ndarray`): a numpy ndarray of uniform probability, shape(N)
+ """
+ return np.full_like(win_rates, 1.0 / len(win_rates))
diff --git a/DI-engine/ding/league/base_league.py b/DI-engine/ding/league/base_league.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a5ed333e5ecc994cd277b0ca8659f8a00c6da14
--- /dev/null
+++ b/DI-engine/ding/league/base_league.py
@@ -0,0 +1,302 @@
+from typing import Union, Dict
+import uuid
+import copy
+import os
+import os.path as osp
+from abc import abstractmethod
+from easydict import EasyDict
+from tabulate import tabulate
+
+from ding.league.player import ActivePlayer, HistoricalPlayer, create_player
+from ding.league.shared_payoff import create_payoff
+from ding.utils import import_module, read_file, save_file, LockContext, LockContextType, LEAGUE_REGISTRY, \
+ deep_merge_dicts
+from .metric import LeagueMetricEnv
+
+
+class BaseLeague:
+ """
+ Overview:
+ League, proposed by Google Deepmind AlphaStar. Can manage multiple players in one league.
+ Interface:
+ get_job_info, judge_snapshot, update_active_player, finish_job, save_checkpoint
+
+ .. note::
+ In ``__init__`` method, league would also initialized players as well(in ``_init_players`` method).
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ config = dict(
+ league_type='base',
+ import_names=["ding.league.base_league"],
+ # ---player----
+ # "player_category" is just a name. Depends on the env.
+ # For example, in StarCraft, this can be ['zerg', 'terran', 'protoss'].
+ player_category=['default'],
+ # Support different types of active players for solo and battle league.
+ # For solo league, supports ['solo_active_player'].
+ # For battle league, supports ['battle_active_player', 'main_player', 'main_exploiter', 'league_exploiter'].
+ # active_players=dict(),
+ # "use_pretrain" means whether to use pretrain model to initialize active player.
+ use_pretrain=False,
+ # "use_pretrain_init_historical" means whether to use pretrain model to initialize historical player.
+ # "pretrain_checkpoint_path" is the pretrain checkpoint path used in "use_pretrain" and
+ # "use_pretrain_init_historical". If both are False, "pretrain_checkpoint_path" can be omitted as well.
+ # Otherwise, "pretrain_checkpoint_path" should list paths of all player categories.
+ use_pretrain_init_historical=False,
+ pretrain_checkpoint_path=dict(default='default_cate_pretrain.pth', ),
+ # ---payoff---
+ payoff=dict(
+ # Supports ['battle']
+ type='battle',
+ decay=0.99,
+ min_win_rate_games=8,
+ ),
+ metric=dict(
+ mu=0,
+ sigma=25 / 3,
+ beta=25 / 3 / 2,
+ tau=0.0,
+ draw_probability=0.02,
+ ),
+ )
+
+ def __init__(self, cfg: EasyDict) -> None:
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`EasyDict`): League config.
+ """
+ self.cfg = deep_merge_dicts(self.default_config(), cfg)
+ self.path_policy = cfg.path_policy
+ if not osp.exists(self.path_policy):
+ os.mkdir(self.path_policy)
+
+ self.league_uid = str(uuid.uuid1())
+ # TODO dict players
+ self.active_players = []
+ self.historical_players = []
+ self.player_path = "./league"
+ self.payoff = create_payoff(self.cfg.payoff)
+ metric_cfg = self.cfg.metric
+ self.metric_env = LeagueMetricEnv(metric_cfg.mu, metric_cfg.sigma, metric_cfg.tau, metric_cfg.draw_probability)
+ self._active_players_lock = LockContext(type_=LockContextType.THREAD_LOCK)
+ self._init_players()
+
+ def _init_players(self) -> None:
+ """
+ Overview:
+ Initialize players (active & historical) in the league.
+ """
+ # Add different types of active players for each player category, according to ``cfg.active_players``.
+ for cate in self.cfg.player_category: # Player's category (Depends on the env)
+ for k, n in self.cfg.active_players.items(): # Active player's type
+ for i in range(n): # This type's active player number
+ name = '{}_{}_{}'.format(k, cate, i)
+ ckpt_path = osp.join(self.path_policy, '{}_ckpt.pth'.format(name))
+ player = create_player(
+ self.cfg, k, self.cfg[k], cate, self.payoff, ckpt_path, name, 0, self.metric_env.create_rating()
+ )
+ if self.cfg.use_pretrain:
+ self.save_checkpoint(self.cfg.pretrain_checkpoint_path[cate], ckpt_path)
+ self.active_players.append(player)
+ self.payoff.add_player(player)
+
+ # Add pretrain player as the initial HistoricalPlayer for each player category.
+ if self.cfg.use_pretrain_init_historical:
+ for cate in self.cfg.player_category:
+ main_player_name = [k for k in self.cfg.keys() if 'main_player' in k]
+ assert len(main_player_name) == 1, main_player_name
+ main_player_name = main_player_name[0]
+ name = '{}_{}_0_pretrain_historical'.format(main_player_name, cate)
+ parent_name = '{}_{}_0'.format(main_player_name, cate)
+ hp = HistoricalPlayer(
+ self.cfg.get(main_player_name),
+ cate,
+ self.payoff,
+ self.cfg.pretrain_checkpoint_path[cate],
+ name,
+ 0,
+ self.metric_env.create_rating(),
+ parent_id=parent_name
+ )
+ self.historical_players.append(hp)
+ self.payoff.add_player(hp)
+
+ # Save active players' ``player_id``` & ``player_ckpt```.
+ self.active_players_ids = [p.player_id for p in self.active_players]
+ self.active_players_ckpts = [p.checkpoint_path for p in self.active_players]
+ # Validate active players are unique by ``player_id``.
+ assert len(self.active_players_ids) == len(set(self.active_players_ids))
+
+ def get_job_info(self, player_id: str = None, eval_flag: bool = False) -> dict:
+ """
+ Overview:
+ Get info dict of the job which is to be launched to an active player.
+ Arguments:
+ - player_id (:obj:`str`): The active player's id.
+ - eval_flag (:obj:`bool`): Whether this is an evaluation job.
+ Returns:
+ - job_info (:obj:`dict`): Job info.
+ ReturnsKeys:
+ - necessary: ``launch_player`` (the active player)
+ """
+ if player_id is None:
+ player_id = self.active_players_ids[0]
+ with self._active_players_lock:
+ idx = self.active_players_ids.index(player_id)
+ player = self.active_players[idx]
+ job_info = self._get_job_info(player, eval_flag)
+ assert 'launch_player' in job_info.keys() and job_info['launch_player'] == player.player_id
+ return job_info
+
+ @abstractmethod
+ def _get_job_info(self, player: ActivePlayer, eval_flag: bool = False) -> dict:
+ """
+ Overview:
+ Real `get_job` method. Called by ``_launch_job``.
+ Arguments:
+ - player (:obj:`ActivePlayer`): The active player to be launched a job.
+ - eval_flag (:obj:`bool`): Whether this is an evaluation job.
+ Returns:
+ - job_info (:obj:`dict`): Job info. Should include keys ['lauch_player'].
+ """
+ raise NotImplementedError
+
+ def judge_snapshot(self, player_id: str, force: bool = False) -> bool:
+ """
+ Overview:
+ Judge whether a player is trained enough for snapshot. If yes, call player's ``snapshot``, create a
+ historical player(prepare the checkpoint and add it to the shared payoff), then mutate it, and return True.
+ Otherwise, return False.
+ Arguments:
+ - player_id (:obj:`ActivePlayer`): The active player's id.
+ Returns:
+ - snapshot_or_not (:obj:`dict`): Whether the active player is snapshotted.
+ """
+ with self._active_players_lock:
+ idx = self.active_players_ids.index(player_id)
+ player = self.active_players[idx]
+ if force or player.is_trained_enough():
+ # Snapshot
+ hp = player.snapshot(self.metric_env)
+ self.save_checkpoint(player.checkpoint_path, hp.checkpoint_path)
+ self.historical_players.append(hp)
+ self.payoff.add_player(hp)
+ # Mutate
+ self._mutate_player(player)
+ return True
+ else:
+ return False
+
+ @abstractmethod
+ def _mutate_player(self, player: ActivePlayer) -> None:
+ """
+ Overview:
+ Players have the probability to mutate, e.g. Reset network parameters.
+ Called by ``self.judge_snapshot``.
+ Arguments:
+ - player (:obj:`ActivePlayer`): The active player that may mutate.
+ """
+ raise NotImplementedError
+
+ def update_active_player(self, player_info: dict) -> None:
+ """
+ Overview:
+ Update an active player's info.
+ Arguments:
+ - player_info (:obj:`dict`): Info dict of the player which is to be updated.
+ ArgumentsKeys:
+ - necessary: `player_id`, `train_iteration`
+ """
+ try:
+ idx = self.active_players_ids.index(player_info['player_id'])
+ player = self.active_players[idx]
+ return self._update_player(player, player_info)
+ except ValueError as e:
+ print(e)
+
+ @abstractmethod
+ def _update_player(self, player: ActivePlayer, player_info: dict) -> None:
+ """
+ Overview:
+ Update an active player. Called by ``self.update_active_player``.
+ Arguments:
+ - player (:obj:`ActivePlayer`): The active player that will be updated.
+ - player_info (:obj:`dict`): Info dict of the active player which is to be updated.
+ """
+ raise NotImplementedError
+
+ def finish_job(self, job_info: dict) -> None:
+ """
+ Overview:
+ Finish current job. Update shared payoff to record the game results.
+ Arguments:
+ - job_info (:obj:`dict`): A dict containing job result information.
+ """
+ # TODO(nyz) more fine-grained job info
+ self.payoff.update(job_info)
+ if 'eval_flag' in job_info and job_info['eval_flag']:
+ home_id, away_id = job_info['player_id']
+ home_player, away_player = self.get_player_by_id(home_id), self.get_player_by_id(away_id)
+ job_info_result = job_info['result']
+ if isinstance(job_info_result[0], list):
+ job_info_result = sum(job_info_result, [])
+ home_player.rating, away_player.rating = self.metric_env.rate_1vs1(
+ home_player.rating, away_player.rating, result=job_info_result
+ )
+
+ def get_player_by_id(self, player_id: str) -> 'Player': # noqa
+ if 'historical' in player_id:
+ return [p for p in self.historical_players if p.player_id == player_id][0]
+ else:
+ return [p for p in self.active_players if p.player_id == player_id][0]
+
+ @staticmethod
+ def save_checkpoint(src_checkpoint, dst_checkpoint) -> None:
+ '''
+ Overview:
+ Copy a checkpoint from path ``src_checkpoint`` to path ``dst_checkpoint``.
+ Arguments:
+ - src_checkpoint (:obj:`str`): Source checkpoint's path, e.g. s3://alphastar_fake_data/ckpt.pth
+ - dst_checkpoint (:obj:`str`): Destination checkpoint's path, e.g. s3://alphastar_fake_data/ckpt.pth
+ '''
+ checkpoint = read_file(src_checkpoint)
+ save_file(dst_checkpoint, checkpoint)
+
+ def player_rank(self, string: bool = False) -> Union[str, Dict[str, float]]:
+ rank = {}
+ for p in self.active_players + self.historical_players:
+ name = p.player_id
+ rank[name] = p.rating.exposure
+ if string:
+ headers = ["Player ID", "Rank (TrueSkill)"]
+ data = []
+ for k, v in rank.items():
+ data.append([k, "{:.2f}".format(v)])
+ s = "\n" + tabulate(data, headers=headers, tablefmt='pipe')
+ return s
+ else:
+ return rank
+
+
+def create_league(cfg: EasyDict, *args) -> BaseLeague:
+ """
+ Overview:
+ Given the key (league_type), create a new league instance if in league_mapping's values,
+ or raise an KeyError. In other words, a derived league must first register then call ``create_league``
+ to get the instance object.
+ Arguments:
+ - cfg (:obj:`EasyDict`): league config, necessary keys: [league.import_module, league.learner_type]
+ Returns:
+ - league (:obj:`BaseLeague`): the created new league, should be an instance of one of \
+ league_mapping's values
+ """
+ import_module(cfg.get('import_names', []))
+ return LEAGUE_REGISTRY.build(cfg.league_type, cfg=cfg, *args)
diff --git a/DI-engine/ding/league/metric.py b/DI-engine/ding/league/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..be675898f4d8479dcb4d3b7db7c813719dca978f
--- /dev/null
+++ b/DI-engine/ding/league/metric.py
@@ -0,0 +1,218 @@
+from typing import Tuple, Union, List
+import math
+import numpy as np
+from trueskill import TrueSkill, Rating, rate_1vs1
+
+
+class EloCalculator(object):
+ """
+ Overview:
+ A class that calculates Elo ratings for players based on game results.
+
+ Attributes:
+ - score (:obj:`dict`): A dictionary that maps game results to scores.
+
+ Interfaces:
+ ``__init__``, ``get_new_rating``, ``get_new_rating_array``.
+ """
+
+ score = {
+ 1: 1.0, # win
+ 0: 0.5, # draw
+ -1: 0.0, # lose
+ }
+
+ @classmethod
+ def get_new_rating(cls,
+ rating_a: int,
+ rating_b: int,
+ result: int,
+ k_factor: int = 32,
+ beta: int = 200) -> Tuple[int, int]:
+ """
+ Overview:
+ Calculates the new ratings for two players based on their current ratings and game result.
+
+ Arguments:
+ - rating_a (:obj:`int`): The current rating of player A.
+ - rating_b (:obj:`int`): The current rating of player B.
+ - result (:obj:`int`): The result of the game: 1 for player A win, 0 for draw, -1 for player B win.
+ - k_factor (:obj:`int`): The K-factor used in the Elo rating system. Defaults to 32.
+ - beta (:obj:`int`): The beta value used in the Elo rating system. Defaults to 200.
+
+ Returns:
+ -ret (:obj:`Tuple[int, int]`): The new ratings for player A and player B, respectively.
+ """
+ assert result in [1, 0, -1]
+ expect_a = 1. / (1. + math.pow(10, (rating_b - rating_a) / (2. * beta)))
+ expect_b = 1. / (1. + math.pow(10, (rating_a - rating_b) / (2. * beta)))
+ new_rating_a = rating_a + k_factor * (EloCalculator.score[result] - expect_a)
+ new_rating_b = rating_b + k_factor * (1 - EloCalculator.score[result] - expect_b)
+ return round(new_rating_a), round(new_rating_b)
+
+ @classmethod
+ def get_new_rating_array(
+ cls,
+ rating: np.ndarray,
+ result: np.ndarray,
+ game_count: np.ndarray,
+ k_factor: int = 32,
+ beta: int = 200
+ ) -> np.ndarray:
+ """
+ Overview:
+ Calculates the new ratings for multiple players based on their current ratings, game results, \
+ and game counts.
+
+ Arguments:
+ - rating (obj:`np.ndarray`): An array of current ratings for each player.
+ - result (obj:`np.ndarray`): An array of game results, where 1 represents a win, 0 represents a draw, \
+ and -1 represents a loss.
+ - game_count (obj:`np.ndarray`): An array of game counts for each player.
+ - k_factor (obj:`int`): The K-factor used in the Elo rating system. Defaults to 32.
+ - beta (obj:`int`): The beta value used in the Elo rating system. Defaults to 200.
+
+ Returns:
+ -ret(obj:`np.ndarray`): An array of new ratings for each player.
+
+ Shapes:
+ - rating (obj:`np.ndarray`): :math:`(N, )`, N is the number of player
+ - result (obj:`np.ndarray`): :math:`(N, N)`
+ - game_count (obj:`np.ndarray`): :math:`(N, N)`
+ """
+ rating_diff = np.expand_dims(rating, 0) - np.expand_dims(rating, 1)
+ expect = 1. / (1. + np.power(10, rating_diff / (2. * beta))) * game_count
+ delta = ((result + 1.) / 2 - expect) * (game_count > 0)
+ delta = delta.sum(axis=1)
+ return np.round(rating + k_factor * delta).astype(np.int64)
+
+
+class PlayerRating(Rating):
+ """
+ Overview:
+ Represents the rating of a player.
+
+ Interfaces:
+ ``__init__``, ``__repr__``.
+ """
+
+ def __init__(self, mu: float = None, sigma: float = None, elo_init: int = None) -> None:
+ super(PlayerRating, self).__init__(mu, sigma)
+ self.elo = elo_init
+
+ def __repr__(self) -> str:
+ c = type(self)
+ args = ('.'.join([c.__module__, c.__name__]), self.mu, self.sigma, self.exposure, self.elo)
+ return '%s(mu=%.3f, sigma=%.3f, exposure=%.3f, elo=%d)' % args
+
+
+class LeagueMetricEnv(TrueSkill):
+ """
+ Overview:
+ A class that represents a TrueSkill rating system for game players. Inherits from the TrueSkill class. \
+ For more details, please refer to https://trueskill.org/.
+
+ Interfaces:
+ ``__init__``, ``create_rating``, ``rate_1vs1``, ``rate_1vsC``.
+ """
+
+ def __init__(self, *args, elo_init: int = 1200, **kwargs) -> None:
+ super(LeagueMetricEnv, self).__init__(*args, **kwargs)
+ self.elo_init = elo_init
+
+ def create_rating(self, mu: float = None, sigma: float = None, elo_init: int = None) -> PlayerRating:
+ """
+ Overview:
+ Creates a new player rating object with the specified mean, standard deviation, and Elo rating.
+
+ Arguments:
+ - mu (:obj:`float`): The mean value of the player's skill rating. If not provided, the default \
+ TrueSkill mean is used.
+ - sigma (:obj:`float`): The standard deviation of the player's skill rating. If not provided, \
+ the default TrueSkill sigma is used.
+ - elo_init (:obj:int`): The initial Elo rating value for the player. If not provided, the default \
+ elo_init value of the LeagueMetricEnv class is used.
+
+ Returns:
+ - PlayerRating: A player rating object with the specified mean, standard deviation, and Elo rating.
+ """
+ if mu is None:
+ mu = self.mu
+ if sigma is None:
+ sigma = self.sigma
+ if elo_init is None:
+ elo_init = self.elo_init
+ return PlayerRating(mu, sigma, elo_init)
+
+ @staticmethod
+ def _rate_1vs1(t1, t2, **kwargs):
+ t1_elo, t2_elo = t1.elo, t2.elo
+ t1, t2 = rate_1vs1(t1, t2, **kwargs)
+ if 'drawn' in kwargs:
+ result = 0
+ else:
+ result = 1
+ t1_elo, t2_elo = EloCalculator.get_new_rating(t1_elo, t2_elo, result)
+ t1 = PlayerRating(t1.mu, t1.sigma, t1_elo)
+ t2 = PlayerRating(t2.mu, t2.sigma, t2_elo)
+ return t1, t2
+
+ def rate_1vs1(self, team1: PlayerRating, team2: PlayerRating, result: List[str] = None, **kwargs) \
+ -> Tuple[PlayerRating, PlayerRating]:
+ """
+ Overview:
+ Rates two teams of players against each other in a 1 vs 1 match and returns the updated ratings \
+ for both teams.
+
+ Arguments:
+ - team1 (:obj:`PlayerRating`): The rating object representing the first team of players.
+ - team2 (:obj:`PlayerRating`): The rating object representing the second team of players.
+ - result (:obj:`List[str]`): The result of the match. Can be 'wins', 'draws', or 'losses'. If \
+ not provided, the default behavior is to rate the match as a win for team1.
+
+ Returns:
+ - ret (:obj:`Tuple[PlayerRating, PlayerRating]`): A tuple containing the updated ratings for team1 \
+ and team2.
+ """
+ if result is None:
+ return self._rate_1vs1(team1, team2, **kwargs)
+ else:
+ for r in result:
+ if r == 'wins':
+ team1, team2 = self._rate_1vs1(team1, team2)
+ elif r == 'draws':
+ team1, team2 = self._rate_1vs1(team1, team2, drawn=True)
+ elif r == 'losses':
+ team2, team1 = self._rate_1vs1(team2, team1)
+ else:
+ raise RuntimeError("invalid result: {}".format(r))
+ return team1, team2
+
+ def rate_1vsC(self, team1: PlayerRating, team2: PlayerRating, result: List[str]) -> PlayerRating:
+ """
+ Overview:
+ Rates a team of players against a single player in a 1 vs C match and returns the updated rating \
+ for the team.
+
+ Arguments:
+ - team1 (:obj:`PlayerRating`): The rating object representing the team of players.
+ - team2 (:obj:`PlayerRating`): The rating object representing the single player.
+ - result (:obj:`List[str]`): The result of the match. Can be 'wins', 'draws', or 'losses'.
+
+ Returns:
+ - PlayerRating: The updated rating for the team of players.
+ """
+ for r in result:
+ if r == 'wins':
+ team1, _ = self._rate_1vs1(team1, team2)
+ elif r == 'draws':
+ team1, _ = self._rate_1vs1(team1, team2, drawn=True)
+ elif r == 'losses':
+ _, team1 = self._rate_1vs1(team2, team1)
+ else:
+ raise RuntimeError("invalid result: {}".format(r))
+ return team1
+
+
+get_elo = EloCalculator.get_new_rating
+get_elo_array = EloCalculator.get_new_rating_array
diff --git a/DI-engine/ding/league/one_vs_one_league.py b/DI-engine/ding/league/one_vs_one_league.py
new file mode 100644
index 0000000000000000000000000000000000000000..2555aa1d9130f5d871cd3014b40f3324c0291509
--- /dev/null
+++ b/DI-engine/ding/league/one_vs_one_league.py
@@ -0,0 +1,127 @@
+from easydict import EasyDict
+from typing import Optional
+
+from ding.utils import LEAGUE_REGISTRY
+from .base_league import BaseLeague
+from .player import ActivePlayer
+
+
+@LEAGUE_REGISTRY.register('one_vs_one')
+class OneVsOneLeague(BaseLeague):
+ """
+ Overview:
+ One vs One battle game league.
+ Decide which two players will play against each other.
+ Interface:
+ __init__, run, close, finish_job, update_active_player
+ """
+ config = dict(
+ league_type='one_vs_one',
+ import_names=["ding.league"],
+ # ---player----
+ # "player_category" is just a name. Depends on the env.
+ # For example, in StarCraft, this can be ['zerg', 'terran', 'protoss'].
+ player_category=['default'],
+ # Support different types of active players for solo and battle league.
+ # For solo league, supports ['solo_active_player'].
+ # For battle league, supports ['battle_active_player', 'main_player', 'main_exploiter', 'league_exploiter'].
+ active_players=dict(
+ naive_sp_player=1, # {player_type: player_num}
+ ),
+ naive_sp_player=dict(
+ # There should be keys ['one_phase_step', 'branch_probs', 'strong_win_rate'].
+ # Specifically for 'main_exploiter' of StarCraft, there should be an additional key ['min_valid_win_rate'].
+ one_phase_step=10,
+ branch_probs=dict(
+ pfsp=0.5,
+ sp=0.5,
+ ),
+ strong_win_rate=0.7,
+ ),
+ # "use_pretrain" means whether to use pretrain model to initialize active player.
+ use_pretrain=False,
+ # "use_pretrain_init_historical" means whether to use pretrain model to initialize historical player.
+ # "pretrain_checkpoint_path" is the pretrain checkpoint path used in "use_pretrain" and
+ # "use_pretrain_init_historical". If both are False, "pretrain_checkpoint_path" can be omitted as well.
+ # Otherwise, "pretrain_checkpoint_path" should list paths of all player categories.
+ use_pretrain_init_historical=False,
+ pretrain_checkpoint_path=dict(default='default_cate_pretrain.pth', ),
+ # ---payoff---
+ payoff=dict(
+ # Supports ['battle']
+ type='battle',
+ decay=0.99,
+ min_win_rate_games=8,
+ ),
+ metric=dict(
+ mu=0,
+ sigma=25 / 3,
+ beta=25 / 3 / 2,
+ tau=0.0,
+ draw_probability=0.02,
+ ),
+ )
+
+ # override
+ def _get_job_info(self, player: ActivePlayer, eval_flag: bool = False) -> dict:
+ """
+ Overview:
+ Get player's job related info, called by ``_launch_job``.
+ Arguments:
+ - player (:obj:`ActivePlayer`): The active player that will be assigned a job.
+ """
+ assert isinstance(player, ActivePlayer), player.__class__
+ player_job_info = EasyDict(player.get_job(eval_flag))
+ if eval_flag:
+ return {
+ 'agent_num': 1,
+ 'launch_player': player.player_id,
+ 'player_id': [player.player_id],
+ 'checkpoint_path': [player.checkpoint_path],
+ 'player_active_flag': [isinstance(player, ActivePlayer)],
+ 'eval_opponent': player_job_info.opponent,
+ }
+ else:
+ return {
+ 'agent_num': 2,
+ 'launch_player': player.player_id,
+ 'player_id': [player.player_id, player_job_info.opponent.player_id],
+ 'checkpoint_path': [player.checkpoint_path, player_job_info.opponent.checkpoint_path],
+ 'player_active_flag': [isinstance(p, ActivePlayer) for p in [player, player_job_info.opponent]],
+ }
+
+ # override
+ def _mutate_player(self, player: ActivePlayer):
+ """
+ Overview:
+ Players have the probability to be reset to supervised learning model parameters.
+ Arguments:
+ - player (:obj:`ActivePlayer`): The active player that may mutate.
+ """
+ pass
+
+ # override
+ def _update_player(self, player: ActivePlayer, player_info: dict) -> Optional[bool]:
+ """
+ Overview:
+ Update an active player, called by ``self.update_active_player``.
+ Arguments:
+ - player (:obj:`ActivePlayer`): The active player that will be updated.
+ - player_info (:obj:`dict`): An info dict of the active player which is to be updated.
+ Returns:
+ - increment_eval_difficulty (:obj:`bool`): Only return this when evaluator calls this method. \
+ Return True if difficulty is incremented; Otherwise return False (difficulty will not increment \
+ when it is already the most difficult or evaluator loses)
+ """
+ assert isinstance(player, ActivePlayer)
+ if 'train_iteration' in player_info:
+ # Update info from learner
+ player.total_agent_step = player_info['train_iteration']
+ return False
+ elif 'eval_win' in player_info:
+ if player_info['eval_win']:
+ # Update info from evaluator
+ increment_eval_difficulty = player.increment_eval_difficulty()
+ return increment_eval_difficulty
+ else:
+ return False
diff --git a/DI-engine/ding/league/player.py b/DI-engine/ding/league/player.py
new file mode 100644
index 0000000000000000000000000000000000000000..e253c0bdadfe460818ebbe241ba494c354e204bb
--- /dev/null
+++ b/DI-engine/ding/league/player.py
@@ -0,0 +1,343 @@
+from typing import Callable, Optional, List
+from collections import namedtuple
+import numpy as np
+from easydict import EasyDict
+
+from ding.utils import import_module, PLAYER_REGISTRY
+from .algorithm import pfsp
+
+
+class Player:
+ """
+ Overview:
+ Base player class, player is the basic member of a league
+ Interfaces:
+ __init__
+ Property:
+ race, payoff, checkpoint_path, player_id, total_agent_step
+ """
+ _name = "BasePlayer" # override this variable for sub-class player
+
+ def __init__(
+ self,
+ cfg: EasyDict,
+ category: str,
+ init_payoff: 'BattleSharedPayoff', # noqa
+ checkpoint_path: str,
+ player_id: str,
+ total_agent_step: int,
+ rating: 'PlayerRating', # noqa
+ ) -> None:
+ """
+ Overview:
+ Initialize base player metadata
+ Arguments:
+ - cfg (:obj:`EasyDict`): Player config dict.
+ - category (:obj:`str`): Player category, depending on the game, \
+ e.g. StarCraft has 3 races ['terran', 'protoss', 'zerg'].
+ - init_payoff (:obj:`Union[BattleSharedPayoff, SoloSharedPayoff]`): Payoff shared by all players.
+ - checkpoint_path (:obj:`str`): The path to load player checkpoint.
+ - player_id (:obj:`str`): Player id in string format.
+ - total_agent_step (:obj:`int`): For active player, it should be 0; \
+ For historical player, it should be parent player's ``_total_agent_step`` when ``snapshot``.
+ - rating (:obj:`PlayerRating`): player rating information in total league
+ """
+ self._cfg = cfg
+ self._category = category
+ self._payoff = init_payoff
+ self._checkpoint_path = checkpoint_path
+ assert isinstance(player_id, str)
+ self._player_id = player_id
+ assert isinstance(total_agent_step, int), (total_agent_step, type(total_agent_step))
+ self._total_agent_step = total_agent_step
+ self._rating = rating
+
+ @property
+ def category(self) -> str:
+ return self._category
+
+ @property
+ def payoff(self) -> 'BattleSharedPayoff': # noqa
+ return self._payoff
+
+ @property
+ def checkpoint_path(self) -> str:
+ return self._checkpoint_path
+
+ @property
+ def player_id(self) -> str:
+ return self._player_id
+
+ @property
+ def total_agent_step(self) -> int:
+ return self._total_agent_step
+
+ @total_agent_step.setter
+ def total_agent_step(self, step: int) -> None:
+ self._total_agent_step = step
+
+ @property
+ def rating(self) -> 'PlayerRating': # noqa
+ return self._rating
+
+ @rating.setter
+ def rating(self, _rating: 'PlayerRating') -> None: # noqa
+ self._rating = _rating
+
+
+@PLAYER_REGISTRY.register('historical_player')
+class HistoricalPlayer(Player):
+ """
+ Overview:
+ Historical player which is snapshotted from an active player, and is fixed with the checkpoint.
+ Have a unique attribute ``parent_id``.
+ Property:
+ race, payoff, checkpoint_path, player_id, total_agent_step, parent_id
+ """
+ _name = "HistoricalPlayer"
+
+ def __init__(self, *args, parent_id: str) -> None:
+ """
+ Overview:
+ Initialize ``_parent_id`` additionally
+ Arguments:
+ - parent_id (:obj:`str`): id of historical player's parent, should be an active player
+ """
+ super().__init__(*args)
+ self._parent_id = parent_id
+
+ @property
+ def parent_id(self) -> str:
+ return self._parent_id
+
+
+class ActivePlayer(Player):
+ """
+ Overview:
+ Active player can be updated, or snapshotted to a historical player in the league training.
+ Interface:
+ __init__, is_trained_enough, snapshot, mutate, get_job
+ Property:
+ race, payoff, checkpoint_path, player_id, total_agent_step
+ """
+ _name = "ActivePlayer"
+ BRANCH = namedtuple("BRANCH", ['name', 'prob'])
+
+ def __init__(self, *args, **kwargs) -> None:
+ """
+ Overview:
+ Initialize player metadata, depending on the game
+ Note:
+ - one_phase_step (:obj:`int`): An active player will be considered trained enough for snapshot \
+ after two phase steps.
+ - last_enough_step (:obj:`int`): Player's last step number that satisfies ``_is_trained_enough``.
+ - strong_win_rate (:obj:`float`): If win rates between this player and all the opponents are greater than
+ this value, this player can be regarded as strong enough to these opponents. \
+ If also already trained for one phase step, this player can be regarded as trained enough for snapshot.
+ - branch_probs (:obj:`namedtuple`): A namedtuple of probabilities of selecting different opponent branch.
+ """
+ super().__init__(*args)
+ self._one_phase_step = int(float(self._cfg.one_phase_step)) # ``one_phase_step`` is like 1e9
+ self._last_enough_step = 0
+ self._strong_win_rate = self._cfg.strong_win_rate
+ assert isinstance(self._cfg.branch_probs, dict)
+ self._branch_probs = [self.BRANCH(k, v) for k, v in self._cfg.branch_probs.items()]
+ # self._eval_opponent_difficulty = ["WEAK", "MEDIUM", "STRONG"]
+ self._eval_opponent_difficulty = ["RULE_BASED"]
+ self._eval_opponent_index = 0
+
+ def is_trained_enough(self, select_fn: Optional[Callable] = None) -> bool:
+ """
+ Overview:
+ Judge whether this player is trained enough for further operations(e.g. snapshot, mutate...)
+ according to past step count and overall win rates against opponents.
+ If yes, set ``self._last_agent_step`` to ``self._total_agent_step`` and return True; otherwise return False.
+ Arguments:
+ - select_fn (:obj:`function`): The function to select opponent players.
+ Returns:
+ - flag (:obj:`bool`): Whether this player is trained enough
+ """
+ if select_fn is None:
+ select_fn = lambda x: isinstance(x, HistoricalPlayer) # noqa
+ step_passed = self._total_agent_step - self._last_enough_step
+ if step_passed < self._one_phase_step:
+ return False
+ elif step_passed >= 2 * self._one_phase_step:
+ # ``step_passed`` is 2 times of ``self._one_phase_step``, regarded as trained enough
+ self._last_enough_step = self._total_agent_step
+ return True
+ else:
+ # Get payoff against specific opponents (Different players have different type of opponent players)
+ # If min win rate is larger than ``self._strong_win_rate``, then is judged trained enough
+ selected_players = self._get_players(select_fn)
+ if len(selected_players) == 0: # No such player, therefore no past game
+ return False
+ win_rates = self._payoff[self, selected_players]
+ if win_rates.min() > self._strong_win_rate:
+ self._last_enough_step = self._total_agent_step
+ return True
+ else:
+ return False
+
+ def snapshot(self, metric_env: 'LeagueMetricEnv') -> HistoricalPlayer: # noqa
+ """
+ Overview:
+ Generate a snapshot historical player from the current player, called in league's ``_snapshot``.
+ Argument:
+ - metric_env (:obj:`LeagueMetricEnv`): player rating environment, one league one env
+ Returns:
+ - snapshot_player (:obj:`HistoricalPlayer`): new instantiated historical player
+
+ .. note::
+ This method only generates a historical player object, but without saving the checkpoint, which should be
+ done by league.
+ """
+ path = self.checkpoint_path.split('.pth')[0] + '_{}'.format(self._total_agent_step) + '.pth'
+ return HistoricalPlayer(
+ self._cfg,
+ self.category,
+ self.payoff,
+ path,
+ self.player_id + '_{}_historical'.format(int(self._total_agent_step)),
+ self._total_agent_step,
+ metric_env.create_rating(mu=self.rating.mu),
+ parent_id=self.player_id
+ )
+
+ def mutate(self, info: dict) -> Optional[str]:
+ """
+ Overview:
+ Mutate the current player, called in league's ``_mutate_player``.
+ Arguments:
+ - info (:obj:`dict`): related information for the mutation
+ Returns:
+ - mutation_result (:obj:`str`): if the player does the mutation operation then returns the
+ corresponding model path, otherwise returns None
+ """
+ pass
+
+ def get_job(self, eval_flag: bool = False) -> dict:
+ """
+ Overview:
+ Get a dict containing some info about the job to be launched, e.g. the selected opponent.
+ Arguments:
+ - eval_flag (:obj:`bool`): Whether to select an opponent for evaluator task.
+ Returns:
+ - ret (:obj:`dict`): The returned dict. Should contain key ['opponent'].
+ """
+ if eval_flag:
+ # eval opponent is a str.
+ opponent = self._eval_opponent_difficulty[self._eval_opponent_index]
+ else:
+ # collect opponent is a Player.
+ opponent = self._get_collect_opponent()
+ return {
+ 'opponent': opponent,
+ }
+
+ def _get_collect_opponent(self) -> Player:
+ """
+ Overview:
+ Select an opponent according to the player's ``branch_probs``.
+ Returns:
+ - opponent (:obj:`Player`): Selected opponent.
+ """
+ p = np.random.uniform()
+ L = len(self._branch_probs)
+ cum_p = [0.] + [sum([j.prob for j in self._branch_probs[:i + 1]]) for i in range(L)]
+ idx = [cum_p[i] <= p < cum_p[i + 1] for i in range(L)].index(True)
+ branch_name = '_{}_branch'.format(self._branch_probs[idx].name)
+ opponent = getattr(self, branch_name)()
+ return opponent
+
+ def _get_players(self, select_fn: Callable) -> List[Player]:
+ """
+ Overview:
+ Get a list of players in the league (shared_payoff), selected by ``select_fn`` .
+ Arguments:
+ - select_fn (:obj:`function`): players in the returned list must satisfy this function
+ Returns:
+ - players (:obj:`list`): a list of players that satisfies ``select_fn``
+ """
+ return [player for player in self._payoff.players if select_fn(player)]
+
+ def _get_opponent(self, players: list, p: Optional[np.ndarray] = None) -> Player:
+ """
+ Overview:
+ Get one opponent player from list ``players`` according to probability ``p``.
+ Arguments:
+ - players (:obj:`list`): a list of players that can select opponent from
+ - p (:obj:`np.ndarray`): the selection probability of each player, should have the same size as \
+ ``players``. If you don't need it and set None, it would select uniformly by default.
+ Returns:
+ - opponent_player (:obj:`Player`): a random chosen opponent player according to probability
+ """
+ idx = np.random.choice(len(players), p=p)
+ return players[idx]
+
+ def increment_eval_difficulty(self) -> bool:
+ """
+ Overview:
+ When evaluating, active player will choose a specific builtin opponent difficulty.
+ This method is used to increment the difficulty.
+ It is usually called after the easier builtin bot is already been beaten by this player.
+ Returns:
+ - increment_or_not (:obj:`bool`): True means difficulty is incremented; \
+ False means difficulty is already the hardest.
+ """
+ if self._eval_opponent_index < len(self._eval_opponent_difficulty) - 1:
+ self._eval_opponent_index += 1
+ return True
+ else:
+ return False
+
+ @property
+ def checkpoint_path(self) -> str:
+ return self._checkpoint_path
+
+ @checkpoint_path.setter
+ def checkpoint_path(self, path: str) -> None:
+ self._checkpoint_path = path
+
+
+@PLAYER_REGISTRY.register('naive_sp_player')
+class NaiveSpPlayer(ActivePlayer):
+
+ def _pfsp_branch(self) -> HistoricalPlayer:
+ """
+ Overview:
+ Select prioritized fictitious self-play opponent, should be a historical player.
+ Returns:
+ - player (:obj:`HistoricalPlayer`): The selected historical player.
+ """
+ historical = self._get_players(lambda p: isinstance(p, HistoricalPlayer))
+ win_rates = self._payoff[self, historical]
+ # Normal self-play if no historical players
+ if win_rates.shape == (0, ):
+ return self
+ p = pfsp(win_rates, weighting='squared')
+ return self._get_opponent(historical, p)
+
+ def _sp_branch(self) -> ActivePlayer:
+ """
+ Overview:
+ Select normal self-play opponent
+ """
+ return self
+
+
+def create_player(cfg: EasyDict, player_type: str, *args, **kwargs) -> Player:
+ """
+ Overview:
+ Given the key (player_type), create a new player instance if in player_mapping's values,
+ or raise an KeyError. In other words, a derived player must first register then call ``create_player``
+ to get the instance object.
+ Arguments:
+ - cfg (:obj:`EasyDict`): player config, necessary keys: [import_names]
+ - player_type (:obj:`str`): the type of player to be created
+ Returns:
+ - player (:obj:`Player`): the created new player, should be an instance of one of \
+ player_mapping's values
+ """
+ import_module(cfg.get('import_names', []))
+ return PLAYER_REGISTRY.build(player_type, *args, **kwargs)
diff --git a/DI-engine/ding/league/shared_payoff.py b/DI-engine/ding/league/shared_payoff.py
new file mode 100644
index 0000000000000000000000000000000000000000..7576d441c0323812eef882bf46a9d0861498287a
--- /dev/null
+++ b/DI-engine/ding/league/shared_payoff.py
@@ -0,0 +1,261 @@
+import copy
+from collections import defaultdict
+from typing import Tuple, Optional
+from easydict import EasyDict
+from tabulate import tabulate
+import numpy as np
+
+from ding.utils import LockContext, LockContextType
+from .player import Player
+
+
+class BattleRecordDict(dict):
+ """
+ Overview:
+ A dict which is used to record battle game result.
+ Initialized four fixed keys: `wins`, `draws`, `losses`, `games`; Each with value 0.
+ Interfaces:
+ __mul__
+ """
+ data_keys = ['wins', 'draws', 'losses', 'games']
+
+ def __init__(self) -> None:
+ """
+ Overview:
+ Initialize four fixed keys ['wins', 'draws', 'losses', 'games'] and set value to 0
+ """
+ super(BattleRecordDict, self).__init__()
+ for k in self.data_keys:
+ self[k] = 0
+
+ def __mul__(self, decay: float) -> dict:
+ """
+ Overview:
+ Multiply each key's value with the input multiplier ``decay``
+ Arguments:
+ - decay (:obj:`float`): The multiplier.
+ Returns:
+ - obj (:obj:`dict`): A deepcopied RecordDict after multiplication decay.
+ """
+ obj = copy.deepcopy(self)
+ for k in obj.keys():
+ obj[k] *= decay
+ return obj
+
+
+class BattleSharedPayoff:
+ """
+ Overview:
+ Payoff data structure to record historical match result, this payoff is shared among all the players.
+ Use LockContext to ensure thread safe, since all players from all threads can access and modify it.
+ Interface:
+ __getitem__, add_player, update, get_key
+ Property:
+ players
+ """
+
+ # TODO(nyz) whether ensures the thread-safe
+
+ def __init__(self, cfg: EasyDict):
+ """
+ Overview:
+ Initialize battle payoff
+ Arguments:
+ - cfg (:obj:`dict`): config(contains {decay, min_win_rate_games})
+ """
+ # ``_players``` is a list containing the references(shallow copy) of all players,
+ # while ``_players_ids``` is a list of strings.
+ self._players = []
+ self._players_ids = []
+ # ``_data``` is a defaultdict. If a key doesn't exist when query, return an instance of BattleRecordDict class.
+ # Key is '[player_id]-[player_id]' string, value is the payoff of the two players.
+ self._data = defaultdict(BattleRecordDict)
+ # ``_decay``` controls how past game info (win, draw, loss) decays.
+ self._decay = cfg.decay
+ # ``_min_win_rate_games``` is used in ``self._win_rate`` method for calculating win rate between two players.
+ self._min_win_rate_games = cfg.get('min_win_rate_games', 8)
+ # Thread lock.
+ self._lock = LockContext(type_=LockContextType.THREAD_LOCK)
+
+ def __repr__(self) -> str:
+ headers = ["Home Player", "Away Player", "Wins", "Draws", "Losses", "Naive Win Rate"]
+ data = []
+ for k, v in self._data.items():
+ k1 = k.split('-')
+ # k is the format of '{}-{}'.format(name1, name2), and each HistoricalPlayer has `historical` suffix
+ if 'historical' in k1[0]:
+ # reverse representation
+ naive_win_rate = (v['losses'] + v['draws'] / 2) / (v['wins'] + v['losses'] + v['draws'] + 1e-8)
+ data.append([k1[1], k1[0], v['losses'], v['draws'], v['wins'], naive_win_rate])
+ else:
+ naive_win_rate = (v['wins'] + v['draws'] / 2) / (v['wins'] + v['losses'] + v['draws'] + 1e-8)
+ data.append([k1[0], k1[1], v['wins'], v['draws'], v['losses'], naive_win_rate])
+ data = sorted(data, key=lambda x: x[0])
+ s = tabulate(data, headers=headers, tablefmt='pipe')
+ return s
+
+ def __getitem__(self, players: tuple) -> np.ndarray:
+ """
+ Overview:
+ Get win rates between home players and away players one by one
+ Arguments:
+ - players (:obj:`tuple`): A tuple of (home, away), each one is a player or a player list.
+ Returns:
+ - win_rates (:obj:`np.ndarray`): Win rate (squeezed, see Shape for more details) \
+ between each player from home and each player from away.
+ Shape:
+ - win_rates: Assume there are m home players and n away players.(m,n > 0)
+
+ - m != 1 and n != 1: shape is (m, n)
+ - m == 1: shape is (n)
+ - n == 1: shape is (m)
+ """
+ with self._lock:
+ home, away = players
+ assert isinstance(home, list) or isinstance(home, Player)
+ assert isinstance(away, list) or isinstance(away, Player)
+ if isinstance(home, Player):
+ home = [home]
+ if isinstance(away, Player):
+ away = [away]
+ win_rates = np.array([[self._win_rate(h.player_id, a.player_id) for a in away] for h in home])
+ if len(home) == 1 or len(away) == 1:
+ win_rates = win_rates.reshape(-1)
+ return win_rates
+
+ def _win_rate(self, home: str, away: str) -> float:
+ """
+ Overview:
+ Calculate win rate of one `home player` vs one `away player`
+ Arguments:
+ - home (:obj:`str`): home player id to access win rate
+ - away (:obj:`str`): away player id to access win rate
+ Returns:
+ - win rate (:obj:`float`): float win rate value. \
+ Only when total games is no less than ``self._min_win_rate_games``, \
+ can the win rate be calculated by (wins + draws/2) / games, or return 0.5 by default.
+ """
+ key, reverse = self.get_key(home, away)
+ handle = self._data[key]
+ # No enough game records.
+ if handle['games'] < self._min_win_rate_games:
+ return 0.5
+ # should use reverse here
+ wins = handle['wins'] if not reverse else handle['losses']
+ return (wins + 0.5 * handle['draws']) / (handle['games'])
+
+ @property
+ def players(self):
+ """
+ Overview:
+ Get all the players
+ Returns:
+ - players (:obj:`list`): players list
+ """
+ with self._lock:
+ return self._players
+
+ def add_player(self, player: Player) -> None:
+ """
+ Overview:
+ Add a player to the shared payoff.
+ Arguments:
+ - player (:obj:`Player`): The player to be added. Usually is a new one to the league as well.
+ """
+ with self._lock:
+ self._players.append(player)
+ self._players_ids.append(player.player_id)
+
+ def update(self, job_info: dict) -> bool:
+ """
+ Overview:
+ Update payoff with job_info when a job is to be finished.
+ If update succeeds, return True; If raises an exception when updating, resolve it and return False.
+ Arguments:
+ - job_info (:obj:`dict`): A dict containing job result information.
+ Returns:
+ - result (:obj:`bool`): Whether update is successful.
+
+ .. note::
+ job_info has at least 5 keys ['launch_player', 'player_id', 'env_num', 'episode_num', 'result'].
+ Key ``player_id`` 's value is a tuple of (home_id, away_id).
+ Key ``result`` 's value is a two-layer list with the length of (episode_num, env_num).
+ """
+
+ def _win_loss_reverse(result_: str, reverse_: bool) -> str:
+ if result_ == 'draws' or not reverse_:
+ return result_
+ reverse_dict = {'wins': 'losses', 'losses': 'wins'}
+ return reverse_dict[result_]
+
+ with self._lock:
+ home_id, away_id = job_info['player_id']
+ job_info_result = job_info['result']
+ # for compatibility of one-layer list
+ if not isinstance(job_info_result[0], list):
+ job_info_result = [job_info_result]
+ try:
+ assert home_id in self._players_ids, "home_id error"
+ assert away_id in self._players_ids, "away_id error"
+ # Assert all results are in ['wins', 'losses', 'draws']
+ assert all([i in BattleRecordDict.data_keys[:3] for j in job_info_result for i in j]), "results error"
+ except Exception as e:
+ print("[ERROR] invalid job_info: {}\n\tError reason is: {}".format(job_info, e))
+ return False
+ if home_id == away_id: # self-play
+ key, reverse = self.get_key(home_id, away_id)
+ self._data[key]['draws'] += 1 # self-play defaults to draws
+ self._data[key]['games'] += 1
+ else:
+ key, reverse = self.get_key(home_id, away_id)
+ # Update with decay
+ # job_info_result is a two-layer list, including total NxM episodes of M envs,
+ # the first(outer) layer is episode dimension and the second(inner) layer is env dimension.
+ for one_episode_result in job_info_result:
+ for one_episode_result_per_env in one_episode_result:
+ # All categories should decay
+ self._data[key] *= self._decay
+ self._data[key]['games'] += 1
+ result = _win_loss_reverse(one_episode_result_per_env, reverse)
+ self._data[key][result] += 1
+ return True
+
+ def get_key(self, home: str, away: str) -> Tuple[str, bool]:
+ """
+ Overview:
+ Join home player id and away player id in alphabetival order.
+ Arguments:
+ - home (:obj:`str`): Home player id
+ - away (:obj:`str`): Away player id
+ Returns:
+ - key (:obj:`str`): Tow ids sorted in alphabetical order, and joined by '-'.
+ - reverse (:obj:`bool`): Whether the two player ids are reordered.
+ """
+ assert isinstance(home, str)
+ assert isinstance(away, str)
+ reverse = False
+ if home <= away:
+ tmp = [home, away]
+ else:
+ tmp = [away, home]
+ reverse = True
+ return '-'.join(tmp), reverse
+
+
+def create_payoff(cfg: EasyDict) -> Optional[BattleSharedPayoff]:
+ """
+ Overview:
+ Given the key (payoff type), now supports keys ['solo', 'battle'],
+ create a new payoff instance if in payoff_mapping's values, or raise an KeyError.
+ Arguments:
+ - cfg (:obj:`EasyDict`): payoff config containing at least one key 'type'
+ Returns:
+ - payoff (:obj:`BattleSharedPayoff` or :obj:`SoloSharedPayoff`): the created new payoff, \
+ should be an instance of one of payoff_mapping's values
+ """
+ payoff_mapping = {'battle': BattleSharedPayoff}
+ payoff_type = cfg.type
+ if payoff_type not in payoff_mapping.keys():
+ raise KeyError("not support payoff type: {}".format(payoff_type))
+ else:
+ return payoff_mapping[payoff_type](cfg)
diff --git a/DI-engine/ding/league/starcraft_player.py b/DI-engine/ding/league/starcraft_player.py
new file mode 100644
index 0000000000000000000000000000000000000000..81d53e73bd2e0c54b353e1d5f5421745a2de92b5
--- /dev/null
+++ b/DI-engine/ding/league/starcraft_player.py
@@ -0,0 +1,234 @@
+from typing import Optional, Union
+import numpy as np
+
+from ding.utils import PLAYER_REGISTRY
+from .player import ActivePlayer, HistoricalPlayer
+from .algorithm import pfsp
+
+
+@PLAYER_REGISTRY.register('main_player')
+class MainPlayer(ActivePlayer):
+ """
+ Overview:
+ Main player in league training.
+ Default branch (0.5 pfsp, 0.35 sp, 0.15 veri).
+ Default snapshot every 2e9 steps.
+ Default mutate prob = 0 (never mutate).
+ Interface:
+ __init__, is_trained_enough, snapshot, mutate, get_job
+ Property:
+ race, payoff, checkpoint_path, player_id, train_iteration
+ """
+ _name = "MainPlayer"
+
+ def _pfsp_branch(self) -> HistoricalPlayer:
+ """
+ Overview:
+ Select prioritized fictitious self-play opponent, should be a historical player.
+ Returns:
+ - player (:obj:`HistoricalPlayer`): the selected historical player
+ """
+ historical = self._get_players(lambda p: isinstance(p, HistoricalPlayer))
+ win_rates = self._payoff[self, historical]
+ p = pfsp(win_rates, weighting='squared')
+ return self._get_opponent(historical, p)
+
+ def _sp_branch(self):
+ """
+ Overview:
+ Select normal self-play opponent
+ """
+ main_players = self._get_players(lambda p: isinstance(p, MainPlayer))
+ main_opponent = self._get_opponent(main_players)
+
+ # TODO(nyz) if only one main_player, self-play win_rates are constantly equal to 0.5
+ # main_opponent is not too strong
+ if self._payoff[self, main_opponent] > 1 - self._strong_win_rate:
+ return main_opponent
+
+ # if the main_opponent is too strong, select a past alternative
+ historical = self._get_players(
+ lambda p: isinstance(p, HistoricalPlayer) and p.parent_id == main_opponent.player_id
+ )
+ win_rates = self._payoff[self, historical]
+ p = pfsp(win_rates, weighting='variance')
+ return self._get_opponent(historical, p)
+
+ def _verification_branch(self):
+ """
+ Overview:
+ Verify no strong historical main exploiter and no forgotten historical past main player
+ """
+ # check exploitation
+ main_exploiters = self._get_players(lambda p: isinstance(p, MainExploiter))
+ exp_historical = self._get_players(
+ lambda p: isinstance(p, HistoricalPlayer) and any([p.parent_id == m.player_id for m in main_exploiters])
+ )
+ win_rates = self._payoff[self, exp_historical]
+ # TODO(nyz) why min win_rates 0.3
+ if len(win_rates) and win_rates.min() < 1 - self._strong_win_rate:
+ p = pfsp(win_rates, weighting='squared')
+ return self._get_opponent(exp_historical, p)
+
+ # check forgotten
+ main_players = self._get_players(lambda p: isinstance(p, MainPlayer))
+ main_opponent = self._get_opponent(main_players) # only one main player
+ main_historical = self._get_players(
+ lambda p: isinstance(p, HistoricalPlayer) and p.parent_id == main_opponent.player_id
+ )
+ win_rates = self._payoff[self, main_historical]
+ # TODO(nyz) whether the method `_get_players` should return players with some sequence(such as step)
+ # win_rates, historical = self._remove_monotonic_suffix(win_rates, historical)
+ if len(win_rates) and win_rates.min() < self._strong_win_rate:
+ p = pfsp(win_rates, weighting='squared')
+ return self._get_opponent(main_historical, p)
+
+ # no forgotten main players or strong main exploiters, use self-play instead
+ return self._sp_branch()
+
+ # def _remove_monotonic_suffix(self, win_rates, players):
+ # if not len(win_rates):
+ # return win_rates, players
+ # for i in range(len(win_rates) - 1, 0, -1):
+ # if win_rates[i - 1] < win_rates[i]:
+ # return win_rates[:i + 1], players[:i + 1]
+ # return np.array([]), []
+
+ # override
+ def is_trained_enough(self) -> bool:
+ # ``_pfsp_branch`` and ``_verification_branch`` are played against historcial player
+ return super().is_trained_enough(select_fn=lambda p: isinstance(p, HistoricalPlayer))
+
+ # override
+ def mutate(self, info: dict) -> None:
+ """
+ Overview:
+ MainPlayer does not mutate
+ """
+ pass
+
+
+@PLAYER_REGISTRY.register('main_exploiter')
+class MainExploiter(ActivePlayer):
+ """
+ Overview:
+ Main exploiter in league training. Can identify weaknesses of main agents, and consequently make them
+ more robust.
+ Default branch (1.0 main_players).
+ Default snapshot when defeating all 3 main players in the league in more than 70% of games,
+ or timeout of 4e9 steps.
+ Default mutate prob = 1 (must mutate).
+ Interface:
+ __init__, is_trained_enough, snapshot, mutate, get_job
+ Property:
+ race, payoff, checkpoint_path, player_id, train_iteration
+ """
+ _name = "MainExploiter"
+
+ def __init__(self, *args, **kwargs):
+ """
+ Overview:
+ Initialize ``min_valid_win_rate`` additionally
+ Note:
+ - min_valid_win_rate (:obj:`float`): only when win rate against the main player is greater than this, \
+ can the main player be regarded as able to produce valid training signals to be selected
+ """
+ super(MainExploiter, self).__init__(*args, **kwargs)
+ self._min_valid_win_rate = self._cfg.min_valid_win_rate
+
+ def _main_players_branch(self):
+ """
+ Overview:
+ Select main player or historical player snapshot from main player as opponent
+ Returns:
+ - player (:obj:`Player`): the selected main player (active/historical)
+ """
+ # get the main player (only one)
+ main_players = self._get_players(lambda p: isinstance(p, MainPlayer))
+ main_opponent = self._get_opponent(main_players)
+ # if this main_opponent can produce valid training signals
+ if self._payoff[self, main_opponent] >= self._min_valid_win_rate:
+ return main_opponent
+ # otherwise, curriculum learning, select a historical version
+ historical = self._get_players(
+ lambda p: isinstance(p, HistoricalPlayer) and p.parent_id == main_opponent.player_id
+ )
+ win_rates = self._payoff[self, historical]
+ p = pfsp(win_rates, weighting='variance')
+ return self._get_opponent(historical, p)
+
+ # override
+ def is_trained_enough(self):
+ # would play against main player, or historical main player (if main player is too strong)
+ return super().is_trained_enough(select_fn=lambda p: isinstance(p, MainPlayer))
+
+ # override
+ def mutate(self, info: dict) -> str:
+ """
+ Overview:
+ Main exploiter is sure to mutate(reset) to the supervised learning player
+ Returns:
+ - mutate_ckpt_path (:obj:`str`): mutation target checkpoint path
+ """
+ return info['reset_checkpoint_path']
+
+
+@PLAYER_REGISTRY.register('league_exploiter')
+class LeagueExploiter(ActivePlayer):
+ """
+ Overview:
+ League exploiter in league training. Can identify global blind spots in the league (strategies that no player
+ in the league can beat, but that are not necessarily robust themselves).
+ Default branch (1.0 pfsp).
+ Default snapshot when defeating all players in the league in more than 70% of games, or timeout of 2e9 steps.
+ Default mutate prob = 0.25.
+ Interface:
+ __init__, is_trained_enough, snapshot, mutate, get_job
+ Property:
+ race, payoff, checkpoint_path, player_id, train_iteration
+ """
+ _name = "LeagueExploiter"
+
+ def __init__(self, *args, **kwargs) -> None:
+ """
+ Overview:
+ Initialize ``mutate_prob`` additionally
+ Note:
+ - mutate_prob (:obj:`float`): the mutation probability of league exploiter. should be in [0, 1]
+ """
+ super(LeagueExploiter, self).__init__(*args, **kwargs)
+ assert 0 <= self._cfg.mutate_prob <= 1
+ self.mutate_prob = self._cfg.mutate_prob
+
+ def _pfsp_branch(self) -> HistoricalPlayer:
+ """
+ Overview:
+ Select prioritized fictitious self-play opponent
+ Returns:
+ - player (:obj:`HistoricalPlayer`): the selected historical player
+ Note:
+ This branch is the same as the psfp branch in MainPlayer
+ """
+ historical = self._get_players(lambda p: isinstance(p, HistoricalPlayer))
+ win_rates = self._payoff[self, historical]
+ p = pfsp(win_rates, weighting='squared')
+ return self._get_opponent(historical, p)
+
+ # override
+ def is_trained_enough(self) -> bool:
+ # will only player against historical player
+ return super().is_trained_enough(select_fn=lambda p: isinstance(p, HistoricalPlayer))
+
+ # override
+ def mutate(self, info) -> Union[str, None]:
+ """
+ Overview:
+ League exploiter can mutate to the supervised learning player with 0.25 prob
+ Returns:
+ - ckpt_path (:obj:`Union[str, None]`): with ``mutate_prob`` prob returns the pretrained model's ckpt path, \
+ with left 1 - ``mutate_prob`` prob returns None, which means no mutation
+ """
+ p = np.random.uniform()
+ if p < self.mutate_prob:
+ return info['reset_checkpoint_path']
+ return None
diff --git a/DI-engine/ding/league/tests/conftest.py b/DI-engine/ding/league/tests/conftest.py
new file mode 100644
index 0000000000000000000000000000000000000000..d00a62f0fc8032e37d8de9e4a29b47af818622fe
--- /dev/null
+++ b/DI-engine/ding/league/tests/conftest.py
@@ -0,0 +1,22 @@
+import numpy as np
+import pytest
+
+
+@pytest.fixture(scope='session')
+def random_job_result():
+
+ def fn():
+ p = np.random.uniform()
+ if p < 1. / 3:
+ return "wins"
+ elif p < 2. / 3:
+ return "draws"
+ else:
+ return "losses"
+
+ return fn
+
+
+@pytest.fixture(scope='session')
+def get_job_result_categories():
+ return ["wins", 'draws', 'losses']
diff --git a/DI-engine/ding/league/tests/league_test_default_config.py b/DI-engine/ding/league/tests/league_test_default_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e698ddb5c8b868ac4b23c47a4e5a4dc6743d1073
--- /dev/null
+++ b/DI-engine/ding/league/tests/league_test_default_config.py
@@ -0,0 +1,63 @@
+from easydict import EasyDict
+
+league_test_config = dict(
+ league=dict(
+ # league_type='fake',
+ import_names=['ding.league'],
+ # ---player----
+ player_category=['zerg', 'terran', 'protoss'],
+ active_players=dict(
+ main_player=1,
+ main_exploiter=1,
+ league_exploiter=2,
+ ),
+ main_player=dict(
+ branch_probs=dict(
+ pfsp=0.5,
+ sp=0.35,
+ verification=0.15,
+ ),
+ strong_win_rate=0.7,
+ one_phase_step=2000,
+ ),
+ main_exploiter=dict(
+ branch_probs=dict(main_players=1.0, ),
+ strong_win_rate=0.7,
+ one_phase_step=2000,
+ min_valid_win_rate=0.2,
+ ),
+ league_exploiter=dict(
+ branch_probs=dict(pfsp=1.0, ),
+ strong_win_rate=0.7,
+ one_phase_step=2000,
+ mutate_prob=0.25,
+ ),
+ # solo_active_player:
+ # one_phase_step=2000
+ # forward_kwargs:
+ # exploration_type=[]
+ # env_kwargs:
+ # env_num=8
+ # episode_num=2
+ # adder_kwargs:
+ # use_gae=False
+ # data_push_length=128
+ # job:
+ # agent_update_freq=30 # second
+ # compressor='none'
+ use_pretrain=True,
+ use_pretrain_init_historical=True,
+ pretrain_checkpoint_path=dict(
+ zerg='pretrain_checkpoint_zerg.pth',
+ terran='pretrain_checkpoint_terran.pth',
+ protoss='pretrain_checkpoint_protoss.pth',
+ ),
+ # ---payoff---
+ payoff=dict(
+ type='battle',
+ decay=0.99,
+ min_win_rate_games=8,
+ ),
+ ),
+)
+league_test_config = EasyDict(league_test_config)
diff --git a/DI-engine/ding/league/tests/test_league_metric.py b/DI-engine/ding/league/tests/test_league_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfbb8bc547012d830c37541d620d2dd51a589458
--- /dev/null
+++ b/DI-engine/ding/league/tests/test_league_metric.py
@@ -0,0 +1,57 @@
+import numpy as np
+import pytest
+
+from ding.league import get_elo, get_elo_array, LeagueMetricEnv
+
+
+@pytest.mark.unittest
+def test_elo_calculator():
+ game_count = np.array([[0, 1, 2], [1, 0, 0], [2, 0, 0]])
+ rating = np.array([1613, 1573, 1601])
+ result = np.array([[0, -1, -1 + 1], [1, 0, 0], [1 + (-1), 0, 0]])
+ new_rating0, new_rating1 = get_elo(rating[0], rating[1], result[0][1])
+ assert new_rating0 == 1595
+ assert new_rating1 == 1591
+
+ old_rating = np.copy(rating)
+ new_rating = get_elo_array(rating, result, game_count)
+ assert (rating == old_rating).all() # no inplace modification
+ assert new_rating.dtype == np.int64
+ assert new_rating[0] == 1578
+ assert new_rating[1] == 1591
+ assert new_rating[2] == 1586
+
+
+@pytest.mark.unittest
+def test_league_metric():
+ sigma = 25 / 3
+ env = LeagueMetricEnv(mu=0, sigma=sigma, beta=sigma / 2, tau=0.0, draw_probability=0.02, elo_init=1000)
+ r1 = env.create_rating(elo_init=1613)
+ r2 = env.create_rating(elo_init=1573)
+ assert r1.mu == 0
+ assert r2.mu == 0
+ assert r2.sigma == sigma
+ assert r2.sigma == sigma
+ assert r1.elo == 1613
+ assert r2.elo == 1573
+ # r1 draw r2
+ r1, r2 = env.rate_1vs1(r1, r2, drawn=True)
+ assert r1.mu == r2.mu
+ assert r1.elo == 1611
+ assert r2.elo == 1575
+ # r1 win r2
+ new_r1, new_r2 = env.rate_1vs1(r1, r2)
+ assert new_r1.mu > r1.mu
+ assert new_r2.mu < r2.mu
+ assert new_r1.mu + new_r2.mu == 0
+ assert pytest.approx(new_r1.mu, abs=1e-4) == 3.230
+ assert pytest.approx(new_r2.mu, abs=1e-4) == -3.230
+ assert new_r1.elo == 1625
+ assert new_r2.elo == 1561
+ # multi result
+ new_r1, new_r2 = env.rate_1vs1(r1, r2, result=['wins', 'wins', 'losses'])
+ assert new_r1.elo > 1611
+ # 1vsConstant
+ new_r1 = env.rate_1vsC(r1, env.create_rating(elo_init=1800), result=['losses', 'losses'])
+ assert new_r1.elo < 1611
+ print('final rating is: ', new_r1)
diff --git a/DI-engine/ding/league/tests/test_one_vs_one_league.py b/DI-engine/ding/league/tests/test_one_vs_one_league.py
new file mode 100644
index 0000000000000000000000000000000000000000..43a993e0b2ff46d9dcbe05eb000c7269ef18e04b
--- /dev/null
+++ b/DI-engine/ding/league/tests/test_one_vs_one_league.py
@@ -0,0 +1,183 @@
+import os
+import random
+
+import pytest
+import copy
+from easydict import EasyDict
+import torch
+
+from ding.league import create_league
+
+one_vs_one_league_default_config = dict(
+ league=dict(
+ league_type='one_vs_one',
+ import_names=["ding.league"],
+ # ---player----
+ # "player_category" is just a name. Depends on the env.
+ # For example, in StarCraft, this can be ['zerg', 'terran', 'protoss'].
+ player_category=['default'],
+ # Support different types of active players for solo and battle league.
+ # For solo league, supports ['solo_active_player'].
+ # For battle league, supports ['battle_active_player', 'main_player', 'main_exploiter', 'league_exploiter'].
+ active_players=dict(
+ naive_sp_player=1, # {player_type: player_num}
+ ),
+ naive_sp_player=dict(
+ # There should be keys ['one_phase_step', 'branch_probs', 'strong_win_rate'].
+ # Specifically for 'main_exploiter' of StarCraft, there should be an additional key ['min_valid_win_rate'].
+ one_phase_step=10,
+ branch_probs=dict(
+ pfsp=0.5,
+ sp=0.5,
+ ),
+ strong_win_rate=0.7,
+ ),
+ # "use_pretrain" means whether to use pretrain model to initialize active player.
+ use_pretrain=False,
+ # "use_pretrain_init_historical" means whether to use pretrain model to initialize historical player.
+ # "pretrain_checkpoint_path" is the pretrain checkpoint path used in "use_pretrain" and
+ # "use_pretrain_init_historical". If both are False, "pretrain_checkpoint_path" can be omitted as well.
+ # Otherwise, "pretrain_checkpoint_path" should list paths of all player categories.
+ use_pretrain_init_historical=False,
+ pretrain_checkpoint_path=dict(default='default_cate_pretrain.pth', ),
+ # ---payoff---
+ payoff=dict(
+ # Supports ['battle']
+ type='battle',
+ decay=0.99,
+ min_win_rate_games=8,
+ ),
+ path_policy='./league',
+ ),
+)
+one_vs_one_league_default_config = EasyDict(one_vs_one_league_default_config)
+
+
+def get_random_result():
+ ran = random.random()
+ if ran < 1. / 3:
+ return "wins"
+ elif ran < 1. / 2:
+ return "losses"
+ else:
+ return "draws"
+
+
+@pytest.mark.unittest
+class TestOneVsOneLeague:
+
+ def test_naive(self):
+ league = create_league(one_vs_one_league_default_config.league)
+ assert (len(league.active_players) == 1)
+ assert (len(league.historical_players) == 0)
+ active_player_ids = [p.player_id for p in league.active_players]
+ assert set(active_player_ids) == set(league.active_players_ids)
+ active_player_id = active_player_ids[0]
+
+ active_player_ckpt = league.active_players[0].checkpoint_path
+ tmp = torch.tensor([1, 2, 3])
+ path_policy = one_vs_one_league_default_config.league.path_policy
+ torch.save(tmp, active_player_ckpt)
+
+ # judge_snapshot & update_active_player
+ assert not league.judge_snapshot(active_player_id)
+ player_update_dict = {
+ 'player_id': active_player_id,
+ 'train_iteration': one_vs_one_league_default_config.league.naive_sp_player.one_phase_step * 2,
+ }
+ league.update_active_player(player_update_dict)
+ assert league.judge_snapshot(active_player_id)
+ historical_player_ids = [p.player_id for p in league.historical_players]
+ assert len(historical_player_ids) == 1
+ historical_player_id = historical_player_ids[0]
+
+ # get_job_info, eval_flag=False
+ vs_active = False
+ vs_historical = False
+ while True:
+ collect_job_info = league.get_job_info(active_player_id, eval_flag=False)
+ assert collect_job_info['agent_num'] == 2
+ assert len(collect_job_info['checkpoint_path']) == 2
+ assert collect_job_info['launch_player'] == active_player_id
+ assert collect_job_info['player_id'][0] == active_player_id
+ if collect_job_info['player_active_flag'][1]:
+ assert collect_job_info['player_id'][1] == collect_job_info['player_id'][0]
+ vs_active = True
+ else:
+ assert collect_job_info['player_id'][1] == historical_player_id
+ vs_historical = True
+ if vs_active and vs_historical:
+ break
+
+ # get_job_info, eval_flag=False
+ eval_job_info = league.get_job_info(active_player_id, eval_flag=True)
+ assert eval_job_info['agent_num'] == 1
+ assert len(eval_job_info['checkpoint_path']) == 1
+ assert eval_job_info['launch_player'] == active_player_id
+ assert eval_job_info['player_id'][0] == active_player_id
+ assert len(eval_job_info['player_id']) == 1
+ assert len(eval_job_info['player_active_flag']) == 1
+ assert eval_job_info['eval_opponent'] in league.active_players[0]._eval_opponent_difficulty
+
+ # finish_job
+
+ episode_num = 5
+ env_num = 8
+ player_id = [active_player_id, historical_player_id]
+ result = [[get_random_result() for __ in range(8)] for _ in range(5)]
+ payoff_update_info = {
+ 'launch_player': active_player_id,
+ 'player_id': player_id,
+ 'episode_num': episode_num,
+ 'env_num': env_num,
+ 'result': result,
+ }
+ league.finish_job(payoff_update_info)
+ wins = 0
+ games = episode_num * env_num
+ for i in result:
+ for j in i:
+ if j == 'wins':
+ wins += 1
+ league.payoff[league.active_players[0], league.historical_players[0]] == wins / games
+
+ os.popen("rm -rf {}".format(path_policy))
+ print("Finish!")
+
+ def test_league_info(self):
+ cfg = copy.deepcopy(one_vs_one_league_default_config.league)
+ cfg.path_policy = 'test_league_info'
+ league = create_league(cfg)
+ active_player_id = [p.player_id for p in league.active_players][0]
+ active_player_ckpt = [p.checkpoint_path for p in league.active_players][0]
+ tmp = torch.tensor([1, 2, 3])
+ torch.save(tmp, active_player_ckpt)
+ assert (len(league.active_players) == 1)
+ assert (len(league.historical_players) == 0)
+ print('\n')
+ print(repr(league.payoff))
+ print(league.player_rank(string=True))
+ league.judge_snapshot(active_player_id, force=True)
+ for i in range(10):
+ job = league.get_job_info(active_player_id, eval_flag=False)
+ payoff_update_info = {
+ 'launch_player': active_player_id,
+ 'player_id': job['player_id'],
+ 'episode_num': 2,
+ 'env_num': 4,
+ 'result': [[get_random_result() for __ in range(4)] for _ in range(2)]
+ }
+ league.finish_job(payoff_update_info)
+ # if not self-play
+ if job['player_id'][0] != job['player_id'][1]:
+ win_loss_result = sum(payoff_update_info['result'], [])
+ home = league.get_player_by_id(job['player_id'][0])
+ away = league.get_player_by_id(job['player_id'][1])
+ home.rating, away.rating = league.metric_env.rate_1vs1(home.rating, away.rating, win_loss_result)
+ print(repr(league.payoff))
+ print(league.player_rank(string=True))
+ os.popen("rm -rf {}".format(cfg.path_policy))
+
+
+if __name__ == '__main__':
+ pytest.main(["-sv", os.path.basename(__file__)])
diff --git a/DI-engine/ding/league/tests/test_payoff.py b/DI-engine/ding/league/tests/test_payoff.py
new file mode 100644
index 0000000000000000000000000000000000000000..fef32778b24185dc002a46dd618f872defe47640
--- /dev/null
+++ b/DI-engine/ding/league/tests/test_payoff.py
@@ -0,0 +1,166 @@
+import os
+from collections import defaultdict
+from copy import deepcopy
+
+import numpy as np
+import pytest
+from easydict import EasyDict
+
+from ding.league.player import Player
+from ding.league.shared_payoff import BattleRecordDict, create_payoff
+from ding.league.metric import LeagueMetricEnv
+
+env = LeagueMetricEnv()
+
+
+@pytest.mark.unittest
+class TestBattleRecordDict:
+
+ def test_init(self):
+ data1 = defaultdict(BattleRecordDict)
+ data1['test_player_0-test_player_1'] *= 1
+ assert data1['test_player_0-test_player_1']['wins'] == 0
+ assert data1['test_player_0-test_player_1']['draws'] == 0
+ assert data1['test_player_0-test_player_1']['losses'] == 0
+ assert data1['test_player_0-test_player_1']['games'] == 0
+ with pytest.raises(KeyError):
+ tmp = data1['test_player_0-test_player_1']['xxx']
+
+
+@pytest.fixture(scope='function')
+def setup_battle_shared_payoff():
+ cfg = EasyDict({'type': 'battle', 'decay': 0.99})
+ return create_payoff(cfg)
+
+
+global sp_player_count
+sp_player_count = 0
+
+
+def get_shared_payoff_player(payoff):
+ global sp_player_count
+ player = Player(
+ cfg=EasyDict(),
+ category='zerg',
+ init_payoff=payoff,
+ checkpoint_path='sp_ckpt_{}.pth'.format(sp_player_count),
+ player_id='sp_player_{}'.format(sp_player_count),
+ total_agent_step=0,
+ rating=env.create_rating(),
+ )
+ sp_player_count += 1
+ return player
+
+
+def _win_loss_reverse(result_: str, reverse_: bool) -> str:
+ if result_ == 'draws' or not reverse_:
+ return result_
+ reverse_dict = {'wins': 'losses', 'losses': 'wins'}
+ return reverse_dict[result_]
+
+
+@pytest.mark.unittest
+class TestBattleSharedPayoff:
+
+ def test_update(self, setup_battle_shared_payoff, random_job_result, get_job_result_categories):
+ N = 10
+ games_per_player = 4
+ player_list = [get_shared_payoff_player(setup_battle_shared_payoff) for _ in range(N)]
+ for p in player_list:
+ setup_battle_shared_payoff.add_player(p)
+
+ # test update exception
+ job_info = {
+ 'player_id': [player_list[0].player_id, player_list[1].player_id],
+ 'episode_num': 1,
+ 'env_num': 1,
+ 'result': [["error"]]
+ }
+ assert not setup_battle_shared_payoff.update(job_info)
+
+ for home in player_list:
+ for away in player_list:
+ if home == away:
+ continue # ignore self-play case
+ for i in range(games_per_player):
+ episode_num = 2
+ env_num = 4
+ job_result = [[random_job_result() for _ in range(env_num)] for _ in range(episode_num)]
+ job_info = {
+ 'player_id': [home.player_id, away.player_id],
+ 'episode_num': episode_num,
+ 'env_num': env_num,
+ 'result': job_result
+ }
+ key, reverse = setup_battle_shared_payoff.get_key(home.player_id, away.player_id)
+ old = deepcopy(setup_battle_shared_payoff._data[key])
+ assert setup_battle_shared_payoff.update(job_info)
+
+ decay = setup_battle_shared_payoff._decay
+ for j in job_result:
+ for i in j:
+ for k in get_job_result_categories:
+ old[k] *= decay
+ result = _win_loss_reverse(i, reverse)
+ old[result] += 1
+
+ for t in get_job_result_categories:
+ assert old[t] == setup_battle_shared_payoff._data[key][t], t
+
+ # test shared payoff
+ for p in player_list:
+ assert id(p.payoff) == id(setup_battle_shared_payoff)
+
+ def test_getitem(self, setup_battle_shared_payoff, random_job_result):
+ N = 10
+ games_per_player = 4
+ player_list = [get_shared_payoff_player(setup_battle_shared_payoff) for _ in range(N)]
+ for p in player_list:
+ setup_battle_shared_payoff.add_player(p)
+
+ # test key not in setup_battle_shared_payoff._data
+ home = player_list[0]
+ away = player_list[0]
+ key, reverse = setup_battle_shared_payoff.get_key(home.player_id, away.player_id)
+ assert key not in setup_battle_shared_payoff._data.keys()
+ win_rate = setup_battle_shared_payoff[home, away]
+ assert key in setup_battle_shared_payoff._data.keys() # set key in ``_win_rate``
+ assert len(win_rate.shape) == 1
+ assert win_rate[0] == pytest.approx(0.5) # no enough game results, return 0.5 by default
+
+ # test players list
+ for i in range(314):
+ home = np.random.choice(setup_battle_shared_payoff.players)
+ away = np.random.choice(setup_battle_shared_payoff.players)
+ env_num = 1
+ episode_num = 1
+ job_result = [[random_job_result() for _ in range(env_num)] for _ in range(episode_num)]
+ job_info = {
+ 'player_id': [home.player_id, away.player_id],
+ 'episode_num': episode_num,
+ 'env_num': env_num,
+ 'result': job_result
+ }
+ assert setup_battle_shared_payoff.update(job_info)
+ for i in range(314):
+ home_num = np.random.randint(1, N + 1)
+ home = np.random.choice(setup_battle_shared_payoff.players, home_num).tolist()
+ away_num = np.random.randint(1, N + 1)
+ away = np.random.choice(setup_battle_shared_payoff.players, away_num).tolist()
+ win_rates = setup_battle_shared_payoff[home, away]
+ assert isinstance(win_rates, np.ndarray)
+ if home_num == 1 or away_num == 1:
+ assert len(win_rates.shape) == 1
+ else:
+ assert len(win_rates.shape) == 2
+ assert win_rates.shape == (home_num, away_num)
+ assert win_rates.max() <= 1.
+ assert win_rates.min() >= 0.
+
+ # test shared payoff
+ for p in player_list:
+ assert id(p.payoff) == id(setup_battle_shared_payoff)
+
+
+if __name__ == '__main__':
+ pytest.main(["-sv", os.path.basename(__file__)])
diff --git a/DI-engine/ding/league/tests/test_player.py b/DI-engine/ding/league/tests/test_player.py
new file mode 100644
index 0000000000000000000000000000000000000000..6cfb03dd6f5469cd72491067fe27c5cc9a6344b8
--- /dev/null
+++ b/DI-engine/ding/league/tests/test_player.py
@@ -0,0 +1,312 @@
+import os
+
+import numpy as np
+import pytest
+from easydict import EasyDict
+
+from ding.league.player import Player, HistoricalPlayer, ActivePlayer, create_player
+from ding.league.shared_payoff import create_payoff
+from ding.league.starcraft_player import MainPlayer, MainExploiter, LeagueExploiter
+from ding.league.tests.league_test_default_config import league_test_config
+from ding.league.metric import LeagueMetricEnv
+
+ONE_PHASE_STEP = 2000
+env = LeagueMetricEnv()
+
+
+@pytest.fixture(scope='function')
+def setup_payoff():
+ cfg = EasyDict({'type': 'battle', 'decay': 0.99})
+ return create_payoff(cfg)
+
+
+@pytest.fixture(scope='function')
+def setup_league(setup_payoff):
+ players = []
+ for category in ['zerg', 'terran', 'protoss']:
+ # main_player
+ main_player_name = '{}_{}'.format('MainPlayer', category)
+ players.append(
+ create_player(
+ league_test_config.league, 'main_player', league_test_config.league.main_player, category, setup_payoff,
+ 'ckpt_{}.pth'.format(main_player_name), main_player_name, 0, env.create_rating()
+ )
+ )
+ # main_exloiter
+ main_exploiter_name = '{}_{}'.format('MainExploiter', category)
+ players.append(
+ create_player(
+ league_test_config.league, 'main_exploiter', league_test_config.league.main_exploiter, category,
+ setup_payoff, 'ckpt_{}.pth'.format(main_exploiter_name), main_exploiter_name, 0, env.create_rating()
+ )
+ )
+ # league_exploiter
+ league_exploiter_name = '{}_{}'.format('LeagueExploiter', category)
+ for i in range(2):
+ players.append(
+ create_player(
+ league_test_config.league,
+ 'league_exploiter',
+ league_test_config.league.league_exploiter,
+ category,
+ setup_payoff,
+ 'ckpt_{}.pth'.format(league_exploiter_name),
+ league_exploiter_name,
+ 0,
+ env.create_rating(),
+ )
+ )
+ # historical player: sl player is used as initial HistoricalPlayer
+ sl_hp_name = '{}_{}_sl'.format('MainPlayer', category)
+ players.append(
+ create_player(
+ league_test_config.league,
+ 'historical_player',
+ EasyDict(),
+ category,
+ setup_payoff,
+ 'ckpt_sl_{}'.format(sl_hp_name),
+ sl_hp_name,
+ 0,
+ env.create_rating(),
+ parent_id=main_player_name,
+ )
+ )
+ for p in players:
+ setup_payoff.add_player(p)
+ return players
+
+
+@pytest.mark.unittest
+class TestMainPlayer:
+
+ def test_get_job(self, setup_league, setup_payoff):
+ N = 10
+ # no indicated p
+ # test get_job
+ for p in setup_league:
+ if isinstance(p, MainPlayer):
+ for i in range(N):
+ job_dict = p.get_job()
+ assert isinstance(job_dict, dict)
+ opponent = job_dict['opponent']
+ assert isinstance(opponent, Player)
+ assert opponent in setup_league
+
+ # payoff = setup_league[np.random.randint(0, len(setup_league))].payoff # random select reference
+ hp_list = []
+ for p in setup_league:
+ if isinstance(p, ActivePlayer):
+ p.total_agent_step = 2 * ONE_PHASE_STEP
+ hp = p.snapshot(env)
+ hp_list.append(hp)
+ setup_payoff.add_player(hp)
+ setup_league += hp_list # 12+3 + 12
+
+ # test get_job with branch prob
+ pfsp, sp, veri = False, False, False
+ for p in setup_league:
+ if isinstance(p, MainPlayer):
+ while True:
+ job_dict = p.get_job()
+ opponent = job_dict['opponent']
+ if isinstance(opponent, HistoricalPlayer) and 'MainPlayer' in opponent.parent_id:
+ veri = True
+ elif isinstance(opponent, HistoricalPlayer):
+ pfsp = True
+ elif isinstance(opponent, MainPlayer):
+ sp = True
+ else:
+ raise Exception("Main Player selects a wrong opponent {}", type(opponent))
+ if veri and pfsp and sp:
+ break
+
+ def test_snapshot(self, setup_league, setup_payoff):
+ N = 10
+ for p in setup_league:
+ for i in range(N):
+ if isinstance(p, ActivePlayer):
+ hp = p.snapshot(env)
+ assert isinstance(hp, HistoricalPlayer)
+ assert id(hp.payoff) == id(p.payoff)
+ assert hp.parent_id == p.player_id
+
+ def test_is_trained_enough(self, setup_league, setup_payoff):
+ for p in setup_league:
+ if isinstance(p, ActivePlayer):
+ assert not p.is_trained_enough()
+ assert p._last_enough_step == 0
+ # step_passed < ONE_PHASE_STEP
+ p.total_agent_step = ONE_PHASE_STEP * 0.99
+ assert not p.is_trained_enough()
+ assert p._last_enough_step == 0
+ # ONE_PHASE_STEP < step_passed < 2*ONE_PHASE_STEP, but low win rate
+ p.total_agent_step = ONE_PHASE_STEP + 1
+ assert not p.is_trained_enough()
+ assert p._last_enough_step == 0
+
+ # prepare HistoricalPlayer
+ # payoff = setup_league[np.random.randint(0, len(setup_league))].payoff # random select reference
+ hp_list = []
+ for p in setup_league:
+ if isinstance(p, MainPlayer):
+ hp = p.snapshot(env)
+ setup_payoff.add_player(hp)
+ hp_list.append(hp)
+ setup_league += hp_list
+
+ # update 10 wins against all historical players, should be trained enough
+ N = 10
+ assert isinstance(setup_league[0], MainPlayer)
+ for n in range(N):
+ for hp in [p for p in setup_league if isinstance(p, HistoricalPlayer)]:
+ match_info = {
+ 'player_id': [setup_league[0].player_id, hp.player_id],
+ 'result': [['wins']],
+ }
+ result = setup_payoff.update(match_info)
+ assert result
+ assert setup_league[0]._total_agent_step > ONE_PHASE_STEP
+ assert setup_league[0]._last_enough_step == 0
+ assert setup_league[0]._last_enough_step != setup_league[0]._total_agent_step
+ assert setup_league[0].is_trained_enough()
+ assert setup_league[0]._last_enough_step == setup_league[0]._total_agent_step
+
+ # update 10 draws against all historical players, should be not trained enough;
+ # then update ``total_agent_step`` to 2*ONE_PHASE_STEP, should be trained enough
+ assert isinstance(setup_league[5], MainPlayer)
+ for n in range(N):
+ for hp in hp_list:
+ match_info = {
+ 'player_id': [setup_league[5].player_id, hp.player_id],
+ 'result': [['draws']],
+ }
+ result = setup_payoff.update(match_info)
+ assert result
+ assert setup_league[5]._total_agent_step > ONE_PHASE_STEP
+ assert not setup_league[5].is_trained_enough()
+ setup_league[5].total_agent_step = 2 * ONE_PHASE_STEP
+ assert setup_league[5].is_trained_enough()
+
+ def test_mutate(self, setup_league, setup_payoff):
+ # main players do not mutate
+ assert isinstance(setup_league[0], MainPlayer)
+ for _ in range(10):
+ assert setup_league[0].mutate({}) is None
+
+ def test_sp_historical(self, setup_league, setup_payoff):
+ N = 10
+ main1 = setup_league[0] # 'zerg'
+ main2 = setup_league[5] # 'terran'
+ assert isinstance(main1, MainPlayer)
+ assert isinstance(main2, MainPlayer)
+ for n in range(N):
+ match_info = {
+ 'player_id': [main1.player_id, main2.player_id],
+ 'result': [['wins']],
+ }
+ result = setup_payoff.update(match_info)
+ assert result
+ for _ in range(200):
+ opponent = main2._sp_branch()
+ condition1 = opponent.category == 'terran' or opponent.category == 'protoss'
+ # condition2 means: zerg_main_opponent is too strong, so that must choose a historical weaker one
+ condition2 = opponent.category == 'zerg' and isinstance(
+ opponent, HistoricalPlayer
+ ) and opponent.parent_id == main1.player_id
+ assert condition1 or condition2, (condition1, condition2)
+
+
+@pytest.mark.unittest
+class TestMainExploiter:
+
+ def test_get_job(self, setup_league, random_job_result, setup_payoff):
+ assert isinstance(setup_league[1], MainExploiter)
+ job_dict = setup_league[1].get_job()
+ opponent = job_dict['opponent']
+ assert isinstance(opponent, MainPlayer)
+
+ N = 10
+ # payoff = setup_league[np.random.randint(0, len(setup_league))].payoff # random select reference
+ for n in range(N):
+ for p in setup_league:
+ if isinstance(p, MainPlayer):
+ match_info = {
+ 'player_id': [setup_league[1].player_id, p.player_id],
+ 'result': [['losses']],
+ }
+ assert setup_payoff.update(match_info)
+
+ job_dict = setup_league[1].get_job()
+ opponent = job_dict['opponent']
+ # as long as main player, both active and historical are ok
+ assert (isinstance(opponent, HistoricalPlayer)
+ and 'MainPlayer' in opponent.parent_id) or isinstance(opponent, MainPlayer)
+ hp_list = []
+ for i in range(3):
+ for p in setup_league:
+ if isinstance(p, MainPlayer):
+ p.total_agent_step = (i + 1) * 2 * ONE_PHASE_STEP
+ hp = p.snapshot(env)
+ setup_payoff.add_player(hp)
+ hp_list.append(hp)
+ setup_league += hp_list
+
+ no_main_player_league = [p for p in setup_league if not isinstance(p, MainPlayer)]
+ for i in range(10000):
+ home = np.random.choice(no_main_player_league)
+ away = np.random.choice(no_main_player_league)
+ result = random_job_result()
+ match_info = {
+ 'player_id': [home.player_id, away.player_id],
+ 'result': [[result]],
+ }
+ assert setup_payoff.update(match_info)
+
+ for i in range(10):
+ job_dict = setup_league[1].get_job()
+ opponent = job_dict['opponent']
+ # as long as main player, both active and historical are ok
+ assert (isinstance(opponent, HistoricalPlayer)
+ and 'MainPlayer' in opponent.parent_id) or isinstance(opponent, MainPlayer)
+
+ def test_is_trained_enough(self, setup_league):
+ # only a few differences from `is_trained_enough` of MainPlayer
+ pass
+
+ def test_mutate(self, setup_league):
+ assert isinstance(setup_league[1], MainExploiter)
+ info = {'reset_checkpoint_path': 'pretrain_checkpoint.pth'}
+ for _ in range(10):
+ assert setup_league[1].mutate(info) == info['reset_checkpoint_path']
+
+
+@pytest.mark.unittest
+class TestLeagueExploiter:
+
+ def test_get_job(self, setup_league):
+ assert isinstance(setup_league[2], LeagueExploiter)
+ job_dict = setup_league[2].get_job()
+ opponent = job_dict['opponent']
+ assert isinstance(opponent, HistoricalPlayer)
+ assert isinstance(setup_league[3], LeagueExploiter)
+ job_dict = setup_league[3].get_job()
+ opponent = job_dict['opponent']
+ assert isinstance(opponent, HistoricalPlayer)
+
+ def test_is_trained_enough(self, setup_league):
+ # this function is the same as `is_trained_enough` of MainPlayer
+ pass
+
+ def test_mutate(self, setup_league):
+ assert isinstance(setup_league[2], LeagueExploiter)
+ info = {'reset_checkpoint_path': 'pretrain_checkpoint.pth'}
+ results = []
+ for _ in range(1000):
+ results.append(setup_league[2].mutate(info))
+ freq = len([t for t in results if t]) * 1.0 / len(results)
+ assert 0.2 <= freq <= 0.3 # approximate
+
+
+if __name__ == '__main__':
+ pytest.main(["-sv", os.path.basename(__file__)])
diff --git a/DI-engine/ding/model/__init__.py b/DI-engine/ding/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..002554cf3180d497c1824bedcf39767522b197ca
--- /dev/null
+++ b/DI-engine/ding/model/__init__.py
@@ -0,0 +1,3 @@
+from .common import *
+from .template import *
+from .wrapper import *
diff --git a/DI-engine/ding/model/common/__init__.py b/DI-engine/ding/model/common/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..4bf7d8be5aec022234fa1226f2e91ef5592d9320
--- /dev/null
+++ b/DI-engine/ding/model/common/__init__.py
@@ -0,0 +1,5 @@
+from .head import DiscreteHead, DuelingHead, DistributionHead, RainbowHead, QRDQNHead, StochasticDuelingHead, \
+ QuantileHead, FQFHead, RegressionHead, ReparameterizationHead, MultiHead, BranchingHead, head_cls_map, \
+ independent_normal_dist, AttentionPolicyHead, PopArtVHead, EnsembleHead
+from .encoder import ConvEncoder, FCEncoder, IMPALAConvEncoder
+from .utils import create_model
diff --git a/DI-engine/ding/model/common/encoder.py b/DI-engine/ding/model/common/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..82dab4808a039f4a355373d618f53f6633962dd3
--- /dev/null
+++ b/DI-engine/ding/model/common/encoder.py
@@ -0,0 +1,472 @@
+from typing import Optional, Dict, Union, List
+from functools import reduce
+import operator
+import math
+import torch
+import torch.nn as nn
+from torch.nn import functional as F
+
+from ding.torch_utils import ResFCBlock, ResBlock, Flatten, normed_linear, normed_conv2d
+from ding.torch_utils.network.dreamer import Conv2dSame, DreamerLayerNorm
+from ding.utils import SequenceType
+
+
+def prod(iterable):
+ """
+ Overview:
+ Product of all elements.(To be deprecated soon.) This function denifition is for supporting python version \
+ that under 3.8. In Python3.8 and larger, 'math.prod()' is recommended.
+ """
+ return reduce(operator.mul, iterable, 1)
+
+
+class ConvEncoder(nn.Module):
+ """
+ Overview:
+ The Convolution Encoder is used to encode 2-dim image observations.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ obs_shape: SequenceType,
+ hidden_size_list: SequenceType = [32, 64, 64, 128],
+ activation: Optional[nn.Module] = nn.ReLU(),
+ kernel_size: SequenceType = [8, 4, 3],
+ stride: SequenceType = [4, 2, 1],
+ padding: Optional[SequenceType] = None,
+ layer_norm: Optional[bool] = False,
+ norm_type: Optional[str] = None
+ ) -> None:
+ """
+ Overview:
+ Initialize the ``Convolution Encoder`` according to the provided arguments.
+ Arguments:
+ - obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, plus one or more ``input size``.
+ - hidden_size_list (:obj:`SequenceType`): Sequence of ``hidden_size`` of subsequent conv layers \
+ and the final dense layer.
+ - activation (:obj:`nn.Module`): Type of activation to use in the conv ``layers`` and ``ResBlock``. \
+ Default is ``nn.ReLU()``.
+ - kernel_size (:obj:`SequenceType`): Sequence of ``kernel_size`` of subsequent conv layers.
+ - stride (:obj:`SequenceType`): Sequence of ``stride`` of subsequent conv layers.
+ - padding (:obj:`SequenceType`): Padding added to all four sides of the input for each conv layer. \
+ See ``nn.Conv2d`` for more details. Default is ``None``.
+ - layer_norm (:obj:`bool`): Whether to use ``DreamerLayerNorm``, which is kind of special trick \
+ proposed in DreamerV3.
+ - norm_type (:obj:`str`): Type of normalization to use. See ``ding.torch_utils.network.ResBlock`` \
+ for more details. Default is ``None``.
+ """
+ super(ConvEncoder, self).__init__()
+ self.obs_shape = obs_shape
+ self.act = activation
+ self.hidden_size_list = hidden_size_list
+ if padding is None:
+ padding = [0 for _ in range(len(kernel_size))]
+
+ layers = []
+ input_size = obs_shape[0] # in_channel
+ for i in range(len(kernel_size)):
+ if layer_norm:
+ layers.append(
+ Conv2dSame(
+ in_channels=input_size,
+ out_channels=hidden_size_list[i],
+ kernel_size=(kernel_size[i], kernel_size[i]),
+ stride=(2, 2),
+ bias=False,
+ )
+ )
+ layers.append(DreamerLayerNorm(hidden_size_list[i]))
+ layers.append(self.act)
+ else:
+ layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i]))
+ layers.append(self.act)
+ input_size = hidden_size_list[i]
+ if len(self.hidden_size_list) >= len(kernel_size) + 2:
+ assert self.hidden_size_list[len(kernel_size) - 1] == self.hidden_size_list[
+ len(kernel_size)], "Please indicate the same hidden size between conv and res block"
+ assert len(
+ set(hidden_size_list[len(kernel_size):-1])
+ ) <= 1, "Please indicate the same hidden size for res block parts"
+ for i in range(len(kernel_size), len(self.hidden_size_list) - 1):
+ layers.append(ResBlock(self.hidden_size_list[i - 1], activation=self.act, norm_type=norm_type))
+ layers.append(Flatten())
+ self.main = nn.Sequential(*layers)
+
+ flatten_size = self._get_flatten_size()
+ self.output_size = hidden_size_list[-1] # outside to use
+ self.mid = nn.Linear(flatten_size, hidden_size_list[-1])
+
+ def _get_flatten_size(self) -> int:
+ """
+ Overview:
+ Get the encoding size after ``self.main`` to get the number of ``in-features`` to feed to ``nn.Linear``.
+ Returns:
+ - outputs (:obj:`torch.Tensor`): Size ``int`` Tensor representing the number of ``in-features``.
+ Shapes:
+ - outputs: :math:`(1,)`.
+ Examples:
+ >>> conv = ConvEncoder(
+ >>> obs_shape=(4, 84, 84),
+ >>> hidden_size_list=[32, 64, 64, 128],
+ >>> activation=nn.ReLU(),
+ >>> kernel_size=[8, 4, 3],
+ >>> stride=[4, 2, 1],
+ >>> padding=None,
+ >>> layer_norm=False,
+ >>> norm_type=None
+ >>> )
+ >>> flatten_size = conv._get_flatten_size()
+ """
+ test_data = torch.randn(1, *self.obs_shape)
+ with torch.no_grad():
+ output = self.main(test_data)
+ return output.shape[1]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Return output 1D embedding tensor of the env's 2D image observation.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Raw 2D observation of the environment.
+ Returns:
+ - outputs (:obj:`torch.Tensor`): Output embedding tensor.
+ Shapes:
+ - x : :math:`(B, C, H, W)`, where ``B`` is batch size, ``C`` is channel, ``H`` is height, ``W`` is width.
+ - outputs: :math:`(B, N)`, where ``N = hidden_size_list[-1]`` .
+ Examples:
+ >>> conv = ConvEncoder(
+ >>> obs_shape=(4, 84, 84),
+ >>> hidden_size_list=[32, 64, 64, 128],
+ >>> activation=nn.ReLU(),
+ >>> kernel_size=[8, 4, 3],
+ >>> stride=[4, 2, 1],
+ >>> padding=None,
+ >>> layer_norm=False,
+ >>> norm_type=None
+ >>> )
+ >>> x = torch.randn(1, 4, 84, 84)
+ >>> output = conv(x)
+ """
+ x = self.main(x)
+ x = self.mid(x)
+ return x
+
+
+class FCEncoder(nn.Module):
+ """
+ Overview:
+ The full connected encoder is used to encode 1-dim input variable.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ obs_shape: int,
+ hidden_size_list: SequenceType,
+ res_block: bool = False,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ dropout: Optional[float] = None
+ ) -> None:
+ """
+ Overview:
+ Initialize the FC Encoder according to arguments.
+ Arguments:
+ - obs_shape (:obj:`int`): Observation shape.
+ - hidden_size_list (:obj:`SequenceType`): Sequence of ``hidden_size`` of subsequent FC layers.
+ - res_block (:obj:`bool`): Whether use ``res_block``. Default is ``False``.
+ - activation (:obj:`nn.Module`): Type of activation to use in ``ResFCBlock``. Default is ``nn.ReLU()``.
+ - norm_type (:obj:`str`): Type of normalization to use. See ``ding.torch_utils.network.ResFCBlock`` \
+ for more details. Default is ``None``.
+ - dropout (:obj:`float`): Dropout rate of the dropout layer. If ``None`` then default no dropout layer.
+ """
+ super(FCEncoder, self).__init__()
+ self.obs_shape = obs_shape
+ self.act = activation
+ self.init = nn.Linear(obs_shape, hidden_size_list[0])
+
+ if res_block:
+ assert len(set(hidden_size_list)) == 1, "Please indicate the same hidden size for res block parts"
+ if len(hidden_size_list) == 1:
+ self.main = ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type, dropout=dropout)
+ else:
+ layers = []
+ for i in range(len(hidden_size_list)):
+ layers.append(
+ ResFCBlock(hidden_size_list[0], activation=self.act, norm_type=norm_type, dropout=dropout)
+ )
+ self.main = nn.Sequential(*layers)
+ else:
+ layers = []
+ for i in range(len(hidden_size_list) - 1):
+ layers.append(nn.Linear(hidden_size_list[i], hidden_size_list[i + 1]))
+ layers.append(self.act)
+ if dropout is not None:
+ layers.append(nn.Dropout(dropout))
+ self.main = nn.Sequential(*layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Return output embedding tensor of the env observation.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Env raw observation.
+ Returns:
+ - outputs (:obj:`torch.Tensor`): Output embedding tensor.
+ Shapes:
+ - x : :math:`(B, M)`, where ``M = obs_shape``.
+ - outputs: :math:`(B, N)`, where ``N = hidden_size_list[-1]``.
+ Examples:
+ >>> fc = FCEncoder(
+ >>> obs_shape=4,
+ >>> hidden_size_list=[32, 64, 64, 128],
+ >>> activation=nn.ReLU(),
+ >>> norm_type=None,
+ >>> dropout=None
+ >>> )
+ >>> x = torch.randn(1, 4)
+ >>> output = fc(x)
+ """
+ x = self.act(self.init(x))
+ x = self.main(x)
+ return x
+
+
+class StructEncoder(nn.Module):
+
+ def __init__(self, obs_shape: Dict[str, Union[int, List[int]]]) -> None:
+ super(StructEncoder, self).__init__()
+ # TODO concrete implementation
+ raise NotImplementedError
+
+
+class IMPALACnnResidualBlock(nn.Module):
+ """
+ Overview:
+ This CNN encoder residual block is residual basic block used in IMPALA algorithm,
+ which preserves the channel number and shape.
+ IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures
+ https://arxiv.org/pdf/1802.01561.pdf
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(self, in_channnel: int, scale: float = 1, batch_norm: bool = False):
+ """
+ Overview:
+ Initialize the IMPALA CNN residual block according to arguments.
+ Arguments:
+ - in_channnel (:obj:`int`): Channel number of input features.
+ - scale (:obj:`float`): Scale of module, defaults to 1.
+ - batch_norm (:obj:`bool`): Whether use batch normalization, defaults to False.
+ """
+ super().__init__()
+ self.in_channnel = in_channnel
+ self.batch_norm = batch_norm
+ s = math.sqrt(scale)
+ self.conv0 = normed_conv2d(self.in_channnel, self.in_channnel, 3, padding=1, scale=s)
+ self.conv1 = normed_conv2d(self.in_channnel, self.in_channnel, 3, padding=1, scale=s)
+ if self.batch_norm:
+ self.bn0 = nn.BatchNorm2d(self.in_channnel)
+ self.bn1 = nn.BatchNorm2d(self.in_channnel)
+
+ def residual(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Return output tensor of the residual block, keep the shape and channel number unchanged.
+ The inplace of activation function should be False for the first relu,
+ so that it does not change the origin input tensor of the residual block.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input tensor.
+ Returns:
+ - output (:obj:`torch.Tensor`): Output tensor.
+ """
+ if self.batch_norm:
+ x = self.bn0(x)
+ x = F.relu(x, inplace=False)
+ x = self.conv0(x)
+ if self.batch_norm:
+ x = self.bn1(x)
+ x = F.relu(x, inplace=True)
+ x = self.conv1(x)
+ return x
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Return output tensor of the residual block, keep the shape and channel number unchanged.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input tensor.
+ Returns:
+ - output (:obj:`torch.Tensor`): Output tensor.
+ Examples:
+ >>> block = IMPALACnnResidualBlock(16)
+ >>> x = torch.randn(1, 16, 84, 84)
+ >>> output = block(x)
+ """
+ return x + self.residual(x)
+
+
+class IMPALACnnDownStack(nn.Module):
+ """
+ Overview:
+ Downsampling stack of CNN encoder used in IMPALA algorithmn.
+ Every IMPALACnnDownStack consists n IMPALACnnResidualBlock,
+ which reduces the spatial size by 2 with maxpooling.
+ IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures
+ https://arxiv.org/pdf/1802.01561.pdf
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(self, in_channnel, nblock, out_channel, scale=1, pool=True, **kwargs):
+ """
+ Overview:
+ Initialize every impala cnn block of the Impala Cnn Encoder.
+ Arguments:
+ - in_channnel (:obj:`int`): Channel number of input features.
+ - nblock (:obj:`int`): Residual Block number in each block.
+ - out_channel (:obj:`int`): Channel number of output features.
+ - scale (:obj:`float`): Scale of the module.
+ - pool (:obj:`bool`): Whether to use maxing pooling after first conv layer.
+ """
+ super().__init__()
+ self.in_channnel = in_channnel
+ self.out_channel = out_channel
+ self.pool = pool
+ self.firstconv = normed_conv2d(in_channnel, out_channel, 3, padding=1)
+ s = scale / math.sqrt(nblock)
+ self.blocks = nn.ModuleList([IMPALACnnResidualBlock(out_channel, scale=s, **kwargs) for _ in range(nblock)])
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Return output tensor of the downsampling stack. The output shape is different from input shape. And you \
+ can refer to the ``output_shape`` method to get the output shape.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input tensor.
+ Returns:
+ - output (:obj:`torch.Tensor`): Output tensor.
+ Examples:
+ >>> stack = IMPALACnnDownStack(16, 2, 32)
+ >>> x = torch.randn(1, 16, 84, 84)
+ >>> output = stack(x)
+ """
+ x = self.firstconv(x)
+ if self.pool:
+ x = F.max_pool2d(x, kernel_size=3, stride=2, padding=1)
+ for block in self.blocks:
+ x = block(x)
+ return x
+
+ def output_shape(self, inshape: tuple) -> tuple:
+ """
+ Overview:
+ Calculate the output shape of the downsampling stack according to input shape and related arguments.
+ Arguments:
+ - inshape (:obj:`tuple`): Input shape.
+ Returns:
+ - output_shape (:obj:`tuple`): Output shape.
+ Shapes:
+ - inshape (:obj:`tuple`): :math:`(C, H, W)`, where C is channel number, H is height and W is width.
+ - output_shape (:obj:`tuple`): :math:`(C, H, W)`, where C is channel number, H is height and W is width.
+ Examples:
+ >>> stack = IMPALACnnDownStack(16, 2, 32)
+ >>> inshape = (16, 84, 84)
+ >>> output_shape = stack.output_shape(inshape)
+ """
+ c, h, w = inshape
+ assert c == self.in_channnel
+ if self.pool:
+ return (self.out_channel, (h + 1) // 2, (w + 1) // 2)
+ else:
+ return (self.out_channel, h, w)
+
+
+class IMPALAConvEncoder(nn.Module):
+ """
+ Overview:
+ IMPALA CNN encoder, which is used in IMPALA algorithm.
+ IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner Architectures, \
+ https://arxiv.org/pdf/1802.01561.pdf,
+ Interface:
+ ``__init__``, ``forward``, ``output_shape``.
+ """
+ name = "IMPALAConvEncoder" # put it here to preserve pickle compat
+
+ def __init__(
+ self,
+ obs_shape: SequenceType,
+ channels: SequenceType = (16, 32, 32),
+ outsize: int = 256,
+ scale_ob: float = 255.0,
+ nblock: int = 2,
+ final_relu: bool = True,
+ **kwargs
+ ) -> None:
+ """
+ Overview:
+ Initialize the IMPALA CNN encoder according to arguments.
+ Arguments:
+ - obs_shape (:obj:`SequenceType`): 2D image observation shape.
+ - channels (:obj:`SequenceType`): The channel number of a series of impala cnn blocks. \
+ Each element of the sequence is the output channel number of a impala cnn block.
+ - outsize (:obj:`int`): The output size the final linear layer, which means the dimension of the \
+ 1D embedding vector.
+ - scale_ob (:obj:`float`): The scale of the input observation, which is used to normalize the input \
+ observation, such as dividing 255.0 for the raw image observation.
+ - nblock (:obj:`int`): The number of Residual Block in each block.
+ - final_relu (:obj:`bool`): Whether to use ReLU activation in the final output of encoder.
+ - kwargs (:obj:`Dict[str, Any]`): Other arguments for ``IMPALACnnDownStack``.
+ """
+ super().__init__()
+ self.scale_ob = scale_ob
+ c, h, w = obs_shape
+ curshape = (c, h, w)
+ s = 1 / math.sqrt(len(channels)) # per stack scale
+ self.stacks = nn.ModuleList()
+ for out_channel in channels:
+ stack = IMPALACnnDownStack(curshape[0], nblock=nblock, out_channel=out_channel, scale=s, **kwargs)
+ self.stacks.append(stack)
+ curshape = stack.output_shape(curshape)
+ self.dense = normed_linear(prod(curshape), outsize, scale=1.4)
+ self.outsize = outsize
+ self.final_relu = final_relu
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Return the 1D embedding vector of the input 2D observation.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input 2D observation tensor.
+ Returns:
+ - output (:obj:`torch.Tensor`): Output 1D embedding vector.
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, C, H, W)`, where B is batch size, C is channel number, H is height \
+ and W is width.
+ - output (:obj:`torch.Tensor`): :math:`(B, outsize)`, where B is batch size.
+ Examples:
+ >>> encoder = IMPALAConvEncoder(
+ >>> obs_shape=(4, 84, 84),
+ >>> channels=(16, 32, 32),
+ >>> outsize=256,
+ >>> scale_ob=255.0,
+ >>> nblock=2,
+ >>> final_relu=True,
+ >>> )
+ >>> x = torch.randn(1, 4, 84, 84)
+ >>> output = encoder(x)
+ """
+ x = x / self.scale_ob
+ for (i, layer) in enumerate(self.stacks):
+ x = layer(x)
+ *batch_shape, h, w, c = x.shape
+ x = x.reshape((*batch_shape, h * w * c))
+ x = F.relu(x)
+ x = self.dense(x)
+ if self.final_relu:
+ x = torch.relu(x)
+ return x
diff --git a/DI-engine/ding/model/common/head.py b/DI-engine/ding/model/common/head.py
new file mode 100755
index 0000000000000000000000000000000000000000..1131e8a2e8bbc3d73849ef697981961e54d3950c
--- /dev/null
+++ b/DI-engine/ding/model/common/head.py
@@ -0,0 +1,1486 @@
+from typing import Optional, Dict, Union, List
+
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributions import Normal, Independent
+
+from ding.torch_utils import fc_block, noise_block, NoiseLinearLayer, MLP, PopArt, conv1d_block
+from ding.rl_utils import beta_function_map
+from ding.utils import lists_to_dicts, SequenceType
+
+
+class DiscreteHead(nn.Module):
+ """
+ Overview:
+ The ``DiscreteHead`` is used to generate discrete actions logit or Q-value logit, \
+ which is often used in q-learning algorithms or actor-critic algorithms for discrete action space.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ output_size: int,
+ layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ dropout: Optional[float] = None,
+ noise: Optional[bool] = False,
+ ) -> None:
+ """
+ Overview:
+ Init the ``DiscreteHead`` layers according to the provided arguments.
+ Arguments:
+ - hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``DiscreteHead``.
+ - output_size (:obj:`int`): The number of outputs.
+ - layer_num (:obj:`int`): The number of layers used in the network to compute Q value output.
+ - activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
+ If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
+ - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
+ for more details. Default ``None``.
+ - dropout (:obj:`float`): The dropout rate, default set to None.
+ - noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
+ Default ``False``.
+ """
+ super(DiscreteHead, self).__init__()
+ layer = NoiseLinearLayer if noise else nn.Linear
+ block = noise_block if noise else fc_block
+ self.Q = nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ layer_num,
+ layer_fn=layer,
+ activation=activation,
+ use_dropout=dropout is not None,
+ dropout_probability=dropout,
+ norm_type=norm_type
+ ), block(hidden_size, output_size)
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``DiscreteHead`` and return the prediction dictionary.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Tensor containing input embedding.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keyword ``logit`` (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - logit: :math:`(B, M)`, where ``M = output_size``.
+ Examples:
+ >>> head = DiscreteHead(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = head(inputs)
+ >>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 64])
+ """
+ logit = self.Q(x)
+ return {'logit': logit}
+
+
+class DistributionHead(nn.Module):
+ """
+ Overview:
+ The ``DistributionHead`` is used to generate distribution for Q-value.
+ This module is used in C51 algorithm.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ output_size: int,
+ layer_num: int = 1,
+ n_atom: int = 51,
+ v_min: float = -10,
+ v_max: float = 10,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ noise: Optional[bool] = False,
+ eps: Optional[float] = 1e-6,
+ ) -> None:
+ """
+ Overview:
+ Init the ``DistributionHead`` layers according to the provided arguments.
+ Arguments:
+ - hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``DistributionHead``.
+ - output_size (:obj:`int`): The number of outputs.
+ - layer_num (:obj:`int`): The number of layers used in the network to compute Q value distribution.
+ - n_atom (:obj:`int`): The number of atoms (discrete supports). Default is ``51``.
+ - v_min (:obj:`int`): Min value of atoms. Default is ``-10``.
+ - v_max (:obj:`int`): Max value of atoms. Default is ``10``.
+ - activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
+ If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
+ - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
+ for more details. Default ``None``.
+ - noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
+ Default ``False``.
+ - eps (:obj:`float`): Small constant used for numerical stability.
+ """
+ super(DistributionHead, self).__init__()
+ layer = NoiseLinearLayer if noise else nn.Linear
+ block = noise_block if noise else fc_block
+ self.Q = nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ layer_num,
+ layer_fn=layer,
+ activation=activation,
+ norm_type=norm_type
+ ), block(hidden_size, output_size * n_atom)
+ )
+ self.output_size = output_size
+ self.n_atom = n_atom
+ self.v_min = v_min
+ self.v_max = v_max
+ self.eps = eps # for numerical stability
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``DistributionHead`` and return the prediction dictionary.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Tensor containing input embedding.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``logit`` (:obj:`torch.Tensor`) and \
+ ``distribution`` (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - logit: :math:`(B, M)`, where ``M = output_size``.
+ - distribution: :math:`(B, M, n_atom)`.
+ Examples:
+ >>> head = DistributionHead(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = head(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+ >>> # default n_atom is 51
+ >>> assert outputs['distribution'].shape == torch.Size([4, 64, 51])
+ """
+ q = self.Q(x)
+ q = q.view(*q.shape[:-1], self.output_size, self.n_atom)
+ dist = torch.softmax(q, dim=-1) + self.eps
+ q = dist * torch.linspace(self.v_min, self.v_max, self.n_atom).to(x)
+ q = q.sum(-1)
+ return {'logit': q, 'distribution': dist}
+
+
+class BranchingHead(nn.Module):
+ """
+ Overview:
+ The ``BranchingHead`` is used to generate Q-value with different branches.
+ This module is used in Branch DQN.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ num_branches: int = 0,
+ action_bins_per_branch: int = 2,
+ layer_num: int = 1,
+ a_layer_num: Optional[int] = None,
+ v_layer_num: Optional[int] = None,
+ norm_type: Optional[str] = None,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ noise: Optional[bool] = False,
+ ) -> None:
+ """
+ Overview:
+ Init the ``BranchingHead`` layers according to the provided arguments. \
+ This head achieves a linear increase of the number of network outputs \
+ with the number of degrees of freedom by allowing a level of independence for each individual action.
+ Therefore, this head is suitable for high dimensional action Spaces.
+ Arguments:
+ - hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``BranchingHead``.
+ - num_branches (:obj:`int`): The number of branches, which is equivalent to the action dimension.
+ - action_bins_per_branch (:obj:int): The number of action bins in each dimension.
+ - layer_num (:obj:`int`): The number of layers used in the network to compute Advantage and Value output.
+ - a_layer_num (:obj:`int`): The number of layers used in the network to compute Advantage output.
+ - v_layer_num (:obj:`int`): The number of layers used in the network to compute Value output.
+ - output_size (:obj:`int`): The number of outputs.
+ - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
+ for more details. Default ``None``.
+ - activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
+ If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
+ - noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
+ Default ``False``.
+ """
+ super(BranchingHead, self).__init__()
+ if a_layer_num is None:
+ a_layer_num = layer_num
+ if v_layer_num is None:
+ v_layer_num = layer_num
+ self.num_branches = num_branches
+ self.action_bins_per_branch = action_bins_per_branch
+
+ layer = NoiseLinearLayer if noise else nn.Linear
+ block = noise_block if noise else fc_block
+ # value network
+
+ self.V = nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ v_layer_num,
+ layer_fn=layer,
+ activation=activation,
+ norm_type=norm_type
+ ), block(hidden_size, 1)
+ )
+ # action branching network
+ action_output_dim = action_bins_per_branch
+ self.branches = nn.ModuleList(
+ [
+ nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ a_layer_num,
+ layer_fn=layer,
+ activation=activation,
+ norm_type=norm_type
+ ), block(hidden_size, action_output_dim)
+ ) for _ in range(self.num_branches)
+ ]
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``BranchingHead`` and return the prediction dictionary.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Tensor containing input embedding.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keyword ``logit`` (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - logit: :math:`(B, M)`, where ``M = output_size``.
+ Examples:
+ >>> head = BranchingHead(64, 5, 2)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = head(inputs)
+ >>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 5, 2])
+ """
+ value_out = self.V(x)
+ value_out = torch.unsqueeze(value_out, 1)
+ action_out = []
+ for b in self.branches:
+ action_out.append(b(x))
+ action_scores = torch.stack(action_out, 1)
+ # From the paper, this implementation performs better than both the naive alternative (Q = V + A) \
+ # and the local maximum reduction method (Q = V + max(A)).
+ action_scores = action_scores - torch.mean(action_scores, 2, keepdim=True)
+ logits = value_out + action_scores
+ return {'logit': logits}
+
+
+class RainbowHead(nn.Module):
+ """
+ Overview:
+ The ``RainbowHead`` is used to generate distribution of Q-value.
+ This module is used in Rainbow DQN.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ output_size: int,
+ layer_num: int = 1,
+ n_atom: int = 51,
+ v_min: float = -10,
+ v_max: float = 10,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ noise: Optional[bool] = True,
+ eps: Optional[float] = 1e-6,
+ ) -> None:
+ """
+ Overview:
+ Init the ``RainbowHead`` layers according to the provided arguments.
+ Arguments:
+ - hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``RainbowHead``.
+ - output_size (:obj:`int`): The number of outputs.
+ - layer_num (:obj:`int`): The number of layers used in the network to compute Q value output.
+ - n_atom (:obj:`int`): The number of atoms (discrete supports). Default is ``51``.
+ - v_min (:obj:`int`): Min value of atoms. Default is ``-10``.
+ - v_max (:obj:`int`): Max value of atoms. Default is ``10``.
+ - activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
+ If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
+ - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
+ for more details. Default ``None``.
+ - noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
+ Default ``False``.
+ - eps (:obj:`float`): Small constant used for numerical stability.
+ """
+ super(RainbowHead, self).__init__()
+ layer = NoiseLinearLayer if noise else nn.Linear
+ block = noise_block if noise else fc_block
+ self.A = nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ layer_num,
+ layer_fn=layer,
+ activation=activation,
+ norm_type=norm_type
+ ), block(hidden_size, output_size * n_atom)
+ )
+ self.Q = nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ layer_num,
+ layer_fn=layer,
+ activation=activation,
+ norm_type=norm_type
+ ), block(hidden_size, n_atom)
+ )
+ self.output_size = output_size
+ self.n_atom = n_atom
+ self.v_min = v_min
+ self.v_max = v_max
+ self.eps = eps
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``RainbowHead`` and return the prediction dictionary.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Tensor containing input embedding.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``logit`` (:obj:`torch.Tensor`) and \
+ ``distribution`` (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - logit: :math:`(B, M)`, where ``M = output_size``.
+ - distribution: :math:`(B, M, n_atom)`.
+ Examples:
+ >>> head = RainbowHead(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = head(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+ >>> # default n_atom is 51
+ >>> assert outputs['distribution'].shape == torch.Size([4, 64, 51])
+ """
+ a = self.A(x)
+ q = self.Q(x)
+ a = a.view(*a.shape[:-1], self.output_size, self.n_atom)
+ q = q.view(*q.shape[:-1], 1, self.n_atom)
+ q = q + a - a.mean(dim=-2, keepdim=True)
+ dist = torch.softmax(q, dim=-1) + self.eps
+ q = dist * torch.linspace(self.v_min, self.v_max, self.n_atom).to(x)
+ q = q.sum(-1)
+ return {'logit': q, 'distribution': dist}
+
+
+class QRDQNHead(nn.Module):
+ """
+ Overview:
+ The ``QRDQNHead`` (Quantile Regression DQN) is used to output action quantiles.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ output_size: int,
+ layer_num: int = 1,
+ num_quantiles: int = 32,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ noise: Optional[bool] = False,
+ ) -> None:
+ """
+ Overview:
+ Init the ``QRDQNHead`` layers according to the provided arguments.
+ Arguments:
+ - hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``QRDQNHead``.
+ - output_size (:obj:`int`): The number of outputs.
+ - layer_num (:obj:`int`): The number of layers used in the network to compute Q value output.
+ - num_quantiles (:obj:`int`): The number of quantiles. Default is ``32``.
+ - activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
+ If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
+ - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
+ for more details. Default ``None``.
+ - noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
+ Default ``False``.
+ """
+ super(QRDQNHead, self).__init__()
+ layer = NoiseLinearLayer if noise else nn.Linear
+ block = noise_block if noise else fc_block
+ self.Q = nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ layer_num,
+ layer_fn=layer,
+ activation=activation,
+ norm_type=norm_type
+ ), block(hidden_size, output_size * num_quantiles)
+ )
+ self.num_quantiles = num_quantiles
+ self.output_size = output_size
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``QRDQNHead`` and return the prediction dictionary.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Tensor containing input embedding.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``logit`` (:obj:`torch.Tensor`), \
+ ``q`` (:obj:`torch.Tensor`), and ``tau`` (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - logit: :math:`(B, M)`, where ``M = output_size``.
+ - q: :math:`(B, M, num_quantiles)`.
+ - tau: :math:`(B, M, 1)`.
+ Examples:
+ >>> head = QRDQNHead(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = head(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+ >>> # default num_quantiles is 32
+ >>> assert outputs['q'].shape == torch.Size([4, 64, 32])
+ >>> assert outputs['tau'].shape == torch.Size([4, 32, 1])
+ """
+ q = self.Q(x)
+ q = q.view(*q.shape[:-1], self.output_size, self.num_quantiles)
+
+ logit = q.mean(-1)
+ tau = torch.linspace(0, 1, self.num_quantiles + 1)
+ tau = ((tau[:-1] + tau[1:]) / 2).view(1, -1, 1).repeat(q.shape[0], 1, 1).to(q)
+ return {'logit': logit, 'q': q, 'tau': tau}
+
+
+class QuantileHead(nn.Module):
+ """
+ Overview:
+ The ``QuantileHead`` is used to output action quantiles.
+ This module is used in IQN.
+ Interfaces:
+ ``__init__``, ``forward``, ``quantile_net``.
+
+ .. note::
+ The difference between ``QuantileHead`` and ``QRDQNHead`` is that ``QuantileHead`` models the \
+ state-action quantile function as a mapping from state-actions and samples from some base distribution \
+ while ``QRDQNHead`` approximates random returns by a uniform mixture of Diracs functions.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ output_size: int,
+ layer_num: int = 1,
+ num_quantiles: int = 32,
+ quantile_embedding_size: int = 128,
+ beta_function_type: Optional[str] = 'uniform',
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ noise: Optional[bool] = False,
+ ) -> None:
+ """
+ Overview:
+ Init the ``QuantileHead`` layers according to the provided arguments.
+ Arguments:
+ - hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``QuantileHead``.
+ - output_size (:obj:`int`): The number of outputs.
+ - layer_num (:obj:`int`): The number of layers used in the network to compute Q value output.
+ - num_quantiles (:obj:`int`): The number of quantiles.
+ - quantile_embedding_size (:obj:`int`): The embedding size of a quantile.
+ - beta_function_type (:obj:`str`): Type of beta function. See ``ding.rl_utils.beta_function.py`` \
+ for more details. Default is ``uniform``.
+ - activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
+ If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
+ - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
+ for more details. Default ``None``.
+ - noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
+ Default ``False``.
+ """
+ super(QuantileHead, self).__init__()
+ layer = NoiseLinearLayer if noise else nn.Linear
+ block = noise_block if noise else fc_block
+ self.Q = nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ layer_num,
+ layer_fn=layer,
+ activation=activation,
+ norm_type=norm_type
+ ), block(hidden_size, output_size)
+ )
+ self.num_quantiles = num_quantiles
+ self.quantile_embedding_size = quantile_embedding_size
+ self.output_size = output_size
+ self.iqn_fc = nn.Linear(self.quantile_embedding_size, hidden_size)
+ self.beta_function = beta_function_map[beta_function_type]
+
+ def quantile_net(self, quantiles: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Deterministic parametric function trained to reparameterize samples from a base distribution. \
+ By repeated Bellman update iterations of Q-learning, the optimal action-value function is estimated.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The encoded embedding tensor of parametric sample.
+ Returns:
+ - quantile_net (:obj:`torch.Tensor`): Quantile network output tensor after reparameterization.
+ Shapes:
+ - quantile_net :math:`(quantile_embedding_size, M)`, where ``M = output_size``.
+ Examples:
+ >>> head = QuantileHead(64, 64)
+ >>> quantiles = torch.randn(128,1)
+ >>> qn_output = head.quantile_net(quantiles)
+ >>> assert isinstance(qn_output, torch.Tensor)
+ >>> # default quantile_embedding_size: int = 128,
+ >>> assert qn_output.shape == torch.Size([128, 64])
+ """
+ quantile_net = quantiles.repeat([1, self.quantile_embedding_size])
+ quantile_net = torch.cos(
+ torch.arange(1, self.quantile_embedding_size + 1, 1).to(quantiles) * math.pi * quantile_net
+ )
+ quantile_net = self.iqn_fc(quantile_net)
+ quantile_net = F.relu(quantile_net)
+ return quantile_net
+
+ def forward(self, x: torch.Tensor, num_quantiles: Optional[int] = None) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``QuantileHead`` and return the prediction dictionary.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Tensor containing input embedding.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``logit`` (:obj:`torch.Tensor`), \
+ ``q`` (:obj:`torch.Tensor`), and ``quantiles`` (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - logit: :math:`(B, M)`, where ``M = output_size``.
+ - q: :math:`(num_quantiles, B, M)`.
+ - quantiles: :math:`(quantile_embedding_size, 1)`.
+ Examples:
+ >>> head = QuantileHead(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = head(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+ >>> # default num_quantiles is 32
+ >>> assert outputs['q'].shape == torch.Size([32, 4, 64])
+ >>> assert outputs['quantiles'].shape == torch.Size([128, 1])
+ """
+
+ if num_quantiles is None:
+ num_quantiles = self.num_quantiles
+ batch_size = x.shape[0]
+
+ q_quantiles = torch.FloatTensor(num_quantiles * batch_size, 1).uniform_(0, 1).to(x)
+ logit_quantiles = torch.FloatTensor(num_quantiles * batch_size, 1).uniform_(0, 1).to(x)
+ logit_quantiles = self.beta_function(logit_quantiles)
+ q_quantile_net = self.quantile_net(q_quantiles)
+ logit_quantile_net = self.quantile_net(logit_quantiles)
+
+ x = x.repeat(num_quantiles, 1)
+ q_x = x * q_quantile_net # 4*32,64
+ logit_x = x * logit_quantile_net
+
+ q = self.Q(q_x).reshape(num_quantiles, batch_size, -1)
+ logit = self.Q(logit_x).reshape(num_quantiles, batch_size, -1).mean(0)
+
+ return {'logit': logit, 'q': q, 'quantiles': q_quantiles}
+
+
+class FQFHead(nn.Module):
+ """
+ Overview:
+ The ``FQFHead`` is used to output action quantiles.
+ This module is used in FQF.
+ Interfaces:
+ ``__init__``, ``forward``, ``quantile_net``.
+
+ .. note::
+ The implementation of FQFHead is based on the paper https://arxiv.org/abs/1911.02140.
+ The difference between FQFHead and QuantileHead is that, in FQF, \
+ N adjustable quantile values for N adjustable quantile fractions are estimated to approximate \
+ the quantile function. The distribution of the return is approximated by a weighted mixture of N \
+ Diracs functions. While in IQN, the state-action quantile function is modeled as a mapping from \
+ state-actions and samples from some base distribution.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ output_size: int,
+ layer_num: int = 1,
+ num_quantiles: int = 32,
+ quantile_embedding_size: int = 128,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ noise: Optional[bool] = False,
+ ) -> None:
+ """
+ Overview:
+ Init the ``FQFHead`` layers according to the provided arguments.
+ Arguments:
+ - hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``FQFHead``.
+ - output_size (:obj:`int`): The number of outputs.
+ - layer_num (:obj:`int`): The number of layers used in the network to compute Q value output.
+ - num_quantiles (:obj:`int`): The number of quantiles.
+ - quantile_embedding_size (:obj:`int`): The embedding size of a quantile.
+ - activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
+ If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
+ - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
+ for more details. Default ``None``.
+ - noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
+ Default ``False``.
+ """
+ super(FQFHead, self).__init__()
+ layer = NoiseLinearLayer if noise else nn.Linear
+ block = noise_block if noise else fc_block
+ self.Q = nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ layer_num,
+ layer_fn=layer,
+ activation=activation,
+ norm_type=norm_type
+ ), block(hidden_size, output_size)
+ )
+ self.num_quantiles = num_quantiles
+ self.quantile_embedding_size = quantile_embedding_size
+ self.output_size = output_size
+ self.fqf_fc = nn.Sequential(nn.Linear(self.quantile_embedding_size, hidden_size), nn.ReLU())
+ self.register_buffer(
+ 'sigma_pi',
+ torch.arange(1, self.quantile_embedding_size + 1, 1).view(1, 1, self.quantile_embedding_size) * math.pi
+ )
+ # initialize weights_xavier of quantiles_proposal network
+ # NOTE(rjy): quantiles_proposal network mean fraction proposal network
+ quantiles_proposal_fc = nn.Linear(hidden_size, num_quantiles)
+ torch.nn.init.xavier_uniform_(quantiles_proposal_fc.weight, gain=0.01)
+ torch.nn.init.constant_(quantiles_proposal_fc.bias, 0)
+ self.quantiles_proposal = nn.Sequential(quantiles_proposal_fc, nn.LogSoftmax(dim=1))
+
+ def quantile_net(self, quantiles: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Deterministic parametric function trained to reparameterize samples from the quantiles_proposal network. \
+ By repeated Bellman update iterations of Q-learning, the optimal action-value function is estimated.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The encoded embedding tensor of parametric sample.
+ Returns:
+ - quantile_net (:obj:`torch.Tensor`): Quantile network output tensor after reparameterization.
+ Examples:
+ >>> head = FQFHead(64, 64)
+ >>> quantiles = torch.randn(4,32)
+ >>> qn_output = head.quantile_net(quantiles)
+ >>> assert isinstance(qn_output, torch.Tensor)
+ >>> # default quantile_embedding_size: int = 128,
+ >>> assert qn_output.shape == torch.Size([4, 32, 64])
+ """
+ batch_size, num_quantiles = quantiles.shape[:2]
+ quantile_net = torch.cos(self.sigma_pi.to(quantiles) * quantiles.view(batch_size, num_quantiles, 1))
+ quantile_net = self.fqf_fc(quantile_net) # (batch_size, num_quantiles, hidden_size)
+ return quantile_net
+
+ def forward(self, x: torch.Tensor, num_quantiles: Optional[int] = None) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``FQFHead`` and return the prediction dictionary.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Tensor containing input embedding.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``logit`` (:obj:`torch.Tensor`), \
+ ``q`` (:obj:`torch.Tensor`), ``quantiles`` (:obj:`torch.Tensor`), \
+ ``quantiles_hats`` (:obj:`torch.Tensor`), \
+ ``q_tau_i`` (:obj:`torch.Tensor`), ``entropies`` (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - logit: :math:`(B, M)`, where ``M = output_size``.
+ - q: :math:`(B, num_quantiles, M)`.
+ - quantiles: :math:`(B, num_quantiles + 1)`.
+ - quantiles_hats: :math:`(B, num_quantiles)`.
+ - q_tau_i: :math:`(B, num_quantiles - 1, M)`.
+ - entropies: :math:`(B, 1)`.
+ Examples:
+ >>> head = FQFHead(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = head(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+ >>> # default num_quantiles is 32
+ >>> assert outputs['q'].shape == torch.Size([4, 32, 64])
+ >>> assert outputs['quantiles'].shape == torch.Size([4, 33])
+ >>> assert outputs['quantiles_hats'].shape == torch.Size([4, 32])
+ >>> assert outputs['q_tau_i'].shape == torch.Size([4, 31, 64])
+ >>> assert outputs['quantiles'].shape == torch.Size([4, 1])
+ """
+
+ if num_quantiles is None:
+ num_quantiles = self.num_quantiles
+ batch_size = x.shape[0]
+
+ log_q_quantiles = self.quantiles_proposal(
+ x.detach()
+ ) # (batch_size, num_quantiles), not to update encoder when learning w1_loss(fraction loss)
+ q_quantiles = log_q_quantiles.exp() # NOTE(rjy): e^log_q = q
+
+ # Calculate entropies of value distributions.
+ entropies = -(log_q_quantiles * q_quantiles).sum(dim=-1, keepdim=True) # (batch_size, 1)
+ assert entropies.shape == (batch_size, 1)
+
+ # accumalative softmax
+ # NOTE(rjy): because quantiles are still expressed in the form of their respective proportions,
+ # e.g. [0.33, 0.33, 0.33] => [0.33, 0.66, 0.99]
+ q_quantiles = torch.cumsum(q_quantiles, dim=1)
+
+ # quantile_hats: find the optimal condition for τ to minimize W1(Z, τ)
+ tau_0 = torch.zeros((batch_size, 1)).to(x)
+ q_quantiles = torch.cat((tau_0, q_quantiles), dim=1) # [batch_size, num_quantiles+1]
+
+ # NOTE(rjy): theta_i = F^(-1)_Z((tau_i+tau_i+1)/2), τ^ = (tau_i+tau_i+1)/2, q_quantiles_hats is τ^
+ q_quantiles_hats = (q_quantiles[:, 1:] + q_quantiles[:, :-1]).detach() / 2. # (batch_size, num_quantiles)
+
+ # NOTE(rjy): reparameterize q_quantiles_hats
+ q_quantile_net = self.quantile_net(q_quantiles_hats) # [batch_size, num_quantiles, hidden_size(64)]
+ # x.view[batch_size, 1, hidden_size(64)]
+ q_x = (x.view(batch_size, 1, -1) * q_quantile_net) # [batch_size, num_quantiles, hidden_size(64)]
+
+ q = self.Q(q_x) # [batch_size, num_quantiles, action_dim(64)]
+
+ logit = q.mean(1)
+ with torch.no_grad():
+ q_tau_i_net = self.quantile_net(
+ q_quantiles[:, 1:-1].detach()
+ ) # [batch_size, num_quantiles-1, hidden_size(64)]
+ q_tau_i_x = (x.view(batch_size, 1, -1) * q_tau_i_net) # [batch_size, (num_quantiles-1), hidden_size(64)]
+
+ q_tau_i = self.Q(q_tau_i_x) # [batch_size, num_quantiles-1, action_dim]
+
+ return {
+ 'logit': logit,
+ 'q': q,
+ 'quantiles': q_quantiles,
+ 'quantiles_hats': q_quantiles_hats,
+ 'q_tau_i': q_tau_i,
+ 'entropies': entropies
+ }
+
+
+class DuelingHead(nn.Module):
+ """
+ Overview:
+ The ``DuelingHead`` is used to output discrete actions logit.
+ This module is used in Dueling DQN.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ output_size: int,
+ layer_num: int = 1,
+ a_layer_num: Optional[int] = None,
+ v_layer_num: Optional[int] = None,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ dropout: Optional[float] = None,
+ noise: Optional[bool] = False,
+ ) -> None:
+ """
+ Overview:
+ Init the ``DuelingHead`` layers according to the provided arguments.
+ Arguments:
+ - hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``DuelingHead``.
+ - output_size (:obj:`int`): The number of outputs.
+ - a_layer_num (:obj:`int`): The number of layers used in the network to compute action output.
+ - v_layer_num (:obj:`int`): The number of layers used in the network to compute value output.
+ - activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
+ If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
+ - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
+ for more details. Default ``None``.
+ - dropout (:obj:`float`): The dropout rate of dropout layer. Default ``None``.
+ - noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
+ Default ``False``.
+ """
+ super(DuelingHead, self).__init__()
+ if a_layer_num is None:
+ a_layer_num = layer_num
+ if v_layer_num is None:
+ v_layer_num = layer_num
+ layer = NoiseLinearLayer if noise else nn.Linear
+ block = noise_block if noise else fc_block
+ self.A = nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ a_layer_num,
+ layer_fn=layer,
+ activation=activation,
+ use_dropout=dropout is not None,
+ dropout_probability=dropout,
+ norm_type=norm_type
+ ), block(hidden_size, output_size)
+ )
+ self.V = nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ v_layer_num,
+ layer_fn=layer,
+ activation=activation,
+ use_dropout=dropout is not None,
+ dropout_probability=dropout,
+ norm_type=norm_type
+ ), block(hidden_size, 1)
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``DuelingHead`` and return the prediction dictionary.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Tensor containing input embedding.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keyword ``logit`` (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - logit: :math:`(B, M)`, where ``M = output_size``.
+ Examples:
+ >>> head = DuelingHead(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = head(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+ """
+ a = self.A(x)
+ v = self.V(x)
+ q_value = a - a.mean(dim=-1, keepdim=True) + v
+ return {'logit': q_value}
+
+
+class StochasticDuelingHead(nn.Module):
+ """
+ Overview:
+ The ``Stochastic Dueling Network`` is proposed in paper ACER (arxiv 1611.01224). \
+ That is to say, dueling network architecture in continuous action space.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ action_shape: int,
+ layer_num: int = 1,
+ a_layer_num: Optional[int] = None,
+ v_layer_num: Optional[int] = None,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ noise: Optional[bool] = False,
+ last_tanh: Optional[bool] = True,
+ ) -> None:
+ """
+ Overview:
+ Init the ``Stochastic DuelingHead`` layers according to the provided arguments.
+ Arguments:
+ - hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``StochasticDuelingHead``.
+ - action_shape (:obj:`int`): The number of continuous action shape, usually integer value.
+ - layer_num (:obj:`int`): The number of default layers used in the network to compute action and value \
+ output.
+ - a_layer_num (:obj:`int`): The number of layers used in the network to compute action output. Default is \
+ ``layer_num``.
+ - v_layer_num (:obj:`int`): The number of layers used in the network to compute value output. Default is \
+ ``layer_num``.
+ - activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
+ If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
+ - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
+ for more details. Default ``None``.
+ - noise (:obj:`bool`): Whether use ``NoiseLinearLayer`` as ``layer_fn`` in Q networks' MLP. \
+ Default ``False``.
+ - last_tanh (:obj:`bool`): If ``True`` Apply ``tanh`` to actions. Default ``True``.
+ """
+ super(StochasticDuelingHead, self).__init__()
+ if a_layer_num is None:
+ a_layer_num = layer_num
+ if v_layer_num is None:
+ v_layer_num = layer_num
+ layer = NoiseLinearLayer if noise else nn.Linear
+ block = noise_block if noise else fc_block
+ self.A = nn.Sequential(
+ MLP(
+ hidden_size + action_shape,
+ hidden_size,
+ hidden_size,
+ a_layer_num,
+ layer_fn=layer,
+ activation=activation,
+ norm_type=norm_type
+ ), block(hidden_size, 1)
+ )
+ self.V = nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ v_layer_num,
+ layer_fn=layer,
+ activation=activation,
+ norm_type=norm_type
+ ), block(hidden_size, 1)
+ )
+ if last_tanh:
+ self.tanh = nn.Tanh()
+ else:
+ self.tanh = None
+
+ def forward(
+ self,
+ s: torch.Tensor,
+ a: torch.Tensor,
+ mu: torch.Tensor,
+ sigma: torch.Tensor,
+ sample_size: int = 10,
+ ) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``StochasticDuelingHead`` and return the prediction dictionary.
+ Arguments:
+ - s (:obj:`torch.Tensor`): Tensor containing input embedding.
+ - a (:obj:`torch.Tensor`): The original continuous behaviour action.
+ - mu (:obj:`torch.Tensor`): The ``mu`` gaussian reparameterization output of actor head at current \
+ timestep.
+ - sigma (:obj:`torch.Tensor`): The ``sigma`` gaussian reparameterization output of actor head at \
+ current timestep.
+ - sample_size (:obj:`int`): The number of samples for continuous action when computing the Q value.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords \
+ ``q_value`` (:obj:`torch.Tensor`) and ``v_value`` (:obj:`torch.Tensor`).
+ Shapes:
+ - s: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - a: :math:`(B, A)`, where ``A = action_size``.
+ - mu: :math:`(B, A)`.
+ - sigma: :math:`(B, A)`.
+ - q_value: :math:`(B, 1)`.
+ - v_value: :math:`(B, 1)`.
+ Examples:
+ >>> head = StochasticDuelingHead(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> a = torch.randn(4, 64)
+ >>> mu = torch.randn(4, 64)
+ >>> sigma = torch.ones(4, 64)
+ >>> outputs = head(inputs, a, mu, sigma)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['q_value'].shape == torch.Size([4, 1])
+ >>> assert outputs['v_value'].shape == torch.Size([4, 1])
+ """
+
+ batch_size = s.shape[0] # batch_size or batch_size * T
+ hidden_size = s.shape[1]
+ action_size = a.shape[1]
+ state_cat_action = torch.cat((s, a), dim=1) # size (B, action_size + state_size)
+ a_value = self.A(state_cat_action) # size (B, 1)
+ v_value = self.V(s) # size (B, 1)
+ # size (B, sample_size, hidden_size)
+ expand_s = (torch.unsqueeze(s, 1)).expand((batch_size, sample_size, hidden_size))
+
+ # in case for gradient back propagation
+ dist = Independent(Normal(mu, sigma), 1)
+ action_sample = dist.rsample(sample_shape=(sample_size, ))
+ if self.tanh:
+ action_sample = self.tanh(action_sample)
+ # (sample_size, B, action_size)->(B, sample_size, action_size)
+ action_sample = action_sample.permute(1, 0, 2)
+
+ # size (B, sample_size, action_size + hidden_size)
+ state_cat_action_sample = torch.cat((expand_s, action_sample), dim=-1)
+ a_val_sample = self.A(state_cat_action_sample) # size (B, sample_size, 1)
+ q_value = v_value + a_value - a_val_sample.mean(dim=1) # size (B, 1)
+
+ return {'q_value': q_value, 'v_value': v_value}
+
+
+class RegressionHead(nn.Module):
+ """
+ Overview:
+ The ``RegressionHead`` is used to regress continuous variables.
+ This module is used for generating Q-value (DDPG critic) of continuous actions, \
+ or state value (A2C/PPO), or directly predicting continuous action (DDPG actor).
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ layer_num: int = 2,
+ final_tanh: Optional[bool] = False,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ hidden_size: int = None,
+ ) -> None:
+ """
+ Overview:
+ Init the ``RegressionHead`` layers according to the provided arguments.
+ Arguments:
+ - hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``RegressionHead``.
+ - output_size (:obj:`int`): The number of outputs.
+ - layer_num (:obj:`int`): The number of layers used in the network to compute Q value output.
+ - final_tanh (:obj:`bool`): If ``True`` apply ``tanh`` to output. Default ``False``.
+ - activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
+ If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
+ - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
+ for more details. Default ``None``.
+ """
+ super(RegressionHead, self).__init__()
+ if hidden_size is None:
+ hidden_size = input_size
+ self.main = MLP(input_size, hidden_size, hidden_size, layer_num, activation=activation, norm_type=norm_type)
+ self.last = nn.Linear(hidden_size, output_size) # for convenience of special initialization
+ self.final_tanh = final_tanh
+ if self.final_tanh:
+ self.tanh = nn.Tanh()
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``RegressionHead`` and return the prediction dictionary.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Tensor containing input embedding.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keyword ``pred`` (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - pred: :math:`(B, M)`, where ``M = output_size``.
+ Examples:
+ >>> head = RegressionHead(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = head(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['pred'].shape == torch.Size([4, 64])
+ """
+ x = self.main(x)
+ x = self.last(x)
+ if self.final_tanh:
+ x = self.tanh(x)
+ if x.shape[-1] == 1 and len(x.shape) > 1:
+ x = x.squeeze(-1)
+ return {'pred': x}
+
+
+class ReparameterizationHead(nn.Module):
+ """
+ Overview:
+ The ``ReparameterizationHead`` is used to generate Gaussian distribution of continuous variable, \
+ which is parameterized by ``mu`` and ``sigma``.
+ This module is often used in stochastic policies, such as PPO and SAC.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+ # The "happo" type here is to align with the sigma initialization method of the network in the original HAPPO \
+ # paper. The code here needs to be optimized later.
+ default_sigma_type = ['fixed', 'independent', 'conditioned', 'happo']
+ default_bound_type = ['tanh', None]
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ layer_num: int = 2,
+ sigma_type: Optional[str] = None,
+ fixed_sigma_value: Optional[float] = 1.0,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ bound_type: Optional[str] = None,
+ hidden_size: int = None
+ ) -> None:
+ """
+ Overview:
+ Init the ``ReparameterizationHead`` layers according to the provided arguments.
+ Arguments:
+ - hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``ReparameterizationHead``.
+ - output_size (:obj:`int`): The number of outputs.
+ - layer_num (:obj:`int`): The number of layers used in the network to compute Q value output.
+ - sigma_type (:obj:`str`): Sigma type used. Choose among \
+ ``['fixed', 'independent', 'conditioned']``. Default is ``None``.
+ - fixed_sigma_value (:obj:`float`): When choosing ``fixed`` type, the tensor ``output['sigma']`` \
+ is filled with this input value. Default is ``None``.
+ - activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
+ If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
+ - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
+ for more details. Default ``None``.
+ - bound_type (:obj:`str`): Bound type to apply to output ``mu``. Choose among ``['tanh', None]``. \
+ Default is ``None``.
+ """
+ super(ReparameterizationHead, self).__init__()
+ if hidden_size is None:
+ hidden_size = input_size
+ self.sigma_type = sigma_type
+ assert sigma_type in self.default_sigma_type, "Please indicate sigma_type as one of {}".format(
+ self.default_sigma_type
+ )
+ self.bound_type = bound_type
+ assert bound_type in self.default_bound_type, "Please indicate bound_type as one of {}".format(
+ self.default_bound_type
+ )
+ self.main = MLP(input_size, hidden_size, hidden_size, layer_num, activation=activation, norm_type=norm_type)
+ self.mu = nn.Linear(hidden_size, output_size)
+ if self.sigma_type == 'fixed':
+ self.sigma = torch.full((1, output_size), fixed_sigma_value)
+ elif self.sigma_type == 'independent': # independent parameter
+ self.log_sigma_param = nn.Parameter(torch.zeros(1, output_size))
+ elif self.sigma_type == 'conditioned':
+ self.log_sigma_layer = nn.Linear(hidden_size, output_size)
+ elif self.sigma_type == 'happo':
+ self.sigma_x_coef = 1.
+ self.sigma_y_coef = 0.5
+ # This parameter (x_coef, y_coef) refers to the HAPPO paper http://arxiv.org/abs/2109.11251.
+ self.log_sigma_param = nn.Parameter(torch.ones(1, output_size) * self.sigma_x_coef)
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``ReparameterizationHead`` and return the prediction \
+ dictionary.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Tensor containing input embedding.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``mu`` (:obj:`torch.Tensor`) and ``sigma`` \
+ (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - mu: :math:`(B, M)`, where ``M = output_size``.
+ - sigma: :math:`(B, M)`.
+ Examples:
+ >>> head = ReparameterizationHead(64, 64, sigma_type='fixed')
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = head(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['mu'].shape == torch.Size([4, 64])
+ >>> assert outputs['sigma'].shape == torch.Size([4, 64])
+ """
+ x = self.main(x)
+ mu = self.mu(x)
+ if self.bound_type == 'tanh':
+ mu = torch.tanh(mu)
+ if self.sigma_type == 'fixed':
+ sigma = self.sigma.to(mu.device) + torch.zeros_like(mu) # addition aims to broadcast shape
+ elif self.sigma_type == 'independent':
+ log_sigma = self.log_sigma_param + torch.zeros_like(mu) # addition aims to broadcast shape
+ sigma = torch.exp(log_sigma)
+ elif self.sigma_type == 'conditioned':
+ log_sigma = self.log_sigma_layer(x)
+ sigma = torch.exp(torch.clamp(log_sigma, -20, 2))
+ elif self.sigma_type == 'happo':
+ log_sigma = self.log_sigma_param + torch.zeros_like(mu)
+ sigma = torch.sigmoid(log_sigma / self.sigma_x_coef) * self.sigma_y_coef
+ return {'mu': mu, 'sigma': sigma}
+
+
+class PopArtVHead(nn.Module):
+ """
+ Overview:
+ The ``PopArtVHead`` is used to generate adaptive normalized state value. More information can be found in \
+ paper Multi-task Deep Reinforcement Learning with PopArt. \
+ https://arxiv.org/abs/1809.04474 \
+ This module is used in PPO or IMPALA.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ hidden_size: int,
+ output_size: int,
+ layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ ) -> None:
+ """
+ Overview:
+ Init the ``PopArtVHead`` layers according to the provided arguments.
+ Arguments:
+ - hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to ``PopArtVHead``.
+ - output_size (:obj:`int`): The number of outputs.
+ - layer_num (:obj:`int`): The number of layers used in the network to compute Q value output.
+ - activation (:obj:`nn.Module`): The type of activation function to use in MLP. \
+ If ``None``, then default set activation to ``nn.ReLU()``. Default ``None``.
+ - norm_type (:obj:`str`): The type of normalization to use. See ``ding.torch_utils.network.fc_block`` \
+ for more details. Default ``None``.
+ """
+ super(PopArtVHead, self).__init__()
+ self.popart = PopArt(hidden_size, output_size)
+ self.Q = nn.Sequential(
+ MLP(
+ hidden_size,
+ hidden_size,
+ hidden_size,
+ layer_num,
+ layer_fn=nn.Linear,
+ activation=activation,
+ norm_type=norm_type
+ ), self.popart
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``PopArtVHead`` and return the normalized prediction and \
+ the unnormalized prediction dictionary.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Tensor containing input embedding.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keyword ``pred`` (:obj:`torch.Tensor`) \
+ and ``unnormalized_pred`` (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - logit: :math:`(B, M)`, where ``M = output_size``.
+ Examples:
+ >>> head = PopArtVHead(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = head(inputs)
+ >>> assert isinstance(outputs, dict) and outputs['pred'].shape == torch.Size([4, 64]) and \
+ outputs['unnormalized_pred'].shape == torch.Size([4, 64])
+ """
+ x = self.Q(x)
+ return x
+
+
+class AttentionPolicyHead(nn.Module):
+ """
+ Overview:
+ Cross-attention-type discrete action policy head, which is often used in variable discrete action space.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(self) -> None:
+ super(AttentionPolicyHead, self).__init__()
+
+ def forward(self, key: torch.Tensor, query: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Use attention-like mechanism to combine key and query tensor to output discrete action logit.
+ Arguments:
+ - key (:obj:`torch.Tensor`): Tensor containing key embedding.
+ - query (:obj:`torch.Tensor`): Tensor containing query embedding.
+ Returns:
+ - logit (:obj:`torch.Tensor`): Tensor containing output discrete action logit.
+ Shapes:
+ - key: :math:`(B, N, K)`, where ``B = batch_size``, ``N = possible discrete action choices`` and \
+ ``K = hidden_size``.
+ - query: :math:`(B, K)`.
+ - logit: :math:`(B, N)`.
+ Examples:
+ >>> head = AttentionPolicyHead()
+ >>> key = torch.randn(4, 5, 64)
+ >>> query = torch.randn(4, 64)
+ >>> logit = head(key, query)
+ >>> assert logit.shape == torch.Size([4, 5])
+
+ .. note::
+ In this head, we assume that the ``key`` and ``query`` tensor are both normalized.
+ """
+ if len(query.shape) == 2 and len(key.shape) == 3:
+ query = query.unsqueeze(1)
+ logit = (key * query).sum(-1)
+ return logit
+
+
+class MultiHead(nn.Module):
+ """
+ Overview:
+ The ``MultiHead`` is used to generate multiple similar results.
+ For example, we can combine ``Distribution`` and ``MultiHead`` to generate multi-discrete action space logit.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(self, head_cls: type, hidden_size: int, output_size_list: SequenceType, **head_kwargs) -> None:
+ """
+ Overview:
+ Init the ``MultiHead`` layers according to the provided arguments.
+ Arguments:
+ - head_cls (:obj:`type`): The class of head, choose among [``DuelingHead``, ``DistributionHead``, \
+ ''QuatileHead'', ...].
+ - hidden_size (:obj:`int`): The ``hidden_size`` of the MLP connected to the ``Head``.
+ - output_size_list (:obj:`int`): Sequence of ``output_size`` for multi discrete action, e.g. ``[2, 3, 5]``.
+ - head_kwargs: (:obj:`dict`): Dict containing class-specific arguments.
+ """
+ super(MultiHead, self).__init__()
+ self.pred = nn.ModuleList()
+ for size in output_size_list:
+ self.pred.append(head_cls(hidden_size, size, **head_kwargs))
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``MultiHead`` and return the prediction dictionary.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Tensor containing input embedding.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``logit`` (:obj:`torch.Tensor`) \
+ corresponding to the logit of each ``output`` each accessed at ``['logit'][i]``.
+ Shapes:
+ - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - logit: :math:`(B, Mi)`, where ``Mi = output_size`` corresponding to output ``i``.
+ Examples:
+ >>> head = MultiHead(DuelingHead, 64, [2, 3, 5], v_layer_num=2)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = head(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> # output_size_list is [2, 3, 5] as set
+ >>> # Therefore each dim of logit is as follows
+ >>> outputs['logit'][0].shape
+ >>> torch.Size([4, 2])
+ >>> outputs['logit'][1].shape
+ >>> torch.Size([4, 3])
+ >>> outputs['logit'][2].shape
+ >>> torch.Size([4, 5])
+ """
+ return lists_to_dicts([m(x) for m in self.pred])
+
+
+class EnsembleHead(nn.Module):
+ """
+ Overview:
+ The ``EnsembleHead`` is used to generate Q-value for Q-ensemble in model-based RL algorithms.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int,
+ hidden_size: int,
+ layer_num: int,
+ ensemble_num: int,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None
+ ) -> None:
+ super(EnsembleHead, self).__init__()
+ d = input_size
+ layers = []
+ for _ in range(layer_num):
+ layers.append(
+ conv1d_block(
+ d * ensemble_num,
+ hidden_size * ensemble_num,
+ kernel_size=1,
+ stride=1,
+ groups=ensemble_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ d = hidden_size
+
+ # Adding activation for last layer will lead to train fail
+ layers.append(
+ conv1d_block(
+ hidden_size * ensemble_num,
+ output_size * ensemble_num,
+ kernel_size=1,
+ stride=1,
+ groups=ensemble_num,
+ activation=None,
+ norm_type=None
+ )
+ )
+ self.pred = nn.Sequential(*layers)
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to run MLP with ``EnsembleHead`` and return the prediction dictionary.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Tensor containing input embedding.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keyword ``pred`` (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N * ensemble_num, 1)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - pred: :math:`(B, M * ensemble_num, 1)`, where ``M = output_size``.
+ Examples:
+ >>> head = EnsembleHead(64 * 10, 64 * 10)
+ >>> inputs = torch.randn(4, 64 * 10, 1) `
+ >>> outputs = head(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['pred'].shape == torch.Size([10, 64 * 10])
+ """
+ x = self.pred(x).squeeze(-1)
+ return {'pred': x}
+
+
+def independent_normal_dist(logits: Union[List, Dict]) -> torch.distributions.Distribution:
+ """
+ Overview:
+ Convert different types logit to independent normal distribution.
+ Arguments:
+ - logits (:obj:`Union[List, Dict]`): The logits to be converted.
+ Returns:
+ - dist (:obj:`torch.distributions.Distribution`): The converted normal distribution.
+ Examples:
+ >>> logits = [torch.randn(4, 5), torch.ones(4, 5)]
+ >>> dist = independent_normal_dist(logits)
+ >>> assert isinstance(dist, torch.distributions.Independent)
+ >>> assert isinstance(dist.base_dist, torch.distributions.Normal)
+ >>> assert dist.base_dist.loc.shape == torch.Size([4, 5])
+ >>> assert dist.base_dist.scale.shape == torch.Size([4, 5])
+ Raises:
+ - TypeError: If the type of logits is not ``list`` or ``dict``.
+ """
+ if isinstance(logits, (list, tuple)):
+ return Independent(Normal(*logits), 1)
+ elif isinstance(logits, dict):
+ return Independent(Normal(logits['mu'], logits['sigma']), 1)
+ else:
+ raise TypeError("invalid logits type: {}".format(type(logits)))
+
+
+head_cls_map = {
+ # discrete
+ 'discrete': DiscreteHead,
+ 'dueling': DuelingHead,
+ 'sdn': StochasticDuelingHead,
+ 'distribution': DistributionHead,
+ 'rainbow': RainbowHead,
+ 'qrdqn': QRDQNHead,
+ 'quantile': QuantileHead,
+ 'fqf': FQFHead,
+ 'branch': BranchingHead,
+ 'attention_policy': AttentionPolicyHead,
+ # continuous
+ 'regression': RegressionHead,
+ 'reparameterization': ReparameterizationHead,
+ 'popart': PopArtVHead,
+ 'sdn': StochasticDuelingHead,
+ # multi
+ 'multi': MultiHead,
+ 'ensemble': EnsembleHead,
+}
diff --git a/DI-engine/ding/model/common/tests/test_encoder.py b/DI-engine/ding/model/common/tests/test_encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd8a5bf752aa702faf59bb6f7e7fb7df21c3c72e
--- /dev/null
+++ b/DI-engine/ding/model/common/tests/test_encoder.py
@@ -0,0 +1,63 @@
+import torch
+import numpy as np
+import pytest
+
+from ding.model import ConvEncoder, FCEncoder, IMPALAConvEncoder
+from ding.torch_utils import is_differentiable
+
+B = 4
+C, H, W = 3, 128, 128
+
+
+@pytest.mark.unittest
+class TestEncoder:
+
+ def output_check(self, model, outputs):
+ loss = outputs.sum()
+ is_differentiable(loss, model)
+
+ def test_conv_encoder(self):
+ inputs = torch.randn(B, C, H, W)
+ model = ConvEncoder((C, H, W), hidden_size_list=[32, 48, 64, 64, 128], activation=torch.nn.Tanh())
+ print(model)
+ outputs = model(inputs)
+ self.output_check(model, outputs)
+ assert outputs.shape == (B, 128)
+
+ def test_dreamer_conv_encoder(self):
+ inputs = torch.randn(B, C, H, W)
+ model = ConvEncoder(
+ (C, H, W),
+ hidden_size_list=[32, 64, 128, 256, 128],
+ activation=torch.nn.SiLU(),
+ kernel_size=[4, 4, 4, 4],
+ layer_norm=True
+ )
+ print(model)
+ outputs = model(inputs)
+ self.output_check(model, outputs)
+ assert outputs.shape == (B, 128)
+
+ def test_fc_encoder(self):
+ inputs = torch.randn(B, 32)
+ hidden_size_list = [128 for _ in range(3)]
+ model = FCEncoder(32, hidden_size_list, res_block=True, activation=torch.nn.Tanh())
+ print(model)
+ outputs = model(inputs)
+ self.output_check(model, outputs)
+ assert outputs.shape == (B, hidden_size_list[-1])
+
+ hidden_size_list = [64, 128, 256]
+ model = FCEncoder(32, hidden_size_list, res_block=False, activation=torch.nn.Tanh())
+ print(model)
+ outputs = model(inputs)
+ self.output_check(model, outputs)
+ assert outputs.shape == (B, hidden_size_list[-1])
+
+ def test_impalaconv_encoder(self):
+ inputs = torch.randn(B, 3, 64, 64)
+ model = IMPALAConvEncoder(obs_shape=(3, 64, 64))
+ print(model)
+ outputs = model(inputs)
+ self.output_check(model, outputs)
+ assert outputs.shape == (B, 256)
diff --git a/DI-engine/ding/model/common/tests/test_head.py b/DI-engine/ding/model/common/tests/test_head.py
new file mode 100644
index 0000000000000000000000000000000000000000..044ff75cbd897e763c30ddea137395de6c5bc2de
--- /dev/null
+++ b/DI-engine/ding/model/common/tests/test_head.py
@@ -0,0 +1,93 @@
+import torch
+import numpy as np
+import pytest
+
+from ding.model.common.head import DuelingHead, ReparameterizationHead, MultiHead, StochasticDuelingHead, EnsembleHead
+from ding.torch_utils import is_differentiable
+
+B = 4
+T = 6
+embedding_dim = 64
+action_shape = 12
+
+
+@pytest.mark.unittest
+class TestHead:
+
+ def output_check(self, model, outputs):
+ if isinstance(outputs, torch.Tensor):
+ loss = outputs.sum()
+ elif isinstance(outputs, list):
+ loss = sum([t.sum() for t in outputs])
+ elif isinstance(outputs, dict):
+ loss = sum([v.sum() for v in outputs.values()])
+ is_differentiable(loss, model)
+
+ def test_dueling(self):
+ inputs = torch.randn(B, embedding_dim)
+ model = DuelingHead(embedding_dim, action_shape, 3, 3)
+ outputs = model(inputs)['logit']
+ self.output_check(model, outputs)
+ assert outputs.shape == (B, action_shape)
+
+ @pytest.mark.parametrize('action_shape', [1, 8])
+ def test_reparameterization(self, action_shape):
+ inputs = torch.randn(B, embedding_dim)
+ for sigma_type in ['fixed', 'independent', 'conditioned']:
+ if sigma_type == 'fixed':
+ model = ReparameterizationHead(
+ embedding_dim, action_shape, sigma_type=sigma_type, fixed_sigma_value=0.5
+ )
+ outputs = model(inputs)
+ mu, sigma = outputs['mu'], outputs['sigma']
+ assert mu.shape == (B, action_shape) and sigma.shape == (B, action_shape)
+ assert sigma.eq(torch.full((B, action_shape), 0.5)).all()
+ self.output_check(model, outputs)
+ elif sigma_type == 'independent':
+ model = ReparameterizationHead(embedding_dim, action_shape, sigma_type=sigma_type)
+ outputs = model(inputs)
+ mu, sigma = outputs['mu'], outputs['sigma']
+ assert mu.shape == (B, action_shape) and sigma.shape == (B, action_shape)
+ self.output_check(model, outputs)
+ assert model.log_sigma_param.grad is not None
+ elif sigma_type == 'conditioned':
+ model = ReparameterizationHead(embedding_dim, action_shape, sigma_type=sigma_type)
+ outputs = model(inputs)
+ mu, sigma = outputs['mu'], outputs['sigma']
+ assert mu.shape == (B, action_shape) and sigma.shape == (B, action_shape)
+ self.output_check(model, outputs)
+
+ def test_multi_head(self):
+ output_size_list = [2, 3, 7]
+ head = MultiHead(DuelingHead, embedding_dim, output_size_list, activation=torch.nn.Tanh())
+ print(head)
+ inputs = torch.randn(B, embedding_dim)
+ outputs = head(inputs)
+ assert isinstance(outputs, dict)
+ self.output_check(head, outputs['logit'])
+ for i, d in enumerate(output_size_list):
+ assert outputs['logit'][i].shape == (B, d)
+
+ @pytest.mark.tmp
+ def test_stochastic_dueling(self):
+ obs = torch.randn(B, embedding_dim)
+ behaviour_action = torch.randn(B, action_shape).clamp(-1, 1)
+ mu = torch.randn(B, action_shape).requires_grad_(True)
+ sigma = torch.rand(B, action_shape).requires_grad_(True)
+ model = StochasticDuelingHead(embedding_dim, action_shape, 3, 3)
+
+ assert mu.grad is None and sigma.grad is None
+ outputs = model(obs, behaviour_action, mu, sigma)
+ self.output_check(model, outputs['q_value'])
+ assert isinstance(mu.grad, torch.Tensor)
+ print(mu.grad)
+ assert isinstance(sigma.grad, torch.Tensor)
+ assert outputs['q_value'].shape == (B, 1)
+ assert outputs['v_value'].shape == (B, 1)
+
+ def test_ensemble(self):
+ inputs = torch.randn(B, embedding_dim * 3, 1)
+ model = EnsembleHead(embedding_dim, action_shape, 3, 3, 3)
+ outputs = model(inputs)['pred']
+ self.output_check(model, outputs)
+ assert outputs.shape == (B, action_shape * 3)
diff --git a/DI-engine/ding/model/common/utils.py b/DI-engine/ding/model/common/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..f74a17996284b8d5a8b11477707e56235368b358
--- /dev/null
+++ b/DI-engine/ding/model/common/utils.py
@@ -0,0 +1,31 @@
+import copy
+import torch
+from easydict import EasyDict
+from ding.utils import import_module, MODEL_REGISTRY
+
+
+def create_model(cfg: EasyDict) -> torch.nn.Module:
+ """
+ Overview:
+ Create a neural network model according to the given EasyDict-type ``cfg``.
+ Arguments:
+ - cfg: (:obj:`EasyDict`): User's model config. The key ``import_name`` is \
+ used to import modules, and they key ``type`` is used to indicate the model.
+ Returns:
+ - (:obj:`torch.nn.Module`): The created neural network model.
+ Examples:
+ >>> cfg = EasyDict({
+ >>> 'import_names': ['ding.model.template.q_learning'],
+ >>> 'type': 'dqn',
+ >>> 'obs_shape': 4,
+ >>> 'action_shape': 2,
+ >>> })
+ >>> model = create_model(cfg)
+
+ .. tip::
+ This method will not modify the ``cfg`` , it will deepcopy the ``cfg`` and then modify it.
+ """
+ cfg = copy.deepcopy(cfg)
+ import_module(cfg.pop('import_names', []))
+ # here we must use the pop opeartion to ensure compatibility
+ return MODEL_REGISTRY.build(cfg.pop("type"), **cfg)
diff --git a/DI-engine/ding/model/template/__init__.py b/DI-engine/ding/model/template/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..c9dc17791c647832d95b70dcc9399827acbb6c5e
--- /dev/null
+++ b/DI-engine/ding/model/template/__init__.py
@@ -0,0 +1,30 @@
+# general
+from .q_learning import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN, BDQ, GTrXLDQN
+from .qac import DiscreteQAC, ContinuousQAC
+from .pdqn import PDQN
+from .vac import VAC, DREAMERVAC
+from .bc import DiscreteBC, ContinuousBC
+from .language_transformer import LanguageTransformer
+# algorithm-specific
+from .pg import PG
+from .ppg import PPG
+from .qmix import Mixer, QMix
+from .collaq import CollaQ
+from .wqmix import WQMix
+from .coma import COMA
+from .atoc import ATOC
+from .sqn import SQN
+from .acer import ACER
+from .qtran import QTran
+from .mavac import MAVAC
+from .ngu import NGU
+from .qac_dist import QACDIST
+from .maqac import DiscreteMAQAC, ContinuousMAQAC
+from .madqn import MADQN
+from .vae import VanillaVAE
+from .decision_transformer import DecisionTransformer
+from .procedure_cloning import ProcedureCloningMCTS, ProcedureCloningBFS
+from .bcq import BCQ
+from .edac import EDAC
+from .ebm import EBM, AutoregressiveEBM
+from .havac import HAVAC
diff --git a/DI-engine/ding/model/template/acer.py b/DI-engine/ding/model/template/acer.py
new file mode 100644
index 0000000000000000000000000000000000000000..44bb386cbad57ab065532ca0b2a1c1b29e38de77
--- /dev/null
+++ b/DI-engine/ding/model/template/acer.py
@@ -0,0 +1,155 @@
+from typing import Union, Dict, Optional
+import torch
+import torch.nn as nn
+
+from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
+from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \
+ FCEncoder, ConvEncoder
+
+
+@MODEL_REGISTRY.register('acer')
+class ACER(nn.Module):
+ """
+ Overview:
+ The model of algorithmn ACER(Actor Critic with Experience Replay)
+ Sample Efficient Actor-Critic with Experience Replay.
+ https://arxiv.org/abs/1611.01224
+ Interfaces:
+ ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
+ """
+ mode = ['compute_actor', 'compute_critic']
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 1,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ ) -> None:
+ """
+ Overview:
+ Init the ACER Model according to arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
+ - action_shape (:obj:`Union[int, SequenceType]`): Action's space.
+ - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
+ - actor_head_layer_num (:obj:`int`):
+ The num of layers used in the network to compute Q value output for actor's nn.
+ - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
+ - critic_head_layer_num (:obj:`int`):
+ The num of layers used in the network to compute Q value output for critic's nn.
+ - activation (:obj:`Optional[nn.Module]`):
+ The type of activation function to use in ``MLP`` the after ``layer_fn``,
+ if ``None`` then default set to ``nn.ReLU()``
+ - norm_type (:obj:`Optional[str]`):
+ The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details.
+ """
+ super(ACER, self).__init__()
+ obs_shape: int = squeeze(obs_shape)
+ action_shape: int = squeeze(action_shape)
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ encoder_cls = FCEncoder
+ elif len(obs_shape) == 3:
+ encoder_cls = ConvEncoder
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own DQN".format(obs_shape)
+ )
+
+ self.actor_encoder = encoder_cls(
+ obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
+ )
+ self.critic_encoder = encoder_cls(
+ obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
+ )
+
+ self.critic_head = RegressionHead(
+ critic_head_hidden_size, action_shape, critic_head_layer_num, activation=activation, norm_type=norm_type
+ )
+ self.actor_head = DiscreteHead(
+ actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type
+ )
+ self.actor = [self.actor_encoder, self.actor_head]
+ self.critic = [self.critic_encoder, self.critic_head]
+ self.actor = nn.ModuleList(self.actor)
+ self.critic = nn.ModuleList(self.critic)
+
+ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
+ """
+ Overview:
+ Use observation to predict output.
+ Parameter updates with ACER's MLPs forward setup.
+ Arguments:
+ - mode (:obj:`str`): Name of the forward mode.
+ Returns:
+ - outputs (:obj:`Dict`): Outputs of network forward.
+ Shapes (Actor):
+ - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
+ Shapes (Critic):
+ - inputs (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``obs_shape``
+ - q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_actor(self, inputs: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to predict output.
+ Execute parameter updates with ``compute_actor`` mode
+ Use encoded embedding tensor to predict output.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`):
+ The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
+ ``hidden_size = actor_head_hidden_size``
+ - mode (:obj:`str`): Name of the forward mode.
+ Returns:
+ - outputs (:obj:`Dict`): Outputs of forward pass encoder and head.
+ ReturnsKeys (either):
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, N1)`, where B is batch size and N1 is ``action_shape``
+ Shapes:
+ - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, N1)`, where B is batch size and N1 is ``action_shape``
+ Examples:
+ >>> # Regression mode
+ >>> model = ACER(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> actor_outputs = model(inputs,'compute_actor')
+ >>> assert actor_outputs['logit'].shape == torch.Size([4, 64])
+ """
+ x = self.actor_encoder(inputs)
+ x = self.actor_head(x)
+
+ return x
+
+ def compute_critic(self, inputs: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Execute parameter updates with ``compute_critic`` mode
+ Use encoded embedding tensor to predict output.
+ Arguments:
+ - ``obs``, ``action`` encoded tensors.
+ - mode (:obj:`str`): Name of the forward mode.
+ Returns:
+ - outputs (:obj:`Dict`): Q-value output.
+ ReturnKeys:
+ - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
+ - q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``.
+ Examples:
+ >>> inputs =torch.randn(4, N)
+ >>> model = ACER(obs_shape=(N, ),action_shape=5)
+ >>> model(inputs, mode='compute_critic')['q_value']
+ """
+
+ obs = inputs
+ x = self.critic_encoder(obs)
+ x = self.critic_head(x)
+ return {"q_value": x['pred']}
diff --git a/DI-engine/ding/model/template/atoc.py b/DI-engine/ding/model/template/atoc.py
new file mode 100644
index 0000000000000000000000000000000000000000..a06f536aefd34b269442692486791075fc525c8f
--- /dev/null
+++ b/DI-engine/ding/model/template/atoc.py
@@ -0,0 +1,582 @@
+from typing import Union, Dict, Optional, Tuple
+
+import torch
+import torch.nn as nn
+
+from ding.utils import squeeze, MODEL_REGISTRY, SequenceType
+from ding.torch_utils import MLP
+from ding.model.common import RegressionHead
+
+
+class ATOCAttentionUnit(nn.Module):
+ """
+ Overview:
+ The attention unit of the ATOC network. We now implement it as two-layer MLP, same as the original paper.
+ Interface:
+ ``__init__``, ``forward``
+
+ .. note::
+ "ATOC paper: We use two-layer MLP to implement the attention unit but it is also can be realized by RNN."
+ """
+
+ def __init__(self, thought_size: int, embedding_size: int) -> None:
+ """
+ Overview:
+ Initialize the attention unit according to the size of input arguments.
+ Arguments:
+ - thought_size (:obj:`int`): the size of input thought
+ - embedding_size (:obj:`int`): the size of hidden layers
+ """
+ super(ATOCAttentionUnit, self).__init__()
+ self._thought_size = thought_size
+ self._hidden_size = embedding_size
+ self._output_size = 1
+ self._act1 = nn.ReLU()
+ self._fc1 = nn.Linear(self._thought_size, self._hidden_size, bias=True)
+ self._fc2 = nn.Linear(self._hidden_size, self._hidden_size, bias=True)
+ self._fc3 = nn.Linear(self._hidden_size, self._output_size, bias=True)
+ self._act2 = nn.Sigmoid()
+
+ def forward(self, data: Union[Dict, torch.Tensor]) -> torch.Tensor:
+ """
+ Overview:
+ Take the thought of agents as input and generate the probability of these agent being initiator
+ Arguments:
+ - x (:obj:`Union[Dict, torch.Tensor`): the input tensor or dict contain the thoughts tensor
+ - ret (:obj:`torch.Tensor`): the output initiator probability
+ Shapes:
+ - data['thought']: :math:`(M, B, N)`, M is the num of thoughts to integrate,\
+ B is batch_size and N is thought size
+ Examples:
+ >>> attention_unit = ATOCAttentionUnit(64, 64)
+ >>> thought = torch.randn(2, 3, 64)
+ >>> attention_unit(thought)
+ """
+ x = data
+ if isinstance(data, Dict):
+ x = data['thought']
+ x = self._fc1(x)
+ x = self._act1(x)
+ x = self._fc2(x)
+ x = self._act1(x)
+ x = self._fc3(x)
+ x = self._act2(x)
+ return x.squeeze(-1)
+
+
+class ATOCCommunicationNet(nn.Module):
+ """
+ Overview:
+ This ATOC commnication net is a bi-direction LSTM, so it can integrate all the thoughts in the group.
+ Interface:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, thought_size: int) -> None:
+ """
+ Overview:
+ Initialize the communication network according to the size of input arguments.
+ Arguments:
+ - thought_size (:obj:`int`): the size of input thought
+
+ .. note::
+
+ communication hidden size should be half of the actor_hidden_size because of the bi-direction lstm
+ """
+ super(ATOCCommunicationNet, self).__init__()
+ assert thought_size % 2 == 0
+ self._thought_size = thought_size
+ self._comm_hidden_size = thought_size // 2
+ self._bi_lstm = nn.LSTM(self._thought_size, self._comm_hidden_size, bidirectional=True)
+
+ def forward(self, data: Union[Dict, torch.Tensor]):
+ """
+ Overview:
+ The forward of ATOCCommunicationNet integrates thoughts in the group.
+ Arguments:
+ - x (:obj:`Union[Dict, torch.Tensor`): the input tensor or dict contain the thoughts tensor
+ - out (:obj:`torch.Tensor`): the integrated thoughts
+ Shapes:
+ - data['thoughts']: :math:`(M, B, N)`, M is the num of thoughts to integrate,\
+ B is batch_size and N is thought size
+ Examples:
+ >>> comm_net = ATOCCommunicationNet(64)
+ >>> thoughts = torch.randn(2, 3, 64)
+ >>> comm_net(thoughts)
+ """
+ self._bi_lstm.flatten_parameters()
+ x = data
+ if isinstance(data, Dict):
+ x = data['thoughts']
+ out, _ = self._bi_lstm(x)
+ return out
+
+
+class ATOCActorNet(nn.Module):
+ """
+ Overview:
+ The actor network of ATOC.
+ Interface:
+ ``__init__``, ``forward``
+
+ .. note::
+ "ATOC paper: The neural networks use ReLU and batch normalization for some hidden layers."
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[Tuple, int],
+ thought_size: int,
+ action_shape: int,
+ n_agent: int,
+ communication: bool = True,
+ agent_per_group: int = 2,
+ initiator_threshold: float = 0.5,
+ attention_embedding_size: int = 64,
+ actor_1_embedding_size: Union[int, None] = None,
+ actor_2_embedding_size: Union[int, None] = None,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ ):
+ """
+ Overview:
+ Initialize the actor network of ATOC
+ Arguments:
+ - obs_shape(:obj:`Union[Tuple, int]`): the observation size
+ - thought_size (:obj:`int`): the size of thoughts
+ - action_shape (:obj:`int`): the action size
+ - n_agent (:obj:`int`): the num of agents
+ - agent_per_group (:obj:`int`): the num of agent in each group
+ - initiator_threshold (:obj:`float`): the threshold of becoming an initiator, default set to 0.5
+ - attention_embedding_size (obj:`int`): the embedding size of attention unit, default set to 64
+ - actor_1_embedding_size (:obj:`Union[int, None]`): the size of embedding size of actor network part1, \
+ if None, then default set to thought size
+ - actor_2_embedding_size (:obj:`Union[int, None]`): the size of embedding size of actor network part2, \
+ if None, then default set to thought size
+ """
+ super(ATOCActorNet, self).__init__()
+ # now only support obs_shape of shape (O_dim, )
+ self._obs_shape = squeeze(obs_shape)
+ self._thought_size = thought_size
+ self._act_shape = action_shape
+ self._n_agent = n_agent
+ self._communication = communication
+ self._agent_per_group = agent_per_group
+ self._initiator_threshold = initiator_threshold
+ if not actor_1_embedding_size:
+ actor_1_embedding_size = self._thought_size
+ if not actor_2_embedding_size:
+ actor_2_embedding_size = self._thought_size
+
+ # Actor Net(I)
+ self.actor_1 = MLP(
+ self._obs_shape,
+ actor_1_embedding_size,
+ self._thought_size,
+ layer_num=2,
+ activation=activation,
+ norm_type=norm_type
+ )
+
+ # Actor Net(II)
+ self.actor_2 = nn.Sequential(
+ nn.Linear(self._thought_size * 2, actor_2_embedding_size), activation,
+ RegressionHead(
+ actor_2_embedding_size, self._act_shape, 2, final_tanh=True, activation=activation, norm_type=norm_type
+ )
+ )
+
+ # Communication
+ if self._communication:
+ self.attention = ATOCAttentionUnit(self._thought_size, attention_embedding_size)
+ self.comm_net = ATOCCommunicationNet(self._thought_size)
+
+ def forward(self, obs: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Take the input obs, and calculate the corresponding action, group, initiator_prob, thoughts, etc...
+ Arguments:
+ - obs (:obj:`Dict`): the input obs containing the observation
+ Returns:
+ - ret (:obj:`Dict`): the returned output, including action, group, initiator_prob, is_initiator, \
+ new_thoughts and old_thoughts
+ ReturnsKeys:
+ - necessary: ``action``
+ - optional: ``group``, ``initiator_prob``, ``is_initiator``, ``new_thoughts``, ``old_thoughts``
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size
+ - action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size
+ - group (:obj:`torch.Tensor`): :math:`(B, A, A)`
+ - initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)`
+ - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)`
+ - new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`
+ - old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`
+ Examples:
+ >>> actor_net = ATOCActorNet(64, 64, 64, 3)
+ >>> obs = torch.randn(2, 3, 64)
+ >>> actor_net(obs)
+ """
+ assert len(obs.shape) == 3
+ self._cur_batch_size = obs.shape[0]
+ B, A, N = obs.shape
+ assert A == self._n_agent
+ assert N == self._obs_shape
+
+ current_thoughts = self.actor_1(obs) # B, A, thought size
+
+ if self._communication:
+ old_thoughts = current_thoughts.clone().detach()
+ init_prob, is_initiator, group = self._get_initiate_group(old_thoughts)
+
+ new_thoughts = self._get_new_thoughts(current_thoughts, group, is_initiator)
+ else:
+ new_thoughts = current_thoughts
+ action = self.actor_2(torch.cat([current_thoughts, new_thoughts], dim=-1))['pred']
+
+ if self._communication:
+ return {
+ 'action': action,
+ 'group': group,
+ 'initiator_prob': init_prob,
+ 'is_initiator': is_initiator,
+ 'new_thoughts': new_thoughts,
+ 'old_thoughts': old_thoughts,
+ }
+ else:
+ return {'action': action}
+
+ def _get_initiate_group(self, current_thoughts):
+ """
+ Overview:
+ Calculate the initiator probability, group and is_initiator
+ Arguments:
+ - current_thoughts (:obj:`torch.Tensor`): tensor of current thoughts
+ Returns:
+ - init_prob (:obj:`torch.Tensor`): tesnor of initiator probability
+ - is_initiator (:obj:`torch.Tensor`): tensor of is initiator
+ - group (:obj:`torch.Tensor`): tensor of group
+ Shapes:
+ - current_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is thought size
+ - init_prob (:obj:`torch.Tensor`): :math:`(B, A)`
+ - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)`
+ - group (:obj:`torch.Tensor`): :math:`(B, A, A)`
+ Examples:
+ >>> actor_net = ATOCActorNet(64, 64, 64, 3)
+ >>> current_thoughts = torch.randn(2, 3, 64)
+ >>> actor_net._get_initiate_group(current_thoughts)
+ """
+ if not self._communication:
+ raise NotImplementedError
+ init_prob = self.attention(current_thoughts) # B, A
+ is_initiator = (init_prob > self._initiator_threshold)
+ B, A = init_prob.shape[:2]
+
+ thoughts_pair_dot = current_thoughts.bmm(current_thoughts.transpose(1, 2))
+ thoughts_square = thoughts_pair_dot.diagonal(0, 1, 2)
+ curr_thought_dists = thoughts_square.unsqueeze(1) - 2 * thoughts_pair_dot + thoughts_square.unsqueeze(2)
+
+ group = torch.zeros(B, A, A).to(init_prob.device)
+
+ # "considers the agents in its observable field"
+ # "initiator first chooses collaborators from agents who have not been selected,
+ # then from agents selected by other initiators,
+ # finally from other initiators"
+ # "all based on proximity"
+
+ # roughly choose m closest as group
+ for b in range(B):
+ for i in range(A):
+ if is_initiator[b][i]:
+ index_seq = curr_thought_dists[b][i].argsort()
+ index_seq = index_seq[:self._agent_per_group]
+ group[b][i][index_seq] = 1
+ return init_prob, is_initiator, group
+
+ def _get_new_thoughts(self, current_thoughts, group, is_initiator):
+ """
+ Overview:
+ Calculate the new thoughts according to current thoughts, group and is_initiator
+ Arguments:
+ - current_thoughts (:obj:`torch.Tensor`): tensor of current thoughts
+ - group (:obj:`torch.Tensor`): tensor of group
+ - is_initiator (:obj:`torch.Tensor`): tensor of is initiator
+ Returns:
+ - new_thoughts (:obj:`torch.Tensor`): tensor of new thoughts
+ Shapes:
+ - current_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is thought size
+ - group: (:obj:`torch.Tensor`): :math:`(B, A, A)`
+ - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)`
+ - new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`
+ Examples:
+ >>> actor_net = ATOCActorNet(64, 64, 64, 3)
+ >>> current_thoughts = torch.randn(2, 3, 64)
+ >>> group = torch.randn(2, 3, 3)
+ >>> is_initiator = torch.randn(2, 3)
+ >>> actor_net._get_new_thoughts(current_thoughts, group, is_initiator)
+ """
+ if not self._communication:
+ raise NotImplementedError
+ B, A = current_thoughts.shape[:2]
+ new_thoughts = current_thoughts.detach().clone()
+ if len(torch.nonzero(is_initiator)) == 0:
+ return new_thoughts
+
+ # TODO(nyz) execute communication serially for shared agent in different group
+ thoughts_to_commute = []
+ for b in range(B):
+ for i in range(A):
+ if is_initiator[b][i]:
+ tmp = []
+ for j in range(A):
+ if group[b][i][j]:
+ tmp.append(new_thoughts[b][j])
+ thoughts_to_commute.append(torch.stack(tmp, dim=0))
+ thoughts_to_commute = torch.stack(thoughts_to_commute, dim=1) # agent_per_group, B_, N
+ integrated_thoughts = self.comm_net(thoughts_to_commute)
+ b_count = 0
+ for b in range(B):
+ for i in range(A):
+ if is_initiator[b][i]:
+ j_count = 0
+ for j in range(A):
+ if group[b][i][j]:
+ new_thoughts[b][j] = integrated_thoughts[j_count][b_count]
+ j_count += 1
+ b_count += 1
+ return new_thoughts
+
+
+@MODEL_REGISTRY.register('atoc')
+class ATOC(nn.Module):
+ """
+ Overview:
+ The QAC network of ATOC, a kind of extension of DDPG for MARL.
+ Learning Attentional Communication for Multi-Agent Cooperation
+ https://arxiv.org/abs/1805.07733
+ Interface:
+ ``__init__``, ``forward``, ``compute_critic``, ``compute_actor``, ``optimize_actor_attention``
+ """
+ mode = ['compute_actor', 'compute_critic', 'optimize_actor_attention']
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ thought_size: int,
+ n_agent: int,
+ communication: bool = True,
+ agent_per_group: int = 2,
+ actor_1_embedding_size: Union[int, None] = None,
+ actor_2_embedding_size: Union[int, None] = None,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 2,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize the ATOC QAC network
+ Arguments:
+ - obs_shape(:obj:`Union[Tuple, int]`): the observation space shape
+ - thought_size (:obj:`int`): the size of thoughts
+ - action_shape (:obj:`int`): the action space shape
+ - n_agent (:obj:`int`): the num of agents
+ - agent_per_group (:obj:`int`): the num of agent in each group
+ """
+ super(ATOC, self).__init__()
+ self._communication = communication
+
+ self.actor = ATOCActorNet(
+ obs_shape,
+ thought_size,
+ action_shape,
+ n_agent,
+ communication,
+ agent_per_group,
+ actor_1_embedding_size=actor_1_embedding_size,
+ actor_2_embedding_size=actor_2_embedding_size
+ )
+ self.critic = nn.Sequential(
+ nn.Linear(obs_shape + action_shape, critic_head_hidden_size), activation,
+ RegressionHead(
+ critic_head_hidden_size,
+ 1,
+ critic_head_layer_num,
+ final_tanh=False,
+ activation=activation,
+ norm_type=norm_type,
+ )
+ )
+
+ def _compute_delta_q(self, obs: torch.Tensor, actor_outputs: Dict) -> torch.Tensor:
+ """
+ Overview:
+ calculate the delta_q according to obs and actor_outputs
+ Arguments:
+ - obs (:obj:`torch.Tensor`): the observations
+ - actor_outputs (:obj:`dict`): the output of actors
+ - delta_q (:obj:`Dict`): the calculated delta_q
+ Returns:
+ - delta_q (:obj:`Dict`): the calculated delta_q
+ ArgumentsKeys:
+ - necessary: ``new_thoughts``, ``old_thoughts``, ``group``, ``is_initiator``
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size
+ - actor_outputs (:obj:`Dict`): the output of actor network, including ``action``, ``new_thoughts``, \
+ ``old_thoughts``, ``group``, ``initiator_prob``, ``is_initiator``
+ - action (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is action size
+ - new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is thought size
+ - old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)` where M is thought size
+ - group (:obj:`torch.Tensor`): :math:`(B, A, A)`
+ - initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)`
+ - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)`
+ - delta_q (:obj:`torch.Tensor`): :math:`(B, A)`
+ Examples:
+ >>> net = ATOC(64, 64, 64, 3)
+ >>> obs = torch.randn(2, 3, 64)
+ >>> actor_outputs = net.compute_actor(obs)
+ >>> net._compute_delta_q(obs, actor_outputs)
+ """
+ if not self._communication:
+ raise NotImplementedError
+ assert len(obs.shape) == 3
+ new_thoughts, old_thoughts, group, is_initiator = actor_outputs['new_thoughts'], actor_outputs[
+ 'old_thoughts'], actor_outputs['group'], actor_outputs['is_initiator']
+ B, A = new_thoughts.shape[:2]
+ curr_delta_q = torch.zeros(B, A).to(new_thoughts.device)
+ with torch.no_grad():
+ for b in range(B):
+ for i in range(A):
+ if not is_initiator[b][i]:
+ continue
+ q_group = []
+ actual_q_group = []
+ for j in range(A):
+ if not group[b][i][j]:
+ continue
+ before_update_action_j = self.actor.actor_2(
+ torch.cat([old_thoughts[b][j], old_thoughts[b][j]], dim=-1)
+ )
+ after_update_action_j = self.actor.actor_2(
+ torch.cat([old_thoughts[b][j], new_thoughts[b][j]], dim=-1)
+ )
+ before_update_input = torch.cat([obs[b][j], before_update_action_j['pred']], dim=-1)
+ before_update_Q_j = self.critic(before_update_input)['pred']
+ after_update_input = torch.cat([obs[b][j], after_update_action_j['pred']], dim=-1)
+ after_update_Q_j = self.critic(after_update_input)['pred']
+ q_group.append(before_update_Q_j)
+ actual_q_group.append(after_update_Q_j)
+ q_group = torch.stack(q_group)
+ actual_q_group = torch.stack(actual_q_group)
+ curr_delta_q[b][i] = actual_q_group.mean() - q_group.mean()
+ return curr_delta_q
+
+ def compute_actor(self, obs: torch.Tensor, get_delta_q: bool = False) -> Dict[str, torch.Tensor]:
+ '''
+ Overview:
+ compute the action according to inputs, call the _compute_delta_q function to compute delta_q
+ Arguments:
+ - obs (:obj:`torch.Tensor`): observation
+ - get_delta_q (:obj:`bool`) : whether need to get delta_q
+ Returns:
+ - outputs (:obj:`Dict`): the output of actor network and delta_q
+ ReturnsKeys:
+ - necessary: ``action``
+ - optional: ``group``, ``initiator_prob``, ``is_initiator``, ``new_thoughts``, ``old_thoughts``, ``delta_q``
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size
+ - action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size
+ - group (:obj:`torch.Tensor`): :math:`(B, A, A)`
+ - initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)`
+ - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)`
+ - new_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`
+ - old_thoughts (:obj:`torch.Tensor`): :math:`(B, A, M)`
+ - delta_q (:obj:`torch.Tensor`): :math:`(B, A)`
+ Examples:
+ >>> net = ATOC(64, 64, 64, 3)
+ >>> obs = torch.randn(2, 3, 64)
+ >>> net.compute_actor(obs)
+ '''
+ outputs = self.actor(obs)
+ if get_delta_q and self._communication:
+ delta_q = self._compute_delta_q(obs, outputs)
+ outputs['delta_q'] = delta_q
+ return outputs
+
+ def compute_critic(self, inputs: Dict) -> Dict:
+ """
+ Overview:
+ compute the q_value according to inputs
+ Arguments:
+ - inputs (:obj:`Dict`): the inputs contain the obs and action
+ Returns:
+ - outputs (:obj:`Dict`): the output of critic network
+ ArgumentsKeys:
+ - necessary: ``obs``, ``action``
+ ReturnsKeys:
+ - necessary: ``q_value``
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, A, N)`, where B is batch size, A is agent num, N is obs size
+ - action (:obj:`torch.Tensor`): :math:`(B, A, M)`, where M is action size
+ - q_value (:obj:`torch.Tensor`): :math:`(B, A)`
+ Examples:
+ >>> net = ATOC(64, 64, 64, 3)
+ >>> obs = torch.randn(2, 3, 64)
+ >>> action = torch.randn(2, 3, 64)
+ >>> net.compute_critic({'obs': obs, 'action': action})
+ """
+ obs, action = inputs['obs'], inputs['action']
+ if len(action.shape) == 2: # (B, A) -> (B, A, 1)
+ action = action.unsqueeze(2)
+ x = torch.cat([obs, action], dim=-1)
+ x = self.critic(x)['pred']
+ return {'q_value': x}
+
+ def optimize_actor_attention(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ return the actor attention loss
+ Arguments:
+ - inputs (:obj:`Dict`): the inputs contain the delta_q, initiator_prob, and is_initiator
+ Returns
+ - loss (:obj:`Dict`): the loss of actor attention unit
+ ArgumentsKeys:
+ - necessary: ``delta_q``, ``initiator_prob``, ``is_initiator``
+ ReturnsKeys:
+ - necessary: ``loss``
+ Shapes:
+ - delta_q (:obj:`torch.Tensor`): :math:`(B, A)`
+ - initiator_prob (:obj:`torch.Tensor`): :math:`(B, A)`
+ - is_initiator (:obj:`torch.Tensor`): :math:`(B, A)`
+ - loss (:obj:`torch.Tensor`): :math:`(1)`
+ Examples:
+ >>> net = ATOC(64, 64, 64, 3)
+ >>> delta_q = torch.randn(2, 3)
+ >>> initiator_prob = torch.randn(2, 3)
+ >>> is_initiator = torch.randn(2, 3)
+ >>> net.optimize_actor_attention(
+ >>> {'delta_q': delta_q,
+ >>> 'initiator_prob': initiator_prob,
+ >>> 'is_initiator': is_initiator})
+ """
+ if not self._communication:
+ raise NotImplementedError
+ delta_q = inputs['delta_q'].reshape(-1)
+ init_prob = inputs['initiator_prob'].reshape(-1)
+ is_init = inputs['is_initiator'].reshape(-1)
+ delta_q = delta_q[is_init.nonzero()]
+ init_prob = init_prob[is_init.nonzero()]
+ init_prob = 0.9 * init_prob + 0.05
+
+ # judge to avoid nan
+ if init_prob.shape == (0, 1):
+ actor_attention_loss = torch.FloatTensor([-0.0]).to(delta_q.device)
+ actor_attention_loss.requires_grad = True
+ else:
+ actor_attention_loss = -delta_q * \
+ torch.log(init_prob) - (1 - delta_q) * torch.log(1 - init_prob)
+ return {'loss': actor_attention_loss.mean()}
+
+ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str, **kwargs) -> Dict:
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs, **kwargs)
diff --git a/DI-engine/ding/model/template/bc.py b/DI-engine/ding/model/template/bc.py
new file mode 100644
index 0000000000000000000000000000000000000000..5348c750a6fc82fd0a6531e88d054cbf6d24940e
--- /dev/null
+++ b/DI-engine/ding/model/template/bc.py
@@ -0,0 +1,217 @@
+from typing import Union, Optional, Dict
+import torch
+import torch.nn as nn
+from easydict import EasyDict
+
+from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
+from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, \
+ MultiHead, RegressionHead, ReparameterizationHead
+
+
+@MODEL_REGISTRY.register('discrete_bc')
+class DiscreteBC(nn.Module):
+ """
+ Overview:
+ The DiscreteBC network.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ dueling: bool = True,
+ head_hidden_size: Optional[int] = None,
+ head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ strides: Optional[list] = None,
+ ) -> None:
+ """
+ Overview:
+ Init the DiscreteBC (encoder + head) Model according to input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
+ - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
+ the last element must match ``head_hidden_size``.
+ - dueling (:obj:`dueling`): Whether choose ``DuelingHead`` or ``DiscreteHead(default)``.
+ - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network.
+ - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
+ if ``None`` then default set it to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
+ ``ding.torch_utils.fc_block`` for more details.
+ - strides (:obj:`Optional[list]`): The strides for each convolution layers, such as [2, 2, 2]. The length \
+ of this argument should be the same as ``encoder_hidden_size_list``.
+ """
+ super(DiscreteBC, self).__init__()
+ # For compatibility: 1, (1, ), [4, 32, 32]
+ obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
+ if head_hidden_size is None:
+ head_hidden_size = encoder_hidden_size_list[-1]
+ # FC Encoder
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ # Conv Encoder
+ elif len(obs_shape) == 3:
+ if not strides:
+ self.encoder = ConvEncoder(
+ obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
+ )
+ else:
+ self.encoder = ConvEncoder(
+ obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type, stride=strides
+ )
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own BC".format(obs_shape)
+ )
+ # Head Type
+ if dueling:
+ head_cls = DuelingHead
+ else:
+ head_cls = DiscreteHead
+ multi_head = not isinstance(action_shape, int)
+ if multi_head:
+ self.head = MultiHead(
+ head_cls,
+ head_hidden_size,
+ action_shape,
+ layer_num=head_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+ else:
+ self.head = head_cls(
+ head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ DiscreteBC forward computation graph, input observation tensor to predict q_value.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Observation inputs
+ Returns:
+ - outputs (:obj:`Dict`): DiscreteBC forward outputs, such as q_value.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): Discrete Q-value output of each action dimension.
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape``
+ Examples:
+ >>> model = DiscreteBC(32, 6) # arguments: 'obs_shape' and 'action_shape'
+ >>> inputs = torch.randn(4, 32)
+ >>> outputs = model(inputs)
+ >>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 6])
+ """
+ x = self.encoder(x)
+ x = self.head(x)
+ return x
+
+
+@MODEL_REGISTRY.register('continuous_bc')
+class ContinuousBC(nn.Module):
+ """
+ Overview:
+ The ContinuousBC network.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType, EasyDict],
+ action_space: str,
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize the ContinuousBC Model according to input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ).
+ - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \
+ EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
+ - action_space (:obj:`str`): The type of action space, \
+ including [``regression``, ``reparameterization``].
+ - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head.
+ - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
+ for actor head.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \
+ after each FC layer, if ``None`` then default set to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \
+ see ``ding.torch_utils.network`` for more details.
+ """
+ super(ContinuousBC, self).__init__()
+ obs_shape: int = squeeze(obs_shape)
+ action_shape = squeeze(action_shape)
+ self.action_shape = action_shape
+ self.action_space = action_space
+ assert self.action_space in ['regression', 'reparameterization']
+ if self.action_space == 'regression':
+ self.actor = nn.Sequential(
+ nn.Linear(obs_shape, actor_head_hidden_size), activation,
+ RegressionHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ final_tanh=True,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ elif self.action_space == 'reparameterization':
+ self.actor = nn.Sequential(
+ nn.Linear(obs_shape, actor_head_hidden_size), activation,
+ ReparameterizationHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ sigma_type='conditioned',
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+
+ def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]]) -> Dict:
+ """
+ Overview:
+ The unique execution (forward) method of ContinuousBC.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor.
+ Returns:
+ - output (:obj:`Dict`): Output dict data, including different key-values among distinct action_space.
+ ReturnsKeys:
+ - action (:obj:`torch.Tensor`): action output of actor network, \
+ with shape :math:`(B, action_shape)`.
+ - logit (:obj:`List[torch.Tensor]`): reparameterized action output of actor network, \
+ with shape :math:`(B, action_shape)`.
+ Shapes:
+ - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
+ - action (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape``
+ - logit (:obj:`List[torch.FloatTensor]`): :math:`(B, M)`, where B is batch size and M is ``action_shape``
+ Examples (Regression):
+ >>> model = ContinuousBC(32, 6, action_space='regression')
+ >>> inputs = torch.randn(4, 32)
+ >>> outputs = model(inputs)
+ >>> assert isinstance(outputs, dict) and outputs['action'].shape == torch.Size([4, 6])
+ Examples (Reparameterization):
+ >>> model = ContinuousBC(32, 6, action_space='reparameterization')
+ >>> inputs = torch.randn(4, 32)
+ >>> outputs = model(inputs)
+ >>> assert isinstance(outputs, dict) and outputs['logit'][0].shape == torch.Size([4, 6])
+ >>> assert outputs['logit'][1].shape == torch.Size([4, 6])
+ """
+ if self.action_space == 'regression':
+ x = self.actor(inputs)
+ return {'action': x['pred']}
+ elif self.action_space == 'reparameterization':
+ x = self.actor(inputs)
+ return {'logit': [x['mu'], x['sigma']]}
diff --git a/DI-engine/ding/model/template/bcq.py b/DI-engine/ding/model/template/bcq.py
new file mode 100755
index 0000000000000000000000000000000000000000..0e72927a765cf2a3709d945fba11cf118db8b04d
--- /dev/null
+++ b/DI-engine/ding/model/template/bcq.py
@@ -0,0 +1,210 @@
+from typing import Union, Dict, Optional, List
+from easydict import EasyDict
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
+from ..common import RegressionHead, ReparameterizationHead
+from .vae import VanillaVAE
+
+
+@MODEL_REGISTRY.register('bcq')
+class BCQ(nn.Module):
+ """
+ Overview:
+ Model of BCQ (Batch-Constrained deep Q-learning).
+ Off-Policy Deep Reinforcement Learning without Exploration.
+ https://arxiv.org/abs/1812.02900
+ Interface:
+ ``forward``, ``compute_actor``, ``compute_critic``, ``compute_vae``, ``compute_eval``
+ Property:
+ ``mode``
+ """
+
+ mode = ['compute_actor', 'compute_critic', 'compute_vae', 'compute_eval']
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType, EasyDict],
+ actor_head_hidden_size: List = [400, 300],
+ critic_head_hidden_size: List = [400, 300],
+ activation: Optional[nn.Module] = nn.ReLU(),
+ vae_hidden_dims: List = [750, 750],
+ phi: float = 0.05
+ ) -> None:
+ """
+ Overview:
+ Initialize neural network, i.e. agent Q network and actor.
+ Arguments:
+ - obs_shape (:obj:`int`): the dimension of observation state
+ - action_shape (:obj:`int`): the dimension of action shape
+ - actor_hidden_size (:obj:`list`): the list of hidden size of actor
+ - critic_hidden_size (:obj:'list'): the list of hidden size of critic
+ - activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU().
+ - vae_hidden_dims (:obj:`list`): the list of hidden size of vae
+ """
+ super(BCQ, self).__init__()
+ obs_shape: int = squeeze(obs_shape)
+ action_shape = squeeze(action_shape)
+ self.action_shape = action_shape
+ self.input_size = obs_shape
+ self.phi = phi
+
+ critic_input_size = self.input_size + action_shape
+ self.critic = nn.ModuleList()
+ for _ in range(2):
+ net = []
+ d = critic_input_size
+ for dim in critic_head_hidden_size:
+ net.append(nn.Linear(d, dim))
+ net.append(activation)
+ d = dim
+ net.append(nn.Linear(d, 1))
+ self.critic.append(nn.Sequential(*net))
+
+ net = []
+ d = critic_input_size
+ for dim in actor_head_hidden_size:
+ net.append(nn.Linear(d, dim))
+ net.append(activation)
+ d = dim
+ net.append(nn.Linear(d, 1))
+ self.actor = nn.Sequential(*net)
+
+ self.vae = VanillaVAE(action_shape, obs_shape, action_shape * 2, vae_hidden_dims)
+
+ def forward(self, inputs: Dict[str, torch.Tensor], mode: str) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ The unique execution (forward) method of BCQ method, and one can indicate different modes to implement \
+ different computation graph, including ``compute_actor`` and ``compute_critic`` in BCQ.
+ Mode compute_actor:
+ Arguments:
+ - inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
+ Returns:
+ - output (:obj:`Dict`): Output dict data, including action tensor.
+ Mode compute_critic:
+ Arguments:
+ - inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
+ Returns:
+ - output (:obj:`Dict`): Output dict data, including q_value tensor.
+ Mode compute_vae:
+ Arguments:
+ - inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \
+ (:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \
+ ``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \
+ ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`).
+ Mode compute_eval:
+ Arguments:
+ - inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
+ Returns:
+ - output (:obj:`Dict`): Output dict data, including action tensor.
+ Examples:
+ >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)}
+ >>> model = BCQ(32, 6)
+ >>> outputs = model(inputs, mode='compute_actor')
+ >>> outputs = model(inputs, mode='compute_critic')
+ >>> outputs = model(inputs, mode='compute_vae')
+ >>> outputs = model(inputs, mode='compute_eval')
+
+ .. note::
+ For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively.
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Use critic network to compute q value.
+ Arguments:
+ - inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``q_value`` (:obj:`torch.Tensor`).
+ Shapes:
+ - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension.
+ - outputs (:obj:`Dict`): :math:`(B, N)`.
+ Examples:
+ >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)}
+ >>> model = BCQ(32, 6)
+ >>> outputs = model.compute_critic(inputs)
+ """
+ obs, action = inputs['obs'], inputs['action']
+ if len(action.shape) == 1: # (B, ) -> (B, 1)
+ action = action.unsqueeze(1)
+ x = torch.cat([obs, action], dim=-1)
+ x = [m(x).squeeze() for m in self.critic]
+ return {'q_value': x}
+
+ def compute_actor(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
+ """
+ Overview:
+ Use actor network to compute action.
+ Arguments:
+ - inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``action`` (:obj:`torch.Tensor`).
+ Shapes:
+ - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension.
+ - outputs (:obj:`Dict`): :math:`(B, N)`.
+ Examples:
+ >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)}
+ >>> model = BCQ(32, 6)
+ >>> outputs = model.compute_actor(inputs)
+ """
+ input = torch.cat([inputs['obs'], inputs['action']], -1)
+ x = self.actor(input)
+ action = self.phi * 1 * torch.tanh(x)
+ action = (action + inputs['action']).clamp(-1, 1)
+ return {'action': action}
+
+ def compute_vae(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Use vae network to compute action.
+ Arguments:
+ - inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` (:obj:`torch.Tensor`), \
+ ``prediction_residual`` (:obj:`torch.Tensor`), ``input`` (:obj:`torch.Tensor`), \
+ ``mu`` (:obj:`torch.Tensor`), ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`).
+ Shapes:
+ - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension.
+ - outputs (:obj:`Dict`): :math:`(B, N)`.
+ Examples:
+ >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)}
+ >>> model = BCQ(32, 6)
+ >>> outputs = model.compute_vae(inputs)
+ """
+ return self.vae.forward(inputs)
+
+ def compute_eval(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Use actor network to compute action.
+ Arguments:
+ - inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``action`` (:obj:`torch.Tensor`).
+ Shapes:
+ - inputs (:obj:`Dict`): :math:`(B, N, D)`, where B is batch size, N is sample number, D is input dimension.
+ - outputs (:obj:`Dict`): :math:`(B, N)`.
+ Examples:
+ >>> inputs = {'obs': torch.randn(4, 32), 'action': torch.randn(4, 6)}
+ >>> model = BCQ(32, 6)
+ >>> outputs = model.compute_eval(inputs)
+ """
+ obs = inputs['obs']
+ obs_rep = obs.clone().unsqueeze(0).repeat_interleave(100, dim=0)
+ z = torch.randn((obs_rep.shape[0], obs_rep.shape[1], self.action_shape * 2)).to(obs.device).clamp(-0.5, 0.5)
+ sample_action = self.vae.decode_with_obs(z, obs_rep)['reconstruction_action']
+ action = self.compute_actor({'obs': obs_rep, 'action': sample_action})['action']
+ q = self.compute_critic({'obs': obs_rep, 'action': action})['q_value'][0]
+ idx = q.argmax(dim=0).unsqueeze(0).unsqueeze(-1)
+ idx = idx.repeat_interleave(action.shape[-1], dim=-1)
+ action = action.gather(0, idx).squeeze()
+ return {'action': action}
diff --git a/DI-engine/ding/model/template/collaq.py b/DI-engine/ding/model/template/collaq.py
new file mode 100644
index 0000000000000000000000000000000000000000..9872d0684a9e3856fa93a8b5d914e1ebb0710c58
--- /dev/null
+++ b/DI-engine/ding/model/template/collaq.py
@@ -0,0 +1,494 @@
+from typing import Union, List
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import reduce
+from ding.utils import list_split, MODEL_REGISTRY
+from ding.torch_utils import fc_block, MLP, ScaledDotProductAttention
+from .q_learning import DRQN
+from .qmix import Mixer
+
+
+class CollaQMultiHeadAttention(nn.Module):
+ """
+ Overview:
+ The head of collaq attention module.
+ Interface:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ n_head: int,
+ d_model_q: int,
+ d_model_v: int,
+ d_k: int,
+ d_v: int,
+ d_out: int,
+ dropout: float = 0.,
+ activation: nn.Module = nn.ReLU()
+ ):
+ """
+ Overview:
+ initialize the head of collaq attention module
+ Arguments:
+ - n_head (:obj:`int`): the num of head
+ - d_model_q (:obj:`int`): the size of input q
+ - d_model_v (:obj:`int`): the size of input v
+ - d_k (:obj:`int`): the size of k, used by Scaled Dot Product Attention
+ - d_v (:obj:`int`): the size of v, used by Scaled Dot Product Attention
+ - d_out (:obj:`int`): the size of output q
+ - dropout (:obj:`float`): Dropout ratio, defaults to 0.
+ - activation (:obj:`nn.Module`): Activation in FFN after attention.
+ """
+ super(CollaQMultiHeadAttention, self).__init__()
+
+ self.act = activation
+
+ self.n_head = n_head
+ self.d_k = d_k
+ self.d_v = d_v
+
+ self.w_qs = nn.Linear(d_model_q, n_head * d_k)
+ self.w_ks = nn.Linear(d_model_v, n_head * d_k)
+ self.w_vs = nn.Linear(d_model_v, n_head * d_v)
+
+ self.fc1 = fc_block(n_head * d_v, n_head * d_v, activation=self.act)
+ self.fc2 = fc_block(n_head * d_v, d_out)
+
+ self.attention = ScaledDotProductAttention(d_k=d_k)
+ self.layer_norm_q = nn.LayerNorm(n_head * d_k, eps=1e-6)
+ self.layer_norm_k = nn.LayerNorm(n_head * d_k, eps=1e-6)
+ self.layer_norm_v = nn.LayerNorm(n_head * d_v, eps=1e-6)
+
+ def forward(self, q, k, v, mask=None):
+ """
+ Overview:
+ forward computation graph of collaQ multi head attention net.
+ Arguments:
+ - q (:obj:`torch.nn.Sequential`): the transformer information q
+ - k (:obj:`torch.nn.Sequential`): the transformer information k
+ - v (:obj:`torch.nn.Sequential`): the transformer information v
+ Returns:
+ - q (:obj:`torch.nn.Sequential`): the transformer output q
+ - residual (:obj:`torch.nn.Sequential`): the transformer output residual
+ Shapes:
+ - q (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \
+ N is the size of input q
+ - k (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \
+ N is the size of input k
+ - v (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \
+ N is the size of input v
+ - q (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \
+ N is the size of output q
+ - residual (:obj:`torch.nn.Sequential`): :math:`(B, L, N)` where B is batch_size, L is sequence length, \
+ N is the size of output residual
+ Examples:
+ >>> net = CollaQMultiHeadAttention(1, 2, 3, 4, 5, 6)
+ >>> q = torch.randn(1, 2, 2)
+ >>> k = torch.randn(1, 3, 3)
+ >>> v = torch.randn(1, 3, 3)
+ >>> q, residual = net(q, k, v)
+ """
+ d_k, d_v, n_head = self.d_k, self.d_v, self.n_head
+ batch_size, len_q, len_k, len_v = q.size(0), q.size(1), k.size(1), v.size(1)
+
+ # Pass through the pre-attention projection: batch_size x len_q x (n_head * d_v)
+ # Separate different heads: batch_size x len_q x n_head x d_v
+ q = self.w_qs(q).view(batch_size, len_q, n_head, d_k)
+ k = self.w_ks(k).view(batch_size, len_k, n_head, d_k)
+ v = self.w_vs(v).view(batch_size, len_v, n_head, d_v)
+ residual = q
+
+ # Transpose for attention dot product: batch_size x n_head x len_q x d_v
+ q, k, v = self.layer_norm_q(q).transpose(1, 2), self.layer_norm_k(k).transpose(
+ 1, 2
+ ), self.layer_norm_v(v).transpose(1, 2)
+ # Unsqueeze the mask tensor for head axis broadcasting
+ if mask is not None:
+ mask = mask.unsqueeze(1)
+ q = self.attention(q, k, v, mask=mask)
+
+ # Transpose to move the head dimension back: batch_size x len_q x n_head x d_v
+ # Combine the last two dimensions to concatenate all the heads together: batch_size x len_q x (n*dv)
+ q = q.transpose(1, 2).contiguous().view(batch_size, len_q, -1)
+ q = self.fc2(self.fc1(q))
+ return q, residual
+
+
+class CollaQSMACAttentionModule(nn.Module):
+ """
+ Overview:
+ Collaq attention module. Used to get agent's attention observation. It includes agent's observation\
+ and agent's part of the observation information of the agent's concerned allies
+ Interface:
+ ``__init__``, ``_cut_obs``, ``forward``
+ """
+
+ def __init__(
+ self,
+ q_dim: int,
+ v_dim: int,
+ self_feature_range: List[int],
+ ally_feature_range: List[int],
+ attention_size: int,
+ activation: nn.Module = nn.ReLU()
+ ):
+ """
+ Overview:
+ initialize collaq attention module
+ Arguments:
+ - q_dim (:obj:`int`): the dimension of transformer output q
+ - v_dim (:obj:`int`): the dimension of transformer output v
+ - self_features (:obj:`torch.Tensor`): output self agent's attention observation
+ - ally_features (:obj:`torch.Tensor`): output ally agent's attention observation
+ - attention_size (:obj:`int`): the size of attention net layer
+ - activation (:obj:`nn.Module`): Activation in FFN after attention.
+ """
+ super(CollaQSMACAttentionModule, self).__init__()
+ self.self_feature_range = self_feature_range
+ self.ally_feature_range = ally_feature_range
+ self.attention_layer = CollaQMultiHeadAttention(
+ 1, q_dim, v_dim, attention_size, attention_size, attention_size, activation=activation
+ )
+
+ def _cut_obs(self, obs: torch.Tensor):
+ """
+ Overview:
+ cut the observed information into self's observation and allay's observation
+ Arguments:
+ - obs (:obj:`torch.Tensor`): input each agent's observation
+ Returns:
+ - self_features (:obj:`torch.Tensor`): output self agent's attention observation
+ - ally_features (:obj:`torch.Tensor`): output ally agent's attention observation
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(T, B, A, N)` where T is timestep, B is batch_size, \
+ A is agent_num, N is obs_shape
+ - self_features (:obj:`torch.Tensor`): :math:`(T, B, A, N)` where T is timestep, B is batch_size, \
+ A is agent_num, N is self_feature_range[1] - self_feature_range[0]
+ - ally_features (:obj:`torch.Tensor`): :math:`(T, B, A, N)` where T is timestep, B is batch_size, \
+ A is agent_num, N is ally_feature_range[1] - ally_feature_range[0]
+ """
+ # obs shape = (T, B, A, obs_shape)
+ self_features = obs[:, :, :, self.self_feature_range[0]:self.self_feature_range[1]]
+ ally_features = obs[:, :, :, self.ally_feature_range[0]:self.ally_feature_range[1]]
+ return self_features, ally_features
+
+ def forward(self, inputs: torch.Tensor):
+ """
+ Overview:
+ forward computation to get agent's attention observation information
+ Arguments:
+ - obs (:obj:`torch.Tensor`): input each agent's observation
+ Returns:
+ - obs (:obj:`torch.Tensor`): output agent's attention observation
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(T, B, A, N)` where T is timestep, B is batch_size, \
+ A is agent_num, N is obs_shape
+ """
+ # obs shape = (T, B ,A, obs_shape)
+ obs = inputs
+ self_features, ally_features = self._cut_obs(obs)
+ T, B, A, _ = self_features.shape
+ self_features = self_features.reshape(T * B * A, 1, -1)
+ ally_features = ally_features.reshape(T * B * A, A - 1, -1)
+ self_features, ally_features = self.attention_layer(self_features, ally_features, ally_features)
+ self_features = self_features.reshape(T, B, A, -1)
+ ally_features = ally_features.reshape(T, B, A, -1)
+ # note: we assume self_feature is near the ally_feature here so we can do this concat
+ obs = torch.cat(
+ [
+ obs[:, :, :, :self.self_feature_range[0]], self_features, ally_features,
+ obs[:, :, :, self.ally_feature_range[1]:]
+ ],
+ dim=-1
+ )
+ return obs
+
+
+@MODEL_REGISTRY.register('collaq')
+class CollaQ(nn.Module):
+ """
+ Overview:
+ The network of CollaQ (Collaborative Q-learning) algorithm.
+ It includes two parts: q_network and q_alone_network.
+ The q_network is used to get the q_value of the agent's observation and \
+ the agent's part of the observation information of the agent's concerned allies.
+ The q_alone_network is used to get the q_value of the agent's observation and \
+ the agent's observation information without the agent's concerned allies.
+ Multi-Agent Collaboration via Reward Attribution Decomposition
+ https://arxiv.org/abs/2010.08531
+ Interface:
+ ``__init__``, ``forward``, ``_setup_global_encoder``
+ """
+
+ def __init__(
+ self,
+ agent_num: int,
+ obs_shape: int,
+ alone_obs_shape: int,
+ global_obs_shape: int,
+ action_shape: int,
+ hidden_size_list: list,
+ attention: bool = False,
+ self_feature_range: Union[List[int], None] = None,
+ ally_feature_range: Union[List[int], None] = None,
+ attention_size: int = 32,
+ mixer: bool = True,
+ lstm_type: str = 'gru',
+ activation: nn.Module = nn.ReLU(),
+ dueling: bool = False,
+ ) -> None:
+ """
+ Overview:
+ Initialize Collaq network.
+ Arguments:
+ - agent_num (:obj:`int`): the number of agent
+ - obs_shape (:obj:`int`): the dimension of each agent's observation state
+ - alone_obs_shape (:obj:`int`): the dimension of each agent's observation state without\
+ other agents
+ - global_obs_shape (:obj:`int`): the dimension of global observation state
+ - action_shape (:obj:`int`): the dimension of action shape
+ - hidden_size_list (:obj:`list`): the list of hidden size
+ - attention (:obj:`bool`): use attention module or not, default to False
+ - self_feature_range (:obj:`Union[List[int], None]`): the agent's feature range
+ - ally_feature_range (:obj:`Union[List[int], None]`): the agent ally's feature range
+ - attention_size (:obj:`int`): the size of attention net layer
+ - mixer (:obj:`bool`): use mixer net or not, default to True
+ - lstm_type (:obj:`str`): use lstm or gru, default to gru
+ - activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU().
+ - dueling (:obj:`bool`): use dueling head or not, default to False.
+ """
+ super(CollaQ, self).__init__()
+ self.attention = attention
+ self.attention_size = attention_size
+ self._act = activation
+ self.mixer = mixer
+ if not self.attention:
+ self._q_network = DRQN(
+ obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling, activation=activation
+ )
+ else:
+ # TODO set the attention layer here beautifully
+ self._self_attention = CollaQSMACAttentionModule(
+ self_feature_range[1] - self_feature_range[0],
+ (ally_feature_range[1] - ally_feature_range[0]) // (agent_num - 1),
+ self_feature_range,
+ ally_feature_range,
+ attention_size,
+ activation=activation
+ )
+ # TODO get the obs_dim_after_attention here beautifully
+ obs_shape_after_attention = self._self_attention(
+ # torch.randn(
+ # 1, 1, (ally_feature_range[1] - ally_feature_range[0]) //
+ # ((self_feature_range[1] - self_feature_range[0])*2) + 1, obs_dim
+ # )
+ torch.randn(1, 1, agent_num, obs_shape)
+ ).shape[-1]
+ self._q_network = DRQN(
+ obs_shape_after_attention,
+ action_shape,
+ hidden_size_list,
+ lstm_type=lstm_type,
+ dueling=dueling,
+ activation=activation
+ )
+ self._q_alone_network = DRQN(
+ alone_obs_shape,
+ action_shape,
+ hidden_size_list,
+ lstm_type=lstm_type,
+ dueling=dueling,
+ activation=activation
+ )
+ embedding_size = hidden_size_list[-1]
+ if self.mixer:
+ self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
+ self._global_state_encoder = nn.Identity()
+
+ def forward(self, data: dict, single_step: bool = True) -> dict:
+ """
+ Overview:
+ The forward method calculates the q_value of each agent and the total q_value of all agents.
+ The q_value of each agent is calculated by the q_network, and the total q_value is calculated by the mixer.
+ Arguments:
+ - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action']
+ - agent_state (:obj:`torch.Tensor`): each agent local state(obs)
+ - agent_alone_state (:obj:`torch.Tensor`): each agent's local state alone, \
+ in smac setting is without ally feature(obs_along)
+ - global_state (:obj:`torch.Tensor`): global state(obs)
+ - prev_state (:obj:`list`): previous rnn state, should include 3 parts: \
+ one hidden state of q_network, and two hidden state if q_alone_network for obs and obs_alone inputs
+ - action (:obj:`torch.Tensor` or None): if action is None, use argmax q_value index as action to\
+ calculate ``agent_q_act``
+ - single_step (:obj:`bool`): whether single_step forward, if so, add timestep dim before forward and\
+ remove it after forward
+ Return:
+ - ret (:obj:`dict`): output data dict with keys ['total_q', 'logit', 'next_state']
+ - total_q (:obj:`torch.Tensor`): total q_value, which is the result of mixer network
+ - agent_q (:obj:`torch.Tensor`): each agent q_value
+ - next_state (:obj:`list`): next rnn state
+ Shapes:
+ - agent_state (:obj:`torch.Tensor`): :math:`(T, B, A, N)`, where T is timestep, B is batch_size\
+ A is agent_num, N is obs_shape
+ - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape
+ - prev_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A
+ - action (:obj:`torch.Tensor`): :math:`(T, B, A)`
+ - total_q (:obj:`torch.Tensor`): :math:`(T, B)`
+ - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape
+ - next_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A
+ Examples:
+ >>> collaQ_model = CollaQ(
+ >>> agent_num=4,
+ >>> obs_shape=32,
+ >>> alone_obs_shape=24,
+ >>> global_obs_shape=32 * 4,
+ >>> action_shape=9,
+ >>> hidden_size_list=[128, 64],
+ >>> self_feature_range=[8, 10],
+ >>> ally_feature_range=[10, 16],
+ >>> attention_size=64,
+ >>> mixer=True,
+ >>> activation=torch.nn.Tanh()
+ >>> )
+ >>> data={
+ >>> 'obs': {
+ >>> 'agent_state': torch.randn(8, 4, 4, 32),
+ >>> 'agent_alone_state': torch.randn(8, 4, 4, 24),
+ >>> 'agent_alone_padding_state': torch.randn(8, 4, 4, 32),
+ >>> 'global_state': torch.randn(8, 4, 32 * 4),
+ >>> 'action_mask': torch.randint(0, 2, size=(8, 4, 4, 9))
+ >>> },
+ >>> 'prev_state': [[[None for _ in range(4)] for _ in range(3)] for _ in range(4)],
+ >>> 'action': torch.randint(0, 9, size=(8, 4, 4))
+ >>> }
+ >>> output = collaQ_model(data, single_step=False)
+ """
+ agent_state, agent_alone_state = data['obs']['agent_state'], data['obs']['agent_alone_state']
+ agent_alone_padding_state = data['obs']['agent_alone_padding_state']
+ global_state, prev_state = data['obs']['global_state'], data['prev_state']
+ # TODO find a better way to implement agent_along_padding_state
+
+ action = data.get('action', None)
+ if single_step:
+ agent_state, agent_alone_state, agent_alone_padding_state, global_state = agent_state.unsqueeze(
+ 0
+ ), agent_alone_state.unsqueeze(0), agent_alone_padding_state.unsqueeze(0), global_state.unsqueeze(0)
+ T, B, A = agent_state.shape[:3]
+
+ if self.attention:
+ agent_state = self._self_attention(agent_state)
+ agent_alone_padding_state = self._self_attention(agent_alone_padding_state)
+
+ # prev state should be of size (B, 3, A) hidden_size)
+ """
+ Note: to achieve such work, we should change the init_fn of hidden_state plugin in collaQ policy
+ """
+ assert len(prev_state) == B and all([len(p) == 3 for p in prev_state]) and all(
+ [len(q) == A] for p in prev_state for q in p
+ ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0]))
+
+ alone_prev_state = [[None for _ in range(A)] for _ in range(B)]
+ colla_prev_state = [[None for _ in range(A)] for _ in range(B)]
+ colla_alone_prev_state = [[None for _ in range(A)] for _ in range(B)]
+
+ for i in range(B):
+ for j in range(3):
+ for k in range(A):
+ if j == 0:
+ alone_prev_state[i][k] = prev_state[i][j][k]
+ elif j == 1:
+ colla_prev_state[i][k] = prev_state[i][j][k]
+ elif j == 2:
+ colla_alone_prev_state[i][k] = prev_state[i][j][k]
+
+ alone_prev_state = reduce(lambda x, y: x + y, alone_prev_state)
+ colla_prev_state = reduce(lambda x, y: x + y, colla_prev_state)
+ colla_alone_prev_state = reduce(lambda x, y: x + y, colla_alone_prev_state)
+
+ agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
+ agent_alone_state = agent_alone_state.reshape(T, -1, *agent_alone_state.shape[3:])
+ agent_alone_padding_state = agent_alone_padding_state.reshape(T, -1, *agent_alone_padding_state.shape[3:])
+
+ colla_output = self._q_network(
+ {
+ 'obs': agent_state,
+ 'prev_state': colla_prev_state,
+ 'enable_fast_timestep': True
+ }
+ )
+ colla_alone_output = self._q_network(
+ {
+ 'obs': agent_alone_padding_state,
+ 'prev_state': colla_alone_prev_state,
+ 'enable_fast_timestep': True
+ }
+ )
+ alone_output = self._q_alone_network(
+ {
+ 'obs': agent_alone_state,
+ 'prev_state': alone_prev_state,
+ 'enable_fast_timestep': True
+ }
+ )
+
+ agent_alone_q, alone_next_state = alone_output['logit'], alone_output['next_state']
+ agent_colla_alone_q, colla_alone_next_state = colla_alone_output['logit'], colla_alone_output['next_state']
+ agent_colla_q, colla_next_state = colla_output['logit'], colla_output['next_state']
+
+ colla_next_state, _ = list_split(colla_next_state, step=A)
+ alone_next_state, _ = list_split(alone_next_state, step=A)
+ colla_alone_next_state, _ = list_split(colla_alone_next_state, step=A)
+
+ next_state = list(
+ map(lambda x: [x[0], x[1], x[2]], zip(alone_next_state, colla_next_state, colla_alone_next_state))
+ )
+
+ agent_alone_q = agent_alone_q.reshape(T, B, A, -1)
+ agent_colla_alone_q = agent_colla_alone_q.reshape(T, B, A, -1)
+ agent_colla_q = agent_colla_q.reshape(T, B, A, -1)
+
+ total_q_before_mix = agent_alone_q + agent_colla_q - agent_colla_alone_q
+ # total_q_before_mix = agent_colla_q
+ # total_q_before_mix = agent_alone_q
+ agent_q = total_q_before_mix
+
+ if action is None:
+ # For target forward process
+ if len(data['obs']['action_mask'].shape) == 3:
+ action_mask = data['obs']['action_mask'].unsqueeze(0)
+ else:
+ action_mask = data['obs']['action_mask']
+ agent_q[action_mask == 0.0] = -9999999
+ action = agent_q.argmax(dim=-1)
+ agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
+ agent_q_act = agent_q_act.squeeze(-1) # T, B, A
+ if self.mixer:
+ global_state_embedding = self._global_state_encoder(global_state)
+ total_q = self._mixer(agent_q_act, global_state_embedding)
+ else:
+ total_q = agent_q_act.sum(-1)
+ if single_step:
+ total_q, agent_q, agent_colla_alone_q = total_q.squeeze(0), agent_q.squeeze(0), agent_colla_alone_q.squeeze(
+ 0
+ )
+ return {
+ 'total_q': total_q,
+ 'logit': agent_q,
+ 'agent_colla_alone_q': agent_colla_alone_q * data['obs']['action_mask'],
+ 'next_state': next_state,
+ 'action_mask': data['obs']['action_mask']
+ }
+
+ def _setup_global_encoder(self, global_obs_shape: int, embedding_size: int) -> torch.nn.Module:
+ """
+ Overview:
+ Used to encoder global observation.
+ Arguments:
+ - global_obs_shape (:obj:`int`): the dimension of global observation state
+ - embedding_size (:obj:`int`): the dimension of state emdedding
+ Returns:
+ - outputs (:obj:`torch.nn.Module`): Global observation encoding network
+ """
+ return MLP(global_obs_shape, embedding_size, embedding_size, 2, activation=self._act)
diff --git a/DI-engine/ding/model/template/coma.py b/DI-engine/ding/model/template/coma.py
new file mode 100644
index 0000000000000000000000000000000000000000..02eb286e842cb91f0bd9362cf1336a1c251486df
--- /dev/null
+++ b/DI-engine/ding/model/template/coma.py
@@ -0,0 +1,275 @@
+from typing import Dict, Union
+import torch
+import torch.nn as nn
+
+from functools import reduce
+from ding.torch_utils import one_hot, MLP
+from ding.utils import squeeze, list_split, MODEL_REGISTRY, SequenceType
+from .q_learning import DRQN
+
+
+class COMAActorNetwork(nn.Module):
+ """
+ Overview:
+ Decentralized actor network in COMA algorithm.
+ Interface:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ obs_shape: int,
+ action_shape: int,
+ hidden_size_list: SequenceType = [128, 128, 64],
+ ):
+ """
+ Overview:
+ Initialize COMA actor network
+ Arguments:
+ - obs_shape (:obj:`int`): the dimension of each agent's observation state
+ - action_shape (:obj:`int`): the dimension of action shape
+ - hidden_size_list (:obj:`list`): the list of hidden size, default to [128, 128, 64]
+ """
+ super(COMAActorNetwork, self).__init__()
+ self.main = DRQN(obs_shape, action_shape, hidden_size_list)
+
+ def forward(self, inputs: Dict) -> Dict:
+ """
+ Overview:
+ The forward computation graph of COMA actor network
+ Arguments:
+ - inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state']
+ - agent_state (:obj:`torch.Tensor`): each agent local state(obs)
+ - action_mask (:obj:`torch.Tensor`): the masked action
+ - prev_state (:obj:`torch.Tensor`): the previous hidden state
+ Returns:
+ - output (:obj:`dict`): output data dict with keys ['logit', 'next_state', 'action_mask']
+ ArgumentsKeys:
+ - necessary: ``obs`` { ``agent_state``, ``action_mask`` }, ``prev_state``
+ ReturnsKeys:
+ - necessary: ``logit``, ``next_state``, ``action_mask``
+ Examples:
+ >>> T, B, A, N = 4, 8, 3, 32
+ >>> embedding_dim = 64
+ >>> action_dim = 6
+ >>> data = torch.randn(T, B, A, N)
+ >>> model = COMAActorNetwork((N, ), action_dim, [128, embedding_dim])
+ >>> prev_state = [[None for _ in range(A)] for _ in range(B)]
+ >>> for t in range(T):
+ >>> inputs = {'obs': {'agent_state': data[t], 'action_mask': None}, 'prev_state': prev_state}
+ >>> outputs = model(inputs)
+ >>> logit, prev_state = outputs['logit'], outputs['next_state']
+ """
+ agent_state = inputs['obs']['agent_state']
+ prev_state = inputs['prev_state']
+ if len(agent_state.shape) == 3: # B, A, N
+ agent_state = agent_state.unsqueeze(0)
+ unsqueeze_flag = True
+ else:
+ unsqueeze_flag = False
+ T, B, A = agent_state.shape[:3]
+ agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
+ prev_state = reduce(lambda x, y: x + y, prev_state)
+ output = self.main({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
+ logit, next_state = output['logit'], output['next_state']
+ next_state, _ = list_split(next_state, step=A)
+ logit = logit.reshape(T, B, A, -1)
+ if unsqueeze_flag:
+ logit = logit.squeeze(0)
+ return {'logit': logit, 'next_state': next_state, 'action_mask': inputs['obs']['action_mask']}
+
+
+class COMACriticNetwork(nn.Module):
+ """
+ Overview:
+ Centralized critic network in COMA algorithm.
+ Interface:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ action_shape: int,
+ hidden_size: int = 128,
+ ):
+ """
+ Overview:
+ initialize COMA critic network
+ Arguments:
+ - input_size (:obj:`int`): the size of input global observation
+ - action_shape (:obj:`int`): the dimension of action shape
+ - hidden_size_list (:obj:`list`): the list of hidden size, default to 128
+ Returns:
+ - output (:obj:`dict`): output data dict with keys ['q_value']
+ Shapes:
+ - obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)`
+ - prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]`
+ - logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)`
+ - next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]`
+ - action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)`
+ """
+ super(COMACriticNetwork, self).__init__()
+ self.action_shape = action_shape
+ self.act = nn.ReLU()
+ self.mlp = nn.Sequential(
+ MLP(input_size, hidden_size, hidden_size, 2, activation=self.act), nn.Linear(hidden_size, action_shape)
+ )
+
+ def forward(self, data: Dict) -> Dict:
+ """
+ Overview:
+ forward computation graph of qmix network
+ Arguments:
+ - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action']
+ - agent_state (:obj:`torch.Tensor`): each agent local state(obs)
+ - global_state (:obj:`torch.Tensor`): global state(obs)
+ - action (:obj:`torch.Tensor`): the masked action
+ ArgumentsKeys:
+ - necessary: ``obs`` { ``agent_state``, ``global_state`` }, ``action``, ``prev_state``
+ ReturnsKeys:
+ - necessary: ``q_value``
+ Examples:
+ >>> agent_num, bs, T = 4, 3, 8
+ >>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
+ >>> coma_model = COMACriticNetwork(
+ >>> obs_dim - action_dim + global_obs_dim + 2 * action_dim * agent_num, action_dim)
+ >>> data = {
+ >>> 'obs': {
+ >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim),
+ >>> 'global_state': torch.randn(T, bs, global_obs_dim),
+ >>> },
+ >>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)),
+ >>> }
+ >>> output = coma_model(data)
+ """
+ x = self._preprocess_data(data)
+ q = self.mlp(x)
+ return {'q_value': q}
+
+ def _preprocess_data(self, data: Dict) -> torch.Tensor:
+ """
+ Overview:
+ preprocess data to make it can be used by MLP net
+ Arguments:
+ - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action']
+ - agent_state (:obj:`torch.Tensor`): each agent local state(obs)
+ - global_state (:obj:`torch.Tensor`): global state(obs)
+ - action (:obj:`torch.Tensor`): the masked action
+ ArgumentsKeys:
+ - necessary: ``obs`` { ``agent_state``, ``global_state``} , ``action``, ``prev_state``
+ Return:
+ - x (:obj:`torch.Tensor`): the data can be used by MLP net, including \
+ ``global_state``, ``agent_state``, ``last_action``, ``action``, ``agent_id``
+ """
+ t_size, batch_size, agent_num = data['obs']['agent_state'].shape[:3]
+ agent_state_ori, global_state = data['obs']['agent_state'], data['obs']['global_state']
+
+ # splite obs, last_action and agent_id
+ agent_state = agent_state_ori[..., :-self.action_shape - agent_num]
+ last_action = agent_state_ori[..., -self.action_shape - agent_num:-agent_num]
+ last_action = last_action.reshape(t_size, batch_size, 1, -1).repeat(1, 1, agent_num, 1)
+ agent_id = agent_state_ori[..., -agent_num:]
+
+ action = one_hot(data['action'], self.action_shape) # T, B, A,N
+ action = action.reshape(t_size, batch_size, -1, agent_num * self.action_shape).repeat(1, 1, agent_num, 1)
+ action_mask = (1 - torch.eye(agent_num).to(action.device))
+ action_mask = action_mask.view(-1, 1).repeat(1, self.action_shape).view(agent_num, -1) # A, A*N
+ action = (action_mask.unsqueeze(0).unsqueeze(0)) * action # T, B, A, A*N
+ global_state = global_state.unsqueeze(2).repeat(1, 1, agent_num, 1)
+
+ x = torch.cat([global_state, agent_state, last_action, action, agent_id], -1)
+ return x
+
+
+@MODEL_REGISTRY.register('coma')
+class COMA(nn.Module):
+ """
+ Overview:
+ The network of COMA algorithm, which is QAC-type actor-critic.
+ Interface:
+ ``__init__``, ``forward``
+ Properties:
+ - mode (:obj:`list`): The list of forward mode, including ``compute_actor`` and ``compute_critic``
+ """
+
+ mode = ['compute_actor', 'compute_critic']
+
+ def __init__(
+ self, agent_num: int, obs_shape: Dict, action_shape: Union[int, SequenceType],
+ actor_hidden_size_list: SequenceType
+ ) -> None:
+ """
+ Overview:
+ initialize COMA network
+ Arguments:
+ - agent_num (:obj:`int`): the number of agent
+ - obs_shape (:obj:`Dict`): the observation information, including agent_state and \
+ global_state
+ - action_shape (:obj:`Union[int, SequenceType]`): the dimension of action shape
+ - actor_hidden_size_list (:obj:`SequenceType`): the list of hidden size
+ """
+ super(COMA, self).__init__()
+ action_shape = squeeze(action_shape)
+ actor_input_size = squeeze(obs_shape['agent_state'])
+ critic_input_size = squeeze(obs_shape['agent_state']) + squeeze(obs_shape['global_state']) + \
+ agent_num * action_shape + (agent_num - 1) * action_shape
+ critic_hidden_size = actor_hidden_size_list[-1]
+ self.actor = COMAActorNetwork(actor_input_size, action_shape, actor_hidden_size_list)
+ self.critic = COMACriticNetwork(critic_input_size, action_shape, critic_hidden_size)
+
+ def forward(self, inputs: Dict, mode: str) -> Dict:
+ """
+ Overview:
+ forward computation graph of COMA network
+ Arguments:
+ - inputs (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action']
+ - agent_state (:obj:`torch.Tensor`): each agent local state(obs)
+ - global_state (:obj:`torch.Tensor`): global state(obs)
+ - action (:obj:`torch.Tensor`): the masked action
+ ArgumentsKeys:
+ - necessary: ``obs`` { ``agent_state``, ``global_state``, ``action_mask`` }, ``action``, ``prev_state``
+ ReturnsKeys:
+ - necessary:
+ - compute_critic: ``q_value``
+ - compute_actor: ``logit``, ``next_state``, ``action_mask``
+ Shapes:
+ - obs (:obj:`dict`): ``agent_state``: :math:`(T, B, A, N, D)`, ``action_mask``: :math:`(T, B, A, N, A)`
+ - prev_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]`
+ - logit (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)`
+ - next_state (:obj:`list`): :math:`[[[h, c] for _ in range(A)] for _ in range(B)]`
+ - action_mask (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)`
+ - q_value (:obj:`torch.Tensor`): :math:`(T, B, A, N, A)`
+ Examples:
+ >>> agent_num, bs, T = 4, 3, 8
+ >>> agent_num, bs, T = 4, 3, 8
+ >>> obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
+ >>> coma_model = COMA(
+ >>> agent_num=agent_num,
+ >>> obs_shape=dict(agent_state=(obs_dim, ), global_state=(global_obs_dim, )),
+ >>> action_shape=action_dim,
+ >>> actor_hidden_size_list=[128, 64],
+ >>> )
+ >>> prev_state = [[None for _ in range(agent_num)] for _ in range(bs)]
+ >>> data = {
+ >>> 'obs': {
+ >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim),
+ >>> 'action_mask': None,
+ >>> },
+ >>> 'prev_state': prev_state,
+ >>> }
+ >>> output = coma_model(data, mode='compute_actor')
+ >>> data= {
+ >>> 'obs': {
+ >>> 'agent_state': torch.randn(T, bs, agent_num, obs_dim),
+ >>> 'global_state': torch.randn(T, bs, global_obs_dim),
+ >>> },
+ >>> 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)),
+ >>> }
+ >>> output = coma_model(data, mode='compute_critic')
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ if mode == 'compute_actor':
+ return self.actor(inputs)
+ elif mode == 'compute_critic':
+ return self.critic(inputs)
diff --git a/DI-engine/ding/model/template/decision_transformer.py b/DI-engine/ding/model/template/decision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d354973833c99a6746f8e7cb75f3ff4e2774e62
--- /dev/null
+++ b/DI-engine/ding/model/template/decision_transformer.py
@@ -0,0 +1,413 @@
+"""
+this extremely minimal Decision Transformer model is based on
+the following causal transformer (GPT) implementation:
+
+Misha Laskin's tweet:
+https://twitter.com/MishaLaskin/status/1481767788775628801?cxt=HHwWgoCzmYD9pZApAAAA
+
+and its corresponding notebook:
+https://colab.research.google.com/drive/1NUBqyboDcGte5qAJKOl8gaJC28V_73Iv?usp=sharing
+
+** the above colab notebook has a bug while applying masked_fill
+which is fixed in the following code
+"""
+
+import math
+from typing import Union, Optional, Tuple
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ding.utils import SequenceType
+
+
+class MaskedCausalAttention(nn.Module):
+ """
+ Overview:
+ The implementation of masked causal attention in decision transformer. The input of this module is a sequence \
+ of several tokens. For the calculated hidden embedding for the i-th token, it is only related the 0 to i-1 \
+ input tokens by applying a mask to the attention map. Thus, this module is called masked-causal attention.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None:
+ """
+ Overview:
+ Initialize the MaskedCausalAttention Model according to input arguments.
+ Arguments:
+ - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128.
+ - max_T (:obj:`int`): The max context length of the attention, such as 6.
+ - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8.
+ - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1.
+ """
+ super().__init__()
+
+ self.n_heads = n_heads
+ self.max_T = max_T
+
+ self.q_net = nn.Linear(h_dim, h_dim)
+ self.k_net = nn.Linear(h_dim, h_dim)
+ self.v_net = nn.Linear(h_dim, h_dim)
+
+ self.proj_net = nn.Linear(h_dim, h_dim)
+
+ self.att_drop = nn.Dropout(drop_p)
+ self.proj_drop = nn.Dropout(drop_p)
+
+ ones = torch.ones((max_T, max_T))
+ mask = torch.tril(ones).view(1, 1, max_T, max_T)
+
+ # register buffer makes sure mask does not get updated
+ # during backpropagation
+ self.register_buffer('mask', mask)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ MaskedCausalAttention forward computation graph, input a sequence tensor \
+ and return a tensor with the same shape.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - out (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input.
+ Examples:
+ >>> inputs = torch.randn(2, 4, 64)
+ >>> model = MaskedCausalAttention(64, 5, 4, 0.1)
+ >>> outputs = model(inputs)
+ >>> assert outputs.shape == torch.Size([2, 4, 64])
+ """
+ B, T, C = x.shape # batch size, seq length, h_dim * n_heads
+
+ N, D = self.n_heads, C // self.n_heads # N = num heads, D = attention dim
+
+ # rearrange q, k, v as (B, N, T, D)
+ q = self.q_net(x).view(B, T, N, D).transpose(1, 2)
+ k = self.k_net(x).view(B, T, N, D).transpose(1, 2)
+ v = self.v_net(x).view(B, T, N, D).transpose(1, 2)
+
+ # weights (B, N, T, T)
+ weights = q @ k.transpose(2, 3) / math.sqrt(D)
+ # causal mask applied to weights
+ weights = weights.masked_fill(self.mask[..., :T, :T] == 0, float('-inf'))
+ # normalize weights, all -inf -> 0 after softmax
+ normalized_weights = F.softmax(weights, dim=-1)
+
+ # attention (B, N, T, D)
+ attention = self.att_drop(normalized_weights @ v)
+
+ # gather heads and project (B, N, T, D) -> (B, T, N*D)
+ attention = attention.transpose(1, 2).contiguous().view(B, T, N * D)
+
+ out = self.proj_drop(self.proj_net(attention))
+ return out
+
+
+class Block(nn.Module):
+ """
+ Overview:
+ The implementation of a transformer block in decision transformer.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, h_dim: int, max_T: int, n_heads: int, drop_p: float) -> None:
+ """
+ Overview:
+ Initialize the Block Model according to input arguments.
+ Arguments:
+ - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128.
+ - max_T (:obj:`int`): The max context length of the attention, such as 6.
+ - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8.
+ - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1.
+ """
+ super().__init__()
+ self.attention = MaskedCausalAttention(h_dim, max_T, n_heads, drop_p)
+ self.mlp = nn.Sequential(
+ nn.Linear(h_dim, 4 * h_dim),
+ nn.GELU(),
+ nn.Linear(4 * h_dim, h_dim),
+ nn.Dropout(drop_p),
+ )
+ self.ln1 = nn.LayerNorm(h_dim)
+ self.ln2 = nn.LayerNorm(h_dim)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Forward computation graph of the decision transformer block, input a sequence tensor \
+ and return a tensor with the same shape.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - output (:obj:`torch.Tensor`): Output tensor, the shape is the same as the input.
+ Examples:
+ >>> inputs = torch.randn(2, 4, 64)
+ >>> model = Block(64, 5, 4, 0.1)
+ >>> outputs = model(inputs)
+ >>> outputs.shape == torch.Size([2, 4, 64])
+ """
+ # Attention -> LayerNorm -> MLP -> LayerNorm
+ x = x + self.attention(x) # residual
+ x = self.ln1(x)
+ x = x + self.mlp(x) # residual
+ x = self.ln2(x)
+ # x = x + self.attention(self.ln1(x))
+ # x = x + self.mlp(self.ln2(x))
+ return x
+
+
+class DecisionTransformer(nn.Module):
+ """
+ Overview:
+ The implementation of decision transformer.
+ Interfaces:
+ ``__init__``, ``forward``, ``configure_optimizers``
+ """
+
+ def __init__(
+ self,
+ state_dim: Union[int, SequenceType],
+ act_dim: int,
+ n_blocks: int,
+ h_dim: int,
+ context_len: int,
+ n_heads: int,
+ drop_p: float,
+ max_timestep: int = 4096,
+ state_encoder: Optional[nn.Module] = None,
+ continuous: bool = False
+ ):
+ """
+ Overview:
+ Initialize the DecisionTransformer Model according to input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Dimension of state, such as 128 or (4, 84, 84).
+ - act_dim (:obj:`int`): The dimension of actions, such as 6.
+ - n_blocks (:obj:`int`): The number of transformer blocks in the decision transformer, such as 3.
+ - h_dim (:obj:`int`): The dimension of the hidden layers, such as 128.
+ - context_len (:obj:`int`): The max context length of the attention, such as 6.
+ - n_heads (:obj:`int`): The number of heads in calculating attention, such as 8.
+ - drop_p (:obj:`float`): The drop rate of the drop-out layer, such as 0.1.
+ - max_timestep (:obj:`int`): The max length of the total sequence, defaults to be 4096.
+ - state_encoder (:obj:`Optional[nn.Module]`): The encoder to pre-process the given input. If it is set to \
+ None, the raw state will be pushed into the transformer.
+ - continuous (:obj:`bool`): Whether the action space is continuous, defaults to be ``False``.
+ """
+ super().__init__()
+
+ self.state_dim = state_dim
+ self.act_dim = act_dim
+ self.h_dim = h_dim
+
+ # transformer blocks
+ input_seq_len = 3 * context_len
+
+ # projection heads (project to embedding)
+ self.embed_ln = nn.LayerNorm(h_dim)
+ self.embed_timestep = nn.Embedding(max_timestep, h_dim)
+ self.drop = nn.Dropout(drop_p)
+
+ self.pos_emb = nn.Parameter(torch.zeros(1, input_seq_len + 1, self.h_dim))
+ self.global_pos_emb = nn.Parameter(torch.zeros(1, max_timestep + 1, self.h_dim))
+
+ if state_encoder is None:
+ self.state_encoder = None
+ blocks = [Block(h_dim, input_seq_len, n_heads, drop_p) for _ in range(n_blocks)]
+ self.embed_rtg = torch.nn.Linear(1, h_dim)
+ self.embed_state = torch.nn.Linear(state_dim, h_dim)
+ self.predict_rtg = torch.nn.Linear(h_dim, 1)
+ self.predict_state = torch.nn.Linear(h_dim, state_dim)
+ if continuous:
+ # continuous actions
+ self.embed_action = torch.nn.Linear(act_dim, h_dim)
+ use_action_tanh = True # True for continuous actions
+ else:
+ # discrete actions
+ self.embed_action = torch.nn.Embedding(act_dim, h_dim)
+ use_action_tanh = False # False for discrete actions
+ self.predict_action = nn.Sequential(
+ *([nn.Linear(h_dim, act_dim)] + ([nn.Tanh()] if use_action_tanh else []))
+ )
+ else:
+ blocks = [Block(h_dim, input_seq_len + 1, n_heads, drop_p) for _ in range(n_blocks)]
+ self.state_encoder = state_encoder
+ self.embed_rtg = nn.Sequential(nn.Linear(1, h_dim), nn.Tanh())
+ self.head = nn.Linear(h_dim, act_dim, bias=False)
+ self.embed_action = nn.Sequential(nn.Embedding(act_dim, h_dim), nn.Tanh())
+ self.transformer = nn.Sequential(*blocks)
+
+ def forward(
+ self,
+ timesteps: torch.Tensor,
+ states: torch.Tensor,
+ actions: torch.Tensor,
+ returns_to_go: torch.Tensor,
+ tar: Optional[int] = None
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Forward computation graph of the decision transformer, input a sequence tensor \
+ and return a tensor with the same shape.
+ Arguments:
+ - timesteps (:obj:`torch.Tensor`): The timestep for input sequence.
+ - states (:obj:`torch.Tensor`): The sequence of states.
+ - actions (:obj:`torch.Tensor`): The sequence of actions.
+ - returns_to_go (:obj:`torch.Tensor`): The sequence of return-to-go.
+ - tar (:obj:`Optional[int]`): Whether to predict action, regardless of index.
+ Returns:
+ - output (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`): Output contains three tensors, \
+ they are correspondingly the predicted states, predicted actions and predicted return-to-go.
+ Examples:
+ >>> B, T = 4, 6
+ >>> state_dim = 3
+ >>> act_dim = 2
+ >>> DT_model = DecisionTransformer(\
+ state_dim=state_dim,\
+ act_dim=act_dim,\
+ n_blocks=3,\
+ h_dim=8,\
+ context_len=T,\
+ n_heads=2,\
+ drop_p=0.1,\
+ )
+ >>> timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T
+ >>> states = torch.randn([B, T, state_dim]) # B x T x state_dim
+ >>> actions = torch.randint(0, act_dim, [B, T, 1])
+ >>> action_target = torch.randint(0, act_dim, [B, T, 1])
+ >>> returns_to_go_sample = torch.tensor([1, 0.8, 0.6, 0.4, 0.2, 0.]).repeat([B, 1]).unsqueeze(-1).float()
+ >>> traj_mask = torch.ones([B, T], dtype=torch.long) # B x T
+ >>> actions = actions.squeeze(-1)
+ >>> state_preds, action_preds, return_preds = DT_model.forward(\
+ timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go\
+ )
+ >>> assert state_preds.shape == torch.Size([B, T, state_dim])
+ >>> assert return_preds.shape == torch.Size([B, T, 1])
+ >>> assert action_preds.shape == torch.Size([B, T, act_dim])
+ """
+ B, T = states.shape[0], states.shape[1]
+ if self.state_encoder is None:
+ time_embeddings = self.embed_timestep(timesteps)
+
+ # time embeddings are treated similar to positional embeddings
+ state_embeddings = self.embed_state(states) + time_embeddings
+ action_embeddings = self.embed_action(actions) + time_embeddings
+ returns_embeddings = self.embed_rtg(returns_to_go) + time_embeddings
+
+ # stack rtg, states and actions and reshape sequence as
+ # (r_0, s_0, a_0, r_1, s_1, a_1, r_2, s_2, a_2 ...)
+ t_p = torch.stack((returns_embeddings, state_embeddings, action_embeddings),
+ dim=1).permute(0, 2, 1, 3).reshape(B, 3 * T, self.h_dim)
+ h = self.embed_ln(t_p)
+ # transformer and prediction
+ h = self.transformer(h)
+ # get h reshaped such that its size = (B x 3 x T x h_dim) and
+ # h[:, 0, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t
+ # h[:, 1, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t
+ # h[:, 2, t] is conditioned on the input sequence r_0, s_0, a_0 ... r_t, s_t, a_t
+ # that is, for each timestep (t) we have 3 output embeddings from the transformer,
+ # each conditioned on all previous timesteps plus
+ # the 3 input variables at that timestep (r_t, s_t, a_t) in sequence.
+ h = h.reshape(B, T, 3, self.h_dim).permute(0, 2, 1, 3)
+
+ return_preds = self.predict_rtg(h[:, 2]) # predict next rtg given r, s, a
+ state_preds = self.predict_state(h[:, 2]) # predict next state given r, s, a
+ action_preds = self.predict_action(h[:, 1]) # predict action given r, s
+ else:
+ state_embeddings = self.state_encoder(
+ states.reshape(-1, *self.state_dim).type(torch.float32).contiguous()
+ ) # (batch * block_size, h_dim)
+ state_embeddings = state_embeddings.reshape(B, T, self.h_dim) # (batch, block_size, h_dim)
+ returns_embeddings = self.embed_rtg(returns_to_go.type(torch.float32))
+ action_embeddings = self.embed_action(actions.type(torch.long).squeeze(-1)) # (batch, block_size, h_dim)
+
+ token_embeddings = torch.zeros(
+ (B, T * 3 - int(tar is None), self.h_dim), dtype=torch.float32, device=state_embeddings.device
+ )
+ token_embeddings[:, ::3, :] = returns_embeddings
+ token_embeddings[:, 1::3, :] = state_embeddings
+ token_embeddings[:, 2::3, :] = action_embeddings[:, -T + int(tar is None):, :]
+
+ all_global_pos_emb = torch.repeat_interleave(
+ self.global_pos_emb, B, dim=0
+ ) # batch_size, traj_length, h_dim
+
+ position_embeddings = torch.gather(
+ all_global_pos_emb, 1, torch.repeat_interleave(timesteps, self.h_dim, dim=-1)
+ ) + self.pos_emb[:, :token_embeddings.shape[1], :]
+
+ t_p = token_embeddings + position_embeddings
+
+ h = self.drop(t_p)
+ h = self.transformer(h)
+ h = self.embed_ln(h)
+ logits = self.head(h)
+
+ return_preds = None
+ state_preds = None
+ action_preds = logits[:, 1::3, :] # only keep predictions from state_embeddings
+
+ return state_preds, action_preds, return_preds
+
+ def configure_optimizers(
+ self, weight_decay: float, learning_rate: float, betas: Tuple[float, float] = (0.9, 0.95)
+ ) -> torch.optim.Optimizer:
+ """
+ Overview:
+ This function returns an optimizer given the input arguments. \
+ We are separating out all parameters of the model into two buckets: those that will experience \
+ weight decay for regularization and those that won't (biases, and layernorm/embedding weights).
+ Arguments:
+ - weight_decay (:obj:`float`): The weigh decay of the optimizer.
+ - learning_rate (:obj:`float`): The learning rate of the optimizer.
+ - betas (:obj:`Tuple[float, float]`): The betas for Adam optimizer.
+ Outputs:
+ - optimizer (:obj:`torch.optim.Optimizer`): The desired optimizer.
+ """
+
+ # separate out all parameters to those that will and won't experience regularizing weight decay
+ decay = set()
+ no_decay = set()
+ # whitelist_weight_modules = (torch.nn.Linear, )
+ whitelist_weight_modules = (torch.nn.Linear, torch.nn.Conv2d)
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
+ for mn, m in self.named_modules():
+ for pn, p in m.named_parameters():
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
+
+ if pn.endswith('bias'):
+ # all biases will not be decayed
+ no_decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
+ # weights of whitelist modules will be weight decayed
+ decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
+ # weights of blacklist modules will NOT be weight decayed
+ no_decay.add(fpn)
+
+ # special case the position embedding parameter in the root GPT module as not decayed
+ no_decay.add('pos_emb')
+ no_decay.add('global_pos_emb')
+
+ # validate that we considered every parameter
+ param_dict = {pn: p for pn, p in self.named_parameters()}
+ inter_params = decay & no_decay
+ union_params = decay | no_decay
+ assert len(inter_params) == 0, "parameters %s made it into both decay/no_decay sets!" % (str(inter_params), )
+ assert len(param_dict.keys() - union_params) == 0,\
+ "parameters %s were not separated into either decay/no_decay set!" \
+ % (str(param_dict.keys() - union_params), )
+
+ # create the pytorch optimizer object
+ optim_groups = [
+ {
+ "params": [param_dict[pn] for pn in sorted(list(decay))],
+ "weight_decay": weight_decay
+ },
+ {
+ "params": [param_dict[pn] for pn in sorted(list(no_decay))],
+ "weight_decay": 0.0
+ },
+ ]
+ optimizer = torch.optim.AdamW(optim_groups, lr=learning_rate, betas=betas)
+ return optimizer
diff --git a/DI-engine/ding/model/template/diffusion.py b/DI-engine/ding/model/template/diffusion.py
new file mode 100755
index 0000000000000000000000000000000000000000..f8b48f3061bbf49e5dec97a7bbf4bc79608bc1eb
--- /dev/null
+++ b/DI-engine/ding/model/template/diffusion.py
@@ -0,0 +1,645 @@
+from typing import Union, List, Dict
+from collections import namedtuple
+import numpy as np
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ding.utils import list_split, MODEL_REGISTRY, squeeze, SequenceType
+from ding.torch_utils.network.diffusion import extract, cosine_beta_schedule, apply_conditioning, \
+ DiffusionUNet1d, TemporalValue
+
+Sample = namedtuple('Sample', 'trajectories values chains')
+
+
+def default_sample_fn(model, x, cond, t):
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = model.p_mean_variance(
+ x=x,
+ cond=cond,
+ t=t,
+ )
+ noise = 0.5 * torch.randn_like(x)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1, ) * (len(x.shape) - 1)))
+ values = torch.zeros(len(x), device=device)
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise, values
+
+
+def get_guide_output(guide, x, cond, t):
+ x.requires_grad_()
+ y = guide(x, cond, t).squeeze(dim=-1)
+ grad = torch.autograd.grad([y.sum()], [x])[0]
+ x.detach()
+ return y, grad
+
+
+def n_step_guided_p_sample(
+ model,
+ x,
+ cond,
+ t,
+ guide,
+ scale=0.001,
+ t_stopgrad=0,
+ n_guide_steps=1,
+ scale_grad_by_std=True,
+):
+ model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape)
+ model_std = torch.exp(0.5 * model_log_variance)
+ model_var = torch.exp(model_log_variance)
+
+ for _ in range(n_guide_steps):
+ with torch.enable_grad():
+ y, grad = get_guide_output(guide, x, cond, t)
+
+ if scale_grad_by_std:
+ grad = model_var * grad
+
+ grad[t < t_stopgrad] = 0
+
+ x = x + scale * grad
+ x = apply_conditioning(x, cond, model.action_dim)
+
+ model_mean, _, model_log_variance = model.p_mean_variance(x=x, cond=cond, t=t)
+
+ # no noise when t == 0
+ noise = torch.randn_like(x)
+ noise[t == 0] = 0
+
+ return model_mean + model_std * noise, y
+
+
+class GaussianDiffusion(nn.Module):
+ """
+ Overview:
+ Gaussian diffusion model
+ Arguments:
+ - model (:obj:`str`): type of model
+ - model_cfg (:obj:'dict') config of model
+ - horizon (:obj:`int`): horizon of trajectory
+ - obs_dim (:obj:`int`): Dim of the ovservation
+ - action_dim (:obj:`int`): Dim of the ation
+ - n_timesteps (:obj:`int`): Number of timesteps
+ - predict_epsilon (:obj:'bool'): Whether predict epsilon
+ - loss_discount (:obj:'float'): discount of loss
+ - clip_denoised (:obj:'bool'): Whether use clip_denoised
+ - action_weight (:obj:'float'): weight of action
+ - loss_weights (:obj:'dict'): weight of loss
+ """
+
+ def __init__(
+ self,
+ model: str,
+ model_cfg: dict,
+ horizon: int,
+ obs_dim: Union[int, SequenceType],
+ action_dim: Union[int, SequenceType],
+ n_timesteps: int = 1000,
+ predict_epsilon: bool = True,
+ loss_discount: float = 1.0,
+ clip_denoised: bool = False,
+ action_weight: float = 1.0,
+ loss_weights: dict = None,
+ ) -> None:
+ super().__init__()
+ self.horizon = horizon
+ self.obs_dim = obs_dim
+ self.action_dim = action_dim
+ self.transition_dim = obs_dim + action_dim
+ if type(model) == str:
+ model = eval(model)
+ self.model = model(**model_cfg)
+ self.predict_epsilon = predict_epsilon
+ self.clip_denoised = clip_denoised
+
+ betas = cosine_beta_schedule(n_timesteps)
+ alphas = 1. - betas
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
+ self.n_timesteps = int(n_timesteps)
+
+ self.register_buffer('betas', betas)
+ self.register_buffer('alphas_cumprod', alphas_cumprod)
+ self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
+ self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
+ self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
+ self.register_buffer('posterior_variance', posterior_variance)
+
+ # log calculation clipped because the posterior variance
+ # is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', torch.log(torch.clamp(posterior_variance, min=1e-20)))
+ self.register_buffer('posterior_mean_coef1', betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
+ self.register_buffer(
+ 'posterior_mean_coef2', (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)
+ )
+
+ self.loss_weights = self.get_loss_weights(action_weight, loss_discount, loss_weights)
+
+ def get_loss_weights(self, action_weight: float, discount: float, weights_dict: dict):
+ """
+ Overview:
+ sets loss coefficients for trajectory
+ Arguments:
+ - action_weight (:obj:'float') coefficient on first action loss
+ - discount (:obj:'float') multiplies t^th timestep of trajectory loss by discount**t
+ - weights_dict (:obj:'dict') { i: c } multiplies dimension i of observation loss by c
+ """
+ self.action_weight = action_weight
+ dim_weights = torch.ones(self.transition_dim, dtype=torch.float32)
+
+ # set loss coefficients for dimensions of observation
+ if weights_dict is None:
+ weights_dict = {}
+ for ind, w in weights_dict.items():
+ dim_weights[self.action_dim + ind] *= w
+
+ # decay loss with trajectory timestep: discount**t
+ discounts = discount ** torch.arange(self.horizon, dtype=torch.float)
+ discounts = discounts / discounts.mean()
+ loss_weights = torch.einsum('h,t->ht', discounts, dim_weights)
+
+ # manually set a0 weight
+ loss_weights[0, :self.action_dim] = action_weight
+ return loss_weights
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ """
+ if self.predict_epsilon, model output is (scaled) noise;
+ otherwise, model predicts x0 directly
+ """
+ if self.predict_epsilon:
+ return (
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+ else:
+ return noise
+
+ def q_posterior(self, x_start, x_t, t):
+ """
+ Overview:
+ give noise and step, compute mean, variance.
+ Arguments:
+ x_start (:obj:'tensor') noise trajectory in timestep 0
+ x_t (:obj:'tuple') noise trajectory in timestep t
+ t (:obj:'int') timestep of diffusion step
+ """
+ posterior_mean = (
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, cond, t):
+ x_recon = self.predict_start_from_noise(x, t=t, noise=self.model(x, cond, t))
+
+ if self.clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ else:
+ assert RuntimeError()
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, cond, return_chain=False, sample_fn=default_sample_fn, plan_size=1, **sample_kwargs):
+ device = self.betas.device
+
+ batch_size = shape[0]
+ x = torch.randn(shape, device=device)
+ x = apply_conditioning(x, cond, self.action_dim)
+
+ chain = [x] if return_chain else None
+
+ for i in reversed(range(0, self.n_timesteps)):
+ t = torch.full((batch_size, ), i, device=device, dtype=torch.long)
+ x, values = sample_fn(self, x, cond, t, **sample_kwargs)
+ x = apply_conditioning(x, cond, self.action_dim)
+
+ if return_chain:
+ chain.append(x)
+ values = values.reshape(-1, plan_size, *values.shape[1:])
+ x = x.reshape(-1, plan_size, *x.shape[1:])
+ if plan_size > 1:
+ inds = torch.argsort(values, dim=1, descending=True)
+ x = x[torch.arange(x.size(0)).unsqueeze(1), inds]
+ values = values[torch.arange(values.size(0)).unsqueeze(1), inds]
+ if return_chain:
+ chain = torch.stack(chain, dim=1)
+ return Sample(x, values, chain)
+
+ @torch.no_grad()
+ def conditional_sample(self, cond, horizon=None, **sample_kwargs):
+ """
+ conditions : [ (time, state), ... ]
+ """
+ device = self.betas.device
+ batch_size = len(cond[0])
+ horizon = horizon or self.horizon
+ shape = (batch_size, horizon, self.transition_dim)
+
+ return self.p_sample_loop(shape, cond, **sample_kwargs)
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Arguments:
+ conditions (:obj:'tuple') [ (time, state), ... ] conditions of diffusion
+ t (:obj:'int') timestep of diffusion
+ noise (:obj:'tensor.float') timestep's noise of diffusion
+ """
+ if noise is None:
+ noise = torch.randn_like(x_start)
+
+ sample = (
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+ )
+
+ return sample
+
+ def p_losses(self, x_start, cond, t):
+ noise = torch.randn_like(x_start)
+
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ x_noisy = apply_conditioning(x_noisy, cond, self.action_dim)
+
+ x_recon = self.model(x_noisy, cond, t)
+ x_recon = apply_conditioning(x_recon, cond, self.action_dim)
+
+ assert noise.shape == x_recon.shape
+
+ if self.predict_epsilon:
+ loss = F.mse_loss(x_recon, noise, reduction='none')
+ a0_loss = (loss[:, 0, :self.action_dim] / self.loss_weights[0, :self.action_dim].to(loss.device)).mean()
+ loss = (loss * self.loss_weights.to(loss.device)).mean()
+ else:
+ loss = F.mse_loss(x_recon, x_start, reduction='none')
+ a0_loss = (loss[:, 0, :self.action_dim] / self.loss_weights[0, :self.action_dim].to(loss.device)).mean()
+ loss = (loss * self.loss_weights.to(loss.device)).mean()
+ return loss, a0_loss
+
+ def forward(self, cond, *args, **kwargs):
+ return self.conditional_sample(cond, *args, **kwargs)
+
+
+class ValueDiffusion(GaussianDiffusion):
+ """
+ Overview:
+ Gaussian diffusion model for value function.
+ """
+
+ def p_losses(self, x_start, cond, target, t):
+ noise = torch.randn_like(x_start)
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ x_noisy = apply_conditioning(x_noisy, cond, self.action_dim)
+
+ pred = self.model(x_noisy, cond, t)
+ loss = F.mse_loss(pred, target, reduction='none').mean()
+ log = {
+ 'mean_pred': pred.mean().item(),
+ 'max_pred': pred.max().item(),
+ 'min_pred': pred.min().item(),
+ }
+
+ return loss, log
+
+ def forward(self, x, cond, t):
+ return self.model(x, cond, t)
+
+
+@MODEL_REGISTRY.register('pd')
+class PlanDiffuser(nn.Module):
+ """
+ Overview:
+ Diffuser model for plan.
+ Arguments:
+ - diffuser_model (:obj:`str`): type of plan model
+ - diffuser_model_cfg (:obj:'dict') config of diffuser_model
+ - value_model (:obj:`str`): type of value model, if haven't use, set it as None
+ - value_model_cfg (:obj:`int`): config of value_model
+ - sample_kwargs : config of sample function
+ """
+
+ def __init__(
+ self, diffuser_model: str, diffuser_model_cfg: dict, value_model: str, value_model_cfg: dict, **sample_kwargs
+ ):
+ super().__init__()
+ diffuser_model = eval(diffuser_model)
+ self.diffuser = diffuser_model(**diffuser_model_cfg)
+ self.value = None
+ if value_model:
+ value_model = eval(value_model)
+ self.value = value_model(**value_model_cfg)
+ self.sample_kwargs = sample_kwargs
+
+ def diffuser_loss(self, x_start, cond, t):
+ return self.diffuser.p_losses(x_start, cond, t)
+
+ def value_loss(self, x_start, cond, target, t):
+ return self.value.p_losses(x_start, cond, target, t)
+
+ def get_eval(self, cond, batch_size=1):
+ cond = self.repeat_cond(cond, batch_size)
+ if self.value:
+ samples = self.diffuser(
+ cond, sample_fn=n_step_guided_p_sample, plan_size=batch_size, guide=self.value, **self.sample_kwargs
+ )
+ # extract action [eval_num, batch_size, horizon, transition_dim]
+ actions = samples.trajectories[:, :, :, :self.diffuser.action_dim]
+ action = actions[:, 0, 0]
+ return action
+ else:
+ samples = self.diffuser(cond, plan_size=batch_size)
+ return samples.trajectories[:, :, :, self.diffuser.action_dim:].squeeze(1)
+
+ def repeat_cond(self, cond, batch_size):
+ for k, v in cond.items():
+ cond[k] = v.repeat_interleave(batch_size, dim=0)
+ return cond
+
+
+@MODEL_REGISTRY.register('dd')
+class GaussianInvDynDiffusion(nn.Module):
+ """
+ Overview:
+ Gaussian diffusion model with Invdyn action model.
+ Arguments:
+ - model (:obj:`str`): type of model
+ - model_cfg (:obj:'dict') config of model
+ - horizon (:obj:`int`): horizon of trajectory
+ - obs_dim (:obj:`int`): Dim of the ovservation
+ - action_dim (:obj:`int`): Dim of the ation
+ - n_timesteps (:obj:`int`): Number of timesteps
+ - hidden_dim (:obj:'int'): hidden dim of inv_model
+ - returns_condition (:obj:'bool'): Whether use returns condition
+ - ar_inv (:obj:'bool'): Whether use inverse action learning
+ - train_only_inv (:obj:'bool'): Whether train inverse action model only
+ - predict_epsilon (:obj:'bool'): Whether predict epsilon
+ - condition_guidance_w (:obj:'float'): weight of condition guidance
+ - loss_discount (:obj:'float'): discount of loss
+ """
+
+ def __init__(
+ self,
+ model: str,
+ model_cfg: dict,
+ horizon: int,
+ obs_dim: Union[int, SequenceType],
+ action_dim: Union[int, SequenceType],
+ n_timesteps: int = 1000,
+ hidden_dim: int = 256,
+ returns_condition: bool = False,
+ ar_inv: bool = False,
+ train_only_inv: bool = False,
+ predict_epsilon: bool = True,
+ condition_guidance_w: float = 0.1,
+ loss_discount: float = 1.0,
+ clip_denoised: bool = False,
+ ) -> None:
+ super().__init__()
+ self.horizon = horizon
+ self.obs_dim = obs_dim
+ self.action_dim = action_dim
+ self.transition_dim = obs_dim + action_dim
+ if type(model) == str:
+ model = eval(model)
+ self.model = model(**model_cfg)
+ self.ar_inv = ar_inv
+ self.train_only_inv = train_only_inv
+ self.predict_epsilon = predict_epsilon
+ self.condition_guidance_w = condition_guidance_w
+
+ self.inv_model = nn.Sequential(
+ nn.Linear(2 * self.obs_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, hidden_dim),
+ nn.ReLU(),
+ nn.Linear(hidden_dim, self.action_dim),
+ )
+
+ self.returns_condition = returns_condition
+ self.clip_denoised = clip_denoised
+
+ betas = cosine_beta_schedule(n_timesteps)
+ alphas = 1. - betas
+ alphas_cumprod = torch.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = torch.cat([torch.ones(1), alphas_cumprod[:-1]])
+ self.n_timesteps = int(n_timesteps)
+
+ self.register_buffer('betas', betas)
+ self.register_buffer('alphas_cumprod', alphas_cumprod)
+ self.register_buffer('alphas_cumprod_prev', alphas_cumprod_prev)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', torch.sqrt(alphas_cumprod))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', torch.sqrt(1. - alphas_cumprod))
+ self.register_buffer('log_one_minus_alphas_cumprod', torch.log(1. - alphas_cumprod))
+ self.register_buffer('sqrt_recip_alphas_cumprod', torch.sqrt(1. / alphas_cumprod))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', torch.sqrt(1. / alphas_cumprod - 1))
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ posterior_variance = betas * (1. - alphas_cumprod_prev) / (1. - alphas_cumprod)
+ self.register_buffer('posterior_variance', posterior_variance)
+
+ # log calculation clipped because the posterior variance
+ # is 0 at the beginning of the diffusion chain
+ self.register_buffer('posterior_log_variance_clipped', torch.log(torch.clamp(posterior_variance, min=1e-20)))
+ self.register_buffer('posterior_mean_coef1', betas * np.sqrt(alphas_cumprod_prev) / (1. - alphas_cumprod))
+ self.register_buffer(
+ 'posterior_mean_coef2', (1. - alphas_cumprod_prev) * np.sqrt(alphas) / (1. - alphas_cumprod)
+ )
+
+ self.loss_weights = self.get_loss_weights(loss_discount)
+
+ def get_loss_weights(self, discount: int):
+ self.action_weight = 1
+ dim_weights = torch.ones(self.obs_dim, dtype=torch.float32)
+
+ # decay loss with trajectory timestep: discount**t
+ discounts = discount ** torch.arange(self.horizon, dtype=torch.float)
+ discounts = discounts / discounts.mean()
+ loss_weights = torch.einsum('h,t->ht', discounts, dim_weights)
+ # Cause things are conditioned on t=0
+ if self.predict_epsilon:
+ loss_weights[0, :] = 0
+
+ return loss_weights
+
+ def predict_start_from_noise(self, x_t, t, noise):
+ """
+ if self.predict_epsilon, model output is (scaled) noise;
+ otherwise, model predicts x0 directly
+ """
+ if self.predict_epsilon:
+ return (
+ extract(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t -
+ extract(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * noise
+ )
+ else:
+ return noise
+
+ def q_posterior(self, x_start, x_t, t):
+ """
+ Arguments:
+ x_start (:obj:'tensor') noise trajectory in timestep 0
+ x_t (:obj:'tuple') noise trajectory in timestep t
+ t (:obj:'int') timestep of diffusion step
+ """
+ posterior_mean = (
+ extract(self.posterior_mean_coef1, t, x_t.shape) * x_start +
+ extract(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = extract(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = extract(self.posterior_log_variance_clipped, t, x_t.shape)
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, x, cond, t, returns=None):
+ """
+ Arguments:
+ x (:obj:'tensor') noise trajectory in timestep t
+ cond (:obj:'tuple') [ (time, state), ... ] state is init state of env, time = 0
+ t (:obj:'int') timestep of diffusion step
+ returns (:obj:'tensor') condition returns of trajectory, returns is normal return
+ returns:
+ model_mean (:obj:'tensor.float')
+ posterior_variance (:obj:'float')
+ posterior_log_variance (:obj:'float')
+ """
+ if self.returns_condition:
+ # epsilon could be epsilon or x0 itself
+ epsilon_cond = self.model(x, cond, t, returns, use_dropout=False)
+ epsilon_uncond = self.model(x, cond, t, returns, force_dropout=True)
+ epsilon = epsilon_uncond + self.condition_guidance_w * (epsilon_cond - epsilon_uncond)
+ else:
+ epsilon = self.model(x, cond, t)
+
+ t = t.detach().to(torch.int64)
+ x_recon = self.predict_start_from_noise(x, t=t, noise=epsilon)
+
+ if self.clip_denoised:
+ x_recon.clamp_(-1., 1.)
+ else:
+ assert RuntimeError()
+
+ model_mean, posterior_variance, posterior_log_variance = self.q_posterior(x_start=x_recon, x_t=x, t=t)
+ return model_mean, posterior_variance, posterior_log_variance
+
+ @torch.no_grad()
+ def p_sample(self, x, cond, t, returns=None):
+ """
+ Arguments:
+ x (:obj:'tensor') noise trajectory in timestep t
+ cond (:obj:'tuple') [ (time, state), ... ] state is init state of env, time = 0
+ t (:obj:'int') timestep of diffusion step
+ returns (:obj:'tensor') condition returns of trajectory, returns is normal return
+ """
+ b, *_, device = *x.shape, x.device
+ model_mean, _, model_log_variance = self.p_mean_variance(x=x, cond=cond, t=t, returns=returns)
+ noise = 0.5 * torch.randn_like(x)
+ # no noise when t == 0
+ nonzero_mask = (1 - (t == 0).float()).reshape(b, *((1, ) * (len(x.shape) - 1)))
+ return model_mean + nonzero_mask * (0.5 * model_log_variance).exp() * noise
+
+ @torch.no_grad()
+ def p_sample_loop(self, shape, cond, returns=None, verbose=True, return_diffusion=False):
+ """
+ Arguments:
+ shape (:obj:'tuple') (batch_size, horizon, self.obs_dim)
+ cond (:obj:'tuple') [ (time, state), ... ] state is init state of env, time = 0
+ returns (:obj:'tensor') condition returns of trajectory, returns is normal return
+ horizon (:obj:'int') horizon of trajectory
+ verbose (:obj:'bool') whether log diffusion progress
+ return_diffusion (:obj:'bool') whether use return diffusion
+ """
+ device = self.betas.device
+
+ batch_size = shape[0]
+ x = 0.5 * torch.randn(shape, device=device)
+ # In this model, init state must be given by the env and without noise.
+ x = apply_conditioning(x, cond, 0)
+
+ if return_diffusion:
+ diffusion = [x]
+
+ for i in reversed(range(0, self.n_timesteps)):
+ timesteps = torch.full((batch_size, ), i, device=device, dtype=torch.long)
+ x = self.p_sample(x, cond, timesteps, returns)
+ x = apply_conditioning(x, cond, 0)
+
+ if return_diffusion:
+ diffusion.append(x)
+
+ if return_diffusion:
+ return x, torch.stack(diffusion, dim=1)
+ else:
+ return x
+
+ @torch.no_grad()
+ def conditional_sample(self, cond, returns=None, horizon=None, *args, **kwargs):
+ """
+ Arguments:
+ conditions (:obj:'tuple') [ (time, state), ... ] state is init state of env, time is timestep of trajectory
+ returns (:obj:'tensor') condition returns of trajectory, returns is normal return
+ horizon (:obj:'int') horizon of trajectory
+ returns:
+ x (:obj:'tensor') tarjctory of env
+ """
+ device = self.betas.device
+ batch_size = len(cond[0])
+ horizon = horizon or self.horizon
+ shape = (batch_size, horizon, self.obs_dim)
+
+ return self.p_sample_loop(shape, cond, returns, *args, **kwargs)
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Arguments:
+ conditions (:obj:'tuple') [ (time, state), ... ] conditions of diffusion
+ t (:obj:'int') timestep of diffusion
+ noise (:obj:'tensor.float') timestep's noise of diffusion
+ """
+ if noise is None:
+ noise = torch.randn_like(x_start)
+
+ sample = (
+ extract(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+ )
+
+ return sample
+
+ def p_losses(self, x_start, cond, t, returns=None):
+ noise = torch.randn_like(x_start)
+
+ x_noisy = self.q_sample(x_start=x_start, t=t, noise=noise)
+ x_noisy = apply_conditioning(x_noisy, cond, 0)
+
+ x_recon = self.model(x_noisy, cond, t, returns)
+
+ if not self.predict_epsilon:
+ x_recon = apply_conditioning(x_recon, cond, 0)
+
+ assert noise.shape == x_recon.shape
+
+ if self.predict_epsilon:
+ loss = F.mse_loss(x_recon, noise, reduction='none')
+ loss = (loss * self.loss_weights.to(loss.device)).mean()
+ else:
+ loss = F.mse_loss(x_recon, x_start, reduction='none')
+ loss = (loss * self.loss_weights.to(loss.device)).mean()
+
+ return loss
+
+ def forward(self, cond, *args, **kwargs):
+ return self.conditional_sample(cond=cond, *args, **kwargs)
diff --git a/DI-engine/ding/model/template/ebm.py b/DI-engine/ding/model/template/ebm.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b91fd1b6dcd75b8b6fce1ec14afb5675e9b2504
--- /dev/null
+++ b/DI-engine/ding/model/template/ebm.py
@@ -0,0 +1,851 @@
+"""
+Vanilla DFO and EBM are adapted from https://github.com/kevinzakka/ibc.
+MCMC is adapted from https://github.com/google-research/ibc.
+"""
+from typing import Callable, Tuple
+from functools import wraps
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from abc import ABC, abstractmethod
+
+from ding.utils import MODEL_REGISTRY, STOCHASTIC_OPTIMIZER_REGISTRY
+from ding.torch_utils import unsqueeze_repeat
+from ding.model.wrapper import IModelWrapper
+from ding.model.common import RegressionHead
+
+
+def create_stochastic_optimizer(device: str, stochastic_optimizer_config: dict):
+ """
+ Overview:
+ Create stochastic optimizer.
+ Arguments:
+ - device (:obj:`str`): Device.
+ - stochastic_optimizer_config (:obj:`dict`): Stochastic optimizer config.
+ """
+ return STOCHASTIC_OPTIMIZER_REGISTRY.build(
+ stochastic_optimizer_config.pop("type"), device=device, **stochastic_optimizer_config
+ )
+
+
+def no_ebm_grad():
+ """Wrapper that disables energy based model gradients"""
+
+ def ebm_disable_grad_wrapper(func: Callable):
+
+ @wraps(func)
+ def wrapper(*args, **kwargs):
+ ebm = args[-1]
+ assert isinstance(ebm, (IModelWrapper, nn.Module)),\
+ 'Make sure ebm is the last positional arguments.'
+ ebm.requires_grad_(False)
+ result = func(*args, **kwargs)
+ ebm.requires_grad_(True)
+ return result
+
+ return wrapper
+
+ return ebm_disable_grad_wrapper
+
+
+class StochasticOptimizer(ABC):
+ """
+ Overview:
+ Base class for stochastic optimizers.
+ Interface:
+ ``__init__``, ``_sample``, ``_get_best_action_sample``, ``set_action_bounds``, ``sample``, ``infer``
+ """
+
+ def _sample(self, obs: torch.Tensor, num_samples: int) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Drawing action samples from the uniform random distribution \
+ and tiling observations to the same shape as action samples.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observation.
+ - num_samples (:obj:`int`): The number of negative samples.
+ Returns:
+ - tiled_obs (:obj:`torch.Tensor`): Observations tiled.
+ - action (:obj:`torch.Tensor`): Action sampled.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
+ - num_samples (:obj:`int`): :math:`N`.
+ - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
+ - action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
+ Examples:
+ >>> obs = torch.randn(2, 4)
+ >>> opt = StochasticOptimizer()
+ >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
+ >>> tiled_obs, action = opt._sample(obs, 8)
+ """
+ size = (obs.shape[0], num_samples, self.action_bounds.shape[1])
+ low, high = self.action_bounds[0, :], self.action_bounds[1, :]
+ action_samples = low + (high - low) * torch.rand(size).to(self.device)
+ tiled_obs = unsqueeze_repeat(obs, num_samples, 1)
+ return tiled_obs, action_samples
+
+ @staticmethod
+ @torch.no_grad()
+ def _get_best_action_sample(obs: torch.Tensor, action_samples: torch.Tensor, ebm: nn.Module):
+ """
+ Overview:
+ Return one action for each batch with highest probability (lowest energy).
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observation.
+ - action_samples (:obj:`torch.Tensor`): Action from uniform distributions.
+ Returns:
+ - best_action_samples (:obj:`torch.Tensor`): Best action.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
+ - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`.
+ - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
+ Examples:
+ >>> obs = torch.randn(2, 4)
+ >>> action_samples = torch.randn(2, 8, 5)
+ >>> ebm = EBM(4, 5)
+ >>> opt = StochasticOptimizer()
+ >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
+ >>> best_action_samples = opt._get_best_action_sample(obs, action_samples, ebm)
+ """
+ # (B, N)
+ energies = ebm.forward(obs, action_samples)
+ probs = F.softmax(-1.0 * energies, dim=-1)
+ # (B, )
+ best_idxs = probs.argmax(dim=-1)
+ return action_samples[torch.arange(action_samples.size(0)), best_idxs]
+
+ def set_action_bounds(self, action_bounds: np.ndarray):
+ """
+ Overview:
+ Set action bounds calculated from the dataset statistics.
+ Arguments:
+ - action_bounds (:obj:`np.ndarray`): Array of shape (2, A), \
+ where action_bounds[0] is lower bound and action_bounds[1] is upper bound.
+ Returns:
+ - action_bounds (:obj:`torch.Tensor`): Action bounds.
+ Shapes:
+ - action_bounds (:obj:`np.ndarray`): :math:`(2, A)`.
+ - action_bounds (:obj:`torch.Tensor`): :math:`(2, A)`.
+ Examples:
+ >>> opt = StochasticOptimizer()
+ >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
+ """
+ self.action_bounds = torch.as_tensor(action_bounds, dtype=torch.float32).to(self.device)
+
+ @abstractmethod
+ def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Create tiled observations and sample counter-negatives for InfoNCE loss.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observations.
+ - ebm (:obj:`torch.nn.Module`): Energy based model.
+ Returns:
+ - tiled_obs (:obj:`torch.Tensor`): Tiled observations.
+ - action (:obj:`torch.Tensor`): Actions.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
+ - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
+ - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
+ - action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
+
+ .. note:: In the case of derivative-free optimization, this function will simply call _sample.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
+ """
+ Overview:
+ Optimize for the best action conditioned on the current observation.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observations.
+ - ebm (:obj:`torch.nn.Module`): Energy based model.
+ Returns:
+ - best_action_samples (:obj:`torch.Tensor`): Best actions.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
+ - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
+ - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
+ """
+ raise NotImplementedError
+
+
+@STOCHASTIC_OPTIMIZER_REGISTRY.register('dfo')
+class DFO(StochasticOptimizer):
+ """
+ Overview:
+ Derivative-Free Optimizer in paper Implicit Behavioral Cloning.
+ https://arxiv.org/abs/2109.00137
+ Interface:
+ ``init``, ``sample``, ``infer``
+ """
+
+ def __init__(
+ self,
+ noise_scale: float = 0.33,
+ noise_shrink: float = 0.5,
+ iters: int = 3,
+ train_samples: int = 8,
+ inference_samples: int = 16384,
+ device: str = 'cpu',
+ ):
+ """
+ Overview:
+ Initialize the Derivative-Free Optimizer
+ Arguments:
+ - noise_scale (:obj:`float`): Initial noise scale.
+ - noise_shrink (:obj:`float`): Noise scale shrink rate.
+ - iters (:obj:`int`): Number of iterations.
+ - train_samples (:obj:`int`): Number of samples for training.
+ - inference_samples (:obj:`int`): Number of samples for inference.
+ - device (:obj:`str`): Device.
+ """
+ self.action_bounds = None
+ self.noise_scale = noise_scale
+ self.noise_shrink = noise_shrink
+ self.iters = iters
+ self.train_samples = train_samples
+ self.inference_samples = inference_samples
+ self.device = device
+
+ def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Drawing action samples from the uniform random distribution \
+ and tiling observations to the same shape as action samples.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observations.
+ - ebm (:obj:`torch.nn.Module`): Energy based model.
+ Returns:
+ - tiled_obs (:obj:`torch.Tensor`): Tiled observation.
+ - action_samples (:obj:`torch.Tensor`): Action samples.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
+ - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
+ - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
+ - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`.
+ Examples:
+ >>> obs = torch.randn(2, 4)
+ >>> ebm = EBM(4, 5)
+ >>> opt = DFO()
+ >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
+ >>> tiled_obs, action_samples = opt.sample(obs, ebm)
+ """
+ return self._sample(obs, self.train_samples)
+
+ @torch.no_grad()
+ def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
+ """
+ Overview:
+ Optimize for the best action conditioned on the current observation.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observations.
+ - ebm (:obj:`torch.nn.Module`): Energy based model.
+ Returns:
+ - best_action_samples (:obj:`torch.Tensor`): Actions.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
+ - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
+ - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
+ Examples:
+ >>> obs = torch.randn(2, 4)
+ >>> ebm = EBM(4, 5)
+ >>> opt = DFO()
+ >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
+ >>> best_action_samples = opt.infer(obs, ebm)
+ """
+ noise_scale = self.noise_scale
+
+ # (B, N, O), (B, N, A)
+ obs, action_samples = self._sample(obs, self.inference_samples)
+
+ for i in range(self.iters):
+ # (B, N)
+ energies = ebm.forward(obs, action_samples)
+ probs = F.softmax(-1.0 * energies, dim=-1)
+
+ # Resample with replacement.
+ idxs = torch.multinomial(probs, self.inference_samples, replacement=True)
+ action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs]
+
+ # Add noise and clip to target bounds.
+ action_samples = action_samples + torch.randn_like(action_samples) * noise_scale
+ action_samples = action_samples.clamp(min=self.action_bounds[0, :], max=self.action_bounds[1, :])
+
+ noise_scale *= self.noise_shrink
+
+ # Return target with highest probability.
+ return self._get_best_action_sample(obs, action_samples, ebm)
+
+
+@STOCHASTIC_OPTIMIZER_REGISTRY.register('ardfo')
+class AutoRegressiveDFO(DFO):
+ """
+ Overview:
+ AutoRegressive Derivative-Free Optimizer in paper Implicit Behavioral Cloning.
+ https://arxiv.org/abs/2109.00137
+ Interface:
+ ``__init__``, ``infer``
+ """
+
+ def __init__(
+ self,
+ noise_scale: float = 0.33,
+ noise_shrink: float = 0.5,
+ iters: int = 3,
+ train_samples: int = 8,
+ inference_samples: int = 4096,
+ device: str = 'cpu',
+ ):
+ """
+ Overview:
+ Initialize the AutoRegressive Derivative-Free Optimizer
+ Arguments:
+ - noise_scale (:obj:`float`): Initial noise scale.
+ - noise_shrink (:obj:`float`): Noise scale shrink rate.
+ - iters (:obj:`int`): Number of iterations.
+ - train_samples (:obj:`int`): Number of samples for training.
+ - inference_samples (:obj:`int`): Number of samples for inference.
+ - device (:obj:`str`): Device.
+ """
+ super().__init__(noise_scale, noise_shrink, iters, train_samples, inference_samples, device)
+
+ @torch.no_grad()
+ def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
+ """
+ Overview:
+ Optimize for the best action conditioned on the current observation.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observations.
+ - ebm (:obj:`torch.nn.Module`): Energy based model.
+ Returns:
+ - best_action_samples (:obj:`torch.Tensor`): Actions.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
+ - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
+ - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
+ Examples:
+ >>> obs = torch.randn(2, 4)
+ >>> ebm = EBM(4, 5)
+ >>> opt = AutoRegressiveDFO()
+ >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
+ >>> best_action_samples = opt.infer(obs, ebm)
+ """
+ noise_scale = self.noise_scale
+
+ # (B, N, O), (B, N, A)
+ obs, action_samples = self._sample(obs, self.inference_samples)
+
+ for i in range(self.iters):
+ # j: action_dim index
+ for j in range(action_samples.shape[-1]):
+ # (B, N)
+ energies = ebm.forward(obs, action_samples)[..., j]
+ probs = F.softmax(-1.0 * energies, dim=-1)
+
+ # Resample with replacement.
+ idxs = torch.multinomial(probs, self.inference_samples, replacement=True)
+ action_samples = action_samples[torch.arange(action_samples.size(0)).unsqueeze(-1), idxs]
+
+ # Add noise and clip to target bounds.
+ action_samples[..., j] = action_samples[..., j] + torch.randn_like(action_samples[..., j]) * noise_scale
+
+ action_samples[..., j] = action_samples[..., j].clamp(
+ min=self.action_bounds[0, j], max=self.action_bounds[1, j]
+ )
+
+ noise_scale *= self.noise_shrink
+
+ # (B, N)
+ energies = ebm.forward(obs, action_samples)[..., -1]
+ probs = F.softmax(-1.0 * energies, dim=-1)
+ # (B, )
+ best_idxs = probs.argmax(dim=-1)
+ return action_samples[torch.arange(action_samples.size(0)), best_idxs]
+
+
+@STOCHASTIC_OPTIMIZER_REGISTRY.register('mcmc')
+class MCMC(StochasticOptimizer):
+ """
+ Overview:
+ MCMC method as stochastic optimizers in paper Implicit Behavioral Cloning.
+ https://arxiv.org/abs/2109.00137
+ Interface:
+ ``__init__``, ``sample``, ``infer``, ``grad_penalty``
+ """
+
+ class BaseScheduler(ABC):
+ """
+ Overview:
+ Base class for learning rate scheduler.
+ Interface:
+ ``get_rate``
+ """
+
+ @abstractmethod
+ def get_rate(self, index):
+ """
+ Overview:
+ Abstract method for getting learning rate.
+ """
+ raise NotImplementedError
+
+ class ExponentialScheduler:
+ """
+ Overview:
+ Exponential learning rate schedule for Langevin sampler.
+ Interface:
+ ``__init__``, ``get_rate``
+ """
+
+ def __init__(self, init, decay):
+ """
+ Overview:
+ Initialize the ExponentialScheduler.
+ Arguments:
+ - init (:obj:`float`): Initial learning rate.
+ - decay (:obj:`float`): Decay rate.
+ """
+ self._decay = decay
+ self._latest_lr = init
+
+ def get_rate(self, index):
+ """
+ Overview:
+ Get learning rate. Assumes calling sequentially.
+ Arguments:
+ - index (:obj:`int`): Current iteration.
+ """
+ del index
+ lr = self._latest_lr
+ self._latest_lr *= self._decay
+ return lr
+
+ class PolynomialScheduler:
+ """
+ Overview:
+ Polynomial learning rate schedule for Langevin sampler.
+ Interface:
+ ``__init__``, ``get_rate``
+ """
+
+ def __init__(self, init, final, power, num_steps):
+ """
+ Overview:
+ Initialize the PolynomialScheduler.
+ Arguments:
+ - init (:obj:`float`): Initial learning rate.
+ - final (:obj:`float`): Final learning rate.
+ - power (:obj:`float`): Power of polynomial.
+ - num_steps (:obj:`int`): Number of steps.
+ """
+ self._init = init
+ self._final = final
+ self._power = power
+ self._num_steps = num_steps
+
+ def get_rate(self, index):
+ """
+ Overview:
+ Get learning rate for index.
+ Arguments:
+ - index (:obj:`int`): Current iteration.
+ """
+ if index == -1:
+ return self._init
+ return (
+ (self._init - self._final) * ((1 - (float(index) / float(self._num_steps - 1))) ** (self._power))
+ ) + self._final
+
+ def __init__(
+ self,
+ iters: int = 100,
+ use_langevin_negative_samples: bool = True,
+ train_samples: int = 8,
+ inference_samples: int = 512,
+ stepsize_scheduler: dict = dict(
+ init=0.5,
+ final=1e-5,
+ power=2.0,
+ # num_steps,
+ ),
+ optimize_again: bool = True,
+ again_stepsize_scheduler: dict = dict(
+ init=1e-5,
+ final=1e-5,
+ power=2.0,
+ # num_steps,
+ ),
+ device: str = 'cpu',
+ # langevin_step
+ noise_scale: float = 0.5,
+ grad_clip=None,
+ delta_action_clip: float = 0.5,
+ add_grad_penalty: bool = True,
+ grad_norm_type: str = 'inf',
+ grad_margin: float = 1.0,
+ grad_loss_weight: float = 1.0,
+ **kwargs,
+ ):
+ """
+ Overview:
+ Initialize the MCMC.
+ Arguments:
+ - iters (:obj:`int`): Number of iterations.
+ - use_langevin_negative_samples (:obj:`bool`): Whether to use Langevin sampler.
+ - train_samples (:obj:`int`): Number of samples for training.
+ - inference_samples (:obj:`int`): Number of samples for inference.
+ - stepsize_scheduler (:obj:`dict`): Step size scheduler for Langevin sampler.
+ - optimize_again (:obj:`bool`): Whether to run a second optimization.
+ - again_stepsize_scheduler (:obj:`dict`): Step size scheduler for the second optimization.
+ - device (:obj:`str`): Device.
+ - noise_scale (:obj:`float`): Initial noise scale.
+ - grad_clip (:obj:`float`): Gradient clip.
+ - delta_action_clip (:obj:`float`): Action clip.
+ - add_grad_penalty (:obj:`bool`): Whether to add gradient penalty.
+ - grad_norm_type (:obj:`str`): Gradient norm type.
+ - grad_margin (:obj:`float`): Gradient margin.
+ - grad_loss_weight (:obj:`float`): Gradient loss weight.
+ """
+ self.iters = iters
+ self.use_langevin_negative_samples = use_langevin_negative_samples
+ self.train_samples = train_samples
+ self.inference_samples = inference_samples
+ self.stepsize_scheduler = stepsize_scheduler
+ self.optimize_again = optimize_again
+ self.again_stepsize_scheduler = again_stepsize_scheduler
+ self.device = device
+
+ self.noise_scale = noise_scale
+ self.grad_clip = grad_clip
+ self.delta_action_clip = delta_action_clip
+ self.add_grad_penalty = add_grad_penalty
+ self.grad_norm_type = grad_norm_type
+ self.grad_margin = grad_margin
+ self.grad_loss_weight = grad_loss_weight
+
+ @staticmethod
+ def _gradient_wrt_act(
+ obs: torch.Tensor,
+ action: torch.Tensor,
+ ebm: nn.Module,
+ create_graph: bool = False,
+ ) -> torch.Tensor:
+ """
+ Overview:
+ Calculate gradient w.r.t action.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observations.
+ - action (:obj:`torch.Tensor`): Actions.
+ - ebm (:obj:`torch.nn.Module`): Energy based model.
+ - create_graph (:obj:`bool`): Whether to create graph.
+ Returns:
+ - grad (:obj:`torch.Tensor`): Gradient w.r.t action.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
+ - action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
+ - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
+ - grad (:obj:`torch.Tensor`): :math:`(B, N, A)`.
+ """
+ action.requires_grad_(True)
+ energy = ebm.forward(obs, action).sum()
+ # `create_graph` set to `True` when second order derivative
+ # is needed i.e, d(de/da)/d_param
+ grad = torch.autograd.grad(energy, action, create_graph=create_graph)[0]
+ action.requires_grad_(False)
+ return grad
+
+ def grad_penalty(self, obs: torch.Tensor, action: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
+ """
+ Overview:
+ Calculate gradient penalty.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observations.
+ - action (:obj:`torch.Tensor`): Actions.
+ - ebm (:obj:`torch.nn.Module`): Energy based model.
+ Returns:
+ - loss (:obj:`torch.Tensor`): Gradient penalty.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N+1, O)`.
+ - action (:obj:`torch.Tensor`): :math:`(B, N+1, A)`.
+ - ebm (:obj:`torch.nn.Module`): :math:`(B, N+1, O)`.
+ - loss (:obj:`torch.Tensor`): :math:`(B, )`.
+ """
+ if not self.add_grad_penalty:
+ return 0.
+ # (B, N+1, A), this gradient is differentiable w.r.t model parameters
+ de_dact = MCMC._gradient_wrt_act(obs, action, ebm, create_graph=True)
+
+ def compute_grad_norm(grad_norm_type, de_dact) -> torch.Tensor:
+ # de_deact: B, N+1, A
+ # return: B, N+1
+ grad_norm_type_to_ord = {
+ '1': 1,
+ '2': 2,
+ 'inf': float('inf'),
+ }
+ ord = grad_norm_type_to_ord[grad_norm_type]
+ return torch.linalg.norm(de_dact, ord, dim=-1)
+
+ # (B, N+1)
+ grad_norms = compute_grad_norm(self.grad_norm_type, de_dact)
+ grad_norms = grad_norms - self.grad_margin
+ grad_norms = grad_norms.clamp(min=0., max=1e10)
+ grad_norms = grad_norms.pow(2)
+
+ grad_loss = grad_norms.mean()
+ return grad_loss * self.grad_loss_weight
+
+ # can not use @torch.no_grad() during the inference
+ # because we need to calculate gradient w.r.t inputs as MCMC updates.
+ @no_ebm_grad()
+ def _langevin_step(self, obs: torch.Tensor, action: torch.Tensor, stepsize: float, ebm: nn.Module) -> torch.Tensor:
+ """
+ Overview:
+ Run one langevin MCMC step.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observations.
+ - action (:obj:`torch.Tensor`): Actions.
+ - stepsize (:obj:`float`): Step size.
+ - ebm (:obj:`torch.nn.Module`): Energy based model.
+ Returns:
+ - action (:obj:`torch.Tensor`): Actions.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
+ - action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
+ - stepsize (:obj:`float`): :math:`(B, )`.
+ - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
+ """
+ l_lambda = 1.0
+ de_dact = MCMC._gradient_wrt_act(obs, action, ebm)
+
+ if self.grad_clip:
+ de_dact = de_dact.clamp(min=-self.grad_clip, max=self.grad_clip)
+
+ gradient_scale = 0.5
+ de_dact = (gradient_scale * l_lambda * de_dact + torch.randn_like(de_dact) * l_lambda * self.noise_scale)
+
+ delta_action = stepsize * de_dact
+ delta_action_clip = self.delta_action_clip * 0.5 * (self.action_bounds[1] - self.action_bounds[0])
+ delta_action = delta_action.clamp(min=-delta_action_clip, max=delta_action_clip)
+
+ action = action - delta_action
+ action = action.clamp(min=self.action_bounds[0], max=self.action_bounds[1])
+
+ return action
+
+ @no_ebm_grad()
+ def _langevin_action_given_obs(
+ self,
+ obs: torch.Tensor,
+ action: torch.Tensor,
+ ebm: nn.Module,
+ scheduler: BaseScheduler = None
+ ) -> torch.Tensor:
+ """
+ Overview:
+ Run langevin MCMC for `self.iters` steps.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observations.
+ - action (:obj:`torch.Tensor`): Actions.
+ - ebm (:obj:`torch.nn.Module`): Energy based model.
+ - scheduler (:obj:`BaseScheduler`): Learning rate scheduler.
+ Returns:
+ - action (:obj:`torch.Tensor`): Actions.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
+ - action (:obj:`torch.Tensor`): :math:`(B, N, A)`.
+ - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
+ """
+ if not scheduler:
+ self.stepsize_scheduler['num_steps'] = self.iters
+ scheduler = MCMC.PolynomialScheduler(**self.stepsize_scheduler)
+ stepsize = scheduler.get_rate(-1)
+ for i in range(self.iters):
+ action = self._langevin_step(obs, action, stepsize, ebm)
+ stepsize = scheduler.get_rate(i)
+ return action
+
+ @no_ebm_grad()
+ def sample(self, obs: torch.Tensor, ebm: nn.Module) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Create tiled observations and sample counter-negatives for InfoNCE loss.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observations.
+ - ebm (:obj:`torch.nn.Module`): Energy based model.
+ Returns:
+ - tiled_obs (:obj:`torch.Tensor`): Tiled observations.
+ - action_samples (:obj:`torch.Tensor`): Action samples.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
+ - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
+ - tiled_obs (:obj:`torch.Tensor`): :math:`(B, N, O)`.
+ - action_samples (:obj:`torch.Tensor`): :math:`(B, N, A)`.
+ Examples:
+ >>> obs = torch.randn(2, 4)
+ >>> ebm = EBM(4, 5)
+ >>> opt = MCMC()
+ >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
+ >>> tiled_obs, action_samples = opt.sample(obs, ebm)
+ """
+ obs, uniform_action_samples = self._sample(obs, self.train_samples)
+ if not self.use_langevin_negative_samples:
+ return obs, uniform_action_samples
+ langevin_action_samples = self._langevin_action_given_obs(obs, uniform_action_samples, ebm)
+ return obs, langevin_action_samples
+
+ @no_ebm_grad()
+ def infer(self, obs: torch.Tensor, ebm: nn.Module) -> torch.Tensor:
+ """
+ Overview:
+ Optimize for the best action conditioned on the current observation.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observations.
+ - ebm (:obj:`torch.nn.Module`): Energy based model.
+ Returns:
+ - best_action_samples (:obj:`torch.Tensor`): Actions.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, O)`.
+ - ebm (:obj:`torch.nn.Module`): :math:`(B, N, O)`.
+ - best_action_samples (:obj:`torch.Tensor`): :math:`(B, A)`.
+ Examples:
+ >>> obs = torch.randn(2, 4)
+ >>> ebm = EBM(4, 5)
+ >>> opt = MCMC()
+ >>> opt.set_action_bounds(np.stack([np.zeros(5), np.ones(5)], axis=0))
+ >>> best_action_samples = opt.infer(obs, ebm)
+ """
+ # (B, N, O), (B, N, A)
+ obs, uniform_action_samples = self._sample(obs, self.inference_samples)
+ action_samples = self._langevin_action_given_obs(
+ obs,
+ uniform_action_samples,
+ ebm,
+ )
+
+ # Run a second optimization, a trick for more precise inference
+ if self.optimize_again:
+ self.again_stepsize_scheduler['num_steps'] = self.iters
+ action_samples = self._langevin_action_given_obs(
+ obs,
+ action_samples,
+ ebm,
+ scheduler=MCMC.PolynomialScheduler(**self.again_stepsize_scheduler),
+ )
+
+ # action_samples: B, N, A
+ return self._get_best_action_sample(obs, action_samples, ebm)
+
+
+@MODEL_REGISTRY.register('ebm')
+class EBM(nn.Module):
+ """
+ Overview:
+ Energy based model.
+ Interface:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ obs_shape: int,
+ action_shape: int,
+ hidden_size: int = 512,
+ hidden_layer_num: int = 4,
+ **kwargs,
+ ):
+ """
+ Overview:
+ Initialize the EBM.
+ Arguments:
+ - obs_shape (:obj:`int`): Observation shape.
+ - action_shape (:obj:`int`): Action shape.
+ - hidden_size (:obj:`int`): Hidden size.
+ - hidden_layer_num (:obj:`int`): Number of hidden layers.
+ """
+ super().__init__()
+ input_size = obs_shape + action_shape
+ self.net = nn.Sequential(
+ nn.Linear(input_size, hidden_size), nn.ReLU(),
+ RegressionHead(
+ hidden_size,
+ 1,
+ hidden_layer_num,
+ final_tanh=False,
+ )
+ )
+
+ def forward(self, obs, action):
+ """
+ Overview:
+ Forward computation graph of EBM.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observation of shape (B, N, O).
+ - action (:obj:`torch.Tensor`): Action of shape (B, N, A).
+ Returns:
+ - pred (:obj:`torch.Tensor`): Energy of shape (B, N).
+ Examples:
+ >>> obs = torch.randn(2, 3, 4)
+ >>> action = torch.randn(2, 3, 5)
+ >>> ebm = EBM(4, 5)
+ >>> pred = ebm(obs, action)
+ """
+ x = torch.cat([obs, action], -1)
+ x = self.net(x)
+ return x['pred']
+
+
+@MODEL_REGISTRY.register('arebm')
+class AutoregressiveEBM(nn.Module):
+ """
+ Overview:
+ Autoregressive energy based model.
+ Interface:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ obs_shape: int,
+ action_shape: int,
+ hidden_size: int = 512,
+ hidden_layer_num: int = 4,
+ ):
+ """
+ Overview:
+ Initialize the AutoregressiveEBM.
+ Arguments:
+ - obs_shape (:obj:`int`): Observation shape.
+ - action_shape (:obj:`int`): Action shape.
+ - hidden_size (:obj:`int`): Hidden size.
+ - hidden_layer_num (:obj:`int`): Number of hidden layers.
+ """
+ super().__init__()
+ self.ebm_list = nn.ModuleList()
+ for i in range(action_shape):
+ self.ebm_list.append(EBM(obs_shape, i + 1, hidden_size, hidden_layer_num))
+
+ def forward(self, obs, action):
+ """
+ Overview:
+ Forward computation graph of AutoregressiveEBM.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observation of shape (B, N, O).
+ - action (:obj:`torch.Tensor`): Action of shape (B, N, A).
+ Returns:
+ - pred (:obj:`torch.Tensor`): Energy of shape (B, N, A).
+ Examples:
+ >>> obs = torch.randn(2, 3, 4)
+ >>> action = torch.randn(2, 3, 5)
+ >>> arebm = AutoregressiveEBM(4, 5)
+ >>> pred = arebm(obs, action)
+ """
+ output_list = []
+ for i, ebm in enumerate(self.ebm_list):
+ output_list.append(ebm(obs, action[..., :i + 1]))
+ return torch.stack(output_list, axis=-1)
diff --git a/DI-engine/ding/model/template/edac.py b/DI-engine/ding/model/template/edac.py
new file mode 100755
index 0000000000000000000000000000000000000000..397ba69763a0b7d1ca298c6811f4f8f7859c89e4
--- /dev/null
+++ b/DI-engine/ding/model/template/edac.py
@@ -0,0 +1,182 @@
+from typing import Union, Optional, Dict
+from easydict import EasyDict
+
+import torch
+import torch.nn as nn
+from ding.model.common import ReparameterizationHead, EnsembleHead
+from ding.utils import SequenceType, squeeze
+
+from ding.utils import MODEL_REGISTRY
+
+
+@MODEL_REGISTRY.register('edac')
+class EDAC(nn.Module):
+ """
+ Overview:
+ The Q-value Actor-Critic network with the ensemble mechanism, which is used in EDAC.
+ Interfaces:
+ ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
+ """
+ mode = ['compute_actor', 'compute_critic']
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType, EasyDict],
+ ensemble_num: int = 2,
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 1,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ **kwargs
+ ) -> None:
+ """
+ Overview:
+ Initailize the EDAC Model according to input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ).
+ - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \
+ EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
+ - ensemble_num (:obj:`int`): Q-net number.
+ - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head.
+ - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
+ for actor head.
+ - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head.
+ - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
+ for critic head.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \
+ after each FC layer, if ``None`` then default set to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \
+ see ``ding.torch_utils.network`` for more details.
+ """
+ super(EDAC, self).__init__()
+ obs_shape: int = squeeze(obs_shape)
+ action_shape = squeeze(action_shape)
+ self.action_shape = action_shape
+ self.ensemble_num = ensemble_num
+ self.actor = nn.Sequential(
+ nn.Linear(obs_shape, actor_head_hidden_size), activation,
+ ReparameterizationHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ sigma_type='conditioned',
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+
+ critic_input_size = obs_shape + action_shape
+ self.critic = EnsembleHead(
+ critic_input_size,
+ 1,
+ critic_head_hidden_size,
+ critic_head_layer_num,
+ self.ensemble_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+
+ def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ The unique execution (forward) method of EDAC method, and one can indicate different modes to implement \
+ different computation graph, including ``compute_actor`` and ``compute_critic`` in EDAC.
+ Mode compute_actor:
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): Observation data, defaults to tensor.
+ Returns:
+ - output (:obj:`Dict`): Output dict data, including differnet key-values among distinct action_space.
+ Mode compute_critic:
+ Arguments:
+ - inputs (:obj:`Dict`): Input dict data, including obs and action tensor.
+ Returns:
+ - output (:obj:`Dict`): Output dict data, including q_value tensor.
+
+ .. note::
+ For specific examples, one can refer to API doc of ``compute_actor`` and ``compute_critic`` respectively.
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_actor(self, obs: torch.Tensor) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
+ """
+ Overview:
+ The forward computation graph of compute_actor mode, uses observation tensor to produce actor output,
+ such as ``action``, ``logit`` and so on.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data, \
+ i.e. ``(B, obs_shape)``.
+ Returns:
+ - outputs (:obj:`Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]`): Actor output varying \
+ from action_space: ``reparameterization``.
+ ReturnsKeys (either):
+ - logit (:obj:`Dict[str, torch.Tensor]`): Reparameterization logit, usually in SAC.
+ - mu (:obj:`torch.Tensor`): Mean of parameterization gaussion distribution.
+ - sigma (:obj:`torch.Tensor`): Standard variation of parameterization gaussion distribution.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``.
+ - action (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
+ - logit.mu (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
+ - logit.sigma (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size.
+ - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \
+ ``action_shape.action_type_shape``.
+ - action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \
+ ``action_shape.action_args_shape``.
+ Examples:
+ >>> model = EDAC(64, 64,)
+ >>> obs = torch.randn(4, 64)
+ >>> actor_outputs = model(obs,'compute_actor')
+ >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 64]) # mu
+ >>> actor_outputs['logit'][1].shape == torch.Size([4, 64]) # sigma
+ """
+ x = self.actor(obs)
+ return {'logit': [x['mu'], x['sigma']]}
+
+ def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ The forward computation graph of compute_critic mode, uses observation and action tensor to produce critic
+ output, such as ``q_value``.
+ Arguments:
+ - inputs (:obj:`Dict[str, torch.Tensor]`): Dict strcture of input data, including ``obs`` and \
+ ``action`` tensor
+ Returns:
+ - outputs (:obj:`Dict[str, torch.Tensor]`): Critic output, such as ``q_value``.
+ ArgumentsKeys:
+ - obs: (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data.
+ - action (:obj:`Union[torch.Tensor, Dict]`): Continuous action with same size as ``action_shape``.
+ ReturnKeys:
+ - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N1)` or '(Ensemble_num, B, N1)', where B is batch size and N1 is \
+ ``obs_shape``.
+ - action (:obj:`torch.Tensor`): :math:`(B, N2)` or '(Ensemble_num, B, N2)', where B is batch size and N4 \
+ is ``action_shape``.
+ - q_value (:obj:`torch.Tensor`): :math:`(Ensemble_num, B)`, where B is batch size.
+ Examples:
+ >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
+ >>> model = EDAC(obs_shape=(8, ),action_shape=1)
+ >>> model(inputs, mode='compute_critic')['q_value'] # q value
+ ... tensor([0.0773, 0.1639, 0.0917, 0.0370], grad_fn=)
+ """
+
+ obs, action = inputs['obs'], inputs['action']
+ if len(action.shape) == 1: # (B, ) -> (B, 1)
+ action = action.unsqueeze(1)
+ x = torch.cat([obs, action], dim=-1)
+ if len(obs.shape) < 3:
+ # [batch_size,dim] -> [batch_size,Ensemble_num * dim,1]
+ x = x.repeat(1, self.ensemble_num).unsqueeze(-1)
+ else:
+ # [Ensemble_num,batch_size,dim] -> [batch_size,Ensemble_num,dim] -> [batch_size,Ensemble_num * dim, 1]
+ x = x.transpose(0, 1)
+ batch_size = obs.shape[1]
+ x = x.reshape(batch_size, -1, 1)
+ # [Ensemble_num,batch_size,1]
+ x = self.critic(x)['pred']
+ # [batch_size,1*Ensemble_num] -> [Ensemble_num,batch_size]
+ x = x.permute(1, 0)
+ return {'q_value': x}
diff --git a/DI-engine/ding/model/template/havac.py b/DI-engine/ding/model/template/havac.py
new file mode 100644
index 0000000000000000000000000000000000000000..77489ed517656faee1517626771b70bdecaff6a3
--- /dev/null
+++ b/DI-engine/ding/model/template/havac.py
@@ -0,0 +1,500 @@
+from typing import Union, Dict, Optional
+import torch
+import torch.nn as nn
+
+from ding.torch_utils import get_lstm
+from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
+from ding.model.template.q_learning import parallel_wrapper
+from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, \
+ FCEncoder, ConvEncoder
+
+
+class RNNLayer(nn.Module):
+
+ def __init__(self, lstm_type, input_size, hidden_size, res_link: bool = False):
+ super(RNNLayer, self).__init__()
+ self.rnn = get_lstm(lstm_type, input_size=input_size, hidden_size=hidden_size)
+ self.res_link = res_link
+
+ def forward(self, x, prev_state, inference: bool = False):
+ """
+ Forward pass of the RNN layer.
+ If inference is True, sequence length of input is set to 1.
+ If res_link is True, a residual link is added to the output.
+ """
+ # x: obs_embedding
+ if self.res_link:
+ a = x
+ if inference:
+ x = x.unsqueeze(0) # for rnn input, put the seq_len of x as 1 instead of none.
+ # prev_state: DataType: List[Tuple[torch.Tensor]]; Initially, it is a list of None
+ x, next_state = self.rnn(x, prev_state)
+ x = x.squeeze(0) # to delete the seq_len dim to match head network input
+ if self.res_link:
+ x = x + a
+ return {'output': x, 'next_state': next_state}
+ else:
+ # lstm_embedding stores all hidden_state
+ lstm_embedding = []
+ hidden_state_list = []
+ for t in range(x.shape[0]): # T timesteps
+ # use x[t:t+1] but not x[t] can keep original dimension
+ output, prev_state = self.rnn(x[t:t + 1], prev_state) # output: (1,B, head_hidden_size)
+ lstm_embedding.append(output)
+ hidden_state = [p['h'] for p in prev_state]
+ # only keep ht, {list: x.shape[0]{Tensor:(1, batch_size, head_hidden_size)}}
+ hidden_state_list.append(torch.cat(hidden_state, dim=1))
+ x = torch.cat(lstm_embedding, 0) # (T, B, head_hidden_size)
+ if self.res_link:
+ x = x + a
+ all_hidden_state = torch.cat(hidden_state_list, dim=0)
+ return {'output': x, 'next_state': prev_state, 'hidden_state': all_hidden_state}
+
+
+@MODEL_REGISTRY.register('havac')
+class HAVAC(nn.Module):
+ """
+ Overview:
+ The HAVAC model of each agent for HAPPO.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
+
+ def __init__(
+ self,
+ agent_obs_shape: Union[int, SequenceType],
+ global_obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ agent_num: int,
+ use_lstm: bool = False,
+ lstm_type: str = 'gru',
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 2,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 1,
+ action_space: str = 'discrete',
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ sigma_type: Optional[str] = 'independent',
+ bound_type: Optional[str] = None,
+ res_link: bool = False,
+ ) -> None:
+ r"""
+ Overview:
+ Init the VAC Model for HAPPO according to arguments.
+ Arguments:
+ - agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent.
+ - global_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for global agent
+ - action_shape (:obj:`Union[int, SequenceType]`): Action's space.
+ - agent_num (:obj:`int`): Number of agents.
+ - lstm_type (:obj:`str`): use lstm or gru, default to gru
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``
+ - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
+ - actor_head_layer_num (:obj:`int`):
+ The num of layers used in the network to compute Q value output for actor's nn.
+ - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
+ - critic_head_layer_num (:obj:`int`):
+ The num of layers used in the network to compute Q value output for critic's nn.
+ - activation (:obj:`Optional[nn.Module]`):
+ The type of activation function to use in ``MLP`` the after ``layer_fn``,
+ if ``None`` then default set to ``nn.ReLU()``
+ - norm_type (:obj:`Optional[str]`):
+ The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details`
+ - res_link (:obj:`bool`): use the residual link or not, default to False
+ """
+ super(HAVAC, self).__init__()
+ self.agent_num = agent_num
+ self.agent_models = nn.ModuleList(
+ [
+ HAVACAgent(
+ agent_obs_shape=agent_obs_shape,
+ global_obs_shape=global_obs_shape,
+ action_shape=action_shape,
+ use_lstm=use_lstm,
+ action_space=action_space,
+ ) for _ in range(agent_num)
+ ]
+ )
+
+ def forward(self, agent_idx, input_data, mode):
+ selected_agent_model = self.agent_models[agent_idx]
+ output = selected_agent_model(input_data, mode)
+ return output
+
+
+class HAVACAgent(nn.Module):
+ """
+ Overview:
+ The HAVAC model of each agent for HAPPO.
+ Interfaces:
+ ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic``
+ """
+ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
+
+ def __init__(
+ self,
+ agent_obs_shape: Union[int, SequenceType],
+ global_obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ use_lstm: bool = False,
+ lstm_type: str = 'gru',
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 2,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 1,
+ action_space: str = 'discrete',
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ sigma_type: Optional[str] = 'happo',
+ bound_type: Optional[str] = None,
+ res_link: bool = False,
+ ) -> None:
+ r"""
+ Overview:
+ Init the VAC Model for HAPPO according to arguments.
+ Arguments:
+ - agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent.
+ - global_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for global agent
+ - action_shape (:obj:`Union[int, SequenceType]`): Action's space.
+ - lstm_type (:obj:`str`): use lstm or gru, default to gru
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``
+ - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
+ - actor_head_layer_num (:obj:`int`):
+ The num of layers used in the network to compute Q value output for actor's nn.
+ - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
+ - critic_head_layer_num (:obj:`int`):
+ The num of layers used in the network to compute Q value output for critic's nn.
+ - activation (:obj:`Optional[nn.Module]`):
+ The type of activation function to use in ``MLP`` the after ``layer_fn``,
+ if ``None`` then default set to ``nn.ReLU()``
+ - norm_type (:obj:`Optional[str]`):
+ The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details`
+ - res_link (:obj:`bool`): use the residual link or not, default to False
+ """
+ super(HAVACAgent, self).__init__()
+ agent_obs_shape: int = squeeze(agent_obs_shape)
+ global_obs_shape: int = squeeze(global_obs_shape)
+ action_shape: int = squeeze(action_shape)
+ self.global_obs_shape, self.agent_obs_shape, self.action_shape = global_obs_shape, agent_obs_shape, action_shape
+ self.action_space = action_space
+ # Encoder Type
+ if isinstance(agent_obs_shape, int) or len(agent_obs_shape) == 1:
+ actor_encoder_cls = FCEncoder
+ elif len(agent_obs_shape) == 3:
+ actor_encoder_cls = ConvEncoder
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own VAC".
+ format(agent_obs_shape)
+ )
+ if isinstance(global_obs_shape, int) or len(global_obs_shape) == 1:
+ critic_encoder_cls = FCEncoder
+ elif len(global_obs_shape) == 3:
+ critic_encoder_cls = ConvEncoder
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own VAC".
+ format(global_obs_shape)
+ )
+
+ # We directly connect the Head after a Liner layer instead of using the 3-layer FCEncoder.
+ # In SMAC task it can obviously improve the performance.
+ # Users can change the model according to their own needs.
+ self.actor_encoder = actor_encoder_cls(
+ obs_shape=agent_obs_shape,
+ hidden_size_list=encoder_hidden_size_list,
+ activation=activation,
+ norm_type=norm_type
+ )
+ self.critic_encoder = critic_encoder_cls(
+ obs_shape=global_obs_shape,
+ hidden_size_list=encoder_hidden_size_list,
+ activation=activation,
+ norm_type=norm_type
+ )
+ # RNN part
+ self.use_lstm = use_lstm
+ if self.use_lstm:
+ self.actor_rnn = RNNLayer(
+ lstm_type,
+ input_size=encoder_hidden_size_list[-1],
+ hidden_size=actor_head_hidden_size,
+ res_link=res_link
+ )
+ self.critic_rnn = RNNLayer(
+ lstm_type,
+ input_size=encoder_hidden_size_list[-1],
+ hidden_size=critic_head_hidden_size,
+ res_link=res_link
+ )
+ # Head Type
+ self.critic_head = RegressionHead(
+ critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
+ )
+ assert self.action_space in ['discrete', 'continuous'], self.action_space
+ if self.action_space == 'discrete':
+ self.actor_head = DiscreteHead(
+ actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type
+ )
+ elif self.action_space == 'continuous':
+ self.actor_head = ReparameterizationHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ sigma_type=sigma_type,
+ activation=activation,
+ norm_type=norm_type,
+ bound_type=bound_type
+ )
+ # must use list, not nn.ModuleList
+ self.actor = [self.actor_encoder, self.actor_rnn, self.actor_head] if self.use_lstm \
+ else [self.actor_encoder, self.actor_head]
+ self.critic = [self.critic_encoder, self.critic_rnn, self.critic_head] if self.use_lstm \
+ else [self.critic_encoder, self.critic_head]
+ # for convenience of call some apis(such as: self.critic.parameters()), but may cause
+ # misunderstanding when print(self)
+ self.actor = nn.ModuleList(self.actor)
+ self.critic = nn.ModuleList(self.critic)
+
+ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
+ r"""
+ Overview:
+ Use encoded embedding tensor to predict output.
+ Parameter updates with VAC's MLPs forward setup.
+ Arguments:
+ Forward with ``'compute_actor'`` or ``'compute_critic'``:
+ - inputs (:obj:`torch.Tensor`):
+ The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
+ Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``.
+ Returns:
+ - outputs (:obj:`Dict`):
+ Run with encoder and head.
+
+ Forward with ``'compute_actor'``, Necessary Keys:
+ - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``.
+
+ Forward with ``'compute_critic'``, Necessary Keys:
+ - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
+ Shapes:
+ - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N corresponding ``hidden_size``
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape``
+ - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
+
+ Actor Examples:
+ >>> model = VAC(64,128)
+ >>> inputs = torch.randn(4, 64)
+ >>> actor_outputs = model(inputs,'compute_actor')
+ >>> assert actor_outputs['logit'].shape == torch.Size([4, 128])
+
+ Critic Examples:
+ >>> model = VAC(64,64)
+ >>> inputs = torch.randn(4, 64)
+ >>> critic_outputs = model(inputs,'compute_critic')
+ >>> critic_outputs['value']
+ tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=)
+
+ Actor-Critic Examples:
+ >>> model = VAC(64,64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = model(inputs,'compute_actor_critic')
+ >>> outputs['value']
+ tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=)
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_actor(self, inputs: Dict, inference: bool = False) -> Dict:
+ r"""
+ Overview:
+ Execute parameter updates with ``'compute_actor'`` mode
+ Use encoded embedding tensor to predict output.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`):
+ input data dict with keys ['obs'(with keys ['agent_state', 'global_state', 'action_mask']),
+ 'actor_prev_state']
+ Returns:
+ - outputs (:obj:`Dict`):
+ Run with encoder RNN(optional) and head.
+
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): Logit encoding tensor.
+ - actor_next_state:
+ - hidden_state
+ Shapes:
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape``
+ - actor_next_state: (B,)
+ - hidden_state:
+
+ Examples:
+ >>> model = HAVAC(
+ agent_obs_shape=obs_dim,
+ global_obs_shape=global_obs_dim,
+ action_shape=action_dim,
+ use_lstm = True,
+ )
+ >>> inputs = {
+ 'obs': {
+ 'agent_state': torch.randn(T, bs, obs_dim),
+ 'global_state': torch.randn(T, bs, global_obs_dim),
+ 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim))
+ },
+ 'actor_prev_state': [None for _ in range(bs)],
+ }
+ >>> actor_outputs = model(inputs,'compute_actor')
+ >>> assert actor_outputs['logit'].shape == (T, bs, action_dim)
+ """
+ x = inputs['obs']['agent_state']
+ output = {}
+ if self.use_lstm:
+ rnn_actor_prev_state = inputs['actor_prev_state']
+ if inference:
+ x = self.actor_encoder(x)
+ rnn_output = self.actor_rnn(x, rnn_actor_prev_state, inference)
+ x = rnn_output['output']
+ x = self.actor_head(x)
+ output['next_state'] = rnn_output['next_state']
+ # output: 'logit'/'next_state'
+ else:
+ assert len(x.shape) in [3, 5], x.shape
+ x = parallel_wrapper(self.actor_encoder)(x) # (T, B, N)
+ rnn_output = self.actor_rnn(x, rnn_actor_prev_state, inference)
+ x = rnn_output['output']
+ x = parallel_wrapper(self.actor_head)(x)
+ output['actor_next_state'] = rnn_output['next_state']
+ output['actor_hidden_state'] = rnn_output['hidden_state']
+ # output: 'logit'/'actor_next_state'/'hidden_state'
+ else:
+ x = self.actor_encoder(x)
+ x = self.actor_head(x)
+ # output: 'logit'
+
+ if self.action_space == 'discrete':
+ action_mask = inputs['obs']['action_mask']
+ logit = x['logit']
+ logit[action_mask == 0.0] = -99999999
+ elif self.action_space == 'continuous':
+ logit = x
+ output['logit'] = logit
+ return output
+
+ def compute_critic(self, inputs: Dict, inference: bool = False) -> Dict:
+ r"""
+ Overview:
+ Execute parameter updates with ``'compute_critic'`` mode
+ Use encoded embedding tensor to predict output.
+ Arguments:
+ - inputs (:obj:`Dict`):
+ input data dict with keys ['obs'(with keys ['agent_state', 'global_state', 'action_mask']),
+ 'critic_prev_state'(when you are using rnn)]
+ Returns:
+ - outputs (:obj:`Dict`):
+ Run with encoder [rnn] and head.
+
+ Necessary Keys:
+ - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
+ - logits
+ Shapes:
+ - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
+ - logits
+
+ Examples:
+ >>> model = HAVAC(
+ agent_obs_shape=obs_dim,
+ global_obs_shape=global_obs_dim,
+ action_shape=action_dim,
+ use_lstm = True,
+ )
+ >>> inputs = {
+ 'obs': {
+ 'agent_state': torch.randn(T, bs, obs_dim),
+ 'global_state': torch.randn(T, bs, global_obs_dim),
+ 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim))
+ },
+ 'critic_prev_state': [None for _ in range(bs)],
+ }
+ >>> critic_outputs = model(inputs,'compute_critic')
+ >>> assert critic_outputs['value'].shape == (T, bs))
+ """
+ global_obs = inputs['obs']['global_state']
+ output = {}
+ if self.use_lstm:
+ rnn_critic_prev_state = inputs['critic_prev_state']
+ if inference:
+ x = self.critic_encoder(global_obs)
+ rnn_output = self.critic_rnn(x, rnn_critic_prev_state, inference)
+ x = rnn_output['output']
+ x = self.critic_head(x)
+ output['next_state'] = rnn_output['next_state']
+ # output: 'value'/'next_state'
+ else:
+ assert len(global_obs.shape) in [3, 5], global_obs.shape
+ x = parallel_wrapper(self.critic_encoder)(global_obs) # (T, B, N)
+ rnn_output = self.critic_rnn(x, rnn_critic_prev_state, inference)
+ x = rnn_output['output']
+ x = parallel_wrapper(self.critic_head)(x)
+ output['critic_next_state'] = rnn_output['next_state']
+ output['critic_hidden_state'] = rnn_output['hidden_state']
+ # output: 'value'/'critic_next_state'/'hidden_state'
+ else:
+ x = self.critic_encoder(global_obs)
+ x = self.critic_head(x)
+ # output: 'value'
+ output['value'] = x['pred']
+ return output
+
+ def compute_actor_critic(self, inputs: Dict, inference: bool = False) -> Dict:
+ r"""
+ Overview:
+ Execute parameter updates with ``'compute_actor_critic'`` mode
+ Use encoded embedding tensor to predict output.
+ Arguments:
+ - inputs (:dict): input data dict with keys
+ ['obs'(with keys ['agent_state', 'global_state', 'action_mask']),
+ 'actor_prev_state', 'critic_prev_state'(when you are using rnn)]
+
+ Returns:
+ - outputs (:obj:`Dict`):
+ Run with encoder and head.
+
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``.
+ - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
+ Shapes:
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape``
+ - value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
+
+ Examples:
+ >>> model = VAC(64,64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = model(inputs,'compute_actor_critic')
+ >>> outputs['value']
+ tensor([0.0252, 0.0235, 0.0201, 0.0072], grad_fn=)
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+
+
+ .. note::
+ ``compute_actor_critic`` interface aims to save computation when shares encoder.
+ Returning the combination dictionry.
+
+ """
+ actor_output = self.compute_actor(inputs, inference)
+ critic_output = self.compute_critic(inputs, inference)
+ if self.use_lstm:
+ return {
+ 'logit': actor_output['logit'],
+ 'value': critic_output['value'],
+ 'actor_next_state': actor_output['actor_next_state'],
+ 'actor_hidden_state': actor_output['actor_hidden_state'],
+ 'critic_next_state': critic_output['critic_next_state'],
+ 'critic_hidden_state': critic_output['critic_hidden_state'],
+ }
+ else:
+ return {
+ 'logit': actor_output['logit'],
+ 'value': critic_output['value'],
+ }
diff --git a/DI-engine/ding/model/template/language_transformer.py b/DI-engine/ding/model/template/language_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..cac2d69adf3705ef1d4592a9d3a790cf3e16801a
--- /dev/null
+++ b/DI-engine/ding/model/template/language_transformer.py
@@ -0,0 +1,102 @@
+from typing import List, Dict
+import torch
+from torch import nn
+
+try:
+ from transformers import AutoTokenizer, AutoModelForTokenClassification
+except ImportError:
+ from ditk import logging
+ logging.warning("not found transformer, please install it using: pip install transformers")
+from ding.utils import MODEL_REGISTRY
+
+
+@MODEL_REGISTRY.register('language_transformer')
+class LanguageTransformer(nn.Module):
+ """
+ Overview:
+ The LanguageTransformer network. Download a pre-trained language model and add head on it.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ model_name: str = "bert-base-uncased",
+ add_linear: bool = False,
+ embedding_size: int = 128,
+ freeze_encoder: bool = True
+ ) -> None:
+ """
+ Overview:
+ Init the LanguageTransformer Model according to input arguments.
+ Arguments:
+ - model_name (:obj:`str`): The base language model name in huggingface, such as "bert-base-uncased".
+ - add_linear (:obj:`bool`): Whether to add a linear layer on the top of language model, defaults to be \
+ ``False``.
+ - embedding_size (:obj:`int`): The embedding size of the added linear layer, such as 128.
+ - freeze_encoder (:obj:`bool`): Whether to freeze the encoder language model while training, \
+ defaults to be ``True``.
+ """
+ super().__init__()
+ self.tokenizer = AutoTokenizer.from_pretrained(model_name)
+ self.model = AutoModelForTokenClassification.from_pretrained(model_name)
+
+ # Freeze transformer encoder and only train the linear layer
+ if freeze_encoder:
+ for param in self.model.parameters():
+ param.requires_grad = False
+
+ if add_linear:
+ # Add a small, adjustable linear layer on top of language model tuned through RL
+ self.embedding_size = embedding_size
+ self.linear = nn.Linear(
+ self.model.config.hidden_size, embedding_size
+ ) # 768 for bert-base-uncased, distilbert-base-uncased
+ else:
+ self.linear = None
+
+ def _calc_embedding(self, x: list) -> torch.Tensor:
+ # ``truncation=True`` means that if the length of the prompt exceed the ``max_length`` of the tokenizer,
+ # the exceeded part will be truncated. ``padding=True`` means that if the length of the prompt does not reach
+ # the ``max_length``, the latter part will be padded. These settings ensure the length of encoded tokens is
+ # exactly ``max_length``, which can enable batch-wise computing.
+ input = self.tokenizer(x, truncation=True, padding=True, return_tensors="pt").to(self.model.device)
+ output = self.model(**input, output_hidden_states=True)
+ # Get last layer hidden states
+ last_hidden_states = output.hidden_states[-1]
+ # Get [CLS] hidden states
+ sentence_embedding = last_hidden_states[:, 0, :] # len(input_list) x hidden_size
+
+ if self.linear:
+ sentence_embedding = self.linear(sentence_embedding) # len(input_list) x embedding_size
+
+ return sentence_embedding
+
+ def forward(self, train_samples: List[str], candidate_samples: List[str]) -> Dict:
+ """
+ Overview:
+ LanguageTransformer forward computation graph, input two lists of strings and predict their matching scores.
+ Arguments:
+ - train_samples (:obj:`List[str]`): One list of strings.
+ - candidate_samples (:obj:`List[str]`): The other list of strings to calculate the matching scores.
+ Returns:
+ - output (:obj:`Dict`): Output dict data, including the logit of matching scores and the \
+ corresponding ``torch.distributions.Categorical`` object.
+
+ Examples:
+ >>> test_pids = [1]
+ >>> cand_pids = [0, 2, 4]
+ >>> problems = [ \
+ "This is problem 0", "This is the first question", "Second problem is here", "Another problem", \
+ "This is the last problem" \
+ ]
+ >>> ctxt_list = [problems[pid] for pid in test_pids]
+ >>> cands_list = [problems[pid] for pid in cand_pids]
+ >>> model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
+ >>> scores = model(ctxt_list, cands_list)
+ >>> assert scores.shape == (1, 3)
+ """
+ prompt_embedding = self._calc_embedding(train_samples)
+ cands_embedding = self._calc_embedding(candidate_samples)
+ scores = torch.mm(prompt_embedding, cands_embedding.t())
+ return {'dist': torch.distributions.Categorical(logits=scores), 'logit': scores}
diff --git a/DI-engine/ding/model/template/madqn.py b/DI-engine/ding/model/template/madqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cab2b1e98c1e5b266754d555629515876ca5e1b
--- /dev/null
+++ b/DI-engine/ding/model/template/madqn.py
@@ -0,0 +1,54 @@
+import torch.nn as nn
+from ding.utils import MODEL_REGISTRY
+from .qmix import QMix
+
+
+@MODEL_REGISTRY.register('madqn')
+class MADQN(nn.Module):
+
+ def __init__(
+ self,
+ agent_num: int,
+ obs_shape: int,
+ action_shape: int,
+ hidden_size_list: list,
+ global_obs_shape: int = None,
+ mixer: bool = False,
+ global_cooperation: bool = True,
+ lstm_type: str = 'gru',
+ dueling: bool = False
+ ) -> None:
+ super(MADQN, self).__init__()
+ self.current = QMix(
+ agent_num=agent_num,
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ hidden_size_list=hidden_size_list,
+ global_obs_shape=global_obs_shape,
+ mixer=mixer,
+ lstm_type=lstm_type,
+ dueling=dueling
+ )
+ self.global_cooperation = global_cooperation
+ if self.global_cooperation:
+ cooperation_obs_shape = global_obs_shape
+ else:
+ cooperation_obs_shape = obs_shape
+ self.cooperation = QMix(
+ agent_num=agent_num,
+ obs_shape=cooperation_obs_shape,
+ action_shape=action_shape,
+ hidden_size_list=hidden_size_list,
+ global_obs_shape=global_obs_shape,
+ mixer=mixer,
+ lstm_type=lstm_type,
+ dueling=dueling
+ )
+
+ def forward(self, data: dict, cooperation: bool = False, single_step: bool = True) -> dict:
+ if cooperation:
+ if self.global_cooperation:
+ data['obs']['agent_state'] = data['obs']['global_state']
+ return self.cooperation(data, single_step=single_step)
+ else:
+ return self.current(data, single_step=single_step)
diff --git a/DI-engine/ding/model/template/maqac.py b/DI-engine/ding/model/template/maqac.py
new file mode 100644
index 0000000000000000000000000000000000000000..2d72e43d534852231fe09a61c57193a8f68dd614
--- /dev/null
+++ b/DI-engine/ding/model/template/maqac.py
@@ -0,0 +1,477 @@
+from typing import Union, Dict, Optional
+from easydict import EasyDict
+import torch
+import torch.nn as nn
+
+from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
+from ..common import RegressionHead, ReparameterizationHead, DiscreteHead, MultiHead, \
+ FCEncoder, ConvEncoder
+
+
+@MODEL_REGISTRY.register('discrete_maqac')
+class DiscreteMAQAC(nn.Module):
+ """
+ Overview:
+ The neural network and computation graph of algorithms related to discrete action Multi-Agent Q-value \
+ Actor-CritiC (MAQAC) model. The model is composed of actor and critic, where actor is a MLP network and \
+ critic is a MLP network. The actor network is used to predict the action probability distribution, and the \
+ critic network is used to predict the Q value of the state-action pair.
+ Interfaces:
+ ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
+ """
+ mode = ['compute_actor', 'compute_critic']
+
+ def __init__(
+ self,
+ agent_obs_shape: Union[int, SequenceType],
+ global_obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ twin_critic: bool = False,
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 1,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize the DiscreteMAQAC Model according to arguments.
+ Arguments:
+ - agent_obs_shape (:obj:`Union[int, SequenceType]`): Agent's observation's space.
+ - global_obs_shape (:obj:`Union[int, SequenceType]`): Global observation's space.
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
+ - action_shape (:obj:`Union[int, SequenceType]`): Action's space.
+ - twin_critic (:obj:`bool`): Whether include twin critic.
+ - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
+ - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
+ for actor's nn.
+ - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
+ - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
+ for critic's nn.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \
+ ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``
+ - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` \
+ for more details.
+ """
+ super(DiscreteMAQAC, self).__init__()
+ agent_obs_shape: int = squeeze(agent_obs_shape)
+ action_shape: int = squeeze(action_shape)
+ self.actor = nn.Sequential(
+ nn.Linear(agent_obs_shape, actor_head_hidden_size), activation,
+ DiscreteHead(
+ actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type
+ )
+ )
+
+ self.twin_critic = twin_critic
+ if self.twin_critic:
+ self.critic = nn.ModuleList()
+ for _ in range(2):
+ self.critic.append(
+ nn.Sequential(
+ nn.Linear(global_obs_shape, critic_head_hidden_size), activation,
+ DiscreteHead(
+ critic_head_hidden_size,
+ action_shape,
+ critic_head_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ )
+ else:
+ self.critic = nn.Sequential(
+ nn.Linear(global_obs_shape, critic_head_hidden_size), activation,
+ DiscreteHead(
+ critic_head_hidden_size,
+ action_shape,
+ critic_head_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+
+ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
+ """
+ Overview:
+ Use observation tensor to predict output, with ``compute_actor`` or ``compute_critic`` mode.
+ Arguments:
+ - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
+ - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
+ - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \
+ with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \
+ N0 corresponds to ``agent_obs_shape``.
+ - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \
+ with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \
+ N1 corresponds to ``global_obs_shape``.
+ - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \
+ with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \
+ N2 corresponds to ``action_shape``.
+
+ - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
+ Returns:
+ - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, \
+ whose key-values vary in different forward modes.
+ Examples:
+ >>> B = 32
+ >>> agent_obs_shape = 216
+ >>> global_obs_shape = 264
+ >>> agent_num = 8
+ >>> action_shape = 14
+ >>> data = {
+ >>> 'obs': {
+ >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape),
+ >>> 'global_state': torch.randn(B, agent_num, global_obs_shape),
+ >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
+ >>> }
+ >>> }
+ >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True)
+ >>> logit = model(data, mode='compute_actor')['logit']
+ >>> value = model(data, mode='compute_critic')['q_value']
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_actor(self, inputs: Dict) -> Dict:
+ """
+ Overview:
+ Use observation tensor to predict action logits.
+ Arguments:
+ - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
+ - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
+ - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \
+ with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \
+ N0 corresponds to ``agent_obs_shape``.
+ - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \
+ with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \
+ N1 corresponds to ``global_obs_shape``.
+ - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \
+ with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \
+ N2 corresponds to ``action_shape``.
+ Returns:
+ - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, \
+ whose key-values vary in different forward modes.
+ - logit (:obj:`torch.Tensor`): Action's output logit (real value range), whose shape is \
+ :math:`(B, A, N2)`, where N2 corresponds to ``action_shape``.
+ - action_mask (:obj:`torch.Tensor`): Action mask tensor with same size as ``action_shape``.
+ Examples:
+ >>> B = 32
+ >>> agent_obs_shape = 216
+ >>> global_obs_shape = 264
+ >>> agent_num = 8
+ >>> action_shape = 14
+ >>> data = {
+ >>> 'obs': {
+ >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape),
+ >>> 'global_state': torch.randn(B, agent_num, global_obs_shape),
+ >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
+ >>> }
+ >>> }
+ >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True)
+ >>> logit = model.compute_actor(data)['logit']
+ """
+ action_mask = inputs['obs']['action_mask']
+ x = self.actor(inputs['obs']['agent_state'])
+ return {'logit': x['logit'], 'action_mask': action_mask}
+
+ def compute_critic(self, inputs: Dict) -> Dict:
+ """
+ Overview:
+ use observation tensor to predict Q value.
+ Arguments:
+ - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
+ - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
+ - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \
+ with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \
+ N0 corresponds to ``agent_obs_shape``.
+ - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \
+ with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \
+ N1 corresponds to ``global_obs_shape``.
+ - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \
+ with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \
+ N2 corresponds to ``action_shape``.
+ Returns:
+ - output (:obj:`Dict[str, torch.Tensor]`): The output dict of DiscreteMAQAC forward computation graph, \
+ whose key-values vary in different values of ``twin_critic``.
+ - q_value (:obj:`list`): If ``twin_critic=True``, q_value should be 2 elements, each is the shape of \
+ :math:`(B, A, N2)`, where B is batch size and A is agent num. N2 corresponds to ``action_shape``. \
+ Otherwise, q_value should be ``torch.Tensor``.
+ Examples:
+ >>> B = 32
+ >>> agent_obs_shape = 216
+ >>> global_obs_shape = 264
+ >>> agent_num = 8
+ >>> action_shape = 14
+ >>> data = {
+ >>> 'obs': {
+ >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape),
+ >>> 'global_state': torch.randn(B, agent_num, global_obs_shape),
+ >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
+ >>> }
+ >>> }
+ >>> model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=True)
+ >>> value = model.compute_critic(data)['q_value']
+ """
+
+ if self.twin_critic:
+ x = [m(inputs['obs']['global_state'])['logit'] for m in self.critic]
+ else:
+ x = self.critic(inputs['obs']['global_state'])['logit']
+ return {'q_value': x}
+
+
+@MODEL_REGISTRY.register('continuous_maqac')
+class ContinuousMAQAC(nn.Module):
+ """
+ Overview:
+ The neural network and computation graph of algorithms related to continuous action Multi-Agent Q-value \
+ Actor-CritiC (MAQAC) model. The model is composed of actor and critic, where actor is a MLP network and \
+ critic is a MLP network. The actor network is used to predict the action probability distribution, and the \
+ critic network is used to predict the Q value of the state-action pair.
+ Interfaces:
+ ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
+ """
+ mode = ['compute_actor', 'compute_critic']
+
+ def __init__(
+ self,
+ agent_obs_shape: Union[int, SequenceType],
+ global_obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType, EasyDict],
+ action_space: str,
+ twin_critic: bool = False,
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 1,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize the QAC Model according to arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
+ - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's space, such as 4, (3, )
+ - action_space (:obj:`str`): Whether choose ``regression`` or ``reparameterization``.
+ - twin_critic (:obj:`bool`): Whether include twin critic.
+ - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
+ - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
+ for actor's nn.
+ - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
+ - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
+ for critic's nn.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \
+ ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``
+ - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` \
+ for more details.
+ """
+ super(ContinuousMAQAC, self).__init__()
+ obs_shape: int = squeeze(agent_obs_shape)
+ global_obs_shape: int = squeeze(global_obs_shape)
+ action_shape = squeeze(action_shape)
+ self.action_shape = action_shape
+ self.action_space = action_space
+ assert self.action_space in ['regression', 'reparameterization'], self.action_space
+ if self.action_space == 'regression': # DDPG, TD3
+ self.actor = nn.Sequential(
+ nn.Linear(obs_shape, actor_head_hidden_size), activation,
+ RegressionHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ final_tanh=True,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ else: # SAC
+ self.actor = nn.Sequential(
+ nn.Linear(obs_shape, actor_head_hidden_size), activation,
+ ReparameterizationHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ sigma_type='conditioned',
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ self.twin_critic = twin_critic
+ critic_input_size = global_obs_shape + action_shape
+ if self.twin_critic:
+ self.critic = nn.ModuleList()
+ for _ in range(2):
+ self.critic.append(
+ nn.Sequential(
+ nn.Linear(critic_input_size, critic_head_hidden_size), activation,
+ RegressionHead(
+ critic_head_hidden_size,
+ 1,
+ critic_head_layer_num,
+ final_tanh=False,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ )
+ else:
+ self.critic = nn.Sequential(
+ nn.Linear(critic_input_size, critic_head_hidden_size), activation,
+ RegressionHead(
+ critic_head_hidden_size,
+ 1,
+ critic_head_layer_num,
+ final_tanh=False,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+
+ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
+ """
+ Overview:
+ Use observation and action tensor to predict output in ``compute_actor`` or ``compute_critic`` mode.
+ Arguments:
+ - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
+ - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
+ - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \
+ with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \
+ N0 corresponds to ``agent_obs_shape``.
+ - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \
+ with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \
+ N1 corresponds to ``global_obs_shape``.
+ - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \
+ with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \
+ N2 corresponds to ``action_shape``.
+
+ - ``action`` (:obj:`torch.Tensor`): The action tensor data, \
+ with shape :math:`(B, A, N3)`, where B is batch size and A is agent num. \
+ N3 corresponds to ``action_shape``.
+ - mode (:obj:`str`): Name of the forward mode.
+ Returns:
+ - outputs (:obj:`Dict`): Outputs of network forward, whose key-values will be different for different \
+ ``mode``, ``twin_critic``, ``action_space``.
+ Examples:
+ >>> B = 32
+ >>> agent_obs_shape = 216
+ >>> global_obs_shape = 264
+ >>> agent_num = 8
+ >>> action_shape = 14
+ >>> act_space = 'reparameterization' # regression
+ >>> data = {
+ >>> 'obs': {
+ >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape),
+ >>> 'global_state': torch.randn(B, agent_num, global_obs_shape),
+ >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
+ >>> },
+ >>> 'action': torch.randn(B, agent_num, squeeze(action_shape))
+ >>> }
+ >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False)
+ >>> if action_space == 'regression':
+ >>> action = model(data['obs'], mode='compute_actor')['action']
+ >>> elif action_space == 'reparameterization':
+ >>> (mu, sigma) = model(data['obs'], mode='compute_actor')['logit']
+ >>> value = model(data, mode='compute_critic')['q_value']
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_actor(self, inputs: Dict) -> Dict:
+ """
+ Overview:
+ Use observation tensor to predict action logits.
+ Arguments:
+ - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
+ - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \
+ with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \
+ N0 corresponds to ``agent_obs_shape``.
+
+ Returns:
+ - outputs (:obj:`Dict`): Outputs of network forward.
+ ReturnKeys (``action_space == 'regression'``):
+ - action (:obj:`torch.Tensor`): Action tensor with same size as ``action_shape``.
+ ReturnKeys (``action_space == 'reparameterization'``):
+ - logit (:obj:`list`): 2 elements, each is the shape of :math:`(B, A, N3)`, where B is batch size and \
+ A is agent num. N3 corresponds to ``action_shape``.
+ Examples:
+ >>> B = 32
+ >>> agent_obs_shape = 216
+ >>> global_obs_shape = 264
+ >>> agent_num = 8
+ >>> action_shape = 14
+ >>> act_space = 'reparameterization' # 'regression'
+ >>> data = {
+ >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape),
+ >>> }
+ >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False)
+ >>> if action_space == 'regression':
+ >>> action = model.compute_actor(data)['action']
+ >>> elif action_space == 'reparameterization':
+ >>> (mu, sigma) = model.compute_actor(data)['logit']
+ """
+ inputs = inputs['agent_state']
+ if self.action_space == 'regression':
+ x = self.actor(inputs)
+ return {'action': x['pred']}
+ else:
+ x = self.actor(inputs)
+ return {'logit': [x['mu'], x['sigma']]}
+
+ def compute_critic(self, inputs: Dict) -> Dict:
+ """
+ Overview:
+ Use observation tensor and action tensor to predict Q value.
+ Arguments:
+ - inputs (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
+ - ``obs`` (:obj:`Dict[str, torch.Tensor]`): The input dict tensor data, has keys:
+ - ``agent_state`` (:obj:`torch.Tensor`): The agent's observation tensor data, \
+ with shape :math:`(B, A, N0)`, where B is batch size and A is agent num. \
+ N0 corresponds to ``agent_obs_shape``.
+ - ``global_state`` (:obj:`torch.Tensor`): The global observation tensor data, \
+ with shape :math:`(B, A, N1)`, where B is batch size and A is agent num. \
+ N1 corresponds to ``global_obs_shape``.
+ - ``action_mask`` (:obj:`torch.Tensor`): The action mask tensor data, \
+ with shape :math:`(B, A, N2)`, where B is batch size and A is agent num. \
+ N2 corresponds to ``action_shape``.
+
+ - ``action`` (:obj:`torch.Tensor`): The action tensor data, \
+ with shape :math:`(B, A, N3)`, where B is batch size and A is agent num. \
+ N3 corresponds to ``action_shape``.
+
+ Returns:
+ - outputs (:obj:`Dict`): Outputs of network forward.
+ ReturnKeys (``twin_critic=True``):
+ - q_value (:obj:`list`): 2 elements, each is the shape of :math:`(B, A)`, where B is batch size and \
+ A is agent num.
+ ReturnKeys (``twin_critic=False``):
+ - q_value (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is agent num.
+ Examples:
+ >>> B = 32
+ >>> agent_obs_shape = 216
+ >>> global_obs_shape = 264
+ >>> agent_num = 8
+ >>> action_shape = 14
+ >>> act_space = 'reparameterization' # 'regression'
+ >>> data = {
+ >>> 'obs': {
+ >>> 'agent_state': torch.randn(B, agent_num, agent_obs_shape),
+ >>> 'global_state': torch.randn(B, agent_num, global_obs_shape),
+ >>> 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
+ >>> },
+ >>> 'action': torch.randn(B, agent_num, squeeze(action_shape))
+ >>> }
+ >>> model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, act_space, twin_critic=False)
+ >>> value = model.compute_critic(data)['q_value']
+ """
+
+ obs, action = inputs['obs']['global_state'], inputs['action']
+ if len(action.shape) == 1: # (B, ) -> (B, 1)
+ action = action.unsqueeze(1)
+ x = torch.cat([obs, action], dim=-1)
+ if self.twin_critic:
+ x = [m(x)['pred'] for m in self.critic]
+ else:
+ x = self.critic(x)['pred']
+ return {'q_value': x}
diff --git a/DI-engine/ding/model/template/mavac.py b/DI-engine/ding/model/template/mavac.py
new file mode 100644
index 0000000000000000000000000000000000000000..78071e6783296a77591c90de5503ec0a7ef89983
--- /dev/null
+++ b/DI-engine/ding/model/template/mavac.py
@@ -0,0 +1,280 @@
+from typing import Union, Dict, Optional
+import torch
+import torch.nn as nn
+
+from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
+from ..common import ReparameterizationHead, RegressionHead, DiscreteHead
+
+
+@MODEL_REGISTRY.register('mavac')
+class MAVAC(nn.Module):
+ """
+ Overview:
+ The neural network and computation graph of algorithms related to (state) Value Actor-Critic (VAC) for \
+ multi-agent, such as MAPPO(https://arxiv.org/abs/2103.01955). This model now supports discrete and \
+ continuous action space. The MAVAC is composed of four parts: ``actor_encoder``, ``critic_encoder``, \
+ ``actor_head`` and ``critic_head``. Encoders are used to extract the feature from various observation. \
+ Heads are used to predict corresponding value or action logit.
+ Interfaces:
+ ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic``.
+ """
+ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
+
+ def __init__(
+ self,
+ agent_obs_shape: Union[int, SequenceType],
+ global_obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ agent_num: int,
+ actor_head_hidden_size: int = 256,
+ actor_head_layer_num: int = 2,
+ critic_head_hidden_size: int = 512,
+ critic_head_layer_num: int = 1,
+ action_space: str = 'discrete',
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ sigma_type: Optional[str] = 'independent',
+ bound_type: Optional[str] = None,
+ ) -> None:
+ """
+ Overview:
+ Init the MAVAC Model according to arguments.
+ Arguments:
+ - agent_obs_shape (:obj:`Union[int, SequenceType]`): Observation's space for single agent, \
+ such as 8 or [4, 84, 84].
+ - global_obs_shape (:obj:`Union[int, SequenceType]`): Global observation's space, such as 8 or [4, 84, 84].
+ - action_shape (:obj:`Union[int, SequenceType]`): Action space shape for single agent, such as 6 \
+ or [2, 3, 3].
+ - agent_num (:obj:`int`): This parameter is temporarily reserved. This parameter may be required for \
+ subsequent changes to the model
+ - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \
+ to 256, it must match the last element of ``agent_obs_shape``.
+ - actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action.
+ - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \
+ to 512, it must match the last element of ``global_obs_shape``.
+ - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output for \
+ critic's nn.
+ - action_space (:obj:`Union[int, SequenceType]`): The type of different action spaces, including \
+ ['discrete', 'continuous'], then will instantiate corresponding head, including ``DiscreteHead`` \
+ and ``ReparameterizationHead``.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \
+ ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
+ ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN'].
+ - sigma_type (:obj:`Optional[str]`): The type of sigma in continuous action space, see \
+ ``ding.torch_utils.network.dreamer.ReparameterizationHead`` for more details, in MAPPO, it defaults \
+ to ``independent``, which means state-independent sigma parameters.
+ - bound_type (:obj:`Optional[str]`): The type of action bound methods in continuous action space, defaults \
+ to ``None``, which means no bound.
+ """
+ super(MAVAC, self).__init__()
+ agent_obs_shape: int = squeeze(agent_obs_shape)
+ global_obs_shape: int = squeeze(global_obs_shape)
+ action_shape: int = squeeze(action_shape)
+ self.global_obs_shape, self.agent_obs_shape, self.action_shape = global_obs_shape, agent_obs_shape, action_shape
+ self.action_space = action_space
+ # Encoder Type
+ # We directly connect the Head after a Liner layer instead of using the 3-layer FCEncoder.
+ # In SMAC task it can obviously improve the performance.
+ # Users can change the model according to their own needs.
+ self.actor_encoder = nn.Identity()
+ self.critic_encoder = nn.Identity()
+ # Head Type
+ self.critic_head = nn.Sequential(
+ nn.Linear(global_obs_shape, critic_head_hidden_size), activation,
+ RegressionHead(
+ critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
+ )
+ )
+ assert self.action_space in ['discrete', 'continuous'], self.action_space
+ if self.action_space == 'discrete':
+ self.actor_head = nn.Sequential(
+ nn.Linear(agent_obs_shape, actor_head_hidden_size), activation,
+ DiscreteHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ elif self.action_space == 'continuous':
+ self.actor_head = nn.Sequential(
+ nn.Linear(agent_obs_shape, actor_head_hidden_size), activation,
+ ReparameterizationHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ sigma_type=sigma_type,
+ activation=activation,
+ norm_type=norm_type,
+ bound_type=bound_type
+ )
+ )
+ # must use list, not nn.ModuleList
+ self.actor = [self.actor_encoder, self.actor_head]
+ self.critic = [self.critic_encoder, self.critic_head]
+ # for convenience of call some apis(such as: self.critic.parameters()), but may cause
+ # misunderstanding when print(self)
+ self.actor = nn.ModuleList(self.actor)
+ self.critic = nn.ModuleList(self.critic)
+
+ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
+ """
+ Overview:
+ MAVAC forward computation graph, input observation tensor to predict state value or action logit. \
+ ``mode`` includes ``compute_actor``, ``compute_critic``, ``compute_actor_critic``.
+ Different ``mode`` will forward with different network modules to get different outputs and save \
+ computation.
+ Arguments:
+ - inputs (:obj:`Dict`): The input dict including observation and related info, \
+ whose key-values vary from different ``mode``.
+ - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
+ Returns:
+ - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph, whose key-values vary from \
+ different ``mode``.
+
+ Examples (Actor):
+ >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14)
+ >>> inputs = {
+ 'agent_state': torch.randn(10, 8, 64),
+ 'global_state': torch.randn(10, 8, 128),
+ 'action_mask': torch.randint(0, 2, size=(10, 8, 14))
+ }
+ >>> actor_outputs = model(inputs,'compute_actor')
+ >>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14])
+
+ Examples (Critic):
+ >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14)
+ >>> inputs = {
+ 'agent_state': torch.randn(10, 8, 64),
+ 'global_state': torch.randn(10, 8, 128),
+ 'action_mask': torch.randint(0, 2, size=(10, 8, 14))
+ }
+ >>> critic_outputs = model(inputs,'compute_critic')
+ >>> assert actor_outputs['value'].shape == torch.Size([10, 8])
+
+ Examples (Actor-Critic):
+ >>> model = MAVAC(64, 64)
+ >>> inputs = {
+ 'agent_state': torch.randn(10, 8, 64),
+ 'global_state': torch.randn(10, 8, 128),
+ 'action_mask': torch.randint(0, 2, size=(10, 8, 14))
+ }
+ >>> outputs = model(inputs,'compute_actor_critic')
+ >>> assert outputs['value'].shape == torch.Size([10, 8, 14])
+ >>> assert outputs['logit'].shape == torch.Size([10, 8])
+
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_actor(self, x: Dict) -> Dict:
+ """
+ Overview:
+ MAVAC forward computation graph for actor part, \
+ predicting action logit with agent observation tensor in ``x``.
+ Arguments:
+ - x (:obj:`Dict`): Input data dict with keys ['agent_state', 'action_mask'(optional)].
+ - agent_state: (:obj:`torch.Tensor`): Each agent local state(obs).
+ - action_mask(optional): (:obj:`torch.Tensor`): When ``action_space`` is discrete, action_mask needs \
+ to be provided to mask illegal actions.
+ Returns:
+ - outputs (:obj:`Dict`): The output dict of the forward computation graph for actor, including ``logit``.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \
+ the same dimension real-value ranged tensor of possible action choices, and for continuous action \
+ space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \
+ same as the number of continuous actions.
+ Shapes:
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \
+ and M is ``agent_num``.
+
+ Examples:
+ >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14)
+ >>> inputs = {
+ 'agent_state': torch.randn(10, 8, 64),
+ 'global_state': torch.randn(10, 8, 128),
+ 'action_mask': torch.randint(0, 2, size=(10, 8, 14))
+ }
+ >>> actor_outputs = model(inputs,'compute_actor')
+ >>> assert actor_outputs['logit'].shape == torch.Size([10, 8, 14])
+
+ """
+ if self.action_space == 'discrete':
+ action_mask = x['action_mask']
+ x = x['agent_state']
+ x = self.actor_encoder(x)
+ x = self.actor_head(x)
+ logit = x['logit']
+ logit[action_mask == 0.0] = -99999999
+ elif self.action_space == 'continuous':
+ x = x['agent_state']
+ x = self.actor_encoder(x)
+ x = self.actor_head(x)
+ logit = x
+ return {'logit': logit}
+
+ def compute_critic(self, x: Dict) -> Dict:
+ """
+ Overview:
+ MAVAC forward computation graph for critic part. \
+ Predict state value with global observation tensor in ``x``.
+ Arguments:
+ - x (:obj:`Dict`): Input data dict with keys ['global_state'].
+ - global_state: (:obj:`torch.Tensor`): Global state(obs).
+ Returns:
+ - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph for critic, \
+ including ``value``.
+ ReturnsKeys:
+ - value (:obj:`torch.Tensor`): The predicted state value tensor.
+ Shapes:
+ - value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is ``agent_num``.
+
+ Examples:
+ >>> model = MAVAC(agent_obs_shape=64, global_obs_shape=128, action_shape=14)
+ >>> inputs = {
+ 'agent_state': torch.randn(10, 8, 64),
+ 'global_state': torch.randn(10, 8, 128),
+ 'action_mask': torch.randint(0, 2, size=(10, 8, 14))
+ }
+ >>> critic_outputs = model(inputs,'compute_critic')
+ >>> assert critic_outputs['value'].shape == torch.Size([10, 8])
+ """
+
+ x = self.critic_encoder(x['global_state'])
+ x = self.critic_head(x)
+ return {'value': x['pred']}
+
+ def compute_actor_critic(self, x: Dict) -> Dict:
+ """
+ Overview:
+ MAVAC forward computation graph for both actor and critic part, input observation to predict action \
+ logit and state value.
+ Arguments:
+ - x (:obj:`Dict`): The input dict contains ``agent_state``, ``global_state`` and other related info.
+ Returns:
+ - outputs (:obj:`Dict`): The output dict of MAVAC's forward computation graph for both actor and critic, \
+ including ``logit`` and ``value``.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): Logit encoding tensor, with same size as input ``x``.
+ - value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
+ Shapes:
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, M, N)`, where B is batch size and N is ``action_shape`` \
+ and M is ``agent_num``.
+ - value (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch sizeand M is ``agent_num``.
+
+ Examples:
+ >>> model = MAVAC(64, 64)
+ >>> inputs = {
+ 'agent_state': torch.randn(10, 8, 64),
+ 'global_state': torch.randn(10, 8, 128),
+ 'action_mask': torch.randint(0, 2, size=(10, 8, 14))
+ }
+ >>> outputs = model(inputs,'compute_actor_critic')
+ >>> assert outputs['value'].shape == torch.Size([10, 8])
+ >>> assert outputs['logit'].shape == torch.Size([10, 8, 14])
+ """
+ logit = self.compute_actor(x)['logit']
+ value = self.compute_critic(x)['value']
+ return {'logit': logit, 'value': value}
diff --git a/DI-engine/ding/model/template/ngu.py b/DI-engine/ding/model/template/ngu.py
new file mode 100644
index 0000000000000000000000000000000000000000..caa3c14760ab57a730cad93a67d5a2725baaceee
--- /dev/null
+++ b/DI-engine/ding/model/template/ngu.py
@@ -0,0 +1,225 @@
+from typing import Union, Optional, Dict, Callable, List
+import torch
+import torch.nn as nn
+
+from ding.torch_utils import get_lstm, one_hot, to_tensor, to_ndarray
+from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
+# from ding.torch_utils.data_helper import one_hot_embedding, one_hot_embedding_none
+from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, MultiHead, RainbowHead, \
+ QuantileHead, QRDQNHead, DistributionHead
+
+
+def parallel_wrapper(forward_fn: Callable) -> Callable:
+ """
+ Overview:
+ Process timestep T and batch_size B at the same time, in other words, treat different timestep data as \
+ different trajectories in a batch.
+ Arguments:
+ - forward_fn (:obj:`Callable`): Normal ``nn.Module`` 's forward function.
+ Returns:
+ - wrapper (:obj:`Callable`): Wrapped function.
+ """
+
+ def wrapper(x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:
+ T, B = x.shape[:2]
+
+ def reshape(d):
+ if isinstance(d, list):
+ d = [reshape(t) for t in d]
+ elif isinstance(d, dict):
+ d = {k: reshape(v) for k, v in d.items()}
+ else:
+ d = d.reshape(T, B, *d.shape[1:])
+ return d
+
+ x = x.reshape(T * B, *x.shape[2:])
+ x = forward_fn(x)
+ x = reshape(x)
+ return x
+
+ return wrapper
+
+
+@MODEL_REGISTRY.register('ngu')
+class NGU(nn.Module):
+ """
+ Overview:
+ The recurrent Q model for NGU(https://arxiv.org/pdf/2002.06038.pdf) policy, modified from the class DRQN in \
+ q_leaning.py. The implementation mentioned in the original paper is 'adapt the R2D2 agent that uses the \
+ dueling network architecture with an LSTM layer after a convolutional neural network'. The NGU network \
+ includes encoder, LSTM core(rnn) and head.
+ Interface:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ collector_env_num: Optional[int] = 1, # TODO
+ dueling: bool = True,
+ head_hidden_size: Optional[int] = None,
+ head_layer_num: int = 1,
+ lstm_type: Optional[str] = 'normal',
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None
+ ) -> None:
+ """
+ Overview:
+ Init the DRQN Model for NGU according to arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space, such as 8 or [4, 84, 84].
+ - action_shape (:obj:`Union[int, SequenceType]`): Action's space, such as 6 or [2, 3, 3].
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``.
+ - collector_env_num (:obj:`Optional[int]`): The number of environments used to collect data simultaneously.
+ - dueling (:obj:`bool`): Whether choose ``DuelingHead`` (True) or ``DiscreteHead (False)``, \
+ default to True.
+ - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``, should match the \
+ last element of ``encoder_hidden_size_list``.
+ - head_layer_num (:obj:`int`): The number of layers in head network.
+ - lstm_type (:obj:`Optional[str]`): Version of rnn cell, now support ['normal', 'pytorch', 'hpc', 'gru'], \
+ default is 'normal'.
+ - activation (:obj:`Optional[nn.Module]`):
+ The type of activation function to use in ``MLP`` the after ``layer_fn``, \
+ if ``None`` then default set to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`):
+ The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details`.
+ """
+ super(NGU, self).__init__()
+ # For compatibility: 1, (1, ), [4, H, H]
+ obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
+ self.action_shape = action_shape
+ self.collector_env_num = collector_env_num
+ if head_hidden_size is None:
+ head_hidden_size = encoder_hidden_size_list[-1]
+ # FC Encoder
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ # Conv Encoder
+ elif len(obs_shape) == 3:
+ self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own DRQN".format(obs_shape)
+ )
+ # NOTE: current obs hidden_state_dim, previous action, previous extrinsic reward, beta
+ # TODO(pu): add prev_reward_intrinsic to network input, reward uses some kind of embedding instead of 1D value
+ input_size = head_hidden_size + action_shape + 1 + self.collector_env_num
+ # LSTM Type
+ self.rnn = get_lstm(lstm_type, input_size=input_size, hidden_size=head_hidden_size)
+ # Head Type
+ if dueling:
+ head_cls = DuelingHead
+ else:
+ head_cls = DiscreteHead
+ multi_head = not isinstance(action_shape, int)
+ if multi_head:
+ self.head = MultiHead(
+ head_cls,
+ head_hidden_size,
+ action_shape,
+ layer_num=head_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+ else:
+ self.head = head_cls(
+ head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type
+ )
+
+ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: Optional[list] = None) -> Dict:
+ """
+ Overview:
+ Forward computation graph of NGU R2D2 network. Input observation, prev_action prev_reward_extrinsic \
+ to predict NGU Q output. Parameter updates with NGU's MLPs forward setup.
+ Arguments:
+ - inputs (:obj:`Dict`):
+ - obs (:obj:`torch.Tensor`): Encoded observation.
+ - prev_state (:obj:`list`): Previous state's tensor of size ``(B, N)``.
+ - inference: (:obj:'bool'): If inference is True, we unroll the one timestep transition, \
+ if inference is False, we unroll the sequence transitions.
+ - saved_state_timesteps: (:obj:'Optional[list]'): When inference is False, \
+ we unroll the sequence transitions, then we would save rnn hidden states at timesteps \
+ that are listed in list saved_state_timesteps.
+ Returns:
+ - outputs (:obj:`Dict`):
+ Run ``MLP`` with ``DRQN`` setups and return the result prediction dictionary.
+
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``obs``.
+ - next_state (:obj:`list`): Next state's tensor of size ``(B, N)``.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N=obs_space)`, where B is batch size.
+ - prev_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]`.
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`.
+ - next_state(:obj:`torch.FloatTensor list`): :math:`[(B, N)]`.
+ """
+ x, prev_state = inputs['obs'], inputs['prev_state']
+ if 'prev_action' in inputs.keys():
+ # collect, eval mode: pass into one timestep mini-batch data (batchsize=env_num)
+ prev_action = inputs['prev_action']
+ prev_reward_extrinsic = inputs['prev_reward_extrinsic']
+ else:
+ # train mode: pass into H timesteps mini-batch data (batchsize=train_batch_size)
+ prev_action = torch.cat(
+ [torch.ones_like(inputs['action'][:, 0].unsqueeze(1)) * (-1), inputs['action'][:, :-1]], dim=1
+ ) # (B, 1) (B, H-1) -> (B, H, self.action_shape)
+ prev_reward_extrinsic = torch.cat(
+ [torch.zeros_like(inputs['reward'][:, 0].unsqueeze(1)), inputs['reward'][:, :-1]], dim=1
+ ) # (B, 1, nstep) (B, H-1, nstep) -> (B, H, nstep)
+ beta = inputs['beta'] # beta_index
+ if inference:
+ # collect, eval mode: pass into one timestep mini-batch data (batchsize=env_num)
+ x = self.encoder(x)
+ x = x.unsqueeze(0)
+ prev_reward_extrinsic = prev_reward_extrinsic.unsqueeze(0).unsqueeze(-1)
+
+ env_num = self.collector_env_num
+ beta_onehot = one_hot(beta, env_num).unsqueeze(0)
+ prev_action_onehot = one_hot(prev_action, self.action_shape).unsqueeze(0)
+ x_a_r_beta = torch.cat(
+ [x, prev_action_onehot, prev_reward_extrinsic, beta_onehot], dim=-1
+ ) # shape (1, H, 1+env_num+action_dim)
+ x, next_state = self.rnn(x_a_r_beta.to(torch.float32), prev_state)
+ # TODO(pu): x, next_state = self.rnn(x, prev_state)
+ x = x.squeeze(0)
+ x = self.head(x)
+ x['next_state'] = next_state
+ return x
+ else:
+ # train mode: pass into H timesteps mini-batch data (batchsize=train_batch_size)
+ assert len(x.shape) in [3, 5], x.shape # (B, H, obs_dim)
+ x = parallel_wrapper(self.encoder)(x) # (B, H, hidden_dim)
+ prev_reward_extrinsic = prev_reward_extrinsic[:, :, 0].unsqueeze(-1) # (B,H,1)
+ env_num = self.collector_env_num
+ beta_onehot = one_hot(beta.view(-1), env_num).view([beta.shape[0], beta.shape[1], -1]) # (B, H, env_num)
+ prev_action_onehot = one_hot(prev_action.view(-1), self.action_shape).view(
+ [prev_action.shape[0], prev_action.shape[1], -1]
+ ) # (B, H, action_dim)
+ x_a_r_beta = torch.cat(
+ [x, prev_action_onehot, prev_reward_extrinsic, beta_onehot], dim=-1
+ ) # (B, H, 1+env_num+action_dim)
+ x = x_a_r_beta
+ lstm_embedding = []
+ # TODO(nyz) how to deal with hidden_size key-value
+ hidden_state_list = []
+ if saved_state_timesteps is not None:
+ saved_state = []
+ for t in range(x.shape[0]): # T timesteps
+ output, prev_state = self.rnn(x[t:t + 1], prev_state)
+ if saved_state_timesteps is not None and t + 1 in saved_state_timesteps:
+ saved_state.append(prev_state)
+ lstm_embedding.append(output)
+ # only take the hidden state h
+ hidden_state_list.append(torch.cat([item['h'] for item in prev_state], dim=1))
+
+ x = torch.cat(lstm_embedding, 0) # [B, H, 64]
+ x = parallel_wrapper(self.head)(x)
+ # the last timestep state including the hidden state (h) and the cell state (c)
+ x['next_state'] = prev_state
+ x['hidden_state'] = torch.cat(hidden_state_list, dim=-3)
+ if saved_state_timesteps is not None:
+ # the selected saved hidden states, including the hidden state (h) and the cell state (c)
+ x['saved_state'] = saved_state
+ return x
diff --git a/DI-engine/ding/model/template/pdqn.py b/DI-engine/ding/model/template/pdqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec94cb3fe14c8480ac2ff3f7999cd4463b69e007
--- /dev/null
+++ b/DI-engine/ding/model/template/pdqn.py
@@ -0,0 +1,229 @@
+from typing import Union, Optional, Dict
+from easydict import EasyDict
+
+import torch
+import torch.nn as nn
+
+from ding.torch_utils import get_lstm
+from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
+from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, RegressionHead
+
+
+@MODEL_REGISTRY.register('pdqn')
+class PDQN(nn.Module):
+ """
+ Overview:
+ The neural network and computation graph of PDQN(https://arxiv.org/abs/1810.06394v1) and \
+ MPDQN(https://arxiv.org/abs/1905.04388) algorithms for parameterized action space. \
+ This model supports parameterized action space with discrete ``action_type`` and continuous ``action_arg``. \
+ In principle, PDQN consists of x network (continuous action parameter network) and Q network (discrete \
+ action type network). But for simplicity, the code is split into ``encoder`` and ``actor_head``, which \
+ contain the encoder and head of the above two networks respectively.
+ Interface:
+ ``__init__``, ``forward``, ``compute_discrete``, ``compute_continuous``.
+ """
+ mode = ['compute_discrete', 'compute_continuous']
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: EasyDict,
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ dueling: bool = True,
+ head_hidden_size: Optional[int] = None,
+ head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ multi_pass: Optional[bool] = False,
+ action_mask: Optional[list] = None
+ ) -> None:
+ """
+ Overview:
+ Init the PDQN (encoder + head) Model according to input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
+ - action_shape (:obj:`EasyDict`): Action space shape in dict type, such as \
+ EasyDict({'action_type_shape': 3, 'action_args_shape': 5}).
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
+ the last element must match ``head_hidden_size``.
+ - dueling (:obj:`dueling`): Whether choose ``DuelingHead`` or ``DiscreteHead(default)``.
+ - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network.
+ - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
+ if ``None`` then default set it to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
+ ``ding.torch_utils.fc_block`` for more details.
+ - multi_pass (:obj:`Optional[bool]`): Whether to use multi pass version.
+ - action_mask: (:obj:`Optional[list]`): An action mask indicating how action args are \
+ associated to each discrete action. For example, if there are 3 discrete action, \
+ 4 continous action args, and the first discrete action associates with the first \
+ continuous action args, the second discrete action associates with the second continuous \
+ action args, and the third discrete action associates with the remaining 2 action args, \
+ the action mask will be like: [[1,0,0,0],[0,1,0,0],[0,0,1,1]] with shape 3*4.
+ """
+ super(PDQN, self).__init__()
+ self.multi_pass = multi_pass
+ if self.multi_pass:
+ assert isinstance(
+ action_mask, list
+ ), 'Please indicate action mask in list form if you set multi_pass to True'
+ self.action_mask = torch.LongTensor(action_mask)
+ nonzero = torch.nonzero(self.action_mask)
+ index = torch.zeros(action_shape.action_args_shape).long()
+ index.scatter_(dim=0, index=nonzero[:, 1], src=nonzero[:, 0])
+ self.action_scatter_index = index # (self.action_args_shape, )
+
+ # squeeze action shape input like (3,) to 3
+ action_shape.action_args_shape = squeeze(action_shape.action_args_shape)
+ action_shape.action_type_shape = squeeze(action_shape.action_type_shape)
+ self.action_args_shape = action_shape.action_args_shape
+ self.action_type_shape = action_shape.action_type_shape
+
+ # init head hidden size
+ if head_hidden_size is None:
+ head_hidden_size = encoder_hidden_size_list[-1]
+
+ # squeeze obs input for compatibility: 1, (1, ), [4, 32, 32]
+ obs_shape = squeeze(obs_shape)
+
+ # Obs Encoder Type
+ if isinstance(obs_shape, int) or len(obs_shape) == 1: # FC Encoder
+ self.dis_encoder = FCEncoder(
+ obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
+ )
+ self.cont_encoder = FCEncoder(
+ obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
+ )
+ elif len(obs_shape) == 3: # Conv Encoder
+ self.dis_encoder = ConvEncoder(
+ obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
+ )
+ self.cont_encoder = ConvEncoder(
+ obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type
+ )
+ else:
+ raise RuntimeError(
+ "Pre-defined encoder not support obs_shape {}, please customize your own PDQN.".format(obs_shape)
+ )
+
+ # Continuous Action Head Type
+ self.cont_head = RegressionHead(
+ head_hidden_size,
+ action_shape.action_args_shape,
+ head_layer_num,
+ final_tanh=True,
+ activation=activation,
+ norm_type=norm_type
+ )
+
+ # Discrete Action Head Type
+ if dueling:
+ dis_head_cls = DuelingHead
+ else:
+ dis_head_cls = DiscreteHead
+ self.dis_head = dis_head_cls(
+ head_hidden_size + action_shape.action_args_shape,
+ action_shape.action_type_shape,
+ head_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+
+ self.actor_head = nn.ModuleList([self.dis_head, self.cont_head])
+ # self.encoder = nn.ModuleList([self.dis_encoder, self.cont_encoder])
+ # To speed up the training process, the X network and the Q network share the encoder for the state
+ self.encoder = nn.ModuleList([self.cont_encoder, self.cont_encoder])
+
+ def forward(self, inputs: Union[torch.Tensor, Dict, EasyDict], mode: str) -> Dict:
+ """
+ Overview:
+ PDQN forward computation graph, input observation tensor to predict q_value for \
+ discrete actions and values for continuous action_args.
+ Arguments:
+ - inputs (:obj:`Union[torch.Tensor, Dict, EasyDict]`): Inputs including observation and \
+ other info according to `mode`.
+ - mode (:obj:`str`): Name of the forward mode.
+ Shapes:
+ - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``.
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_continuous(self, inputs: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use observation tensor to predict continuous action args.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): Observation inputs.
+ Returns:
+ - outputs (:obj:`Dict`): A dict with key 'action_args'.
+ - 'action_args' (:obj:`torch.Tensor`): The continuous action args.
+ Shapes:
+ - inputs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``.
+ - action_args (:obj:`torch.Tensor`): :math:`(B, M)`, where M is ``action_args_shape``.
+ Examples:
+ >>> act_shape = EasyDict({'action_type_shape': (3, ), 'action_args_shape': (5, )})
+ >>> model = PDQN(4, act_shape)
+ >>> inputs = torch.randn(64, 4)
+ >>> outputs = model.forward(inputs, mode='compute_continuous')
+ >>> assert outputs['action_args'].shape == torch.Size([64, 5])
+ """
+ cont_x = self.encoder[1](inputs) # size (B, encoded_state_shape)
+ action_args = self.actor_head[1](cont_x)['pred'] # size (B, action_args_shape)
+ outputs = {'action_args': action_args}
+ return outputs
+
+ def compute_discrete(self, inputs: Union[Dict, EasyDict]) -> Dict:
+ """
+ Overview:
+ Use observation tensor and continuous action args to predict discrete action types.
+ Arguments:
+ - inputs (:obj:`Union[Dict, EasyDict]`): A dict with keys 'state', 'action_args'.
+ - state (:obj:`torch.Tensor`): Observation inputs.
+ - action_args (:obj:`torch.Tensor`): Action parameters are used to concatenate with the observation \
+ and serve as input to the discrete action type network.
+ Returns:
+ - outputs (:obj:`Dict`): A dict with keys 'logit', 'action_args'.
+ - 'logit': The logit value for each discrete action.
+ - 'action_args': The continuous action args(same as the inputs['action_args']) for later usage.
+ Examples:
+ >>> act_shape = EasyDict({'action_type_shape': (3, ), 'action_args_shape': (5, )})
+ >>> model = PDQN(4, act_shape)
+ >>> inputs = {'state': torch.randn(64, 4), 'action_args': torch.randn(64, 5)}
+ >>> outputs = model.forward(inputs, mode='compute_discrete')
+ >>> assert outputs['logit'].shape == torch.Size([64, 3])
+ >>> assert outputs['action_args'].shape == torch.Size([64, 5])
+ """
+ dis_x = self.encoder[0](inputs['state']) # size (B, encoded_state_shape)
+ action_args = inputs['action_args'] # size (B, action_args_shape)
+
+ if self.multi_pass: # mpdqn
+ # fill_value=-2 is a mask value, which is not in normal acton range
+ # (B, action_args_shape, K) where K is the action_type_shape
+ mp_action = torch.full(
+ (dis_x.shape[0], self.action_args_shape, self.action_type_shape),
+ fill_value=-2,
+ device=dis_x.device,
+ dtype=dis_x.dtype
+ )
+ index = self.action_scatter_index.view(1, -1, 1).repeat(dis_x.shape[0], 1, 1).to(dis_x.device)
+
+ # index: (B, action_args_shape, 1) src: (B, action_args_shape, 1)
+ mp_action.scatter_(dim=-1, index=index, src=action_args.unsqueeze(-1))
+ mp_action = mp_action.permute(0, 2, 1) # (B, K, action_args_shape)
+
+ mp_state = dis_x.unsqueeze(1).repeat(1, self.action_type_shape, 1) # (B, K, obs_shape)
+ mp_state_action_cat = torch.cat([mp_state, mp_action], dim=-1)
+
+ logit = self.actor_head[0](mp_state_action_cat)['logit'] # (B, K, K)
+
+ logit = torch.diagonal(logit, dim1=-2, dim2=-1) # (B, K)
+ else: # pdqn
+ # size (B, encoded_state_shape + action_args_shape)
+ if len(action_args.shape) == 1: # (B, ) -> (B, 1)
+ action_args = action_args.unsqueeze(1)
+ state_action_cat = torch.cat((dis_x, action_args), dim=-1)
+ logit = self.actor_head[0](state_action_cat)['logit'] # size (B, K) where K is action_type_shape
+
+ outputs = {'logit': logit, 'action_args': action_args}
+ return outputs
diff --git a/DI-engine/ding/model/template/pg.py b/DI-engine/ding/model/template/pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..6059642dd34bdf8814916f765b1eb098c9228e89
--- /dev/null
+++ b/DI-engine/ding/model/template/pg.py
@@ -0,0 +1,111 @@
+from typing import Union, Optional, Dict, Callable, List
+import torch
+import torch.nn as nn
+from easydict import EasyDict
+
+from ding.torch_utils import get_lstm
+from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
+from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, \
+ MultiHead, RegressionHead, ReparameterizationHead, independent_normal_dist
+
+
+@MODEL_REGISTRY.register('pg')
+class PG(nn.Module):
+ """
+ Overview:
+ The neural network and computation graph of algorithms related to Policy Gradient(PG) \
+ (https://proceedings.neurips.cc/paper/1999/file/464d828b85b0bed98e80ade0a5c43b0f-Paper.pdf). \
+ The PG model is composed of two parts: encoder and head. Encoders are used to extract the feature \
+ from various observation. Heads are used to predict corresponding action logit.
+ Interface:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ action_space: str = 'discrete',
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ head_hidden_size: Optional[int] = None,
+ head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None
+ ) -> None:
+ """
+ Overview:
+ Initialize the PG model according to corresponding input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
+ - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
+ - action_space (:obj:`str`): The type of different action spaces, including ['discrete', 'continuous'], \
+ then will instantiate corresponding head, including ``DiscreteHead`` and ``ReparameterizationHead``.
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
+ the last element must match ``head_hidden_size``.
+ - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``head`` network, defaults \
+ to None, it must match the last element of ``encoder_hidden_size_list``.
+ - head_layer_num (:obj:`int`): The num of layers used in the ``head`` network to compute action.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
+ if ``None`` then default set it to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
+ ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
+ Examples:
+ >>> model = PG((4, 84, 84), 5)
+ >>> inputs = torch.randn(8, 4, 84, 84)
+ >>> outputs = model(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['logit'].shape == (8, 5)
+ >>> assert outputs['dist'].sample().shape == (8, )
+ """
+ super(PG, self).__init__()
+ # For compatibility: 1, (1, ), [4, 32, 32]
+ obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
+ if head_hidden_size is None:
+ head_hidden_size = encoder_hidden_size_list[-1]
+ # FC Encoder
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ # Conv Encoder
+ elif len(obs_shape) == 3:
+ self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own BC".format(obs_shape)
+ )
+ self.action_space = action_space
+ # Head
+ if self.action_space == 'discrete':
+ self.head = DiscreteHead(
+ head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type
+ )
+ elif self.action_space == 'continuous':
+ self.head = ReparameterizationHead(
+ head_hidden_size,
+ action_shape,
+ head_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ sigma_type='independent'
+ )
+ else:
+ raise KeyError("not support action space: {}".format(self.action_space))
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ PG forward computation graph, input observation tensor to predict policy distribution.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input observation tensor data.
+ Returns:
+ - outputs (:obj:`torch.distributions`): The output policy distribution. If action space is \
+ discrete, the output is Categorical distribution; if action space is continuous, the output is Normal \
+ distribution.
+ """
+ x = self.encoder(x)
+ x = self.head(x)
+ if self.action_space == 'discrete':
+ x['dist'] = torch.distributions.Categorical(logits=x['logit'])
+ elif self.action_space == 'continuous':
+ x = {'logit': {'mu': x['mu'], 'sigma': x['sigma']}}
+ x['dist'] = independent_normal_dist(x['logit'])
+ return x
diff --git a/DI-engine/ding/model/template/ppg.py b/DI-engine/ding/model/template/ppg.py
new file mode 100644
index 0000000000000000000000000000000000000000..76df579e715bda89e60901f6a269325f6d2b009e
--- /dev/null
+++ b/DI-engine/ding/model/template/ppg.py
@@ -0,0 +1,152 @@
+from typing import Optional, Dict, Union
+import copy
+import torch
+import torch.nn as nn
+from ding.utils import SequenceType, MODEL_REGISTRY
+from .vac import VAC
+
+
+@MODEL_REGISTRY.register('ppg')
+class PPG(nn.Module):
+ """
+ Overview:
+ Phasic Policy Gradient (PPG) model from paper `Phasic Policy Gradient`
+ https://arxiv.org/abs/2009.04416 \
+ This module contains VAC module and an auxiliary critic module.
+ Interfaces:
+ ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic``
+ """
+
+ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ action_space: str = 'discrete',
+ share_encoder: bool = True,
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 2,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ impala_cnn_encoder: bool = False,
+ ) -> None:
+ """
+ Overview:
+ Initailize the PPG Model according to input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ).
+ - action_shape (:obj:`Union[int, SequenceType]`): Action's shape, such as 4, (3, ).
+ - action_space (:obj:`str`): The action space type, such as 'discrete', 'continuous'.
+ - share_encoder (:obj:`bool`): Whether to share encoder.
+ - encoder_hidden_size_list (:obj:`SequenceType`): The hidden size list of encoder.
+ - actor_head_hidden_size (:obj:`int`): The ``hidden_size`` to pass to actor head.
+ - actor_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
+ for actor head.
+ - critic_head_hidden_size (:obj:`int`): The ``hidden_size`` to pass to critic head.
+ - critic_head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output \
+ for critic head.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \
+ after each FC layer, if ``None`` then default set to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \
+ see ``ding.torch_utils.network`` for more details.
+ - impala_cnn_encoder (:obj:`bool`): Whether to use impala cnn encoder.
+ """
+ super(PPG, self).__init__()
+ self.actor_critic = VAC(
+ obs_shape,
+ action_shape,
+ action_space,
+ share_encoder,
+ encoder_hidden_size_list,
+ actor_head_hidden_size,
+ actor_head_layer_num,
+ critic_head_hidden_size,
+ critic_head_layer_num,
+ activation,
+ norm_type,
+ impala_cnn_encoder=impala_cnn_encoder
+ )
+ self.aux_critic = copy.deepcopy(self.actor_critic.critic)
+
+ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
+ """
+ Overview:
+ Compute action logits or value according to mode being ``compute_actor``, ``compute_critic`` or \
+ ``compute_actor_critic``.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input observation tensor data.
+ - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
+ Returns:
+ - outputs (:obj:`Dict`): The output dict of PPG's forward computation graph, whose key-values vary from \
+ different ``mode``.
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_actor(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use actor to compute action logits.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input observation tensor data.
+ Returns:
+ - output (:obj:`Dict`): The output data containing action logits.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \
+ the same dimension real-value ranged tensor of possible action choices, and for continuous action \
+ space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \
+ same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \
+ and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``.
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is the input feature size.
+ - output (:obj:`Dict`): ``logit``: :math:`(B, A)`, where B is batch size and A is the action space size.
+ """
+ return self.actor_critic(x, mode='compute_actor')
+
+ def compute_critic(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use critic to compute value.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input observation tensor data.
+ Returns:
+ - output (:obj:`Dict`): The output dict of VAC's forward computation graph for critic, including ``value``.
+ ReturnsKeys:
+ - necessary: ``value``
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is the input feature size.
+ - output (:obj:`Dict`): ``value``: :math:`(B, 1)`, where B is batch size.
+ """
+ x = self.aux_critic[0](x) # encoder
+ x = self.aux_critic[1](x) # head
+ return {'value': x['pred']}
+
+ def compute_actor_critic(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use actor and critic to compute action logits and value.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input observation tensor data.
+ Returns:
+ - outputs (:obj:`Dict`): The output dict of PPG's forward computation graph for both actor and critic, \
+ including ``logit`` and ``value``.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \
+ the same dimension real-value ranged tensor of possible action choices, and for continuous action \
+ space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \
+ same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \
+ and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``.
+ - value (:obj:`torch.Tensor`): The predicted state value tensor.
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is the input feature size.
+ - output (:obj:`Dict`): ``value``: :math:`(B, 1)`, where B is batch size.
+ - output (:obj:`Dict`): ``logit``: :math:`(B, A)`, where B is batch size and A is the action space size.
+
+ .. note::
+ ``compute_actor_critic`` interface aims to save computation when shares encoder.
+ """
+ return self.actor_critic(x, mode='compute_actor_critic')
diff --git a/DI-engine/ding/model/template/procedure_cloning.py b/DI-engine/ding/model/template/procedure_cloning.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f03c8a4bf2b9c10f84f35907abbcbee795dd23f
--- /dev/null
+++ b/DI-engine/ding/model/template/procedure_cloning.py
@@ -0,0 +1,327 @@
+from typing import Optional, Tuple, Union, Dict
+
+import torch
+import torch.nn as nn
+
+from ding.utils import MODEL_REGISTRY, SequenceType
+from ding.torch_utils.network.transformer import Attention
+from ding.torch_utils.network.nn_module import fc_block, build_normalization
+from ..common import FCEncoder, ConvEncoder
+
+
+class PCTransformer(nn.Module):
+ """
+ Overview:
+ The transformer block for neural network of algorithms related to Procedure cloning (PC).
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self, cnn_hidden: int, att_hidden: int, att_heads: int, drop_p: float, max_T: int, n_att: int,
+ feedforward_hidden: int, n_feedforward: int
+ ) -> None:
+ """
+ Overview:
+ Initialize the procedure cloning transformer model according to corresponding input arguments.
+ Arguments:
+ - cnn_hidden (:obj:`int`): The last channel dimension of CNN encoder, such as 32.
+ - att_hidden (:obj:`int`): The dimension of attention blocks, such as 32.
+ - att_heads (:obj:`int`): The number of heads in attention blocks, such as 4.
+ - drop_p (:obj:`float`): The drop out rate of attention, such as 0.5.
+ - max_T (:obj:`int`): The sequence length of procedure cloning, such as 4.
+ - n_attn (:obj:`int`): The number of attention layers, such as 4.
+ - feedforward_hidden (:obj:`int`):The dimension of feedforward layers, such as 32.
+ - n_feedforward (:obj:`int`): The number of feedforward layers, such as 4.
+ """
+ super().__init__()
+ self.n_att = n_att
+ self.n_feedforward = n_feedforward
+ self.attention_layer = []
+
+ self.norm_layer = [nn.LayerNorm(att_hidden)] * n_att
+ self.attention_layer.append(Attention(cnn_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))
+ for i in range(n_att - 1):
+ self.attention_layer.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))
+
+ self.att_drop = nn.Dropout(drop_p)
+
+ self.fc_blocks = []
+ self.fc_blocks.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU()))
+ for i in range(n_feedforward - 1):
+ self.fc_blocks.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU()))
+ self.norm_layer.extend([nn.LayerNorm(feedforward_hidden)] * n_feedforward)
+ self.mask = torch.tril(torch.ones((max_T, max_T), dtype=torch.bool)).view(1, 1, max_T, max_T)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ The unique execution (forward) method of PCTransformer.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Sequential data of several hidden states.
+ Returns:
+ - output (:obj:`torch.Tensor`): A tensor with the same shape as the input.
+ Examples:
+ >>> model = PCTransformer(128, 128, 8, 0, 16, 2, 128, 2)
+ >>> h = torch.randn((2, 16, 128))
+ >>> h = model(h)
+ >>> assert h.shape == torch.Size([2, 16, 128])
+ """
+ for i in range(self.n_att):
+ x = self.att_drop(self.attention_layer[i](x, self.mask))
+ x = self.norm_layer[i](x)
+ for i in range(self.n_feedforward):
+ x = self.fc_blocks[i](x)
+ x = self.norm_layer[i + self.n_att](x)
+ return x
+
+
+@MODEL_REGISTRY.register('pc_mcts')
+class ProcedureCloningMCTS(nn.Module):
+ """
+ Overview:
+ The neural network of algorithms related to Procedure cloning (PC).
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ obs_shape: SequenceType,
+ action_dim: int,
+ cnn_hidden_list: SequenceType = [128, 128, 256, 256, 256],
+ cnn_activation: nn.Module = nn.ReLU(),
+ cnn_kernel_size: SequenceType = [3, 3, 3, 3, 3],
+ cnn_stride: SequenceType = [1, 1, 1, 1, 1],
+ cnn_padding: SequenceType = [1, 1, 1, 1, 1],
+ mlp_hidden_list: SequenceType = [256, 256],
+ mlp_activation: nn.Module = nn.ReLU(),
+ att_heads: int = 8,
+ att_hidden: int = 128,
+ n_att: int = 4,
+ n_feedforward: int = 2,
+ feedforward_hidden: int = 256,
+ drop_p: float = 0.5,
+ max_T: int = 17
+ ) -> None:
+ """
+ Overview:
+ Initialize the MCTS procedure cloning model according to corresponding input arguments.
+ Arguments:
+ - obs_shape (:obj:`SequenceType`): Observation space shape, such as [4, 84, 84].
+ - action_dim (:obj:`int`): Action space shape, such as 6.
+ - cnn_hidden_list (:obj:`SequenceType`): The cnn channel dims for each block, such as\
+ [128, 128, 256, 256, 256].
+ - cnn_activation (:obj:`nn.Module`): The activation function for cnn blocks, such as ``nn.ReLU()``.
+ - cnn_kernel_size (:obj:`SequenceType`): The kernel size for each cnn block, such as [3, 3, 3, 3, 3].
+ - cnn_stride (:obj:`SequenceType`): The stride for each cnn block, such as [1, 1, 1, 1, 1].
+ - cnn_padding (:obj:`SequenceType`): The padding for each cnn block, such as [1, 1, 1, 1, 1].
+ - mlp_hidden_list (:obj:`SequenceType`): The last dim for this must match the last dim of \
+ ``cnn_hidden_list``, such as [256, 256].
+ - mlp_activation (:obj:`nn.Module`): The activation function for mlp layers, such as ``nn.ReLU()``.
+ - att_heads (:obj:`int`): The number of attention heads in transformer, such as 8.
+ - att_hidden (:obj:`int`): The number of attention dimension in transformer, such as 128.
+ - n_att (:obj:`int`): The number of attention blocks in transformer, such as 4.
+ - n_feedforward (:obj:`int`): The number of feedforward layers in transformer, such as 2.
+ - drop_p (:obj:`float`): The drop out rate of attention, such as 0.5.
+ - max_T (:obj:`int`): The sequence length of procedure cloning, such as 17.
+ """
+ super().__init__()
+
+ # Conv Encoder
+ self.embed_state = ConvEncoder(
+ obs_shape, cnn_hidden_list, cnn_activation, cnn_kernel_size, cnn_stride, cnn_padding
+ )
+ self.embed_action = FCEncoder(action_dim, mlp_hidden_list, activation=mlp_activation)
+
+ self.cnn_hidden_list = cnn_hidden_list
+
+ assert cnn_hidden_list[-1] == mlp_hidden_list[-1]
+ layers = []
+ for i in range(n_att):
+ if i == 0:
+ layers.append(Attention(cnn_hidden_list[-1], att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))
+ else:
+ layers.append(Attention(att_hidden, att_hidden, att_hidden, att_heads, nn.Dropout(drop_p)))
+ layers.append(build_normalization('LN')(att_hidden))
+ for i in range(n_feedforward):
+ if i == 0:
+ layers.append(fc_block(att_hidden, feedforward_hidden, activation=nn.ReLU()))
+ else:
+ layers.append(fc_block(feedforward_hidden, feedforward_hidden, activation=nn.ReLU()))
+ self.layernorm2 = build_normalization('LN')(feedforward_hidden)
+
+ self.transformer = PCTransformer(
+ cnn_hidden_list[-1], att_hidden, att_heads, drop_p, max_T, n_att, feedforward_hidden, n_feedforward
+ )
+
+ self.predict_goal = torch.nn.Linear(cnn_hidden_list[-1], cnn_hidden_list[-1])
+ self.predict_action = torch.nn.Linear(cnn_hidden_list[-1], action_dim)
+
+ def forward(self, states: torch.Tensor, goals: torch.Tensor,
+ actions: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ ProcedureCloningMCTS forward computation graph, input states tensor and goals tensor, \
+ calculate the predicted states and actions.
+ Arguments:
+ - states (:obj:`torch.Tensor`): The observation of current time.
+ - goals (:obj:`torch.Tensor`): The target observation after a period.
+ - actions (:obj:`torch.Tensor`): The actions executed during the period.
+ Returns:
+ - outputs (:obj:`Tuple[torch.Tensor, torch.Tensor]`): Predicted states and actions.
+ Examples:
+ >>> inputs = { \
+ 'states': torch.randn(2, 3, 64, 64), \
+ 'goals': torch.randn(2, 3, 64, 64), \
+ 'actions': torch.randn(2, 15, 9) \
+ }
+ >>> model = ProcedureCloningMCTS(obs_shape=(3, 64, 64), action_dim=9)
+ >>> goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions'])
+ >>> assert goal_preds.shape == (2, 256)
+ >>> assert action_preds.shape == (2, 16, 9)
+ """
+ B, T, _ = actions.shape
+
+ # shape: (B, h_dim)
+ state_embeddings = self.embed_state(states).reshape(B, 1, self.cnn_hidden_list[-1])
+ goal_embeddings = self.embed_state(goals).reshape(B, 1, self.cnn_hidden_list[-1])
+ # shape: (B, context_len, h_dim)
+ actions_embeddings = self.embed_action(actions)
+
+ h = torch.cat((state_embeddings, goal_embeddings, actions_embeddings), dim=1)
+ h = self.transformer(h)
+ h = h.reshape(B, T + 2, self.cnn_hidden_list[-1])
+
+ goal_preds = self.predict_goal(h[:, 0, :])
+ action_preds = self.predict_action(h[:, 1:, :])
+
+ return goal_preds, action_preds
+
+
+class BFSConvEncoder(nn.Module):
+ """
+ Overview:
+ The ``BFSConvolution Encoder`` used to encode raw 3-dim observations. And output a feature map with the
+ same height and width as input. Interfaces: ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ obs_shape: SequenceType,
+ hidden_size_list: SequenceType = [32, 64, 64, 128],
+ activation: Optional[nn.Module] = nn.ReLU(),
+ kernel_size: SequenceType = [8, 4, 3],
+ stride: SequenceType = [4, 2, 1],
+ padding: Optional[SequenceType] = None,
+ ) -> None:
+ """
+ Overview:
+ Init the ``BFSConvolution Encoder`` according to the provided arguments.
+ Arguments:
+ - obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, plus one or more ``input size``.
+ - hidden_size_list (:obj:`SequenceType`): Sequence of ``hidden_size`` of subsequent conv layers \
+ and the final dense layer.
+ - activation (:obj:`nn.Module`): Type of activation to use in the conv ``layers`` and ``ResBlock``. \
+ Default is ``nn.ReLU()``.
+ - kernel_size (:obj:`SequenceType`): Sequence of ``kernel_size`` of subsequent conv layers.
+ - stride (:obj:`SequenceType`): Sequence of ``stride`` of subsequent conv layers.
+ - padding (:obj:`SequenceType`): Padding added to all four sides of the input for each conv layer. \
+ See ``nn.Conv2d`` for more details. Default is ``None``.
+ """
+ super(BFSConvEncoder, self).__init__()
+ self.obs_shape = obs_shape
+ self.act = activation
+ self.hidden_size_list = hidden_size_list
+ if padding is None:
+ padding = [0 for _ in range(len(kernel_size))]
+
+ layers = []
+ input_size = obs_shape[0] # in_channel
+ for i in range(len(kernel_size)):
+ layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i], padding[i]))
+ layers.append(self.act)
+ input_size = hidden_size_list[i]
+ layers = layers[:-1]
+ self.main = nn.Sequential(*layers)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Return output tensor of the env observation.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Env raw observation.
+ Returns:
+ - outputs (:obj:`torch.Tensor`): Output embedding tensor.
+ Examples:
+ >>> model = BFSConvEncoder([3, 16, 16], [32, 32, 4], kernel_size=[3, 3, 3], stride=[1, 1, 1]\
+ , padding=[1, 1, 1])
+ >>> inputs = torch.randn(3, 16, 16).unsqueeze(0)
+ >>> outputs = model(inputs)
+ >>> assert outputs['logit'].shape == torch.Size([4, 16, 16])
+ """
+ return self.main(x)
+
+
+@MODEL_REGISTRY.register('pc_bfs')
+class ProcedureCloningBFS(nn.Module):
+ """
+ Overview:
+ The neural network introduced in procedure cloning (PC) to process 3-dim observations.\
+ Given an input, this model will perform several 3x3 convolutions and output a feature map with \
+ the same height and width of input. The channel number of output will be the ``action_shape``.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ obs_shape: SequenceType,
+ action_shape: int,
+ encoder_hidden_size_list: SequenceType = [128, 128, 256, 256],
+ ):
+ """
+ Overview:
+ Init the ``BFSConvolution Encoder`` according to the provided arguments.
+ Arguments:
+ - obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, plus one or more ``input size``,\
+ such as [4, 84, 84].
+ - action_dim (:obj:`int`): Action space shape, such as 6.
+ - cnn_hidden_list (:obj:`SequenceType`): The cnn channel dims for each block, such as [128, 128, 256, 256].
+ """
+ super().__init__()
+ num_layers = len(encoder_hidden_size_list)
+
+ kernel_sizes = (3, ) * (num_layers + 1)
+ stride_sizes = (1, ) * (num_layers + 1)
+ padding_sizes = (1, ) * (num_layers + 1)
+ # The output channel equals to action_shape + 1
+ encoder_hidden_size_list.append(action_shape + 1)
+
+ self._encoder = BFSConvEncoder(
+ obs_shape=obs_shape,
+ hidden_size_list=encoder_hidden_size_list,
+ kernel_size=kernel_sizes,
+ stride=stride_sizes,
+ padding=padding_sizes,
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ The computation graph. Given a 3-dim observation, this function will return a tensor with the same \
+ height and width. The channel number of output will be the ``action_shape``.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input observation tensor data.
+ Returns:
+ - outputs (:obj:`Dict`): The output dict of model's forward computation graph, \
+ only contains a single key ``logit``.
+ Examples:
+ >>> model = ProcedureCloningBFS([3, 16, 16], 4)
+ >>> inputs = torch.randn(16, 16, 3).unsqueeze(0)
+ >>> outputs = model(inputs)
+ >>> assert outputs['logit'].shape == torch.Size([16, 16, 4])
+ """
+ x = x.permute(0, 3, 1, 2)
+ x = self._encoder(x)
+ return {'logit': x.permute(0, 2, 3, 1)}
diff --git a/DI-engine/ding/model/template/q_learning.py b/DI-engine/ding/model/template/q_learning.py
new file mode 100644
index 0000000000000000000000000000000000000000..ece076bd81c4c0b95de854300c9d2417ecd5150c
--- /dev/null
+++ b/DI-engine/ding/model/template/q_learning.py
@@ -0,0 +1,1201 @@
+from typing import Union, Optional, Dict, Callable, List
+import torch
+import torch.nn as nn
+
+from ding.torch_utils import get_lstm
+from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
+from ..common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, MultiHead, RainbowHead, \
+ QuantileHead, FQFHead, QRDQNHead, DistributionHead, BranchingHead
+from ding.torch_utils.network.gtrxl import GTrXL
+
+
+@MODEL_REGISTRY.register('dqn')
+class DQN(nn.Module):
+ """
+ Overview:
+ The neural nework structure and computation graph of Deep Q Network (DQN) algorithm, which is the most classic \
+ value-based RL algorithm for discrete action. The DQN is composed of two parts: ``encoder`` and ``head``. \
+ The ``encoder`` is used to extract the feature from various observation, and the ``head`` is used to compute \
+ the Q value of each action dimension.
+ Interfaces:
+ ``__init__``, ``forward``.
+
+ .. note::
+ Current ``DQN`` supports two types of encoder: ``FCEncoder`` and ``ConvEncoder``, two types of head: \
+ ``DiscreteHead`` and ``DuelingHead``. You can customize your own encoder or head by inheriting this class.
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ dueling: bool = True,
+ head_hidden_size: Optional[int] = None,
+ head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ dropout: Optional[float] = None
+ ) -> None:
+ """
+ Overview:
+ initialize the DQN (encoder + head) Model according to corresponding input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
+ - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
+ the last element must match ``head_hidden_size``.
+ - dueling (:obj:`Optional[bool]`): Whether choose ``DuelingHead`` or ``DiscreteHead (default)``.
+ - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network, defaults to None, \
+ then it will be set to the last element of ``encoder_hidden_size_list``.
+ - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
+ if ``None`` then default set it to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
+ ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
+ - dropout (:obj:`Optional[float]`): The dropout rate of the dropout layer. \
+ if ``None`` then default disable dropout layer.
+ """
+ super(DQN, self).__init__()
+ # Squeeze data from tuple, list or dict to single object. For example, from (4, ) to 4
+ obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
+ if head_hidden_size is None:
+ head_hidden_size = encoder_hidden_size_list[-1]
+ # FC Encoder
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.encoder = FCEncoder(
+ obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type, dropout=dropout
+ )
+ # Conv Encoder
+ elif len(obs_shape) == 3:
+ assert dropout is None, "dropout is not supported in ConvEncoder"
+ self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own DQN".format(obs_shape)
+ )
+ # Head Type
+ if dueling:
+ head_cls = DuelingHead
+ else:
+ head_cls = DiscreteHead
+ multi_head = not isinstance(action_shape, int)
+ if multi_head:
+ self.head = MultiHead(
+ head_cls,
+ head_hidden_size,
+ action_shape,
+ layer_num=head_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ dropout=dropout
+ )
+ else:
+ self.head = head_cls(
+ head_hidden_size,
+ action_shape,
+ head_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ dropout=dropout
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ DQN forward computation graph, input observation tensor to predict q_value.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input observation tensor data.
+ Returns:
+ - outputs (:obj:`Dict`): The output of DQN's forward, including q_value.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): Discrete Q-value output of each possible action dimension.
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
+ - logit (:obj:`torch.Tensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape``
+ Examples:
+ >>> model = DQN(32, 6) # arguments: 'obs_shape' and 'action_shape'
+ >>> inputs = torch.randn(4, 32)
+ >>> outputs = model(inputs)
+ >>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 6])
+
+ .. note::
+ For consistency and compatibility, we name all the outputs of the network which are related to action \
+ selections as ``logit``.
+ """
+ x = self.encoder(x)
+ x = self.head(x)
+ return x
+
+
+@MODEL_REGISTRY.register('bdq')
+class BDQ(nn.Module):
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ num_branches: int = 0,
+ action_bins_per_branch: int = 2,
+ layer_num: int = 3,
+ a_layer_num: Optional[int] = None,
+ v_layer_num: Optional[int] = None,
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ head_hidden_size: Optional[int] = None,
+ norm_type: Optional[nn.Module] = None,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ ) -> None:
+ """
+ Overview:
+ Init the BDQ (encoder + head) Model according to input arguments. \
+ referenced paper Action Branching Architectures for Deep Reinforcement Learning \
+
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
+ - num_branches (:obj:`int`): The number of branches, which is equivalent to the action dimension, \
+ such as 6 in mujoco's halfcheetah environment.
+ - action_bins_per_branch (:obj:`int`): The number of actions in each dimension.
+ - layer_num (:obj:`int`): The number of layers used in the network to compute Advantage and Value output.
+ - a_layer_num (:obj:`int`): The number of layers used in the network to compute Advantage output.
+ - v_layer_num (:obj:`int`): The number of layers used in the network to compute Value output.
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
+ the last element must match ``head_hidden_size``.
+ - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network.
+ - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
+ ``ding.torch_utils.fc_block`` for more details.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
+ if ``None`` then default set it to ``nn.ReLU()``
+ """
+ super(BDQ, self).__init__()
+ # For compatibility: 1, (1, ), [4, 32, 32]
+ obs_shape, num_branches = squeeze(obs_shape), squeeze(num_branches)
+ if head_hidden_size is None:
+ head_hidden_size = encoder_hidden_size_list[-1]
+
+ # backbone
+ # FC Encoder
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ # Conv Encoder
+ elif len(obs_shape) == 3:
+ self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own DQN".format(obs_shape)
+ )
+
+ self.num_branches = num_branches
+ self.action_bins_per_branch = action_bins_per_branch
+
+ # head
+ self.head = BranchingHead(
+ head_hidden_size,
+ num_branches=self.num_branches,
+ action_bins_per_branch=self.action_bins_per_branch,
+ layer_num=layer_num,
+ a_layer_num=a_layer_num,
+ v_layer_num=v_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ BDQ forward computation graph, input observation tensor to predict q_value.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Observation inputs
+ Returns:
+ - outputs (:obj:`Dict`): BDQ forward outputs, such as q_value.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): Discrete Q-value output of each action dimension.
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where B is batch size and M is
+ ``num_branches * action_bins_per_branch``
+ Examples:
+ >>> model = BDQ(8, 5, 2) # arguments: 'obs_shape', 'num_branches' and 'action_bins_per_branch'.
+ >>> inputs = torch.randn(4, 8)
+ >>> outputs = model(inputs)
+ >>> assert isinstance(outputs, dict) and outputs['logit'].shape == torch.Size([4, 5, 2])
+ """
+ x = self.encoder(x) / (self.num_branches + 1) # corresponds to the "Gradient Rescaling" in the paper
+ x = self.head(x)
+ return x
+
+
+@MODEL_REGISTRY.register('c51dqn')
+class C51DQN(nn.Module):
+ """
+ Overview:
+ The neural network structure and computation graph of C51DQN, which combines distributional RL and DQN. \
+ You can refer to https://arxiv.org/pdf/1707.06887.pdf for more details. The C51DQN is composed of \
+ ``encoder`` and ``head``. ``encoder`` is used to extract the feature of observation, and ``head`` is \
+ used to compute the distribution of Q-value.
+ Interfaces:
+ ``__init__``, ``forward``
+
+ .. note::
+ Current C51DQN supports two types of encoder: ``FCEncoder`` and ``ConvEncoder``.
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ head_hidden_size: int = None,
+ head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ v_min: Optional[float] = -10,
+ v_max: Optional[float] = 10,
+ n_atom: Optional[int] = 51,
+ ) -> None:
+ """
+ Overview:
+ initialize the C51 Model according to corresponding input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
+ - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
+ the last element must match ``head_hidden_size``.
+ - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network, defaults to None, \
+ then it will be set to the last element of ``encoder_hidden_size_list``.
+ - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
+ if ``None`` then default set it to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
+ ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
+ - v_min (:obj:`Optional[float]`): The minimum value of the support of the distribution, which is related \
+ to the value (discounted sum of reward) scale of the specific environment. Defaults to -10.
+ - v_max (:obj:`Optional[float]`): The maximum value of the support of the distribution, which is related \
+ to the value (discounted sum of reward) scale of the specific environment. Defaults to 10.
+ - n_atom (:obj:`Optional[int]`): The number of atoms in the prediction distribution, 51 is the default \
+ value in the paper, you can also try other values such as 301.
+ """
+ super(C51DQN, self).__init__()
+ # For compatibility: 1, (1, ), [4, 32, 32]
+ obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
+ if head_hidden_size is None:
+ head_hidden_size = encoder_hidden_size_list[-1]
+ # FC Encoder
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ # Conv Encoder
+ elif len(obs_shape) == 3:
+ self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own C51DQN".format(obs_shape)
+ )
+ # Head Type
+ multi_head = not isinstance(action_shape, int)
+ if multi_head:
+ self.head = MultiHead(
+ DistributionHead,
+ head_hidden_size,
+ action_shape,
+ layer_num=head_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ n_atom=n_atom,
+ v_min=v_min,
+ v_max=v_max,
+ )
+ else:
+ self.head = DistributionHead(
+ head_hidden_size,
+ action_shape,
+ head_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ n_atom=n_atom,
+ v_min=v_min,
+ v_max=v_max,
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ C51DQN forward computation graph, input observation tensor to predict q_value and its distribution.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input observation tensor data.
+ Returns:
+ - outputs (:obj:`Dict`): The output of DQN's forward, including q_value, and distribution.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): Discrete Q-value output of each possible action dimension.
+ - distribution (:obj:`torch.Tensor`): Q-Value discretized distribution, i.e., probability of each \
+ uniformly spaced atom Q-value, such as dividing [-10, 10] into 51 uniform spaces.
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size.
+ - logit (:obj:`torch.Tensor`): :math:`(B, M)`, where M is action_shape.
+ - distribution(:obj:`torch.Tensor`): :math:`(B, M, P)`, where P is n_atom.
+ Examples:
+ >>> model = C51DQN(128, 64) # arguments: 'obs_shape' and 'action_shape'
+ >>> inputs = torch.randn(4, 128)
+ >>> outputs = model(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> # default head_hidden_size: int = 64,
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+ >>> # default n_atom: int = 51
+ >>> assert outputs['distribution'].shape == torch.Size([4, 64, 51])
+
+ .. note::
+ For consistency and compatibility, we name all the outputs of the network which are related to action \
+ selections as ``logit``.
+
+ .. note::
+ For convenience, we recommend that the number of atoms should be odd, so that the middle atom is exactly \
+ the value of the Q-value.
+ """
+ x = self.encoder(x)
+ x = self.head(x)
+ return x
+
+
+@MODEL_REGISTRY.register('qrdqn')
+class QRDQN(nn.Module):
+ """
+ Overview:
+ The neural network structure and computation graph of QRDQN, which combines distributional RL and DQN. \
+ You can refer to Distributional Reinforcement Learning with Quantile Regression \
+ https://arxiv.org/pdf/1710.10044.pdf for more details.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ head_hidden_size: Optional[int] = None,
+ head_layer_num: int = 1,
+ num_quantiles: int = 32,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize the QRDQN Model according to input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
+ - action_shape (:obj:`Union[int, SequenceType]`): Action's space.
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``
+ - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``.
+ - head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output
+ - num_quantiles (:obj:`int`): Number of quantiles in the prediction distribution.
+ - activation (:obj:`Optional[nn.Module]`):
+ The type of activation function to use in ``MLP`` the after ``layer_fn``,
+ if ``None`` then default set to ``nn.ReLU()``
+ - norm_type (:obj:`Optional[str]`):
+ The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details`
+ """
+ super(QRDQN, self).__init__()
+ # For compatibility: 1, (1, ), [4, 32, 32]
+ obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
+ if head_hidden_size is None:
+ head_hidden_size = encoder_hidden_size_list[-1]
+ # FC Encoder
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ # Conv Encoder
+ elif len(obs_shape) == 3:
+ self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own QRDQN".format(obs_shape)
+ )
+ # Head Type
+ multi_head = not isinstance(action_shape, int)
+ if multi_head:
+ self.head = MultiHead(
+ QRDQNHead,
+ head_hidden_size,
+ action_shape,
+ layer_num=head_layer_num,
+ num_quantiles=num_quantiles,
+ activation=activation,
+ norm_type=norm_type,
+ )
+ else:
+ self.head = QRDQNHead(
+ head_hidden_size,
+ action_shape,
+ head_layer_num,
+ num_quantiles=num_quantiles,
+ activation=activation,
+ norm_type=norm_type,
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use observation tensor to predict QRDQN's output.
+ Parameter updates with QRDQN's MLPs forward setup.
+ Arguments:
+ - x (:obj:`torch.Tensor`):
+ The encoded embedding tensor with ``(B, N=hidden_size)``.
+ Returns:
+ - outputs (:obj:`Dict`):
+ Run with encoder and head. Return the result prediction dictionary.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``x``.
+ - q (:obj:`torch.Tensor`): Q valye tensor tensor of size ``(B, N, num_quantiles)``
+ - tau (:obj:`torch.Tensor`): tau tensor of size ``(B, N, 1)``
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size.
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is action_shape.
+ - tau (:obj:`torch.Tensor`): :math:`(B, M, 1)`
+ Examples:
+ >>> model = QRDQN(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = model(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+ >>> # default num_quantiles : int = 32
+ >>> assert outputs['q'].shape == torch.Size([4, 64, 32])
+ >>> assert outputs['tau'].shape == torch.Size([4, 32, 1])
+ """
+ x = self.encoder(x)
+ x = self.head(x)
+ return x
+
+
+@MODEL_REGISTRY.register('iqn')
+class IQN(nn.Module):
+ """
+ Overview:
+ The neural network structure and computation graph of IQN, which combines distributional RL and DQN. \
+ You can refer to paper Implicit Quantile Networks for Distributional Reinforcement Learning \
+ https://arxiv.org/pdf/1806.06923.pdf for more details.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ head_hidden_size: Optional[int] = None,
+ head_layer_num: int = 1,
+ num_quantiles: int = 32,
+ quantile_embedding_size: int = 128,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None
+ ) -> None:
+ """
+ Overview:
+ Initialize the IQN Model according to input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape.
+ - action_shape (:obj:`Union[int, SequenceType]`): Action space shape.
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``
+ - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``.
+ - head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output
+ - num_quantiles (:obj:`int`): Number of quantiles in the prediction distribution.
+ - activation (:obj:`Optional[nn.Module]`):
+ The type of activation function to use in ``MLP`` the after ``layer_fn``,
+ if ``None`` then default set to ``nn.ReLU()``
+ - norm_type (:obj:`Optional[str]`):
+ The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details.
+ """
+ super(IQN, self).__init__()
+ # For compatibility: 1, (1, ), [4, 32, 32]
+ obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
+ if head_hidden_size is None:
+ head_hidden_size = encoder_hidden_size_list[-1]
+ # FC Encoder
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ # Conv Encoder
+ elif len(obs_shape) == 3:
+ self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own IQN".format(obs_shape)
+ )
+ # Head Type
+ head_cls = QuantileHead
+ multi_head = not isinstance(action_shape, int)
+ if multi_head:
+ self.head = MultiHead(
+ head_cls,
+ head_hidden_size,
+ action_shape,
+ layer_num=head_layer_num,
+ num_quantiles=num_quantiles,
+ quantile_embedding_size=quantile_embedding_size,
+ activation=activation,
+ norm_type=norm_type
+ )
+ else:
+ self.head = head_cls(
+ head_hidden_size,
+ action_shape,
+ head_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ num_quantiles=num_quantiles,
+ quantile_embedding_size=quantile_embedding_size,
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to predict IQN's output.
+ Parameter updates with IQN's MLPs forward setup.
+ Arguments:
+ - x (:obj:`torch.Tensor`):
+ The encoded embedding tensor with ``(B, N=hidden_size)``.
+ Returns:
+ - outputs (:obj:`Dict`):
+ Run with encoder and head. Return the result prediction dictionary.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``x``.
+ - q (:obj:`torch.Tensor`): Q valye tensor tensor of size ``(num_quantiles, N, B)``
+ - quantiles (:obj:`torch.Tensor`): quantiles tensor of size ``(quantile_embedding_size, 1)``
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size.
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is action_shape
+ - quantiles (:obj:`torch.Tensor`): :math:`(P, 1)`, where P is quantile_embedding_size.
+ Examples:
+ >>> model = IQN(64, 64) # arguments: 'obs_shape' and 'action_shape'
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = model(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+ >>> # default num_quantiles: int = 32
+ >>> assert outputs['q'].shape == torch.Size([32, 4, 64]
+ >>> # default quantile_embedding_size: int = 128
+ >>> assert outputs['quantiles'].shape == torch.Size([128, 1])
+ """
+ x = self.encoder(x)
+ x = self.head(x)
+ return x
+
+
+@MODEL_REGISTRY.register('fqf')
+class FQF(nn.Module):
+ """
+ Overview:
+ The neural network structure and computation graph of FQF, which combines distributional RL and DQN. \
+ You can refer to paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \
+ https://arxiv.org/pdf/1911.02140.pdf for more details.
+ Interface:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ head_hidden_size: Optional[int] = None,
+ head_layer_num: int = 1,
+ num_quantiles: int = 32,
+ quantile_embedding_size: int = 128,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None
+ ) -> None:
+ """
+ Overview:
+ Initialize the FQF Model according to input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape.
+ - action_shape (:obj:`Union[int, SequenceType]`): Action space shape.
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``
+ - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``.
+ - head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output
+ - num_quantiles (:obj:`int`): Number of quantiles in the prediction distribution.
+ - activation (:obj:`Optional[nn.Module]`):
+ The type of activation function to use in ``MLP`` the after ``layer_fn``,
+ if ``None`` then default set to ``nn.ReLU()``
+ - norm_type (:obj:`Optional[str]`):
+ The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details.
+ """
+ super(FQF, self).__init__()
+ # For compatibility: 1, (1, ), [4, 32, 32]
+ obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
+ if head_hidden_size is None:
+ head_hidden_size = encoder_hidden_size_list[-1]
+ # FC Encoder
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ # Conv Encoder
+ elif len(obs_shape) == 3:
+ self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own FQF".format(obs_shape)
+ )
+ # Head Type
+ head_cls = FQFHead
+ multi_head = not isinstance(action_shape, int)
+ if multi_head:
+ self.head = MultiHead(
+ head_cls,
+ head_hidden_size,
+ action_shape,
+ layer_num=head_layer_num,
+ num_quantiles=num_quantiles,
+ quantile_embedding_size=quantile_embedding_size,
+ activation=activation,
+ norm_type=norm_type
+ )
+ else:
+ self.head = head_cls(
+ head_hidden_size,
+ action_shape,
+ head_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ num_quantiles=num_quantiles,
+ quantile_embedding_size=quantile_embedding_size,
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to predict FQF's output.
+ Parameter updates with FQF's MLPs forward setup.
+ Arguments:
+ - x (:obj:`torch.Tensor`):
+ The encoded embedding tensor with ``(B, N=hidden_size)``.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``logit`` (:obj:`torch.Tensor`), \
+ ``q`` (:obj:`torch.Tensor`), ``quantiles`` (:obj:`torch.Tensor`), \
+ ``quantiles_hats`` (:obj:`torch.Tensor`), \
+ ``q_tau_i`` (:obj:`torch.Tensor`), ``entropies`` (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N)`, where B is batch size and N is head_hidden_size.
+ - logit: :math:`(B, M)`, where M is action_shape.
+ - q: :math:`(B, num_quantiles, M)`.
+ - quantiles: :math:`(B, num_quantiles + 1)`.
+ - quantiles_hats: :math:`(B, num_quantiles)`.
+ - q_tau_i: :math:`(B, num_quantiles - 1, M)`.
+ - entropies: :math:`(B, 1)`.
+ Examples:
+ >>> model = FQF(64, 64) # arguments: 'obs_shape' and 'action_shape'
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = model(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+ >>> # default num_quantiles: int = 32
+ >>> assert outputs['q'].shape == torch.Size([4, 32, 64])
+ >>> assert outputs['quantiles'].shape == torch.Size([4, 33])
+ >>> assert outputs['quantiles_hats'].shape == torch.Size([4, 32])
+ >>> assert outputs['q_tau_i'].shape == torch.Size([4, 31, 64])
+ >>> assert outputs['quantiles'].shape == torch.Size([4, 1])
+ """
+ x = self.encoder(x)
+ x = self.head(x)
+ return x
+
+
+@MODEL_REGISTRY.register('rainbowdqn')
+class RainbowDQN(nn.Module):
+ """
+ Overview:
+ The neural network structure and computation graph of RainbowDQN, which combines distributional RL and DQN. \
+ You can refer to paper Rainbow: Combining Improvements in Deep Reinforcement Learning \
+ https://arxiv.org/pdf/1710.02298.pdf for more details.
+ Interfaces:
+ ``__init__``, ``forward``
+
+ .. note::
+ RainbowDQN contains dueling architecture by default.
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ head_hidden_size: Optional[int] = None,
+ head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ v_min: Optional[float] = -10,
+ v_max: Optional[float] = 10,
+ n_atom: Optional[int] = 51,
+ ) -> None:
+ """
+ Overview:
+ Init the Rainbow Model according to arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape.
+ - action_shape (:obj:`Union[int, SequenceType]`): Action space shape.
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``
+ - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to ``Head``.
+ - head_layer_num (:obj:`int`): The num of layers used in the network to compute Q value output
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` the after \
+ ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``
+ - norm_type (:obj:`Optional[str]`): The type of normalization to use, see ``ding.torch_utils.fc_block`` \
+ for more details`
+ - n_atom (:obj:`Optional[int]`): Number of atoms in the prediction distribution.
+ """
+ super(RainbowDQN, self).__init__()
+ # For compatibility: 1, (1, ), [4, 32, 32]
+ obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
+ if head_hidden_size is None:
+ head_hidden_size = encoder_hidden_size_list[-1]
+ # FC Encoder
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ # Conv Encoder
+ elif len(obs_shape) == 3:
+ self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own RainbowDQN".
+ format(obs_shape)
+ )
+ # Head Type
+ multi_head = not isinstance(action_shape, int)
+ if multi_head:
+ self.head = MultiHead(
+ RainbowHead,
+ head_hidden_size,
+ action_shape,
+ layer_num=head_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ n_atom=n_atom,
+ v_min=v_min,
+ v_max=v_max,
+ )
+ else:
+ self.head = RainbowHead(
+ head_hidden_size,
+ action_shape,
+ head_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ n_atom=n_atom,
+ v_min=v_min,
+ v_max=v_max,
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use observation tensor to predict Rainbow output.
+ Parameter updates with Rainbow's MLPs forward setup.
+ Arguments:
+ - x (:obj:`torch.Tensor`):
+ The encoded embedding tensor with ``(B, N=hidden_size)``.
+ Returns:
+ - outputs (:obj:`Dict`):
+ Run ``MLP`` with ``RainbowHead`` setups and return the result prediction dictionary.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): Logit tensor with same size as input ``x``.
+ - distribution (:obj:`torch.Tensor`): Distribution tensor of size ``(B, N, n_atom)``
+ Shapes:
+ - x (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is head_hidden_size.
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is action_shape.
+ - distribution(:obj:`torch.FloatTensor`): :math:`(B, M, P)`, where P is n_atom.
+ Examples:
+ >>> model = RainbowDQN(64, 64) # arguments: 'obs_shape' and 'action_shape'
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = model(inputs)
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+ >>> # default n_atom: int =51
+ >>> assert outputs['distribution'].shape == torch.Size([4, 64, 51])
+ """
+ x = self.encoder(x)
+ x = self.head(x)
+ return x
+
+
+def parallel_wrapper(forward_fn: Callable) -> Callable:
+ """
+ Overview:
+ Process timestep T and batch_size B at the same time, in other words, treat different timestep data as
+ different trajectories in a batch.
+ Arguments:
+ - forward_fn (:obj:`Callable`): Normal ``nn.Module`` 's forward function.
+ Returns:
+ - wrapper (:obj:`Callable`): Wrapped function.
+ """
+
+ def wrapper(x: torch.Tensor) -> Union[torch.Tensor, List[torch.Tensor]]:
+ T, B = x.shape[:2]
+
+ def reshape(d):
+ if isinstance(d, list):
+ d = [reshape(t) for t in d]
+ elif isinstance(d, dict):
+ d = {k: reshape(v) for k, v in d.items()}
+ else:
+ d = d.reshape(T, B, *d.shape[1:])
+ return d
+
+ # NOTE(rjy): the initial input shape will be (T, B, N),
+ # means encoder or head should process B trajectorys, each trajectory has T timestep,
+ # but T and B dimension can be both treated as batch_size in encoder and head,
+ # i.e., independent and parallel processing,
+ # so here we need such fn to reshape for encoder or head
+ x = x.reshape(T * B, *x.shape[2:])
+ x = forward_fn(x)
+ x = reshape(x)
+ return x
+
+ return wrapper
+
+
+@MODEL_REGISTRY.register('drqn')
+class DRQN(nn.Module):
+ """
+ Overview:
+ The neural network structure and computation graph of DRQN (DQN + RNN = DRQN) algorithm, which is the most \
+ common DQN variant for sequential data and paratially observable environment. The DRQN is composed of three \
+ parts: ``encoder``, ``head`` and ``rnn``. The ``encoder`` is used to extract the feature from various \
+ observation, the ``rnn`` is used to process the sequential observation and other data, and the ``head`` is \
+ used to compute the Q value of each action dimension.
+ Interfaces:
+ ``__init__``, ``forward``.
+
+ .. note::
+ Current ``DRQN`` supports two types of encoder: ``FCEncoder`` and ``ConvEncoder``, two types of head: \
+ ``DiscreteHead`` and ``DuelingHead``, three types of rnn: ``normal (LSTM with LayerNorm)``, ``pytorch`` and \
+ ``gru``. You can customize your own encoder, rnn or head by inheriting this class.
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ dueling: bool = True,
+ head_hidden_size: Optional[int] = None,
+ head_layer_num: int = 1,
+ lstm_type: Optional[str] = 'normal',
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ res_link: bool = False
+ ) -> None:
+ """
+ Overview:
+ Initialize the DRQN Model according to the corresponding input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
+ - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
+ the last element must match ``head_hidden_size``.
+ - dueling (:obj:`Optional[bool]`): Whether choose ``DuelingHead`` or ``DiscreteHead (default)``.
+ - head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of head network, defaults to None, \
+ then it will be set to the last element of ``encoder_hidden_size_list``.
+ - head_layer_num (:obj:`int`): The number of layers used in the head network to compute Q value output.
+ - lstm_type (:obj:`Optional[str]`): The type of RNN module, now support ['normal', 'pytorch', 'gru'].
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
+ if ``None`` then default set it to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
+ ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
+ - res_link (:obj:`bool`): Whether to enable the residual link, which is the skip connnection between \
+ single frame data and the sequential data, defaults to False.
+ """
+ super(DRQN, self).__init__()
+ # For compatibility: 1, (1, ), [4, 32, 32]
+ obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
+ if head_hidden_size is None:
+ head_hidden_size = encoder_hidden_size_list[-1]
+ # FC Encoder
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.encoder = FCEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ # Conv Encoder
+ elif len(obs_shape) == 3:
+ self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own DRQN".format(obs_shape)
+ )
+ # LSTM Type
+ self.rnn = get_lstm(lstm_type, input_size=head_hidden_size, hidden_size=head_hidden_size)
+ self.res_link = res_link
+ # Head Type
+ if dueling:
+ head_cls = DuelingHead
+ else:
+ head_cls = DiscreteHead
+ multi_head = not isinstance(action_shape, int)
+ if multi_head:
+ self.head = MultiHead(
+ head_cls,
+ head_hidden_size,
+ action_shape,
+ layer_num=head_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+ else:
+ self.head = head_cls(
+ head_hidden_size, action_shape, head_layer_num, activation=activation, norm_type=norm_type
+ )
+
+ def forward(self, inputs: Dict, inference: bool = False, saved_state_timesteps: Optional[list] = None) -> Dict:
+ """
+ Overview:
+ DRQN forward computation graph, input observation tensor to predict q_value.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): The dict of input data, including observation and previous rnn state.
+ - inference: (:obj:'bool'): Whether to enable inference forward mode, if True, we unroll the one timestep \
+ transition, otherwise, we unroll the eentire sequence transitions.
+ - saved_state_timesteps: (:obj:'Optional[list]'): When inference is False, we unroll the sequence \
+ transitions, then we would use this list to indicate how to save and return hidden state.
+ ArgumentsKeys:
+ - obs (:obj:`torch.Tensor`): The raw observation tensor.
+ - prev_state (:obj:`list`): The previous rnn state tensor, whose structure depends on ``lstm_type``.
+ Returns:
+ - outputs (:obj:`Dict`): The output of DRQN's forward, including logit (q_value) and next state.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): Discrete Q-value output of each possible action dimension.
+ - next_state (:obj:`list`): The next rnn state tensor, whose structure depends on ``lstm_type``.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``obs_shape``
+ - logit (:obj:`torch.Tensor`): :math:`(B, M)`, where B is batch size and M is ``action_shape``
+ Examples:
+ >>> # Init input's Keys:
+ >>> prev_state = [[torch.randn(1, 1, 64) for __ in range(2)] for _ in range(4)] # B=4
+ >>> obs = torch.randn(4,64)
+ >>> model = DRQN(64, 64) # arguments: 'obs_shape' and 'action_shape'
+ >>> outputs = model({'obs': inputs, 'prev_state': prev_state}, inference=True)
+ >>> # Check outputs's Keys
+ >>> assert isinstance(outputs, dict)
+ >>> assert outputs['logit'].shape == (4, 64)
+ >>> assert len(outputs['next_state']) == 4
+ >>> assert all([len(t) == 2 for t in outputs['next_state']])
+ >>> assert all([t[0].shape == (1, 1, 64) for t in outputs['next_state']])
+ """
+
+ x, prev_state = inputs['obs'], inputs['prev_state']
+ # for both inference and other cases, the network structure is encoder -> rnn network -> head
+ # the difference is inference take the data with seq_len=1 (or T = 1)
+ # NOTE(rjy): in most situations, set inference=True when evaluate and inference=False when training
+ if inference:
+ x = self.encoder(x)
+ if self.res_link:
+ a = x
+ x = x.unsqueeze(0) # for rnn input, put the seq_len of x as 1 instead of none.
+ # prev_state: DataType: List[Tuple[torch.Tensor]]; Initially, it is a list of None
+ x, next_state = self.rnn(x, prev_state)
+ x = x.squeeze(0) # to delete the seq_len dim to match head network input
+ if self.res_link:
+ x = x + a
+ x = self.head(x)
+ x['next_state'] = next_state
+ return x
+ else:
+ # In order to better explain why rnn needs saved_state and which states need to be stored,
+ # let's take r2d2 as an example
+ # in r2d2,
+ # 1) data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep]
+ # 2) data['main_obs'] = data['obs'][bs:-self._nstep]
+ # 3) data['target_obs'] = data['obs'][bs + self._nstep:]
+ # NOTE(rjy): (T, B, N) or (T, B, C, H, W)
+ assert len(x.shape) in [3, 5], x.shape
+ x = parallel_wrapper(self.encoder)(x) # (T, B, N)
+ if self.res_link:
+ a = x
+ # NOTE(rjy) lstm_embedding stores all hidden_state
+ lstm_embedding = []
+ # TODO(nyz) how to deal with hidden_size key-value
+ hidden_state_list = []
+ if saved_state_timesteps is not None:
+ saved_state = []
+ for t in range(x.shape[0]): # T timesteps
+ # NOTE(rjy) use x[t:t+1] but not x[t] can keep original dimension
+ output, prev_state = self.rnn(x[t:t + 1], prev_state) # output: (1,B, head_hidden_size)
+ if saved_state_timesteps is not None and t + 1 in saved_state_timesteps:
+ saved_state.append(prev_state)
+ lstm_embedding.append(output)
+ hidden_state = [p['h'] for p in prev_state]
+ # only keep ht, {list: x.shape[0]{Tensor:(1, batch_size, head_hidden_size)}}
+ hidden_state_list.append(torch.cat(hidden_state, dim=1))
+ x = torch.cat(lstm_embedding, 0) # (T, B, head_hidden_size)
+ if self.res_link:
+ x = x + a
+ x = parallel_wrapper(self.head)(x) # (T, B, action_shape)
+ # NOTE(rjy): x['next_state'] is the hidden state of the last timestep inputted to lstm
+ # the last timestep state including the hidden state (h) and the cell state (c)
+ # shape: {list: B{dict: 2{Tensor:(1, 1, head_hidden_size}}}
+ x['next_state'] = prev_state
+ # all hidden state h, this returns a tensor of the dim: seq_len*batch_size*head_hidden_size
+ # This key is used in qtran, the algorithm requires to retain all h_{t} during training
+ x['hidden_state'] = torch.cat(hidden_state_list, dim=0)
+ if saved_state_timesteps is not None:
+ # the selected saved hidden states, including the hidden state (h) and the cell state (c)
+ # in r2d2, set 'saved_hidden_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep]',
+ # then saved_state will record the hidden_state for main_obs and target_obs to
+ # initialize their lstm (h c)
+ x['saved_state'] = saved_state
+ return x
+
+
+@MODEL_REGISTRY.register('gtrxldqn')
+class GTrXLDQN(nn.Module):
+ """
+ Overview:
+ The neural network structure and computation graph of Gated Transformer-XL DQN algorithm, which is the \
+ enhanced version of DRQN, using Transformer-XL to improve long-term sequential modelling ability. The \
+ GTrXL-DQN is composed of three parts: ``encoder``, ``head`` and ``core``. The ``encoder`` is used to extract \
+ the feature from various observation, the ``core`` is used to process the sequential observation and other \
+ data, and the ``head`` is used to compute the Q value of each action dimension.
+ Interfaces:
+ ``__init__``, ``forward``, ``reset_memory``, ``get_memory`` .
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ head_layer_num: int = 1,
+ att_head_dim: int = 16,
+ hidden_size: int = 16,
+ att_head_num: int = 2,
+ att_mlp_num: int = 2,
+ att_layer_num: int = 3,
+ memory_len: int = 64,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ head_norm_type: Optional[str] = None,
+ dropout: float = 0.,
+ gru_gating: bool = True,
+ gru_bias: float = 2.,
+ dueling: bool = True,
+ encoder_hidden_size_list: SequenceType = [128, 128, 256],
+ encoder_norm_type: Optional[str] = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize the GTrXLDQN model accoding to corresponding input arguments.
+
+ .. tip::
+ You can refer to GTrXl class in ``ding.torch_utils.network.gtrxl`` for more details about the input \
+ arguments.
+
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Used by Transformer. Observation's space.
+ - action_shape (:obj:Union[int, SequenceType]): Used by Head. Action's space.
+ - head_layer_num (:obj:`int`): Used by Head. Number of layers.
+ - att_head_dim (:obj:`int`): Used by Transformer.
+ - hidden_size (:obj:`int`): Used by Transformer and Head.
+ - att_head_num (:obj:`int`): Used by Transformer.
+ - att_mlp_num (:obj:`int`): Used by Transformer.
+ - att_layer_num (:obj:`int`): Used by Transformer.
+ - memory_len (:obj:`int`): Used by Transformer.
+ - activation (:obj:`Optional[nn.Module]`): Used by Transformer and Head. if ``None`` then default set to \
+ ``nn.ReLU()``.
+ - head_norm_type (:obj:`Optional[str]`): Used by Head. The type of normalization to use, see \
+ ``ding.torch_utils.fc_block`` for more details`.
+ - dropout (:obj:`bool`): Used by Transformer.
+ - gru_gating (:obj:`bool`): Used by Transformer.
+ - gru_bias (:obj:`float`): Used by Transformer.
+ - dueling (:obj:`bool`): Used by Head. Make the head dueling.
+ - encoder_hidden_size_list(:obj:`SequenceType`): Used by Encoder. The collection of ``hidden_size`` if \
+ using a custom convolutional encoder.
+ - encoder_norm_type (:obj:`Optional[str]`): Used by Encoder. The type of normalization to use, see \
+ ``ding.torch_utils.fc_block`` for more details`.
+ """
+ super(GTrXLDQN, self).__init__()
+ self.core = GTrXL(
+ input_dim=obs_shape,
+ head_dim=att_head_dim,
+ embedding_dim=hidden_size,
+ head_num=att_head_num,
+ mlp_num=att_mlp_num,
+ layer_num=att_layer_num,
+ memory_len=memory_len,
+ activation=activation,
+ dropout_ratio=dropout,
+ gru_gating=gru_gating,
+ gru_bias=gru_bias,
+ )
+
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ raise NotImplementedError("not support obs_shape for pre-defined encoder: {}".format(obs_shape))
+ # replace the embedding layer of Transformer with Conv Encoder
+ elif len(obs_shape) == 3:
+ assert encoder_hidden_size_list[-1] == hidden_size
+ self.obs_encoder = ConvEncoder(
+ obs_shape, encoder_hidden_size_list, activation=activation, norm_type=encoder_norm_type
+ )
+ self.dropout = nn.Dropout(dropout)
+ self.core.use_embedding_layer = False
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own GTrXL".format(obs_shape)
+ )
+ # Head Type
+ if dueling:
+ head_cls = DuelingHead
+ else:
+ head_cls = DiscreteHead
+ multi_head = not isinstance(action_shape, int)
+ if multi_head:
+ self.head = MultiHead(
+ head_cls,
+ hidden_size,
+ action_shape,
+ layer_num=head_layer_num,
+ activation=activation,
+ norm_type=head_norm_type
+ )
+ else:
+ self.head = head_cls(
+ hidden_size, action_shape, head_layer_num, activation=activation, norm_type=head_norm_type
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Let input tensor go through GTrXl and the Head sequentially.
+ Arguments:
+ - x (:obj:`torch.Tensor`): input tensor of shape (seq_len, bs, obs_shape).
+ Returns:
+ - out (:obj:`Dict`): run ``GTrXL`` with ``DiscreteHead`` setups and return the result prediction dictionary.
+ ReturnKeys:
+ - logit (:obj:`torch.Tensor`): discrete Q-value output of each action dimension, shape is (B, action_space).
+ - memory (:obj:`torch.Tensor`): memory tensor of size ``(bs x layer_num+1 x memory_len x embedding_dim)``.
+ - transformer_out (:obj:`torch.Tensor`): output tensor of transformer with same size as input ``x``.
+ Examples:
+ >>> # Init input's Keys:
+ >>> obs_dim, seq_len, bs, action_dim = 128, 64, 32, 4
+ >>> obs = torch.rand(seq_len, bs, obs_dim)
+ >>> model = GTrXLDQN(obs_dim, action_dim)
+ >>> outputs = model(obs)
+ >>> assert isinstance(outputs, dict)
+ """
+ if len(x.shape) == 5:
+ # 3d obs: cur_seq, bs, ch, h, w
+ x_ = x.reshape([x.shape[0] * x.shape[1]] + list(x.shape[-3:]))
+ x_ = self.dropout(self.obs_encoder(x_))
+ x = x_.reshape(x.shape[0], x.shape[1], -1)
+ o1 = self.core(x)
+ out = self.head(o1['logit'])
+ # layer_num+1 x memory_len x bs embedding_dim -> bs x layer_num+1 x memory_len x embedding_dim
+ out['memory'] = o1['memory'].permute((2, 0, 1, 3)).contiguous()
+ out['transformer_out'] = o1['logit'] # output of gtrxl, out['logit'] is final output
+ return out
+
+ def reset_memory(self, batch_size: Optional[int] = None, state: Optional[torch.Tensor] = None) -> None:
+ """
+ Overview:
+ Clear or reset the memory of GTrXL.
+ Arguments:
+ - batch_size (:obj:`Optional[int]`): The number of samples in a training batch.
+ - state (:obj:`Optional[torch.Tensor]`): The input memory data, whose shape is \
+ (layer_num, memory_len, bs, embedding_dim).
+ """
+ self.core.reset_memory(batch_size, state)
+
+ def get_memory(self) -> Optional[torch.Tensor]:
+ """
+ Overview:
+ Return the memory of GTrXL.
+ Returns:
+ - memory: (:obj:`Optional[torch.Tensor]`): output memory or None if memory has not been initialized, \
+ whose shape is (layer_num, memory_len, bs, embedding_dim).
+ """
+ return self.core.get_memory()
diff --git a/DI-engine/ding/model/template/qac.py b/DI-engine/ding/model/template/qac.py
new file mode 100755
index 0000000000000000000000000000000000000000..6034a4d74cf970adbd7c540707a047604692f20b
--- /dev/null
+++ b/DI-engine/ding/model/template/qac.py
@@ -0,0 +1,541 @@
+from typing import Union, Dict, Optional
+from easydict import EasyDict
+import numpy as np
+import torch
+import torch.nn as nn
+
+from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
+from ..common import RegressionHead, ReparameterizationHead, DiscreteHead, MultiHead, \
+ FCEncoder, ConvEncoder
+
+
+@MODEL_REGISTRY.register('continuous_qac')
+class ContinuousQAC(nn.Module):
+ """
+ Overview:
+ The neural network and computation graph of algorithms related to Q-value Actor-Critic (QAC), such as \
+ DDPG/TD3/SAC. This model now supports continuous and hybrid action space. The ContinuousQAC is composed of \
+ four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders are used to \
+ extract the feature from various observation. Heads are used to predict corresponding Q-value or action logit. \
+ In high-dimensional observation space like 2D image, we often use a shared encoder for both ``actor_encoder`` \
+ and ``critic_encoder``. In low-dimensional observation space like 1D vector, we often use different encoders.
+ Interfaces:
+ ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
+ """
+ mode = ['compute_actor', 'compute_critic']
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType, EasyDict],
+ action_space: str,
+ twin_critic: bool = False,
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 1,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ encoder_hidden_size_list: Optional[SequenceType] = None,
+ share_encoder: Optional[bool] = False,
+ ) -> None:
+ """
+ Overview:
+ Initailize the ContinuousQAC Model according to input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ).
+ - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ), \
+ EasyDict({'action_type_shape': 3, 'action_args_shape': 4}).
+ - action_space (:obj:`str`): The type of action space, including [``regression``, ``reparameterization``, \
+ ``hybrid``], ``regression`` is used for DDPG/TD3, ``reparameterization`` is used for SAC and \
+ ``hybrid`` for PADDPG.
+ - twin_critic (:obj:`bool`): Whether to use twin critic, one of tricks in TD3.
+ - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head.
+ - actor_head_layer_num (:obj:`int`): The num of layers used in the actor network to compute action.
+ - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head.
+ - critic_head_layer_num (:obj:`int`): The num of layers used in the critic network to compute Q-value.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \
+ after each FC layer, if ``None`` then default set to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \
+ see ``ding.torch_utils.network`` for more details.
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
+ the last element must match ``head_hidden_size``, this argument is only used in image observation.
+ - share_encoder (:obj:`Optional[bool]`): Whether to share encoder between actor and critic.
+ """
+ super(ContinuousQAC, self).__init__()
+ obs_shape: int = squeeze(obs_shape)
+ action_shape = squeeze(action_shape)
+ self.action_shape = action_shape
+ self.action_space = action_space
+ assert self.action_space in ['regression', 'reparameterization', 'hybrid'], self.action_space
+
+ # encoder
+ self.share_encoder = share_encoder
+ if np.isscalar(obs_shape) or len(obs_shape) == 1:
+ assert not self.share_encoder, "Vector observation doesn't need share encoder."
+ assert encoder_hidden_size_list is None, "Vector obs encoder only uses one layer nn.Linear"
+ # Because there is already a layer nn.Linear in the head, so we use nn.Identity here to keep
+ # compatible with the image observation and avoid adding an extra layer nn.Linear.
+ self.actor_encoder = nn.Identity()
+ self.critic_encoder = nn.Identity()
+ encoder_output_size = obs_shape
+ elif len(obs_shape) == 3:
+
+ def setup_conv_encoder():
+ kernel_size = [3 for _ in range(len(encoder_hidden_size_list))]
+ stride = [2] + [1 for _ in range(len(encoder_hidden_size_list) - 1)]
+ return ConvEncoder(
+ obs_shape,
+ encoder_hidden_size_list,
+ activation=activation,
+ norm_type=norm_type,
+ kernel_size=kernel_size,
+ stride=stride
+ )
+
+ if self.share_encoder:
+ encoder = setup_conv_encoder()
+ self.actor_encoder = self.critic_encoder = encoder
+ else:
+ self.actor_encoder = setup_conv_encoder()
+ self.critic_encoder = setup_conv_encoder()
+ encoder_output_size = self.actor_encoder.output_size
+ else:
+ raise RuntimeError("not support observation shape: {}".format(obs_shape))
+ # head
+ if self.action_space == 'regression': # DDPG, TD3
+ self.actor_head = nn.Sequential(
+ nn.Linear(encoder_output_size, actor_head_hidden_size), activation,
+ RegressionHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ final_tanh=True,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ elif self.action_space == 'reparameterization': # SAC
+ self.actor_head = nn.Sequential(
+ nn.Linear(encoder_output_size, actor_head_hidden_size), activation,
+ ReparameterizationHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ sigma_type='conditioned',
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ elif self.action_space == 'hybrid': # PADDPG
+ # hybrid action space: action_type(discrete) + action_args(continuous),
+ # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])}
+ action_shape.action_args_shape = squeeze(action_shape.action_args_shape)
+ action_shape.action_type_shape = squeeze(action_shape.action_type_shape)
+ actor_action_args = nn.Sequential(
+ nn.Linear(encoder_output_size, actor_head_hidden_size), activation,
+ RegressionHead(
+ actor_head_hidden_size,
+ action_shape.action_args_shape,
+ actor_head_layer_num,
+ final_tanh=True,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ actor_action_type = nn.Sequential(
+ nn.Linear(encoder_output_size, actor_head_hidden_size), activation,
+ DiscreteHead(
+ actor_head_hidden_size,
+ action_shape.action_type_shape,
+ actor_head_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ )
+ )
+ self.actor_head = nn.ModuleList([actor_action_type, actor_action_args])
+
+ self.twin_critic = twin_critic
+ if self.action_space == 'hybrid':
+ critic_input_size = encoder_output_size + action_shape.action_type_shape + action_shape.action_args_shape
+ else:
+ critic_input_size = encoder_output_size + action_shape
+ if self.twin_critic:
+ self.critic_head = nn.ModuleList()
+ for _ in range(2):
+ self.critic_head.append(
+ nn.Sequential(
+ nn.Linear(critic_input_size, critic_head_hidden_size), activation,
+ RegressionHead(
+ critic_head_hidden_size,
+ 1,
+ critic_head_layer_num,
+ final_tanh=False,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ )
+ else:
+ self.critic_head = nn.Sequential(
+ nn.Linear(critic_input_size, critic_head_hidden_size), activation,
+ RegressionHead(
+ critic_head_hidden_size,
+ 1,
+ critic_head_layer_num,
+ final_tanh=False,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+
+ # Convenient for calling some apis (e.g. self.critic.parameters()),
+ # but may cause misunderstanding when `print(self)`
+ self.actor = nn.ModuleList([self.actor_encoder, self.actor_head])
+ self.critic = nn.ModuleList([self.critic_encoder, self.critic_head])
+
+ def forward(self, inputs: Union[torch.Tensor, Dict[str, torch.Tensor]], mode: str) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ QAC forward computation graph, input observation tensor to predict Q-value or action logit. Different \
+ ``mode`` will forward with different network modules to get different outputs and save computation.
+ Arguments:
+ - inputs (:obj:`Union[torch.Tensor, Dict[str, torch.Tensor]]`): The input data for forward computation \
+ graph, for ``compute_actor``, it is the observation tensor, for ``compute_critic``, it is the \
+ dict data including obs and action tensor.
+ - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
+ Returns:
+ - output (:obj:`Dict[str, torch.Tensor]`): The output dict of QAC forward computation graph, whose \
+ key-values vary in different forward modes.
+ Examples (Actor):
+ >>> # Regression mode
+ >>> model = ContinuousQAC(64, 6, 'regression')
+ >>> obs = torch.randn(4, 64)
+ >>> actor_outputs = model(obs,'compute_actor')
+ >>> assert actor_outputs['action'].shape == torch.Size([4, 6])
+ >>> # Reparameterization Mode
+ >>> model = ContinuousQAC(64, 6, 'reparameterization')
+ >>> obs = torch.randn(4, 64)
+ >>> actor_outputs = model(obs,'compute_actor')
+ >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6]) # mu
+ >>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma
+
+ Examples (Critic):
+ >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
+ >>> model = ContinuousQAC(obs_shape=(8, ),action_shape=1, action_space='regression')
+ >>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, ) # q value
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_actor(self, obs: torch.Tensor) -> Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]:
+ """
+ Overview:
+ QAC forward computation graph for actor part, input observation tensor to predict action or action logit.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input observation tensor data.
+ Returns:
+ - outputs (:obj:`Dict[str, Union[torch.Tensor, Dict[str, torch.Tensor]]]`): Actor output dict varying \
+ from action_space: ``regression``, ``reparameterization``, ``hybrid``.
+ ReturnsKeys (regression):
+ - action (:obj:`torch.Tensor`): Continuous action with same size as ``action_shape``, usually in DDPG/TD3.
+ ReturnsKeys (reparameterization):
+ - logit (:obj:`Dict[str, torch.Tensor]`): The predictd reparameterization action logit, usually in SAC. \
+ It is a list containing two tensors: ``mu`` and ``sigma``. The former is the mean of the gaussian \
+ distribution, the latter is the standard deviation of the gaussian distribution.
+ ReturnsKeys (hybrid):
+ - logit (:obj:`torch.Tensor`): The predicted discrete action type logit, it will be the same dimension \
+ as ``action_type_shape``, i.e., all the possible discrete action types.
+ - action_args (:obj:`torch.Tensor`): Continuous action arguments with same size as ``action_args_shape``.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``.
+ - action (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
+ - logit.mu (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size and N1 corresponds to ``action_shape``.
+ - logit.sigma (:obj:`torch.Tensor`): :math:`(B, N1)`, B is batch size.
+ - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \
+ ``action_shape.action_type_shape``.
+ - action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \
+ ``action_shape.action_args_shape``.
+ Examples:
+ >>> # Regression mode
+ >>> model = ContinuousQAC(64, 6, 'regression')
+ >>> obs = torch.randn(4, 64)
+ >>> actor_outputs = model(obs,'compute_actor')
+ >>> assert actor_outputs['action'].shape == torch.Size([4, 6])
+ >>> # Reparameterization Mode
+ >>> model = ContinuousQAC(64, 6, 'reparameterization')
+ >>> obs = torch.randn(4, 64)
+ >>> actor_outputs = model(obs,'compute_actor')
+ >>> assert actor_outputs['logit'][0].shape == torch.Size([4, 6]) # mu
+ >>> actor_outputs['logit'][1].shape == torch.Size([4, 6]) # sigma
+ """
+ obs = self.actor_encoder(obs)
+ if self.action_space == 'regression':
+ x = self.actor_head(obs)
+ return {'action': x['pred']}
+ elif self.action_space == 'reparameterization':
+ x = self.actor_head(obs)
+ return {'logit': [x['mu'], x['sigma']]}
+ elif self.action_space == 'hybrid':
+ logit = self.actor_head[0](obs)
+ action_args = self.actor_head[1](obs)
+ return {'logit': logit['logit'], 'action_args': action_args['pred']}
+
+ def compute_critic(self, inputs: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ QAC forward computation graph for critic part, input observation and action tensor to predict Q-value.
+ Arguments:
+ - inputs (:obj:`Dict[str, torch.Tensor]`): The dict of input data, including ``obs`` and ``action`` \
+ tensor, also contains ``logit`` and ``action_args`` tensor in hybrid action_space.
+ ArgumentsKeys:
+ - obs: (:obj:`torch.Tensor`): Observation tensor data, now supports a batch of 1-dim vector data.
+ - action (:obj:`Union[torch.Tensor, Dict]`): Continuous action with same size as ``action_shape``.
+ - logit (:obj:`torch.Tensor`): Discrete action logit, only in hybrid action_space.
+ - action_args (:obj:`torch.Tensor`): Continuous action arguments, only in hybrid action_space.
+ Returns:
+ - outputs (:obj:`Dict[str, torch.Tensor]`): The output dict of QAC's forward computation graph for critic, \
+ including ``q_value``.
+ ReturnKeys:
+ - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``.
+ - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \
+ ``action_shape.action_type_shape``.
+ - action_args (:obj:`torch.Tensor`): :math:`(B, N3)`, B is batch size and N3 corresponds to \
+ ``action_shape.action_args_shape``.
+ - action (:obj:`torch.Tensor`): :math:`(B, N4)`, where B is batch size and N4 is ``action_shape``.
+ - q_value (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch size.
+
+ Examples:
+ >>> inputs = {'obs': torch.randn(4, 8), 'action': torch.randn(4, 1)}
+ >>> model = ContinuousQAC(obs_shape=(8, ),action_shape=1, action_space='regression')
+ >>> assert model(inputs, mode='compute_critic')['q_value'].shape == (4, ) # q value
+ """
+
+ obs, action = inputs['obs'], inputs['action']
+ obs = self.critic_encoder(obs)
+ assert len(obs.shape) == 2
+ if self.action_space == 'hybrid':
+ action_type_logit = inputs['logit']
+ action_type_logit = torch.softmax(action_type_logit, dim=-1)
+ action_args = action['action_args']
+ if len(action_args.shape) == 1:
+ action_args = action_args.unsqueeze(1)
+ x = torch.cat([obs, action_type_logit, action_args], dim=1)
+ else:
+ if len(action.shape) == 1: # (B, ) -> (B, 1)
+ action = action.unsqueeze(1)
+ x = torch.cat([obs, action], dim=1)
+ if self.twin_critic:
+ x = [m(x)['pred'] for m in self.critic_head]
+ else:
+ x = self.critic_head(x)['pred']
+ return {'q_value': x}
+
+
+@MODEL_REGISTRY.register('discrete_qac')
+class DiscreteQAC(nn.Module):
+ """
+ Overview:
+ The neural network and computation graph of algorithms related to discrete action Q-value Actor-Critic (QAC), \
+ such as DiscreteSAC. This model now supports only discrete action space. The DiscreteQAC is composed of \
+ four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders are used to \
+ extract the feature from various observation. Heads are used to predict corresponding Q-value or action logit. \
+ In high-dimensional observation space like 2D image, we often use a shared encoder for both ``actor_encoder`` \
+ and ``critic_encoder``. In low-dimensional observation space like 1D vector, we often use different encoders.
+ Interfaces:
+ ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
+ """
+ mode = ['compute_actor', 'compute_critic']
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ twin_critic: bool = False,
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 1,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ encoder_hidden_size_list: SequenceType = None,
+ share_encoder: Optional[bool] = False,
+ ) -> None:
+ """
+ Overview:
+ Initailize the DiscreteQAC Model according to input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation's shape, such as 128, (156, ).
+ - action_shape (:obj:`Union[int, SequenceType, EasyDict]`): Action's shape, such as 4, (3, ).
+ - twin_critic (:obj:`bool`): Whether to use twin critic.
+ - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor head.
+ - actor_head_layer_num (:obj:`int`): The num of layers used in the actor network to compute action.
+ - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic head.
+ - critic_head_layer_num (:obj:`int`): The num of layers used in the critic network to compute Q-value.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function to use in ``MLP`` \
+ after each FC layer, if ``None`` then default set to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`): The type of normalization to after network layer (FC, Conv), \
+ see ``ding.torch_utils.network`` for more details.
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
+ the last element must match ``head_hidden_size``, this argument is only used in image observation.
+ - share_encoder (:obj:`Optional[bool]`): Whether to share encoder between actor and critic.
+ """
+ super(DiscreteQAC, self).__init__()
+ obs_shape: int = squeeze(obs_shape)
+ action_shape: int = squeeze(action_shape)
+ # encoder
+ self.share_encoder = share_encoder
+ if np.isscalar(obs_shape) or len(obs_shape) == 1:
+ assert not self.share_encoder, "Vector observation doesn't need share encoder."
+ assert encoder_hidden_size_list is None, "Vector obs encoder only uses one layer nn.Linear"
+ # Because there is already a layer nn.Linear in the head, so we use nn.Identity here to keep
+ # compatible with the image observation and avoid adding an extra layer nn.Linear.
+ self.actor_encoder = nn.Identity()
+ self.critic_encoder = nn.Identity()
+ encoder_output_size = obs_shape
+ elif len(obs_shape) == 3:
+
+ def setup_conv_encoder():
+ kernel_size = [3 for _ in range(len(encoder_hidden_size_list))]
+ stride = [2] + [1 for _ in range(len(encoder_hidden_size_list) - 1)]
+ return ConvEncoder(
+ obs_shape,
+ encoder_hidden_size_list,
+ activation=activation,
+ norm_type=norm_type,
+ kernel_size=kernel_size,
+ stride=stride
+ )
+
+ if self.share_encoder:
+ encoder = setup_conv_encoder()
+ self.actor_encoder = self.critic_encoder = encoder
+ else:
+ self.actor_encoder = setup_conv_encoder()
+ self.critic_encoder = setup_conv_encoder()
+ encoder_output_size = self.actor_encoder.output_size
+ else:
+ raise RuntimeError("not support observation shape: {}".format(obs_shape))
+
+ # head
+ self.actor_head = nn.Sequential(
+ nn.Linear(encoder_output_size, actor_head_hidden_size), activation,
+ DiscreteHead(
+ actor_head_hidden_size, action_shape, actor_head_layer_num, activation=activation, norm_type=norm_type
+ )
+ )
+
+ self.twin_critic = twin_critic
+ if self.twin_critic:
+ self.critic_head = nn.ModuleList()
+ for _ in range(2):
+ self.critic_head.append(
+ nn.Sequential(
+ nn.Linear(encoder_output_size, critic_head_hidden_size), activation,
+ DiscreteHead(
+ critic_head_hidden_size,
+ action_shape,
+ critic_head_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ )
+ else:
+ self.critic_head = nn.Sequential(
+ nn.Linear(encoder_output_size, critic_head_hidden_size), activation,
+ DiscreteHead(
+ critic_head_hidden_size,
+ action_shape,
+ critic_head_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ # Convenient for calling some apis (e.g. self.critic.parameters()),
+ # but may cause misunderstanding when `print(self)`
+ self.actor = nn.ModuleList([self.actor_encoder, self.actor_head])
+ self.critic = nn.ModuleList([self.critic_encoder, self.critic_head])
+
+ def forward(self, inputs: torch.Tensor, mode: str) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ QAC forward computation graph, input observation tensor to predict Q-value or action logit. Different \
+ ``mode`` will forward with different network modules to get different outputs and save computation.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): The input observation tensor data.
+ - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
+ Returns:
+ - output (:obj:`Dict[str, torch.Tensor]`): The output dict of QAC forward computation graph, whose \
+ key-values vary in different forward modes.
+ Examples (Actor):
+ >>> model = DiscreteQAC(64, 6)
+ >>> obs = torch.randn(4, 64)
+ >>> actor_outputs = model(obs,'compute_actor')
+ >>> assert actor_outputs['logit'].shape == torch.Size([4, 6])
+
+ Examples(Critic):
+ >>> model = DiscreteQAC(64, 6, twin_critic=False)
+ >>> obs = torch.randn(4, 64)
+ >>> actor_outputs = model(obs,'compute_critic')
+ >>> assert actor_outputs['q_value'].shape == torch.Size([4, 6])
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_actor(self, inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ QAC forward computation graph for actor part, input observation tensor to predict action or action logit.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): The input observation tensor data.
+ Returns:
+ - outputs (:obj:`Dict[str, torch.Tensor]`): The output dict of QAC forward computation graph for actor, \
+ including discrete action ``logit``.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): The predicted discrete action type logit, it will be the same dimension \
+ as ``action_shape``, i.e., all the possible discrete action choices.
+ Shapes:
+ - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``obs_shape``.
+ - logit (:obj:`torch.Tensor`): :math:`(B, N2)`, B is batch size and N2 corresponds to \
+ ``action_shape``.
+ Examples:
+ >>> model = DiscreteQAC(64, 6)
+ >>> obs = torch.randn(4, 64)
+ >>> actor_outputs = model(obs,'compute_actor')
+ >>> assert actor_outputs['logit'].shape == torch.Size([4, 6])
+ """
+ x = self.actor_encoder(inputs)
+ x = self.actor_head(x)
+ return {'logit': x['logit']}
+
+ def compute_critic(self, inputs: torch.Tensor) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ QAC forward computation graph for critic part, input observation to predict Q-value for each possible \
+ discrete action choices.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): The input observation tensor data.
+ Returns:
+ - outputs (:obj:`Dict[str, torch.Tensor]`): The output dict of QAC forward computation graph for critic, \
+ including ``q_value`` for each possible discrete action choices.
+ ReturnKeys:
+ - q_value (:obj:`torch.Tensor`): The predicted Q-value for each possible discrete action choices, it will \
+ be the same dimension as ``action_shape`` and used to calculate the loss.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``.
+ - q_value (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``.
+ Examples:
+ >>> model = DiscreteQAC(64, 6, twin_critic=False)
+ >>> obs = torch.randn(4, 64)
+ >>> actor_outputs = model(obs,'compute_critic')
+ >>> assert actor_outputs['q_value'].shape == torch.Size([4, 6])
+ """
+ inputs = self.critic_encoder(inputs)
+ if self.twin_critic:
+ x = [m(inputs)['logit'] for m in self.critic_head]
+ else:
+ x = self.critic_head(inputs)['logit']
+ return {'q_value': x}
diff --git a/DI-engine/ding/model/template/qac_dist.py b/DI-engine/ding/model/template/qac_dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9390cb06e0c7dcc00c8cbfe1bdca8d3eae5dd30
--- /dev/null
+++ b/DI-engine/ding/model/template/qac_dist.py
@@ -0,0 +1,247 @@
+from typing import Union, Dict, Optional
+import torch
+import torch.nn as nn
+
+from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
+from ..common import RegressionHead, ReparameterizationHead, DistributionHead
+
+
+@MODEL_REGISTRY.register('qac_dist')
+class QACDIST(nn.Module):
+ """
+ Overview:
+ The QAC model with distributional Q-value.
+ Interfaces:
+ ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``
+ """
+ mode = ['compute_actor', 'compute_critic']
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ action_space: str = "regression",
+ critic_head_type: str = "categorical",
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 1,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ v_min: Optional[float] = -10,
+ v_max: Optional[float] = 10,
+ n_atom: Optional[int] = 51,
+ ) -> None:
+ """
+ Overview:
+ Init the QAC Distributional Model according to arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation's space.
+ - action_shape (:obj:`Union[int, SequenceType]`): Action's space.
+ - action_space (:obj:`str`): Whether choose ``regression`` or ``reparameterization``.
+ - critic_head_type (:obj:`str`): Only ``categorical``.
+ - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to actor-nn's ``Head``.
+ - actor_head_layer_num (:obj:`int`):
+ The num of layers used in the network to compute Q value output for actor's nn.
+ - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` to pass to critic-nn's ``Head``.
+ - critic_head_layer_num (:obj:`int`):
+ The num of layers used in the network to compute Q value output for critic's nn.
+ - activation (:obj:`Optional[nn.Module]`):
+ The type of activation function to use in ``MLP`` the after ``layer_fn``,
+ if ``None`` then default set to ``nn.ReLU()``
+ - norm_type (:obj:`Optional[str]`):
+ The type of normalization to use, see ``ding.torch_utils.fc_block`` for more details.
+ - v_min (:obj:`int`): Value of the smallest atom
+ - v_max (:obj:`int`): Value of the largest atom
+ - n_atom (:obj:`int`): Number of atoms in the support
+ """
+ super(QACDIST, self).__init__()
+ obs_shape: int = squeeze(obs_shape)
+ action_shape: int = squeeze(action_shape)
+ self.action_space = action_space
+ assert self.action_space in ['regression', 'reparameterization']
+ if self.action_space == 'regression':
+ self.actor = nn.Sequential(
+ nn.Linear(obs_shape, actor_head_hidden_size), activation,
+ RegressionHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ final_tanh=True,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ elif self.action_space == 'reparameterization':
+ self.actor = nn.Sequential(
+ nn.Linear(obs_shape, actor_head_hidden_size), activation,
+ ReparameterizationHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ sigma_type='conditioned',
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+ self.critic_head_type = critic_head_type
+ assert self.critic_head_type in ['categorical'], self.critic_head_type
+ if self.critic_head_type == 'categorical':
+ self.critic = nn.Sequential(
+ nn.Linear(obs_shape + action_shape, critic_head_hidden_size), activation,
+ DistributionHead(
+ critic_head_hidden_size,
+ 1,
+ critic_head_layer_num,
+ n_atom=n_atom,
+ v_min=v_min,
+ v_max=v_max,
+ activation=activation,
+ norm_type=norm_type
+ )
+ )
+
+ def forward(self, inputs: Union[torch.Tensor, Dict], mode: str) -> Dict:
+ """
+ Overview:
+ Use observation and action tensor to predict output.
+ Parameter updates with QACDIST's MLPs forward setup.
+ Arguments:
+ Forward with ``'compute_actor'``:
+ - inputs (:obj:`torch.Tensor`):
+ The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
+ Whether ``actor_head_hidden_size`` or ``critic_head_hidden_size`` depend on ``mode``.
+
+ Forward with ``'compute_critic'``, inputs (`Dict`) Necessary Keys:
+ - ``obs``, ``action`` encoded tensors.
+
+ - mode (:obj:`str`): Name of the forward mode.
+ Returns:
+ - outputs (:obj:`Dict`): Outputs of network forward.
+
+ Forward with ``'compute_actor'``, Necessary Keys (either):
+ - action (:obj:`torch.Tensor`): Action tensor with same size as input ``x``.
+ - logit (:obj:`torch.Tensor`):
+ Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``.
+
+ Forward with ``'compute_critic'``, Necessary Keys:
+ - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
+ - distribution (:obj:`torch.Tensor`): Q value distribution tensor.
+ Actor Shapes:
+ - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
+ - action (:obj:`torch.Tensor`): :math:`(B, N0)`
+ - q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size.
+
+ Critic Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
+ - action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is``action_shape``
+ - q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
+ - distribution (:obj:`torch.FloatTensor`): :math:`(B, 1, N3)`, where B is batch size and N3 is ``num_atom``
+
+ Actor Examples:
+ >>> # Regression mode
+ >>> model = QACDIST(64, 64, 'regression')
+ >>> inputs = torch.randn(4, 64)
+ >>> actor_outputs = model(inputs,'compute_actor')
+ >>> assert actor_outputs['action'].shape == torch.Size([4, 64])
+ >>> # Reparameterization Mode
+ >>> model = QACDIST(64, 64, 'reparameterization')
+ >>> inputs = torch.randn(4, 64)
+ >>> actor_outputs = model(inputs,'compute_actor')
+ >>> actor_outputs['logit'][0].shape # mu
+ >>> torch.Size([4, 64])
+ >>> actor_outputs['logit'][1].shape # sigma
+ >>> torch.Size([4, 64])
+
+ Critic Examples:
+ >>> # Categorical mode
+ >>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)}
+ >>> model = QACDIST(obs_shape=(N, ),action_shape=1,action_space='regression', \
+ ... critic_head_type='categorical', n_atoms=51)
+ >>> q_value = model(inputs, mode='compute_critic') # q value
+ >>> assert q_value['q_value'].shape == torch.Size([4, 1])
+ >>> assert q_value['distribution'].shape == torch.Size([4, 1, 51])
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(inputs)
+
+ def compute_actor(self, inputs: torch.Tensor) -> Dict:
+ """
+ Overview:
+ Use encoded embedding tensor to predict output.
+ Execute parameter updates with ``'compute_actor'`` mode
+ Use encoded embedding tensor to predict output.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`):
+ The encoded embedding tensor, determined with given ``hidden_size``, i.e. ``(B, N=hidden_size)``.
+ ``hidden_size = actor_head_hidden_size``
+ - mode (:obj:`str`): Name of the forward mode.
+ Returns:
+ - outputs (:obj:`Dict`): Outputs of forward pass encoder and head.
+
+ ReturnsKeys (either):
+ - action (:obj:`torch.Tensor`): Continuous action tensor with same size as ``action_shape``.
+ - logit (:obj:`torch.Tensor`):
+ Logit tensor encoding ``mu`` and ``sigma``, both with same size as input ``x``.
+ Shapes:
+ - inputs (:obj:`torch.Tensor`): :math:`(B, N0)`, B is batch size and N0 corresponds to ``hidden_size``
+ - action (:obj:`torch.Tensor`): :math:`(B, N0)`
+ - logit (:obj:`list`): 2 elements, mu and sigma, each is the shape of :math:`(B, N0)`.
+ - q_value (:obj:`torch.FloatTensor`): :math:`(B, )`, B is batch size.
+ Examples:
+ >>> # Regression mode
+ >>> model = QACDIST(64, 64, 'regression')
+ >>> inputs = torch.randn(4, 64)
+ >>> actor_outputs = model(inputs,'compute_actor')
+ >>> assert actor_outputs['action'].shape == torch.Size([4, 64])
+ >>> # Reparameterization Mode
+ >>> model = QACDIST(64, 64, 'reparameterization')
+ >>> inputs = torch.randn(4, 64)
+ >>> actor_outputs = model(inputs,'compute_actor')
+ >>> actor_outputs['logit'][0].shape # mu
+ >>> torch.Size([4, 64])
+ >>> actor_outputs['logit'][1].shape # sigma
+ >>> torch.Size([4, 64])
+ """
+ x = self.actor(inputs)
+ if self.action_space == 'regression':
+ return {'action': x['pred']}
+ elif self.action_space == 'reparameterization':
+ return {'logit': [x['mu'], x['sigma']]}
+
+ def compute_critic(self, inputs: Dict) -> Dict:
+ """
+ Overview:
+ Execute parameter updates with ``'compute_critic'`` mode
+ Use encoded embedding tensor to predict output.
+ Arguments:
+ - ``obs``, ``action`` encoded tensors.
+ - mode (:obj:`str`): Name of the forward mode.
+ Returns:
+ - outputs (:obj:`Dict`): Q-value output and distribution.
+
+ ReturnKeys:
+ - q_value (:obj:`torch.Tensor`): Q value tensor with same size as batch size.
+ - distribution (:obj:`torch.Tensor`): Q value distribution tensor.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, N1)`, where B is batch size and N1 is ``obs_shape``
+ - action (:obj:`torch.Tensor`): :math:`(B, N2)`, where B is batch size and N2 is``action_shape``
+ - q_value (:obj:`torch.FloatTensor`): :math:`(B, N2)`, where B is batch size and N2 is ``action_shape``
+ - distribution (:obj:`torch.FloatTensor`): :math:`(B, 1, N3)`, where B is batch size and N3 is ``num_atom``
+
+ Examples:
+ >>> # Categorical mode
+ >>> inputs = {'obs': torch.randn(4,N), 'action': torch.randn(4,1)}
+ >>> model = QACDIST(obs_shape=(N, ),action_shape=1,action_space='regression', \
+ ... critic_head_type='categorical', n_atoms=51)
+ >>> q_value = model(inputs, mode='compute_critic') # q value
+ >>> assert q_value['q_value'].shape == torch.Size([4, 1])
+ >>> assert q_value['distribution'].shape == torch.Size([4, 1, 51])
+ """
+ obs, action = inputs['obs'], inputs['action']
+ assert len(obs.shape) == 2
+ if len(action.shape) == 1: # (B, ) -> (B, 1)
+ action = action.unsqueeze(1)
+ x = torch.cat([obs, action], dim=1)
+ x = self.critic(x)
+ return {'q_value': x['logit'], 'distribution': x['distribution']}
diff --git a/DI-engine/ding/model/template/qmix.py b/DI-engine/ding/model/template/qmix.py
new file mode 100644
index 0000000000000000000000000000000000000000..68354e0cf7441c64638688b9ece17110c1524ab2
--- /dev/null
+++ b/DI-engine/ding/model/template/qmix.py
@@ -0,0 +1,219 @@
+from typing import Union, List
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import reduce
+from ding.utils import list_split, MODEL_REGISTRY
+from ding.torch_utils import fc_block, MLP
+from .q_learning import DRQN
+
+
+class Mixer(nn.Module):
+ """
+ Overview:
+ Mixer network in QMIX, which mix up the independent q_value of each agent to a total q_value. \
+ The weights (but not the biases) of the Mixer network are restricted to be non-negative and \
+ produced by separate hypernetworks. Each hypernetwork takes the globle state s as input and generates \
+ the weights of one layer of the Mixer network.
+ Interface:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ agent_num: int,
+ state_dim: int,
+ mixing_embed_dim: int,
+ hypernet_embed: int = 64,
+ activation: nn.Module = nn.ReLU()
+ ):
+ """
+ Overview:
+ Initialize mixer network proposed in QMIX according to arguments. Each hypernetwork consists of \
+ linear layers, followed by an absolute activation function, to ensure that the Mixer network weights are \
+ non-negative.
+ Arguments:
+ - agent_num (:obj:`int`): The number of agent, such as 8.
+ - state_dim(:obj:`int`): The dimension of global observation state, such as 16.
+ - mixing_embed_dim (:obj:`int`): The dimension of mixing state emdedding, such as 128.
+ - hypernet_embed (:obj:`int`): The dimension of hypernet emdedding, default to 64.
+ - activation (:obj:`nn.Module`): Activation function in network, defaults to nn.ReLU().
+ """
+ super(Mixer, self).__init__()
+
+ self.n_agents = agent_num
+ self.state_dim = state_dim
+ self.embed_dim = mixing_embed_dim
+ self.act = activation
+ self.hyper_w_1 = nn.Sequential(
+ nn.Linear(self.state_dim, hypernet_embed), self.act,
+ nn.Linear(hypernet_embed, self.embed_dim * self.n_agents)
+ )
+ self.hyper_w_final = nn.Sequential(
+ nn.Linear(self.state_dim, hypernet_embed), self.act, nn.Linear(hypernet_embed, self.embed_dim)
+ )
+
+ # state dependent bias for hidden layer
+ self.hyper_b_1 = nn.Linear(self.state_dim, self.embed_dim)
+
+ # V(s) instead of a bias for the last layers
+ self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), self.act, nn.Linear(self.embed_dim, 1))
+
+ def forward(self, agent_qs, states):
+ """
+ Overview:
+ Forward computation graph of pymarl mixer network. Mix up the input independent q_value of each agent \
+ to a total q_value with weights generated by hypernetwork according to global ``states``.
+ Arguments:
+ - agent_qs (:obj:`torch.FloatTensor`): The independent q_value of each agent.
+ - states (:obj:`torch.FloatTensor`): The emdedding vector of global state.
+ Returns:
+ - q_tot (:obj:`torch.FloatTensor`): The total mixed q_value.
+ Shapes:
+ - agent_qs (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is agent_num.
+ - states (:obj:`torch.FloatTensor`): :math:`(B, M)`, where M is embedding_size.
+ - q_tot (:obj:`torch.FloatTensor`): :math:`(B, )`.
+ """
+ bs = agent_qs.shape[:-1]
+ states = states.reshape(-1, self.state_dim)
+ agent_qs = agent_qs.view(-1, 1, self.n_agents)
+ # First layer
+ w1 = torch.abs(self.hyper_w_1(states))
+ b1 = self.hyper_b_1(states)
+ w1 = w1.view(-1, self.n_agents, self.embed_dim)
+ b1 = b1.view(-1, 1, self.embed_dim)
+ hidden = F.elu(torch.bmm(agent_qs, w1) + b1)
+ # Second layer
+ w_final = torch.abs(self.hyper_w_final(states))
+ w_final = w_final.view(-1, self.embed_dim, 1)
+ # State-dependent bias
+ v = self.V(states).view(-1, 1, 1)
+ # Compute final output
+ y = torch.bmm(hidden, w_final) + v
+ # Reshape and return
+ q_tot = y.view(*bs)
+ return q_tot
+
+
+@MODEL_REGISTRY.register('qmix')
+class QMix(nn.Module):
+ """
+ Overview:
+ The neural network and computation graph of algorithms related to QMIX(https://arxiv.org/abs/1803.11485). \
+ The QMIX is composed of two parts: agent Q network and mixer(optional). The QMIX paper mentions that all \
+ agents share local Q network parameters, so only one Q network is initialized here. Then use summation or \
+ Mixer network to process the local Q according to the ``mixer`` settings to obtain the global Q.
+ Interface:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ agent_num: int,
+ obs_shape: int,
+ global_obs_shape: int,
+ action_shape: int,
+ hidden_size_list: list,
+ mixer: bool = True,
+ lstm_type: str = 'gru',
+ activation: nn.Module = nn.ReLU(),
+ dueling: bool = False
+ ) -> None:
+ """
+ Overview:
+ Initialize QMIX neural network according to arguments, i.e. agent Q network and mixer.
+ Arguments:
+ - agent_num (:obj:`int`): The number of agent, such as 8.
+ - obs_shape (:obj:`int`): The dimension of each agent's observation state, such as 8 or [4, 84, 84].
+ - global_obs_shape (:obj:`int`): The dimension of global observation state, such as 8 or [4, 84, 84].
+ - action_shape (:obj:`int`): The dimension of action shape, such as 6 or [2, 3, 3].
+ - hidden_size_list (:obj:`list`): The list of hidden size for ``q_network``, \
+ the last element must match mixer's ``mixing_embed_dim``.
+ - mixer (:obj:`bool`): Use mixer net or not, default to True. If it is false, \
+ the final local Q is added to obtain the global Q.
+ - lstm_type (:obj:`str`): The type of RNN module in ``q_network``, now support \
+ ['normal', 'pytorch', 'gru'], default to gru.
+ - activation (:obj:`nn.Module`): The type of activation function to use in ``MLP`` the after \
+ ``layer_fn``, if ``None`` then default set to ``nn.ReLU()``.
+ - dueling (:obj:`bool`): Whether choose ``DuelingHead`` (True) or ``DiscreteHead (False)``, \
+ default to False.
+ """
+ super(QMix, self).__init__()
+ self._act = activation
+ self._q_network = DRQN(
+ obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling, activation=activation
+ )
+ embedding_size = hidden_size_list[-1]
+ self.mixer = mixer
+ if self.mixer:
+ self._mixer = Mixer(agent_num, global_obs_shape, embedding_size, activation=activation)
+ self._global_state_encoder = nn.Identity()
+
+ def forward(self, data: dict, single_step: bool = True) -> dict:
+ """
+ Overview:
+ QMIX forward computation graph, input dict including time series observation and related data to predict \
+ total q_value and each agent q_value.
+ Arguments:
+ - data (:obj:`dict`): Input data dict with keys ['obs', 'prev_state', 'action'].
+ - agent_state (:obj:`torch.Tensor`): Time series local observation data of each agents.
+ - global_state (:obj:`torch.Tensor`): Time series global observation data.
+ - prev_state (:obj:`list`): Previous rnn state for ``q_network``.
+ - action (:obj:`torch.Tensor` or None): The actions of each agent given outside the function. \
+ If action is None, use argmax q_value index as action to calculate ``agent_q_act``.
+ - single_step (:obj:`bool`): Whether single_step forward, if so, add timestep dim before forward and\
+ remove it after forward.
+ Returns:
+ - ret (:obj:`dict`): Output data dict with keys [``total_q``, ``logit``, ``next_state``].
+ ReturnsKeys:
+ - total_q (:obj:`torch.Tensor`): Total q_value, which is the result of mixer network.
+ - agent_q (:obj:`torch.Tensor`): Each agent q_value.
+ - next_state (:obj:`list`): Next rnn state for ``q_network``.
+ Shapes:
+ - agent_state (:obj:`torch.Tensor`): :math:`(T, B, A, N)`, where T is timestep, B is batch_size\
+ A is agent_num, N is obs_shape.
+ - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape.
+ - prev_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A.
+ - action (:obj:`torch.Tensor`): :math:`(T, B, A)`.
+ - total_q (:obj:`torch.Tensor`): :math:`(T, B)`.
+ - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape.
+ - next_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A.
+ """
+ agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[
+ 'prev_state']
+ action = data.get('action', None)
+ if single_step:
+ agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0)
+ T, B, A = agent_state.shape[:3]
+ assert len(prev_state) == B and all(
+ [len(p) == A for p in prev_state]
+ ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0]))
+ prev_state = reduce(lambda x, y: x + y, prev_state)
+ agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
+ output = self._q_network({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
+ agent_q, next_state = output['logit'], output['next_state']
+ next_state, _ = list_split(next_state, step=A)
+ agent_q = agent_q.reshape(T, B, A, -1)
+ if action is None:
+ # for target forward process
+ if len(data['obs']['action_mask'].shape) == 3:
+ action_mask = data['obs']['action_mask'].unsqueeze(0)
+ else:
+ action_mask = data['obs']['action_mask']
+ agent_q[action_mask == 0.0] = -9999999
+ action = agent_q.argmax(dim=-1)
+ agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
+ agent_q_act = agent_q_act.squeeze(-1) # T, B, A
+ if self.mixer:
+ global_state_embedding = self._global_state_encoder(global_state)
+ total_q = self._mixer(agent_q_act, global_state_embedding)
+ else:
+ total_q = agent_q_act.sum(-1)
+ if single_step:
+ total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0)
+ return {
+ 'total_q': total_q,
+ 'logit': agent_q,
+ 'next_state': next_state,
+ 'action_mask': data['obs']['action_mask']
+ }
diff --git a/DI-engine/ding/model/template/qtran.py b/DI-engine/ding/model/template/qtran.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e627f1d15721b7238feab2bf8ba191bb3210091
--- /dev/null
+++ b/DI-engine/ding/model/template/qtran.py
@@ -0,0 +1,143 @@
+from typing import Union, List
+import torch
+import numpy as np
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import reduce
+from ding.utils import list_split, squeeze, MODEL_REGISTRY
+from ding.torch_utils.network.nn_module import fc_block, MLP
+from ding.torch_utils.network.transformer import ScaledDotProductAttention
+from ding.torch_utils import to_tensor, tensor_to_list
+from .q_learning import DRQN
+
+
+@MODEL_REGISTRY.register('qtran')
+class QTran(nn.Module):
+ """
+ Overview:
+ QTRAN network
+ Interface:
+ __init__, forward
+ """
+
+ def __init__(
+ self,
+ agent_num: int,
+ obs_shape: int,
+ global_obs_shape: int,
+ action_shape: int,
+ hidden_size_list: list,
+ embedding_size: int,
+ lstm_type: str = 'gru',
+ dueling: bool = False
+ ) -> None:
+ """
+ Overview:
+ initialize QTRAN network
+ Arguments:
+ - agent_num (:obj:`int`): the number of agent
+ - obs_shape (:obj:`int`): the dimension of each agent's observation state
+ - global_obs_shape (:obj:`int`): the dimension of global observation state
+ - action_shape (:obj:`int`): the dimension of action shape
+ - hidden_size_list (:obj:`list`): the list of hidden size
+ - embedding_size (:obj:`int`): the dimension of embedding
+ - lstm_type (:obj:`str`): use lstm or gru, default to gru
+ - dueling (:obj:`bool`): use dueling head or not, default to False.
+ """
+ super(QTran, self).__init__()
+ self._act = nn.ReLU()
+ self._q_network = DRQN(obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling)
+ q_input_size = global_obs_shape + hidden_size_list[-1] + action_shape
+ self.Q = nn.Sequential(
+ nn.Linear(q_input_size, embedding_size), nn.ReLU(), nn.Linear(embedding_size, embedding_size), nn.ReLU(),
+ nn.Linear(embedding_size, 1)
+ )
+
+ # V(s)
+ self.V = nn.Sequential(
+ nn.Linear(global_obs_shape, embedding_size), nn.ReLU(), nn.Linear(embedding_size, embedding_size),
+ nn.ReLU(), nn.Linear(embedding_size, 1)
+ )
+ ae_input = hidden_size_list[-1] + action_shape
+ self.action_encoding = nn.Sequential(nn.Linear(ae_input, ae_input), nn.ReLU(), nn.Linear(ae_input, ae_input))
+
+ def forward(self, data: dict, single_step: bool = True) -> dict:
+ """
+ Overview:
+ forward computation graph of qtran network
+ Arguments:
+ - data (:obj:`dict`): input data dict with keys ['obs', 'prev_state', 'action']
+ - agent_state (:obj:`torch.Tensor`): each agent local state(obs)
+ - global_state (:obj:`torch.Tensor`): global state(obs)
+ - prev_state (:obj:`list`): previous rnn state
+ - action (:obj:`torch.Tensor` or None): if action is None, use argmax q_value index as action to\
+ calculate ``agent_q_act``
+ - single_step (:obj:`bool`): whether single_step forward, if so, add timestep dim before forward and\
+ remove it after forward
+ Return:
+ - ret (:obj:`dict`): output data dict with keys ['total_q', 'logit', 'next_state']
+ - total_q (:obj:`torch.Tensor`): total q_value, which is the result of mixer network
+ - agent_q (:obj:`torch.Tensor`): each agent q_value
+ - next_state (:obj:`list`): next rnn state
+ Shapes:
+ - agent_state (:obj:`torch.Tensor`): :math:`(T, B, A, N)`, where T is timestep, B is batch_size\
+ A is agent_num, N is obs_shape
+ - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape
+ - prev_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A
+ - action (:obj:`torch.Tensor`): :math:`(T, B, A)`
+ - total_q (:obj:`torch.Tensor`): :math:`(T, B)`
+ - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape
+ - next_state (:obj:`list`): math:`(B, A)`, a list of length B, and each element is a list of length A
+ """
+ agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[
+ 'prev_state']
+ action = data.get('action', None)
+ if single_step:
+ agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0)
+ T, B, A = agent_state.shape[:3]
+ assert len(prev_state) == B and all(
+ [len(p) == A for p in prev_state]
+ ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0]))
+ prev_state = reduce(lambda x, y: x + y, prev_state)
+ agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
+ output = self._q_network({'obs': agent_state, 'prev_state': prev_state, 'enable_fast_timestep': True})
+ agent_q, next_state = output['logit'], output['next_state']
+ next_state, _ = list_split(next_state, step=A)
+ agent_q = agent_q.reshape(T, B, A, -1)
+ if action is None:
+ # For target forward process
+ if len(data['obs']['action_mask'].shape) == 3:
+ action_mask = data['obs']['action_mask'].unsqueeze(0)
+ else:
+ action_mask = data['obs']['action_mask']
+ agent_q[action_mask == 0.0] = -9999999
+ action = agent_q.argmax(dim=-1)
+ agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
+ agent_q_act = agent_q_act.squeeze(-1) # T, B, A
+
+ hidden_states = output['hidden_state'].reshape(T * B, A, -1)
+ action = action.reshape(T * B, A).unsqueeze(-1)
+ action_onehot = torch.zeros(size=(T * B, A, agent_q.shape[-1]), device=action.device)
+ action_onehot = action_onehot.scatter(2, action, 1)
+ agent_state_action_input = torch.cat([hidden_states, action_onehot], dim=2)
+ agent_state_action_encoding = self.action_encoding(agent_state_action_input.reshape(T * B * A,
+ -1)).reshape(T * B, A, -1)
+ agent_state_action_encoding = agent_state_action_encoding.sum(dim=1) # Sum across agents
+
+ inputs = torch.cat([global_state.reshape(T * B, -1), agent_state_action_encoding], dim=1)
+ q_outputs = self.Q(inputs)
+ q_outputs = q_outputs.reshape(T, B)
+ v_outputs = self.V(global_state.reshape(T * B, -1))
+ v_outputs = v_outputs.reshape(T, B)
+ if single_step:
+ q_outputs, agent_q, agent_q_act, v_outputs = q_outputs.squeeze(0), agent_q.squeeze(0), agent_q_act.squeeze(
+ 0
+ ), v_outputs.squeeze(0)
+ return {
+ 'total_q': q_outputs,
+ 'logit': agent_q,
+ 'agent_q_act': agent_q_act,
+ 'vs': v_outputs,
+ 'next_state': next_state,
+ 'action_mask': data['obs']['action_mask']
+ }
diff --git a/DI-engine/ding/model/template/sqn.py b/DI-engine/ding/model/template/sqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d71850a5f9076e97eaf799821c7b4e6551f0b01
--- /dev/null
+++ b/DI-engine/ding/model/template/sqn.py
@@ -0,0 +1,23 @@
+from typing import Dict
+import torch
+import torch.nn as nn
+
+from ding.utils import MODEL_REGISTRY
+from .q_learning import DQN
+
+
+@MODEL_REGISTRY.register('sqn')
+class SQN(nn.Module):
+
+ def __init__(self, *args, **kwargs) -> None:
+ super(SQN, self).__init__()
+ self.q0 = DQN(*args, **kwargs)
+ self.q1 = DQN(*args, **kwargs)
+
+ def forward(self, data: torch.Tensor) -> Dict:
+ output0 = self.q0(data)
+ output1 = self.q1(data)
+ return {
+ 'q_value': [output0['logit'], output1['logit']],
+ 'logit': output0['logit'],
+ }
diff --git a/DI-engine/ding/model/template/tests/test_acer.py b/DI-engine/ding/model/template/tests/test_acer.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c3877335a7e00cea58b00928130cb661a508870
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_acer.py
@@ -0,0 +1,41 @@
+import torch
+import pytest
+from itertools import product
+
+from ding.model.template import ACER
+from ding.torch_utils import is_differentiable
+
+B = 4
+obs_shape = [4, (8, ), (4, 64, 64)]
+act_shape = [3, (6, )]
+args = list(product(*[obs_shape, act_shape]))
+
+
+@pytest.mark.unittest
+class TestACER:
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_ACER(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(B, obs_shape)
+ else:
+ inputs = torch.randn(B, *obs_shape)
+ model = ACER(obs_shape, act_shape)
+
+ outputs_c = model(inputs, mode='compute_critic')
+ assert isinstance(outputs_c, dict)
+ if isinstance(act_shape, int):
+ assert outputs_c['q_value'].shape == (B, act_shape)
+ elif len(act_shape) == 1:
+ assert outputs_c['q_value'].shape == (B, *act_shape)
+
+ outputs_a = model(inputs, mode='compute_actor')
+ assert isinstance(outputs_a, dict)
+ if isinstance(act_shape, int):
+ assert outputs_a['logit'].shape == (B, act_shape)
+ elif len(act_shape) == 1:
+ assert outputs_a['logit'].shape == (B, *act_shape)
+
+ outputs = {**outputs_a, **outputs_c}
+ loss = sum([v.sum() for v in outputs.values()])
+ is_differentiable(loss, model)
diff --git a/DI-engine/ding/model/template/tests/test_atoc.py b/DI-engine/ding/model/template/tests/test_atoc.py
new file mode 100644
index 0000000000000000000000000000000000000000..45a503b3f28b90daaa8ee4c6234d424937f8958a
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_atoc.py
@@ -0,0 +1,64 @@
+import pytest
+import torch
+from ding.model.template.atoc import ATOCActorNet, ATOC
+from ding.torch_utils import is_differentiable
+
+
+@pytest.mark.unittest
+class TestATOC:
+
+ @pytest.mark.tmp
+ def test_actor_net(self):
+ B, A, obs_dim, act_dim, thought_dim = 6, 5, 12, 6, 14
+ torch.autograd.set_detect_anomaly(True)
+ model = ATOCActorNet(obs_dim, thought_dim, act_dim, A, True, 2, initiator_threshold=0.001)
+ for i in range(10):
+ out = model.forward(torch.randn(B, A, obs_dim))
+ assert out['action'].shape == (B, A, act_dim)
+ assert out['group'].shape == (B, A, A)
+ loss1 = out['action'].sum()
+ if i == 0:
+ is_differentiable(loss1, model, print_instead=True)
+ else:
+ loss1.backward()
+
+ def test_qac_net(self):
+ B, A, obs_dim, act_dim, thought_dim = 6, 5, 12, 6, 14
+ model = ATOC(obs_dim, act_dim, thought_dim, A, True, 2, 2)
+
+ # test basic forward path
+
+ optimize_critic = torch.optim.SGD(model.critic.parameters(), 0.1)
+ obs = torch.randn(B, A, obs_dim)
+ act = torch.rand(B, A, act_dim)
+ out = model({'obs': obs, 'action': act}, mode='compute_critic')
+ assert out['q_value'].shape == (B, A)
+ q_loss = out['q_value'].sum()
+ q_loss.backward()
+ optimize_critic.step()
+
+ out = model(obs, mode='compute_actor', get_delta_q=True)
+ assert out['delta_q'].shape == (B, A)
+ assert out['initiator_prob'].shape == (B, A)
+ assert out['is_initiator'].shape == (B, A)
+ optimizer_act = torch.optim.SGD(model.actor.parameters(), 0.1)
+ optimizer_att = torch.optim.SGD(model.actor.attention.parameters(), 0.1)
+
+ obs = torch.randn(B, A, obs_dim)
+ delta_q = model(obs, mode='compute_actor', get_delta_q=True)
+ attention_loss = model(delta_q, mode='optimize_actor_attention')
+ optimizer_att.zero_grad()
+ loss = attention_loss['loss']
+ loss.backward()
+ optimizer_att.step()
+
+ weights = dict(model.actor.named_parameters())
+ output = model(obs, mode='compute_actor')
+ output['obs'] = obs
+ q_loss = model(output, mode='compute_critic')
+ loss = q_loss['q_value'].sum()
+ before_update_weights = model.actor.named_parameters()
+ optimizer_act.zero_grad()
+
+ loss.backward()
+ optimizer_act.step()
diff --git a/DI-engine/ding/model/template/tests/test_bc.py b/DI-engine/ding/model/template/tests/test_bc.py
new file mode 100644
index 0000000000000000000000000000000000000000..17a2075c671dabd956857612db2970b0ac9e5bf6
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_bc.py
@@ -0,0 +1,83 @@
+import torch
+import numpy as np
+import pytest
+from itertools import product
+
+from ding.model.template import DiscreteBC, ContinuousBC
+from ding.torch_utils import is_differentiable
+from ding.utils import squeeze
+
+B = 4
+T = 6
+embedding_size = 32
+action_shape_args = [(6, ), [
+ 1,
+]]
+args = list(product(*[action_shape_args, ['regression', 'reparameterization']]))
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('action_shape, action_space', args)
+class TestContinuousBC:
+
+ def test_continuous_bc(self, action_shape, action_space):
+ N = 32
+ inputs = {'obs': torch.randn(B, N), 'action': torch.randn(B, squeeze(action_shape))}
+ model = ContinuousBC(
+ obs_shape=(N, ),
+ action_shape=action_shape,
+ action_space=action_space,
+ actor_head_hidden_size=embedding_size,
+ )
+ # compute_action
+ print(model)
+ if action_space == 'regression':
+ action = model(inputs['obs'])['action']
+ if squeeze(action_shape) == 1:
+ assert action.shape == (B, )
+ else:
+ assert action.shape == (B, squeeze(action_shape))
+ assert action.eq(action.clamp(-1, 1)).all()
+ is_differentiable(action.sum(), model.actor)
+ elif action_space == 'reparameterization':
+ (mu, sigma) = model(inputs['obs'])['logit']
+ assert mu.shape == (B, *action_shape)
+ assert sigma.shape == (B, *action_shape)
+ is_differentiable(mu.sum() + sigma.sum(), model.actor)
+
+
+T, B = 3, 4
+obs_shape = [4, (8, ), (4, 64, 64)]
+act_shape = [3, (6, ), [2, 3, 6]]
+args = list(product(*[obs_shape, act_shape]))
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('obs_shape, act_shape', args)
+class TestDiscreteBC:
+
+ def output_check(self, model, outputs):
+ if isinstance(outputs, torch.Tensor):
+ loss = outputs.sum()
+ elif isinstance(outputs, list):
+ loss = sum([t.sum() for t in outputs])
+ elif isinstance(outputs, dict):
+ loss = sum([v.sum() for v in outputs.values()])
+ is_differentiable(loss, model)
+
+ def test_discrete_bc(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(B, obs_shape)
+ else:
+ inputs = torch.randn(B, *obs_shape)
+ model = DiscreteBC(obs_shape, act_shape)
+ outputs = model(inputs)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (B, act_shape)
+ elif len(act_shape) == 1:
+ assert outputs['logit'].shape == (B, *act_shape)
+ else:
+ for i, s in enumerate(act_shape):
+ assert outputs['logit'][i].shape == (B, s)
+ self.output_check(model, outputs['logit'])
diff --git a/DI-engine/ding/model/template/tests/test_bcq.py b/DI-engine/ding/model/template/tests/test_bcq.py
new file mode 100644
index 0000000000000000000000000000000000000000..101cfd9b9cdb95f66eaf4a4b072042911351c395
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_bcq.py
@@ -0,0 +1,75 @@
+import pytest
+from itertools import product
+import torch
+from ding.model.template import BCQ
+from ding.torch_utils import is_differentiable
+
+B = 4
+obs_shape = [4, (8, )]
+act_shape = [3, (6, )]
+args = list(product(*[obs_shape, act_shape]))
+
+
+@pytest.mark.unittest
+class TestBCQ:
+
+ def output_check(self, model, outputs):
+ if isinstance(outputs, torch.Tensor):
+ loss = outputs.sum()
+ elif isinstance(outputs, dict):
+ loss = sum([v.sum() for v in outputs.values()])
+ is_differentiable(loss, model)
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_BCQ(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs_obs = torch.randn(B, obs_shape)
+ else:
+ inputs_obs = torch.randn(B, *obs_shape)
+ if isinstance(act_shape, int):
+ inputs_act = torch.randn(B, act_shape)
+ else:
+ inputs_act = torch.randn(B, *act_shape)
+ inputs = {'obs': inputs_obs, 'action': inputs_act}
+ model = BCQ(obs_shape, act_shape)
+
+ outputs_c = model(inputs, mode='compute_critic')
+ assert isinstance(outputs_c, dict)
+ if isinstance(act_shape, int):
+ assert torch.stack(outputs_c['q_value']).shape == (2, B)
+ else:
+ assert torch.stack(outputs_c['q_value']).shape == (2, B)
+ self.output_check(model.critic, torch.stack(outputs_c['q_value']))
+
+ outputs_a = model(inputs, mode='compute_actor')
+ assert isinstance(outputs_a, dict)
+ if isinstance(act_shape, int):
+ assert outputs_a['action'].shape == (B, act_shape)
+ elif len(act_shape) == 1:
+ assert outputs_a['action'].shape == (B, *act_shape)
+ self.output_check(model.actor, outputs_a)
+
+ outputs_vae = model(inputs, mode='compute_vae')
+ assert isinstance(outputs_vae, dict)
+ if isinstance(act_shape, int):
+ assert outputs_vae['recons_action'].shape == (B, act_shape)
+ assert outputs_vae['mu'].shape == (B, act_shape * 2)
+ assert outputs_vae['log_var'].shape == (B, act_shape * 2)
+ assert outputs_vae['z'].shape == (B, act_shape * 2)
+ elif len(act_shape) == 1:
+ assert outputs_vae['recons_action'].shape == (B, *act_shape)
+ assert outputs_vae['mu'].shape == (B, act_shape[0] * 2)
+ assert outputs_vae['log_var'].shape == (B, act_shape[0] * 2)
+ assert outputs_vae['z'].shape == (B, act_shape[0] * 2)
+ if isinstance(obs_shape, int):
+ assert outputs_vae['prediction_residual'].shape == (B, obs_shape)
+ else:
+ assert outputs_vae['prediction_residual'].shape == (B, *obs_shape)
+
+ outputs_eval = model(inputs, mode='compute_eval')
+ assert isinstance(outputs_eval, dict)
+ assert isinstance(outputs_eval, dict)
+ if isinstance(act_shape, int):
+ assert outputs_eval['action'].shape == (B, act_shape)
+ elif len(act_shape) == 1:
+ assert outputs_eval['action'].shape == (B, *act_shape)
diff --git a/DI-engine/ding/model/template/tests/test_collaq.py b/DI-engine/ding/model/template/tests/test_collaq.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf2969655a5d1b612e6c77226fc68254dc3081b2
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_collaq.py
@@ -0,0 +1,53 @@
+import pytest
+import torch
+from ding.torch_utils import is_differentiable
+from ding.model.template import CollaQ
+
+
+@pytest.mark.unittest
+def test_collaQ():
+ use_mixer = [True, False]
+ agent_num, bs, T = 4, 6, 8
+ obs_dim, obs_alone_dim, global_obs_dim, action_dim = 32, 24, 32 * 4, 9
+ self_feature_range = [8, 10]
+ allay_feature_range = [10, 16]
+ embedding_dim = 64
+ for mix in use_mixer:
+ collaQ_model = CollaQ(
+ agent_num,
+ obs_dim,
+ obs_alone_dim,
+ global_obs_dim,
+ action_dim, [128, embedding_dim],
+ True,
+ self_feature_range,
+ allay_feature_range,
+ 32,
+ mix,
+ activation=torch.nn.Tanh()
+ )
+ print(collaQ_model)
+ data = {
+ 'obs': {
+ 'agent_state': torch.randn(T, bs, agent_num, obs_dim),
+ 'agent_alone_state': torch.randn(T, bs, agent_num, obs_alone_dim),
+ 'agent_alone_padding_state': torch.randn(T, bs, agent_num, obs_dim),
+ 'global_state': torch.randn(T, bs, global_obs_dim),
+ 'action_mask': torch.randint(0, 2, size=(T, bs, agent_num, action_dim))
+ },
+ 'prev_state': [[[None for _ in range(agent_num)] for _ in range(3)] for _ in range(bs)],
+ 'action': torch.randint(0, action_dim, size=(T, bs, agent_num))
+ }
+ output = collaQ_model(data, single_step=False)
+ assert set(output.keys()) == set(['total_q', 'logit', 'next_state', 'action_mask', 'agent_colla_alone_q'])
+ assert output['total_q'].shape == (T, bs)
+ assert output['logit'].shape == (T, bs, agent_num, action_dim)
+ assert len(output['next_state']) == bs and all([len(n) == 3 for n in output['next_state']]) and all(
+ [len(q) == agent_num for n in output['next_state'] for q in n]
+ )
+ print(output['next_state'][0][0][0]['h'].shape)
+ # data['prev_state'] = output['next_state']
+ loss = output['total_q'].sum()
+ is_differentiable(loss, collaQ_model)
+ data.pop('action')
+ output = collaQ_model(data, single_step=False)
diff --git a/DI-engine/ding/model/template/tests/test_coma_nn.py b/DI-engine/ding/model/template/tests/test_coma_nn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c747da9aca716c73886cc705908ca85261a2a550
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_coma_nn.py
@@ -0,0 +1,43 @@
+import pytest
+import torch
+from ding.torch_utils import is_differentiable
+from ding.model.template.coma import COMACriticNetwork, COMAActorNetwork
+
+
+@pytest.mark.unittest
+def test_coma_critic():
+ agent_num, bs, T = 4, 3, 8
+ obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
+ coma_model = COMACriticNetwork(obs_dim - action_dim + global_obs_dim + 2 * action_dim * agent_num, action_dim)
+ data = {
+ 'obs': {
+ 'agent_state': torch.randn(T, bs, agent_num, obs_dim),
+ 'global_state': torch.randn(T, bs, global_obs_dim),
+ },
+ 'action': torch.randint(0, action_dim, size=(T, bs, agent_num)),
+ }
+ output = coma_model(data)
+ assert set(output.keys()) == set(['q_value'])
+ assert output['q_value'].shape == (T, bs, agent_num, action_dim)
+ loss = output['q_value'].sum()
+ is_differentiable(loss, coma_model)
+
+
+@pytest.mark.unittest
+def test_rnn_actor_net():
+ T, B, A, N = 4, 8, 3, 32
+ embedding_dim = 64
+ action_dim = 6
+ data = torch.randn(T, B, A, N)
+ model = COMAActorNetwork((N, ), action_dim, [128, embedding_dim])
+ prev_state = [[None for _ in range(A)] for _ in range(B)]
+ for t in range(T):
+ inputs = {'obs': {'agent_state': data[t], 'action_mask': None}, 'prev_state': prev_state}
+ outputs = model(inputs)
+ logit, prev_state = outputs['logit'], outputs['next_state']
+ assert len(prev_state) == B
+ assert all([len(o) == A and all([len(o1) == 2 for o1 in o]) for o in prev_state])
+ assert logit.shape == (B, A, action_dim)
+ # test the last step can backward correctly
+ loss = logit.sum()
+ is_differentiable(loss, model)
diff --git a/DI-engine/ding/model/template/tests/test_decision_transformer.py b/DI-engine/ding/model/template/tests/test_decision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..71f52da4a93d49573f80a79a821819efaeecd918
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_decision_transformer.py
@@ -0,0 +1,103 @@
+import pytest
+from itertools import product
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ding.model.template import DecisionTransformer
+from ding.torch_utils import is_differentiable
+
+action_space = ['continuous', 'discrete']
+state_encoder = [None, nn.Sequential(nn.Flatten(), nn.Linear(8, 8), nn.Tanh())]
+args = list(product(*[action_space, state_encoder]))
+args.pop(1)
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('action_space, state_encoder', args)
+def test_decision_transformer(action_space, state_encoder):
+ B, T = 4, 6
+ if state_encoder:
+ state_dim = (2, 2, 2)
+ else:
+ state_dim = 3
+ act_dim = 2
+ DT_model = DecisionTransformer(
+ state_dim=state_dim,
+ act_dim=act_dim,
+ state_encoder=state_encoder,
+ n_blocks=3,
+ h_dim=8,
+ context_len=T,
+ n_heads=2,
+ drop_p=0.1,
+ continuous=(action_space == 'continuous')
+ )
+ DT_model.configure_optimizers(1.0, 0.0003)
+
+ is_continuous = True if action_space == 'continuous' else False
+ if state_encoder:
+ timesteps = torch.randint(0, 100, [B, 3 * T - 1, 1], dtype=torch.long) # B x T
+ else:
+ timesteps = torch.randint(0, 100, [B, T], dtype=torch.long) # B x T
+ if isinstance(state_dim, int):
+ states = torch.randn([B, T, state_dim]) # B x T x state_dim
+ else:
+ states = torch.randn([B, T, *state_dim]) # B x T x state_dim
+ if action_space == 'continuous':
+ actions = torch.randn([B, T, act_dim]) # B x T x act_dim
+ action_target = torch.randn([B, T, act_dim])
+ else:
+ actions = torch.randint(0, act_dim, [B, T, 1])
+ action_target = torch.randint(0, act_dim, [B, T, 1])
+ returns_to_go_sample = torch.tensor([1, 0.8, 0.6, 0.4, 0.2, 0.])
+ returns_to_go = returns_to_go_sample.repeat([B, 1]).unsqueeze(-1) # B x T x 1
+
+ # all ones since no padding
+ traj_mask = torch.ones([B, T], dtype=torch.long) # B x T
+
+ if is_continuous:
+ assert action_target.shape == (B, T, act_dim)
+ else:
+ assert action_target.shape == (B, T, 1)
+ actions = actions.squeeze(-1)
+
+ returns_to_go = returns_to_go.float()
+ state_preds, action_preds, return_preds = DT_model.forward(
+ timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go
+ )
+ if state_encoder:
+ assert state_preds is None
+ assert return_preds is None
+ else:
+ assert state_preds.shape == (B, T, state_dim)
+ assert return_preds.shape == (B, T, 1)
+ assert action_preds.shape == (B, T, act_dim)
+
+ # only consider non padded elements
+ if state_encoder:
+ action_preds = action_preds.reshape(-1, act_dim)
+ else:
+ action_preds = action_preds.view(-1, act_dim)[traj_mask.view(-1, ) > 0]
+
+ if is_continuous:
+ action_target = action_target.view(-1, act_dim)[traj_mask.view(-1, ) > 0]
+ else:
+ action_target = action_target.view(-1)[traj_mask.view(-1, ) > 0]
+
+ if is_continuous:
+ action_loss = F.mse_loss(action_preds, action_target)
+ else:
+ action_loss = F.cross_entropy(action_preds, action_target)
+
+ if state_encoder:
+ is_differentiable(
+ action_loss, [DT_model.transformer, DT_model.embed_action, DT_model.embed_rtg, DT_model.state_encoder]
+ )
+ else:
+ is_differentiable(
+ action_loss, [
+ DT_model.transformer, DT_model.embed_action, DT_model.predict_action, DT_model.embed_rtg,
+ DT_model.embed_state
+ ]
+ )
diff --git a/DI-engine/ding/model/template/tests/test_ebm.py b/DI-engine/ding/model/template/tests/test_ebm.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba5faaea60fc196b58dd87665947162df0dda208
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_ebm.py
@@ -0,0 +1,116 @@
+import pytest
+
+import torch
+import numpy as np
+from ding.model.template.ebm import EBM, AutoregressiveEBM
+from ding.model.template.ebm import DFO, AutoRegressiveDFO, MCMC
+
+# batch, negative_samples, obs_shape, action_shape
+B, N, O, A = 32, 1024, 11, 3
+
+
+@pytest.mark.unittest
+class TestEBM:
+
+ def test_forward(self):
+ obs = torch.randn(B, N, O)
+ action = torch.randn(B, N, A)
+ ebm = EBM(O, A)
+ energy = ebm(obs, action)
+ assert energy.shape == (B, N)
+
+
+@pytest.mark.unittest
+class TestDFO:
+ opt = DFO(train_samples=N, inference_samples=N)
+ opt.set_action_bounds(np.stack([np.zeros(A), np.ones(A)], axis=0))
+ ebm = EBM(O, A)
+
+ def test_sample(self):
+ obs = torch.randn(B, O)
+ tiled_obs, action_samples = self.opt.sample(obs, self.ebm)
+ assert tiled_obs.shape == (B, N, O)
+ assert action_samples.shape == (B, N, A)
+
+ def test_infer(self):
+ obs = torch.randn(B, O)
+ action = self.opt.infer(obs, self.ebm)
+ assert action.shape == (B, A)
+
+
+@pytest.mark.unittest
+class TestAutoregressiveEBM:
+
+ def test_forward(self):
+ obs = torch.randn(B, N, O)
+ action = torch.randn(B, N, A)
+ arebm = AutoregressiveEBM(O, A)
+ energy = arebm(obs, action)
+ assert energy.shape == (B, N, A)
+
+
+@pytest.mark.unittest
+class TestAutoregressiveDFO:
+ opt = AutoRegressiveDFO(train_samples=N, inference_samples=N)
+ opt.set_action_bounds(np.stack([np.zeros(A), np.ones(A)], axis=0))
+ ebm = AutoregressiveEBM(O, A)
+
+ def test_sample(self):
+ obs = torch.randn(B, O)
+ tiled_obs, action_samples = self.opt.sample(obs, self.ebm)
+ assert tiled_obs.shape == (B, N, O)
+ assert action_samples.shape == (B, N, A)
+
+ def test_infer(self):
+ obs = torch.randn(B, O)
+ action = self.opt.infer(obs, self.ebm)
+ assert action.shape == (B, A)
+
+
+@pytest.mark.unittest
+class TestMCMC:
+ opt = MCMC(iters=3, train_samples=N, inference_samples=N)
+ opt.set_action_bounds(np.stack([np.zeros(A), np.ones(A)], axis=0))
+ obs = torch.randn(B, N, O)
+ action = torch.randn(B, N, A)
+ ebm = EBM(O, A)
+
+ def test_gradient_wrt_act(self):
+ ebm = EBM(O, A)
+ # inference mode
+ de_dact = MCMC._gradient_wrt_act(self.obs, self.action, ebm)
+ assert de_dact.shape == (B, N, A)
+ # train mode
+ de_dact = MCMC._gradient_wrt_act(self.obs, self.action, ebm, create_graph=True)
+ loss = de_dact.pow(2).sum()
+ loss.backward()
+ assert de_dact.shape == (B, N, A)
+ assert ebm.net[0].weight.grad is not None
+
+ def test_langevin_step(self):
+ stepsize = 1
+ action = self.opt._langevin_step(self.obs, self.action, stepsize, self.ebm)
+ assert action.shape == (B, N, A)
+ # TODO: new action should have lower energy
+
+ def test_langevin_action_given_obs(self):
+ action = self.opt._langevin_action_given_obs(self.obs, self.action, self.ebm)
+ assert action.shape == (B, N, A)
+
+ def test_grad_penalty(self):
+ ebm = EBM(O, A)
+ self.opt.add_grad_penalty = True
+ loss = self.opt.grad_penalty(self.obs, self.action, ebm)
+ loss.backward()
+ assert ebm.net[0].weight.grad is not None
+
+ def test_sample(self):
+ obs = torch.randn(B, O)
+ tiled_obs, action_samples = self.opt.sample(obs, self.ebm)
+ assert tiled_obs.shape == (B, N, O)
+ assert action_samples.shape == (B, N, A)
+
+ def test_infer(self):
+ obs = torch.randn(B, O)
+ action = self.opt.infer(obs, self.ebm)
+ assert action.shape == (B, A)
diff --git a/DI-engine/ding/model/template/tests/test_edac.py b/DI-engine/ding/model/template/tests/test_edac.py
new file mode 100644
index 0000000000000000000000000000000000000000..76f0cca60a23380bca0b83beea42e3a24652f095
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_edac.py
@@ -0,0 +1,57 @@
+import torch
+import pytest
+from itertools import product
+
+from ding.model.template import EDAC
+from ding.torch_utils import is_differentiable
+
+B = 4
+obs_shape = [4, (8, )]
+act_shape = [3, (6, )]
+args = list(product(*[obs_shape, act_shape]))
+
+
+@pytest.mark.unittest
+class TestEDAC:
+
+ def output_check(self, model, outputs):
+ if isinstance(outputs, torch.Tensor):
+ loss = outputs.sum()
+ elif isinstance(outputs, list):
+ loss = sum([t.sum() for t in outputs])
+ elif isinstance(outputs, dict):
+ loss = sum([v.sum() for v in outputs.values()])
+ is_differentiable(loss, model)
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_EDAC(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs_obs = torch.randn(B, obs_shape)
+ else:
+ inputs_obs = torch.randn(B, *obs_shape)
+ if isinstance(act_shape, int):
+ inputs_act = torch.randn(B, act_shape)
+ else:
+ inputs_act = torch.randn(B, *act_shape)
+ inputs = {'obs': inputs_obs, 'action': inputs_act}
+ model = EDAC(obs_shape, act_shape, ensemble_num=2)
+
+ outputs_c = model(inputs, mode='compute_critic')
+ assert isinstance(outputs_c, dict)
+ assert outputs_c['q_value'].shape == (2, B)
+ self.output_check(model.critic, outputs_c)
+
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(B, obs_shape)
+ else:
+ inputs = torch.randn(B, *obs_shape)
+ outputs_a = model(inputs, mode='compute_actor')
+ assert isinstance(outputs_a, dict)
+ if isinstance(act_shape, int):
+ assert outputs_a['logit'][0].shape == (B, act_shape)
+ assert outputs_a['logit'][1].shape == (B, act_shape)
+ elif len(act_shape) == 1:
+ assert outputs_a['logit'][0].shape == (B, *act_shape)
+ assert outputs_a['logit'][1].shape == (B, *act_shape)
+ outputs = {'mu': outputs_a['logit'][0], 'sigma': outputs_a['logit'][1]}
+ self.output_check(model.actor, outputs)
diff --git a/DI-engine/ding/model/template/tests/test_havac.py b/DI-engine/ding/model/template/tests/test_havac.py
new file mode 100644
index 0000000000000000000000000000000000000000..42982ec5aed38c42d41a2b5005301ad9b2c71e67
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_havac.py
@@ -0,0 +1,103 @@
+import pytest
+import torch
+import random
+from ding.torch_utils import is_differentiable
+from ding.model.template import HAVAC
+
+
+@pytest.mark.unittest
+class TestHAVAC:
+
+ def test_havac_rnn_actor(self):
+ # discrete+rnn
+ bs, T = 3, 8
+ obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
+ agent_num = 5
+ data = {
+ 'obs': {
+ 'agent_state': torch.randn(T, bs, obs_dim),
+ 'global_state': torch.randn(T, bs, global_obs_dim),
+ 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim))
+ },
+ 'actor_prev_state': [None for _ in range(bs)],
+ }
+ model = HAVAC(
+ agent_obs_shape=obs_dim,
+ global_obs_shape=global_obs_dim,
+ action_shape=action_dim,
+ agent_num=agent_num,
+ use_lstm=True,
+ )
+ agent_idx = random.randint(0, agent_num - 1)
+ output = model(agent_idx, data, mode='compute_actor')
+ assert set(output.keys()) == set(['logit', 'actor_next_state', 'actor_hidden_state'])
+ assert output['logit'].shape == (T, bs, action_dim)
+ assert len(output['actor_next_state']) == bs
+ print(output['actor_next_state'][0]['h'].shape)
+ loss = output['logit'].sum()
+ is_differentiable(loss, model.agent_models[agent_idx].actor)
+
+ def test_havac_rnn_critic(self):
+ # discrete+rnn
+ bs, T = 3, 8
+ obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
+ agent_num = 5
+ data = {
+ 'obs': {
+ 'agent_state': torch.randn(T, bs, obs_dim),
+ 'global_state': torch.randn(T, bs, global_obs_dim),
+ 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim))
+ },
+ 'critic_prev_state': [None for _ in range(bs)],
+ }
+ model = HAVAC(
+ agent_obs_shape=obs_dim,
+ global_obs_shape=global_obs_dim,
+ action_shape=action_dim,
+ agent_num=agent_num,
+ use_lstm=True,
+ )
+ agent_idx = random.randint(0, agent_num - 1)
+ output = model(agent_idx, data, mode='compute_critic')
+ assert set(output.keys()) == set(['value', 'critic_next_state', 'critic_hidden_state'])
+ assert output['value'].shape == (T, bs)
+ assert len(output['critic_next_state']) == bs
+ print(output['critic_next_state'][0]['h'].shape)
+ loss = output['value'].sum()
+ is_differentiable(loss, model.agent_models[agent_idx].critic)
+
+ def test_havac_rnn_actor_critic(self):
+ # discrete+rnn
+ bs, T = 3, 8
+ obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
+ agent_num = 5
+ data = {
+ 'obs': {
+ 'agent_state': torch.randn(T, bs, obs_dim),
+ 'global_state': torch.randn(T, bs, global_obs_dim),
+ 'action_mask': torch.randint(0, 2, size=(T, bs, action_dim))
+ },
+ 'actor_prev_state': [None for _ in range(bs)],
+ 'critic_prev_state': [None for _ in range(bs)],
+ }
+ model = HAVAC(
+ agent_obs_shape=obs_dim,
+ global_obs_shape=global_obs_dim,
+ action_shape=action_dim,
+ agent_num=agent_num,
+ use_lstm=True,
+ )
+ agent_idx = random.randint(0, agent_num - 1)
+ output = model(agent_idx, data, mode='compute_actor_critic')
+ assert set(output.keys()) == set(
+ ['logit', 'actor_next_state', 'actor_hidden_state', 'value', 'critic_next_state', 'critic_hidden_state']
+ )
+ assert output['logit'].shape == (T, bs, action_dim)
+ assert output['value'].shape == (T, bs)
+ loss = output['logit'].sum() + output['value'].sum()
+ is_differentiable(loss, model.agent_models[agent_idx])
+
+
+# test_havac_rnn_actor()
+# test_havac_rnn_critic()
+# test_havac_rnn_actor_critic()
diff --git a/DI-engine/ding/model/template/tests/test_hybrid_qac.py b/DI-engine/ding/model/template/tests/test_hybrid_qac.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a81d553508b8345d87a42307587afd9ebb5a73c
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_hybrid_qac.py
@@ -0,0 +1,70 @@
+import torch
+import numpy as np
+import pytest
+from itertools import product
+
+from ding.model.template import ContinuousQAC
+from ding.torch_utils import is_differentiable
+from ding.utils import squeeze
+from easydict import EasyDict
+B = 4
+T = 6
+embedding_size = 32
+hybrid_args = {
+ 'action_shape': EasyDict({
+ 'action_type_shape': (4, ),
+ 'action_args_shape': (6, )
+ }),
+ 'twin': True,
+ 'action_space': 'hybrid'
+}
+
+
+@pytest.mark.unittest
+class TestHybridContinuousQAC:
+
+ def test_hybrid_qac(
+ self,
+ action_shape=hybrid_args['action_shape'],
+ twin=hybrid_args['twin'],
+ action_space=hybrid_args['action_space']
+ ):
+ N = 32
+ assert action_space == 'hybrid'
+ inputs = {
+ 'obs': torch.randn(B, N),
+ 'action': {
+ 'action_type': torch.randint(0, squeeze(action_shape.action_type_shape), (B, )),
+ 'action_args': torch.rand(B, squeeze(action_shape.action_args_shape))
+ },
+ 'logit': torch.randn(B, squeeze(action_shape.action_type_shape))
+ }
+ model = ContinuousQAC(
+ obs_shape=(N, ),
+ action_shape=action_shape,
+ action_space=action_space,
+ critic_head_hidden_size=embedding_size,
+ actor_head_hidden_size=embedding_size,
+ twin_critic=twin,
+ )
+ # compute_q
+ q = model(inputs, mode='compute_critic')['q_value']
+ if twin:
+ is_differentiable(q[0].sum(), model.critic[1][0])
+ is_differentiable(q[1].sum(), model.critic[1][1])
+ else:
+ is_differentiable(q.sum(), model.critic)
+
+ # compute_action
+ print(model)
+
+ output = model(inputs['obs'], mode='compute_actor')
+ discrete_logit = output['logit']
+ continuous_args = output['action_args']
+ # test discrete action_type + continuous action_args
+ if squeeze(action_shape.action_type_shape) == 1:
+ assert discrete_logit.shape == (B, )
+ else:
+ assert discrete_logit.shape == (B, squeeze(action_shape.action_type_shape))
+ assert continuous_args.shape == (B, action_shape.action_args_shape)
+ is_differentiable(discrete_logit.sum() + continuous_args.sum(), model.actor)
diff --git a/DI-engine/ding/model/template/tests/test_language_transformer.py b/DI-engine/ding/model/template/tests/test_language_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..40095c2ab29b0c9aa3bc3322a52be65fe2b67271
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_language_transformer.py
@@ -0,0 +1,25 @@
+import pytest
+
+from ding.model.template.language_transformer import LanguageTransformer
+
+
+@pytest.mark.unittest
+class TestNLPPretrainedModel:
+
+ def check_model(self):
+ test_pids = [1]
+ cand_pids = [0, 2, 4]
+ problems = [
+ "This is problem 0", "This is the first question", "Second problem is here", "Another problem",
+ "This is the last problem"
+ ]
+ ctxt_list = [problems[pid] for pid in test_pids]
+ cands_list = [problems[pid] for pid in cand_pids]
+
+ model = LanguageTransformer(model_name="bert-base-uncased", add_linear=True, embedding_size=256)
+ scores = model(ctxt_list, cands_list)
+ assert scores.shape == (1, 3)
+
+ model = LanguageTransformer(model_name="bert-base-uncased", add_linear=False, embedding_size=256)
+ scores = model(ctxt_list, cands_list)
+ assert scores.shape == (1, 3)
diff --git a/DI-engine/ding/model/template/tests/test_madqn.py b/DI-engine/ding/model/template/tests/test_madqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2245c332f09fb7553a0c8524434fa857594d645
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_madqn.py
@@ -0,0 +1,30 @@
+import pytest
+import torch
+from ding.torch_utils import is_differentiable
+from ding.model.template import MADQN
+
+
+@pytest.mark.unittest
+def test_madqn():
+ agent_num, bs, T = 4, 3, 8
+ obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
+ embedding_dim = 64
+ madqn_model = MADQN(
+ agent_num=agent_num,
+ obs_shape=obs_dim,
+ action_shape=action_dim,
+ hidden_size_list=[embedding_dim, embedding_dim],
+ global_obs_shape=global_obs_dim
+ )
+ data = {
+ 'obs': {
+ 'agent_state': torch.randn(T, bs, agent_num, obs_dim),
+ 'global_state': torch.randn(T, bs, agent_num, global_obs_dim),
+ 'action_mask': torch.randint(0, 2, size=(T, bs, agent_num, action_dim))
+ },
+ 'prev_state': [[None for _ in range(agent_num)] for _ in range(bs)],
+ 'action': torch.randint(0, action_dim, size=(T, bs, agent_num))
+ }
+ output = madqn_model(data, cooperation=True, single_step=False)
+ assert output['total_q'].shape == (T, bs)
+ assert len(output['next_state']) == bs and all([len(n) == agent_num for n in output['next_state']])
diff --git a/DI-engine/ding/model/template/tests/test_maqac.py b/DI-engine/ding/model/template/tests/test_maqac.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa917e7ebc4ee092b38c949663770a34d8672e4a
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_maqac.py
@@ -0,0 +1,119 @@
+import torch
+import numpy as np
+import pytest
+from itertools import product
+
+from ding.model.template import DiscreteMAQAC, ContinuousMAQAC
+from ding.torch_utils import is_differentiable
+from ding.utils.default_helper import squeeze
+
+B = 32
+agent_obs_shape = [216, 265]
+global_obs_shape = [264, 324]
+agent_num = 8
+action_shape = 14
+args = list(product(*[agent_obs_shape, global_obs_shape, [False, True]]))
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('agent_obs_shape, global_obs_shape, twin_critic', args)
+class TestDiscreteMAQAC:
+
+ def output_check(self, model, outputs, action_shape):
+ if isinstance(action_shape, tuple):
+ loss = sum([t.sum() for t in outputs])
+ elif np.isscalar(action_shape):
+ loss = outputs.sum()
+ is_differentiable(loss, model)
+
+ def test_maqac(self, agent_obs_shape, global_obs_shape, twin_critic):
+ data = {
+ 'obs': {
+ 'agent_state': torch.randn(B, agent_num, agent_obs_shape),
+ 'global_state': torch.randn(B, agent_num, global_obs_shape),
+ 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
+ }
+ }
+ model = DiscreteMAQAC(agent_obs_shape, global_obs_shape, action_shape, twin_critic=twin_critic)
+
+ logit = model(data, mode='compute_actor')['logit']
+ value = model(data, mode='compute_critic')['q_value']
+
+ value_sum = sum(t.sum() for t in value) if twin_critic else value.sum()
+ outputs = value_sum + logit.sum()
+ self.output_check(model, outputs, action_shape)
+
+ for p in model.parameters():
+ p.grad = None
+ logit = model(data, mode='compute_actor')['logit']
+ self.output_check(model.actor, logit, action_shape)
+
+ for p in model.parameters():
+ p.grad = None
+ value = model(data, mode='compute_critic')['q_value']
+ if twin_critic:
+ for v in value:
+ assert v.shape == (B, agent_num, action_shape)
+ else:
+ assert value.shape == (B, agent_num, action_shape)
+ self.output_check(model.critic, sum(t.sum() for t in value) if twin_critic else value.sum(), action_shape)
+
+
+B = 32
+agent_obs_shape = [216, 265]
+global_obs_shape = [264, 324]
+agent_num = 8
+action_shape = 14
+action_space = ['regression', 'reparameterization']
+args = list(product(*[agent_obs_shape, global_obs_shape, action_space, [False, True]]))
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('agent_obs_shape, global_obs_shape, action_space, twin_critic', args)
+class TestContinuousMAQAC:
+
+ def output_check(self, model, outputs, action_shape):
+ if isinstance(action_shape, tuple):
+ loss = sum([t.sum() for t in outputs])
+ elif np.isscalar(action_shape):
+ loss = outputs.sum()
+ is_differentiable(loss, model)
+
+ def test_continuousmaqac(self, agent_obs_shape, global_obs_shape, action_space, twin_critic):
+ data = {
+ 'obs': {
+ 'agent_state': torch.randn(B, agent_num, agent_obs_shape),
+ 'global_state': torch.randn(B, agent_num, global_obs_shape),
+ 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
+ },
+ 'action': torch.randn(B, agent_num, squeeze(action_shape))
+ }
+ model = ContinuousMAQAC(agent_obs_shape, global_obs_shape, action_shape, action_space, twin_critic=twin_critic)
+
+ for p in model.parameters():
+ p.grad = None
+
+ if action_space == 'regression':
+ action = model(data['obs'], mode='compute_actor')['action']
+ if squeeze(action_shape) == 1:
+ assert action.shape == (B, )
+ else:
+ assert action.shape == (B, agent_num, squeeze(action_shape))
+ assert action.eq(action.clamp(-1, 1)).all()
+ self.output_check(model.actor, action, action_shape)
+ #is_differentiable(action.sum(), model.actor)
+ elif action_space == 'reparameterization':
+ (mu, sigma) = model(data['obs'], mode='compute_actor')['logit']
+ assert mu.shape == (B, agent_num, action_shape)
+ assert sigma.shape == (B, agent_num, action_shape)
+ is_differentiable(mu.sum() + sigma.sum(), model.actor)
+
+ for p in model.parameters():
+ p.grad = None
+ value = model(data, mode='compute_critic')['q_value']
+ if twin_critic:
+ for v in value:
+ assert v.shape == (B, agent_num)
+ else:
+ assert value.shape == (B, agent_num)
+ self.output_check(model.critic, sum(t.sum() for t in value) if twin_critic else value.sum(), action_shape)
diff --git a/DI-engine/ding/model/template/tests/test_mavac.py b/DI-engine/ding/model/template/tests/test_mavac.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6c6927373e0644dbd02a52971f1beb210612bf4
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_mavac.py
@@ -0,0 +1,52 @@
+import pytest
+import numpy as np
+import torch
+from itertools import product
+
+from ding.model import mavac
+from ding.model.template.mavac import MAVAC
+from ding.torch_utils import is_differentiable
+
+B = 32
+agent_obs_shape = [216, 265]
+global_obs_shape = [264, 324]
+agent_num = 8
+action_shape = 14
+args = list(product(*[agent_obs_shape, global_obs_shape]))
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('agent_obs_shape, global_obs_shape', args)
+class TestVAC:
+
+ def output_check(self, model, outputs, action_shape):
+ if isinstance(action_shape, tuple):
+ loss = sum([t.sum() for t in outputs])
+ elif np.isscalar(action_shape):
+ loss = outputs.sum()
+ is_differentiable(loss, model)
+
+ def test_vac(self, agent_obs_shape, global_obs_shape):
+ data = {
+ 'agent_state': torch.randn(B, agent_num, agent_obs_shape),
+ 'global_state': torch.randn(B, agent_num, global_obs_shape),
+ 'action_mask': torch.randint(0, 2, size=(B, agent_num, action_shape))
+ }
+ model = MAVAC(agent_obs_shape, global_obs_shape, action_shape, agent_num)
+
+ logit = model(data, mode='compute_actor_critic')['logit']
+ value = model(data, mode='compute_actor_critic')['value']
+
+ outputs = value.sum() + logit.sum()
+ self.output_check(model, outputs, action_shape)
+
+ for p in model.parameters():
+ p.grad = None
+ logit = model(data, mode='compute_actor')['logit']
+ self.output_check(model.actor, logit, model.action_shape)
+
+ for p in model.parameters():
+ p.grad = None
+ value = model(data, mode='compute_critic')['value']
+ assert value.shape == (B, agent_num)
+ self.output_check(model.critic, value, action_shape)
diff --git a/DI-engine/ding/model/template/tests/test_ngu.py b/DI-engine/ding/model/template/tests/test_ngu.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed0e86f194768fd3cdf371bf73d9db0097a20ccd
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_ngu.py
@@ -0,0 +1,70 @@
+import pytest
+from itertools import product
+import torch
+from ding.model.template import NGU
+from ding.torch_utils import is_differentiable
+
+B = 4
+H = 4
+obs_shape = [4, (8, ), (4, 64, 64)]
+act_shape = [4, (4, )]
+args = list(product(*[obs_shape, act_shape]))
+
+
+@pytest.mark.unittest
+class TestNGU:
+
+ def output_check(self, model, outputs):
+ if isinstance(outputs, torch.Tensor):
+ loss = outputs.sum()
+ elif isinstance(outputs, list):
+ loss = sum([t.sum() for t in outputs])
+ elif isinstance(outputs, dict):
+ loss = sum([v.sum() for v in outputs.values()])
+ is_differentiable(loss, model)
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_ngu(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs_obs = torch.randn(B, H, obs_shape)
+ else:
+ inputs_obs = torch.randn(B, H, *obs_shape)
+ if isinstance(act_shape, int):
+ inputs_prev_action = torch.ones(B, act_shape).long()
+ else:
+ inputs_prev_action = torch.ones(B, *act_shape).long()
+ inputs_prev_reward_extrinsic = torch.randn(B, H, 1)
+ inputs_beta = 2 * torch.ones([4, 4], dtype=torch.long)
+ inputs = {
+ 'obs': inputs_obs,
+ 'prev_state': None,
+ 'prev_action': inputs_prev_action,
+ 'prev_reward_extrinsic': inputs_prev_reward_extrinsic,
+ 'beta': inputs_beta
+ }
+
+ model = NGU(obs_shape, act_shape, collector_env_num=3)
+ outputs = model(inputs)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (B, act_shape, act_shape)
+ elif len(act_shape) == 1:
+ assert outputs['logit'].shape == (B, *act_shape, *act_shape)
+ self.output_check(model, outputs['logit'])
+
+ inputs = {
+ 'obs': inputs_obs,
+ 'prev_state': None,
+ 'action': inputs_prev_action,
+ 'reward': inputs_prev_reward_extrinsic,
+ 'prev_reward_extrinsic': inputs_prev_reward_extrinsic,
+ 'beta': inputs_beta
+ }
+ model = NGU(obs_shape, act_shape, collector_env_num=3)
+ outputs = model(inputs)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (B, act_shape, act_shape)
+ elif len(act_shape) == 1:
+ assert outputs['logit'].shape == (B, *act_shape, *act_shape)
+ self.output_check(model, outputs['logit'])
diff --git a/DI-engine/ding/model/template/tests/test_pdqn.py b/DI-engine/ding/model/template/tests/test_pdqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f9b66f9af029ad26b45e653b457447eb80bf160
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_pdqn.py
@@ -0,0 +1,61 @@
+import pytest
+import torch
+from easydict import EasyDict
+
+from ding.model.template import PDQN
+
+action_args_shape_values = [1, 5]
+
+
+@pytest.mark.unittest
+class TestPQQN:
+
+ @pytest.mark.unittest
+ @pytest.mark.parametrize('action_type_shape', action_args_shape_values)
+ def test_dqn(self, action_type_shape):
+ T, B = 3, 4
+ obs_shape = (4, )
+ act_shape = EasyDict({'action_type_shape': (3, ), 'action_args_shape': (action_type_shape, )})
+ if isinstance(obs_shape, int):
+ cont_inputs = torch.randn(B, obs_shape)
+ else:
+ cont_inputs = torch.randn(B, *obs_shape)
+ model = PDQN(obs_shape, act_shape)
+ cont_outputs = model.forward(cont_inputs, mode='compute_continuous')
+ assert isinstance(cont_outputs, dict)
+ dis_inputs = {'state': cont_inputs, 'action_args': cont_outputs['action_args']}
+ dis_outputs = model.forward(dis_inputs, mode='compute_discrete')
+ assert isinstance(dis_outputs, dict)
+ if isinstance(act_shape['action_type_shape'], int):
+ assert dis_outputs['logit'].shape == (B, act_shape.action_type_shape)
+ elif len(act_shape['action_type_shape']) == 1:
+ assert dis_outputs['logit'].shape == (B, *act_shape.action_type_shape)
+ else:
+ for i, s in enumerate(act_shape):
+ assert dis_outputs['logit'][i].shape == (B, s)
+
+ def test_mdqn(self):
+ T, B = 3, 4
+ obs_shape = (4, )
+ act_shape = EasyDict({'action_type_shape': 3, 'action_args_shape': 5})
+ if isinstance(obs_shape, int):
+ cont_inputs = torch.randn(B, obs_shape)
+ else:
+ cont_inputs = torch.randn(B, *obs_shape)
+ model = PDQN(
+ obs_shape, act_shape, multi_pass=True, action_mask=[[1, 1, 0, 0, 0], [0, 0, 1, 1, 1], [0, 0, 0, 0, 0]]
+ )
+ cont_outputs = model.forward(cont_inputs, mode='compute_continuous')
+ assert isinstance(cont_outputs, dict)
+ dis_inputs = {'state': cont_inputs, 'action_args': cont_outputs['action_args']}
+
+ dis_outputs = model.forward(dis_inputs, mode='compute_discrete')
+
+ assert isinstance(dis_outputs, dict)
+ if isinstance(act_shape['action_type_shape'], int):
+ assert dis_outputs['logit'].shape == (B, act_shape.action_type_shape)
+ elif len(act_shape['action_type_shape']) == 1:
+ assert dis_outputs['logit'].shape == (B, *act_shape.action_type_shape)
+ else:
+ for i, s in enumerate(act_shape):
+ assert dis_outputs['logit'][i].shape == (B, s)
diff --git a/DI-engine/ding/model/template/tests/test_pg.py b/DI-engine/ding/model/template/tests/test_pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..2eb0dfba907a5cf50d31542b9f6678fc51f8f227
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_pg.py
@@ -0,0 +1,61 @@
+import torch
+import numpy as np
+import pytest
+from itertools import product
+
+from ding.model.template import PG
+from ding.torch_utils import is_differentiable
+from ding.utils import squeeze
+
+B = 4
+
+
+@pytest.mark.unittest
+class TestDiscretePG:
+
+ def output_check(self, model, outputs):
+ if isinstance(outputs, torch.Tensor):
+ loss = outputs.sum()
+ elif isinstance(outputs, list):
+ loss = sum([t.sum() for t in outputs])
+ elif isinstance(outputs, dict):
+ loss = sum([v.sum() for v in outputs.values()])
+ is_differentiable(loss, model)
+
+ def test_discrete_pg(self):
+ obs_shape = (4, 84, 84)
+ action_shape = 5
+ model = PG(
+ obs_shape,
+ action_shape,
+ )
+ inputs = torch.randn(B, 4, 84, 84)
+
+ outputs = model(inputs)
+ assert isinstance(outputs, dict)
+ assert outputs['logit'].shape == (B, action_shape)
+ assert outputs['dist'].sample().shape == (B, )
+ self.output_check(model, outputs['logit'])
+
+ def test_continuous_pg(self):
+ N = 32
+ action_shape = (6, )
+ inputs = {'obs': torch.randn(B, N), 'action': torch.randn(B, squeeze(action_shape))}
+ model = PG(
+ obs_shape=(N, ),
+ action_shape=action_shape,
+ action_space='continuous',
+ )
+ # compute_action
+ print(model)
+ outputs = model(inputs['obs'])
+ assert isinstance(outputs, dict)
+ dist = outputs['dist']
+ action = dist.sample()
+ assert action.shape == (B, *action_shape)
+
+ logit = outputs['logit']
+ mu, sigma = logit['mu'], logit['sigma']
+ assert mu.shape == (B, *action_shape)
+ assert sigma.shape == (B, *action_shape)
+ is_differentiable(mu.sum() + sigma.sum(), model)
diff --git a/DI-engine/ding/model/template/tests/test_procedure_cloning.py b/DI-engine/ding/model/template/tests/test_procedure_cloning.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2bb1979545db7dcbb8092ae5ea476659d655539
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_procedure_cloning.py
@@ -0,0 +1,37 @@
+import pytest
+from itertools import product
+
+import torch
+
+from ding.model.template import ProcedureCloningMCTS, ProcedureCloningBFS
+
+B = 4
+T = 15
+obs_shape = [(64, 64, 3)]
+action_dim = [9]
+obs_embeddings = 256
+args = list(product(*[obs_shape, action_dim]))
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('obs_shape, action_dim', args)
+class TestProcedureCloning:
+
+ def test_procedure_cloning_mcts(self, obs_shape, action_dim):
+ inputs = {
+ 'states': torch.randn(B, *obs_shape),
+ 'goals': torch.randn(B, *obs_shape),
+ 'actions': torch.randn(B, T, action_dim)
+ }
+ model = ProcedureCloningMCTS(obs_shape=obs_shape, action_dim=action_dim)
+ goal_preds, action_preds = model(inputs['states'], inputs['goals'], inputs['actions'])
+ assert goal_preds.shape == (B, obs_embeddings)
+ assert action_preds.shape == (B, T + 1, action_dim)
+
+ def test_procedure_cloning_bfs(self, obs_shape, action_dim):
+ o_shape = (obs_shape[2], obs_shape[0], obs_shape[1])
+ model = ProcedureCloningBFS(obs_shape=o_shape, action_shape=action_dim)
+
+ inputs = torch.randn(B, *obs_shape)
+ map_preds = model(inputs)
+ assert map_preds['logit'].shape == (B, obs_shape[0], obs_shape[1], action_dim + 1)
diff --git a/DI-engine/ding/model/template/tests/test_q_learning.py b/DI-engine/ding/model/template/tests/test_q_learning.py
new file mode 100644
index 0000000000000000000000000000000000000000..2307a372d1f9b4e56675d2c8e21b78d1568a5260
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_q_learning.py
@@ -0,0 +1,293 @@
+import pytest
+from itertools import product
+import torch
+from ding.model.template import DQN, RainbowDQN, QRDQN, IQN, FQF, DRQN, C51DQN, BDQ, GTrXLDQN
+from ding.torch_utils import is_differentiable
+
+T, B = 3, 4
+obs_shape = [4, (8, ), (4, 64, 64)]
+act_shape = [3, (6, ), [2, 3, 6]]
+args = list(product(*[obs_shape, act_shape]))
+
+
+@pytest.mark.unittest
+class TestQLearning:
+
+ def output_check(self, model, outputs):
+ if isinstance(outputs, torch.Tensor):
+ loss = outputs.sum()
+ elif isinstance(outputs, list):
+ loss = sum([t.sum() for t in outputs])
+ elif isinstance(outputs, dict):
+ loss = sum([v.sum() for v in outputs.values()])
+ is_differentiable(loss, model)
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_dqn(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(B, obs_shape)
+ else:
+ inputs = torch.randn(B, *obs_shape)
+ model = DQN(obs_shape, act_shape)
+ outputs = model(inputs)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (B, act_shape)
+ elif len(act_shape) == 1:
+ assert outputs['logit'].shape == (B, *act_shape)
+ else:
+ for i, s in enumerate(act_shape):
+ assert outputs['logit'][i].shape == (B, s)
+ self.output_check(model, outputs['logit'])
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_bdq(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(B, obs_shape)
+ else:
+ inputs = torch.randn(B, *obs_shape)
+ if not isinstance(act_shape, int) and len(act_shape) > 1:
+ return
+ num_branches = act_shape
+ for action_bins_per_branch in range(1, 10):
+ model = BDQ(obs_shape, num_branches, action_bins_per_branch)
+ outputs = model(inputs)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (B, act_shape, action_bins_per_branch)
+ else:
+ assert outputs['logit'].shape == (B, *act_shape, action_bins_per_branch)
+ self.output_check(model, outputs['logit'])
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_rainbowdqn(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(B, obs_shape)
+ else:
+ inputs = torch.randn(B, *obs_shape)
+ model = RainbowDQN(obs_shape, act_shape, n_atom=41)
+ outputs = model(inputs)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (B, act_shape)
+ assert outputs['distribution'].shape == (B, act_shape, 41)
+ elif len(act_shape) == 1:
+ assert outputs['logit'].shape == (B, *act_shape)
+ assert outputs['distribution'].shape == (B, *act_shape, 41)
+ else:
+ for i, s in enumerate(act_shape):
+ assert outputs['logit'][i].shape == (B, s)
+ assert outputs['distribution'][i].shape == (B, s, 41)
+ self.output_check(model, outputs['logit'])
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_c51(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(B, obs_shape)
+ else:
+ inputs = torch.randn(B, *obs_shape)
+ model = C51DQN(obs_shape, act_shape, n_atom=41)
+ outputs = model(inputs)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (B, act_shape)
+ assert outputs['distribution'].shape == (B, act_shape, 41)
+ elif len(act_shape) == 1:
+ assert outputs['logit'].shape == (B, *act_shape)
+ assert outputs['distribution'].shape == (B, *act_shape, 41)
+ else:
+ for i, s in enumerate(act_shape):
+ assert outputs['logit'][i].shape == (B, s)
+ assert outputs['distribution'][i].shape == (B, s, 41)
+ self.output_check(model, outputs['logit'])
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_iqn(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(B, obs_shape)
+ else:
+ inputs = torch.randn(B, *obs_shape)
+ num_quantiles = 48
+ model = IQN(obs_shape, act_shape, num_quantiles=num_quantiles, quantile_embedding_size=64)
+ outputs = model(inputs)
+ print(model)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (B, act_shape)
+ assert outputs['q'].shape == (num_quantiles, B, act_shape)
+ assert outputs['quantiles'].shape == (B * num_quantiles, 1)
+ elif len(act_shape) == 1:
+ assert outputs['logit'].shape == (B, *act_shape)
+ assert outputs['q'].shape == (num_quantiles, B, *act_shape)
+ assert outputs['quantiles'].shape == (B * num_quantiles, 1)
+ else:
+ for i, s in enumerate(act_shape):
+ assert outputs['logit'][i].shape == (B, s)
+ assert outputs['q'][i].shape == (num_quantiles, B, s)
+ assert outputs['quantiles'][i].shape == (B * num_quantiles, 1)
+ self.output_check(model, outputs['logit'])
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_fqf(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(B, obs_shape)
+ else:
+ inputs = torch.randn(B, *obs_shape)
+ num_quantiles = 48
+ model = FQF(obs_shape, act_shape, num_quantiles=num_quantiles, quantile_embedding_size=64)
+ outputs = model(inputs)
+ print(model)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (B, act_shape)
+ assert outputs['q'].shape == (B, num_quantiles, act_shape)
+ assert outputs['quantiles'].shape == (B, num_quantiles + 1)
+ assert outputs['quantiles_hats'].shape == (B, num_quantiles)
+ assert outputs['q_tau_i'].shape == (B, num_quantiles - 1, act_shape)
+ all_quantiles_proposal = model.head.quantiles_proposal
+ all_fqf_fc = model.head.fqf_fc
+ elif len(act_shape) == 1:
+ assert outputs['logit'].shape == (B, *act_shape)
+ assert outputs['q'].shape == (B, num_quantiles, *act_shape)
+ assert outputs['quantiles'].shape == (B, num_quantiles + 1)
+ assert outputs['quantiles_hats'].shape == (B, num_quantiles)
+ assert outputs['q_tau_i'].shape == (B, num_quantiles - 1, *act_shape)
+ all_quantiles_proposal = model.head.quantiles_proposal
+ all_fqf_fc = model.head.fqf_fc
+ else:
+ for i, s in enumerate(act_shape):
+ assert outputs['logit'][i].shape == (B, s)
+ assert outputs['q'][i].shape == (B, num_quantiles, s)
+ assert outputs['quantiles'][i].shape == (B, num_quantiles + 1)
+ assert outputs['quantiles_hats'][i].shape == (B, num_quantiles)
+ assert outputs['q_tau_i'][i].shape == (B, num_quantiles - 1, s)
+ all_quantiles_proposal = [h.quantiles_proposal for h in model.head.pred]
+ all_fqf_fc = [h.fqf_fc for h in model.head.pred]
+ self.output_check(all_quantiles_proposal, outputs['quantiles'])
+ for p in model.parameters():
+ p.grad = None
+ self.output_check(all_fqf_fc, outputs['q'])
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_qrdqn(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(B, obs_shape)
+ else:
+ inputs = torch.randn(B, *obs_shape)
+ model = QRDQN(obs_shape, act_shape, num_quantiles=32)
+ outputs = model(inputs)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (B, act_shape)
+ assert outputs['q'].shape == (B, act_shape, 32)
+ assert outputs['tau'].shape == (B, 32, 1)
+ elif len(act_shape) == 1:
+ assert outputs['logit'].shape == (B, *act_shape)
+ assert outputs['q'].shape == (B, *act_shape, 32)
+ assert outputs['tau'].shape == (B, 32, 1)
+ else:
+ for i, s in enumerate(act_shape):
+ assert outputs['logit'][i].shape == (B, s)
+ assert outputs['q'][i].shape == (B, s, 32)
+ assert outputs['tau'][i].shape == (B, 32, 1)
+ self.output_check(model, outputs['logit'])
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_drqn(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(T, B, obs_shape)
+ else:
+ inputs = torch.randn(T, B, *obs_shape)
+ # (num_layer * num_direction, 1, head_hidden_size)
+ prev_state = [{k: torch.randn(1, 1, 64) for k in ['h', 'c']} for _ in range(B)]
+ model = DRQN(obs_shape, act_shape)
+ outputs = model({'obs': inputs, 'prev_state': prev_state}, inference=False)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (T, B, act_shape)
+ elif len(act_shape) == 1:
+ assert outputs['logit'].shape == (T, B, *act_shape)
+ else:
+ for i, s in enumerate(act_shape):
+ assert outputs['logit'][i].shape == (T, B, s)
+ assert len(outputs['next_state']) == B
+ assert all([len(t) == 2 for t in outputs['next_state']])
+ assert all([t['h'].shape == (1, 1, 64) for t in outputs['next_state']])
+ self.output_check(model, outputs['logit'])
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_drqn_inference(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(B, obs_shape)
+ else:
+ inputs = torch.randn(B, *obs_shape)
+ # (num_layer * num_direction, 1, head_hidden_size)
+ prev_state = [{k: torch.randn(1, 1, 64) for k in ['h', 'c']} for _ in range(B)]
+ model = DRQN(obs_shape, act_shape)
+ outputs = model({'obs': inputs, 'prev_state': prev_state}, inference=True)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (B, act_shape)
+ elif len(act_shape) == 1:
+ assert outputs['logit'].shape == (B, *act_shape)
+ else:
+ for i, s in enumerate(act_shape):
+ assert outputs['logit'][i].shape == (B, s)
+ assert len(outputs['next_state']) == B
+ assert all([len(t) == 2 for t in outputs['next_state']])
+ assert all([t['h'].shape == (1, 1, 64) for t in outputs['next_state']])
+ self.output_check(model, outputs['logit'])
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_drqn_res_link(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(T, B, obs_shape)
+ else:
+ inputs = torch.randn(T, B, *obs_shape)
+ # (num_layer * num_direction, 1, head_hidden_size)
+ prev_state = [{k: torch.randn(1, 1, 64) for k in ['h', 'c']} for _ in range(B)]
+ model = DRQN(obs_shape, act_shape, res_link=True)
+ outputs = model({'obs': inputs, 'prev_state': prev_state}, inference=False)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (T, B, act_shape)
+ elif len(act_shape) == 1:
+ assert outputs['logit'].shape == (T, B, *act_shape)
+ else:
+ for i, s in enumerate(act_shape):
+ assert outputs['logit'][i].shape == (T, B, s)
+ assert len(outputs['next_state']) == B
+ assert all([len(t) == 2 for t in outputs['next_state']])
+ assert all([t['h'].shape == (1, 1, 64) for t in outputs['next_state']])
+ self.output_check(model, outputs['logit'])
+
+ @pytest.mark.parametrize('obs_shape, act_shape', args)
+ def test_drqn_inference_res_link(self, obs_shape, act_shape):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(B, obs_shape)
+ else:
+ inputs = torch.randn(B, *obs_shape)
+ # (num_layer * num_direction, 1, head_hidden_size)
+ prev_state = [{k: torch.randn(1, 1, 64) for k in ['h', 'c']} for _ in range(B)]
+ model = DRQN(obs_shape, act_shape, res_link=True)
+ outputs = model({'obs': inputs, 'prev_state': prev_state}, inference=True)
+ assert isinstance(outputs, dict)
+ if isinstance(act_shape, int):
+ assert outputs['logit'].shape == (B, act_shape)
+ elif len(act_shape) == 1:
+ assert outputs['logit'].shape == (B, *act_shape)
+ else:
+ for i, s in enumerate(act_shape):
+ assert outputs['logit'][i].shape == (B, s)
+ assert len(outputs['next_state']) == B
+ assert all([len(t) == 2 for t in outputs['next_state']])
+ assert all([t['h'].shape == (1, 1, 64) for t in outputs['next_state']])
+ self.output_check(model, outputs['logit'])
+
+ @pytest.mark.tmp
+ def test_GTrXLDQN(self):
+ obs_dim, seq_len, bs, action_dim = [4, 64, 64], 64, 32, 4
+ obs = torch.rand(seq_len, bs, *obs_dim)
+ model = GTrXLDQN(obs_dim, action_dim, encoder_hidden_size_list=[16, 16, 16])
+ outputs = model(obs)
+ assert isinstance(outputs, dict)
diff --git a/DI-engine/ding/model/template/tests/test_qac.py b/DI-engine/ding/model/template/tests/test_qac.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ddbf9d5115691a9f863f386319b85ab3a8ab1af
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_qac.py
@@ -0,0 +1,130 @@
+import torch
+import numpy as np
+import pytest
+from itertools import product
+
+from ding.model.template import ContinuousQAC, DiscreteMAQAC, DiscreteQAC
+from ding.torch_utils import is_differentiable
+from ding.utils import squeeze
+
+B = 4
+T = 6
+embedding_size = 32
+action_shape_args = [(6, ), [
+ 1,
+]]
+args = list(product(*[action_shape_args, [True, False], ['regression', 'reparameterization']]))
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('action_shape, twin, action_space', args)
+class TestContinuousQAC:
+
+ def test_fcqac(self, action_shape, twin, action_space):
+ N = 32
+ inputs = {'obs': torch.randn(B, N), 'action': torch.randn(B, squeeze(action_shape))}
+ model = ContinuousQAC(
+ obs_shape=(N, ),
+ action_shape=action_shape,
+ action_space=action_space,
+ critic_head_hidden_size=embedding_size,
+ actor_head_hidden_size=embedding_size,
+ twin_critic=twin,
+ )
+ # compute_q
+ q = model(inputs, mode='compute_critic')['q_value']
+ if twin:
+ is_differentiable(q[0].sum(), model.critic[1][0])
+ is_differentiable(q[1].sum(), model.critic[1][1])
+ else:
+ is_differentiable(q.sum(), model.critic)
+
+ # compute_action
+ print(model)
+ if action_space == 'regression':
+ action = model(inputs['obs'], mode='compute_actor')['action']
+ if squeeze(action_shape) == 1:
+ assert action.shape == (B, )
+ else:
+ assert action.shape == (B, squeeze(action_shape))
+ assert action.eq(action.clamp(-1, 1)).all()
+ is_differentiable(action.sum(), model.actor)
+ elif action_space == 'reparameterization':
+ (mu, sigma) = model(inputs['obs'], mode='compute_actor')['logit']
+ assert mu.shape == (B, *action_shape)
+ assert sigma.shape == (B, *action_shape)
+ is_differentiable(mu.sum() + sigma.sum(), model.actor)
+
+
+args = list(product(*[[True, False], [(13, ), [4, 84, 84]]]))
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('twin, obs_shape', args)
+class TestDiscreteQAC:
+
+ def test_discreteqac(self, twin, obs_shape):
+ action_shape = 6
+ inputs = torch.randn(B, *obs_shape)
+ model = DiscreteQAC(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=twin,
+ encoder_hidden_size_list=[32, 32, 64] if len(obs_shape) > 1 else None,
+ )
+ # compute_critic
+ q = model(inputs, mode='compute_critic')['q_value']
+ if twin:
+ is_differentiable(q[0].sum(), model.critic[1][0])
+ # is_differentiable(q[1].sum(), model.critic[1][1]) # backward encoder twice
+ assert q[0].shape == (B, action_shape)
+ assert q[1].shape == (B, action_shape)
+ else:
+ is_differentiable(q.sum(), model.critic[1])
+ assert q.shape == (B, action_shape)
+
+ # compute_actor
+ print(model)
+ logit = model(inputs, mode='compute_actor')['logit']
+ assert logit.shape == (B, action_shape)
+ is_differentiable(logit.sum(), model.actor)
+
+
+B = 4
+embedding_size = 64
+action_shape_args = [(6, ), 1]
+args = list(product(*[action_shape_args, [True, False], [True, False]]))
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('action_shape, twin, share_encoder', args)
+class TestContinuousQACPixel:
+
+ def test_qacpixel(self, action_shape, twin, share_encoder):
+ inputs = {'obs': torch.randn(B, 3, 84, 84), 'action': torch.randn(B, squeeze(action_shape))}
+ model = ContinuousQAC(
+ obs_shape=(3, 84, 84),
+ action_shape=action_shape,
+ action_space='reparameterization',
+ critic_head_hidden_size=embedding_size,
+ actor_head_hidden_size=embedding_size,
+ twin_critic=twin,
+ share_encoder=share_encoder,
+ encoder_hidden_size_list=[32, 32, 64],
+ )
+ # compute_q
+ q = model(inputs, mode='compute_critic')['q_value']
+ if twin:
+ q = torch.min(q[0], q[1])
+ is_differentiable(q.sum(), model.critic)
+
+ # compute_action
+ print(model)
+ (mu, sigma) = model(inputs['obs'], mode='compute_actor')['logit']
+ action_shape = squeeze(action_shape)
+ assert mu.shape == (B, action_shape)
+ assert sigma.shape == (B, action_shape)
+ if share_encoder: # if share_encoder, actor_encoder's grad is not None
+ is_differentiable(mu.sum() + sigma.sum(), model.actor_head)
+ else:
+ is_differentiable(mu.sum() + sigma.sum(), model.actor)
diff --git a/DI-engine/ding/model/template/tests/test_qac_dist.py b/DI-engine/ding/model/template/tests/test_qac_dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e6f8548092e32b21171fc31f7dc31b24e4865d6
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_qac_dist.py
@@ -0,0 +1,58 @@
+import torch
+import numpy as np
+import pytest
+from itertools import product
+
+from ding.model.template import QACDIST
+from ding.torch_utils import is_differentiable
+from ding.utils import squeeze
+
+B = 4
+T = 6
+embedding_size = 32
+action_shape_args = [(6, ), [
+ 1,
+]]
+args = list(product(*[action_shape_args, ['regression', 'reparameterization']]))
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('action_shape, action_space', args)
+class TestQACDIST:
+
+ def test_fcqac_dist(self, action_shape, action_space):
+ N = 32
+ inputs = {'obs': torch.randn(B, N), 'action': torch.randn(B, squeeze(action_shape))}
+ model = QACDIST(
+ obs_shape=(N, ),
+ action_shape=action_shape,
+ action_space=action_space,
+ critic_head_hidden_size=embedding_size,
+ actor_head_hidden_size=embedding_size,
+ )
+ # compute_q
+ q = model(inputs, mode='compute_critic')
+ is_differentiable(q['q_value'].sum(), model.critic)
+
+ if isinstance(action_shape, int):
+ assert q['q_value'].shape == (B, 1)
+ assert q['distribution'].shape == (B, 1, 51)
+ elif len(action_shape) == 1:
+ assert q['q_value'].shape == (B, 1)
+ assert q['distribution'].shape == (B, 1, 51)
+
+ # compute_action
+ print(model)
+ if action_space == 'regression':
+ action = model(inputs['obs'], mode='compute_actor')['action']
+ if squeeze(action_shape) == 1:
+ assert action.shape == (B, )
+ else:
+ assert action.shape == (B, squeeze(action_shape))
+ assert action.eq(action.clamp(-1, 1)).all()
+ is_differentiable(action.sum(), model.actor)
+ elif action_space == 'reparameterization':
+ (mu, sigma) = model(inputs['obs'], mode='compute_actor')['logit']
+ assert mu.shape == (B, *action_shape)
+ assert sigma.shape == (B, *action_shape)
+ is_differentiable(mu.sum() + sigma.sum(), model.actor)
diff --git a/DI-engine/ding/model/template/tests/test_qmix.py b/DI-engine/ding/model/template/tests/test_qmix.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce1817b6974f23206d1ab9084d61624c2446d04a
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_qmix.py
@@ -0,0 +1,45 @@
+import pytest
+import torch
+from ding.torch_utils import is_differentiable
+from ding.model.template import Mixer, QMix
+
+
+@pytest.mark.unittest
+def test_mixer():
+ agent_num, bs, embedding_dim = 4, 3, 32
+ agent_q = torch.randn(bs, agent_num)
+ state_embedding = torch.randn(bs, embedding_dim)
+ mixer = Mixer(agent_num, embedding_dim, 64)
+ total_q = mixer(agent_q, state_embedding)
+ assert total_q.shape == (bs, )
+ loss = total_q.mean()
+ is_differentiable(loss, mixer)
+
+
+@pytest.mark.unittest
+def test_qmix():
+ use_mixer = [True, False]
+ agent_num, bs, T = 4, 3, 8
+ obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
+ embedding_dim = 64
+ for mix in use_mixer:
+ qmix_model = QMix(agent_num, obs_dim, global_obs_dim, action_dim, [128, embedding_dim], mix)
+ data = {
+ 'obs': {
+ 'agent_state': torch.randn(T, bs, agent_num, obs_dim),
+ 'global_state': torch.randn(T, bs, global_obs_dim),
+ 'action_mask': torch.randint(0, 2, size=(T, bs, agent_num, action_dim))
+ },
+ 'prev_state': [[None for _ in range(agent_num)] for _ in range(bs)],
+ 'action': torch.randint(0, action_dim, size=(T, bs, agent_num))
+ }
+ output = qmix_model(data, single_step=False)
+ assert set(output.keys()) == set(['total_q', 'logit', 'next_state', 'action_mask'])
+ assert output['total_q'].shape == (T, bs)
+ assert output['logit'].shape == (T, bs, agent_num, action_dim)
+ assert len(output['next_state']) == bs and all([len(n) == agent_num for n in output['next_state']])
+ print(output['next_state'][0][0]['h'].shape)
+ loss = output['total_q'].sum()
+ is_differentiable(loss, qmix_model)
+ data.pop('action')
+ output = qmix_model(data, single_step=False)
diff --git a/DI-engine/ding/model/template/tests/test_qtran.py b/DI-engine/ding/model/template/tests/test_qtran.py
new file mode 100644
index 0000000000000000000000000000000000000000..c468ec0f15ca3dc5c86893eab38e2a4c046b935b
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_qtran.py
@@ -0,0 +1,33 @@
+import pytest
+from itertools import product
+import torch
+from ding.model.template import QTran
+from ding.torch_utils import is_differentiable
+
+
+@pytest.mark.unittest
+def test_qtran():
+ agent_num, bs, T = 4, 3, 8
+ obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
+ embedding_dim = 64
+ data = {
+ 'obs': {
+ 'agent_state': torch.randn(T, bs, agent_num, obs_dim),
+ 'global_state': torch.randn(T, bs, global_obs_dim),
+ 'action_mask': torch.randint(0, 2, size=(T, bs, agent_num, action_dim))
+ },
+ 'prev_state': [[None for _ in range(agent_num)] for _ in range(bs)],
+ 'action': torch.randint(0, action_dim, size=(T, bs, agent_num))
+ }
+ model = QTran(agent_num, obs_dim, global_obs_dim, action_dim, [32, embedding_dim], embedding_dim)
+ output = model.forward(data, single_step=False)
+ assert set(output.keys()) == set(['next_state', 'agent_q_act', 'vs', 'logit', 'action_mask', 'total_q'])
+ assert output['total_q'].shape == (T, bs)
+ assert output['logit'].shape == (T, bs, agent_num, action_dim)
+ assert len(output['next_state']) == bs and all([len(n) == agent_num for n in output['next_state']])
+ print(output['next_state'][0][0]['h'].shape)
+ loss = output['total_q'].sum() + output['agent_q_act'].sum() + output['vs'].sum()
+ is_differentiable(loss, model)
+
+ data.pop('action')
+ outputs = model.forward(data, single_step=False)
diff --git a/DI-engine/ding/model/template/tests/test_vac.py b/DI-engine/ding/model/template/tests/test_vac.py
new file mode 100644
index 0000000000000000000000000000000000000000..85e44e8a4c334239ece6ffd72cc9a5a6a4194b22
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_vac.py
@@ -0,0 +1,112 @@
+import pytest
+import numpy as np
+import torch
+from itertools import product
+
+from ding.model import VAC, DREAMERVAC
+from ding.torch_utils import is_differentiable
+
+from ding.model import ConvEncoder
+from easydict import EasyDict
+
+ezD = EasyDict({'action_args_shape': (3, ), 'action_type_shape': 4})
+B, C, H, W = 4, 3, 128, 128
+obs_shape = [4, (8, ), (4, 64, 64)]
+act_args = [[6, 'discrete'], [(3, ), 'continuous'], [[2, 3, 6], 'discrete'], [ezD, 'hybrid']]
+# act_args = [[(3, ), True]]
+args = list(product(*[obs_shape, act_args, [False, True]]))
+
+
+def output_check(model, outputs, action_shape):
+ if isinstance(action_shape, tuple):
+ loss = sum([t.sum() for t in outputs])
+ elif np.isscalar(action_shape):
+ loss = outputs.sum()
+ elif isinstance(action_shape, dict):
+ loss = outputs.sum()
+ is_differentiable(loss, model)
+
+
+def model_check(model, inputs):
+ outputs = model(inputs, mode='compute_actor_critic')
+ value, logit = outputs['value'], outputs['logit']
+ if model.action_space == 'continuous':
+ outputs = value.sum() + logit['mu'].sum() + logit['sigma'].sum()
+ elif model.action_space == 'hybrid':
+ outputs = value.sum() + logit['action_type'].sum() + logit['action_args']['mu'].sum(
+ ) + logit['action_args']['sigma'].sum()
+ else:
+ if model.multi_head:
+ outputs = value.sum() + sum([t.sum() for t in logit])
+ else:
+ outputs = value.sum() + logit.sum()
+ output_check(model, outputs, 1)
+
+ for p in model.parameters():
+ p.grad = None
+ logit = model(inputs, mode='compute_actor')['logit']
+ if model.action_space == 'continuous':
+ logit = logit['mu'].sum() + logit['sigma'].sum()
+ elif model.action_space == 'hybrid':
+ logit = logit['action_type'].sum() + logit['action_args']['mu'].sum() + logit['action_args']['sigma'].sum()
+ output_check(model.actor, logit, model.action_shape)
+
+ for p in model.parameters():
+ p.grad = None
+ value = model(inputs, mode='compute_critic')['value']
+ assert value.shape == (B, )
+ output_check(model.critic, value, 1)
+
+
+@pytest.mark.unittest
+class TestDREAMERVAC:
+
+ def test_DREAMERVAC(self):
+ obs_shape = 8
+ act_shape = 6
+ model = DREAMERVAC(obs_shape, act_shape)
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('obs_shape, act_args, share_encoder', args)
+class TestVACGeneral:
+
+ def test_vac(self, obs_shape, act_args, share_encoder):
+ if isinstance(obs_shape, int):
+ inputs = torch.randn(B, obs_shape)
+ else:
+ inputs = torch.randn(B, *obs_shape)
+ model = VAC(obs_shape, action_shape=act_args[0], action_space=act_args[1], share_encoder=share_encoder)
+ model_check(model, inputs)
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('share_encoder', [(False, ), (True, )])
+class TestVACEncoder:
+
+ def test_vac_with_impala_encoder(self, share_encoder):
+ inputs = torch.randn(B, 4, 64, 64)
+ model = VAC(
+ obs_shape=(4, 64, 64),
+ action_shape=6,
+ action_space='discrete',
+ share_encoder=share_encoder,
+ impala_cnn_encoder=True
+ )
+ model_check(model, inputs)
+
+ def test_encoder_assignment(self, share_encoder):
+ inputs = torch.randn(B, 4, 64, 64)
+
+ special_encoder = ConvEncoder(obs_shape=(4, 64, 64), hidden_size_list=[16, 32, 32, 64])
+
+ model = VAC(
+ obs_shape=(4, 64, 64),
+ action_shape=6,
+ action_space='discrete',
+ share_encoder=share_encoder,
+ actor_head_hidden_size=64,
+ critic_head_hidden_size=64,
+ encoder=special_encoder
+ )
+ model_check(model, inputs)
diff --git a/DI-engine/ding/model/template/tests/test_vae.py b/DI-engine/ding/model/template/tests/test_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..600c4a66e4b99d83e0acd9b82341877b5ab24493
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_vae.py
@@ -0,0 +1,37 @@
+import pytest
+import torch
+from ding.torch_utils import is_differentiable
+from ding.model.template.vae import VanillaVAE
+
+
+@pytest.mark.unittest
+def test_vae():
+ batch_size = 32
+ action_shape = 6
+ original_action_shape = 2
+ obs_shape = 6
+ hidden_size_list = [256, 256]
+ inputs = {
+ 'action': torch.randn(batch_size, original_action_shape),
+ 'obs': torch.randn(batch_size, obs_shape),
+ 'next_obs': torch.randn(batch_size, obs_shape)
+ }
+
+ vae_model = VanillaVAE(original_action_shape, obs_shape, action_shape, hidden_size_list)
+ outputs = vae_model(inputs)
+
+ assert outputs['recons_action'].shape == (batch_size, original_action_shape)
+ assert outputs['prediction_residual'].shape == (batch_size, obs_shape)
+ assert isinstance(outputs['input'], dict)
+ assert outputs['mu'].shape == (batch_size, obs_shape)
+ assert outputs['log_var'].shape == (batch_size, obs_shape)
+ assert outputs['z'].shape == (batch_size, action_shape)
+
+ outputs_decode = vae_model.decode_with_obs(outputs['z'], inputs['obs'])
+ assert outputs_decode['reconstruction_action'].shape == (batch_size, original_action_shape)
+ assert outputs_decode['predition_residual'].shape == (batch_size, obs_shape)
+
+ outputs['original_action'] = inputs['action']
+ outputs['true_residual'] = inputs['next_obs'] - inputs['obs']
+ vae_loss = vae_model.loss_function(outputs, kld_weight=0.01, predict_weight=0.01)
+ is_differentiable(vae_loss['loss'], vae_model)
diff --git a/DI-engine/ding/model/template/tests/test_wqmix.py b/DI-engine/ding/model/template/tests/test_wqmix.py
new file mode 100644
index 0000000000000000000000000000000000000000..350b0f00d4ec7c24cde7bea7664ca27931d5fd4b
--- /dev/null
+++ b/DI-engine/ding/model/template/tests/test_wqmix.py
@@ -0,0 +1,49 @@
+import pytest
+import torch
+from ding.torch_utils import is_differentiable
+from ding.model.template.wqmix import MixerStar, WQMix
+
+args = [True, False]
+
+
+@pytest.mark.unittest
+def test_mixer_star():
+ agent_num, bs, embedding_dim = 4, 3, 32
+ agent_q = torch.randn(bs, agent_num)
+ state_embedding = torch.randn(bs, embedding_dim)
+ mixer_star = MixerStar(agent_num, embedding_dim, 64)
+ total_q = mixer_star(agent_q, state_embedding)
+ assert total_q.shape == (bs, )
+ loss = total_q.mean()
+ is_differentiable(loss, mixer_star)
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('is_q_star', args)
+def test_wqmix(is_q_star):
+ agent_num, bs, T = 4, 3, 8
+ obs_dim, global_obs_dim, action_dim = 32, 32 * 4, 9
+ embedding_dim = 64
+ wqmix_model = WQMix(agent_num, obs_dim, global_obs_dim, action_dim, [128, embedding_dim], 'gru')
+ data = {
+ 'obs': {
+ 'agent_state': torch.randn(T, bs, agent_num, obs_dim),
+ 'global_state': torch.randn(T, bs, global_obs_dim),
+ 'action_mask': torch.randint(0, 2, size=(T, bs, agent_num, action_dim))
+ },
+ 'prev_state': [[None for _ in range(agent_num)] for _ in range(bs)],
+ 'action': torch.randint(0, action_dim, size=(T, bs, agent_num))
+ }
+ output = wqmix_model(data, single_step=False, q_star=is_q_star)
+ assert set(output.keys()) == set(['total_q', 'logit', 'next_state', 'action_mask'])
+ assert output['total_q'].shape == (T, bs)
+ assert output['logit'].shape == (T, bs, agent_num, action_dim)
+ assert len(output['next_state']) == bs and all([len(n) == agent_num for n in output['next_state']])
+ print(output['next_state'][0][0]['h'].shape)
+ loss = output['total_q'].sum()
+ if is_q_star:
+ is_differentiable(loss, [wqmix_model._q_network_star, wqmix_model._mixer_star])
+ else:
+ is_differentiable(loss, [wqmix_model._q_network, wqmix_model._mixer])
+ data.pop('action')
+ output = wqmix_model(data, single_step=False, q_star=is_q_star)
diff --git a/DI-engine/ding/model/template/vac.py b/DI-engine/ding/model/template/vac.py
new file mode 100644
index 0000000000000000000000000000000000000000..29363d3570082948a6a7eb8ec8c4b57dc7666600
--- /dev/null
+++ b/DI-engine/ding/model/template/vac.py
@@ -0,0 +1,427 @@
+from typing import Union, Dict, Optional
+from easydict import EasyDict
+import torch
+import torch.nn as nn
+from copy import deepcopy
+from ding.utils import SequenceType, squeeze, MODEL_REGISTRY
+from ..common import ReparameterizationHead, RegressionHead, DiscreteHead, MultiHead, \
+ FCEncoder, ConvEncoder, IMPALAConvEncoder
+from ding.torch_utils.network.dreamer import ActionHead, DenseHead
+
+
+@MODEL_REGISTRY.register('vac')
+class VAC(nn.Module):
+ """
+ Overview:
+ The neural network and computation graph of algorithms related to (state) Value Actor-Critic (VAC), such as \
+ A2C/PPO/IMPALA. This model now supports discrete, continuous and hybrid action space. The VAC is composed of \
+ four parts: ``actor_encoder``, ``critic_encoder``, ``actor_head`` and ``critic_head``. Encoders are used to \
+ extract the feature from various observation. Heads are used to predict corresponding value or action logit. \
+ In high-dimensional observation space like 2D image, we often use a shared encoder for both ``actor_encoder`` \
+ and ``critic_encoder``. In low-dimensional observation space like 1D vector, we often use different encoders.
+ Interfaces:
+ ``__init__``, ``forward``, ``compute_actor``, ``compute_critic``, ``compute_actor_critic``.
+ """
+ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType, EasyDict],
+ action_space: str = 'discrete',
+ share_encoder: bool = True,
+ encoder_hidden_size_list: SequenceType = [128, 128, 64],
+ actor_head_hidden_size: int = 64,
+ actor_head_layer_num: int = 1,
+ critic_head_hidden_size: int = 64,
+ critic_head_layer_num: int = 1,
+ activation: Optional[nn.Module] = nn.ReLU(),
+ norm_type: Optional[str] = None,
+ sigma_type: Optional[str] = 'independent',
+ fixed_sigma_value: Optional[int] = 0.3,
+ bound_type: Optional[str] = None,
+ encoder: Optional[torch.nn.Module] = None,
+ impala_cnn_encoder: bool = False,
+ ) -> None:
+ """
+ Overview:
+ Initialize the VAC model according to corresponding input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
+ - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
+ - action_space (:obj:`str`): The type of different action spaces, including ['discrete', 'continuous', \
+ 'hybrid'], then will instantiate corresponding head, including ``DiscreteHead``, \
+ ``ReparameterizationHead``, and hybrid heads.
+ - share_encoder (:obj:`bool`): Whether to share observation encoders between actor and decoder.
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
+ the last element must match ``head_hidden_size``.
+ - actor_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``actor_head`` network, defaults \
+ to 64, it must match the last element of ``encoder_hidden_size_list``.
+ - actor_head_layer_num (:obj:`int`): The num of layers used in the ``actor_head`` network to compute action.
+ - critic_head_hidden_size (:obj:`Optional[int]`): The ``hidden_size`` of ``critic_head`` network, defaults \
+ to 64, it must match the last element of ``encoder_hidden_size_list``.
+ - critic_head_layer_num (:obj:`int`): The num of layers used in the ``critic_head`` network.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
+ if ``None`` then default set it to ``nn.ReLU()``.
+ - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
+ ``ding.torch_utils.fc_block`` for more details. you can choose one of ['BN', 'IN', 'SyncBN', 'LN']
+ - sigma_type (:obj:`Optional[str]`): The type of sigma in continuous action space, see \
+ ``ding.torch_utils.network.dreamer.ReparameterizationHead`` for more details, in A2C/PPO, it defaults \
+ to ``independent``, which means state-independent sigma parameters.
+ - fixed_sigma_value (:obj:`Optional[int]`): If ``sigma_type`` is ``fixed``, then use this value as sigma.
+ - bound_type (:obj:`Optional[str]`): The type of action bound methods in continuous action space, defaults \
+ to ``None``, which means no bound.
+ - encoder (:obj:`Optional[torch.nn.Module]`): The encoder module, defaults to ``None``, you can define \
+ your own encoder module and pass it into VAC to deal with different observation space.
+ - impala_cnn_encoder (:obj:`bool`): Whether to use IMPALA CNN encoder, defaults to ``False``.
+ """
+ super(VAC, self).__init__()
+ obs_shape: int = squeeze(obs_shape)
+ action_shape = squeeze(action_shape)
+ self.obs_shape, self.action_shape = obs_shape, action_shape
+ self.impala_cnn_encoder = impala_cnn_encoder
+ self.share_encoder = share_encoder
+
+ # Encoder Type
+ def new_encoder(outsize, activation):
+ if impala_cnn_encoder:
+ return IMPALAConvEncoder(obs_shape=obs_shape, channels=encoder_hidden_size_list, outsize=outsize)
+ else:
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ return FCEncoder(
+ obs_shape=obs_shape,
+ hidden_size_list=encoder_hidden_size_list,
+ activation=activation,
+ norm_type=norm_type
+ )
+ elif len(obs_shape) == 3:
+ return ConvEncoder(
+ obs_shape=obs_shape,
+ hidden_size_list=encoder_hidden_size_list,
+ activation=activation,
+ norm_type=norm_type
+ )
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own encoder".
+ format(obs_shape)
+ )
+
+ if self.share_encoder:
+ assert actor_head_hidden_size == critic_head_hidden_size, \
+ "actor and critic network head should have same size."
+ if encoder:
+ if isinstance(encoder, torch.nn.Module):
+ self.encoder = encoder
+ else:
+ raise ValueError("illegal encoder instance.")
+ else:
+ self.encoder = new_encoder(actor_head_hidden_size, activation)
+ else:
+ if encoder:
+ if isinstance(encoder, torch.nn.Module):
+ self.actor_encoder = encoder
+ self.critic_encoder = deepcopy(encoder)
+ else:
+ raise ValueError("illegal encoder instance.")
+ else:
+ self.actor_encoder = new_encoder(actor_head_hidden_size, activation)
+ self.critic_encoder = new_encoder(critic_head_hidden_size, activation)
+
+ # Head Type
+ self.critic_head = RegressionHead(
+ critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type
+ )
+ self.action_space = action_space
+ assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space
+ if self.action_space == 'continuous':
+ self.multi_head = False
+ self.actor_head = ReparameterizationHead(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ sigma_type=sigma_type,
+ activation=activation,
+ norm_type=norm_type,
+ bound_type=bound_type
+ )
+ elif self.action_space == 'discrete':
+ actor_head_cls = DiscreteHead
+ multi_head = not isinstance(action_shape, int)
+ self.multi_head = multi_head
+ if multi_head:
+ self.actor_head = MultiHead(
+ actor_head_cls,
+ actor_head_hidden_size,
+ action_shape,
+ layer_num=actor_head_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+ else:
+ self.actor_head = actor_head_cls(
+ actor_head_hidden_size,
+ action_shape,
+ actor_head_layer_num,
+ activation=activation,
+ norm_type=norm_type
+ )
+ elif self.action_space == 'hybrid': # HPPO
+ # hybrid action space: action_type(discrete) + action_args(continuous),
+ # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])}
+ action_shape.action_args_shape = squeeze(action_shape.action_args_shape)
+ action_shape.action_type_shape = squeeze(action_shape.action_type_shape)
+ actor_action_args = ReparameterizationHead(
+ actor_head_hidden_size,
+ action_shape.action_args_shape,
+ actor_head_layer_num,
+ sigma_type=sigma_type,
+ fixed_sigma_value=fixed_sigma_value,
+ activation=activation,
+ norm_type=norm_type,
+ bound_type=bound_type,
+ )
+ actor_action_type = DiscreteHead(
+ actor_head_hidden_size,
+ action_shape.action_type_shape,
+ actor_head_layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ )
+ self.actor_head = nn.ModuleList([actor_action_type, actor_action_args])
+
+ if self.share_encoder:
+ self.actor = [self.encoder, self.actor_head]
+ self.critic = [self.encoder, self.critic_head]
+ else:
+ self.actor = [self.actor_encoder, self.actor_head]
+ self.critic = [self.critic_encoder, self.critic_head]
+ # Convenient for calling some apis (e.g. self.critic.parameters()),
+ # but may cause misunderstanding when `print(self)`
+ self.actor = nn.ModuleList(self.actor)
+ self.critic = nn.ModuleList(self.critic)
+
+ def forward(self, x: torch.Tensor, mode: str) -> Dict:
+ """
+ Overview:
+ VAC forward computation graph, input observation tensor to predict state value or action logit. Different \
+ ``mode`` will forward with different network modules to get different outputs and save computation.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input observation tensor data.
+ - mode (:obj:`str`): The forward mode, all the modes are defined in the beginning of this class.
+ Returns:
+ - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph, whose key-values vary from \
+ different ``mode``.
+
+ Examples (Actor):
+ >>> model = VAC(64, 128)
+ >>> inputs = torch.randn(4, 64)
+ >>> actor_outputs = model(inputs,'compute_actor')
+ >>> assert actor_outputs['logit'].shape == torch.Size([4, 128])
+
+ Examples (Critic):
+ >>> model = VAC(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> critic_outputs = model(inputs,'compute_critic')
+ >>> assert actor_outputs['logit'].shape == torch.Size([4, 64])
+
+ Examples (Actor-Critic):
+ >>> model = VAC(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = model(inputs,'compute_actor_critic')
+ >>> assert critic_outputs['value'].shape == torch.Size([4])
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+
+ """
+ assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode)
+ return getattr(self, mode)(x)
+
+ def compute_actor(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ VAC forward computation graph for actor part, input observation tensor to predict action logit.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input observation tensor data.
+ Returns:
+ - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for actor, including ``logit``.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \
+ the same dimension real-value ranged tensor of possible action choices, and for continuous action \
+ space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \
+ same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \
+ and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``.
+ Shapes:
+ - logit (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape``
+
+ Examples:
+ >>> model = VAC(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> actor_outputs = model(inputs,'compute_actor')
+ >>> assert actor_outputs['logit'].shape == torch.Size([4, 64])
+ """
+ if self.share_encoder:
+ x = self.encoder(x)
+ else:
+ x = self.actor_encoder(x)
+
+ if self.action_space == 'discrete':
+ return self.actor_head(x)
+ elif self.action_space == 'continuous':
+ x = self.actor_head(x) # mu, sigma
+ return {'logit': x}
+ elif self.action_space == 'hybrid':
+ action_type = self.actor_head[0](x)
+ action_args = self.actor_head[1](x)
+ return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}}
+
+ def compute_critic(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ VAC forward computation graph for critic part, input observation tensor to predict state value.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input observation tensor data.
+ Returns:
+ - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for critic, including ``value``.
+ ReturnsKeys:
+ - value (:obj:`torch.Tensor`): The predicted state value tensor.
+ Shapes:
+ - value (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch size, (B, 1) is squeezed to (B, ).
+
+ Examples:
+ >>> model = VAC(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> critic_outputs = model(inputs,'compute_critic')
+ >>> assert critic_outputs['value'].shape == torch.Size([4])
+ """
+ if self.share_encoder:
+ x = self.encoder(x)
+ else:
+ x = self.critic_encoder(x)
+ x = self.critic_head(x)
+ return {'value': x['pred']}
+
+ def compute_actor_critic(self, x: torch.Tensor) -> Dict:
+ """
+ Overview:
+ VAC forward computation graph for both actor and critic part, input observation tensor to predict action \
+ logit and state value.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input observation tensor data.
+ Returns:
+ - outputs (:obj:`Dict`): The output dict of VAC's forward computation graph for both actor and critic, \
+ including ``logit`` and ``value``.
+ ReturnsKeys:
+ - logit (:obj:`torch.Tensor`): The predicted action logit tensor, for discrete action space, it will be \
+ the same dimension real-value ranged tensor of possible action choices, and for continuous action \
+ space, it will be the mu and sigma of the Gaussian distribution, and the number of mu and sigma is the \
+ same as the number of continuous actions. Hybrid action space is a kind of combination of discrete \
+ and continuous action space, so the logit will be a dict with ``action_type`` and ``action_args``.
+ - value (:obj:`torch.Tensor`): The predicted state value tensor.
+ Shapes:
+ - logit (:obj:`torch.Tensor`): :math:`(B, N)`, where B is batch size and N is ``action_shape``
+ - value (:obj:`torch.Tensor`): :math:`(B, )`, where B is batch size, (B, 1) is squeezed to (B, ).
+
+ Examples:
+ >>> model = VAC(64, 64)
+ >>> inputs = torch.randn(4, 64)
+ >>> outputs = model(inputs,'compute_actor_critic')
+ >>> assert critic_outputs['value'].shape == torch.Size([4])
+ >>> assert outputs['logit'].shape == torch.Size([4, 64])
+
+
+ .. note::
+ ``compute_actor_critic`` interface aims to save computation when shares encoder and return the combination \
+ dict output.
+ """
+ if self.share_encoder:
+ actor_embedding = critic_embedding = self.encoder(x)
+ else:
+ actor_embedding = self.actor_encoder(x)
+ critic_embedding = self.critic_encoder(x)
+
+ value = self.critic_head(critic_embedding)['pred']
+
+ if self.action_space == 'discrete':
+ logit = self.actor_head(actor_embedding)['logit']
+ return {'logit': logit, 'value': value}
+ elif self.action_space == 'continuous':
+ x = self.actor_head(actor_embedding)
+ return {'logit': x, 'value': value}
+ elif self.action_space == 'hybrid':
+ action_type = self.actor_head[0](actor_embedding)
+ action_args = self.actor_head[1](actor_embedding)
+ return {'logit': {'action_type': action_type['logit'], 'action_args': action_args}, 'value': value}
+
+
+@MODEL_REGISTRY.register('dreamervac')
+class DREAMERVAC(nn.Module):
+ """
+ Overview:
+ The neural network and computation graph of DreamerV3 (state) Value Actor-Critic (VAC).
+ This model now supports discrete, continuous action space.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+ mode = ['compute_actor', 'compute_critic', 'compute_actor_critic']
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType, EasyDict],
+ dyn_stoch=32,
+ dyn_deter=512,
+ dyn_discrete=32,
+ actor_layers=2,
+ value_layers=2,
+ units=512,
+ act='SiLU',
+ norm='LayerNorm',
+ actor_dist='normal',
+ actor_init_std=1.0,
+ actor_min_std=0.1,
+ actor_max_std=1.0,
+ actor_temp=0.1,
+ action_unimix_ratio=0.01,
+ ) -> None:
+ """
+ Overview:
+ Initialize the ``DREAMERVAC`` model according to arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
+ - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
+ """
+ super(DREAMERVAC, self).__init__()
+ obs_shape: int = squeeze(obs_shape)
+ action_shape = squeeze(action_shape)
+ self.obs_shape, self.action_shape = obs_shape, action_shape
+
+ if dyn_discrete:
+ feat_size = dyn_stoch * dyn_discrete + dyn_deter
+ else:
+ feat_size = dyn_stoch + dyn_deter
+ self.actor = ActionHead(
+ feat_size, # pytorch version
+ action_shape,
+ actor_layers,
+ units,
+ act,
+ norm,
+ actor_dist,
+ actor_init_std,
+ actor_min_std,
+ actor_max_std,
+ actor_temp,
+ outscale=1.0,
+ unimix_ratio=action_unimix_ratio,
+ )
+ self.critic = DenseHead(
+ feat_size, # pytorch version
+ (255, ),
+ value_layers,
+ units,
+ 'SiLU', # act
+ 'LN', # norm
+ 'twohot_symlog',
+ outscale=0.0,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ )
diff --git a/DI-engine/ding/model/template/vae.py b/DI-engine/ding/model/template/vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3181361c7415d015cd180a76c0661223c699b79
--- /dev/null
+++ b/DI-engine/ding/model/template/vae.py
@@ -0,0 +1,223 @@
+"""Credit: Note the following vae model is modified from https://github.com/AntixK/PyTorch-VAE"""
+
+import torch
+from torch.nn import functional as F
+from torch import nn
+from abc import abstractmethod
+from typing import List, Dict, Callable, Union, Any, TypeVar, Tuple, Optional
+from ding.utils.type_helper import Tensor
+
+
+class VanillaVAE(nn.Module):
+ """
+ Overview:
+ Implementation of Vanilla variational autoencoder for action reconstruction.
+ Interfaces:
+ ``__init__``, ``encode``, ``decode``, ``decode_with_obs``, ``reparameterize``, \
+ ``forward``, ``loss_function`` .
+ """
+
+ def __init__(
+ self,
+ action_shape: int,
+ obs_shape: int,
+ latent_size: int,
+ hidden_dims: List = [256, 256],
+ **kwargs
+ ) -> None:
+ super(VanillaVAE, self).__init__()
+ self.action_shape = action_shape
+ self.obs_shape = obs_shape
+ self.latent_size = latent_size
+ self.hidden_dims = hidden_dims
+
+ # Build Encoder
+ self.encode_action_head = nn.Sequential(nn.Linear(self.action_shape, hidden_dims[0]), nn.ReLU())
+ self.encode_obs_head = nn.Sequential(nn.Linear(self.obs_shape, hidden_dims[0]), nn.ReLU())
+
+ self.encode_common = nn.Sequential(nn.Linear(hidden_dims[0], hidden_dims[1]), nn.ReLU())
+ self.encode_mu_head = nn.Linear(hidden_dims[1], latent_size)
+ self.encode_logvar_head = nn.Linear(hidden_dims[1], latent_size)
+
+ # Build Decoder
+ self.decode_action_head = nn.Sequential(nn.Linear(latent_size, hidden_dims[-1]), nn.ReLU())
+ self.decode_common = nn.Sequential(nn.Linear(hidden_dims[-1], hidden_dims[-2]), nn.ReLU())
+ # TODO(pu): tanh
+ self.decode_reconst_action_head = nn.Sequential(nn.Linear(hidden_dims[-2], self.action_shape), nn.Tanh())
+
+ # residual prediction
+ self.decode_prediction_head_layer1 = nn.Sequential(nn.Linear(hidden_dims[-2], hidden_dims[-2]), nn.ReLU())
+ self.decode_prediction_head_layer2 = nn.Linear(hidden_dims[-2], self.obs_shape)
+
+ self.obs_encoding = None
+
+ def encode(self, input: Dict[str, Tensor]) -> Dict[str, Any]:
+ """
+ Overview:
+ Encodes the input by passing through the encoder network and returns the latent codes.
+ Arguments:
+ - input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) and \
+ `action` (:obj:`torch.Tensor`), representing the observation and agent's action respectively.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``mu`` (:obj:`torch.Tensor`), \
+ ``log_var`` (:obj:`torch.Tensor`) and ``obs_encoding`` (:obj:`torch.Tensor`) \
+ representing latent codes.
+ Shapes:
+ - obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``.
+ - action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``.
+ - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
+ - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
+ - obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim``.
+ """
+ action_encoding = self.encode_action_head(input['action'])
+ obs_encoding = self.encode_obs_head(input['obs'])
+ # obs_encoding = self.condition_obs(input['obs']) # TODO(pu): using a different network
+ input = obs_encoding * action_encoding # TODO(pu): what about add, cat?
+ result = self.encode_common(input)
+
+ # Split the result into mu and var components
+ # of the latent Gaussian distribution
+ mu = self.encode_mu_head(result)
+ log_var = self.encode_logvar_head(result)
+
+ return {'mu': mu, 'log_var': log_var, 'obs_encoding': obs_encoding}
+
+ def decode(self, z: Tensor, obs_encoding: Tensor) -> Dict[str, Any]:
+ """
+ Overview:
+ Maps the given latent action and obs_encoding onto the original action space.
+ Arguments:
+ - z (:obj:`torch.Tensor`): the sampled latent action
+ - obs_encoding (:obj:`torch.Tensor`): observation encoding
+ Returns:
+ - outputs (:obj:`Dict`): DQN forward outputs, such as q_value.
+ ReturnsKeys:
+ - reconstruction_action (:obj:`torch.Tensor`): reconstruction_action.
+ - predition_residual (:obj:`torch.Tensor`): predition_residual.
+ Shapes:
+ - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size``
+ - obs_encoding (:obj:`torch.Tensor`): :math:`(B, H)`, where B is batch size and H is ``hidden dim``
+ """
+ action_decoding = self.decode_action_head(torch.tanh(z)) # NOTE: tanh, here z is not bounded
+ action_obs_decoding = action_decoding * obs_encoding
+ action_obs_decoding_tmp = self.decode_common(action_obs_decoding)
+
+ reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp)
+ predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp)
+ predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp)
+ return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual}
+
+ def decode_with_obs(self, z: Tensor, obs: Tensor) -> Dict[str, Any]:
+ """
+ Overview:
+ Maps the given latent action and obs onto the original action space.
+ Using the method self.encode_obs_head(obs) to get the obs_encoding.
+ Arguments:
+ - z (:obj:`torch.Tensor`): the sampled latent action
+ - obs (:obj:`torch.Tensor`): observation
+ Returns:
+ - outputs (:obj:`Dict`): DQN forward outputs, such as q_value.
+ ReturnsKeys:
+ - reconstruction_action (:obj:`torch.Tensor`): the action reconstructed by VAE .
+ - predition_residual (:obj:`torch.Tensor`): the observation predicted by VAE.
+ Shapes:
+ - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size``
+ - obs (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``obs_shape``
+ """
+ obs_encoding = self.encode_obs_head(obs)
+ # TODO(pu): here z is already bounded, z is produced by td3 policy, it has been operated by tanh
+ action_decoding = self.decode_action_head(z)
+ action_obs_decoding = action_decoding * obs_encoding
+ action_obs_decoding_tmp = self.decode_common(action_obs_decoding)
+ reconstruction_action = self.decode_reconst_action_head(action_obs_decoding_tmp)
+ predition_residual_tmp = self.decode_prediction_head_layer1(action_obs_decoding_tmp)
+ predition_residual = self.decode_prediction_head_layer2(predition_residual_tmp)
+
+ return {'reconstruction_action': reconstruction_action, 'predition_residual': predition_residual}
+
+ def reparameterize(self, mu: Tensor, logvar: Tensor) -> Tensor:
+ """
+ Overview:
+ Reparameterization trick to sample from N(mu, var) from N(0,1).
+ Arguments:
+ - mu (:obj:`torch.Tensor`): Mean of the latent Gaussian
+ - logvar (:obj:`torch.Tensor`): Standard deviation of the latent Gaussian
+ Shapes:
+ - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size``
+ - logvar (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latnet_size``
+ """
+ std = torch.exp(0.5 * logvar)
+ eps = torch.randn_like(std)
+ return eps * std + mu
+
+ def forward(self, input: Dict[str, Tensor], **kwargs) -> dict:
+ """
+ Overview:
+ Encode the input, reparameterize `mu` and `log_var`, decode `obs_encoding`.
+ Argumens:
+ - input (:obj:`Dict`): Dict containing keywords `obs` (:obj:`torch.Tensor`) \
+ and `action` (:obj:`torch.Tensor`), representing the observation \
+ and agent's action respectively.
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keywords ``recons_action`` \
+ (:obj:`torch.Tensor`), ``prediction_residual`` (:obj:`torch.Tensor`), \
+ ``input`` (:obj:`torch.Tensor`), ``mu`` (:obj:`torch.Tensor`), \
+ ``log_var`` (:obj:`torch.Tensor`) and ``z`` (:obj:`torch.Tensor`).
+ Shapes:
+ - recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``.
+ - prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, \
+ where B is batch size and O is ``observation dim``.
+ - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
+ - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
+ - z (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent_size``
+ """
+
+ encode_output = self.encode(input)
+ z = self.reparameterize(encode_output['mu'], encode_output['log_var'])
+ decode_output = self.decode(z, encode_output['obs_encoding'])
+ return {
+ 'recons_action': decode_output['reconstruction_action'],
+ 'prediction_residual': decode_output['predition_residual'],
+ 'input': input,
+ 'mu': encode_output['mu'],
+ 'log_var': encode_output['log_var'],
+ 'z': z
+ }
+
+ def loss_function(self, args: Dict[str, Tensor], **kwargs) -> Dict[str, Tensor]:
+ """
+ Overview:
+ Computes the VAE loss function.
+ Arguments:
+ - args (:obj:`Dict[str, Tensor]`): Dict containing keywords ``recons_action``, ``prediction_residual`` \
+ ``original_action``, ``mu``, ``log_var`` and ``true_residual``.
+ - kwargs (:obj:`Dict`): Dict containing keywords ``kld_weight`` and ``predict_weight``.
+ Returns:
+ - outputs (:obj:`Dict[str, Tensor]`): Dict containing different ``loss`` results, including ``loss``, \
+ ``reconstruction_loss``, ``kld_loss``, ``predict_loss``.
+ Shapes:
+ - recons_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size \
+ and A is ``action dim``.
+ - prediction_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size \
+ and O is ``observation dim``.
+ - original_action (:obj:`torch.Tensor`): :math:`(B, A)`, where B is batch size and A is ``action dim``.
+ - mu (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
+ - log_var (:obj:`torch.Tensor`): :math:`(B, L)`, where B is batch size and L is ``latent size``.
+ - true_residual (:obj:`torch.Tensor`): :math:`(B, O)`, where B is batch size and O is ``observation dim``.
+ """
+ recons_action = args['recons_action']
+ prediction_residual = args['prediction_residual']
+ original_action = args['original_action']
+ mu = args['mu']
+ log_var = args['log_var']
+ true_residual = args['true_residual']
+
+ kld_weight = kwargs['kld_weight']
+ predict_weight = kwargs['predict_weight']
+
+ recons_loss = F.mse_loss(recons_action, original_action)
+ kld_loss = torch.mean(-0.5 * torch.sum(1 + log_var - mu ** 2 - log_var.exp(), dim=1), dim=0)
+ predict_loss = F.mse_loss(prediction_residual, true_residual)
+
+ loss = recons_loss + kld_weight * kld_loss + predict_weight * predict_loss
+ return {'loss': loss, 'reconstruction_loss': recons_loss, 'kld_loss': kld_loss, 'predict_loss': predict_loss}
diff --git a/DI-engine/ding/model/template/wqmix.py b/DI-engine/ding/model/template/wqmix.py
new file mode 100644
index 0000000000000000000000000000000000000000..f80aa25d4ae87f734ee971439ebb811a63a270c3
--- /dev/null
+++ b/DI-engine/ding/model/template/wqmix.py
@@ -0,0 +1,255 @@
+from typing import Union, List
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from functools import reduce
+from ding.utils import list_split, MODEL_REGISTRY
+from ding.torch_utils.network.nn_module import fc_block, MLP
+from ding.torch_utils.network.transformer import ScaledDotProductAttention
+from .q_learning import DRQN
+from ding.model.template.qmix import Mixer
+
+
+class MixerStar(nn.Module):
+ """
+ Overview:
+ Mixer network for Q_star in WQMIX(https://arxiv.org/abs/2006.10800), which mix up the independent q_value of \
+ each agent to a total q_value and is diffrent from the QMIX's mixer network, \
+ here the mixing network is a feedforward network with 3 hidden layers of 256 dim. \
+ This Q_star mixing network is not constrained to be monotonic by using non-negative weights and \
+ having the state and agent_q be inputs, as opposed to having hypernetworks take the state as input \
+ and generate the weights in QMIX.
+ Interface:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(self, agent_num: int, state_dim: int, mixing_embed_dim: int) -> None:
+ """
+ Overview:
+ Initialize the mixer network of Q_star in WQMIX.
+ Arguments:
+ - agent_num (:obj:`int`): The number of agent, e.g., 8.
+ - state_dim(:obj:`int`): The dimension of global observation state, e.g., 16.
+ - mixing_embed_dim (:obj:`int`): The dimension of mixing state emdedding, e.g., 128.
+ """
+ super(MixerStar, self).__init__()
+ self.agent_num = agent_num
+ self.state_dim = state_dim
+ self.embed_dim = mixing_embed_dim
+ self.input_dim = self.agent_num + self.state_dim # shape N+A
+ non_lin = nn.ReLU()
+ self.net = nn.Sequential(
+ nn.Linear(self.input_dim, self.embed_dim), non_lin, nn.Linear(self.embed_dim, self.embed_dim), non_lin,
+ nn.Linear(self.embed_dim, self.embed_dim), non_lin, nn.Linear(self.embed_dim, 1)
+ )
+
+ # V(s) instead of a bias for the last layers
+ self.V = nn.Sequential(nn.Linear(self.state_dim, self.embed_dim), non_lin, nn.Linear(self.embed_dim, 1))
+
+ def forward(self, agent_qs: torch.FloatTensor, states: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ Overview:
+ Forward computation graph of the mixer network for Q_star in WQMIX. This mixer network for \
+ is a feed-forward network that takes the state and the appropriate actions' utilities as input.
+ Arguments:
+ - agent_qs (:obj:`torch.FloatTensor`): The independent q_value of each agent.
+ - states (:obj:`torch.FloatTensor`): The emdedding vector of global state.
+ Returns:
+ - q_tot (:obj:`torch.FloatTensor`): The total mixed q_value.
+ Shapes:
+ - agent_qs (:obj:`torch.FloatTensor`): :math:`(T,B, N)`, where T is timestep, \
+ B is batch size, A is agent_num, N is obs_shape.
+ - states (:obj:`torch.FloatTensor`): :math:`(T, B, M)`, where M is global_obs_shape.
+ - q_tot (:obj:`torch.FloatTensor`): :math:`(T, B, )`.
+ """
+ # in below annotations about the shape of the variables, T is timestep,
+ # B is batch_size A is agent_num, N is obs_shape, for example,
+ # in 3s5z, we can set T=10, B=32, A=8, N=216
+ bs = agent_qs.shape[:-1] # (T*B, A)
+ states = states.reshape(-1, self.state_dim) # T*B, N),
+ agent_qs = agent_qs.reshape(-1, self.agent_num) # (T, B, A) -> (T*B, A)
+ inputs = torch.cat([states, agent_qs], dim=1) # (T*B, N) (T*B, A)-> (T*B, N+A)
+ advs = self.net(inputs) # (T*B, 1)
+ vs = self.V(states) # (T*B, 1)
+ y = advs + vs
+ q_tot = y.view(*bs) # (T*B, 1) -> (T, B)
+
+ return q_tot
+
+
+@MODEL_REGISTRY.register('wqmix')
+class WQMix(nn.Module):
+ """
+ Overview:
+ WQMIX (https://arxiv.org/abs/2006.10800) network, There are two components: \
+ 1) Q_tot, which is same as QMIX network and composed of agent Q network and mixer network. \
+ 2) An unrestricted joint action Q_star, which is composed of agent Q network and mixer_star network. \
+ The QMIX paper mentions that all agents share local Q network parameters, so only one Q network is initialized \
+ in Q_tot or Q_star.
+ Interface:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ agent_num: int,
+ obs_shape: int,
+ global_obs_shape: int,
+ action_shape: int,
+ hidden_size_list: list,
+ lstm_type: str = 'gru',
+ dueling: bool = False
+ ) -> None:
+ """
+ Overview:
+ Initialize WQMIX neural network according to arguments, i.e. agent Q network and mixer, \
+ Q_star network and mixer_star.
+ Arguments:
+ - agent_num (:obj:`int`): The number of agent, such as 8.
+ - obs_shape (:obj:`int`): The dimension of each agent's observation state, such as 8.
+ - global_obs_shape (:obj:`int`): The dimension of global observation state, such as 8.
+ - action_shape (:obj:`int`): The dimension of action shape, such as 6.
+ - hidden_size_list (:obj:`list`): The list of hidden size for ``q_network``, \
+ the last element must match mixer's ``mixing_embed_dim``.
+ - lstm_type (:obj:`str`): The type of RNN module in ``q_network``, now support \
+ ['normal', 'pytorch', 'gru'], default to gru.
+ - dueling (:obj:`bool`): Whether choose ``DuelingHead`` (True) or ``DiscreteHead (False)``, \
+ default to False.
+ """
+ super(WQMix, self).__init__()
+ self._act = nn.ReLU()
+ self._q_network = DRQN(obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling)
+ self._q_network_star = DRQN(obs_shape, action_shape, hidden_size_list, lstm_type=lstm_type, dueling=dueling)
+ embedding_size = hidden_size_list[-1]
+ self._mixer = Mixer(agent_num, global_obs_shape, mixing_embed_dim=embedding_size)
+ self._mixer_star = MixerStar(
+ agent_num, global_obs_shape, mixing_embed_dim=256
+ ) # the mixing network of Q_star is a feedforward network with 3 hidden layers of 256 dim
+ self._global_state_encoder = nn.Identity() # nn.Sequential()
+
+ def forward(self, data: dict, single_step: bool = True, q_star: bool = False) -> dict:
+ """
+ Overview:
+ Forward computation graph of qmix network. Input dict including time series observation and \
+ related data to predict total q_value and each agent q_value. Determine whether to calculate \
+ Q_tot or Q_star based on the ``q_star`` parameter.
+ Arguments:
+ - data (:obj:`dict`): Input data dict with keys ['obs', 'prev_state', 'action'].
+ - agent_state (:obj:`torch.Tensor`): Time series local observation data of each agents.
+ - global_state (:obj:`torch.Tensor`): Time series global observation data.
+ - prev_state (:obj:`list`): Previous rnn state for ``q_network`` or ``_q_network_star``.
+ - action (:obj:`torch.Tensor` or None): If action is None, use argmax q_value index as action to\
+ calculate ``agent_q_act``.
+ - single_step (:obj:`bool`): Whether single_step forward, if so, add timestep dim before forward and\
+ remove it after forward.
+ - Q_star (:obj:`bool`): Whether Q_star network forward. If True, using the Q_star network, where the\
+ agent networks have the same architecture as Q network but do not share parameters and the mixing\
+ network is a feedforward network with 3 hidden layers of 256 dim; if False, using the Q network,\
+ same as the Q network in Qmix paper.
+ Returns:
+ - ret (:obj:`dict`): Output data dict with keys [``total_q``, ``logit``, ``next_state``].
+ - total_q (:obj:`torch.Tensor`): Total q_value, which is the result of mixer network.
+ - agent_q (:obj:`torch.Tensor`): Each agent q_value.
+ - next_state (:obj:`list`): Next rnn state.
+ Shapes:
+ - agent_state (:obj:`torch.Tensor`): :math:`(T, B, A, N)`, where T is timestep, B is batch_size\
+ A is agent_num, N is obs_shape.
+ - global_state (:obj:`torch.Tensor`): :math:`(T, B, M)`, where M is global_obs_shape.
+ - prev_state (:obj:`list`): math:`(T, B, A)`, a list of length B, and each element is a list of length A.
+ - action (:obj:`torch.Tensor`): :math:`(T, B, A)`.
+ - total_q (:obj:`torch.Tensor`): :math:`(T, B)`.
+ - agent_q (:obj:`torch.Tensor`): :math:`(T, B, A, P)`, where P is action_shape.
+ - next_state (:obj:`list`): math:`(T, B, A)`, a list of length B, and each element is a list of length A.
+ """
+ if q_star: # forward using Q_star network
+ agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[
+ 'prev_state']
+ action = data.get('action', None)
+ if single_step:
+ agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0)
+ T, B, A = agent_state.shape[:3]
+ assert len(prev_state) == B and all(
+ [len(p) == A for p in prev_state]
+ ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0]))
+ prev_state = reduce(lambda x, y: x + y, prev_state)
+ agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
+ output = self._q_network_star(
+ {
+ 'obs': agent_state,
+ 'prev_state': prev_state,
+ 'enable_fast_timestep': True
+ }
+ ) # here is the forward pass of the agent networks of Q_star
+ agent_q, next_state = output['logit'], output['next_state']
+ next_state, _ = list_split(next_state, step=A)
+ agent_q = agent_q.reshape(T, B, A, -1)
+ if action is None:
+ # For target forward process
+ if len(data['obs']['action_mask'].shape) == 3:
+ action_mask = data['obs']['action_mask'].unsqueeze(0)
+ else:
+ action_mask = data['obs']['action_mask']
+ agent_q[action_mask == 0.0] = -9999999
+ action = agent_q.argmax(dim=-1)
+ agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
+ agent_q_act = agent_q_act.squeeze(-1) # T, B, A
+
+ global_state_embedding = self._global_state_encoder(global_state)
+ total_q = self._mixer_star(
+ agent_q_act, global_state_embedding
+ ) # here is the forward pass of the mixer networks of Q_star
+
+ if single_step:
+ total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0)
+ return {
+ 'total_q': total_q,
+ 'logit': agent_q,
+ 'next_state': next_state,
+ 'action_mask': data['obs']['action_mask']
+ }
+ else: # forward using Q network
+ agent_state, global_state, prev_state = data['obs']['agent_state'], data['obs']['global_state'], data[
+ 'prev_state']
+ action = data.get('action', None)
+ if single_step:
+ agent_state, global_state = agent_state.unsqueeze(0), global_state.unsqueeze(0)
+ T, B, A = agent_state.shape[:3]
+ assert len(prev_state) == B and all(
+ [len(p) == A for p in prev_state]
+ ), '{}-{}-{}-{}'.format([type(p) for p in prev_state], B, A, len(prev_state[0]))
+ prev_state = reduce(lambda x, y: x + y, prev_state)
+ agent_state = agent_state.reshape(T, -1, *agent_state.shape[3:])
+ output = self._q_network(
+ {
+ 'obs': agent_state,
+ 'prev_state': prev_state,
+ 'enable_fast_timestep': True
+ }
+ ) # here is the forward pass of the agent networks of Q
+ agent_q, next_state = output['logit'], output['next_state']
+ next_state, _ = list_split(next_state, step=A)
+ agent_q = agent_q.reshape(T, B, A, -1)
+ if action is None:
+ # For target forward process
+ if len(data['obs']['action_mask'].shape) == 3:
+ action_mask = data['obs']['action_mask'].unsqueeze(0)
+ else:
+ action_mask = data['obs']['action_mask']
+ agent_q[action_mask == 0.0] = -9999999
+ action = agent_q.argmax(dim=-1)
+ agent_q_act = torch.gather(agent_q, dim=-1, index=action.unsqueeze(-1))
+ agent_q_act = agent_q_act.squeeze(-1) # T, B, A
+
+ global_state_embedding = self._global_state_encoder(global_state)
+ total_q = self._mixer(
+ agent_q_act, global_state_embedding
+ ) # here is the forward pass of the mixer networks of Q
+
+ if single_step:
+ total_q, agent_q = total_q.squeeze(0), agent_q.squeeze(0)
+ return {
+ 'total_q': total_q,
+ 'logit': agent_q,
+ 'next_state': next_state,
+ 'action_mask': data['obs']['action_mask']
+ }
diff --git a/DI-engine/ding/model/wrapper/__init__.py b/DI-engine/ding/model/wrapper/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..24d621e97366150acbb0378df27bf473f0079d39
--- /dev/null
+++ b/DI-engine/ding/model/wrapper/__init__.py
@@ -0,0 +1 @@
+from .model_wrappers import model_wrap, register_wrapper, IModelWrapper
diff --git a/DI-engine/ding/model/wrapper/model_wrappers.py b/DI-engine/ding/model/wrapper/model_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e4275873274237ab3c47bcfb89c55dc0c6a2fa88
--- /dev/null
+++ b/DI-engine/ding/model/wrapper/model_wrappers.py
@@ -0,0 +1,1020 @@
+from typing import Any, Tuple, Callable, Optional, List, Dict, Union
+from abc import ABC
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributions import Categorical, Independent, Normal
+from ding.torch_utils import get_tensor_data, zeros_like
+from ding.rl_utils import create_noise_generator
+from ding.utils.data import default_collate
+
+
+class IModelWrapper(ABC):
+ """
+ Overview:
+ The basic interface class of model wrappers. Model wrapper is a wrapper class of torch.nn.Module model, which \
+ is used to add some extra operations for the wrapped model, such as hidden state maintain for RNN-base model, \
+ argmax action selection for discrete action space, etc.
+ Interfaces:
+ ``__init__``, ``__getattr__``, ``info``, ``reset``, ``forward``.
+ """
+
+ def __init__(self, model: nn.Module) -> None:
+ """
+ Overview:
+ Initialize model and other necessary member variabls in the model wrapper.
+ """
+ self._model = model
+
+ def __getattr__(self, key: str) -> Any:
+ """
+ Overview:
+ Get original attrbutes of torch.nn.Module model, such as variables and methods defined in model.
+ Arguments:
+ - key (:obj:`str`): The string key to query.
+ Returns:
+ - ret (:obj:`Any`): The queried attribute.
+ """
+ return getattr(self._model, key)
+
+ def info(self, attr_name: str) -> str:
+ """
+ Overview:
+ Get some string information of the indicated ``attr_name``, which is used for debug wrappers.
+ This method will recursively search for the indicated ``attr_name``.
+ Arguments:
+ - attr_name (:obj:`str`): The string key to query information.
+ Returns:
+ - info_string (:obj:`str`): The information string of the indicated ``attr_name``.
+ """
+ if attr_name in dir(self):
+ if isinstance(self._model, IModelWrapper):
+ return '{} {}'.format(self.__class__.__name__, self._model.info(attr_name))
+ else:
+ if attr_name in dir(self._model):
+ return '{} {}'.format(self.__class__.__name__, self._model.__class__.__name__)
+ else:
+ return '{}'.format(self.__class__.__name__)
+ else:
+ if isinstance(self._model, IModelWrapper):
+ return '{}'.format(self._model.info(attr_name))
+ else:
+ return '{}'.format(self._model.__class__.__name__)
+
+ def reset(self, data_id: List[int] = None, **kwargs) -> None:
+ """
+ Overview
+ Basic interface, reset some stateful varaibles in the model wrapper, such as hidden state of RNN.
+ Here we do nothing and just implement this interface method.
+ Other derived model wrappers can override this method to add some extra operations.
+ Arguments:
+ - data_id (:obj:`List[int]`): The data id list to reset. If None, reset all data. In practice, \
+ model wrappers often needs to maintain some stateful variables for each data trajectory, \
+ so we leave this ``data_id`` argument to reset the stateful variables of the indicated data.
+ """
+ pass
+
+ def forward(self, *args, **kwargs) -> Any:
+ """
+ Overview:
+ Basic interface, call the wrapped model's forward method. Other derived model wrappers can override this \
+ method to add some extra operations.
+ """
+ return self._model.forward(*args, **kwargs)
+
+
+class BaseModelWrapper(IModelWrapper):
+ """
+ Overview:
+ Placeholder class for the model wrapper. This class is used to wrap the model without any extra operations, \
+ including a empty ``reset`` method and a ``forward`` method which directly call the wrapped model's forward.
+ To keep the consistency of the model wrapper interface, we use this class to wrap the model without specific \
+ operations in the implementation of DI-engine's policy.
+ """
+ pass
+
+
+class HiddenStateWrapper(IModelWrapper):
+ """
+ Overview:
+ Maintain the hidden state for RNN-base model. Each sample in a batch has its own state.
+ Interfaces:
+ ``__init__``, ``reset``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ model: Any,
+ state_num: int,
+ save_prev_state: bool = False,
+ init_fn: Callable = lambda: None,
+ ) -> None:
+ """
+ Overview:
+ Maintain the hidden state for RNN-base model. Each sample in a batch has its own state. \
+ Init the maintain state and state function; Then wrap the ``model.forward`` method with auto \
+ saved data ['prev_state'] input, and create the ``model.reset`` method.
+ Arguments:
+ - model(:obj:`Any`): Wrapped model class, should contain forward method.
+ - state_num (:obj:`int`): Number of states to process.
+ - save_prev_state (:obj:`bool`): Whether to output the prev state in output.
+ - init_fn (:obj:`Callable`): The function which is used to init every hidden state when init and reset, \
+ default return None for hidden states.
+
+ .. note::
+ 1. This helper must deal with an actual batch with some parts of samples, e.g: 6 samples of state_num 8.
+ 2. This helper must deal with the single sample state reset.
+ """
+ super().__init__(model)
+ self._state_num = state_num
+ # This is to maintain hidden states (when it comes to this wrapper, \
+ # map self._state into data['prev_value] and update next_state, store in self._state)
+ self._state = {i: init_fn() for i in range(state_num)}
+ self._save_prev_state = save_prev_state
+ self._init_fn = init_fn
+
+ def forward(self, data, **kwargs):
+ state_id = kwargs.pop('data_id', None)
+ valid_id = kwargs.pop('valid_id', None) # None, not used in any code in DI-engine
+ data, state_info = self.before_forward(data, state_id) # update data['prev_state'] with self._state
+ output = self._model.forward(data, **kwargs)
+ h = output.pop('next_state', None)
+ if h is not None:
+ self.after_forward(h, state_info, valid_id) # this is to store the 'next hidden state' for each time step
+ if self._save_prev_state:
+ prev_state = get_tensor_data(data['prev_state'])
+ # for compatibility, because of the incompatibility between None and torch.Tensor
+ for i in range(len(prev_state)):
+ if prev_state[i] is None:
+ prev_state[i] = zeros_like(h[0])
+ output['prev_state'] = prev_state
+ return output
+
+ def reset(self, *args, **kwargs):
+ state = kwargs.pop('state', None)
+ state_id = kwargs.get('data_id', None)
+ self.reset_state(state, state_id)
+ if hasattr(self._model, 'reset'):
+ return self._model.reset(*args, **kwargs)
+
+ def reset_state(self, state: Optional[list] = None, state_id: Optional[list] = None) -> None:
+ if state_id is None: # train: init all states
+ state_id = [i for i in range(self._state_num)]
+ if state is None: # collect: init state that are done
+ state = [self._init_fn() for i in range(len(state_id))]
+ assert len(state) == len(state_id), '{}/{}'.format(len(state), len(state_id))
+ for idx, s in zip(state_id, state):
+ self._state[idx] = s
+
+ def before_forward(self, data: dict, state_id: Optional[list]) -> Tuple[dict, dict]:
+ if state_id is None:
+ state_id = [i for i in range(self._state_num)]
+
+ state_info = {idx: self._state[idx] for idx in state_id}
+ data['prev_state'] = list(state_info.values())
+ return data, state_info
+
+ def after_forward(self, h: Any, state_info: dict, valid_id: Optional[list] = None) -> None:
+ assert len(h) == len(state_info), '{}/{}'.format(len(h), len(state_info))
+ for i, idx in enumerate(state_info.keys()):
+ if valid_id is None:
+ self._state[idx] = h[i]
+ else:
+ if idx in valid_id:
+ self._state[idx] = h[i]
+
+
+class TransformerInputWrapper(IModelWrapper):
+
+ def __init__(self, model: Any, seq_len: int, init_fn: Callable = lambda: None) -> None:
+ """
+ Overview:
+ Given N the length of the sequences received by a Transformer model, maintain the last N-1 input
+ observations. In this way we can provide at each step all the observations needed by Transformer to
+ compute its output. We need this because some methods such as 'collect' and 'evaluate' only provide the
+ model 1 observation per step and don't have memory of past observations, but Transformer needs a sequence
+ of N observations. The wrapper method ``forward`` will save the input observation in a FIFO memory of
+ length N and the method ``reset`` will reset the memory. The empty memory spaces will be initialized
+ with 'init_fn' or zero by calling the method ``reset_input``. Since different env can terminate at
+ different steps, the method ``reset_memory_entry`` only initializes the memory of specific environments in
+ the batch size.
+ Arguments:
+ - model (:obj:`Any`): Wrapped model class, should contain forward method.
+ - seq_len (:obj:`int`): Number of past observations to remember.
+ - init_fn (:obj:`Callable`): The function which is used to init every memory locations when init and reset.
+ """
+ super().__init__(model)
+ self.seq_len = seq_len
+ self._init_fn = init_fn
+ self.obs_memory = None # shape (N, bs, *obs_shape)
+ self.init_obs = None # sample of observation used to initialize the memory
+ self.bs = None
+ self.memory_idx = [] # len bs, index of where to put the next element in the sequence for each batch
+
+ def forward(self,
+ input_obs: torch.Tensor,
+ only_last_logit: bool = True,
+ data_id: List = None,
+ **kwargs) -> Dict[str, torch.Tensor]:
+ """
+ Arguments:
+ - input_obs (:obj:`torch.Tensor`): Input observation without sequence shape: ``(bs, *obs_shape)``.
+ - only_last_logit (:obj:`bool`): if True 'logit' only contains the output corresponding to the current \
+ observation (shape: bs, embedding_dim), otherwise logit has shape (seq_len, bs, embedding_dim).
+ - data_id (:obj:`List`): id of the envs that are currently running. Memory update and logits return has \
+ only effect for those environments. If `None` it is considered that all envs are running.
+ Returns:
+ - Dictionary containing the input_sequence 'input_seq' stored in memory and the transformer output 'logit'.
+ """
+ if self.obs_memory is None:
+ self.reset_input(torch.zeros_like(input_obs)) # init the memory with the size of the input observation
+ if data_id is None:
+ data_id = list(range(self.bs))
+ assert self.obs_memory.shape[0] == self.seq_len
+ # implements a fifo queue, self.memory_idx is index where to put the last element
+ for i, b in enumerate(data_id):
+ if self.memory_idx[b] == self.seq_len:
+ # roll back of 1 position along dim 1 (sequence dim)
+ self.obs_memory[:, b] = torch.roll(self.obs_memory[:, b], -1, 0)
+ self.obs_memory[self.memory_idx[b] - 1, b] = input_obs[i]
+ if self.memory_idx[b] < self.seq_len:
+ self.obs_memory[self.memory_idx[b], b] = input_obs[i]
+ if self.memory_idx != self.seq_len:
+ self.memory_idx[b] += 1
+ out = self._model.forward(self.obs_memory, **kwargs)
+ out['input_seq'] = self.obs_memory
+ if only_last_logit:
+ # return only the logits for running environments
+ out['logit'] = [out['logit'][self.memory_idx[b] - 1][b] for b in range(self.bs) if b in data_id]
+ out['logit'] = default_collate(out['logit'])
+ return out
+
+ def reset_input(self, input_obs: torch.Tensor):
+ """
+ Overview:
+ Initialize the whole memory
+ """
+ init_obs = torch.zeros_like(input_obs)
+ self.init_obs = init_obs
+ self.obs_memory = [] # List(bs, *obs_shape)
+ for i in range(self.seq_len):
+ self.obs_memory.append(init_obs.clone() if init_obs is not None else self._init_fn())
+ self.obs_memory = default_collate(self.obs_memory) # shape (N, bs, *obs_shape)
+ self.bs = self.init_obs.shape[0]
+ self.memory_idx = [0 for _ in range(self.bs)]
+
+ # called before evaluation
+ # called after each evaluation iteration for each done env
+ # called after each collect iteration for each done env
+ def reset(self, *args, **kwargs):
+ state_id = kwargs.get('data_id', None)
+ input_obs = kwargs.get('input_obs', None)
+ if input_obs is not None:
+ self.reset_input(input_obs)
+ if state_id is not None:
+ self.reset_memory_entry(state_id)
+ if input_obs is None and state_id is None:
+ self.obs_memory = None
+ if hasattr(self._model, 'reset'):
+ return self._model.reset(*args, **kwargs)
+
+ def reset_memory_entry(self, state_id: Optional[list] = None) -> None:
+ """
+ Overview:
+ Reset specific batch of the memory, batch ids are specified in 'state_id'
+ """
+ assert self.init_obs is not None, 'Call method "reset_memory" first'
+ for _id in state_id:
+ self.memory_idx[_id] = 0
+ self.obs_memory[:, _id] = self.init_obs[_id] # init the corresponding sequence with broadcasting
+
+
+class TransformerSegmentWrapper(IModelWrapper):
+
+ def __init__(self, model: Any, seq_len: int) -> None:
+ """
+ Overview:
+ Given T the length of a trajectory and N the length of the sequences received by a Transformer model,
+ split T in sequences of N elements and forward each sequence one by one. If T % N != 0, the last sequence
+ will be zero-padded. Usually used during Transformer training phase.
+ Arguments:
+ - model (:obj:`Any`): Wrapped model class, should contain forward method.
+ - seq_len (:obj:`int`): N, length of a sequence.
+ """
+ super().__init__(model)
+ self.seq_len = seq_len
+
+ def forward(self, obs: torch.Tensor, **kwargs) -> Dict[str, torch.Tensor]:
+ """
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least \
+ ['main_obs', 'target_obs', 'action', 'reward', 'done', 'weight']
+ Returns:
+ - List containing a dict of the model output for each sequence.
+ """
+ sequences = list(torch.split(obs, self.seq_len, dim=0))
+ if sequences[-1].shape[0] < self.seq_len:
+ last = sequences[-1].clone()
+ diff = self.seq_len - last.shape[0]
+ sequences[-1] = F.pad(input=last, pad=(0, 0, 0, 0, 0, diff), mode='constant', value=0)
+ outputs = []
+ for i, seq in enumerate(sequences):
+ out = self._model.forward(seq, **kwargs)
+ outputs.append(out)
+ out = {}
+ for k in outputs[0].keys():
+ out_k = [o[k] for o in outputs]
+ out_k = torch.cat(out_k, dim=0)
+ out[k] = out_k
+ return out
+
+
+class TransformerMemoryWrapper(IModelWrapper):
+
+ def __init__(
+ self,
+ model: Any,
+ batch_size: int,
+ ) -> None:
+ """
+ Overview:
+ Stores a copy of the Transformer memory in order to be reused across different phases. To make it more
+ clear, suppose the training pipeline is divided into 3 phases: evaluate, collect, learn. The goal of the
+ wrapper is to maintain the content of the memory at the end of each phase and reuse it when the same phase
+ is executed again. In this way, it prevents different phases to interferer each other memory.
+ Arguments:
+ - model (:obj:`Any`): Wrapped model class, should contain forward method.
+ - batch_size (:obj:`int`): Memory batch size.
+ """
+ super().__init__(model)
+ # shape (layer_num, memory_len, bs, embedding_dim)
+ self._model.reset_memory(batch_size=batch_size)
+ self.memory = self._model.get_memory()
+ self.mem_shape = self.memory.shape
+
+ def forward(self, *args, **kwargs) -> Dict[str, torch.Tensor]:
+ """
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least \
+ ['main_obs', 'target_obs', 'action', 'reward', 'done', 'weight']
+ Returns:
+ - Output of the forward method.
+ """
+ self._model.reset_memory(state=self.memory)
+ out = self._model.forward(*args, **kwargs)
+ self.memory = self._model.get_memory()
+ return out
+
+ def reset(self, *args, **kwargs):
+ state_id = kwargs.get('data_id', None)
+ if state_id is None:
+ self.memory = torch.zeros(self.mem_shape)
+ else:
+ self.reset_memory_entry(state_id)
+ if hasattr(self._model, 'reset'):
+ return self._model.reset(*args, **kwargs)
+
+ def reset_memory_entry(self, state_id: Optional[list] = None) -> None:
+ """
+ Overview:
+ Reset specific batch of the memory, batch ids are specified in 'state_id'
+ """
+ for _id in state_id:
+ self.memory[:, :, _id] = torch.zeros((self.mem_shape[-1]))
+
+ def show_memory_occupancy(self, layer=0) -> None:
+ memory = self.memory
+ memory_shape = memory.shape
+ print('Layer {}-------------------------------------------'.format(layer))
+ for b in range(memory_shape[-2]):
+ print('b{}: '.format(b), end='')
+ for m in range(memory_shape[1]):
+ if sum(abs(memory[layer][m][b].flatten())) != 0:
+ print(1, end='')
+ else:
+ print(0, end='')
+ print()
+
+
+def sample_action(logit=None, prob=None):
+ if prob is None:
+ prob = torch.softmax(logit, dim=-1)
+ shape = prob.shape
+ prob += 1e-8
+ prob = prob.view(-1, shape[-1])
+ # prob can also be treated as weight in multinomial sample
+ action = torch.multinomial(prob, 1).squeeze(-1)
+ action = action.view(*shape[:-1])
+ return action
+
+
+class ArgmaxSampleWrapper(IModelWrapper):
+ """
+ Overview:
+ Used to help the model to sample argmax action.
+ Interfaces:
+ ``forward``.
+ """
+
+ def forward(self, *args, **kwargs):
+ """
+ Overview:
+ Employ model forward computation graph, and use the output logit to greedily select max action (argmax).
+ """
+ output = self._model.forward(*args, **kwargs)
+ assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
+ logit = output['logit']
+ assert isinstance(logit, torch.Tensor) or isinstance(logit, list)
+ if isinstance(logit, torch.Tensor):
+ logit = [logit]
+ if 'action_mask' in output:
+ mask = output['action_mask']
+ if isinstance(mask, torch.Tensor):
+ mask = [mask]
+ logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)]
+ action = [l.argmax(dim=-1) for l in logit]
+ if len(action) == 1:
+ action, logit = action[0], logit[0]
+ output['action'] = action
+ return output
+
+
+class CombinationArgmaxSampleWrapper(IModelWrapper):
+ r"""
+ Overview:
+ Used to help the model to sample combination argmax action.
+ Interfaces:
+ ``forward``.
+ """
+
+ def forward(self, shot_number, *args, **kwargs):
+ output = self._model.forward(*args, **kwargs)
+ # Generate actions.
+ act = []
+ mask = torch.zeros_like(output['logit'])
+ for ii in range(shot_number):
+ masked_logit = output['logit'] + mask
+ actions = masked_logit.argmax(dim=-1)
+ act.append(actions)
+ for jj in range(actions.shape[0]):
+ mask[jj][actions[jj]] = -1e8
+ # `act` is shaped: (B, shot_number)
+ act = torch.stack(act, dim=1)
+ output['action'] = act
+ return output
+
+
+class CombinationMultinomialSampleWrapper(IModelWrapper):
+ r"""
+ Overview:
+ Used to help the model to sample combination multinomial action.
+ Interfaces:
+ ``forward``.
+ """
+
+ def forward(self, shot_number, *args, **kwargs):
+ output = self._model.forward(*args, **kwargs)
+ # Generate actions.
+ act = []
+ mask = torch.zeros_like(output['logit'])
+ for ii in range(shot_number):
+ dist = torch.distributions.Categorical(logits=output['logit'] + mask)
+ actions = dist.sample()
+ act.append(actions)
+ for jj in range(actions.shape[0]):
+ mask[jj][actions[jj]] = -1e8
+
+ # `act` is shaped: (B, shot_number)
+ act = torch.stack(act, dim=1)
+ output['action'] = act
+ return output
+
+
+class HybridArgmaxSampleWrapper(IModelWrapper):
+ r"""
+ Overview:
+ Used to help the model to sample argmax action in hybrid action space,
+ i.e.{'action_type': discrete, 'action_args', continuous}
+ Interfaces:
+ ``forward``.
+ """
+
+ def forward(self, *args, **kwargs):
+ output = self._model.forward(*args, **kwargs)
+ assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
+ if 'logit' not in output:
+ return output
+ logit = output['logit']
+ assert isinstance(logit, torch.Tensor) or isinstance(logit, list)
+ if isinstance(logit, torch.Tensor):
+ logit = [logit]
+ if 'action_mask' in output:
+ mask = output['action_mask']
+ if isinstance(mask, torch.Tensor):
+ mask = [mask]
+ logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)]
+ action = [l.argmax(dim=-1) for l in logit]
+ if len(action) == 1:
+ action, logit = action[0], logit[0]
+ output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit}
+ return output
+
+
+class MultinomialSampleWrapper(IModelWrapper):
+ """
+ Overview:
+ Used to help the model get the corresponding action from the output['logits']self.
+ Interfaces:
+ ``forward``.
+ """
+
+ def forward(self, *args, **kwargs):
+ if 'alpha' in kwargs.keys():
+ alpha = kwargs.pop('alpha')
+ else:
+ alpha = None
+ output = self._model.forward(*args, **kwargs)
+ assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
+ logit = output['logit']
+ assert isinstance(logit, torch.Tensor) or isinstance(logit, list)
+ if isinstance(logit, torch.Tensor):
+ logit = [logit]
+ if 'action_mask' in output:
+ mask = output['action_mask']
+ if isinstance(mask, torch.Tensor):
+ mask = [mask]
+ logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)]
+ if alpha is None:
+ action = [sample_action(logit=l) for l in logit]
+ else:
+ # Note that if alpha is passed in here, we will divide logit by alpha.
+ action = [sample_action(logit=l / alpha) for l in logit]
+ if len(action) == 1:
+ action, logit = action[0], logit[0]
+ output['action'] = action
+ return output
+
+
+class EpsGreedySampleWrapper(IModelWrapper):
+ r"""
+ Overview:
+ Epsilon greedy sampler used in collector_model to help balance exploratin and exploitation.
+ The type of eps can vary from different algorithms, such as:
+ - float (i.e. python native scalar): for almost normal case
+ - Dict[str, float]: for algorithm NGU
+ Interfaces:
+ ``forward``.
+ """
+
+ def forward(self, *args, **kwargs):
+ eps = kwargs.pop('eps')
+ output = self._model.forward(*args, **kwargs)
+ assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
+ logit = output['logit']
+ assert isinstance(logit, torch.Tensor) or isinstance(logit, list)
+ if isinstance(logit, torch.Tensor):
+ logit = [logit]
+ if 'action_mask' in output:
+ mask = output['action_mask']
+ if isinstance(mask, torch.Tensor):
+ mask = [mask]
+ logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)]
+ else:
+ mask = None
+ action = []
+ if isinstance(eps, dict):
+ # for NGU policy, eps is a dict, each collect env has a different eps
+ for i, l in enumerate(logit[0]):
+ eps_tmp = eps[i]
+ if np.random.random() > eps_tmp:
+ action.append(l.argmax(dim=-1))
+ else:
+ if mask is not None:
+ action.append(
+ sample_action(prob=mask[0][i].float().unsqueeze(0)).to(logit[0].device).squeeze(0)
+ )
+ else:
+ action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1]).to(logit[0].device))
+ action = torch.stack(action, dim=-1) # shape torch.size([env_num])
+ else:
+ for i, l in enumerate(logit):
+ if np.random.random() > eps:
+ action.append(l.argmax(dim=-1))
+ else:
+ if mask is not None:
+ action.append(sample_action(prob=mask[i].float()))
+ else:
+ action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1]))
+ if len(action) == 1:
+ action, logit = action[0], logit[0]
+ output['action'] = action
+ return output
+
+
+class EpsGreedyMultinomialSampleWrapper(IModelWrapper):
+ r"""
+ Overview:
+ Epsilon greedy sampler coupled with multinomial sample used in collector_model
+ to help balance exploration and exploitation.
+ Interfaces:
+ ``forward``.
+ """
+
+ def forward(self, *args, **kwargs):
+ eps = kwargs.pop('eps')
+ if 'alpha' in kwargs.keys():
+ alpha = kwargs.pop('alpha')
+ else:
+ alpha = None
+ output = self._model.forward(*args, **kwargs)
+ assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
+ logit = output['logit']
+ assert isinstance(logit, torch.Tensor) or isinstance(logit, list)
+ if isinstance(logit, torch.Tensor):
+ logit = [logit]
+ if 'action_mask' in output:
+ mask = output['action_mask']
+ if isinstance(mask, torch.Tensor):
+ mask = [mask]
+ logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)]
+ else:
+ mask = None
+ action = []
+ for i, l in enumerate(logit):
+ if np.random.random() > eps:
+ if alpha is None:
+ action = [sample_action(logit=l) for l in logit]
+ else:
+ # Note that if alpha is passed in here, we will divide logit by alpha.
+ action = [sample_action(logit=l / alpha) for l in logit]
+ else:
+ if mask:
+ action.append(sample_action(prob=mask[i].float()))
+ else:
+ action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1]))
+ if len(action) == 1:
+ action, logit = action[0], logit[0]
+ output['action'] = action
+ return output
+
+
+class HybridEpsGreedySampleWrapper(IModelWrapper):
+ r"""
+ Overview:
+ Epsilon greedy sampler used in collector_model to help balance exploration and exploitation.
+ In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous}
+ Interfaces:
+ ``forward``.
+ """
+
+ def forward(self, *args, **kwargs):
+ eps = kwargs.pop('eps')
+ output = self._model.forward(*args, **kwargs)
+ assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
+ logit = output['logit']
+ assert isinstance(logit, torch.Tensor) or isinstance(logit, list)
+ if isinstance(logit, torch.Tensor):
+ logit = [logit]
+ if 'action_mask' in output:
+ mask = output['action_mask']
+ if isinstance(mask, torch.Tensor):
+ mask = [mask]
+ logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)]
+ else:
+ mask = None
+ action = []
+ for i, l in enumerate(logit):
+ if np.random.random() > eps:
+ action.append(l.argmax(dim=-1))
+ else:
+ if mask:
+ action.append(sample_action(prob=mask[i].float()))
+ else:
+ action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1]))
+ if len(action) == 1:
+ action, logit = action[0], logit[0]
+ output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit}
+ return output
+
+
+class HybridEpsGreedyMultinomialSampleWrapper(IModelWrapper):
+ """
+ Overview:
+ Epsilon greedy sampler coupled with multinomial sample used in collector_model
+ to help balance exploration and exploitation.
+ In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous}
+ Interfaces:
+ ``forward``.
+ """
+
+ def forward(self, *args, **kwargs):
+ eps = kwargs.pop('eps')
+ output = self._model.forward(*args, **kwargs)
+ assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
+ if 'logit' not in output:
+ return output
+
+ logit = output['logit']
+ assert isinstance(logit, torch.Tensor) or isinstance(logit, list)
+ if isinstance(logit, torch.Tensor):
+ logit = [logit]
+ if 'action_mask' in output:
+ mask = output['action_mask']
+ if isinstance(mask, torch.Tensor):
+ mask = [mask]
+ logit = [l.sub_(1e8 * (1 - m)) for l, m in zip(logit, mask)]
+ else:
+ mask = None
+ action = []
+ for i, l in enumerate(logit):
+ if np.random.random() > eps:
+ action = [sample_action(logit=l) for l in logit]
+ else:
+ if mask:
+ action.append(sample_action(prob=mask[i].float()))
+ else:
+ action.append(torch.randint(0, l.shape[-1], size=l.shape[:-1]))
+ if len(action) == 1:
+ action, logit = action[0], logit[0]
+ output = {'action': {'action_type': action, 'action_args': output['action_args']}, 'logit': logit}
+ return output
+
+
+class HybridReparamMultinomialSampleWrapper(IModelWrapper):
+ """
+ Overview:
+ Reparameterization sampler coupled with multinomial sample used in collector_model
+ to help balance exploration and exploitation.
+ In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous}
+ Interfaces:
+ forward
+ """
+
+ def forward(self, *args, **kwargs):
+ output = self._model.forward(*args, **kwargs)
+ assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
+
+ logit = output['logit'] # logit: {'action_type': action_type_logit, 'action_args': action_args_logit}
+ # discrete part
+ action_type_logit = logit['action_type']
+ prob = torch.softmax(action_type_logit, dim=-1)
+ pi_action = Categorical(prob)
+ action_type = pi_action.sample()
+ # continuous part
+ mu, sigma = logit['action_args']['mu'], logit['action_args']['sigma']
+ dist = Independent(Normal(mu, sigma), 1)
+ action_args = dist.sample()
+ action = {'action_type': action_type, 'action_args': action_args}
+ output['action'] = action
+ return output
+
+
+class HybridDeterministicArgmaxSampleWrapper(IModelWrapper):
+ """
+ Overview:
+ Deterministic sampler coupled with argmax sample used in eval_model.
+ In hybrid action space, i.e.{'action_type': discrete, 'action_args', continuous}
+ Interfaces:
+ forward
+ """
+
+ def forward(self, *args, **kwargs):
+ output = self._model.forward(*args, **kwargs)
+ assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
+ logit = output['logit'] # logit: {'action_type': action_type_logit, 'action_args': action_args_logit}
+ # discrete part
+ action_type_logit = logit['action_type']
+ action_type = action_type_logit.argmax(dim=-1)
+ # continuous part
+ mu = logit['action_args']['mu']
+ action_args = mu
+ action = {'action_type': action_type, 'action_args': action_args}
+ output['action'] = action
+ return output
+
+
+class DeterministicSampleWrapper(IModelWrapper):
+ """
+ Overview:
+ Deterministic sampler (just use mu directly) used in eval_model.
+ Interfaces:
+ forward
+ """
+
+ def forward(self, *args, **kwargs):
+ output = self._model.forward(*args, **kwargs)
+ assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
+ output['action'] = output['logit']['mu']
+ return output
+
+
+class ReparamSampleWrapper(IModelWrapper):
+ """
+ Overview:
+ Reparameterization gaussian sampler used in collector_model.
+ Interfaces:
+ forward
+ """
+
+ def forward(self, *args, **kwargs):
+ output = self._model.forward(*args, **kwargs)
+ assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
+ mu, sigma = output['logit']['mu'], output['logit']['sigma']
+ dist = Independent(Normal(mu, sigma), 1)
+ output['action'] = dist.sample()
+ return output
+
+
+class ActionNoiseWrapper(IModelWrapper):
+ r"""
+ Overview:
+ Add noise to collector's action output; Do clips on both generated noise and action after adding noise.
+ Interfaces:
+ ``__init__``, ``forward``.
+ Arguments:
+ - model (:obj:`Any`): Wrapped model class. Should contain ``forward`` method.
+ - noise_type (:obj:`str`): The type of noise that should be generated, support ['gauss', 'ou'].
+ - noise_kwargs (:obj:`dict`): Keyword args that should be used in noise init. Depends on ``noise_type``.
+ - noise_range (:obj:`Optional[dict]`): Range of noise, used for clipping.
+ - action_range (:obj:`Optional[dict]`): Range of action + noise, used for clip, default clip to [-1, 1].
+ """
+
+ def __init__(
+ self,
+ model: Any,
+ noise_type: str = 'gauss',
+ noise_kwargs: dict = {},
+ noise_range: Optional[dict] = None,
+ action_range: Optional[dict] = {
+ 'min': -1,
+ 'max': 1
+ }
+ ) -> None:
+ super().__init__(model)
+ self.noise_generator = create_noise_generator(noise_type, noise_kwargs)
+ self.noise_range = noise_range
+ self.action_range = action_range
+
+ def forward(self, *args, **kwargs):
+ # if noise sigma need decay, update noise kwargs.
+ if 'sigma' in kwargs:
+ sigma = kwargs.pop('sigma')
+ if sigma is not None:
+ self.noise_generator.sigma = sigma
+ output = self._model.forward(*args, **kwargs)
+ assert isinstance(output, dict), "model output must be dict, but find {}".format(type(output))
+ if 'action' in output or 'action_args' in output:
+ key = 'action' if 'action' in output else 'action_args'
+ action = output[key]
+ assert isinstance(action, torch.Tensor)
+ action = self.add_noise(action)
+ output[key] = action
+ return output
+
+ def add_noise(self, action: torch.Tensor) -> torch.Tensor:
+ r"""
+ Overview:
+ Generate noise and clip noise if needed. Add noise to action and clip action if needed.
+ Arguments:
+ - action (:obj:`torch.Tensor`): Model's action output.
+ Returns:
+ - noised_action (:obj:`torch.Tensor`): Action processed after adding noise and clipping.
+ """
+ noise = self.noise_generator(action.shape, action.device)
+ if self.noise_range is not None:
+ noise = noise.clamp(self.noise_range['min'], self.noise_range['max'])
+ action += noise
+ if self.action_range is not None:
+ action = action.clamp(self.action_range['min'], self.action_range['max'])
+ return action
+
+
+class TargetNetworkWrapper(IModelWrapper):
+ r"""
+ Overview:
+ Maintain and update the target network
+ Interfaces:
+ update, reset
+ """
+
+ def __init__(self, model: Any, update_type: str, update_kwargs: dict):
+ super().__init__(model)
+ assert update_type in ['momentum', 'assign']
+ self._update_type = update_type
+ self._update_kwargs = update_kwargs
+ self._update_count = 0
+
+ def reset(self, *args, **kwargs):
+ target_update_count = kwargs.pop('target_update_count', None)
+ self.reset_state(target_update_count)
+ if hasattr(self._model, 'reset'):
+ return self._model.reset(*args, **kwargs)
+
+ def update(self, state_dict: dict, direct: bool = False) -> None:
+ r"""
+ Overview:
+ Update the target network state dict
+
+ Arguments:
+ - state_dict (:obj:`dict`): the state_dict from learner model
+ - direct (:obj:`bool`): whether to update the target network directly, \
+ if true then will simply call the load_state_dict method of the model
+ """
+ if direct:
+ self._model.load_state_dict(state_dict, strict=True)
+ self._update_count = 0
+ else:
+ if self._update_type == 'assign':
+ if (self._update_count + 1) % self._update_kwargs['freq'] == 0:
+ self._model.load_state_dict(state_dict, strict=True)
+ self._update_count += 1
+ elif self._update_type == 'momentum':
+ theta = self._update_kwargs['theta']
+ for name, p in self._model.named_parameters():
+ # default theta = 0.001
+ p.data = (1 - theta) * p.data + theta * state_dict[name]
+
+ def reset_state(self, target_update_count: int = None) -> None:
+ r"""
+ Overview:
+ Reset the update_count
+ Arguments:
+ target_update_count (:obj:`int`): reset target update count value.
+ """
+ if target_update_count is not None:
+ self._update_count = target_update_count
+
+
+class TeacherNetworkWrapper(IModelWrapper):
+ """
+ Overview:
+ Set the teacher Network. Set the model's model.teacher_cfg to the input teacher_cfg
+ """
+
+ def __init__(self, model, teacher_cfg):
+ super().__init__(model)
+ self._model._teacher_cfg = teacher_cfg
+ raise NotImplementedError
+
+
+wrapper_name_map = {
+ 'base': BaseModelWrapper,
+ 'hidden_state': HiddenStateWrapper,
+ 'argmax_sample': ArgmaxSampleWrapper,
+ 'hybrid_argmax_sample': HybridArgmaxSampleWrapper,
+ 'eps_greedy_sample': EpsGreedySampleWrapper,
+ 'eps_greedy_multinomial_sample': EpsGreedyMultinomialSampleWrapper,
+ 'deterministic_sample': DeterministicSampleWrapper,
+ 'reparam_sample': ReparamSampleWrapper,
+ 'hybrid_eps_greedy_sample': HybridEpsGreedySampleWrapper,
+ 'hybrid_eps_greedy_multinomial_sample': HybridEpsGreedyMultinomialSampleWrapper,
+ 'hybrid_reparam_multinomial_sample': HybridReparamMultinomialSampleWrapper,
+ 'hybrid_deterministic_argmax_sample': HybridDeterministicArgmaxSampleWrapper,
+ 'multinomial_sample': MultinomialSampleWrapper,
+ 'action_noise': ActionNoiseWrapper,
+ 'transformer_input': TransformerInputWrapper,
+ 'transformer_segment': TransformerSegmentWrapper,
+ 'transformer_memory': TransformerMemoryWrapper,
+ # model wrapper
+ 'target': TargetNetworkWrapper,
+ 'teacher': TeacherNetworkWrapper,
+ 'combination_argmax_sample': CombinationArgmaxSampleWrapper,
+ 'combination_multinomial_sample': CombinationMultinomialSampleWrapper,
+}
+
+
+def model_wrap(model: Union[nn.Module, IModelWrapper], wrapper_name: str = None, **kwargs):
+ """
+ Overview:
+ Wrap the model with the specified wrapper and return the wrappered model.
+ Arguments:
+ - model (:obj:`Any`): The model to be wrapped.
+ - wrapper_name (:obj:`str`): The name of the wrapper to be used.
+
+ .. note::
+ The arguments of the wrapper should be passed in as kwargs.
+ """
+ if wrapper_name in wrapper_name_map:
+ # TODO test whether to remove this if branch
+ if not isinstance(model, IModelWrapper):
+ model = wrapper_name_map['base'](model)
+ model = wrapper_name_map[wrapper_name](model, **kwargs)
+ else:
+ raise TypeError("not support model_wrapper type: {}".format(wrapper_name))
+ return model
+
+
+def register_wrapper(name: str, wrapper_type: type) -> None:
+ """
+ Overview:
+ Register new wrapper to ``wrapper_name_map``. When user implements a new wrapper, they must call this function \
+ to complete the registration. Then the wrapper can be called by ``model_wrap``.
+ Arguments:
+ - name (:obj:`str`): The name of the new wrapper to be registered.
+ - wrapper_type (:obj:`type`): The wrapper class needs to be added in ``wrapper_name_map``. This argument \
+ should be the subclass of ``IModelWrapper``.
+ """
+ assert isinstance(name, str)
+ assert issubclass(wrapper_type, IModelWrapper)
+ wrapper_name_map[name] = wrapper_type
diff --git a/DI-engine/ding/model/wrapper/test_model_wrappers.py b/DI-engine/ding/model/wrapper/test_model_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..1da744d36fb504e9e2cc37e7774d6c24137673ba
--- /dev/null
+++ b/DI-engine/ding/model/wrapper/test_model_wrappers.py
@@ -0,0 +1,578 @@
+import copy
+from copy import deepcopy
+from collections import OrderedDict
+
+import pytest
+import torch
+import torch.nn as nn
+from ditk import logging
+
+from ding.torch_utils import get_lstm
+from ding.torch_utils.network.gtrxl import GTrXL
+from ding.model import model_wrap, register_wrapper, IModelWrapper
+from ding.model.wrapper.model_wrappers import BaseModelWrapper
+
+
+class TempMLP(torch.nn.Module):
+
+ def __init__(self):
+ super(TempMLP, self).__init__()
+ self.fc1 = nn.Linear(3, 4)
+ self.bn1 = nn.BatchNorm1d(4)
+ self.fc2 = nn.Linear(4, 6)
+ self.act = nn.ReLU()
+
+ def forward(self, x):
+ x = self.fc1(x)
+ x = self.bn1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ x = self.act(x)
+ return x
+
+
+class ActorMLP(torch.nn.Module):
+
+ def __init__(self):
+ super(ActorMLP, self).__init__()
+ self.fc1 = nn.Linear(3, 4)
+ self.bn1 = nn.BatchNorm1d(4)
+ self.fc2 = nn.Linear(4, 6)
+ self.act = nn.ReLU()
+ self.out = nn.Softmax(dim=-1)
+
+ def forward(self, inputs, tmp=0):
+ x = self.fc1(inputs['obs'])
+ x = self.bn1(x)
+ x = self.act(x)
+ x = self.fc2(x)
+ x = self.act(x)
+ x = self.out(x)
+ ret = {'logit': x, 'tmp': tmp, 'action': x + torch.rand_like(x)}
+ if 'mask' in inputs:
+ ret['action_mask'] = inputs['mask']
+ return ret
+
+
+class HybridActorMLP(torch.nn.Module):
+
+ def __init__(self):
+ super(HybridActorMLP, self).__init__()
+ self.fc1 = nn.Linear(3, 4)
+ self.bn1 = nn.BatchNorm1d(4)
+ self.fc2 = nn.Linear(4, 6)
+ self.act = nn.ReLU()
+ self.out = nn.Softmax(dim=-1)
+
+ self.fc2_cont = nn.Linear(4, 6)
+ self.act_cont = nn.ReLU()
+
+ def forward(self, inputs, tmp=0):
+ x = self.fc1(inputs['obs'])
+ x = self.bn1(x)
+ x_ = self.act(x)
+
+ x = self.fc2(x_)
+ x = self.act(x)
+ x_disc = self.out(x)
+
+ x = self.fc2_cont(x_)
+ x_cont = self.act_cont(x)
+
+ ret = {'logit': x_disc, 'action_args': x_cont, 'tmp': tmp}
+
+ if 'mask' in inputs:
+ ret['action_mask'] = inputs['mask']
+ return ret
+
+
+class HybridReparamActorMLP(torch.nn.Module):
+
+ def __init__(self):
+ super(HybridReparamActorMLP, self).__init__()
+ self.fc1 = nn.Linear(3, 4)
+ self.bn1 = nn.BatchNorm1d(4)
+ self.fc2 = nn.Linear(4, 6)
+ self.act = nn.ReLU()
+ self.out = nn.Softmax(dim=-1)
+
+ self.fc2_cont_mu = nn.Linear(4, 6)
+ self.act_cont_mu = nn.ReLU()
+
+ self.fc2_cont_sigma = nn.Linear(4, 6)
+ self.act_cont_sigma = nn.ReLU()
+
+ def forward(self, inputs, tmp=0):
+ x = self.fc1(inputs['obs'])
+ x = self.bn1(x)
+ x_ = self.act(x)
+
+ x = self.fc2(x_)
+ x = self.act(x)
+ x_disc = self.out(x)
+
+ x = self.fc2_cont_mu(x_)
+ x_cont_mu = self.act_cont_mu(x)
+
+ x = self.fc2_cont_sigma(x_)
+ x_cont_sigma = self.act_cont_sigma(x) + 1e-8
+
+ ret = {'logit': {'action_type': x_disc, 'action_args': {'mu': x_cont_mu, 'sigma': x_cont_sigma}}, 'tmp': tmp}
+
+ if 'mask' in inputs:
+ ret['action_mask'] = inputs['mask']
+ return ret
+
+
+class ReparamActorMLP(torch.nn.Module):
+
+ def __init__(self):
+ super(ReparamActorMLP, self).__init__()
+ self.fc1 = nn.Linear(3, 4)
+ self.bn1 = nn.BatchNorm1d(4)
+ self.fc2 = nn.Linear(4, 6)
+ self.act = nn.ReLU()
+
+ self.fc2_cont_mu = nn.Linear(4, 6)
+ self.fc2_cont_sigma = nn.Linear(4, 6)
+
+ def forward(self, inputs, tmp=0):
+ x = self.fc1(inputs['obs'])
+ x = self.bn1(x)
+ x_ = self.act(x)
+
+ x = self.fc2_cont_mu(x_)
+ x_cont_mu = self.act(x)
+
+ x = self.fc2_cont_sigma(x_)
+ x_cont_sigma = self.act(x) + 1e-8
+
+ ret = {'logit': {'mu': x_cont_mu, 'sigma': x_cont_sigma}, 'tmp': tmp}
+
+ if 'mask' in inputs:
+ ret['action_mask'] = inputs['mask']
+ return ret
+
+
+class DeterministicActorMLP(torch.nn.Module):
+
+ def __init__(self):
+ super(DeterministicActorMLP, self).__init__()
+ self.fc1 = nn.Linear(3, 4)
+ self.bn1 = nn.BatchNorm1d(4)
+ self.act = nn.ReLU()
+
+ self.fc2_cont_mu = nn.Linear(4, 6)
+ self.act_cont_mu = nn.ReLU()
+
+ def forward(self, inputs):
+ x = self.fc1(inputs['obs'])
+ x = self.bn1(x)
+ x_ = self.act(x)
+
+ x = self.fc2_cont_mu(x_)
+ x_cont_mu = self.act_cont_mu(x)
+
+ ret = {
+ 'logit': {
+ 'mu': x_cont_mu,
+ }
+ }
+
+ if 'mask' in inputs:
+ ret['action_mask'] = inputs['mask']
+ return ret
+
+
+class TempLSTM(torch.nn.Module):
+
+ def __init__(self):
+ super(TempLSTM, self).__init__()
+ self.model = get_lstm(lstm_type='pytorch', input_size=36, hidden_size=32, num_layers=2, norm_type=None)
+
+ def forward(self, data):
+ output, next_state = self.model(data['f'], data['prev_state'], list_next_state=True)
+ return {'output': output, 'next_state': next_state}
+
+
+@pytest.fixture(scope='function')
+def setup_model():
+ return torch.nn.Linear(3, 6)
+
+
+@pytest.mark.unittest
+class TestModelWrappers:
+
+ def test_hidden_state_wrapper(self):
+
+ model = TempLSTM()
+ state_num = 4
+ model = model_wrap(model, wrapper_name='hidden_state', state_num=state_num, save_prev_state=True)
+ model.reset()
+ data = {'f': torch.randn(2, 4, 36)}
+ output = model.forward(data)
+ assert output['output'].shape == (2, state_num, 32)
+ assert len(output['prev_state']) == 4
+ assert output['prev_state'][0]['h'].shape == (2, 1, 32)
+ for item in model._state.values():
+ assert isinstance(item, dict) and len(item) == 2
+ assert all(t.shape == (2, 1, 32) for t in item.values())
+
+ data = {'f': torch.randn(2, 3, 36)}
+ data_id = [0, 1, 3]
+ output = model.forward(data, data_id=data_id)
+ assert output['output'].shape == (2, 3, 32)
+ assert all([len(s) == 2 for s in output['prev_state']])
+ for item in model._state.values():
+ assert isinstance(item, dict) and len(item) == 2
+ assert all(t.shape == (2, 1, 32) for t in item.values())
+
+ data = {'f': torch.randn(2, 2, 36)}
+ data_id = [0, 1]
+ output = model.forward(data, data_id=data_id)
+ assert output['output'].shape == (2, 2, 32)
+
+ assert all([isinstance(s, dict) and len(s) == 2 for s in model._state.values()])
+ model.reset()
+ assert all([isinstance(s, type(None)) for s in model._state.values()])
+
+ def test_target_network_wrapper(self):
+
+ model = TempMLP()
+ target_model = deepcopy(model)
+ target_model2 = deepcopy(model)
+ target_model = model_wrap(target_model, wrapper_name='target', update_type='assign', update_kwargs={'freq': 2})
+ model = model_wrap(model, wrapper_name='base')
+ register_wrapper('abstract', IModelWrapper)
+ assert all([hasattr(target_model, n) for n in ['reset', 'forward', 'update']])
+ assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12
+ model.fc1.weight.data = torch.randn_like(model.fc1.weight)
+ assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12
+ target_model.update(model.state_dict(), direct=True)
+ assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12
+ model.reset()
+ target_model.reset()
+
+ inputs = torch.randn(2, 3)
+ model.train()
+ target_model.train()
+ output = model.forward(inputs)
+ with torch.no_grad():
+ output_target = target_model.forward(inputs)
+ assert output.eq(output_target).sum() == 2 * 6
+ model.fc1.weight.data = torch.randn_like(model.fc1.weight)
+ assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12
+ target_model.update(model.state_dict())
+ assert model.fc1.weight.ne(target_model.fc1.weight).sum() == 12
+ target_model.update(model.state_dict())
+ assert model.fc1.weight.eq(target_model.fc1.weight).sum() == 12
+ # test real reset update_count
+ assert target_model._update_count != 0
+ target_model.reset()
+ assert target_model._update_count != 0
+ target_model.reset(target_update_count=0)
+ assert target_model._update_count == 0
+
+ target_model2 = model_wrap(
+ target_model2, wrapper_name='target', update_type='momentum', update_kwargs={'theta': 0.01}
+ )
+ target_model2.update(model.state_dict(), direct=True)
+ assert model.fc1.weight.eq(target_model2.fc1.weight).sum() == 12
+ model.fc1.weight.data = torch.randn_like(model.fc1.weight)
+ old_state_dict = target_model2.state_dict()
+ target_model2.update(model.state_dict())
+ assert target_model2.fc1.weight.data.eq(
+ old_state_dict['fc1.weight'] * (1 - 0.01) + model.fc1.weight.data * 0.01
+ ).all()
+
+ def test_eps_greedy_wrapper(self):
+ model = ActorMLP()
+ model = model_wrap(model, wrapper_name='eps_greedy_sample')
+ model.eval()
+ eps_threshold = 0.5
+ data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
+ with torch.no_grad():
+ output = model.forward(data, eps=eps_threshold)
+ assert output['tmp'] == 0
+ for i in range(10):
+ if i == 5:
+ data.pop('mask')
+ with torch.no_grad():
+ output = model.forward(data, eps=eps_threshold, tmp=1)
+ assert isinstance(output, dict)
+ assert output['tmp'] == 1
+
+ def test_multinomial_sample_wrapper(self):
+ model = model_wrap(ActorMLP(), wrapper_name='multinomial_sample')
+ data = {'obs': torch.randn(4, 3)}
+ output = model.forward(data)
+ assert output['action'].shape == (4, )
+ data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
+ output = model.forward(data)
+ assert output['action'].shape == (4, )
+
+ def test_eps_greedy_multinomial_wrapper(self):
+ model = ActorMLP()
+ model = model_wrap(model, wrapper_name='eps_greedy_multinomial_sample')
+ model.eval()
+ eps_threshold = 0.5
+ data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
+ with torch.no_grad():
+ output = model.forward(data, eps=eps_threshold, alpha=0.2)
+ assert output['tmp'] == 0
+ for i in range(10):
+ if i == 5:
+ data.pop('mask')
+ with torch.no_grad():
+ output = model.forward(data, eps=eps_threshold, tmp=1, alpha=0.2)
+ assert isinstance(output, dict)
+ assert output['tmp'] == 1
+
+ def test_hybrid_eps_greedy_wrapper(self):
+ model = HybridActorMLP()
+ model = model_wrap(model, wrapper_name='hybrid_eps_greedy_sample')
+ model.eval()
+ eps_threshold = 0.5
+ data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
+ with torch.no_grad():
+ output = model.forward(data, eps=eps_threshold)
+ # logit = output['logit']
+ # assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all()
+ assert isinstance(output['action']['action_args'],
+ torch.Tensor) and output['action']['action_args'].shape == (4, 6)
+ for i in range(10):
+ if i == 5:
+ data.pop('mask')
+ with torch.no_grad():
+ output = model.forward(data, eps=eps_threshold, tmp=1)
+ assert isinstance(output, dict)
+
+ def test_hybrid_eps_greedy_multinomial_wrapper(self):
+ model = HybridActorMLP()
+ model = model_wrap(model, wrapper_name='hybrid_eps_greedy_multinomial_sample')
+ model.eval()
+ eps_threshold = 0.5
+ data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
+ with torch.no_grad():
+ output = model.forward(data, eps=eps_threshold)
+ assert isinstance(output['logit'], torch.Tensor) and output['logit'].shape == (4, 6)
+ assert isinstance(output['action']['action_type'],
+ torch.Tensor) and output['action']['action_type'].shape == (4, )
+ assert isinstance(output['action']['action_args'],
+ torch.Tensor) and output['action']['action_args'].shape == (4, 6)
+ for i in range(10):
+ if i == 5:
+ data.pop('mask')
+ with torch.no_grad():
+ output = model.forward(data, eps=eps_threshold, tmp=1)
+ assert isinstance(output, dict)
+
+ def test_hybrid_reparam_multinomial_wrapper(self):
+ model = HybridReparamActorMLP()
+ model = model_wrap(model, wrapper_name='hybrid_reparam_multinomial_sample')
+ model.eval()
+ data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
+ with torch.no_grad():
+ output = model.forward(data)
+ assert isinstance(output['logit'], dict) and output['logit']['action_type'].shape == (4, 6)
+ assert isinstance(output['logit']['action_args'], dict) and output['logit']['action_args']['mu'].shape == (
+ 4, 6
+ ) and output['logit']['action_args']['sigma'].shape == (4, 6)
+ assert isinstance(output['action']['action_type'],
+ torch.Tensor) and output['action']['action_type'].shape == (4, )
+ assert isinstance(output['action']['action_args'],
+ torch.Tensor) and output['action']['action_args'].shape == (4, 6)
+ for i in range(10):
+ if i == 5:
+ data.pop('mask')
+ with torch.no_grad():
+ output = model.forward(data, tmp=1)
+ assert isinstance(output, dict)
+
+ def test_argmax_sample_wrapper(self):
+ model = model_wrap(ActorMLP(), wrapper_name='argmax_sample')
+ data = {'obs': torch.randn(4, 3)}
+ output = model.forward(data)
+ logit = output['logit']
+ assert output['action'].eq(logit.argmax(dim=-1)).all()
+ data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
+ output = model.forward(data)
+ logit = output['logit'].sub(1e8 * (1 - data['mask']))
+ assert output['action'].eq(logit.argmax(dim=-1)).all()
+
+ def test_hybrid_argmax_sample_wrapper(self):
+ model = model_wrap(HybridActorMLP(), wrapper_name='hybrid_argmax_sample')
+ data = {'obs': torch.randn(4, 3)}
+ output = model.forward(data)
+ logit = output['logit']
+ assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all()
+ assert isinstance(output['action']['action_args'],
+ torch.Tensor) and output['action']['action_args'].shape == (4, 6)
+ data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
+ output = model.forward(data)
+ logit = output['logit'].sub(1e8 * (1 - data['mask']))
+ assert output['action']['action_type'].eq(logit.argmax(dim=-1)).all()
+ assert output['action']['action_args'].shape == (4, 6)
+
+ def test_hybrid_deterministic_argmax_sample_wrapper(self):
+ model = model_wrap(HybridReparamActorMLP(), wrapper_name='hybrid_deterministic_argmax_sample')
+ data = {'obs': torch.randn(4, 3)}
+ output = model.forward(data)
+ assert output['action']['action_type'].eq(output['logit']['action_type'].argmax(dim=-1)).all()
+ assert isinstance(output['action']['action_args'],
+ torch.Tensor) and output['action']['action_args'].shape == (4, 6)
+ assert output['action']['action_args'].eq(output['logit']['action_args']['mu']).all
+
+ def test_deterministic_sample_wrapper(self):
+ model = model_wrap(DeterministicActorMLP(), wrapper_name='deterministic_sample')
+ data = {'obs': torch.randn(4, 3)}
+ output = model.forward(data)
+ assert output['action'].eq(output['logit']['mu']).all()
+ assert isinstance(output['action'], torch.Tensor) and output['action'].shape == (4, 6)
+
+ def test_reparam_wrapper(self):
+ model = ReparamActorMLP()
+ model = model_wrap(model, wrapper_name='reparam_sample')
+ model.eval()
+ data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
+ with torch.no_grad():
+ output = model.forward(data)
+ assert isinstance(output['logit'],
+ dict) and output['logit']['mu'].shape == (4, 6) and output['logit']['sigma'].shape == (4, 6)
+ for i in range(10):
+ if i == 5:
+ data.pop('mask')
+ with torch.no_grad():
+ output = model.forward(data, tmp=1)
+ assert isinstance(output, dict)
+
+ def test_eps_greedy_wrapper_with_list_eps(self):
+ model = ActorMLP()
+ model = model_wrap(model, wrapper_name='eps_greedy_sample')
+ model.eval()
+ eps_threshold = {i: 0.5 for i in range(4)} # for NGU
+ data = {'obs': torch.randn(4, 3), 'mask': torch.randint(0, 2, size=(4, 6))}
+ with torch.no_grad():
+ output = model.forward(data, eps=eps_threshold)
+ assert output['tmp'] == 0
+ for i in range(10):
+ if i == 5:
+ data.pop('mask')
+ with torch.no_grad():
+ output = model.forward(data, eps=eps_threshold, tmp=1)
+ assert isinstance(output, dict)
+ assert output['tmp'] == 1
+
+ def test_action_noise_wrapper(self):
+ model = model_wrap(
+ ActorMLP(),
+ wrapper_name='action_noise',
+ noise_type='gauss',
+ noise_range={
+ 'min': -0.1,
+ 'max': 0.1
+ },
+ action_range={
+ 'min': -0.05,
+ 'max': 0.05
+ }
+ )
+ data = {'obs': torch.randn(4, 3)}
+ output = model.forward(data)
+ action = output['action']
+ assert action.shape == (4, 6)
+ assert action.eq(action.clamp(-0.05, 0.05)).all()
+
+ def test_transformer_input_wrapper(self):
+ seq_len, bs, obs_shape = 8, 8, 32
+ emb_dim = 64
+ model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim)
+ model = model_wrap(model, wrapper_name='transformer_input', seq_len=seq_len)
+ obs = []
+ for i in range(seq_len + 1):
+ obs.append(torch.randn((bs, obs_shape)))
+ out = model.forward(obs[0], only_last_logit=False)
+ assert out['logit'].shape == (seq_len, bs, emb_dim)
+ assert out['input_seq'].shape == (seq_len, bs, obs_shape)
+ assert sum(out['input_seq'][1:].flatten()) == 0
+ for i in range(1, seq_len - 1):
+ out = model.forward(obs[i])
+ assert out['logit'].shape == (bs, emb_dim)
+ assert out['input_seq'].shape == (seq_len, bs, obs_shape)
+ assert sum(out['input_seq'][seq_len - 1:].flatten()) == 0
+ assert sum(out['input_seq'][:seq_len - 1].flatten()) != 0
+ out = model.forward(obs[seq_len - 1])
+ prev_memory = torch.clone(out['input_seq'])
+ out = model.forward(obs[seq_len])
+ assert torch.all(torch.eq(out['input_seq'][seq_len - 2], prev_memory[seq_len - 1]))
+ # test update of single batches in the memory
+ model.reset(data_id=[0, 5]) # reset memory batch in position 0 and 5
+ assert sum(model.obs_memory[:, 0].flatten()) == 0 and sum(model.obs_memory[:, 5].flatten()) == 0
+ assert sum(model.obs_memory[:, 1].flatten()) != 0
+ assert model.memory_idx[0] == 0 and model.memory_idx[5] == 0 and model.memory_idx[1] == seq_len
+ # test reset
+ model.reset()
+ assert model.obs_memory is None
+
+ def test_transformer_segment_wrapper(self):
+ seq_len, bs, obs_shape = 12, 8, 32
+ layer_num, memory_len, emb_dim = 3, 4, 4
+ model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num)
+ model = model_wrap(model, wrapper_name='transformer_segment', seq_len=seq_len)
+ inputs1 = torch.randn((seq_len, bs, obs_shape))
+ out = model.forward(inputs1)
+ info = model.info('info')
+ info = model.info('x')
+
+ def test_transformer_memory_wrapper(self):
+ seq_len, bs, obs_shape = 12, 8, 32
+ layer_num, memory_len, emb_dim = 3, 4, 4
+ model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num)
+ model1 = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs)
+ model2 = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs)
+ model1.show_memory_occupancy()
+ inputs1 = torch.randn((seq_len, bs, obs_shape))
+ out = model1.forward(inputs1)
+ new_memory1 = model1.memory
+ inputs2 = torch.randn((seq_len, bs, obs_shape))
+ out = model2.forward(inputs2)
+ new_memory2 = model2.memory
+ assert not torch.all(torch.eq(new_memory1, new_memory2))
+ model1.reset(data_id=[0, 5])
+ assert sum(model1.memory[:, :, 0].flatten()) == 0 and sum(model1.memory[:, :, 5].flatten()) == 0
+ assert sum(model1.memory[:, :, 1].flatten()) != 0
+ model1.reset()
+ assert sum(model1.memory.flatten()) == 0
+
+ seq_len, bs, obs_shape = 8, 8, 32
+ layer_num, memory_len, emb_dim = 3, 20, 4
+ model = GTrXL(input_dim=obs_shape, embedding_dim=emb_dim, memory_len=memory_len, layer_num=layer_num)
+ model = model_wrap(model, wrapper_name='transformer_memory', batch_size=bs)
+ inputs1 = torch.randn((seq_len, bs, obs_shape))
+ out = model.forward(inputs1)
+ new_memory1 = model.memory
+ inputs2 = torch.randn((seq_len, bs, obs_shape))
+ out = model.forward(inputs2)
+ new_memory2 = model.memory
+ print(new_memory1.shape, inputs1.shape)
+ assert sum(new_memory1[:, -8:].flatten()) != 0
+ assert sum(new_memory1[:, :-8].flatten()) == 0
+ assert sum(new_memory2[:, -16:].flatten()) != 0
+ assert sum(new_memory2[:, :-16].flatten()) == 0
+ assert torch.all(torch.eq(new_memory1[:, -8:], new_memory2[:, -16:-8]))
+
+ def test_combination_argmax_sample_wrapper(self):
+ model = model_wrap(ActorMLP(), wrapper_name='combination_argmax_sample')
+ data = {'obs': torch.randn(4, 3)}
+ shot_number = 2
+ output = model.forward(shot_number=shot_number, inputs=data)
+ assert output['action'].shape == (4, shot_number)
+ assert (output['action'] >= 0).all() and (output['action'] < 64).all()
+
+ def test_combination_multinomial_sample_wrapper(self):
+ model = model_wrap(ActorMLP(), wrapper_name='combination_multinomial_sample')
+ data = {'obs': torch.randn(4, 3)}
+ shot_number = 2
+ output = model.forward(shot_number=shot_number, inputs=data)
+ assert output['action'].shape == (4, shot_number)
+ assert (output['action'] >= 0).all() and (output['action'] < 64).all()
diff --git a/DI-engine/ding/policy/__init__.py b/DI-engine/ding/policy/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..c85883a0afb2d54755f0bb1a70b012b47b7777a1
--- /dev/null
+++ b/DI-engine/ding/policy/__init__.py
@@ -0,0 +1,58 @@
+from .base_policy import Policy, CommandModePolicy, create_policy, get_policy_cls
+from .common_utils import single_env_forward_wrapper, single_env_forward_wrapper_ttorch, default_preprocess_learn
+from .dqn import DQNSTDIMPolicy, DQNPolicy
+from .mdqn import MDQNPolicy
+from .iqn import IQNPolicy
+from .fqf import FQFPolicy
+from .qrdqn import QRDQNPolicy
+from .c51 import C51Policy
+from .rainbow import RainbowDQNPolicy
+from .ddpg import DDPGPolicy
+from .d4pg import D4PGPolicy
+from .td3 import TD3Policy
+from .td3_vae import TD3VAEPolicy
+from .td3_bc import TD3BCPolicy
+from .dt import DTPolicy
+
+from .pg import PGPolicy
+from .a2c import A2CPolicy
+from .ppo import PPOPolicy, PPOPGPolicy, PPOOffPolicy
+from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy
+from .cql import CQLPolicy, DiscreteCQLPolicy
+from .edac import EDACPolicy
+from .impala import IMPALAPolicy
+from .ngu import NGUPolicy
+from .r2d2 import R2D2Policy
+from .r2d2_gtrxl import R2D2GTrXLPolicy
+from .ppg import PPGPolicy, PPGOffPolicy
+from .sqn import SQNPolicy
+from .bdq import BDQPolicy
+
+from .qmix import QMIXPolicy
+from .wqmix import WQMIXPolicy
+from .coma import COMAPolicy
+from .collaq import CollaQPolicy
+from .atoc import ATOCPolicy
+from .acer import ACERPolicy
+from .qtran import QTRANPolicy
+
+from .il import ILPolicy
+
+from .r2d3 import R2D3Policy
+
+from .command_mode_policy_instance import *
+
+from .policy_factory import PolicyFactory, get_random_policy
+from .pdqn import PDQNPolicy
+
+from .bc import BehaviourCloningPolicy
+from .ibc import IBCPolicy
+
+from .pc import ProcedureCloningBFSPolicy
+
+from .bcq import BCQPolicy
+
+# new-type policy
+from .ppof import PPOFPolicy
+from .prompt_pg import PromptPGPolicy
+from .happo import HAPPOPolicy
diff --git a/DI-engine/ding/policy/a2c.py b/DI-engine/ding/policy/a2c.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e05f4e7128a66afdbaee2e1eca603cb52f9abbd
--- /dev/null
+++ b/DI-engine/ding/policy/a2c.py
@@ -0,0 +1,295 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import torch
+
+from ding.rl_utils import a2c_data, a2c_error, get_gae_with_default_last_value, get_train_sample, \
+ a2c_error_continuous
+from ding.torch_utils import Adam, to_device
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY, split_data_generator
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('a2c')
+class A2CPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of A2C algorithm.
+ """
+ config = dict(
+ # (string) RL policy register name (refer to function "register_policy").
+ type='a2c',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
+ on_policy=True, # for a2c strictly on policy algorithm, this line should not be seen by users
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (str) Which kind of action space used in PPOPolicy, ['discrete', 'continuous']
+ action_space='discrete',
+ learn=dict(
+
+ # (int) for a2c, update_per_collect must be 1.
+ update_per_collect=1, # fixed value, this line should not be modified by users
+ batch_size=64,
+ learning_rate=0.001,
+ # (List[float])
+ betas=(0.9, 0.999),
+ # (float)
+ eps=1e-8,
+ # (float)
+ grad_norm=0.5,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ # (bool) Whether to normalize advantage. Default to False.
+ adv_norm=False,
+ ignore_done=False,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ # n_sample=80,
+ unroll_len=1,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.9,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.95,
+ ),
+ eval=dict(),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'vac', ['ding.model.template.vac']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the optimizer, algorithm config, main and target models.
+ """
+ assert self._cfg.action_space in ["continuous", "discrete"]
+ # Optimizer
+ self._optimizer = Adam(
+ self._model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ betas=self._cfg.learn.betas,
+ eps=self._cfg.learn.eps
+ )
+
+ # Algorithm config
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._value_weight = self._cfg.learn.value_weight
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._adv_norm = self._cfg.learn.adv_norm
+ self._grad_norm = self._cfg.learn.grad_norm
+
+ # Main and target models
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ self._learn_model.reset()
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs','adv']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
+ """
+ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False)
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._learn_model.train()
+
+ for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
+ # forward
+ output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic')
+
+ adv = batch['adv']
+ return_ = batch['value'] + adv
+ if self._adv_norm:
+ # norm adv in total train_batch
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
+ error_data = a2c_data(output['logit'], batch['action'], output['value'], adv, return_, batch['weight'])
+
+ # Calculate A2C loss
+ if self._action_space == 'continuous':
+ a2c_loss = a2c_error_continuous(error_data)
+ elif self._action_space == 'discrete':
+ a2c_loss = a2c_error(error_data)
+
+ wv, we = self._value_weight, self._entropy_weight
+ total_loss = a2c_loss.policy_loss + wv * a2c_loss.value_loss - we * a2c_loss.entropy_loss
+
+ # ====================
+ # A2C-learning update
+ # ====================
+
+ self._optimizer.zero_grad()
+ total_loss.backward()
+
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ list(self._learn_model.parameters()),
+ max_norm=self._grad_norm,
+ )
+ self._optimizer.step()
+
+ # =============
+ # after update
+ # =============
+ # only record last updates information in logger
+ return {
+ 'cur_lr': self._optimizer.param_groups[0]['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': a2c_loss.policy_loss.item(),
+ 'value_loss': a2c_loss.value_loss.item(),
+ 'entropy_loss': a2c_loss.entropy_loss.item(),
+ 'adv_abs_max': adv.abs().max().item(),
+ 'grad_norm': grad_norm,
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ """
+
+ assert self._cfg.action_space in ["continuous", "discrete"]
+ self._unroll_len = self._cfg.collect.unroll_len
+
+ self._action_space = self._cfg.action_space
+ if self._action_space == 'continuous':
+ self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample')
+ elif self._action_space == 'discrete':
+ self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
+ self._collect_model.reset()
+ # Algorithm
+ self._gamma = self._cfg.collect.discount_factor
+ self._gae_lambda = self._cfg.collect.gae_lambda
+
+ def _forward_collect(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of collect mode.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, mode='compute_actor_critic')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': model_output['action'],
+ 'value': model_output['value'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the trajectory and the n step return data, then sample from the n_step return data
+ Arguments:
+ - data (:obj:`list`): The trajectory's buffer list
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ data = get_gae_with_default_last_value(
+ data,
+ data[-1]['done'],
+ gamma=self._gamma,
+ gae_lambda=self._gae_lambda,
+ cuda=self._cuda,
+ )
+ return get_train_sample(data, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model with argmax strategy.
+ """
+ assert self._cfg.action_space in ["continuous", "discrete"]
+ self._action_space = self._cfg.action_space
+ if self._action_space == 'continuous':
+ self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample')
+ elif self._action_space == 'discrete':
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of eval mode, similar to ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'entropy_loss', 'adv_abs_max', 'grad_norm']
diff --git a/DI-engine/ding/policy/acer.py b/DI-engine/ding/policy/acer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7ac4db7753eec8af94bdf8657a4bfcf9e9524df9
--- /dev/null
+++ b/DI-engine/ding/policy/acer.py
@@ -0,0 +1,485 @@
+from collections import namedtuple
+from typing import List, Dict, Any, Tuple
+import copy
+
+import torch
+
+from ding.model import model_wrap
+from ding.rl_utils import get_train_sample, compute_q_retraces, acer_policy_error,\
+ acer_value_error, acer_trust_region_update
+from ding.torch_utils import Adam, RMSprop, to_device
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from ding.policy.base_policy import Policy
+
+EPS = 1e-8
+
+
+@POLICY_REGISTRY.register('acer')
+class ACERPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of ACER algorithm.
+
+ Config:
+ == ======================= ======== ============== ===================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ======================= ======== ============== ===================================== =======================
+ 1 ``type`` str acer | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff-
+ | | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is
+ | on-policy or off-policy
+ 4 ``trust_region`` bool True | Whether the RL algorithm use trust |
+ | region constraint |
+ 5 ``trust_region_value`` float 1.0 | maximum range of the trust region |
+ 6 ``unroll_len`` int 32 | trajectory length to calculate
+ | Q retrace target
+ 7 ``learn.update`` int 4 | How many updates(iterations) to | this args can be vary
+ ``per_collect`` | train after collector's one | from envs. Bigger val
+ | collection. Only |
+ | valid in serial training | means more off-policy
+ 8 ``c_clip_ratio`` float 1.0 | clip ratio of importance weights |
+ == ======================= ======== ============== ===================================== =======================
+ """
+ unroll_len = 32
+ config = dict(
+ type='acer',
+ cuda=False,
+ # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
+ # here we follow ppo serial pipeline, the original is False
+ on_policy=False,
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ learn=dict(
+ # (str) the type of gradient clip method
+ grad_clip_type=None,
+ # (float) max value when ACER use gradient clip
+ clip_value=None,
+
+ # (int) collect n_sample data, train model update_per_collect times
+ # here we follow ppo serial pipeline
+ update_per_collect=4,
+ # (int) the number of data for a train iteration
+ batch_size=16,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.0001,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.9,
+ # (float) additional discounting parameter
+ lambda_=0.95,
+ # (int) the trajectory length to calculate v-trace target
+ unroll_len=unroll_len,
+ # (float) clip ratio of importance weights
+ c_clip_ratio=10,
+ trust_region=True,
+ trust_region_value=1.0,
+ learning_rate_actor=0.0005,
+ learning_rate_critic=0.0005,
+ target_theta=0.01
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ # n_sample=16,
+ # (int) the trajectory length to calculate v-trace target
+ unroll_len=unroll_len,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.9,
+ gae_lambda=0.95,
+ collector=dict(
+ type='sample',
+ collect_print_freq=1000,
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=200, ), ),
+ other=dict(replay_buffer=dict(
+ replay_buffer_size=1000,
+ max_use=16,
+ ), ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'acer', ['ding.model.template.acer']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Initialize the optimizer, algorithm config and main model.
+ """
+ # Optimizer
+ self._optimizer_actor = Adam(
+ self._model.actor.parameters(),
+ lr=self._cfg.learn.learning_rate_actor,
+ grad_clip_type=self._cfg.learn.grad_clip_type,
+ clip_value=self._cfg.learn.clip_value
+ )
+ self._optimizer_critic = Adam(
+ self._model.critic.parameters(),
+ lr=self._cfg.learn.learning_rate_critic,
+ )
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+
+ self._action_shape = self._cfg.model.action_shape
+ self._unroll_len = self._cfg.learn.unroll_len
+
+ # Algorithm config
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._value_weight = self._cfg.learn.value_weight
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._gamma = self._cfg.learn.discount_factor
+ # self._rho_clip_ratio = self._cfg.learn.rho_clip_ratio
+ self._c_clip_ratio = self._cfg.learn.c_clip_ratio
+ # self._rho_pg_clip_ratio = self._cfg.learn.rho_pg_clip_ratio
+ self._use_trust_region = self._cfg.learn.trust_region
+ self._trust_region_value = self._cfg.learn.trust_region_value
+ # Main model
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _data_preprocess_learn(self, data: List[Dict[str, Any]]):
+ """
+ Overview:
+ Data preprocess function of learn mode.
+ Convert list trajectory data to to trajectory data, which is a dict of tensors.
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): List type data, a list of data for training. Each list element is a \
+ dict, whose values are torch.Tensor or np.ndarray or dict/list combinations, keys include at least 'obs',\
+ 'next_obs', 'logit', 'action', 'reward', 'done'
+ Returns:
+ - data (:obj:`dict`): Dict type data. Values are torch.Tensor or np.ndarray or dict/list combinations. \
+ ReturnsKeys:
+ - necessary: 'logit', 'action', 'reward', 'done', 'weight', 'obs_plus_1'.
+ - optional and not used in later computation: 'obs', 'next_obs'.'IS', 'collect_iter', 'replay_unique_id', \
+ 'replay_buffer_idx', 'priority', 'staleness', 'use'.
+ ReturnsShapes:
+ - obs_plus_1 (:obj:`torch.FloatTensor`): :math:`(T * B, obs_shape)`, where T is timestep, B is batch size \
+ and obs_shape is the shape of single env observation
+ - logit (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where N is action dim
+ - action (:obj:`torch.LongTensor`): :math:`(T, B)`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T+1, B)`
+ - done (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ - weight (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ """
+ data = default_collate(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+ data['weight'] = data.get('weight', None)
+ # shape (T+1)*B,env_obs_shape
+ data['obs_plus_1'] = torch.cat((data['obs'] + data['next_obs'][-1:]), dim=0)
+ data['logit'] = torch.cat(
+ data['logit'], dim=0
+ ).reshape(self._unroll_len, -1, self._action_shape) # shape T,B,env_action_shape
+ data['action'] = torch.cat(data['action'], dim=0).reshape(self._unroll_len, -1) # shape T,B,
+ data['done'] = torch.cat(data['done'], dim=0).reshape(self._unroll_len, -1).float() # shape T,B,
+ data['reward'] = torch.cat(data['reward'], dim=0).reshape(self._unroll_len, -1) # shape T,B,
+ data['weight'] = torch.cat(
+ data['weight'], dim=0
+ ).reshape(self._unroll_len, -1) if data['weight'] else None # shape T,B
+ return data
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward computation graph of learn mode(updating policy).
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): List type data, a list of data for training. Each list element is a \
+ dict, whose values are torch.Tensor or np.ndarray or dict/list combinations, keys include at least 'obs',\
+ 'next_obs', 'logit', 'action', 'reward', 'done'
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
+ recorded in text log and tensorboard, values are python scalar or a list of scalars.
+ ArgumentsKeys:
+ - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
+ - optional: 'collect_iter', 'replay_unique_id', 'replay_buffer_idx', 'priority', 'staleness', 'use', 'IS'
+ ReturnsKeys:
+ - necessary: ``cur_lr_actor``, ``cur_lr_critic``, ``actor_loss`,``bc_loss``,``policy_loss``,\
+ ``critic_loss``,``entropy_loss``
+ """
+ data = self._data_preprocess_learn(data)
+ self._learn_model.train()
+ action_data = self._learn_model.forward(data['obs_plus_1'], mode='compute_actor')
+ q_value_data = self._learn_model.forward(data['obs_plus_1'], mode='compute_critic')
+ avg_action_data = self._target_model.forward(data['obs_plus_1'], mode='compute_actor')
+
+ target_logit, behaviour_logit, avg_logit, actions, q_values, rewards, weights = self._reshape_data(
+ action_data, avg_action_data, q_value_data, data
+ )
+ # shape (T+1),B,env_action_shape
+ target_logit = torch.log_softmax(target_logit, dim=-1)
+ # shape T,B,env_action_shape
+ behaviour_logit = torch.log_softmax(behaviour_logit, dim=-1)
+ # shape (T+1),B,env_action_shape
+ avg_logit = torch.log_softmax(avg_logit, dim=-1)
+ with torch.no_grad():
+ # shape T,B,env_action_shape
+ ratio = torch.exp(target_logit[0:-1] - behaviour_logit)
+ # shape (T+1),B,1
+ v_pred = (q_values * torch.exp(target_logit)).sum(-1).unsqueeze(-1)
+ # Calculate retrace
+ q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio, self._gamma)
+
+ # the terminal states' weights are 0. it needs to be shift to count valid state
+ weights_ext = torch.ones_like(weights)
+ weights_ext[1:] = weights[0:-1]
+ weights = weights_ext
+ q_retraces = q_retraces[0:-1] # shape T,B,1
+ q_values = q_values[0:-1] # shape T,B,env_action_shape
+ v_pred = v_pred[0:-1] # shape T,B,1
+ target_logit = target_logit[0:-1] # shape T,B,env_action_shape
+ avg_logit = avg_logit[0:-1] # shape T,B,env_action_shape
+ total_valid = weights.sum() # 1
+ # ====================
+ # policy update
+ # ====================
+ actor_loss, bc_loss = acer_policy_error(
+ q_values, q_retraces, v_pred, target_logit, actions, ratio, self._c_clip_ratio
+ )
+ actor_loss = actor_loss * weights.unsqueeze(-1)
+ bc_loss = bc_loss * weights.unsqueeze(-1)
+ dist_new = torch.distributions.categorical.Categorical(logits=target_logit)
+ entropy_loss = (dist_new.entropy() * weights).unsqueeze(-1) # shape T,B,1
+ total_actor_loss = (actor_loss + bc_loss + self._entropy_weight * entropy_loss).sum() / total_valid
+ self._optimizer_actor.zero_grad()
+ actor_gradients = torch.autograd.grad(-total_actor_loss, target_logit, retain_graph=True)
+ if self._use_trust_region:
+ actor_gradients = acer_trust_region_update(
+ actor_gradients, target_logit, avg_logit, self._trust_region_value
+ )
+ target_logit.backward(actor_gradients)
+ self._optimizer_actor.step()
+
+ # ====================
+ # critic update
+ # ====================
+ critic_loss = (acer_value_error(q_values, q_retraces, actions) * weights.unsqueeze(-1)).sum() / total_valid
+ self._optimizer_critic.zero_grad()
+ critic_loss.backward()
+ self._optimizer_critic.step()
+ self._target_model.update(self._learn_model.state_dict())
+
+ with torch.no_grad():
+ kl_div = torch.exp(avg_logit) * (avg_logit - target_logit)
+ kl_div = (kl_div.sum(-1) * weights).sum() / total_valid
+
+ return {
+ 'cur_actor_lr': self._optimizer_actor.defaults['lr'],
+ 'cur_critic_lr': self._optimizer_critic.defaults['lr'],
+ 'actor_loss': (actor_loss.sum() / total_valid).item(),
+ 'bc_loss': (bc_loss.sum() / total_valid).item(),
+ 'policy_loss': total_actor_loss.item(),
+ 'critic_loss': critic_loss.item(),
+ 'entropy_loss': (entropy_loss.sum() / total_valid).item(),
+ 'kl_div': kl_div.item()
+ }
+
+ def _reshape_data(
+ self, action_data: Dict[str, Any], avg_action_data: Dict[str, Any], q_value_data: Dict[str, Any],
+ data: Dict[str, Any]
+ ) -> Tuple[Any, Any, Any, Any, Any, Any]:
+ r"""
+ Overview:
+ Obtain weights for loss calculating, where should be 0 for done positions
+ Update values and rewards with the weight
+ Arguments:
+ - output (:obj:`Dict[int, Any]`): Dict type data, output of learn_model forward. \
+ Values are torch.Tensor or np.ndarray or dict/list combinations,keys are value, logit.
+ - data (:obj:`Dict[int, Any]`): Dict type data, input of policy._forward_learn \
+ Values are torch.Tensor or np.ndarray or dict/list combinations. Keys includes at \
+ least ['logit', 'action', 'reward', 'done',]
+ Returns:
+ - data (:obj:`Tuple[Any]`): Tuple of target_logit, behaviour_logit, actions, \
+ values, rewards, weights
+ ReturnsShapes:
+ - target_logit (:obj:`torch.FloatTensor`): :math:`((T+1), B, Obs_Shape)`, where T is timestep,\
+ B is batch size and Obs_Shape is the shape of single env observation.
+ - behaviour_logit (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where N is action dim.
+ - avg_action_logit (:obj:`torch.FloatTensor`): :math: `(T+1, B, N)`, where N is action dim.
+ - actions (:obj:`torch.LongTensor`): :math:`(T, B)`
+ - values (:obj:`torch.FloatTensor`): :math:`(T+1, B)`
+ - rewards (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ - weights (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ """
+ target_logit = action_data['logit'].reshape(
+ self._unroll_len + 1, -1, self._action_shape
+ ) # shape (T+1),B,env_action_shape
+ behaviour_logit = data['logit'] # shape T,B,env_action_shape
+ avg_action_logit = avg_action_data['logit'].reshape(
+ self._unroll_len + 1, -1, self._action_shape
+ ) # shape (T+1),B,env_action_shape
+ actions = data['action'] # shape T,B
+ values = q_value_data['q_value'].reshape(
+ self._unroll_len + 1, -1, self._action_shape
+ ) # shape (T+1),B,env_action_shape
+ rewards = data['reward'] # shape T,B
+ weights_ = 1 - data['done'] # shape T,B
+ weights = torch.ones_like(rewards) # shape T,B
+ weights = weights_
+ return target_logit, behaviour_logit, avg_action_logit, actions, values, rewards, weights
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Return the state_dict of learn mode, usually including model and optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'actor_optimizer': self._optimizer_actor.state_dict(),
+ 'critic_optimizer': self._optimizer_critic.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ r"""
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer_actor.load_state_dict(state_dict['actor_optimizer'])
+ self._optimizer_critic.load_state_dict(state_dict['critic_optimizer'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``, initialize algorithm arguments and collect_model.
+ Use multinomial_sample to choose action.
+ """
+ self._collect_unroll_len = self._cfg.collect.unroll_len
+ self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Dict[str, Any]]:
+ r"""
+ Overview:
+ Forward computation graph of collect mode(collect training data).
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): Dict type data, stacked env data for predicting \
+ action, values are torch.Tensor or np.ndarray or dict/list combinations,keys \
+ are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Dict[str,Any]]`): Dict of predicting policy_output(logit, action) for each env.
+ ReturnsKeys
+ - necessary: ``logit``, ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ output = {i: d for i, d in zip(data_id, output)}
+ return output
+
+ def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ r"""
+ Overview:
+ For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly.
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \
+ format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`dict`): List of training samples.
+ .. note::
+ We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \
+ And the user can customize the this data processing procedure by overriding this two methods and collector \
+ itself.
+ """
+ return get_train_sample(data, self._unroll_len)
+
+ def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation,can be torch.Tensor or np.ndarray or dict/list combinations.
+ - model_output (:obj:`dict`): Output of collect model, including ['logit','action']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data, including at least ['obs','next_obs', 'logit',\
+ 'action','reward', 'done']
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'logit': policy_output['logit'],
+ 'action': policy_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``, initialize eval_model,
+ and use argmax_sample to choose action.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ r"""
+ Overview:
+ Forward computation graph of eval mode(evaluate policy performance), at most cases, it is similar to \
+ ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ReturnsKeys
+ - necessary: ``action``
+ - optional: ``logit``
+
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ output = {i: d for i, d in zip(data_id, output)}
+ return output
+
+ def _monitor_vars_learn(self) -> List[str]:
+ r"""
+ Overview:
+ Return this algorithm default model setting for demonstration.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
+ .. note::
+ The user can define and use customized network model but must obey the same interface definition indicated \
+ by import_names path. For IMPALA, ``ding.model.interface.IMPALA``
+ """
+ return ['actor_loss', 'bc_loss', 'policy_loss', 'critic_loss', 'entropy_loss', 'kl_div']
diff --git a/DI-engine/ding/policy/atoc.py b/DI-engine/ding/policy/atoc.py
new file mode 100644
index 0000000000000000000000000000000000000000..8addc327166a75e4951dab5a4e51c8b4277ce8c9
--- /dev/null
+++ b/DI-engine/ding/policy/atoc.py
@@ -0,0 +1,380 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import copy
+import torch
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('atoc')
+class ATOCPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of ATOC algorithm.
+ Interface:
+ __init__, set_setting, __repr__, state_dict_handle
+ Property:
+ learn_mode, collect_mode, eval_mode
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='atoc',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ model=dict(
+ # (bool) Whether to use communication module in ATOC, if not, it is a multi-agent DDPG
+ communication=True,
+ # (int) The number of thought size
+ thought_size=8,
+ # (int) The number of agent for each communication group
+ agent_per_group=2,
+ ),
+ learn=dict(
+ # (int) Collect n_sample data, update model n_iteration time
+ update_per_collect=5,
+ # (int) The number of data for a train iteration
+ batch_size=64,
+ # (float) Gradient-descent step size of actor
+ learning_rate_actor=0.001,
+ # (float) Gradient-descent step size of critic
+ learning_rate_critic=0.001,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) Target network update weight, theta * new_w + (1 - theta) * old_w, defaults in [0, 0.1]
+ target_theta=0.005,
+ # (float) Discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ # (bool) Whether to use communication module in ATOC, if not, it is a multi-agent DDPG
+ communication=True,
+ # (int) The frequency of actor update, each critic update
+ actor_update_freq=1,
+ # (bool) Whether use noise in action output when learning
+ noise=True,
+ # (float) The std of noise distribution for target policy smooth
+ noise_sigma=0.15,
+ # (float, float) The minimum and maximum value of noise
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ # (bool) Whether to use reward batch norm in the total batch
+ reward_batch_norm=False,
+ ignore_done=False,
+ ),
+ collect=dict(
+ # (int) Collect n_sample data, update model n_iteration time
+ # n_sample=64,
+ # (int) Unroll length of a train iteration(gradient update step)
+ unroll_len=1,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The std of noise distribution for exploration
+ noise_sigma=0.4,
+ ),
+ eval=dict(),
+ other=dict(
+ replay_buffer=dict(
+ # (int) The max size of replay buffer
+ replay_buffer_size=100000,
+ # (int) The max use count of data, if count is bigger than this value, the data will be removed
+ max_use=10,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'atoc', ['ding.model.template.atoc']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init actor and critic optimizers, algorithm config, main and target models.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ assert not self._priority and not self._priority_IS_weight
+ # algorithm config
+ self._communication = self._cfg.learn.communication
+ self._gamma = self._cfg.learn.discount_factor
+ self._actor_update_freq = self._cfg.learn.actor_update_freq
+ # actor and critic optimizer
+ self._optimizer_actor = Adam(
+ self._model.actor.parameters(),
+ lr=self._cfg.learn.learning_rate_actor,
+ )
+ self._optimizer_critic = Adam(
+ self._model.critic.parameters(),
+ lr=self._cfg.learn.learning_rate_critic,
+ )
+ if self._communication:
+ self._optimizer_actor_attention = Adam(
+ self._model.actor.attention.parameters(),
+ lr=self._cfg.learn.learning_rate_actor,
+ )
+ self._reward_batch_norm = self._cfg.learn.reward_batch_norm
+
+ # main and target models
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ if self._cfg.learn.noise:
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='action_noise',
+ noise_type='gauss',
+ noise_kwargs={
+ 'mu': 0.0,
+ 'sigma': self._cfg.learn.noise_sigma
+ },
+ noise_range=self._cfg.learn.noise_range
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ self._forward_learn_cnt = 0 # count iterations
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses.
+ """
+ loss_dict = {}
+ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False)
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # critic learn forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ next_obs = data['next_obs']
+ reward = data['reward']
+ if self._reward_batch_norm:
+ reward = (reward - reward.mean()) / (reward.std() + 1e-8)
+ # current q value
+ q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
+ # target q value.
+ with torch.no_grad():
+ next_action = self._target_model.forward(next_obs, mode='compute_actor')['action']
+ next_data = {'obs': next_obs, 'action': next_action}
+ target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']
+ td_data = v_1step_td_data(q_value.mean(-1), target_q_value.mean(-1), reward, data['done'], data['weight'])
+ critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma)
+ loss_dict['critic_loss'] = critic_loss
+ # ================
+ # critic update
+ # ================
+ self._optimizer_critic.zero_grad()
+ critic_loss.backward()
+ self._optimizer_critic.step()
+ # ===============================
+ # actor learn forward and update
+ # ===============================
+ # actor updates every ``self._actor_update_freq`` iters
+ if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
+ if self._communication:
+ output = self._learn_model.forward(data['obs'], mode='compute_actor', get_delta_q=False)
+ output['delta_q'] = data['delta_q']
+ attention_loss = self._learn_model.forward(output, mode='optimize_actor_attention')['loss']
+ loss_dict['attention_loss'] = attention_loss
+ self._optimizer_actor_attention.zero_grad()
+ attention_loss.backward()
+ self._optimizer_actor_attention.step()
+
+ output = self._learn_model.forward(data['obs'], mode='compute_actor', get_delta_q=False)
+
+ critic_input = {'obs': data['obs'], 'action': output['action']}
+ actor_loss = -self._learn_model.forward(critic_input, mode='compute_critic')['q_value'].mean()
+ loss_dict['actor_loss'] = actor_loss
+ # actor update
+ self._optimizer_actor.zero_grad()
+ actor_loss.backward()
+ self._optimizer_actor.step()
+ # =============
+ # after update
+ # =============
+ loss_dict['total_loss'] = sum(loss_dict.values())
+ self._forward_learn_cnt += 1
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr_actor': self._optimizer_actor.defaults['lr'],
+ 'cur_lr_critic': self._optimizer_critic.defaults['lr'],
+ 'priority': td_error_per_sample.abs().tolist(),
+ 'q_value': q_value.mean().item(),
+ **loss_dict,
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer_actor': self._optimizer_actor.state_dict(),
+ 'optimizer_critic': self._optimizer_critic.state_dict(),
+ 'optimize_actor_attention': self._optimizer_actor_attention.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer_actor.load_state_dict(state_dict['optimizer_actor'])
+ self._optimizer_critic.load_state_dict(state_dict['optimizer_critic'])
+ self._optimizer_actor_attention.load_state_dict(state_dict['optimize_actor_attention'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ # collect model
+ self._collect_model = model_wrap(
+ self._model,
+ wrapper_name='action_noise',
+ noise_type='gauss',
+ noise_kwargs={
+ 'mu': 0.0,
+ 'sigma': self._cfg.collect.noise_sigma
+ },
+ noise_range=None, # no noise clip in actor
+ )
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of collect mode.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, mode='compute_actor', get_delta_q=True)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
+ (here 'obs' indicates obs after env step, i.e. next_obs).
+ Return:
+ - transition (:obj:`Dict[str, Any]`): Dict type transition data.
+ """
+ if self._communication:
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': model_output['action'],
+ 'delta_q': model_output['delta_q'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ else:
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': model_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ if self._communication:
+ delta_q_batch = [d['delta_q'] for d in data]
+ delta_min = torch.stack(delta_q_batch).min()
+ delta_max = torch.stack(delta_q_batch).max()
+ for i in range(len(data)):
+ data[i]['delta_q'] = (data[i]['delta_q'] - delta_min) / (delta_max - delta_min + 1e-8)
+ return get_train_sample(data, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model. Unlike learn and collect model, eval model does not need noise.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='base')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of eval mode, similar to ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ r"""
+ Overview:
+ Return variables' name if variables are to used in monitor.
+ Returns:
+ - vars (:obj:`List[str]`): Variables' name list.
+ """
+ return [
+ 'cur_lr_actor',
+ 'cur_lr_critic',
+ 'critic_loss',
+ 'actor_loss',
+ 'attention_loss',
+ 'total_loss',
+ 'q_value',
+ ]
diff --git a/DI-engine/ding/policy/base_policy.py b/DI-engine/ding/policy/base_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ff99c7b43f880ceb76dc9375f90b8e65c2a2295
--- /dev/null
+++ b/DI-engine/ding/policy/base_policy.py
@@ -0,0 +1,861 @@
+from typing import Optional, List, Dict, Any, Tuple, Union
+from abc import ABC, abstractmethod
+from collections import namedtuple
+from easydict import EasyDict
+
+import copy
+import torch
+
+from ding.model import create_model
+from ding.utils import import_module, allreduce, broadcast, get_rank, allreduce_async, synchronize, deep_merge_dicts, \
+ POLICY_REGISTRY
+
+
+class Policy(ABC):
+ """
+ Overview:
+ The basic class of Reinforcement Learning (RL) and Imitation Learning (IL) policy in DI-engine.
+ Property:
+ ``cfg``, ``learn_mode``, ``collect_mode``, ``eval_mode``
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ """
+ Overview:
+ Get the default config of policy. This method is used to create the default config of policy.
+ Returns:
+ - cfg (:obj:`EasyDict`): The default config of corresponding policy. For the derived policy class, \
+ it will recursively merge the default config of base class and its own default config.
+
+ .. tip::
+ This method will deepcopy the ``config`` attribute of the class and return the result. So users don't need \
+ to worry about the modification of the returned config.
+ """
+ if cls == Policy:
+ raise RuntimeError("Basic class Policy doesn't have completed default_config")
+
+ base_cls = cls.__base__
+ if base_cls == Policy:
+ base_policy_cfg = EasyDict(copy.deepcopy(Policy.config))
+ else:
+ base_policy_cfg = copy.deepcopy(base_cls.default_config())
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg = deep_merge_dicts(base_policy_cfg, cfg)
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ learn_function = namedtuple(
+ 'learn_function', [
+ 'forward',
+ 'reset',
+ 'info',
+ 'monitor_vars',
+ 'get_attribute',
+ 'set_attribute',
+ 'state_dict',
+ 'load_state_dict',
+ ]
+ )
+ collect_function = namedtuple(
+ 'collect_function', [
+ 'forward',
+ 'process_transition',
+ 'get_train_sample',
+ 'reset',
+ 'get_attribute',
+ 'set_attribute',
+ 'state_dict',
+ 'load_state_dict',
+ ]
+ )
+ eval_function = namedtuple(
+ 'eval_function', [
+ 'forward',
+ 'reset',
+ 'get_attribute',
+ 'set_attribute',
+ 'state_dict',
+ 'load_state_dict',
+ ]
+ )
+ total_field = set(['learn', 'collect', 'eval'])
+ config = dict(
+ # (bool) Whether the learning policy is the same as the collecting data policy (on-policy).
+ on_policy=False,
+ # (bool) Whether to use cuda in policy.
+ cuda=False,
+ # (bool) Whether to use data parallel multi-gpu mode in policy.
+ multi_gpu=False,
+ # (bool) Whether to synchronize update the model parameters after allreduce the gradients of model parameters.
+ bp_update_sync=True,
+ # (bool) Whether to enable infinite trajectory length in data collecting.
+ traj_len_inf=False,
+ # neural network model config
+ model=dict(),
+ )
+
+ def __init__(
+ self,
+ cfg: EasyDict,
+ model: Optional[torch.nn.Module] = None,
+ enable_field: Optional[List[str]] = None
+ ) -> None:
+ """
+ Overview:
+ Initialize policy instance according to input configures and model. This method will initialize differnent \
+ fields in policy, including ``learn``, ``collect``, ``eval``. The ``learn`` field is used to train the \
+ policy, the ``collect`` field is used to collect data for training, and the ``eval`` field is used to \
+ evaluate the policy. The ``enable_field`` is used to specify which field to initialize, if it is None, \
+ then all fields will be initialized.
+ Arguments:
+ - cfg (:obj:`EasyDict`): The final merged config used to initialize policy. For the default config, \
+ see the ``config`` attribute and its comments of policy class.
+ - model (:obj:`torch.nn.Module`): The neural network model used to initialize policy. If it \
+ is None, then the model will be created according to ``default_model`` method and ``cfg.model`` field. \
+ Otherwise, the model will be set to the ``model`` instance created by outside caller.
+ - enable_field (:obj:`Optional[List[str]]`): The field list to initialize. If it is None, then all fields \
+ will be initialized. Otherwise, only the fields in ``enable_field`` will be initialized, which is \
+ beneficial to save resources.
+
+ .. note::
+ For the derived policy class, it should implement the ``_init_learn``, ``_init_collect``, ``_init_eval`` \
+ method to initialize the corresponding field.
+ """
+ self._cfg = cfg
+ self._on_policy = self._cfg.on_policy
+ if enable_field is None:
+ self._enable_field = self.total_field
+ else:
+ self._enable_field = enable_field
+ assert set(self._enable_field).issubset(self.total_field), self._enable_field
+
+ if len(set(self._enable_field).intersection(set(['learn', 'collect', 'eval']))) > 0:
+ model = self._create_model(cfg, model)
+ self._cuda = cfg.cuda and torch.cuda.is_available()
+ # now only support multi-gpu for only enable learn mode
+ if len(set(self._enable_field).intersection(set(['learn']))) > 0:
+ multi_gpu = self._cfg.multi_gpu
+ self._rank = get_rank() if multi_gpu else 0
+ if self._cuda:
+ # model.cuda() is an in-place operation.
+ model.cuda()
+ if multi_gpu:
+ bp_update_sync = self._cfg.bp_update_sync
+ self._bp_update_sync = bp_update_sync
+ self._init_multi_gpu_setting(model, bp_update_sync)
+ else:
+ self._rank = 0
+ if self._cuda:
+ # model.cuda() is an in-place operation.
+ model.cuda()
+ self._model = model
+ self._device = 'cuda:{}'.format(self._rank % torch.cuda.device_count()) if self._cuda else 'cpu'
+ else:
+ self._cuda = False
+ self._rank = 0
+ self._device = 'cpu'
+
+ # call the initialization method of different modes, such as ``_init_learn``, ``_init_collect``, ``_init_eval``
+ for field in self._enable_field:
+ getattr(self, '_init_' + field)()
+
+ def _init_multi_gpu_setting(self, model: torch.nn.Module, bp_update_sync: bool) -> None:
+ """
+ Overview:
+ Initialize multi-gpu data parallel training setting, including broadcast model parameters at the beginning \
+ of the training, and prepare the hook function to allreduce the gradients of model parameters.
+ Arguments:
+ - model (:obj:`torch.nn.Module`): The neural network model to be trained.
+ - bp_update_sync (:obj:`bool`): Whether to synchronize update the model parameters after allreduce the \
+ gradients of model parameters. Async update can be parallel in different network layers like pipeline \
+ so that it can save time.
+ """
+ for name, param in model.state_dict().items():
+ assert isinstance(param.data, torch.Tensor), type(param.data)
+ broadcast(param.data, 0)
+ # here we manually set the gradient to zero tensor at the beginning of the training, which is necessary for
+ # the case that different GPUs have different computation graph.
+ for name, param in model.named_parameters():
+ setattr(param, 'grad', torch.zeros_like(param))
+ if not bp_update_sync:
+
+ def make_hook(name, p):
+
+ def hook(*ignore):
+ allreduce_async(name, p.grad.data)
+
+ return hook
+
+ for i, (name, p) in enumerate(model.named_parameters()):
+ if p.requires_grad:
+ p_tmp = p.expand_as(p)
+ grad_acc = p_tmp.grad_fn.next_functions[0][0]
+ grad_acc.register_hook(make_hook(name, p))
+
+ def _create_model(self, cfg: EasyDict, model: Optional[torch.nn.Module] = None) -> torch.nn.Module:
+ """
+ Overview:
+ Create or validate the neural network model according to input configures and model. If the input model is \
+ None, then the model will be created according to ``default_model`` method and ``cfg.model`` field. \
+ Otherwise, the model will be verified as an instance of ``torch.nn.Module`` and set to the ``model`` \
+ instance created by outside caller.
+ Arguments:
+ - cfg (:obj:`EasyDict`): The final merged config used to initialize policy.
+ - model (:obj:`torch.nn.Module`): The neural network model used to initialize policy. User can refer to \
+ the default model defined in corresponding policy to customize its own model.
+ Returns:
+ - model (:obj:`torch.nn.Module`): The created neural network model. The different modes of policy will \
+ add distinct wrappers and plugins to the model, which is used to train, collect and evaluate.
+ Raises:
+ - RuntimeError: If the input model is not None and is not an instance of ``torch.nn.Module``.
+ """
+ if model is None:
+ model_cfg = cfg.model
+ if 'type' not in model_cfg:
+ m_type, import_names = self.default_model()
+ model_cfg.type = m_type
+ model_cfg.import_names = import_names
+ return create_model(model_cfg)
+ else:
+ if isinstance(model, torch.nn.Module):
+ return model
+ else:
+ raise RuntimeError("invalid model: {}".format(type(model)))
+
+ @property
+ def cfg(self) -> EasyDict:
+ return self._cfg
+
+ @abstractmethod
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. This method will be \
+ called in ``__init__`` method if ``learn`` field is in ``enable_field``. Almost different policies have \
+ its own learn mode, so this method must be overrided in subclass.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. This method will be \
+ called in ``__init__`` method if ``collect`` field is in ``enable_field``. Almost different policies have \
+ its own collect mode, so this method must be overrided in subclass.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_collect`` \
+ and ``_load_state_dict_collect`` methods.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. This method will be \
+ called in ``__init__`` method if ``eval`` field is in ``enable_field``. Almost different policies have \
+ its own eval mode, so this method must be overrided in subclass.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_eval`` \
+ and ``_load_state_dict_eval`` methods.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ raise NotImplementedError
+
+ @property
+ def learn_mode(self) -> 'Policy.learn_function': # noqa
+ """
+ Overview:
+ Return the interfaces of learn mode of policy, which is used to train the model. Here we use namedtuple \
+ to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \
+ subclass can override the interfaces to customize its own learn mode.
+ Returns:
+ - interfaces (:obj:`Policy.learn_function`): The interfaces of learn mode of policy, it is a namedtuple \
+ whose values of distinct fields are different internal methods.
+ Examples:
+ >>> policy = Policy(cfg, model)
+ >>> policy_learn = policy.learn_mode
+ >>> train_output = policy_learn.forward(data)
+ >>> state_dict = policy_learn.state_dict()
+ """
+ return Policy.learn_function(
+ self._forward_learn,
+ self._reset_learn,
+ self.__repr__,
+ self._monitor_vars_learn,
+ self._get_attribute,
+ self._set_attribute,
+ self._state_dict_learn,
+ self._load_state_dict_learn,
+ )
+
+ @property
+ def collect_mode(self) -> 'Policy.collect_function': # noqa
+ """
+ Overview:
+ Return the interfaces of collect mode of policy, which is used to train the model. Here we use namedtuple \
+ to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \
+ subclass can override the interfaces to customize its own collect mode.
+ Returns:
+ - interfaces (:obj:`Policy.collect_function`): The interfaces of collect mode of policy, it is a \
+ namedtuple whose values of distinct fields are different internal methods.
+ Examples:
+ >>> policy = Policy(cfg, model)
+ >>> policy_collect = policy.collect_mode
+ >>> obs = env_manager.ready_obs
+ >>> inference_output = policy_collect.forward(obs)
+ >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
+ """
+ return Policy.collect_function(
+ self._forward_collect,
+ self._process_transition,
+ self._get_train_sample,
+ self._reset_collect,
+ self._get_attribute,
+ self._set_attribute,
+ self._state_dict_collect,
+ self._load_state_dict_collect,
+ )
+
+ @property
+ def eval_mode(self) -> 'Policy.eval_function': # noqa
+ """
+ Overview:
+ Return the interfaces of eval mode of policy, which is used to train the model. Here we use namedtuple \
+ to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \
+ subclass can override the interfaces to customize its own eval mode.
+ Returns:
+ - interfaces (:obj:`Policy.eval_function`): The interfaces of eval mode of policy, it is a namedtuple \
+ whose values of distinct fields are different internal methods.
+ Examples:
+ >>> policy = Policy(cfg, model)
+ >>> policy_eval = policy.eval_mode
+ >>> obs = env_manager.ready_obs
+ >>> inference_output = policy_eval.forward(obs)
+ >>> next_obs, rew, done, info = env_manager.step(inference_output.action)
+ """
+ return Policy.eval_function(
+ self._forward_eval,
+ self._reset_eval,
+ self._get_attribute,
+ self._set_attribute,
+ self._state_dict_eval,
+ self._load_state_dict_eval,
+ )
+
+ def _set_attribute(self, name: str, value: Any) -> None:
+ """
+ Overview:
+ In order to control the access of the policy attributes, we expose different modes to outside rather than \
+ directly use the policy instance. And we also provide a method to set the attribute of the policy in \
+ different modes. And the new attribute will named as ``_{name}``.
+ Arguments:
+ - name (:obj:`str`): The name of the attribute.
+ - value (:obj:`Any`): The value of the attribute.
+ """
+ setattr(self, '_' + name, value)
+
+ def _get_attribute(self, name: str) -> Any:
+ """
+ Overview:
+ In order to control the access of the policy attributes, we expose different modes to outside rather than \
+ directly use the policy instance. And we also provide a method to get the attribute of the policy in \
+ different modes.
+ Arguments:
+ - name (:obj:`str`): The name of the attribute.
+ Returns:
+ - value (:obj:`Any`): The value of the attribute.
+
+ .. note::
+ DI-engine's policy will first try to access `_get_{name}` method, and then try to access `_{name}` \
+ attribute. If both of them are not found, it will raise a ``NotImplementedError``.
+ """
+ if hasattr(self, '_get_' + name):
+ return getattr(self, '_get_' + name)()
+ elif hasattr(self, '_' + name):
+ return getattr(self, '_' + name)
+ else:
+ raise NotImplementedError
+
+ def __repr__(self) -> str:
+ """
+ Overview:
+ Get the string representation of the policy.
+ Returns:
+ - repr (:obj:`str`): The string representation of the policy.
+ """
+ return "DI-engine DRL Policy\n{}".format(repr(self._model))
+
+ def sync_gradients(self, model: torch.nn.Module) -> None:
+ """
+ Overview:
+ Synchronize (allreduce) gradients of model parameters in data-parallel multi-gpu training.
+ Arguments:
+ - model (:obj:`torch.nn.Module`): The model to synchronize gradients.
+
+ .. note::
+ This method is only used in multi-gpu training, and it shoule be called after ``backward`` method and \
+ before ``step`` method. The user can also use ``bp_update_sync`` config to control whether to synchronize \
+ gradients allreduce and optimizer updates.
+ """
+
+ if self._bp_update_sync:
+ for name, param in model.named_parameters():
+ if param.requires_grad:
+ allreduce(param.grad.data)
+ else:
+ synchronize()
+
+ # don't need to implement default_model method by force
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \
+ ``ding.model.template.q_learning.DQN``
+ """
+ raise NotImplementedError
+
+ # *************************************** learn function ************************************
+
+ @abstractmethod
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss value, policy entropy, q value, priority, \
+ and so on. This method is left to be implemented by the subclass, and more arguments can be added in \
+ ``data`` item if necessary.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, in the ``_forward_learn`` method, data should be stacked in \
+ the batch dimension by some utility functions such as ``default_preprocess_learn``.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The training information of policy forward, including some metrics for \
+ monitoring training such as loss, priority, q value, policy entropy, and some data for next step \
+ training such as priority. Note the output data item should be Python native scalar rather than \
+ PyTorch tensor, which is convenient for the outside to use.
+ """
+ raise NotImplementedError
+
+ # don't need to implement _reset_learn method by force
+ def _reset_learn(self, data_id: Optional[List[int]] = None) -> None:
+ """
+ Overview:
+ Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the \
+ memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \
+ varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \
+ different trajectories in ``data_id`` will have different hidden state in RNN.
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \
+ specified by ``data_id``.
+
+ .. note::
+ This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
+ """
+ pass
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+
+ .. tip::
+ The default implementation is ``['cur_lr', 'total_loss']``. Other derived classes can overwrite this \
+ method to add their own keys if necessary.
+ """
+ return ['cur_lr', 'total_loss']
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode, usually including model and optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _get_batch_size(self) -> Union[int, Dict[str, int]]:
+ # some specifial algorithms use different batch size for different optimization parts.
+ if 'batch_size' in self._cfg:
+ return self._cfg.batch_size
+ else: # for compatibility
+ return self._cfg.learn.batch_size
+
+ # *************************************** collect function ************************************
+
+ @abstractmethod
+ def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs, or the action logits to calculate the loss in learn \
+ mode. This method is left to be implemented by the subclass, and more arguments can be added in ``kwargs`` \
+ part if necessary.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \
+ dict is the same as the input data, i.e. environment id.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def _process_transition(
+ self, obs: Union[torch.Tensor, Dict[str, torch.Tensor]], policy_output: Dict[str, torch.Tensor],
+ timestep: namedtuple
+ ) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, such as . Some policies \
+ need to do some special process and pack its own necessary attributes (e.g. hidden state and logit), \
+ so this method is left to be implemented by the subclass.
+ Arguments:
+ - obs (:obj:`Union[torch.Tensor, Dict[str, torch.Tensor]]`): The observation of the current timestep.
+ - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \
+ as input. Usually, it contains the action and the logit of the action.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
+ except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
+ reward, done, info, etc.
+ Returns:
+ - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. A train sample can be a processed transition (DQN with nstep TD) \
+ or some multi-timestep transitions (DRQN). This method is usually used in collectors to execute necessary \
+ RL data preprocessing before training, which can help learner amortize revelant time consumption. \
+ In addition, you can also implement this method as an identity function and do the data processing \
+ in ``self._forward_learn`` method.
+ Arguments:
+ - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \
+ as input transitions, but may contain more data for training, such as nstep reward, advantage, etc.
+
+ .. note::
+ We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \
+ And the user can customize the this data processing procecure by overriding this two methods and collector \
+ itself
+ """
+ raise NotImplementedError
+
+ # don't need to implement _reset_collect method by force
+ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None:
+ """
+ Overview:
+ Reset some stateful variables for collect mode when necessary, such as the hidden state of RNN or the \
+ memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \
+ varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \
+ different environments/episodes in collecting in ``data_id`` will have different hidden state in RNN.
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \
+ specified by ``data_id``.
+
+ .. note::
+ This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
+ """
+ pass
+
+ def _state_dict_collect(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of collect mode, only including model in usual, which is necessary for distributed \
+ training scenarios to auto-recover collectors.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of current policy collect state, for saving and restoring.
+
+ .. tip::
+ Not all the scenarios need to auto-recover collectors, sometimes, we can directly shutdown the crashed \
+ collector and renew a new one.
+ """
+ return {'model': self._collect_model.state_dict()}
+
+ def _load_state_dict_collect(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy collect mode, such as load pretrained state_dict, auto-recover \
+ checkpoint, or model replica from learner in distributed training scenarios.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of policy collect state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._collect_model.load_state_dict(state_dict['model'], strict=True)
+
+ def _get_n_sample(self) -> Union[int, None]:
+ if 'n_sample' in self._cfg:
+ return self._cfg.n_sample
+ else: # for compatibility
+ return self._cfg.collect.get('n_sample', None) # for some adpative collecting data case
+
+ def _get_n_episode(self) -> Union[int, None]:
+ if 'n_episode' in self._cfg:
+ return self._cfg.n_episode
+ else: # for compatibility
+ return self._cfg.collect.get('n_episode', None) # for some adpative collecting data case
+
+ # *************************************** eval function ************************************
+
+ @abstractmethod
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance, such as interacting with envs or \
+ computing metrics on validation dataset). Forward means that the policy gets some necessary data (mainly \
+ observation) from the envs and then returns the output data, such as the action to interact with the envs. \
+ This method is left to be implemented by the subclass.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+ """
+ raise NotImplementedError
+
+ # don't need to implement _reset_eval method by force
+ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
+ """
+ Overview:
+ Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \
+ memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \
+ varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \
+ different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN.
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \
+ specified by ``data_id``.
+
+ .. note::
+ This method is not mandatory to be implemented. The sub-class can overwrite this method if necessary.
+ """
+ pass
+
+ def _state_dict_eval(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of eval mode, only including model in usual, which is necessary for distributed \
+ training scenarios to auto-recover evaluators.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of current policy eval state, for saving and restoring.
+
+ .. tip::
+ Not all the scenarios need to auto-recover evaluators, sometimes, we can directly shutdown the crashed \
+ evaluator and renew a new one.
+ """
+ return {'model': self._eval_model.state_dict()}
+
+ def _load_state_dict_eval(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy eval mode, such as load auto-recover \
+ checkpoint, or model replica from learner in distributed training scenarios.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of policy eval state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._eval_model.load_state_dict(state_dict['model'], strict=True)
+
+
+class CommandModePolicy(Policy):
+ """
+ Overview:
+ Policy with command mode, which can be used in old version of DI-engine pipeline: ``serial_pipeline``. \
+ ``CommandModePolicy`` uses ``_get_setting_learn``, ``_get_setting_collect``, ``_get_setting_eval`` methods \
+ to exchange information between different workers.
+
+ Interface:
+ ``_init_command``, ``_get_setting_learn``, ``_get_setting_collect``, ``_get_setting_eval``
+ Property:
+ ``command_mode``
+ """
+ command_function = namedtuple('command_function', ['get_setting_learn', 'get_setting_collect', 'get_setting_eval'])
+ total_field = set(['learn', 'collect', 'eval', 'command'])
+
+ @property
+ def command_mode(self) -> 'Policy.command_function': # noqa
+ """
+ Overview:
+ Return the interfaces of command mode of policy, which is used to train the model. Here we use namedtuple \
+ to define immutable interfaces and restrict the usage of policy in different mode. Moreover, derived \
+ subclass can override the interfaces to customize its own command mode.
+ Returns:
+ - interfaces (:obj:`Policy.command_function`): The interfaces of command mode, it is a namedtuple \
+ whose values of distinct fields are different internal methods.
+ Examples:
+ >>> policy = CommandModePolicy(cfg, model)
+ >>> policy_command = policy.command_mode
+ >>> settings = policy_command.get_setting_learn(command_info)
+ """
+ return CommandModePolicy.command_function(
+ self._get_setting_learn, self._get_setting_collect, self._get_setting_eval
+ )
+
+ @abstractmethod
+ def _init_command(self) -> None:
+ """
+ Overview:
+ Initialize the command mode of policy, including related attributes and modules. This method will be \
+ called in ``__init__`` method if ``command`` field is in ``enable_field``. Almost different policies have \
+ its own command mode, so this method must be overrided in subclass.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_command`` method, you'd better name them \
+ with prefix ``_command_`` to avoid conflict with other modes, such as ``self._command_attr1``.
+ """
+ raise NotImplementedError
+
+ # *************************************** command function ************************************
+ @abstractmethod
+ def _get_setting_learn(self, command_info: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Overview:
+ Accoding to ``command_info``, i.e., global training information (e.g. training iteration, collected env \
+ step, evaluation results, etc.), return the setting of learn mode, which contains dynamically changed \
+ hyperparameters for learn mode, such as ``batch_size``, ``learning_rate``, etc.
+ Arguments:
+ - command_info (:obj:`Dict[str, Any]`): The global training information, which is defined in ``commander``.
+ Returns:
+ - setting (:obj:`Dict[str, Any]`): The latest setting of learn mode, which is usually used as extra \
+ arguments of the ``policy._forward_learn`` method.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def _get_setting_collect(self, command_info: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Overview:
+ Accoding to ``command_info``, i.e., global training information (e.g. training iteration, collected env \
+ step, evaluation results, etc.), return the setting of collect mode, which contains dynamically changed \
+ hyperparameters for collect mode, such as ``eps``, ``temperature``, etc.
+ Arguments:
+ - command_info (:obj:`Dict[str, Any]`): The global training information, which is defined in ``commander``.
+ Returns:
+ - setting (:obj:`Dict[str, Any]`): The latest setting of collect mode, which is usually used as extra \
+ arguments of the ``policy._forward_collect`` method.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def _get_setting_eval(self, command_info: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Overview:
+ Accoding to ``command_info``, i.e., global training information (e.g. training iteration, collected env \
+ step, evaluation results, etc.), return the setting of eval mode, which contains dynamically changed \
+ hyperparameters for eval mode, such as ``temperature``, etc.
+ Arguments:
+ - command_info (:obj:`Dict[str, Any]`): The global training information, which is defined in ``commander``.
+ Returns:
+ - setting (:obj:`Dict[str, Any]`): The latest setting of eval mode, which is usually used as extra \
+ arguments of the ``policy._forward_eval`` method.
+ """
+ raise NotImplementedError
+
+
+def create_policy(cfg: EasyDict, **kwargs) -> Policy:
+ """
+ Overview:
+ Create a policy instance according to ``cfg`` and other kwargs.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Final merged policy config.
+ ArgumentsKeys:
+ - type (:obj:`str`): Policy type set in ``POLICY_REGISTRY.register`` method , such as ``dqn`` .
+ - import_names (:obj:`List[str]`): A list of module names (paths) to import before creating policy, such \
+ as ``ding.policy.dqn`` .
+ Returns:
+ - policy (:obj:`Policy`): The created policy instance.
+
+ .. tip::
+ ``kwargs`` contains other arguments that need to be passed to the policy constructor. You can refer to \
+ the ``__init__`` method of the corresponding policy class for details.
+
+ .. note::
+ For more details about how to merge config, please refer to the system document of DI-engine \
+ (`en link <../03_system/config.html>`_).
+ """
+ import_module(cfg.get('import_names', []))
+ return POLICY_REGISTRY.build(cfg.type, cfg=cfg, **kwargs)
+
+
+def get_policy_cls(cfg: EasyDict) -> type:
+ """
+ Overview:
+ Get policy class according to ``cfg``, which is used to access related class variables/methods.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Final merged policy config.
+ ArgumentsKeys:
+ - type (:obj:`str`): Policy type set in ``POLICY_REGISTRY.register`` method , such as ``dqn`` .
+ - import_names (:obj:`List[str]`): A list of module names (paths) to import before creating policy, such \
+ as ``ding.policy.dqn`` .
+ Returns:
+ - policy (:obj:`type`): The policy class.
+ """
+ import_module(cfg.get('import_names', []))
+ return POLICY_REGISTRY.get(cfg.type)
diff --git a/DI-engine/ding/policy/bc.py b/DI-engine/ding/policy/bc.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c95b8abec31e83d50855d0c7895990903626d5d
--- /dev/null
+++ b/DI-engine/ding/policy/bc.py
@@ -0,0 +1,353 @@
+import math
+import torch
+import torch.nn as nn
+import copy
+from torch.optim import Adam, SGD, AdamW
+from torch.optim.lr_scheduler import LambdaLR
+import logging
+from typing import List, Dict, Any, Tuple, Union, Optional
+from collections import namedtuple
+from easydict import EasyDict
+from ding.policy import Policy
+from ding.model import model_wrap
+from ding.torch_utils import to_device, to_list
+from ding.utils import EasyTimer
+from ding.utils.data import default_collate, default_decollate
+from ding.rl_utils import get_nstep_return_data, get_train_sample
+from ding.utils import POLICY_REGISTRY
+from ding.torch_utils.loss.cross_entropy_loss import LabelSmoothCELoss
+
+
+@POLICY_REGISTRY.register('bc')
+class BehaviourCloningPolicy(Policy):
+ """
+ Overview:
+ Behaviour Cloning (BC) policy class, which supports both discrete and continuous action space. \
+ The policy is trained by supervised learning, and the data is a offline dataset collected by expert.
+ """
+
+ config = dict(
+ type='bc',
+ cuda=False,
+ on_policy=False,
+ continuous=False,
+ action_shape=19,
+ learn=dict(
+ update_per_collect=1,
+ batch_size=32,
+ learning_rate=1e-5,
+ lr_decay=False,
+ decay_epoch=30,
+ decay_rate=0.1,
+ warmup_lr=1e-4,
+ warmup_epoch=3,
+ optimizer='SGD',
+ momentum=0.9,
+ weight_decay=1e-4,
+ ce_label_smooth=False,
+ show_accuracy=False,
+ tanh_mask=False, # if actions always converge to 1 or -1, use this.
+ ),
+ collect=dict(
+ unroll_len=1,
+ noise=False,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ eval=dict(), # for compatibility
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For example about discrete BC, its registered name is ``discrete_bc`` and the \
+ import_names is ``ding.model.template.bc``.
+ """
+ if self._cfg.continuous:
+ return 'continuous_bc', ['ding.model.template.bc']
+ else:
+ return 'discrete_bc', ['ding.model.template.bc']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For BC, it mainly contains \
+ optimizer, algorithm-specific arguments such as lr_scheduler, loss, etc. \
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ assert self._cfg.learn.optimizer in ['SGD', 'Adam'], self._cfg.learn.optimizer
+ if self._cfg.learn.optimizer == 'SGD':
+ self._optimizer = SGD(
+ self._model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ weight_decay=self._cfg.learn.weight_decay,
+ momentum=self._cfg.learn.momentum
+ )
+ elif self._cfg.learn.optimizer == 'Adam':
+ if self._cfg.learn.weight_decay is None:
+ self._optimizer = Adam(
+ self._model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ )
+ else:
+ self._optimizer = AdamW(
+ self._model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ weight_decay=self._cfg.learn.weight_decay
+ )
+ if self._cfg.learn.lr_decay:
+
+ def lr_scheduler_fn(epoch):
+ if epoch <= self._cfg.learn.warmup_epoch:
+ return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate
+ else:
+ ratio = (epoch - self._cfg.learn.warmup_epoch) // self._cfg.learn.decay_epoch
+ return math.pow(self._cfg.learn.decay_rate, ratio)
+
+ self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn)
+ self._timer = EasyTimer(cuda=True)
+ self._learn_model = model_wrap(self._model, 'base')
+ self._learn_model.reset()
+
+ if self._cfg.continuous:
+ if self._cfg.loss_type == 'l1_loss':
+ self._loss = nn.L1Loss()
+ elif self._cfg.loss_type == 'mse_loss':
+ self._loss = nn.MSELoss()
+ else:
+ raise KeyError("not support loss type: {}".format(self._cfg.loss_type))
+ else:
+ if not self._cfg.learn.ce_label_smooth:
+ self._loss = nn.CrossEntropyLoss()
+ else:
+ self._loss = LabelSmoothCELoss(0.1)
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss and time.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
+ dimension by some utility functions such as ``default_preprocess_learn``. \
+ For BC, each element in list is a dict containing at least the following keys: ``obs``, ``action``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+ """
+ if isinstance(data, list):
+ data = default_collate(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._learn_model.train()
+ with self._timer:
+ obs, action = data['obs'], data['action'].squeeze()
+ if self._cfg.continuous:
+ if self._cfg.learn.tanh_mask:
+ """tanh_mask
+ We mask the action out of range of [tanh(-1),tanh(1)], model will learn information
+ and produce action in [-1,1]. So the action won't always converge to -1 or 1.
+ """
+ mu = self._eval_model.forward(data['obs'])['action']
+ bound = 1 - 2 / (math.exp(2) + 1) # tanh(1): (e-e**(-1))/(e+e**(-1))
+ mask = mu.ge(-bound) & mu.le(bound)
+ mask_percent = 1 - mask.sum().item() / mu.numel()
+ if mask_percent > 0.8: # if there is too little data to learn(<80%). So we use all data.
+ loss = self._loss(mu, action.detach())
+ else:
+ loss = self._loss(mu.masked_select(mask), action.masked_select(mask).detach())
+ else:
+ mu = self._learn_model.forward(data['obs'])['action']
+ # When we use bco, action is predicted by idm, gradient is not expected.
+ loss = self._loss(mu, action.detach())
+ else:
+ a_logit = self._learn_model.forward(obs)
+ # When we use bco, action is predicted by idm, gradient is not expected.
+ loss = self._loss(a_logit['logit'], action.detach())
+
+ if self._cfg.learn.show_accuracy:
+ # Calculate the overall accuracy and the accuracy of each class
+ total_accuracy = (a_logit['action'] == action.view(-1)).float().mean()
+ self.total_accuracy_in_dataset.append(total_accuracy)
+ logging.info(f'the total accuracy in current train mini-batch is: {total_accuracy.item()}')
+ for action_unique in to_list(torch.unique(action)):
+ action_index = (action == action_unique).nonzero(as_tuple=True)[0]
+ action_accuracy = (a_logit['action'][action_index] == action.view(-1)[action_index]
+ ).float().mean()
+ if math.isnan(action_accuracy):
+ action_accuracy = 0.0
+ self.action_accuracy_in_dataset[action_unique].append(action_accuracy)
+ logging.info(
+ f'the accuracy of action {action_unique} in current train mini-batch is: '
+ f'{action_accuracy.item()}, '
+ f'(nan means the action does not appear in the mini-batch)'
+ )
+ forward_time = self._timer.value
+ with self._timer:
+ self._optimizer.zero_grad()
+ loss.backward()
+ backward_time = self._timer.value
+ with self._timer:
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ sync_time = self._timer.value
+ self._optimizer.step()
+ cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups]
+ cur_lr = sum(cur_lr) / len(cur_lr)
+ return {
+ 'cur_lr': cur_lr,
+ 'total_loss': loss.item(),
+ 'forward_time': forward_time,
+ 'backward_time': backward_time,
+ 'sync_time': sync_time,
+ }
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time']
+
+ def _init_eval(self):
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For BC, it contains the \
+ eval model to greedily select action with argmax q_value mechanism for discrete action space.
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ if self._cfg.continuous:
+ self._eval_model = model_wrap(self._model, wrapper_name='base')
+ else:
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+ """
+ tensor_input = isinstance(data, torch.Tensor)
+ if tensor_input:
+ data = default_collate(list(data))
+ else:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ if tensor_input:
+ return output
+ else:
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ BC policy uses offline dataset so it does not need to collect data. However, sometimes we need to use the \
+ trained BC policy to collect data for other purposes.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ if self._cfg.continuous:
+ self._collect_model = model_wrap(
+ self._model,
+ wrapper_name='action_noise',
+ noise_type='gauss',
+ noise_kwargs={
+ 'mu': 0.0,
+ 'sigma': self._cfg.collect.noise_sigma.start
+ },
+ noise_range=self._cfg.collect.noise_range
+ )
+ else:
+ self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ if self._cfg.continuous:
+ # output = self._collect_model.forward(data)
+ output = self._collect_model.forward(data, **kwargs)
+ else:
+ output = self._collect_model.forward(data, **kwargs)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, policy_output: dict, timestep: namedtuple) -> dict:
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': policy_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return EasyDict(transition)
+
+ def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ data = get_nstep_return_data(data, 1, 1)
+ return get_train_sample(data, self._unroll_len)
diff --git a/DI-engine/ding/policy/bcq.py b/DI-engine/ding/policy/bcq.py
new file mode 100755
index 0000000000000000000000000000000000000000..9a8388b00ff63008b40c23c93503c58a60b19335
--- /dev/null
+++ b/DI-engine/ding/policy/bcq.py
@@ -0,0 +1,289 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import copy
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, get_nstep_return_data
+from ding.model import model_wrap
+from ding.policy import Policy
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('bcq')
+class BCQPolicy(Policy):
+ config = dict(
+ type='bcq',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool type) priority: Determine whether to use priority in buffer sample.
+ # Default False in SAC.
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (int) Number of training samples(randomly collected) in replay buffer when training starts.
+ # Default 10000 in SAC.
+ random_collect_size=10000,
+ nstep=1,
+ model=dict(
+ # (List) Hidden list for actor network head.
+ actor_head_hidden_size=[400, 300],
+
+ # (List) Hidden list for critic network head.
+ critic_head_hidden_size=[400, 300],
+ # Max perturbation hyper-parameter for BCQ
+ phi=0.05,
+ ),
+ learn=dict(
+
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=1,
+ # (int) Minibatch size for gradient descent.
+ batch_size=100,
+
+ # (float type) learning_rate_q: Learning rate for soft q network.
+ # Default to 3e-4.
+ # Please set to 1e-3, when model.value_network is True.
+ learning_rate_q=3e-4,
+ # (float type) learning_rate_policy: Learning rate for policy network.
+ # Default to 3e-4.
+ # Please set to 1e-3, when model.value_network is True.
+ learning_rate_policy=3e-4,
+ # (float type) learning_rate_vae: Learning rate for vae network.
+ # `learning_rate_value` should be initialized, when model.vae_network is True.
+ # Please set to 3e-4, when model.vae_network is True.
+ learning_rate_vae=3e-4,
+ # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with done is False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+
+ # (float type) target_theta: Used for soft update of the target network,
+ # aka. Interpolation factor in polyak averaging for target networks.
+ # Default to 0.005.
+ target_theta=0.005,
+ # (float) discount factor for the discounted sum of rewards, aka. gamma.
+ discount_factor=0.99,
+ lmbda=0.75,
+
+ # (float) Weight uniform initialization range in the last output layer
+ init_w=3e-3,
+ ),
+ collect=dict(
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(),
+ other=dict(
+ replay_buffer=dict(
+ # (int type) replay_buffer_size: Max size of replay buffer.
+ replay_buffer_size=1000000,
+ # (int type) max_use: Max use times of one data in the buffer.
+ # Data will be removed once used for too many times.
+ # Default to infinite.
+ # max_use=256,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'bcq', ['ding.model.template.bcq']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init q, value and policy's optimizers, algorithm config, main and target models.
+ """
+ # Init
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self.lmbda = self._cfg.learn.lmbda
+ self.latent_dim = self._cfg.model.action_shape * 2
+
+ # Optimizers
+ self._optimizer_q = Adam(
+ self._model.critic.parameters(),
+ lr=self._cfg.learn.learning_rate_q,
+ )
+ self._optimizer_policy = Adam(
+ self._model.actor.parameters(),
+ lr=self._cfg.learn.learning_rate_policy,
+ )
+ self._optimizer_vae = Adam(
+ self._model.vae.parameters(),
+ lr=self._cfg.learn.learning_rate_vae,
+ )
+
+ # Algorithm config
+ self._gamma = self._cfg.learn.discount_factor
+
+ # Main and target models
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ self._forward_learn_cnt = 0
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ loss_dict = {}
+
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if len(data.get('action').shape) == 1:
+ data['action'] = data['action'].reshape(-1, 1)
+
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ self._learn_model.train()
+ self._target_model.train()
+ obs = data['obs']
+ next_obs = data['next_obs']
+ reward = data['reward']
+ done = data['done']
+ batch_size = obs.shape[0]
+
+ # train_vae
+ vae_out = self._model.forward(data, mode='compute_vae')
+ recon, mean, log_std = vae_out['recons_action'], vae_out['mu'], vae_out['log_var']
+ recons_loss = F.mse_loss(recon, data['action'])
+ kld_loss = torch.mean(-0.5 * torch.sum(1 + log_std - mean ** 2 - log_std.exp(), dim=1), dim=0)
+ loss_dict['recons_loss'] = recons_loss
+ loss_dict['kld_loss'] = kld_loss
+ vae_loss = recons_loss + 0.5 * kld_loss
+ loss_dict['vae_loss'] = vae_loss
+ self._optimizer_vae.zero_grad()
+ vae_loss.backward()
+ self._optimizer_vae.step()
+
+ # train_critic
+ q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
+
+ with torch.no_grad():
+ next_obs_rep = torch.repeat_interleave(next_obs, 10, 0)
+ z = torch.randn((next_obs_rep.shape[0], self.latent_dim)).to(self._device).clamp(-0.5, 0.5)
+ vae_action = self._model.vae.decode_with_obs(z, next_obs_rep)['reconstruction_action']
+ next_action = self._target_model.forward({
+ 'obs': next_obs_rep,
+ 'action': vae_action
+ }, mode='compute_actor')['action']
+
+ next_data = {'obs': next_obs_rep, 'action': next_action}
+ target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']
+ # the value of a policy according to the maximum entropy objective
+ # find min one as target q value
+ target_q_value = self.lmbda * torch.min(target_q_value[0], target_q_value[1]) \
+ + (1 - self.lmbda) * torch.max(target_q_value[0], target_q_value[1])
+ target_q_value = target_q_value.reshape(batch_size, -1).max(1)[0].reshape(-1, 1)
+
+ q_data0 = v_1step_td_data(q_value[0], target_q_value, reward, done, data['weight'])
+ loss_dict['critic_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma)
+ q_data1 = v_1step_td_data(q_value[1], target_q_value, reward, done, data['weight'])
+ loss_dict['twin_critic_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma)
+ td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2
+
+ self._optimizer_q.zero_grad()
+ (loss_dict['critic_loss'] + loss_dict['twin_critic_loss']).backward()
+ self._optimizer_q.step()
+
+ # train_policy
+ z = torch.randn((obs.shape[0], self.latent_dim)).to(self._device).clamp(-0.5, 0.5)
+ sample_action = self._model.vae.decode_with_obs(z, obs)['reconstruction_action']
+ input = {'obs': obs, 'action': sample_action}
+ perturbed_action = self._model.forward(input, mode='compute_actor')['action']
+ q_input = {'obs': obs, 'action': perturbed_action}
+ q = self._learn_model.forward(q_input, mode='compute_critic')['q_value'][0]
+ loss_dict['actor_loss'] = -q.mean()
+ self._optimizer_policy.zero_grad()
+ loss_dict['actor_loss'].backward()
+ self._optimizer_policy.step()
+ self._forward_learn_cnt += 1
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'td_error': td_error_per_sample.detach().mean().item(),
+ 'target_q_value': target_q_value.detach().mean().item(),
+ **loss_dict
+ }
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return [
+ 'td_error', 'target_q_value', 'critic_loss', 'twin_critic_loss', 'actor_loss', 'recons_loss', 'kld_loss',
+ 'vae_loss'
+ ]
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ ret = {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer_q': self._optimizer_q.state_dict(),
+ 'optimizer_policy': self._optimizer_policy.state_dict(),
+ 'optimizer_vae': self._optimizer_vae.state_dict(),
+ }
+ return ret
+
+ def _init_eval(self):
+ self._eval_model = model_wrap(self._model, wrapper_name='base')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> Dict[str, Any]:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_eval')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _init_collect(self) -> None:
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._gamma = self._cfg.discount_factor # necessary for parallel
+ self._nstep = self._cfg.nstep # necessary for parallel
+ self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: dict, **kwargs) -> dict:
+ pass
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ pass
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the trajectory and the n step return data, then sample from the n_step return data
+ Arguments:
+ - data (:obj:`list`): The trajectory's cache
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
+ return get_train_sample(data, self._unroll_len)
diff --git a/DI-engine/ding/policy/bdq.py b/DI-engine/ding/policy/bdq.py
new file mode 100644
index 0000000000000000000000000000000000000000..c994c6cd456bf07ac4253b68e8493ac5e8ff0f63
--- /dev/null
+++ b/DI-engine/ding/policy/bdq.py
@@ -0,0 +1,393 @@
+from typing import List, Dict, Any, Tuple
+from collections import namedtuple
+import copy
+import torch
+
+from ding.torch_utils import Adam, to_device, ContrastiveLoss
+from ding.rl_utils import q_nstep_td_data, bdq_nstep_td_error, get_nstep_return_data, get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('bdq')
+class BDQPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of BDQ algorithm, extended by PER/multi-step TD. \
+ referenced paper Action Branching Architectures for Deep Reinforcement Learning \
+
+ .. note::
+ BDQ algorithm contains a neural architecture featuring a shared decision module \
+ followed by several network branches, one for each action dimension.
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str bdq | RL policy register name, refer to | This arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool False | Whether use priority(PER) | Priority sample,
+ | update priority
+ 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight
+ | ``_weight`` | to correct biased update. If True,
+ | priority must be True.
+ 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 7 ``nstep`` int 1, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 8 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ | ``_gpu``
+ 10 | ``learn.batch_`` int 64 | The number of samples of an iteration
+ | ``size``
+ 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
+ | ``_rate``
+ 12 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
+ | ``update_freq``
+ 13 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
+ | ``done`` | calculation. | fake termination env
+ 14 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
+ | call of collector. | different envs
+ 15 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
+ | ``_len``
+ 16 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp',
+ | 'linear'].
+ 17 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1]
+ | ``start``
+ 18 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1]
+ | ``end``
+ 19 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set
+ | ``decay`` | decay=10000 means
+ | the exploration rate
+ | decay from start
+ | value to end value
+ | during decay length.
+ == ==================== ======== ============== ======================================== =======================
+ """
+
+ config = dict(
+ type='bdq',
+ # (bool) Whether use cuda in policy
+ cuda=False,
+ # (bool) Whether learning policy is the same as collecting data policy(on-policy)
+ on_policy=False,
+ # (bool) Whether enable priority experience sample
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (float) Discount factor(gamma) for returns
+ discount_factor=0.97,
+ # (int) The number of step for calculating target q_value
+ nstep=1,
+ learn=dict(
+
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=3,
+ # (int) How many samples in a training batch
+ batch_size=64,
+ # (float) The step size of gradient descent
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (int) Frequence of target network update.
+ target_update_freq=100,
+ # (bool) Whether ignore done(usually for max step termination env)
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] shoule be set
+ # n_sample=8,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(),
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ # (float) Epsilon start value
+ start=0.95,
+ # (float) Epsilon end value
+ end=0.1,
+ # (int) Decay length(env step)
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=10000, ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default model setting for demonstration.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For BDQ, ``ding.model.template.q_learning.BDQ``
+ """
+ return 'bdq', ['ding.model.template.q_learning']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \
+ and target model.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ # Optimizer
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+
+ # use model_wrapper for specialized demands of different modes
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='assign',
+ update_kwargs={'freq': self._cfg.learn.target_update_freq}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Overview:
+ Forward computation graph of learn mode(updating policy).
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
+ np.ndarray or dict/list combinations.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
+ recorded in text log and tensorboard, values are python scalar or a list of scalars.
+ ArgumentsKeys:
+ - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
+ - optional: ``value_gamma``, ``IS``
+ ReturnsKeys:
+ - necessary: ``cur_lr``, ``total_loss``, ``priority``
+ - optional: ``action_distribution``
+ """
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=True
+ )
+
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # Q-learning forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # Current q value (main model)
+ q_value = self._learn_model.forward(data['obs'])['logit']
+ # Target q value
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(data['next_obs'])['logit']
+ # Max q value action (main model)
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+ if data['action'].shape != target_q_action.shape:
+ data['action'] = data['action'].unsqueeze(-1)
+
+ data_n = q_nstep_td_data(
+ q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
+ )
+ value_gamma = data.get('value_gamma')
+ loss, td_error_per_sample = bdq_nstep_td_error(data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma)
+
+ # ====================
+ # Q-learning update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ self._optimizer.step()
+
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ update_info = {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'q_value': q_value.mean().item(),
+ 'target_q_value': target_q_value.mean().item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
+ # '[histogram]action_distribution': data['action'],
+ }
+ q_value_per_branch = torch.mean(q_value, 2, keepdim=False)
+ for i in range(self._model.num_branches):
+ update_info['q_value_b_' + str(i)] = q_value_per_branch[:, i].mean().item()
+ return update_info
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return ['cur_lr', 'total_loss', 'q_value'] + ['q_value_b_' + str(i) for i in range(self._model.num_branches)]
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode, usually including model and optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Collect mode init method. Called by ``self.__init__``, initialize algorithm arguments and collect_model, \
+ enable the eps_greedy_sample for exploration.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._gamma = self._cfg.discount_factor # necessary for parallel
+ self._nstep = self._cfg.nstep # necessary for parallel
+ self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
+ """
+ Overview:
+ Forward computation graph of collect mode(collect training data), with eps_greedy for exploration.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting policy_output(action) for the interaction with \
+ env and the constructing of transition.
+ ArgumentsKeys:
+ - necessary: ``obs``
+ ReturnsKeys
+ - necessary: ``logit``, ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, eps=eps)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. A train sample can be a processed transition(BDQ with nstep TD).
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \
+ format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`dict`): The list of training samples.
+
+ .. note::
+ We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \
+ And the user can customize the this data processing procecure by overriding this two methods and collector \
+ itself.
+ """
+ data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
+ return get_train_sample(data, self._unroll_len)
+
+ def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]:
+ """
+ Overview:
+ Generate a transition(e.g.: ) for this algorithm training.
+ Arguments:
+ - obs (:obj:`Any`): Env observation.
+ - policy_output (:obj:`Dict[str, Any]`): The output of policy collect mode(``self._forward_collect``),\
+ including at least ``action``.
+ - timestep (:obj:`namedtuple`): The output after env step(execute policy output action), including at \
+ least ``obs``, ``reward``, ``done``, (here obs indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': policy_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``, initialize eval_model.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Forward computation graph of eval mode(evaluate policy performance), at most cases, it is similar to \
+ ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ArgumentsKeys:
+ - necessary: ``obs``
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
diff --git a/DI-engine/ding/policy/c51.py b/DI-engine/ding/policy/c51.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b6f36d68eeba9e6017c35021504f4d7751e1550
--- /dev/null
+++ b/DI-engine/ding/policy/c51.py
@@ -0,0 +1,268 @@
+from typing import List, Dict, Any, Tuple, Union
+import copy
+import torch
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import dist_nstep_td_data, dist_nstep_td_error, get_train_sample, get_nstep_return_data
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .dqn import DQNPolicy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('c51')
+class C51Policy(DQNPolicy):
+ r"""
+ Overview:
+ Policy class of C51 algorithm.
+
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str c51 | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool False | Whether use priority(PER) | priority sample,
+ | update priority
+ 5 ``model.v_min`` float -10 | Value of the smallest atom
+ | in the support set.
+ 6 ``model.v_max`` float 10 | Value of the largest atom
+ | in the support set.
+ 7 ``model.n_atom`` int 51 | Number of atoms in the support set
+ | of the value distribution.
+ 8 | ``other.eps`` float 0.95 | Start value for epsilon decay.
+ | ``.start`` |
+ 9 | ``other.eps`` float 0.1 | End value for epsilon decay.
+ | ``.end``
+ 10 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 11 ``nstep`` int 1, | N-step reward discount sum for target
+ | q_value estimation
+ 12 | ``learn.update`` int 3 | How many updates(iterations) to train | this args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ == ==================== ======== ============== ======================================== =======================
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='c51',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.97,
+ # (int) N-step reward for target q_value estimation
+ nstep=1,
+ model=dict(
+ v_min=-10,
+ v_max=10,
+ n_atom=51,
+ ),
+ learn=dict(
+
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=3,
+ batch_size=64,
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (int) Frequence of target network update.
+ target_update_freq=100,
+ # (bool) Whether ignore done(usually for max step termination env)
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) Only one of [n_sample, n_step, n_episode] shoule be set
+ # n_sample=8,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(),
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=0.95,
+ end=0.1,
+ # (int) Decay length(env step)
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=10000, )
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'c51dqn', ['ding.model.template.q_learning']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the optimizer, algorithm config, main and target models.
+ """
+ self._priority = self._cfg.priority
+ # Optimizer
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+ self._v_max = self._cfg.model.v_max
+ self._v_min = self._cfg.model.v_min
+ self._n_atom = self._cfg.model.n_atom
+
+ # use wrapper instead of plugin
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='assign',
+ update_kwargs={'freq': self._cfg.learn.target_update_freq}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
+ """
+ data = default_preprocess_learn(
+ data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # Q-learning forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # Current q value (main model)
+ output = self._learn_model.forward(data['obs'])
+ q_value = output['logit']
+ q_value_dist = output['distribution']
+ # Target q value
+ with torch.no_grad():
+ target_output = self._target_model.forward(data['next_obs'])
+ target_q_value_dist = target_output['distribution']
+ target_q_value = target_output['logit']
+ # Max q value action (main model)
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+
+ data_n = dist_nstep_td_data(
+ q_value_dist, target_q_value_dist, data['action'], target_q_action, data['reward'], data['done'],
+ data['weight']
+ )
+ value_gamma = data.get('value_gamma')
+ loss, td_error_per_sample = dist_nstep_td_error(
+ data_n, self._gamma, self._v_min, self._v_max, self._n_atom, nstep=self._nstep, value_gamma=value_gamma
+ )
+
+ # ====================
+ # Q-learning update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ self._optimizer.step()
+
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'q_value': q_value.mean().item(),
+ 'target_q_value': target_q_value.mean().item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
+ # '[histogram]action_distribution': data['action'],
+ }
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return ['cur_lr', 'total_loss', 'q_value', 'target_q_value']
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Collect mode init method. Called by ``self.__init__``. Initialize necessary arguments for nstep return \
+ calculation and collect_model for exploration (eps_greedy_sample).
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._gamma = self._cfg.discount_factor # necessary for parallel
+ self._nstep = self._cfg.nstep # necessary for parallel
+ self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
+ """
+ Overview:
+ Forward computation graph of collect mode(collect training data), with eps_greedy for exploration.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting policy_output(action) for the interaction with \
+ env and the constructing of transition.
+ ArgumentsKeys:
+ - necessary: ``obs``
+ ReturnsKeys
+ - necessary: ``logit``, ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, eps=eps)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ """
+ Overview:
+ Calculate nstep return data and transform a trajectory into many train samples.
+ Arguments:
+ - data (:obj:`list`): The collected data of a trajectory, which is a list that contains dict elements.
+ Returns:
+ - samples (:obj:`dict`): The training samples generated.
+ """
+ data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
+ return get_train_sample(data, self._unroll_len)
diff --git a/DI-engine/ding/policy/collaq.py b/DI-engine/ding/policy/collaq.py
new file mode 100644
index 0000000000000000000000000000000000000000..961d4fed8081e6a05de7a53211790642e9bbd168
--- /dev/null
+++ b/DI-engine/ding/policy/collaq.py
@@ -0,0 +1,455 @@
+from typing import List, Dict, Any, Tuple, Union, Optional
+from collections import namedtuple
+import torch
+import copy
+
+from ding.torch_utils import to_device, RMSprop
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import timestep_collate, default_collate, default_decollate
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('collaq')
+class CollaQPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of CollaQ algorithm. CollaQ is a multi-agent reinforcement learning algorithm
+ Interface:
+ _init_learn, _data_preprocess_learn, _forward_learn, _reset_learn, _state_dict_learn, _load_state_dict_learn\
+ _init_collect, _forward_collect, _reset_collect, _process_transition, _init_eval, _forward_eval\
+ _reset_eval, _get_train_sample, default_model
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str collaq | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool True | Whether to use cuda for network | this arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4. ``priority`` bool False | Whether use priority(PER) | priority sample,
+ | update priority
+ 5 | ``priority_`` bool False | Whether use Importance Sampling | IS weight
+ | ``IS_weight`` | Weight to correct biased update.
+ 6 | ``learn.update_`` int 20 | How many updates(iterations) to train | this args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 7 | ``learn.target_`` float 0.001 | Target network update momentum | between[0,1]
+ | ``update_theta`` | parameter.
+ 8 | ``learn.discount`` float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse
+ | ``_factor`` | gamma | reward env
+ 9 | ``learn.collaq`` float 1.0 | The weight of collaq MARA loss
+ | ``_loss_weight``
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='collaq',
+ # (bool) Whether to use cuda for network.
+ cuda=True,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ learn=dict(
+
+ # (int) Collect n_episode data, update_model n_iteration times
+ update_per_collect=20,
+ # (int) The number of data for a train iteration
+ batch_size=32,
+ # (float) Gradient-descent step size
+ learning_rate=0.0005,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) Target network update weight, theta * new_w + (1 - theta) * old_w, defaults in [0, 0.1]
+ target_update_theta=0.001,
+ # (float) Discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ # (float) The weight of collaq MARA loss
+ collaq_loss_weight=1.0,
+ # (float)
+ clip_value=100,
+ # (bool) Whether to use double DQN mechanism(target q for surpassing over estimation)
+ double_q=False,
+ ),
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] shoule be set
+ # n_episode=32,
+ # (int) Cut trajectories into pieces with length "unroll_len", the length of timesteps
+ # in each forward when training. In qmix, it is greater than 1 because there is RNN.
+ unroll_len=10,
+ ),
+ eval=dict(),
+ other=dict(
+ eps=dict(
+ # (str) Type of epsilon decay
+ type='exp',
+ # (float) Start value for epsilon decay, in [0, 1].
+ # 0 means not use epsilon decay.
+ start=1,
+ # (float) Start value for epsilon decay, in [0, 1].
+ end=0.05,
+ # (int) Decay length(env step)
+ decay=200000,
+ ),
+ replay_buffer=dict(
+ # (int) max size of replay buffer
+ replay_buffer_size=5000,
+ max_reuse=10,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default model setting for demonstration.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For collaq, ``ding.model.collaq.CollaQ`` .
+ """
+ return 'collaq', ['ding.model.template.collaq']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the learner model of CollaQPolicy
+ Arguments:
+ .. note::
+
+ The _init_learn method takes the argument from the self._cfg.learn in the config file
+
+ - learning_rate (:obj:`float`): The learning rate fo the optimizer
+ - gamma (:obj:`float`): The discount factor
+ - alpha (:obj:`float`): The collaQ loss factor, the weight for calculating MARL loss
+ - agent_num (:obj:`int`): Since this is a multi-agent algorithm, we need to input the agent num.
+ - batch_size (:obj:`int`): Need batch size info to init hidden_state plugins
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._optimizer = RMSprop(
+ params=self._model.parameters(), lr=self._cfg.learn.learning_rate, alpha=0.99, eps=0.00001
+ )
+ self._gamma = self._cfg.learn.discount_factor
+ self._alpha = self._cfg.learn.collaq_loss_weight
+
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_update_theta}
+ )
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ init_fn=lambda: [[None for _ in range(self._cfg.model.agent_num)] for _ in range(3)]
+ )
+ self._learn_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ init_fn=lambda: [[None for _ in range(self._cfg.model.agent_num)] for _ in range(3)]
+ )
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _data_preprocess_learn(
+ self,
+ data: List[Any],
+ use_priority_IS_weight: bool = False,
+ use_priority: bool = False,
+ ) -> dict:
+ r"""
+ Overview:
+ Preprocess the data to fit the required data format for learning
+
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function
+
+ Returns:
+ - data (:obj:`Dict[str, Any]`): the processed data, from \
+ [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])}
+ """
+ # data preprocess
+ data = timestep_collate(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+ if use_priority_IS_weight:
+ assert use_priority, "Use IS Weight correction, but Priority is not used."
+ if use_priority and use_priority_IS_weight:
+ if 'priority_IS' in data:
+ data['weight'] = data['priority_IS']
+ else: # for compability
+ data['weight'] = data['IS']
+ else:
+ data['weight'] = data.get('weight', None)
+ data['done'] = data['done'].float()
+ return data
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
+ np.ndarray or dict/list combinations.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
+ recorded in text log and tensorboard, values are python scalar or a list of scalars.
+ ArgumentsKeys:
+ - necessary: ``obs``, ``next_obs``, ``action``, ``reward``, ``weight``, ``prev_state``, ``done``
+ ReturnsKeys:
+ - necessary: ``cur_lr``, ``total_loss``
+ - cur_lr (:obj:`float`): Current learning rate
+ - total_loss (:obj:`float`): The calculated loss
+ """
+ data = self._data_preprocess_learn(data, self.cfg.priority_IS_weight, self.cfg.priority)
+ # ====================
+ # CollaQ forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # for hidden_state plugin, we need to reset the main model and target model
+ self._learn_model.reset(state=data['prev_state'][0])
+ self._target_model.reset(state=data['prev_state'][0])
+ inputs = {'obs': data['obs'], 'action': data['action']}
+ ret = self._learn_model.forward(inputs, single_step=False)
+ total_q = ret['total_q']
+ agent_colla_alone_q = ret['agent_colla_alone_q'].sum(-1).sum(-1)
+
+ if self._cfg.learn.double_q:
+ next_inputs = {'obs': data['next_obs']}
+ logit_detach = self._learn_model.forward(next_inputs, single_step=False)['logit'].clone().detach()
+ next_inputs = {'obs': data['next_obs'], 'action': logit_detach.argmax(dim=-1)}
+ else:
+ next_inputs = {'obs': data['next_obs']}
+ with torch.no_grad():
+ target_total_q = self._target_model.forward(next_inputs, single_step=False)['total_q']
+
+ # td_loss calculation
+ td_data = v_1step_td_data(total_q, target_total_q, data['reward'], data['done'], data['weight'])
+ td_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma)
+ # collaQ loss calculation
+ colla_loss = (agent_colla_alone_q ** 2).mean()
+ # combine loss with factor
+ loss = colla_loss * self._alpha + td_loss
+ # ====================
+ # CollaQ update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ grad_norm = torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.learn.clip_value)
+ self._optimizer.step()
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'colla_loss': colla_loss.item(),
+ 'td_loss': td_loss.item(),
+ 'grad_norm': grad_norm,
+ 'priority': torch.mean(td_error_per_sample.abs(), dim=0).tolist(),
+ }
+
+ def _reset_learn(self, data_id: Optional[List[int]] = None) -> None:
+ r"""
+ Overview:
+ Reset learn model to the state indicated by data_id
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\
+ the model state to the state indicated by data_id
+ """
+ self._learn_model.reset(data_id=data_id)
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Return the state_dict of learn mode, usually including model and optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ r"""
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ Enable the eps_greedy_sample and the hidden_state plugin.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._collect_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.collect.env_num,
+ save_prev_state=True,
+ init_fn=lambda: [[None for _ in range(self._cfg.model.agent_num)] for _ in range(3)]
+ )
+ self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: dict, eps: float) -> dict:
+ r"""
+ Overview:
+ Forward function for collect mode with eps_greedy
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, eps=eps, data_id=data_id)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None:
+ r"""
+ Overview:
+ Reset collect model to the state indicated by data_id
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\
+ the model state to the state indicated by data_id
+ """
+ self._collect_model.reset(data_id=data_id)
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least \
+ ['action', 'prev_state', 'agent_colla_alone_q']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'prev_state': model_output['prev_state'],
+ 'action': model_output['action'],
+ 'agent_colla_alone_q': model_output['agent_colla_alone_q'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model with argmax strategy and the hidden_state plugin.
+ """
+ self._eval_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.eval.env_num,
+ save_prev_state=True,
+ init_fn=lambda: [[None for _ in range(self._cfg.model.agent_num)] for _ in range(3)]
+ )
+ self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function for eval mode, similar to ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, data_id=data_id)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
+ r"""
+ Overview:
+ Reset eval model to the state indicated by data_id
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\
+ the model state to the state indicated by data_id
+ """
+ self._eval_model.reset(data_id=data_id)
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the train sample from trajectory.
+ Arguments:
+ - data (:obj:`list`): The trajectory's cache
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ return get_train_sample(data, self._unroll_len)
+
+ def _monitor_vars_learn(self) -> List[str]:
+ r"""
+ Overview:
+ Return variables' name if variables are to used in monitor.
+ Returns:
+ - vars (:obj:`List[str]`): Variables' name list.
+ """
+ return ['cur_lr', 'total_loss', 'colla_loss', 'td_loss', 'grad_norm']
diff --git a/DI-engine/ding/policy/coma.py b/DI-engine/ding/policy/coma.py
new file mode 100644
index 0000000000000000000000000000000000000000..5940a25f0eded3845b764537520dc1d0b7cf61db
--- /dev/null
+++ b/DI-engine/ding/policy/coma.py
@@ -0,0 +1,379 @@
+from typing import List, Dict, Any, Tuple, Union, Optional
+from collections import namedtuple
+import torch
+import copy
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import coma_data, coma_error, get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate, timestep_collate
+from .base_policy import Policy
+
+
+@POLICY_REGISTRY.register('coma')
+class COMAPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of COMA algorithm. COMA is a multi model reinforcement learning algorithm
+ Interface:
+ _init_learn, _data_preprocess_learn, _forward_learn, _reset_learn, _state_dict_learn, _load_state_dict_learn\
+ _init_collect, _forward_collect, _reset_collect, _process_transition, _init_eval, _forward_eval\
+ _reset_eval, _get_train_sample, default_model, _monitor_vars_learn
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str coma | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool True | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4. ``priority`` bool False | Whether use priority(PER) | priority sample,
+ | update priority
+ 5 | ``priority_`` bool False | Whether use Importance Sampling | IS weight
+ | ``IS_weight`` | Weight to correct biased update.
+ 6 | ``learn.update`` int 1 | How many updates(iterations) to train | this args can be vary
+ | ``_per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 7 | ``learn.target_`` float 0.001 | Target network update momentum | between[0,1]
+ | ``update_theta`` | parameter.
+ 8 | ``learn.discount`` float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse
+ | ``_factor`` | gamma | reward env
+ 9 | ``learn.td_`` float 0.8 | The trade-off factor of td-lambda,
+ | ``lambda`` | which balances 1step td and mc
+ 10 | ``learn.value_`` float 1.0 | The loss weight of value network | policy network weight
+ | ``weight`` | is set to 1
+ 11 | ``learn.entropy_`` float 0.01 | The loss weight of entropy | policy network weight
+ | ``weight`` | regularization | is set to 1
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='coma',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) target network update weight, theta * new_w + (1 - theta) * old_w, defaults in [0, 0.1]
+ target_update_theta=0.001,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ # (float) the trade-off factor of td-lambda, which balances 1step td and mc(nstep td in practice)
+ td_lambda=0.8,
+ # (float) the loss weight of policy network network
+ policy_weight=0.001,
+ # (float) the loss weight of value network
+ value_weight=1,
+ # (float) the loss weight of entropy regularization
+ entropy_weight=0.01,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration time
+ # n_episode=32,
+ # (int) unroll length of a train iteration(gradient update step)
+ unroll_len=20,
+ ),
+ eval=dict(),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default model setting for demonstration.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For coma, ``ding.model.coma.coma``
+ """
+ return 'coma', ['ding.model.template.coma']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Init the learner model of COMAPolicy
+
+ Arguments:
+ .. note::
+
+ The _init_learn method takes the argument from the self._cfg.learn in the config file
+
+ - learning_rate (:obj:`float`): The learning rate fo the optimizer
+ - gamma (:obj:`float`): The discount factor
+ - lambda (:obj:`float`): The lambda factor, determining the mix of bootstrapping\
+ vs further accumulation of multistep returns at each timestep,
+ - value_wight(:obj:`float`): The weight of value loss in total loss
+ - entropy_weight(:obj:`float`): The weight of entropy loss in total loss
+ - agent_num (:obj:`int`): Since this is a multi-agent algorithm, we need to input the agent num.
+ - batch_size (:obj:`int`): Need batch size info to init hidden_state plugins
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ assert not self._priority, "not implemented priority in COMA"
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+ self._gamma = self._cfg.learn.discount_factor
+ self._lambda = self._cfg.learn.td_lambda
+ self._policy_weight = self._cfg.learn.policy_weight
+ self._value_weight = self._cfg.learn.value_weight
+ self._entropy_weight = self._cfg.learn.entropy_weight
+
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_update_theta}
+ )
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._learn_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _data_preprocess_learn(self, data: List[Any]) -> dict:
+ r"""
+ Overview:
+ Preprocess the data to fit the required data format for learning
+
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function, the Dict
+ in data should contain keys including at least ['obs', 'action', 'reward']
+
+ Returns:
+ - data (:obj:`Dict[str, Any]`): the processed data, including at least \
+ ['obs', 'action', 'reward', 'done', 'weight']
+ """
+ # data preprocess
+ data = timestep_collate(data)
+ assert set(data.keys()) > set(['obs', 'action', 'reward'])
+ if self._cuda:
+ data = to_device(data, self._device)
+ data['weight'] = data.get('weight', None)
+ data['done'] = data['done'].float()
+ return data
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode, acquire the data and calculate the loss and\
+ optimize learner model
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
+ np.ndarray or dict/list combinations.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
+ recorded in text log and tensorboard, values are python scalar or a list of scalars.
+ ArgumentsKeys:
+ - necessary: ``obs``, ``action``, ``reward``, ``done``, ``weight``
+ ReturnsKeys:
+ - necessary: ``cur_lr``, ``total_loss``, ``policy_loss``, ``value_loss``, ``entropy_loss``
+ - cur_lr (:obj:`float`): Current learning rate
+ - total_loss (:obj:`float`): The calculated loss
+ - policy_loss (:obj:`float`): The policy(actor) loss of coma
+ - value_loss (:obj:`float`): The value(critic) loss of coma
+ - entropy_loss (:obj:`float`): The entropy loss
+ """
+ data = self._data_preprocess_learn(data)
+ # forward
+ self._learn_model.train()
+ self._target_model.train()
+ self._learn_model.reset(state=data['prev_state'][0])
+ self._target_model.reset(state=data['prev_state'][0])
+ q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(data, mode='compute_critic')['q_value']
+ logit = self._learn_model.forward(data, mode='compute_actor')['logit']
+ logit[data['obs']['action_mask'] == 0.0] = -9999999
+
+ data = coma_data(logit, data['action'], q_value, target_q_value, data['reward'], data['weight'])
+ coma_loss = coma_error(data, self._gamma, self._lambda)
+ total_loss = self._policy_weight * coma_loss.policy_loss + self._value_weight * coma_loss.q_value_loss - \
+ self._entropy_weight * coma_loss.entropy_loss
+
+ # update
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ self._optimizer.step()
+ # after update
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': coma_loss.policy_loss.item(),
+ 'value_loss': coma_loss.q_value_loss.item(),
+ 'entropy_loss': coma_loss.entropy_loss.item(),
+ }
+
+ def _reset_learn(self, data_id: Optional[List[int]] = None) -> None:
+ self._learn_model.reset(data_id=data_id)
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init moethod. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ Model has eps_greedy_sample wrapper and hidden state wrapper
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._collect_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.collect.env_num,
+ save_prev_state=True,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: dict, eps: float) -> dict:
+ r"""
+ Overview:
+ Collect output according to eps_greedy plugin
+
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, eps=eps, data_id=data_id, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None:
+ self._collect_model.reset(data_id=data_id)
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'prev_state': model_output['prev_state'],
+ 'action': model_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model with argmax strategy and hidden_state plugin.
+ """
+ self._eval_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.eval.env_num,
+ save_prev_state=True,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of eval mode, similar to ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, data_id=data_id, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
+ self._eval_model.reset(data_id=data_id)
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the train sample from trajectory
+
+ Arguments:
+ - data (:obj:`list`): The trajectory's cache
+
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ return get_train_sample(data, self._unroll_len)
+
+ def _monitor_vars_learn(self) -> List[str]:
+ r"""
+ Overview:
+ Return variables' name if variables are to used in monitor.
+ Returns:
+ - vars (:obj:`List[str]`): Variables' name list.
+ """
+ return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'entropy_loss']
diff --git a/DI-engine/ding/policy/command_mode_policy_instance.py b/DI-engine/ding/policy/command_mode_policy_instance.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e817ead4bcd4c33a918a4b28503441a14618e95
--- /dev/null
+++ b/DI-engine/ding/policy/command_mode_policy_instance.py
@@ -0,0 +1,457 @@
+from ding.utils import POLICY_REGISTRY
+from ding.rl_utils import get_epsilon_greedy_fn
+from .base_policy import CommandModePolicy
+
+from .dqn import DQNPolicy, DQNSTDIMPolicy
+from .mdqn import MDQNPolicy
+from .c51 import C51Policy
+from .qrdqn import QRDQNPolicy
+from .iqn import IQNPolicy
+from .fqf import FQFPolicy
+from .rainbow import RainbowDQNPolicy
+from .r2d2 import R2D2Policy
+from .r2d2_gtrxl import R2D2GTrXLPolicy
+from .r2d2_collect_traj import R2D2CollectTrajPolicy
+from .sqn import SQNPolicy
+from .ppo import PPOPolicy, PPOOffPolicy, PPOPGPolicy, PPOSTDIMPolicy
+from .offppo_collect_traj import OffPPOCollectTrajPolicy
+from .ppg import PPGPolicy, PPGOffPolicy
+from .pg import PGPolicy
+from .a2c import A2CPolicy
+from .impala import IMPALAPolicy
+from .ngu import NGUPolicy
+from .ddpg import DDPGPolicy
+from .td3 import TD3Policy
+from .td3_vae import TD3VAEPolicy
+from .td3_bc import TD3BCPolicy
+from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy
+from .mbpolicy.mbsac import MBSACPolicy, STEVESACPolicy
+from .mbpolicy.dreamer import DREAMERPolicy
+from .qmix import QMIXPolicy
+from .wqmix import WQMIXPolicy
+from .collaq import CollaQPolicy
+from .coma import COMAPolicy
+from .atoc import ATOCPolicy
+from .acer import ACERPolicy
+from .qtran import QTRANPolicy
+from .sql import SQLPolicy
+from .bc import BehaviourCloningPolicy
+from .ibc import IBCPolicy
+
+from .dqfd import DQFDPolicy
+from .r2d3 import R2D3Policy
+
+from .d4pg import D4PGPolicy
+from .cql import CQLPolicy, DiscreteCQLPolicy
+from .dt import DTPolicy
+from .pdqn import PDQNPolicy
+from .madqn import MADQNPolicy
+from .bdq import BDQPolicy
+from .bcq import BCQPolicy
+from .edac import EDACPolicy
+from .prompt_pg import PromptPGPolicy
+from .plan_diffuser import PDPolicy
+from .happo import HAPPOPolicy
+
+
+class EpsCommandModePolicy(CommandModePolicy):
+
+ def _init_command(self) -> None:
+ r"""
+ Overview:
+ Command mode init method. Called by ``self.__init__``.
+ Set the eps_greedy rule according to the config for command
+ """
+ eps_cfg = self._cfg.other.eps
+ self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ def _get_setting_collect(self, command_info: dict) -> dict:
+ r"""
+ Overview:
+ Collect mode setting information including eps
+ Arguments:
+ - command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep']
+ Returns:
+ - collect_setting (:obj:`dict`): Including eps in collect mode.
+ """
+ # Decay according to `learner_train_iter`
+ # step = command_info['learner_train_iter']
+ # Decay according to `envstep`
+ step = command_info['envstep']
+ return {'eps': self.epsilon_greedy(step)}
+
+ def _get_setting_learn(self, command_info: dict) -> dict:
+ return {}
+
+ def _get_setting_eval(self, command_info: dict) -> dict:
+ return {}
+
+
+class DummyCommandModePolicy(CommandModePolicy):
+
+ def _init_command(self) -> None:
+ pass
+
+ def _get_setting_collect(self, command_info: dict) -> dict:
+ return {}
+
+ def _get_setting_learn(self, command_info: dict) -> dict:
+ return {}
+
+ def _get_setting_eval(self, command_info: dict) -> dict:
+ return {}
+
+
+@POLICY_REGISTRY.register('bdq_command')
+class BDQCommandModePolicy(BDQPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('mdqn_command')
+class MDQNCommandModePolicy(MDQNPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('dqn_command')
+class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('dqn_stdim_command')
+class DQNSTDIMCommandModePolicy(DQNSTDIMPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('dqfd_command')
+class DQFDCommandModePolicy(DQFDPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('c51_command')
+class C51CommandModePolicy(C51Policy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('qrdqn_command')
+class QRDQNCommandModePolicy(QRDQNPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('iqn_command')
+class IQNCommandModePolicy(IQNPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('fqf_command')
+class FQFCommandModePolicy(FQFPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('rainbow_command')
+class RainbowDQNCommandModePolicy(RainbowDQNPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('r2d2_command')
+class R2D2CommandModePolicy(R2D2Policy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('r2d2_gtrxl_command')
+class R2D2GTrXLCommandModePolicy(R2D2GTrXLPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('r2d2_collect_traj_command')
+class R2D2CollectTrajCommandModePolicy(R2D2CollectTrajPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('r2d3_command')
+class R2D3CommandModePolicy(R2D3Policy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('sqn_command')
+class SQNCommandModePolicy(SQNPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('sql_command')
+class SQLCommandModePolicy(SQLPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('ppo_command')
+class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('happo_command')
+class HAPPOCommandModePolicy(HAPPOPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('ppo_stdim_command')
+class PPOSTDIMCommandModePolicy(PPOSTDIMPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('ppo_pg_command')
+class PPOPGCommandModePolicy(PPOPGPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('ppo_offpolicy_command')
+class PPOOffCommandModePolicy(PPOOffPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('offppo_collect_traj_command')
+class PPOOffCollectTrajCommandModePolicy(OffPPOCollectTrajPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('pg_command')
+class PGCommandModePolicy(PGPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('a2c_command')
+class A2CCommandModePolicy(A2CPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('impala_command')
+class IMPALACommandModePolicy(IMPALAPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('ppg_offpolicy_command')
+class PPGOffCommandModePolicy(PPGOffPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('ppg_command')
+class PPGCommandModePolicy(PPGPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('madqn_command')
+class MADQNCommandModePolicy(MADQNPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('ddpg_command')
+class DDPGCommandModePolicy(DDPGPolicy, CommandModePolicy):
+
+ def _init_command(self) -> None:
+ r"""
+ Overview:
+ Command mode init method. Called by ``self.__init__``.
+ If hybrid action space, set the eps_greedy rule according to the config for command,
+ otherwise, just a empty method
+ """
+ if self._cfg.action_space == 'hybrid':
+ eps_cfg = self._cfg.other.eps
+ self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ def _get_setting_collect(self, command_info: dict) -> dict:
+ r"""
+ Overview:
+ Collect mode setting information including eps when hybrid action space
+ Arguments:
+ - command_info (:obj:`dict`): Dict type, including at least ['learner_step', 'envstep']
+ Returns:
+ - collect_setting (:obj:`dict`): Including eps in collect mode.
+ """
+ if self._cfg.action_space == 'hybrid':
+ # Decay according to `learner_step`
+ # step = command_info['learner_step']
+ # Decay according to `envstep`
+ step = command_info['envstep']
+ return {'eps': self.epsilon_greedy(step)}
+ else:
+ return {}
+
+ def _get_setting_learn(self, command_info: dict) -> dict:
+ return {}
+
+ def _get_setting_eval(self, command_info: dict) -> dict:
+ return {}
+
+
+@POLICY_REGISTRY.register('td3_command')
+class TD3CommandModePolicy(TD3Policy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('td3_vae_command')
+class TD3VAECommandModePolicy(TD3VAEPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('td3_bc_command')
+class TD3BCCommandModePolicy(TD3BCPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('sac_command')
+class SACCommandModePolicy(SACPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('mbsac_command')
+class MBSACCommandModePolicy(MBSACPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('stevesac_command')
+class STEVESACCommandModePolicy(STEVESACPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('dreamer_command')
+class DREAMERCommandModePolicy(DREAMERPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('cql_command')
+class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('discrete_cql_command')
+class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('dt_command')
+class DTCommandModePolicy(DTPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('qmix_command')
+class QMIXCommandModePolicy(QMIXPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('wqmix_command')
+class WQMIXCommandModePolicy(WQMIXPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('collaq_command')
+class CollaQCommandModePolicy(CollaQPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('coma_command')
+class COMACommandModePolicy(COMAPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('atoc_command')
+class ATOCCommandModePolicy(ATOCPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('acer_command')
+class ACERCommandModePolisy(ACERPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('qtran_command')
+class QTRANCommandModePolicy(QTRANPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('ngu_command')
+class NGUCommandModePolicy(NGUPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('d4pg_command')
+class D4PGCommandModePolicy(D4PGPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('pdqn_command')
+class PDQNCommandModePolicy(PDQNPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('discrete_sac_command')
+class DiscreteSACCommandModePolicy(DiscreteSACPolicy, EpsCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('sqil_sac_command')
+class SQILSACCommandModePolicy(SQILSACPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('ibc_command')
+class IBCCommandModePolicy(IBCPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('bcq_command')
+class BCQCommandModelPolicy(BCQPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('edac_command')
+class EDACCommandModelPolicy(EDACPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('pd_command')
+class PDCommandModelPolicy(PDPolicy, DummyCommandModePolicy):
+ pass
+
+
+@POLICY_REGISTRY.register('bc_command')
+class BCCommandModePolicy(BehaviourCloningPolicy, DummyCommandModePolicy):
+
+ def _init_command(self) -> None:
+ r"""
+ Overview:
+ Command mode init method. Called by ``self.__init__``.
+ Set the eps_greedy rule according to the config for command
+ """
+ if self._cfg.continuous:
+ noise_cfg = self._cfg.collect.noise_sigma
+ self.epsilon_greedy = get_epsilon_greedy_fn(noise_cfg.start, noise_cfg.end, noise_cfg.decay, noise_cfg.type)
+ else:
+ eps_cfg = self._cfg.other.eps
+ self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ def _get_setting_collect(self, command_info: dict) -> dict:
+ r"""
+ Overview:
+ Collect mode setting information including eps
+ Arguments:
+ - command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep']
+ Returns:
+ - collect_setting (:obj:`dict`): Including eps in collect mode.
+ """
+ if self._cfg.continuous:
+ # Decay according to `learner_step`
+ step = command_info['learner_step']
+ return {'sigma': self.epsilon_greedy(step)}
+ else:
+ # Decay according to `envstep`
+ step = command_info['envstep']
+ return {'eps': self.epsilon_greedy(step)}
+
+ def _get_setting_learn(self, command_info: dict) -> dict:
+ return {}
+
+ def _get_setting_eval(self, command_info: dict) -> dict:
+ return {}
+
+
+@POLICY_REGISTRY.register('prompt_pg_command')
+class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy):
+ pass
diff --git a/DI-engine/ding/policy/common_utils.py b/DI-engine/ding/policy/common_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..de1d697152d81aaa010fa39a65f83459d4e28423
--- /dev/null
+++ b/DI-engine/ding/policy/common_utils.py
@@ -0,0 +1,129 @@
+from typing import List, Any, Dict, Callable
+import torch
+import numpy as np
+import treetensor.torch as ttorch
+from ding.utils.data import default_collate
+from ding.torch_utils import to_tensor, to_ndarray, unsqueeze, squeeze
+
+
+def default_preprocess_learn(
+ data: List[Any],
+ use_priority_IS_weight: bool = False,
+ use_priority: bool = False,
+ use_nstep: bool = False,
+ ignore_done: bool = False,
+) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Default data pre-processing in policy's ``_forward_learn`` method, including stacking batch data, preprocess \
+ ignore done, nstep and priority IS weight.
+ Arguments:
+ - data (:obj:`List[Any]`): The list of a training batch samples, each sample is a dict of PyTorch Tensor.
+ - use_priority_IS_weight (:obj:`bool`): Whether to use priority IS weight correction, if True, this function \
+ will set the weight of each sample to the priority IS weight.
+ - use_priority (:obj:`bool`): Whether to use priority, if True, this function will set the priority IS weight.
+ - use_nstep (:obj:`bool`): Whether to use nstep TD error, if True, this function will reshape the reward.
+ - ignore_done (:obj:`bool`): Whether to ignore done, if True, this function will set the done to 0.
+ Returns:
+ - data (:obj:`Dict[str, torch.Tensor]`): The preprocessed dict data whose values can be directly used for \
+ the following model forward and loss computation.
+ """
+ # data preprocess
+ elem = data[0]
+ if isinstance(elem['action'], (np.ndarray, torch.Tensor)) and elem['action'].dtype in [np.int64, torch.int64]:
+ data = default_collate(data, cat_1dim=True) # for discrete action
+ else:
+ data = default_collate(data, cat_1dim=False) # for continuous action
+ if 'value' in data and data['value'].dim() == 2 and data['value'].shape[1] == 1:
+ data['value'] = data['value'].squeeze(-1)
+ if 'adv' in data and data['adv'].dim() == 2 and data['adv'].shape[1] == 1:
+ data['adv'] = data['adv'].squeeze(-1)
+
+ if ignore_done:
+ data['done'] = torch.zeros_like(data['done']).float()
+ else:
+ data['done'] = data['done'].float()
+
+ if data['done'].dim() == 2 and data['done'].shape[1] == 1:
+ data['done'] = data['done'].squeeze(-1)
+
+ if use_priority_IS_weight:
+ assert use_priority, "Use IS Weight correction, but Priority is not used."
+ if use_priority and use_priority_IS_weight:
+ if 'priority_IS' in data:
+ data['weight'] = data['priority_IS']
+ else: # for compability
+ data['weight'] = data['IS']
+ else:
+ data['weight'] = data.get('weight', None)
+ if use_nstep:
+ # reward reshaping for n-step
+ reward = data['reward']
+ if len(reward.shape) == 1:
+ reward = reward.unsqueeze(1)
+ # reward: (batch_size, nstep) -> (nstep, batch_size)
+ data['reward'] = reward.permute(1, 0).contiguous()
+ else:
+ if data['reward'].dim() == 2 and data['reward'].shape[1] == 1:
+ data['reward'] = data['reward'].squeeze(-1)
+
+ return data
+
+
+def single_env_forward_wrapper(forward_fn: Callable) -> Callable:
+ """
+ Overview:
+ Wrap policy to support gym-style interaction between policy and single environment.
+ Arguments:
+ - forward_fn (:obj:`Callable`): The original forward function of policy.
+ Returns:
+ - wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy.
+ Examples:
+ >>> env = gym.make('CartPole-v0')
+ >>> policy = DQNPolicy(...)
+ >>> forward_fn = single_env_forward_wrapper(policy.eval_mode.forward)
+ >>> obs = env.reset()
+ >>> action = forward_fn(obs)
+ >>> next_obs, rew, done, info = env.step(action)
+
+ """
+
+ def _forward(obs):
+ obs = {0: unsqueeze(to_tensor(obs))}
+ action = forward_fn(obs)[0]['action']
+ action = to_ndarray(squeeze(action))
+ return action
+
+ return _forward
+
+
+def single_env_forward_wrapper_ttorch(forward_fn: Callable, cuda: bool = True) -> Callable:
+ """
+ Overview:
+ Wrap policy to support gym-style interaction between policy and single environment for treetensor (ttorch) data.
+ Arguments:
+ - forward_fn (:obj:`Callable`): The original forward function of policy.
+ - cuda (:obj:`bool`): Whether to use cuda in policy, if True, this function will move the input data to cuda.
+ Returns:
+ - wrapped_forward_fn (:obj:`Callable`): The wrapped forward function of policy.
+
+ Examples:
+ >>> env = gym.make('CartPole-v0')
+ >>> policy = PPOFPolicy(...)
+ >>> forward_fn = single_env_forward_wrapper_ttorch(policy.eval)
+ >>> obs = env.reset()
+ >>> action = forward_fn(obs)
+ >>> next_obs, rew, done, info = env.step(action)
+ """
+
+ def _forward(obs):
+ # unsqueeze means add batch dim, i.e. (O, ) -> (1, O)
+ obs = ttorch.as_tensor(obs).unsqueeze(0)
+ if cuda and torch.cuda.is_available():
+ obs = obs.cuda()
+ action = forward_fn(obs).action
+ # squeeze means delete batch dim, i.e. (1, A) -> (A, )
+ action = action.squeeze(0).cpu().numpy()
+ return action
+
+ return _forward
diff --git a/DI-engine/ding/policy/cql.py b/DI-engine/ding/policy/cql.py
new file mode 100644
index 0000000000000000000000000000000000000000..b82ffd65df4af8a879e8f5e7331f8c8e30bf0451
--- /dev/null
+++ b/DI-engine/ding/policy/cql.py
@@ -0,0 +1,677 @@
+from typing import List, Dict, Any, Tuple, Union
+import copy
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.distributions import Normal, Independent
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \
+ qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .sac import SACPolicy
+from .qrdqn import QRDQNPolicy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('cql')
+class CQLPolicy(SACPolicy):
+ """
+ Overview:
+ Policy class of CQL algorithm for continuous control. Paper link: https://arxiv.org/abs/2006.04779.
+
+ Config:
+ == ==================== ======== ============= ================================= =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============= ================================= =======================
+ 1 ``type`` str cql | RL policy register name, refer | this arg is optional,
+ | to registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool True | Whether to use cuda for network |
+ 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for
+ | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/
+ | | buffer when training starts. | TD3.
+ 4 | ``model.policy_`` int 256 | Linear layer size for policy |
+ | ``embedding_size`` | network. |
+ 5 | ``model.soft_q_`` int 256 | Linear layer size for soft q |
+ | ``embedding_size`` | network. |
+ 6 | ``model.value_`` int 256 | Linear layer size for value | Defalut to None when
+ | ``embedding_size`` | network. | model.value_network
+ | | | is False.
+ 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when
+ | ``_rate_q`` | network. | model.value_network
+ | | | is True.
+ 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when
+ | ``_rate_policy`` | network. | model.value_network
+ | | | is True.
+ 9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when
+ | ``_rate_value`` | network. | model.value_network
+ | | | is False.
+ 10 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali-
+ | | coefficient. | zation for auto
+ | | | `alpha`, when
+ | | | auto_alpha is True
+ 11 | ``learn.repara_`` bool True | Determine whether to use |
+ | ``meterization`` | reparameterization trick. |
+ 12 | ``learn.`` bool False | Determine whether to use | Temperature parameter
+ | ``auto_alpha`` | auto temperature parameter | determines the
+ | | `alpha`. | relative importance
+ | | | of the entropy term
+ | | | against the reward.
+ 13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only
+ | ``ignore_done`` | done flag. | in halfcheetah env.
+ 14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation
+ | ``target_theta`` | target network. | factor in polyak aver
+ | | | aging for target
+ | | | networks.
+ == ==================== ======== ============= ================================= =======================
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='cql',
+ # (bool) Whether to use cuda for policy.
+ cuda=False,
+ # (bool) on_policy: Determine whether on-policy or off-policy.
+ # on-policy setting influences the behaviour of buffer.
+ on_policy=False,
+ # (bool) priority: Determine whether to use priority in buffer sample.
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (int) Number of training samples(randomly collected) in replay buffer when training starts.
+ random_collect_size=10000,
+ model=dict(
+ # (bool type) twin_critic: Determine whether to use double-soft-q-net for target q computation.
+ # Please refer to TD3 about Clipped Double-Q Learning trick, which learns two Q-functions instead of one .
+ # Default to True.
+ twin_critic=True,
+ # (str type) action_space: Use reparameterization trick for continous action
+ action_space='reparameterization',
+ # (int) Hidden size for actor network head.
+ actor_head_hidden_size=256,
+ # (int) Hidden size for critic network head.
+ critic_head_hidden_size=256,
+ ),
+ # learn_mode config
+ learn=dict(
+ # (int) How many updates (iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ update_per_collect=1,
+ # (int) Minibatch size for gradient descent.
+ batch_size=256,
+ # (float) learning_rate_q: Learning rate for soft q network.
+ learning_rate_q=3e-4,
+ # (float) learning_rate_policy: Learning rate for policy network.
+ learning_rate_policy=3e-4,
+ # (float) learning_rate_alpha: Learning rate for auto temperature parameter ``alpha``.
+ learning_rate_alpha=3e-4,
+ # (float) target_theta: Used for soft update of the target network,
+ # aka. Interpolation factor in polyak averaging for target networks.
+ target_theta=0.005,
+ # (float) discount factor for the discounted sum of rewards, aka. gamma.
+ discount_factor=0.99,
+ # (float) alpha: Entropy regularization coefficient.
+ # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details.
+ # If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`.
+ # Default to 0.2.
+ alpha=0.2,
+ # (bool) auto_alpha: Determine whether to use auto temperature parameter `\alpha` .
+ # Temperature parameter determines the relative importance of the entropy term against the reward.
+ # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details.
+ # Default to False.
+ # Note that: Using auto alpha needs to set learning_rate_alpha in `cfg.policy.learn`.
+ auto_alpha=True,
+ # (bool) log_space: Determine whether to use auto `\alpha` in log space.
+ log_space=True,
+ # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with done is False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+ # (float) Weight uniform initialization range in the last output layer.
+ init_w=3e-3,
+ # (int) The numbers of action sample each at every state s from a uniform-at-random.
+ num_actions=10,
+ # (bool) Whether use lagrange multiplier in q value loss.
+ with_lagrange=False,
+ # (float) The threshold for difference in Q-values.
+ lagrange_thresh=-1,
+ # (float) Loss weight for conservative item.
+ min_q_weight=1.0,
+ # (bool) Whether to use entropy in target q.
+ with_q_entropy=False,
+ ),
+ eval=dict(), # for compatibility
+ )
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \
+ contains three optimizers, algorithm-specific arguments such as gamma, min_q_weight, with_lagrange and \
+ with_q_entropy, main and target model. Especially, the ``auto_alpha`` mechanism for balancing max entropy \
+ target is also initialized here.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._twin_critic = self._cfg.model.twin_critic
+ self._num_actions = self._cfg.learn.num_actions
+
+ self._min_q_version = 3
+ self._min_q_weight = self._cfg.learn.min_q_weight
+ self._with_lagrange = self._cfg.learn.with_lagrange and (self._lagrange_thresh > 0)
+ self._lagrange_thresh = self._cfg.learn.lagrange_thresh
+ if self._with_lagrange:
+ self.target_action_gap = self._lagrange_thresh
+ self.log_alpha_prime = torch.tensor(0.).to(self._device).requires_grad_()
+ self.alpha_prime_optimizer = Adam(
+ [self.log_alpha_prime],
+ lr=self._cfg.learn.learning_rate_q,
+ )
+
+ self._with_q_entropy = self._cfg.learn.with_q_entropy
+
+ # Weight Init
+ init_w = self._cfg.learn.init_w
+ self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w)
+ self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w)
+ self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w)
+ self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w)
+ if self._twin_critic:
+ self._model.critic_head[0][-1].last.weight.data.uniform_(-init_w, init_w)
+ self._model.critic_head[0][-1].last.bias.data.uniform_(-init_w, init_w)
+ self._model.critic_head[1][-1].last.weight.data.uniform_(-init_w, init_w)
+ self._model.critic_head[1][-1].last.bias.data.uniform_(-init_w, init_w)
+ else:
+ self._model.critic_head[2].last.weight.data.uniform_(-init_w, init_w)
+ self._model.critic_head[-1].last.bias.data.uniform_(-init_w, init_w)
+
+ # Optimizers
+ self._optimizer_q = Adam(
+ self._model.critic.parameters(),
+ lr=self._cfg.learn.learning_rate_q,
+ )
+ self._optimizer_policy = Adam(
+ self._model.actor.parameters(),
+ lr=self._cfg.learn.learning_rate_policy,
+ )
+
+ # Algorithm config
+ self._gamma = self._cfg.learn.discount_factor
+ # Init auto alpha
+ if self._cfg.learn.auto_alpha:
+ if self._cfg.learn.target_entropy is None:
+ assert 'action_shape' in self._cfg.model, "CQL need network model with action_shape variable"
+ self._target_entropy = -np.prod(self._cfg.model.action_shape)
+ else:
+ self._target_entropy = self._cfg.learn.target_entropy
+ if self._cfg.learn.log_space:
+ self._log_alpha = torch.log(torch.FloatTensor([self._cfg.learn.alpha]))
+ self._log_alpha = self._log_alpha.to(self._device).requires_grad_()
+ self._alpha_optim = torch.optim.Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha)
+ assert self._log_alpha.shape == torch.Size([1]) and self._log_alpha.requires_grad
+ self._alpha = self._log_alpha.detach().exp()
+ self._auto_alpha = True
+ self._log_space = True
+ else:
+ self._alpha = torch.FloatTensor([self._cfg.learn.alpha]).to(self._device).requires_grad_()
+ self._alpha_optim = torch.optim.Adam([self._alpha], lr=self._cfg.learn.learning_rate_alpha)
+ self._auto_alpha = True
+ self._log_space = False
+ else:
+ self._alpha = torch.tensor(
+ [self._cfg.learn.alpha], requires_grad=False, device=self._device, dtype=torch.float32
+ )
+ self._auto_alpha = False
+
+ # Main and target models
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ self._forward_learn_cnt = 0
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the offline dataset and then returns the output \
+ result, including various training information such as loss, action, priority.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
+ dimension by some utility functions such as ``default_preprocess_learn``. \
+ For CQL, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
+ ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+ """
+ loss_dict = {}
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if len(data.get('action').shape) == 1:
+ data['action'] = data['action'].reshape(-1, 1)
+
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ self._learn_model.train()
+ self._target_model.train()
+ obs = data['obs']
+ next_obs = data['next_obs']
+ reward = data['reward']
+ done = data['done']
+
+ # 1. predict q value
+ q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
+
+ # 2. predict target value
+ with torch.no_grad():
+ (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit']
+
+ dist = Independent(Normal(mu, sigma), 1)
+ pred = dist.rsample()
+ next_action = torch.tanh(pred)
+ y = 1 - next_action.pow(2) + 1e-6
+ next_log_prob = dist.log_prob(pred).unsqueeze(-1)
+ next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True)
+
+ next_data = {'obs': next_obs, 'action': next_action}
+ target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']
+ # the value of a policy according to the maximum entropy objective
+ if self._twin_critic:
+ # find min one as target q value
+ if self._with_q_entropy:
+ target_q_value = torch.min(target_q_value[0],
+ target_q_value[1]) - self._alpha * next_log_prob.squeeze(-1)
+ else:
+ target_q_value = torch.min(target_q_value[0], target_q_value[1])
+ else:
+ if self._with_q_entropy:
+ target_q_value = target_q_value - self._alpha * next_log_prob.squeeze(-1)
+
+ # 3. compute q loss
+ if self._twin_critic:
+ q_data0 = v_1step_td_data(q_value[0], target_q_value, reward, done, data['weight'])
+ loss_dict['critic_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma)
+ q_data1 = v_1step_td_data(q_value[1], target_q_value, reward, done, data['weight'])
+ loss_dict['twin_critic_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma)
+ td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2
+ else:
+ q_data = v_1step_td_data(q_value, target_q_value, reward, done, data['weight'])
+ loss_dict['critic_loss'], td_error_per_sample = v_1step_td_error(q_data, self._gamma)
+
+ # 4. add CQL
+
+ curr_actions_tensor, curr_log_pis = self._get_policy_actions(data, self._num_actions)
+ new_curr_actions_tensor, new_log_pis = self._get_policy_actions({'obs': next_obs}, self._num_actions)
+
+ random_actions_tensor = torch.FloatTensor(curr_actions_tensor.shape).uniform_(-1,
+ 1).to(curr_actions_tensor.device)
+
+ obs_repeat = obs.unsqueeze(1).repeat(1, self._num_actions,
+ 1).view(obs.shape[0] * self._num_actions, obs.shape[1])
+ act_repeat = data['action'].unsqueeze(1).repeat(1, self._num_actions, 1).view(
+ data['action'].shape[0] * self._num_actions, data['action'].shape[1]
+ )
+
+ q_rand = self._get_q_value({'obs': obs_repeat, 'action': random_actions_tensor})
+ # q2_rand = self._get_q_value(obs, random_actions_tensor, network=self.qf2)
+ q_curr_actions = self._get_q_value({'obs': obs_repeat, 'action': curr_actions_tensor})
+ # q2_curr_actions = self._get_tensor_values(obs, curr_actions_tensor, network=self.qf2)
+ q_next_actions = self._get_q_value({'obs': obs_repeat, 'action': new_curr_actions_tensor})
+ # q2_next_actions = self._get_tensor_values(obs, new_curr_actions_tensor, network=self.qf2)
+
+ cat_q1 = torch.cat([q_rand[0], q_value[0].reshape(-1, 1, 1), q_next_actions[0], q_curr_actions[0]], 1)
+ cat_q2 = torch.cat([q_rand[1], q_value[1].reshape(-1, 1, 1), q_next_actions[1], q_curr_actions[1]], 1)
+ std_q1 = torch.std(cat_q1, dim=1)
+ std_q2 = torch.std(cat_q2, dim=1)
+ if self._min_q_version == 3:
+ # importance sampled version
+ random_density = np.log(0.5 ** curr_actions_tensor.shape[-1])
+ cat_q1 = torch.cat(
+ [
+ q_rand[0] - random_density, q_next_actions[0] - new_log_pis.detach(),
+ q_curr_actions[0] - curr_log_pis.detach()
+ ], 1
+ )
+ cat_q2 = torch.cat(
+ [
+ q_rand[1] - random_density, q_next_actions[1] - new_log_pis.detach(),
+ q_curr_actions[1] - curr_log_pis.detach()
+ ], 1
+ )
+
+ min_qf1_loss = torch.logsumexp(cat_q1, dim=1).mean() * self._min_q_weight
+ min_qf2_loss = torch.logsumexp(cat_q2, dim=1).mean() * self._min_q_weight
+ """Subtract the log likelihood of data"""
+ min_qf1_loss = min_qf1_loss - q_value[0].mean() * self._min_q_weight
+ min_qf2_loss = min_qf2_loss - q_value[1].mean() * self._min_q_weight
+
+ if self._with_lagrange:
+ alpha_prime = torch.clamp(self.log_alpha_prime.exp(), min=0.0, max=1000000.0)
+ min_qf1_loss = alpha_prime * (min_qf1_loss - self.target_action_gap)
+ min_qf2_loss = alpha_prime * (min_qf2_loss - self.target_action_gap)
+
+ self.alpha_prime_optimizer.zero_grad()
+ alpha_prime_loss = (-min_qf1_loss - min_qf2_loss) * 0.5
+ alpha_prime_loss.backward(retain_graph=True)
+ self.alpha_prime_optimizer.step()
+
+ loss_dict['critic_loss'] += min_qf1_loss
+ if self._twin_critic:
+ loss_dict['twin_critic_loss'] += min_qf2_loss
+
+ # 5. update q network
+ self._optimizer_q.zero_grad()
+ loss_dict['critic_loss'].backward(retain_graph=True)
+ if self._twin_critic:
+ loss_dict['twin_critic_loss'].backward()
+ self._optimizer_q.step()
+
+ # 6. evaluate to get action distribution
+ (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit']
+ dist = Independent(Normal(mu, sigma), 1)
+ pred = dist.rsample()
+ action = torch.tanh(pred)
+ y = 1 - action.pow(2) + 1e-6
+ log_prob = dist.log_prob(pred).unsqueeze(-1)
+ log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
+
+ eval_data = {'obs': obs, 'action': action}
+ new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value']
+ if self._twin_critic:
+ new_q_value = torch.min(new_q_value[0], new_q_value[1])
+
+ # 8. compute policy loss
+ policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean()
+
+ loss_dict['policy_loss'] = policy_loss
+
+ # 9. update policy network
+ self._optimizer_policy.zero_grad()
+ loss_dict['policy_loss'].backward()
+ self._optimizer_policy.step()
+
+ # 10. compute alpha loss
+ if self._auto_alpha:
+ if self._log_space:
+ log_prob = log_prob + self._target_entropy
+ loss_dict['alpha_loss'] = -(self._log_alpha * log_prob.detach()).mean()
+
+ self._alpha_optim.zero_grad()
+ loss_dict['alpha_loss'].backward()
+ self._alpha_optim.step()
+ self._alpha = self._log_alpha.detach().exp()
+ else:
+ log_prob = log_prob + self._target_entropy
+ loss_dict['alpha_loss'] = -(self._alpha * log_prob.detach()).mean()
+
+ self._alpha_optim.zero_grad()
+ loss_dict['alpha_loss'].backward()
+ self._alpha_optim.step()
+ self._alpha = max(0, self._alpha)
+
+ loss_dict['total_loss'] = sum(loss_dict.values())
+
+ # =============
+ # after update
+ # =============
+ self._forward_learn_cnt += 1
+ # target update
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr_q': self._optimizer_q.defaults['lr'],
+ 'cur_lr_p': self._optimizer_policy.defaults['lr'],
+ 'priority': td_error_per_sample.abs().tolist(),
+ 'td_error': td_error_per_sample.detach().mean().item(),
+ 'alpha': self._alpha.item(),
+ 'target_q_value': target_q_value.detach().mean().item(),
+ **loss_dict
+ }
+
+ def _get_policy_actions(self, data: Dict, num_actions: int = 10, epsilon: float = 1e-6) -> List:
+ # evaluate to get action distribution
+ obs = data['obs']
+ obs = obs.unsqueeze(1).repeat(1, num_actions, 1).view(obs.shape[0] * num_actions, obs.shape[1])
+ (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit']
+ dist = Independent(Normal(mu, sigma), 1)
+ pred = dist.rsample()
+ action = torch.tanh(pred)
+
+ # evaluate action log prob depending on Jacobi determinant.
+ y = 1 - action.pow(2) + epsilon
+ log_prob = dist.log_prob(pred).unsqueeze(-1)
+ log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
+
+ return action, log_prob.view(-1, num_actions, 1)
+
+ def _get_q_value(self, data: Dict, keep: bool = True) -> torch.Tensor:
+ new_q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
+ if self._twin_critic:
+ new_q_value = [value.view(-1, self._num_actions, 1) for value in new_q_value]
+ else:
+ new_q_value = new_q_value.view(-1, self._num_actions, 1)
+ if self._twin_critic and not keep:
+ new_q_value = torch.min(new_q_value[0], new_q_value[1])
+ return new_q_value
+
+
+@POLICY_REGISTRY.register('discrete_cql')
+class DiscreteCQLPolicy(QRDQNPolicy):
+ """
+ Overview:
+ Policy class of discrete CQL algorithm in discrete action space environments.
+ Paper link: https://arxiv.org/abs/2006.04779.
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='discrete_cql',
+ # (bool) Whether to use cuda for policy.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.97,
+ # (int) N-step reward for target q_value estimation
+ nstep=1,
+ # learn_mode config
+ learn=dict(
+ # (int) How many updates (iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ update_per_collect=1,
+ # (int) Minibatch size for one gradient descent.
+ batch_size=64,
+ # (float) Learning rate for soft q network.
+ learning_rate=0.001,
+ # (int) Frequence of target network update.
+ target_update_freq=100,
+ # (bool) Whether ignore done(usually for max step termination env).
+ ignore_done=False,
+ # (float) Loss weight for conservative item.
+ min_q_weight=1.0,
+ ),
+ eval=dict(), # for compatibility
+ )
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For DiscreteCQL, it mainly \
+ contains the optimizer, algorithm-specific arguments such as gamma, nstep and min_q_weight, main and \
+ target model. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._min_q_weight = self._cfg.learn.min_q_weight
+ self._priority = self._cfg.priority
+ # Optimizer
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+
+ # use wrapper instead of plugin
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='assign',
+ update_kwargs={'freq': self._cfg.learn.target_update_freq}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the offline dataset and then returns the output \
+ result, including various training information such as loss, action, priority.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
+ dimension by some utility functions such as ``default_preprocess_learn``. \
+ For DiscreteCQL, each element in list is a dict containing at least the following keys: ``obs``, \
+ ``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight`` \
+ and ``value_gamma`` for nstep return computation.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+ """
+ data = default_preprocess_learn(
+ data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ if data['action'].dim() == 2 and data['action'].shape[-1] == 1:
+ data['action'] = data['action'].squeeze(-1)
+ # ====================
+ # Q-learning forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # Current q value (main model)
+ ret = self._learn_model.forward(data['obs'])
+ q_value, tau = ret['q'], ret['tau']
+ # Target q value
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(data['next_obs'])['q']
+ # Max q value action (main model)
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+
+ # add CQL
+ # 1. chose action and compute q in dataset.
+ # 2. compute value loss(negative_sampling - dataset_expec)
+ replay_action_one_hot = F.one_hot(data['action'], self._cfg.model.action_shape)
+ replay_chosen_q = (q_value.mean(-1) * replay_action_one_hot).sum(dim=1)
+
+ dataset_expec = replay_chosen_q.mean()
+
+ negative_sampling = torch.logsumexp(q_value.mean(-1), dim=1).mean()
+
+ min_q_loss = negative_sampling - dataset_expec
+
+ data_n = qrdqn_nstep_td_data(
+ q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], tau, data['weight']
+ )
+ value_gamma = data.get('value_gamma')
+ loss, td_error_per_sample = qrdqn_nstep_td_error(
+ data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma
+ )
+
+ loss += self._min_q_weight * min_q_loss
+
+ # ====================
+ # Q-learning update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ self._optimizer.step()
+
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ 'q_target': target_q_value.mean().item(),
+ 'q_value': q_value.mean().item(),
+ # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
+ # '[histogram]action_distribution': data['action'],
+ }
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return ['cur_lr', 'total_loss', 'q_target', 'q_value']
diff --git a/DI-engine/ding/policy/d4pg.py b/DI-engine/ding/policy/d4pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..c7f4f4ebd70f98dab3c472083a924c432b58b197
--- /dev/null
+++ b/DI-engine/ding/policy/d4pg.py
@@ -0,0 +1,376 @@
+from typing import List, Dict, Any, Tuple, Union
+import torch
+import copy
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import get_train_sample
+from ding.rl_utils import dist_nstep_td_data, dist_nstep_td_error, get_nstep_return_data
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from .ddpg import DDPGPolicy
+from .common_utils import default_preprocess_learn
+import numpy as np
+
+
+@POLICY_REGISTRY.register('d4pg')
+class D4PGPolicy(DDPGPolicy):
+ """
+ Overview:
+ Policy class of D4PG algorithm. D4PG is a variant of DDPG, which uses distributional critic. \
+ The distributional critic is implemented by using quantile regression. \
+ Paper link: https://arxiv.org/abs/1804.08617.
+
+ Property:
+ learn_mode, collect_mode, eval_mode
+ Config:
+ == ==================== ======== ============= ================================= =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============= ================================= =======================
+ 1 ``type`` str d4pg | RL policy register name, refer | this arg is optional,
+ | to registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool True | Whether to use cuda for network |
+ 3 | ``random_`` int 25000 | Number of randomly collected | Default to 25000 for
+ | ``collect_size`` | training samples in replay | DDPG/TD3, 10000 for
+ | | buffer when training starts. | sac.
+ 5 | ``learn.learning`` float 1e-3 | Learning rate for actor |
+ | ``_rate_actor`` | network(aka. policy). |
+ 6 | ``learn.learning`` float 1e-3 | Learning rates for critic |
+ | ``_rate_critic`` | network (aka. Q-network). |
+ 7 | ``learn.actor_`` int 1 | When critic network updates | Default 1
+ | ``update_freq`` | once, how many times will actor |
+ | | network update. |
+ 8 | ``learn.noise`` bool False | Whether to add noise on target | Default False for
+ | | network's action. | D4PG.
+ | | | Target Policy Smoo-
+ | | | thing Regularization
+ | | | in TD3 paper.
+ 9 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only
+ | ``ignore_done`` | done flag. | in halfcheetah env.
+ 10 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation
+ | ``target_theta`` | target network. | factor in polyak aver
+ | | | aging for target
+ | | | networks.
+ 11 | ``collect.-`` float 0.1 | Used for add noise during co- | Sample noise from dis
+ | ``noise_sigma`` | llection, through controlling | tribution, Gaussian
+ | | the sigma of distribution | process.
+ 12 | ``model.v_min`` float -10 | Value of the smallest atom |
+ | | in the support set. |
+ 13 | ``model.v_max`` float 10 | Value of the largest atom |
+ | | in the support set. |
+ 14 | ``model.n_atom`` int 51 | Number of atoms in the support |
+ | | set of the value distribution. |
+ 15 | ``nstep`` int 3, [1, 5] | N-step reward discount sum for |
+ | | target q_value estimation |
+ 16 | ``priority`` bool True | Whether use priority(PER) | priority sample,
+ | update priority
+ == ==================== ======== ============= ================================= =======================
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='d4pg',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool type) on_policy: Determine whether on-policy or off-policy.
+ # on-policy setting influences the behaviour of buffer.
+ # Default False in D4PG.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ # Default True in D4PG.
+ priority=True,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=True,
+ # (int) Number of training samples(randomly collected) in replay buffer when training starts.
+ # Default 25000 in D4PG.
+ random_collect_size=25000,
+ # (int) N-step reward for target q_value estimation
+ nstep=3,
+ # (str) Action space type
+ action_space='continuous', # ['continuous', 'hybrid']
+ # (bool) Whether use batch normalization for reward
+ reward_batch_norm=False,
+ # (bool) Whether to need policy data in process transition
+ transition_with_policy_data=False,
+ model=dict(
+ # (float) Value of the smallest atom in the support set.
+ # Default to -10.0.
+ v_min=-10,
+ # (float) Value of the smallest atom in the support set.
+ # Default to 10.0.
+ v_max=10,
+ # (int) Number of atoms in the support set of the
+ # value distribution. Default to 51.
+ n_atom=51
+ ),
+ learn=dict(
+
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=1,
+ # (int) Minibatch size for gradient descent.
+ batch_size=256,
+ # Learning rates for actor network(aka. policy).
+ learning_rate_actor=1e-3,
+ # Learning rates for critic network(aka. Q-network).
+ learning_rate_critic=1e-3,
+ # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with done is False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+ # (float type) target_theta: Used for soft update of the target network,
+ # aka. Interpolation factor in polyak averaging for target networks.
+ # Default to 0.005.
+ target_theta=0.005,
+ # (float) discount factor for the discounted sum of rewards, aka. gamma.
+ discount_factor=0.99,
+ # (int) When critic network updates once, how many times will actor network update.
+ actor_update_freq=1,
+ # (bool) Whether to add noise on target network's action.
+ # Target Policy Smoothing Regularization in original TD3 paper.
+ noise=False,
+ ),
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] should be set
+ # n_sample=1,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ # It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma".
+ noise_sigma=0.1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ other=dict(
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer.
+ replay_buffer_size=1000000,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return the default neural network model class for D4PGPolicy. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+ """
+ return 'qac_dist', ['ding.model.template.qac_dist']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the D4PG policy's learning mode, which involves setting up key components \
+ specific to the D4PG algorithm. This includes creating separate optimizers for the actor \
+ and critic networks, a distinctive trait of D4PG's actor-critic approach, and configuring \
+ algorithm-specific parameters such as v_min, v_max, and n_atom for the distributional aspect \
+ of the critic. Additionally, the method sets up the target model with momentum-based updates, \
+ crucial for stabilizing learning, and optionally integrates noise into the target model for \
+ effective exploration. This method is invoked during the '__init__' if 'learn' is specified \
+ in 'enable_field'.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ # actor and critic optimizer
+ self._optimizer_actor = Adam(
+ self._model.actor.parameters(),
+ lr=self._cfg.learn.learning_rate_actor,
+ )
+ self._optimizer_critic = Adam(
+ self._model.critic.parameters(),
+ lr=self._cfg.learn.learning_rate_critic,
+ )
+ self._reward_batch_norm = self._cfg.reward_batch_norm
+
+ self._gamma = self._cfg.learn.discount_factor
+ self._nstep = self._cfg.nstep
+ self._actor_update_freq = self._cfg.learn.actor_update_freq
+
+ # main and target models
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ if self._cfg.learn.noise:
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='action_noise',
+ noise_type='gauss',
+ noise_kwargs={
+ 'mu': 0.0,
+ 'sigma': self._cfg.learn.noise_sigma
+ },
+ noise_range=self._cfg.learn.noise_range
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ self._v_max = self._cfg.model.v_max
+ self._v_min = self._cfg.model.v_min
+ self._n_atom = self._cfg.model.n_atom
+
+ self._forward_learn_cnt = 0 # count iterations
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as different loss, actor and critic lr.
+ Arguments:
+ - data (:obj:`dict`): Input data used for policy forward, including the \
+ collected training samples from replay buffer. For each element in dict, the key of the \
+ dict is the name of data items and the value is the corresponding data. Usually, the value is \
+ torch.Tensor or np.ndarray or there dict/list combinations. In the ``_forward_learn`` method, data \
+ often need to first be stacked in the batch dimension by some utility functions such as \
+ ``default_preprocess_learn``. \
+ For D4PG, each element in list is a dict containing at least the following keys: ``obs``, \
+ ``action``, ``reward``, ``next_obs``. Sometimes, it also contains other keys such as ``weight``.
+
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The output result dict of forward learn, containing at \
+ least the "cur_lr_actor", "cur_lr_critic", "different losses", "q_value", "action", "priority", \
+ keys. Additionally, loss_dict also contains other keys, which are mainly used for monitoring and \
+ debugging. "q_value_dict" is used to record the q_value statistics.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for D4PGPolicy: ``ding.policy.tests.test_d4pg``.
+ """
+ loss_dict = {}
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._cfg.priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # critic learn forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ next_obs = data.get('next_obs')
+ reward = data.get('reward')
+ if self._reward_batch_norm:
+ reward = (reward - reward.mean()) / (reward.std() + 1e-8)
+ # current q value
+ q_value = self._learn_model.forward(data, mode='compute_critic')
+ q_value_dict = {}
+ q_dist = q_value['distribution']
+ q_value_dict['q_value'] = q_value['q_value'].mean()
+ # target q value.
+ with torch.no_grad():
+ next_action = self._target_model.forward(next_obs, mode='compute_actor')['action']
+ next_data = {'obs': next_obs, 'action': next_action}
+ target_q_dist = self._target_model.forward(next_data, mode='compute_critic')['distribution']
+
+ value_gamma = data.get('value_gamma')
+ action_index = np.zeros(next_action.shape[0])
+ # since the action is a scalar value, action index is set to 0 which is the only possible choice
+ td_data = dist_nstep_td_data(
+ q_dist, target_q_dist, action_index, action_index, reward, data['done'], data['weight']
+ )
+ critic_loss, td_error_per_sample = dist_nstep_td_error(
+ td_data, self._gamma, self._v_min, self._v_max, self._n_atom, nstep=self._nstep, value_gamma=value_gamma
+ )
+ loss_dict['critic_loss'] = critic_loss
+ # ================
+ # critic update
+ # ================
+ self._optimizer_critic.zero_grad()
+ for k in loss_dict:
+ if 'critic' in k:
+ loss_dict[k].backward()
+ self._optimizer_critic.step()
+ # ===============================
+ # actor learn forward and update
+ # ===============================
+ # actor updates every ``self._actor_update_freq`` iters
+ if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
+ actor_data = self._learn_model.forward(data['obs'], mode='compute_actor')
+ actor_data['obs'] = data['obs']
+ actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'].mean()
+
+ loss_dict['actor_loss'] = actor_loss
+ # actor update
+ self._optimizer_actor.zero_grad()
+ actor_loss.backward()
+ self._optimizer_actor.step()
+ # =============
+ # after update
+ # =============
+ loss_dict['total_loss'] = sum(loss_dict.values())
+ self._forward_learn_cnt += 1
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr_actor': self._optimizer_actor.defaults['lr'],
+ 'cur_lr_critic': self._optimizer_critic.defaults['lr'],
+ 'q_value': q_value['q_value'].mean().item(),
+ 'action': data['action'].mean().item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ **loss_dict,
+ **q_value_dict,
+ }
+
+ def _get_train_sample(self, traj: list) -> Union[None, List[Any]]:
+ """
+ Overview:
+ Process the data of a given trajectory (transitions, a list of transition) into a list of sample that \
+ can be used for training directly. The sample is generated by the following steps: \
+ 1. Calculate the nstep return data. \
+ 2. Sample the data from the nstep return data. \
+ 3. Stack the data in the batch dimension. \
+ 4. Return the sample data. \
+ For D4PG, the nstep return data is generated by ``get_nstep_return_data`` and the sample data is \
+ generated by ``get_train_sample``.
+
+ Arguments:
+ - traj (:obj:`list`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+
+ Returns:
+ - samples (:obj:`dict`): The training samples generated, including at least the following keys: \
+ ``'obs'``, ``'next_obs'``, ``'action'``, ``'reward'``, ``'done'``, ``'weight'``, ``'value_gamma'``. \
+ For more information, please refer to the ``get_train_sample`` method.
+ """
+ data = get_nstep_return_data(traj, self._nstep, gamma=self._gamma)
+ return get_train_sample(data, self._unroll_len)
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ ret = ['cur_lr_actor', 'cur_lr_critic', 'critic_loss', 'actor_loss', 'total_loss', 'q_value', 'action']
+ return ret
diff --git a/DI-engine/ding/policy/ddpg.py b/DI-engine/ding/policy/ddpg.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e253370b89c086126b92c664174b84771c39cfd
--- /dev/null
+++ b/DI-engine/ding/policy/ddpg.py
@@ -0,0 +1,542 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import torch
+import copy
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('ddpg')
+class DDPGPolicy(Policy):
+ """
+ Overview:
+ Policy class of DDPG algorithm. Paper link: https://arxiv.org/abs/1509.02971.
+
+ Config:
+ == ==================== ======== ============= ================================= =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============= ================================= =======================
+ 1 | ``type`` str ddpg | RL policy register name, refer | this arg is optional,
+ | | to registry ``POLICY_REGISTRY`` | a placeholder
+ 2 | ``cuda`` bool False | Whether to use cuda for network |
+ 3 | ``random_`` int 25000 | Number of randomly collected | Default to 25000 for
+ | ``collect_size`` | training samples in replay | DDPG/TD3, 10000 for
+ | | buffer when training starts. | sac.
+ 4 | ``model.twin_`` bool False | Whether to use two critic | Default False for
+ | ``critic`` | networks or only one. | DDPG, Clipped Double
+ | | | Q-learning method in
+ | | | TD3 paper.
+ 5 | ``learn.learning`` float 1e-3 | Learning rate for actor |
+ | ``_rate_actor`` | network(aka. policy). |
+ 6 | ``learn.learning`` float 1e-3 | Learning rates for critic |
+ | ``_rate_critic`` | network (aka. Q-network). |
+ 7 | ``learn.actor_`` int 2 | When critic network updates | Default 1 for DDPG,
+ | ``update_freq`` | once, how many times will actor | 2 for TD3. Delayed
+ | | network update. | Policy Updates method
+ | | | in TD3 paper.
+ 8 | ``learn.noise`` bool False | Whether to add noise on target | Default False for
+ | | network's action. | DDPG, True for TD3.
+ | | | Target Policy Smoo-
+ | | | thing Regularization
+ | | | in TD3 paper.
+ 9 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only
+ | ``ignore_done`` | done flag. | in halfcheetah env.
+ 10 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation
+ | ``target_theta`` | target network. | factor in polyak aver-
+ | | | aging for target
+ | | | networks.
+ 11 | ``collect.-`` float 0.1 | Used for add noise during co- | Sample noise from dis-
+ | ``noise_sigma`` | llection, through controlling | tribution, Ornstein-
+ | | the sigma of distribution | Uhlenbeck process in
+ | | | DDPG paper, Gaussian
+ | | | process in ours.
+ == ==================== ======== ============= ================================= =======================
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='ddpg',
+ # (bool) Whether to use cuda in policy.
+ cuda=False,
+ # (bool) Whether learning policy is the same as collecting data policy(on-policy). Default False in DDPG.
+ on_policy=False,
+ # (bool) Whether to enable priority experience sample.
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (int) Number of training samples(randomly collected) in replay buffer when training starts.
+ # Default 25000 in DDPG/TD3.
+ random_collect_size=25000,
+ # (bool) Whether to need policy data in process transition.
+ transition_with_policy_data=False,
+ # (str) Action space type, including ['continuous', 'hybrid'].
+ action_space='continuous',
+ # (bool) Whether use batch normalization for reward.
+ reward_batch_norm=False,
+ # (bool) Whether to enable multi-agent training setting.
+ multi_agent=False,
+ # learn_mode config
+ learn=dict(
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=1,
+ # (int) Minibatch size for gradient descent.
+ batch_size=256,
+ # (float) Learning rates for actor network(aka. policy).
+ learning_rate_actor=1e-3,
+ # (float) Learning rates for critic network(aka. Q-network).
+ learning_rate_critic=1e-3,
+ # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+ # (float) target_theta: Used for soft update of the target network,
+ # aka. Interpolation factor in polyak averaging for target networks.
+ # Default to 0.005.
+ target_theta=0.005,
+ # (float) discount factor for the discounted sum of rewards, aka. gamma.
+ discount_factor=0.99,
+ # (int) When critic network updates once, how many times will actor network update.
+ # Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default 1 for DDPG, 2 for TD3.
+ actor_update_freq=1,
+ # (bool) Whether to add noise on target network's action.
+ # Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default True for TD3, False for DDPG.
+ noise=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ # Only one of [n_sample, n_episode] shoule be set.
+ # n_sample=1,
+ # (int) Split episodes or trajectories into pieces with length `unroll_len`.
+ unroll_len=1,
+ # (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma".
+ noise_sigma=0.1,
+ ),
+ eval=dict(), # for compability
+ other=dict(
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer. Usually, larger buffer size is better.
+ replay_buffer_size=100000,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+ """
+ if self._cfg.multi_agent:
+ return 'continuous_maqac', ['ding.model.template.maqac']
+ else:
+ return 'continuous_qac', ['ding.model.template.qac']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For DDPG, it mainly \
+ contains two optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target model.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ # actor and critic optimizer
+ self._optimizer_actor = Adam(
+ self._model.actor.parameters(),
+ lr=self._cfg.learn.learning_rate_actor,
+ )
+ self._optimizer_critic = Adam(
+ self._model.critic.parameters(),
+ lr=self._cfg.learn.learning_rate_critic,
+ )
+ self._reward_batch_norm = self._cfg.reward_batch_norm
+
+ self._gamma = self._cfg.learn.discount_factor
+ self._actor_update_freq = self._cfg.learn.actor_update_freq
+ self._twin_critic = self._cfg.model.twin_critic # True for TD3, False for DDPG
+
+ # main and target models
+ self._target_model = copy.deepcopy(self._model)
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ if self._cfg.action_space == 'hybrid':
+ self._learn_model = model_wrap(self._learn_model, wrapper_name='hybrid_argmax_sample')
+ self._target_model = model_wrap(self._target_model, wrapper_name='hybrid_argmax_sample')
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ if self._cfg.learn.noise:
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='action_noise',
+ noise_type='gauss',
+ noise_kwargs={
+ 'mu': 0.0,
+ 'sigma': self._cfg.learn.noise_sigma
+ },
+ noise_range=self._cfg.learn.noise_range
+ )
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ self._forward_learn_cnt = 0 # count iterations
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss, action, priority.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
+ dimension by some utility functions such as ``default_preprocess_learn``. \
+ For DDPG, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
+ ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \
+ and ``logit`` which is used for hybrid action space.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for DDPGPolicy: ``ding.policy.tests.test_ddpg``.
+ """
+ loss_dict = {}
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._cfg.priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # critic learn forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ next_obs = data['next_obs']
+ reward = data['reward']
+ if self._reward_batch_norm:
+ reward = (reward - reward.mean()) / (reward.std() + 1e-8)
+ # current q value
+ q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
+
+ # target q value.
+ with torch.no_grad():
+ next_actor_data = self._target_model.forward(next_obs, mode='compute_actor')
+ next_actor_data['obs'] = next_obs
+ target_q_value = self._target_model.forward(next_actor_data, mode='compute_critic')['q_value']
+
+ q_value_dict = {}
+ target_q_value_dict = {}
+
+ if self._twin_critic:
+ # TD3: two critic networks
+ target_q_value = torch.min(target_q_value[0], target_q_value[1]) # find min one as target q value
+ q_value_dict['q_value'] = q_value[0].mean().data.item()
+ q_value_dict['q_value_twin'] = q_value[1].mean().data.item()
+ target_q_value_dict['target q_value'] = target_q_value.mean().data.item()
+ # critic network1
+ td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight'])
+ critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma)
+ loss_dict['critic_loss'] = critic_loss
+ # critic network2(twin network)
+ td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight'])
+ critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma)
+ loss_dict['critic_twin_loss'] = critic_twin_loss
+ td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2
+ else:
+ # DDPG: single critic network
+ q_value_dict['q_value'] = q_value.mean().data.item()
+ target_q_value_dict['target q_value'] = target_q_value.mean().data.item()
+ td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight'])
+ critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma)
+ loss_dict['critic_loss'] = critic_loss
+ # ================
+ # critic update
+ # ================
+ self._optimizer_critic.zero_grad()
+ for k in loss_dict:
+ if 'critic' in k:
+ loss_dict[k].backward()
+ self._optimizer_critic.step()
+ # ===============================
+ # actor learn forward and update
+ # ===============================
+ # actor updates every ``self._actor_update_freq`` iters
+ if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
+ actor_data = self._learn_model.forward(data['obs'], mode='compute_actor')
+ actor_data['obs'] = data['obs']
+ if self._twin_critic:
+ actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0].mean()
+ else:
+ actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'].mean()
+
+ loss_dict['actor_loss'] = actor_loss
+ # actor update
+ self._optimizer_actor.zero_grad()
+ actor_loss.backward()
+ self._optimizer_actor.step()
+ # =============
+ # after update
+ # =============
+ loss_dict['total_loss'] = sum(loss_dict.values())
+ self._forward_learn_cnt += 1
+ self._target_model.update(self._learn_model.state_dict())
+ if self._cfg.action_space == 'hybrid':
+ action_log_value = -1. # TODO(nyz) better way to viz hybrid action
+ else:
+ action_log_value = data['action'].mean()
+ return {
+ 'cur_lr_actor': self._optimizer_actor.defaults['lr'],
+ 'cur_lr_critic': self._optimizer_critic.defaults['lr'],
+ # 'q_value': np.array(q_value).mean(),
+ 'action': action_log_value,
+ 'priority': td_error_per_sample.abs().tolist(),
+ 'td_error': td_error_per_sample.abs().mean(),
+ **loss_dict,
+ **q_value_dict,
+ **target_q_value_dict,
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode, usually including model, target_model and optimizers.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer_actor': self._optimizer_actor.state_dict(),
+ 'optimizer_critic': self._optimizer_critic.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer_actor.load_state_dict(state_dict['optimizer_actor'])
+ self._optimizer_critic.load_state_dict(state_dict['optimizer_critic'])
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For DDPG, it contains the \
+ collect_model to balance the exploration and exploitation with the perturbed noise mechanism, and other \
+ algorithm-specific arguments such as unroll_len. \
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ # collect model
+ self._collect_model = model_wrap(
+ self._model,
+ wrapper_name='action_noise',
+ noise_type='gauss',
+ noise_kwargs={
+ 'mu': 0.0,
+ 'sigma': self._cfg.collect.noise_sigma
+ },
+ noise_range=None
+ )
+ if self._cfg.action_space == 'hybrid':
+ self._collect_model = model_wrap(self._collect_model, wrapper_name='hybrid_eps_greedy_multinomial_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \
+ dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for DDPGPolicy: ``ding.policy.tests.test_ddpg``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, mode='compute_actor', **kwargs)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor],
+ timestep: namedtuple) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For DDPG, it contains obs, next_obs, action, reward, done.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.
+ - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \
+ as input. For DDPG, it contains the action and the logit of the action (in hybrid action space).
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
+ except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
+ reward, done, info, etc.
+ Returns:
+ - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': policy_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ if self._cfg.action_space == 'hybrid':
+ transition['logit'] = policy_output['logit']
+ return transition
+
+ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. In DDPG, a train sample is a processed transition (unroll_len=1).
+ Arguments:
+ - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \
+ as input transitions, but may contain more data for training.
+ """
+ return get_train_sample(transitions, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For DDPG, it contains the \
+ eval model to greedily select action type with argmax q_value mechanism for hybrid action space. \
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='base')
+ if self._cfg.action_space == 'hybrid':
+ self._eval_model = model_wrap(self._eval_model, wrapper_name='hybrid_argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for DDPGPolicy: ``ding.policy.tests.test_ddpg``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ ret = [
+ 'cur_lr_actor', 'cur_lr_critic', 'critic_loss', 'actor_loss', 'total_loss', 'q_value', 'q_value_twin',
+ 'action', 'td_error'
+ ]
+ if self._twin_critic:
+ ret += ['critic_twin_loss']
+ return ret
diff --git a/DI-engine/ding/policy/dqfd.py b/DI-engine/ding/policy/dqfd.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e9ecab8530dd69f4a889d9bf6c42c5ff8f0561d
--- /dev/null
+++ b/DI-engine/ding/policy/dqfd.py
@@ -0,0 +1,273 @@
+from typing import List, Dict, Any, Tuple
+from collections import namedtuple
+import copy
+import torch
+from torch.optim import AdamW
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample, \
+ dqfd_nstep_td_error, dqfd_nstep_td_data
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .dqn import DQNPolicy
+from .common_utils import default_preprocess_learn
+from copy import deepcopy
+
+
+@POLICY_REGISTRY.register('dqfd')
+class DQFDPolicy(DQNPolicy):
+ r"""
+ Overview:
+ Policy class of DQFD algorithm, extended by Double DQN/Dueling DQN/PER/multi-step TD.
+
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str dqn | RL policy register name, refer to | This arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool True | Whether use priority(PER) | Priority sample,
+ | update priority
+ 5 | ``priority_IS`` bool True | Whether use Importance Sampling Weight
+ | ``_weight`` | to correct biased update. If True,
+ | priority must be True.
+ 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 7 ``nstep`` int 10, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 8 | ``lambda1`` float 1 | multiplicative factor for n-step
+ 9 | ``lambda2`` float 1 | multiplicative factor for the
+ | supervised margin loss
+ 10 | ``lambda3`` float 1e-5 | L2 loss
+ 11 | ``margin_fn`` float 0.8 | margin function in JE, here we set
+ | this as a constant
+ 12 | ``per_train_`` int 10 | number of pertraining iterations
+ | ``iter_k``
+ 13 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 14 | ``learn.batch_`` int 64 | The number of samples of an iteration
+ | ``size``
+ 15 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
+ | ``_rate``
+ 16 | ``learn.target_`` int 100 | Frequency of target network update. | Hard(assign) update
+ | ``update_freq``
+ 17 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
+ | ``done`` | calculation. | fake termination env
+ 18 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
+ | call of collector. | different envs
+ 19 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
+ | ``_len``
+ == ==================== ======== ============== ======================================== =======================
+ """
+
+ config = dict(
+ type='dqfd',
+ cuda=False,
+ on_policy=False,
+ priority=True,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=True,
+ discount_factor=0.99,
+ nstep=10,
+ learn=dict(
+ # multiplicative factor for each loss
+ lambda1=1.0, # n-step return
+ lambda2=1.0, # supervised loss
+ lambda3=1e-5, # L2
+ # margin function in JE, here we implement this as a constant
+ margin_function=0.8,
+ # number of pertraining iterations
+ per_train_iter_k=10,
+
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=3,
+ batch_size=64,
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (int) Frequence of target network update.
+ target_update_freq=100,
+ # (bool) Whether ignore done(usually for max step termination env)
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] should be set
+ # n_sample=8,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ # The hyperparameter pho, the demo ratio, control the propotion of data\
+ # coming from expert demonstrations versus from the agent's own experience.
+ pho=0.5,
+ ),
+ eval=dict(),
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=0.95,
+ end=0.1,
+ # (int) Decay length(env step)
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=10000, ),
+ ),
+ )
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Learn mode init method. Called by ``self.__init__``, initialize the optimizer, algorithm arguments, main \
+ and target model.
+ """
+ self.lambda1 = self._cfg.learn.lambda1 # n-step return
+ self.lambda2 = self._cfg.learn.lambda2 # supervised loss
+ self.lambda3 = self._cfg.learn.lambda3 # L2
+ # margin function in JE, here we implement this as a constant
+ self.margin_function = self._cfg.learn.margin_function
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ # Optimizer
+ # two optimizers: the performance of adamW is better than adam, so we recommend using the adamW.
+ self._optimizer = AdamW(self._model.parameters(), lr=self._cfg.learn.learning_rate, weight_decay=self.lambda3)
+ # self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate, weight_decay=self.lambda3)
+
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+
+ # use model_wrapper for specialized demands of different modes
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='assign',
+ update_kwargs={'freq': self._cfg.learn.target_update_freq}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Overview:
+ Forward computation graph of learn mode(updating policy).
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
+ np.ndarray or dict/list combinations.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
+ recorded in text log and tensorboard, values are python scalar or a list of scalars.
+ ArgumentsKeys:
+ - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
+ - optional: ``value_gamma``, ``IS``
+ ReturnsKeys:
+ - necessary: ``cur_lr``, ``total_loss``, ``priority``
+ - optional: ``action_distribution``
+ """
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=True
+ )
+ data['done_1'] = data['done_1'].float()
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # Q-learning forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # Current q value (main model)
+ q_value = self._learn_model.forward(data['obs'])['logit']
+ # Target q value
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(data['next_obs'])['logit']
+ target_q_value_one_step = self._target_model.forward(data['next_obs_1'])['logit']
+ # Max q value action (main model)
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+ target_q_action_one_step = self._learn_model.forward(data['next_obs_1'])['action']
+
+ # modify the tensor type to match the JE computation in dqfd_nstep_td_error
+ is_expert = data['is_expert'].float()
+ data_n = dqfd_nstep_td_data(
+ q_value,
+ target_q_value,
+ data['action'],
+ target_q_action,
+ data['reward'],
+ data['done'],
+ data['done_1'],
+ data['weight'],
+ target_q_value_one_step,
+ target_q_action_one_step,
+ is_expert # set is_expert flag(expert 1, agent 0)
+ )
+ value_gamma = data.get('value_gamma')
+ loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error(
+ data_n,
+ self._gamma,
+ self.lambda1,
+ self.lambda2,
+ self.margin_function,
+ nstep=self._nstep,
+ value_gamma=value_gamma
+ )
+
+ # ====================
+ # Q-learning update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ self._optimizer.step()
+
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
+ # '[histogram]action_distribution': data['action'],
+ }
+
+ def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \
+ or some continuous transitions(DRQN).
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \
+ format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`dict`): The list of training samples.
+
+ .. note::
+ We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \
+ And the user can customize the this data processing procecure by overriding this two methods and collector \
+ itself.
+ """
+ data_1 = deepcopy(get_nstep_return_data(data, 1, gamma=self._gamma))
+ data = get_nstep_return_data(
+ data, self._nstep, gamma=self._gamma
+ ) # here we want to include one-step next observation
+ for i in range(len(data)):
+ data[i]['next_obs_1'] = data_1[i]['next_obs'] # concat the one-step next observation
+ data[i]['done_1'] = data_1[i]['done']
+ return get_train_sample(data, self._unroll_len)
diff --git a/DI-engine/ding/policy/dqn.py b/DI-engine/ding/policy/dqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1f6fdbb49d7c5f2fdd1b1a15c490855f0626db3
--- /dev/null
+++ b/DI-engine/ding/policy/dqn.py
@@ -0,0 +1,864 @@
+from typing import List, Dict, Any, Tuple
+from collections import namedtuple
+import copy
+import torch
+
+from ding.torch_utils import Adam, to_device, ContrastiveLoss
+from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('dqn')
+class DQNPolicy(Policy):
+ """
+ Overview:
+ Policy class of DQN algorithm, extended by Double DQN/Dueling DQN/PER/multi-step TD.
+
+ Config:
+ == ===================== ======== ============== ======================================= =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ===================== ======== ============== ======================================= =======================
+ 1 ``type`` str dqn | RL policy register name, refer to | This arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool False | Whether use priority(PER) | Priority sample,
+ | update priority
+ 5 | ``priority_IS`` bool False | Whether use Importance Sampling
+ | ``_weight`` | Weight to correct biased update. If
+ | True, priority must be True.
+ 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 7 ``nstep`` int 1, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 8 | ``model.dueling`` bool True | dueling head architecture
+ 9 | ``model.encoder`` list [32, 64, | Sequence of ``hidden_size`` of | default kernel_size
+ | ``_hidden`` (int) 64, 128] | subsequent conv layers and the | is [8, 4, 3]
+ | ``_size_list`` | final dense layer. | default stride is
+ | [4, 2 ,1]
+ 10 | ``model.dropout`` float None | Dropout rate for dropout layers. | [0,1]
+ | If set to ``None``
+ | means no dropout
+ 11 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary
+ | ``per_collect`` | after collector's one collection. | from envs. Bigger val
+ | Only valid in serial training | means more off-policy
+ 12 | ``learn.batch_`` int 64 | The number of samples of an iteration
+ | ``size``
+ 13 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
+ | ``_rate``
+ 14 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
+ | ``update_freq``
+ 15 | ``learn.target_`` float 0.005 | Frequence of target network update. | Soft(assign) update
+ | ``theta`` | Only one of [target_update_freq,
+ | | target_theta] should be set
+ 16 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
+ | ``done`` | calculation. | fake termination env
+ 17 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
+ | call of collector. | different envs
+ 18 ``collect.n_episode`` int 8 | The number of training episodes of a | only one of [n_sample
+ | call of collector | ,n_episode] should
+ | | be set
+ 19 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
+ | ``_len``
+ 20 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp',
+ | 'linear'].
+ 21 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1]
+ | ``start``
+ 22 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1]
+ | ``end``
+ 23 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set
+ | ``decay`` | decay=10000 means
+ | the exploration rate
+ | decay from start
+ | value to end value
+ | during decay length.
+ == ===================== ======== ============== ======================================= =======================
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='dqn',
+ # (bool) Whether to use cuda in policy.
+ cuda=False,
+ # (bool) Whether learning policy is the same as collecting data policy(on-policy).
+ on_policy=False,
+ # (bool) Whether to enable priority experience sample.
+ priority=False,
+ # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (float) Discount factor(gamma) for returns.
+ discount_factor=0.97,
+ # (int) The number of step for calculating target q_value.
+ nstep=1,
+ model=dict(
+ # (list(int)) Sequence of ``hidden_size`` of subsequent conv layers and the final dense layer.
+ encoder_hidden_size_list=[128, 128, 64],
+ ),
+ # learn_mode config
+ learn=dict(
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=3,
+ # (int) How many samples in a training batch.
+ batch_size=64,
+ # (float) The step size of gradient descent.
+ learning_rate=0.001,
+ # (int) Frequence of target network update.
+ # Only one of [target_update_freq, target_theta] should be set.
+ target_update_freq=100,
+ # (float) : Used for soft update of the target network.
+ # aka. Interpolation factor in EMA update for target network.
+ # Only one of [target_update_freq, target_theta] should be set.
+ target_theta=0.005,
+ # (bool) Whether ignore done(usually for max step termination env).
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with done is False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ # Only one of [n_sample, n_episode] shoule be set.
+ n_sample=8,
+ # (int) Split episodes or trajectories into pieces with length `unroll_len`.
+ unroll_len=1,
+ ),
+ eval=dict(), # for compability
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ # (float) Epsilon start value.
+ start=0.95,
+ # (float) Epsilon end value.
+ end=0.1,
+ # (int) Decay length(env step).
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer. Usually, larger buffer size is better.
+ replay_buffer_size=10000,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \
+ ``ding.model.template.q_learning``.
+ """
+ return 'dqn', ['ding.model.template.q_learning']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For DQN, it mainly contains \
+ optimizer, algorithm-specific arguments such as nstep and gamma, main and target model.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ # Optimizer
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+
+ # use model_wrapper for specialized demands of different modes
+ self._target_model = copy.deepcopy(self._model)
+ if 'target_update_freq' in self._cfg.learn:
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='assign',
+ update_kwargs={'freq': self._cfg.learn.target_update_freq}
+ )
+ elif 'target_theta' in self._cfg.learn:
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ else:
+ raise RuntimeError("DQN needs target network, please either indicate target_update_freq or target_theta")
+ self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss, q value, priority.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
+ dimension by some utility functions such as ``default_preprocess_learn``. \
+ For DQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
+ ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \
+ and ``value_gamma``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``.
+ """
+ # Data preprocessing operations, such as stack data, cpu to cuda device
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # Q-learning forward
+ self._learn_model.train()
+ self._target_model.train()
+ # Current q value (main model)
+ q_value = self._learn_model.forward(data['obs'])['logit']
+ # Target q value
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(data['next_obs'])['logit']
+ # Max q value action (main model), i.e. Double DQN
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+
+ data_n = q_nstep_td_data(
+ q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
+ )
+ value_gamma = data.get('value_gamma')
+ loss, td_error_per_sample = q_nstep_td_error(data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma)
+
+ # Update network parameters
+ self._optimizer.zero_grad()
+ loss.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ self._optimizer.step()
+
+ # Postprocessing operations, such as updating target model, return logged values and priority.
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'q_value': q_value.mean().item(),
+ 'target_q_value': target_q_value.mean().item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
+ # '[histogram]action_distribution': data['action'],
+ }
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return ['cur_lr', 'total_loss', 'q_value', 'target_q_value']
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode, usually including model, target_model and optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For DQN, it contains the \
+ collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism, and other \
+ algorithm-specific arguments such as unroll_len and nstep.
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+
+ .. tip::
+ Some variables need to initialize independently in different modes, such as gamma and nstep in DQN. This \
+ design is for the convenience of parallel execution of different policy modes.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._gamma = self._cfg.discount_factor # necessary for parallel
+ self._nstep = self._cfg.nstep # necessary for parallel
+ self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \
+ exploration, i.e., classic epsilon-greedy exploration strategy.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ - eps (:obj:`float`): The epsilon value for exploration.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \
+ dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, eps=eps)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. In DQN with nstep TD, a train sample is a processed transition. \
+ This method is usually used in collectors to execute necessary \
+ RL data preprocessing before training, which can help learner amortize revelant time consumption. \
+ In addition, you can also implement this method as an identity function and do the data processing \
+ in ``self._forward_learn`` method.
+ Arguments:
+ - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \
+ in the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is similar in format \
+ to input transitions, but may contain more data for training, such as nstep reward and target obs.
+ """
+ transitions = get_nstep_return_data(transitions, self._nstep, gamma=self._gamma)
+ return get_train_sample(transitions, self._unroll_len)
+
+ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor],
+ timestep: namedtuple) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For DQN, it contains obs, next_obs, action, reward, done.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.
+ - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \
+ as input. For DQN, it contains the action and the logit (q_value) of the action.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
+ except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
+ reward, done, info, etc.
+ Returns:
+ - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': policy_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For DQN, it contains the \
+ eval model to greedily select action with argmax q_value mechanism.
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for DQNPolicy: ``ding.policy.tests.test_dqn``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def calculate_priority(self, data: Dict[int, Any], update_target_model: bool = False) -> Dict[str, Any]:
+ """
+ Overview:
+ Calculate priority for replay buffer.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training.
+ - update_target_model (:obj:`bool`): Whether to update target model.
+ Returns:
+ - priority (:obj:`Dict[str, Any]`): Dict type priority data, values are python scalar or a list of scalars.
+ ArgumentsKeys:
+ - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
+ - optional: ``value_gamma``
+ ReturnsKeys:
+ - necessary: ``priority``
+ """
+
+ if update_target_model:
+ self._target_model.load_state_dict(self._learn_model.state_dict())
+
+ data = default_preprocess_learn(
+ data,
+ use_priority=False,
+ use_priority_IS_weight=False,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # Q-learning forward
+ # ====================
+ self._learn_model.eval()
+ self._target_model.eval()
+ with torch.no_grad():
+ # Current q value (main model)
+ q_value = self._learn_model.forward(data['obs'])['logit']
+ # Target q value
+ target_q_value = self._target_model.forward(data['next_obs'])['logit']
+ # Max q value action (main model), i.e. Double DQN
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+ data_n = q_nstep_td_data(
+ q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
+ )
+ value_gamma = data.get('value_gamma')
+ loss, td_error_per_sample = q_nstep_td_error(
+ data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma
+ )
+ return {'priority': td_error_per_sample.abs().tolist()}
+
+
+@POLICY_REGISTRY.register('dqn_stdim')
+class DQNSTDIMPolicy(DQNPolicy):
+ """
+ Overview:
+ Policy class of DQN algorithm, extended by ST-DIM auxiliary objectives.
+ ST-DIM paper link: https://arxiv.org/abs/1906.08226.
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str dqn_stdim | RL policy register name, refer to | This arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool False | Whether use priority(PER) | Priority sample,
+ | update priority
+ 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight
+ | ``_weight`` | to correct biased update. If True,
+ | priority must be True.
+ 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 7 ``nstep`` int 1, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 8 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ | ``_gpu``
+ 10 | ``learn.batch_`` int 64 | The number of samples of an iteration
+ | ``size``
+ 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
+ | ``_rate``
+ 12 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
+ | ``update_freq``
+ 13 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
+ | ``done`` | calculation. | fake termination env
+ 14 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
+ | call of collector. | different envs
+ 15 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
+ | ``_len``
+ 16 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp',
+ | 'linear'].
+ 17 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1]
+ | ``start``
+ 18 | ``other.eps.`` float 0.1 | end value of exploration rate | [0,1]
+ | ``end``
+ 19 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set
+ | ``decay`` | decay=10000 means
+ | the exploration rate
+ | decay from start
+ | value to end value
+ | during decay length.
+ 20 | ``aux_loss`` float 0.001 | the ratio of the auxiliary loss to | any real value,
+ | ``_weight`` | the TD loss | typically in
+ | [-0.1, 0.1].
+ == ==================== ======== ============== ======================================== =======================
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='dqn_stdim',
+ # (bool) Whether to use cuda in policy.
+ cuda=False,
+ # (bool) Whether to learning policy is the same as collecting data policy (on-policy).
+ on_policy=False,
+ # (bool) Whether to enable priority experience sample.
+ priority=False,
+ # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (float) Discount factor(gamma) for returns.
+ discount_factor=0.97,
+ # (int) The number of step for calculating target q_value.
+ nstep=1,
+ # (float) The weight of auxiliary loss to main loss.
+ aux_loss_weight=0.001,
+ # learn_mode config
+ learn=dict(
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=3,
+ # (int) How many samples in a training batch.
+ batch_size=64,
+ # (float) The step size of gradient descent.
+ learning_rate=0.001,
+ # (int) Frequence of target network update.
+ target_update_freq=100,
+ # (bool) Whether ignore done(usually for max step termination env).
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ # Only one of [n_sample, n_episode] shoule be set.
+ # n_sample=8,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(), # for compability
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ # (float) Epsilon start value.
+ start=0.95,
+ # (float) Epsilon end value.
+ end=0.1,
+ # (int) Decay length (env step).
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer. Usually, larger buffer size is better.
+ replay_buffer_size=10000,
+ ),
+ ),
+ )
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For DQNSTDIM, it first \
+ call super class's ``_init_learn`` method, then initialize extra auxiliary model, its optimizer, and the \
+ loss weight. This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ super()._init_learn()
+ x_size, y_size = self._get_encoding_size()
+ self._aux_model = ContrastiveLoss(x_size, y_size, **self._cfg.aux_model)
+ if self._cuda:
+ self._aux_model.cuda()
+ self._aux_optimizer = Adam(self._aux_model.parameters(), lr=self._cfg.learn.learning_rate)
+ self._aux_loss_weight = self._cfg.aux_loss_weight
+
+ def _get_encoding_size(self) -> Tuple[Tuple[int], Tuple[int]]:
+ """
+ Overview:
+ Get the input encoding size of the ST-DIM axuiliary model.
+ Returns:
+ - info_dict (:obj:`Tuple[Tuple[int], Tuple[int]]`): The encoding size without the first (Batch) dimension.
+ """
+ obs = self._cfg.model.obs_shape
+ if isinstance(obs, int):
+ obs = [obs]
+ test_data = {
+ "obs": torch.randn(1, *obs),
+ "next_obs": torch.randn(1, *obs),
+ }
+ if self._cuda:
+ test_data = to_device(test_data, self._device)
+ with torch.no_grad():
+ x, y = self._model_encode(test_data)
+ return x.size()[1:], y.size()[1:]
+
+ def _model_encode(self, data: dict) -> Tuple[torch.Tensor]:
+ """
+ Overview:
+ Get the encoding of the main model as input for the auxiliary model.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, same as the _forward_learn input.
+ Returns:
+ - (:obj:`Tuple[torch.Tensor]`): the tuple of two tensors to apply contrastive embedding learning. \
+ In ST-DIM algorithm, these two variables are the dqn encoding of `obs` and `next_obs` respectively.
+ """
+ assert hasattr(self._model, "encoder")
+ x = self._model.encoder(data["obs"])
+ y = self._model.encoder(data["next_obs"])
+ return x, y
+
+ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss, q value, priority, aux_loss.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
+ dimension by some utility functions such as ``default_preprocess_learn``. \
+ For DQNSTDIM, each element in list is a dict containing at least the following keys: ``obs``, \
+ ``action``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as \
+ ``weight`` and ``value_gamma``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+ """
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ # ======================
+ # Auxiliary model update
+ # ======================
+ # RL network encoding
+ # To train the auxiliary network, the gradients of x, y should be 0.
+ with torch.no_grad():
+ x_no_grad, y_no_grad = self._model_encode(data)
+ # the forward function of the auxiliary network
+ self._aux_model.train()
+ aux_loss_learn = self._aux_model.forward(x_no_grad, y_no_grad)
+ # the BP process of the auxiliary network
+ self._aux_optimizer.zero_grad()
+ aux_loss_learn.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._aux_model)
+ self._aux_optimizer.step()
+
+ # ====================
+ # Q-learning forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # Current q value (main model)
+ q_value = self._learn_model.forward(data['obs'])['logit']
+ # Target q value
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(data['next_obs'])['logit']
+ # Max q value action (main model)
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+
+ data_n = q_nstep_td_data(
+ q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
+ )
+ value_gamma = data.get('value_gamma')
+ bellman_loss, td_error_per_sample = q_nstep_td_error(
+ data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma
+ )
+
+ # ======================
+ # Compute auxiliary loss
+ # ======================
+ x, y = self._model_encode(data)
+ self._aux_model.eval()
+ aux_loss_eval = self._aux_model.forward(x, y) * self._aux_loss_weight
+ loss = aux_loss_eval + bellman_loss
+
+ # ====================
+ # Q-learning update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ self._optimizer.step()
+
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'bellman_loss': bellman_loss.item(),
+ 'aux_loss_learn': aux_loss_learn.item(),
+ 'aux_loss_eval': aux_loss_eval.item(),
+ 'total_loss': loss.item(),
+ 'q_value': q_value.mean().item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
+ # '[histogram]action_distribution': data['action'],
+ }
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return ['cur_lr', 'bellman_loss', 'aux_loss_learn', 'aux_loss_eval', 'total_loss', 'q_value']
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode, usually including model and optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ 'aux_optimizer': self._aux_optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+ self._aux_optimizer.load_state_dict(state_dict['aux_optimizer'])
diff --git a/DI-engine/ding/policy/dt.py b/DI-engine/ding/policy/dt.py
new file mode 100644
index 0000000000000000000000000000000000000000..005a6246440197bb544165fc641f25e5f0bb832b
--- /dev/null
+++ b/DI-engine/ding/policy/dt.py
@@ -0,0 +1,433 @@
+from typing import List, Dict, Any, Tuple, Optional
+from collections import namedtuple
+import torch.nn.functional as F
+import torch
+import numpy as np
+from ding.torch_utils import to_device
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_decollate
+from .base_policy import Policy
+
+
+@POLICY_REGISTRY.register('dt')
+class DTPolicy(Policy):
+ """
+ Overview:
+ Policy class of Decision Transformer algorithm in discrete environments.
+ Paper link: https://arxiv.org/abs/2106.01345.
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='dt',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (int) N-step reward for target q_value estimation
+ obs_shape=4,
+ action_shape=2,
+ rtg_scale=1000, # normalize returns to go
+ max_eval_ep_len=1000, # max len of one episode
+ batch_size=64, # training batch size
+ wt_decay=1e-4, # decay weight in optimizer
+ warmup_steps=10000, # steps for learning rate warmup
+ context_len=20, # length of transformer input
+ learning_rate=1e-4,
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For example about DQN, its registered name is ``dqn`` and the import_names is \
+ ``ding.model.template.q_learning``.
+ """
+ return 'dt', ['ding.model.template.dt']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For Decision Transformer, \
+ it mainly contains the optimizer, algorithm-specific arguments such as rtg_scale and lr scheduler.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ # rtg_scale: scale of `return to go`
+ # rtg_target: max target of `return to go`
+ # Our goal is normalize `return to go` to (0, 1), which will favour the covergence.
+ # As a result, we usually set rtg_scale == rtg_target.
+ self.rtg_scale = self._cfg.rtg_scale # normalize returns to go
+ self.rtg_target = self._cfg.rtg_target # max target reward_to_go
+ self.max_eval_ep_len = self._cfg.max_eval_ep_len # max len of one episode
+
+ lr = self._cfg.learning_rate # learning rate
+ wt_decay = self._cfg.wt_decay # weight decay
+ warmup_steps = self._cfg.warmup_steps # warmup steps for lr scheduler
+
+ self.clip_grad_norm_p = self._cfg.clip_grad_norm_p
+ self.context_len = self._cfg.model.context_len # K in decision transformer
+
+ self.state_dim = self._cfg.model.state_dim
+ self.act_dim = self._cfg.model.act_dim
+
+ self._learn_model = self._model
+ self._atari_env = 'state_mean' not in self._cfg
+ self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg
+
+ if self._atari_env:
+ self._optimizer = self._learn_model.configure_optimizers(wt_decay, lr)
+ else:
+ self._optimizer = torch.optim.AdamW(self._learn_model.parameters(), lr=lr, weight_decay=wt_decay)
+
+ self._scheduler = torch.optim.lr_scheduler.LambdaLR(
+ self._optimizer, lambda steps: min((steps + 1) / warmup_steps, 1)
+ )
+
+ self.max_env_score = -1.0
+
+ def _forward_learn(self, data: List[torch.Tensor]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the offline dataset and then returns the output \
+ result, including various training information such as loss, current learning rate.
+ Arguments:
+ - data (:obj:`List[torch.Tensor]`): The input data used for policy forward, including a series of \
+ processed torch.Tensor data, i.e., timesteps, states, actions, returns_to_go, traj_mask.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ """
+ self._learn_model.train()
+
+ timesteps, states, actions, returns_to_go, traj_mask = data
+
+ # The shape of `returns_to_go` may differ with different dataset (B x T or B x T x 1),
+ # and we need a 3-dim tensor
+ if len(returns_to_go.shape) == 2:
+ returns_to_go = returns_to_go.unsqueeze(-1)
+
+ if self._basic_discrete_env:
+ actions = actions.to(torch.long)
+ actions = actions.squeeze(-1)
+ action_target = torch.clone(actions).detach().to(self._device)
+
+ if self._atari_env:
+ state_preds, action_preds, return_preds = self._learn_model.forward(
+ timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go, tar=1
+ )
+ else:
+ state_preds, action_preds, return_preds = self._learn_model.forward(
+ timesteps=timesteps, states=states, actions=actions, returns_to_go=returns_to_go
+ )
+
+ if self._atari_env:
+ action_loss = F.cross_entropy(action_preds.reshape(-1, action_preds.size(-1)), action_target.reshape(-1))
+ else:
+ traj_mask = traj_mask.view(-1, )
+
+ # only consider non padded elements
+ action_preds = action_preds.view(-1, self.act_dim)[traj_mask > 0]
+
+ if self._cfg.model.continuous:
+ action_target = action_target.view(-1, self.act_dim)[traj_mask > 0]
+ action_loss = F.mse_loss(action_preds, action_target)
+ else:
+ action_target = action_target.view(-1)[traj_mask > 0]
+ action_loss = F.cross_entropy(action_preds, action_target)
+
+ self._optimizer.zero_grad()
+ action_loss.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ torch.nn.utils.clip_grad_norm_(self._learn_model.parameters(), self.clip_grad_norm_p)
+ self._optimizer.step()
+ self._scheduler.step()
+
+ return {
+ 'cur_lr': self._optimizer.state_dict()['param_groups'][0]['lr'],
+ 'action_loss': action_loss.detach().cpu().item(),
+ 'total_loss': action_loss.detach().cpu().item(),
+ }
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For DQN, it contains the \
+ eval model, some algorithm-specific parameters such as context_len, max_eval_ep_len, etc.
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. tip::
+ For the evaluation of complete episodes, we need to maintain some historical information for transformer \
+ inference. These variables need to be initialized in ``_init_eval`` and reset in ``_reset_eval`` when \
+ necessary.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ self._eval_model = self._model
+ # init data
+ self._device = torch.device(self._device)
+ self.rtg_scale = self._cfg.rtg_scale # normalize returns to go
+ self.rtg_target = self._cfg.rtg_target # max target reward_to_go
+ self.state_dim = self._cfg.model.state_dim
+ self.act_dim = self._cfg.model.act_dim
+ self.eval_batch_size = self._cfg.evaluator_env_num
+ self.max_eval_ep_len = self._cfg.max_eval_ep_len
+ self.context_len = self._cfg.model.context_len # K in decision transformer
+
+ self.t = [0 for _ in range(self.eval_batch_size)]
+ if self._cfg.model.continuous:
+ self.actions = torch.zeros(
+ (self.eval_batch_size, self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device
+ )
+ else:
+ self.actions = torch.zeros(
+ (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device
+ )
+ self._atari_env = 'state_mean' not in self._cfg
+ self._basic_discrete_env = not self._cfg.model.continuous and 'state_mean' in self._cfg
+ if self._atari_env:
+ self.states = torch.zeros(
+ (
+ self.eval_batch_size,
+ self.max_eval_ep_len,
+ ) + tuple(self.state_dim),
+ dtype=torch.float32,
+ device=self._device
+ )
+ self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)]
+ else:
+ self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)]
+ self.states = torch.zeros(
+ (self.eval_batch_size, self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device
+ )
+ self.state_mean = torch.from_numpy(np.array(self._cfg.state_mean)).to(self._device)
+ self.state_std = torch.from_numpy(np.array(self._cfg.state_std)).to(self._device)
+ self.timesteps = torch.arange(
+ start=0, end=self.max_eval_ep_len, step=1
+ ).repeat(self.eval_batch_size, 1).to(self._device)
+ self.rewards_to_go = torch.zeros(
+ (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device
+ )
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance, such as interacting with envs. \
+ Forward means that the policy gets some input data (current obs/return-to-go and historical information) \
+ from the envs and then returns the output data, such as the action to interact with the envs. \
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs and \
+ reward to calculate running return-to-go. The key of the dict is environment id and the value is the \
+ corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ Decision Transformer will do different operations for different types of envs in evaluation.
+ """
+ # save and forward
+ data_id = list(data.keys())
+
+ self._eval_model.eval()
+ with torch.no_grad():
+ if self._atari_env:
+ states = torch.zeros(
+ (
+ self.eval_batch_size,
+ self.context_len,
+ ) + tuple(self.state_dim),
+ dtype=torch.float32,
+ device=self._device
+ )
+ timesteps = torch.zeros((self.eval_batch_size, 1, 1), dtype=torch.long, device=self._device)
+ else:
+ states = torch.zeros(
+ (self.eval_batch_size, self.context_len, self.state_dim), dtype=torch.float32, device=self._device
+ )
+ timesteps = torch.zeros((self.eval_batch_size, self.context_len), dtype=torch.long, device=self._device)
+ if not self._cfg.model.continuous:
+ actions = torch.zeros(
+ (self.eval_batch_size, self.context_len, 1), dtype=torch.long, device=self._device
+ )
+ else:
+ actions = torch.zeros(
+ (self.eval_batch_size, self.context_len, self.act_dim), dtype=torch.float32, device=self._device
+ )
+ rewards_to_go = torch.zeros(
+ (self.eval_batch_size, self.context_len, 1), dtype=torch.float32, device=self._device
+ )
+ for i in data_id:
+ if self._atari_env:
+ self.states[i, self.t[i]] = data[i]['obs'].to(self._device)
+ else:
+ self.states[i, self.t[i]] = (data[i]['obs'].to(self._device) - self.state_mean) / self.state_std
+ self.running_rtg[i] = self.running_rtg[i] - (data[i]['reward'] / self.rtg_scale).to(self._device)
+ self.rewards_to_go[i, self.t[i]] = self.running_rtg[i]
+
+ if self.t[i] <= self.context_len:
+ if self._atari_env:
+ timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones(
+ (1, 1), dtype=torch.int64
+ ).to(self._device)
+ else:
+ timesteps[i] = self.timesteps[i, :self.context_len]
+ states[i] = self.states[i, :self.context_len]
+ actions[i] = self.actions[i, :self.context_len]
+ rewards_to_go[i] = self.rewards_to_go[i, :self.context_len]
+ else:
+ if self._atari_env:
+ timesteps[i] = min(self.t[i], self._cfg.model.max_timestep) * torch.ones(
+ (1, 1), dtype=torch.int64
+ ).to(self._device)
+ else:
+ timesteps[i] = self.timesteps[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
+ states[i] = self.states[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
+ actions[i] = self.actions[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
+ rewards_to_go[i] = self.rewards_to_go[i, self.t[i] - self.context_len + 1:self.t[i] + 1]
+ if self._basic_discrete_env:
+ actions = actions.squeeze(-1)
+ _, act_preds, _ = self._eval_model.forward(timesteps, states, actions, rewards_to_go)
+ del timesteps, states, actions, rewards_to_go
+
+ logits = act_preds[:, -1, :]
+ if not self._cfg.model.continuous:
+ if self._atari_env:
+ probs = F.softmax(logits, dim=-1)
+ act = torch.zeros((self.eval_batch_size, 1), dtype=torch.long, device=self._device)
+ for i in data_id:
+ act[i] = torch.multinomial(probs[i], num_samples=1)
+ else:
+ act = torch.argmax(logits, axis=1).unsqueeze(1)
+ else:
+ act = logits
+ for i in data_id:
+ self.actions[i, self.t[i]] = act[i] # TODO: self.actions[i] should be a queue when exceed max_t
+ self.t[i] += 1
+
+ if self._cuda:
+ act = to_device(act, 'cpu')
+ output = {'action': act}
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
+ """
+ Overview:
+ Reset some stateful variables for eval mode when necessary, such as the historical info of transformer \
+ for decision transformer. If ``data_id`` is None, it means to reset all the stateful \
+ varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \
+ different environments/episodes in evaluation in ``data_id`` will have different history.
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \
+ specified by ``data_id``.
+ """
+ # clean data
+ if data_id is None:
+ self.t = [0 for _ in range(self.eval_batch_size)]
+ self.timesteps = torch.arange(
+ start=0, end=self.max_eval_ep_len, step=1
+ ).repeat(self.eval_batch_size, 1).to(self._device)
+ if not self._cfg.model.continuous:
+ self.actions = torch.zeros(
+ (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.long, device=self._device
+ )
+ else:
+ self.actions = torch.zeros(
+ (self.eval_batch_size, self.max_eval_ep_len, self.act_dim),
+ dtype=torch.float32,
+ device=self._device
+ )
+ if self._atari_env:
+ self.states = torch.zeros(
+ (
+ self.eval_batch_size,
+ self.max_eval_ep_len,
+ ) + tuple(self.state_dim),
+ dtype=torch.float32,
+ device=self._device
+ )
+ self.running_rtg = [self.rtg_target for _ in range(self.eval_batch_size)]
+ else:
+ self.states = torch.zeros(
+ (self.eval_batch_size, self.max_eval_ep_len, self.state_dim),
+ dtype=torch.float32,
+ device=self._device
+ )
+ self.running_rtg = [self.rtg_target / self.rtg_scale for _ in range(self.eval_batch_size)]
+
+ self.rewards_to_go = torch.zeros(
+ (self.eval_batch_size, self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device
+ )
+ else:
+ for i in data_id:
+ self.t[i] = 0
+ if not self._cfg.model.continuous:
+ self.actions[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.long, device=self._device)
+ else:
+ self.actions[i] = torch.zeros(
+ (self.max_eval_ep_len, self.act_dim), dtype=torch.float32, device=self._device
+ )
+ if self._atari_env:
+ self.states[i] = torch.zeros(
+ (self.max_eval_ep_len, ) + tuple(self.state_dim), dtype=torch.float32, device=self._device
+ )
+ self.running_rtg[i] = self.rtg_target
+ else:
+ self.states[i] = torch.zeros(
+ (self.max_eval_ep_len, self.state_dim), dtype=torch.float32, device=self._device
+ )
+ self.running_rtg[i] = self.rtg_target / self.rtg_scale
+ self.timesteps[i] = torch.arange(start=0, end=self.max_eval_ep_len, step=1).to(self._device)
+ self.rewards_to_go[i] = torch.zeros((self.max_eval_ep_len, 1), dtype=torch.float32, device=self._device)
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return ['cur_lr', 'action_loss']
+
+ def _init_collect(self) -> None:
+ pass
+
+ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
+ pass
+
+ def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ pass
+
+ def _process_transition(self, obs: Any, policy_output: Dict[str, Any], timestep: namedtuple) -> Dict[str, Any]:
+ pass
diff --git a/DI-engine/ding/policy/edac.py b/DI-engine/ding/policy/edac.py
new file mode 100755
index 0000000000000000000000000000000000000000..0e8d44542e13d0a30eeb38ba5d4291c426e61bdb
--- /dev/null
+++ b/DI-engine/ding/policy/edac.py
@@ -0,0 +1,299 @@
+from typing import List, Dict, Any, Tuple, Union
+import copy
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.distributions import Normal, Independent
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \
+ qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .sac import SACPolicy
+from .dqn import DQNPolicy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('edac')
+class EDACPolicy(SACPolicy):
+ """
+ Overview:
+ Policy class of EDAC algorithm. https://arxiv.org/pdf/2110.01548.pdf
+
+ Config:
+ == ==================== ======== ============= ================================= =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============= ================================= =======================
+ 1 ``type`` str td3 | RL policy register name, refer | this arg is optional,
+ | to registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool True | Whether to use cuda for network |
+ 3 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for
+ | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/
+ | | buffer when training starts. | TD3.
+ 4 | ``model.policy_`` int 256 | Linear layer size for policy |
+ | ``embedding_size`` | network. |
+ 5 | ``model.soft_q_`` int 256 | Linear layer size for soft q |
+ | ``embedding_size`` | network. |
+ 6 | ``model.emsemble`` int 10 | Number of Q-ensemble network |
+ | ``_num`` | |
+ | | | is False.
+ 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3, when
+ | ``_rate_q`` | network. | model.value_network
+ | | | is True.
+ 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3, when
+ | ``_rate_policy`` | network. | model.value_network
+ | | | is True.
+ 9 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to None when
+ | ``_rate_value`` | network. | model.value_network
+ | | | is False.
+ 10 | ``learn.alpha`` float 1.0 | Entropy regularization | alpha is initiali-
+ | | coefficient. | zation for auto
+ | | | `alpha`, when
+ | | | auto_alpha is True
+ 11 | ``learn.eta`` bool True | Parameter of EDAC algorithm | Defalut to 1.0
+ 12 | ``learn.`` bool True | Determine whether to use | Temperature parameter
+ | ``auto_alpha`` | auto temperature parameter | determines the
+ | | `alpha`. | relative importance
+ | | | of the entropy term
+ | | | against the reward.
+ 13 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only
+ | ``ignore_done`` | done flag. | in halfcheetah env.
+ 14 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation
+ | ``target_theta`` | target network. | factor in polyak aver
+ | | | aging for target
+ | | | networks.
+ == ==================== ======== ============= ================================= =======================
+ """
+ config = dict(
+ # (str) RL policy register name
+ type='edac',
+ cuda=False,
+ on_policy=False,
+ multi_agent=False,
+ priority=False,
+ priority_IS_weight=False,
+ random_collect_size=10000,
+ model=dict(
+ # (bool type) ensemble_num:num of Q-network.
+ ensemble_num=10,
+ # (bool type) value_network: Determine whether to use value network as the
+ # original EDAC paper (arXiv 2110.01548).
+ # using value_network needs to set learning_rate_value, learning_rate_q,
+ # and learning_rate_policy in `cfg.policy.learn`.
+ # Default to False.
+ # value_network=False,
+
+ # (int) Hidden size for actor network head.
+ actor_head_hidden_size=256,
+
+ # (int) Hidden size for critic network head.
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ multi_gpu=False,
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_value=3e-4,
+ learning_rate_alpha=3e-4,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=1,
+ auto_alpha=True,
+ # (bool type) log_space: Determine whether to use auto `\alpha` in log space.
+ log_space=True,
+ # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with done is False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+ # (float) Weight uniform initialization range in the last output layer
+ init_w=3e-3,
+ # (float) Loss weight for conservative item.
+ min_q_weight=1.0,
+ # (bool) Whether to use entropy in target q.
+ with_q_entropy=False,
+ eta=0.1,
+ ),
+ collect=dict(
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(),
+ other=dict(
+ replay_buffer=dict(
+ # (int type) replay_buffer_size: Max size of replay buffer.
+ replay_buffer_size=1000000,
+ # (int type) max_use: Max use times of one data in the buffer.
+ # Data will be removed once used for too many times.
+ # Default to infinite.
+ # max_use=256,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+ """
+ return 'edac', ['ding.model.template.edac']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init q, value and policy's optimizers, algorithm config, main and target models.
+ """
+ super()._init_learn()
+ # EDAC special implementation
+ self._eta = self._cfg.learn.eta
+ self._with_q_entropy = self._cfg.learn.with_q_entropy
+ self._forward_learn_cnt = 0
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ loss_dict = {}
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if len(data.get('action').shape) == 1:
+ data['action'] = data['action'].reshape(-1, 1)
+
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ self._learn_model.train()
+ self._target_model.train()
+ obs = data['obs']
+ next_obs = data['next_obs']
+ reward = data['reward']
+ done = data['done']
+ acs = data['action']
+
+ # 1. predict q value
+ q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
+ with torch.no_grad():
+ (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit']
+
+ dist = Independent(Normal(mu, sigma), 1)
+ pred = dist.rsample()
+ next_action = torch.tanh(pred)
+ y = 1 - next_action.pow(2) + 1e-6
+ next_log_prob = dist.log_prob(pred).unsqueeze(-1)
+ next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True)
+
+ next_data = {'obs': next_obs, 'action': next_action}
+ target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']
+ # the value of a policy according to the maximum entropy objective
+
+ target_q_value, _ = torch.min(target_q_value, dim=0)
+ if self._with_q_entropy:
+ target_q_value -= self._alpha * next_log_prob.squeeze(-1)
+ target_q_value = self._gamma * (1 - done) * target_q_value + reward
+
+ weight = data['weight']
+ if weight is None:
+ weight = torch.ones_like(q_value)
+ td_error_per_sample = nn.MSELoss(reduction='none')(q_value, target_q_value).mean(dim=1).sum()
+ loss_dict['critic_loss'] = (td_error_per_sample * weight).mean()
+
+ # penalty term of EDAC
+ if self._eta > 0:
+ # [batch_size,dim] -> [Ensemble_num,batch_size,dim]
+ pre_obs = obs.unsqueeze(0).repeat_interleave(self._cfg.model.ensemble_num, dim=0)
+ pre_acs = acs.unsqueeze(0).repeat_interleave(self._cfg.model.ensemble_num, dim=0).requires_grad_(True)
+
+ # [Ensemble_num,batch_size]
+ q_pred_tile = self._learn_model.forward({
+ 'obs': pre_obs,
+ 'action': pre_acs
+ }, mode='compute_critic')['q_value'].requires_grad_(True)
+
+ q_pred_grads = torch.autograd.grad(q_pred_tile.sum(), pre_acs, retain_graph=True, create_graph=True)[0]
+ q_pred_grads = q_pred_grads / (torch.norm(q_pred_grads, p=2, dim=2).unsqueeze(-1) + 1e-10)
+ # [Ensemble_num,batch_size,act_dim] -> [batch_size,Ensemble_num,act_dim]
+ q_pred_grads = q_pred_grads.transpose(0, 1)
+
+ q_pred_grads = q_pred_grads @ q_pred_grads.permute(0, 2, 1)
+ masks = torch.eye(
+ self._cfg.model.ensemble_num, device=obs.device
+ ).unsqueeze(dim=0).repeat(q_pred_grads.size(0), 1, 1)
+ q_pred_grads = (1 - masks) * q_pred_grads
+ grad_loss = torch.mean(torch.sum(q_pred_grads, dim=(1, 2))) / (self._cfg.model.ensemble_num - 1)
+ loss_dict['critic_loss'] += grad_loss * self._eta
+
+ self._optimizer_q.zero_grad()
+ loss_dict['critic_loss'].backward()
+ self._optimizer_q.step()
+
+ (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit']
+ dist = Independent(Normal(mu, sigma), 1)
+ pred = dist.rsample()
+ action = torch.tanh(pred)
+ y = 1 - action.pow(2) + 1e-6
+ log_prob = dist.log_prob(pred).unsqueeze(-1)
+ log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
+
+ eval_data = {'obs': obs, 'action': action}
+ new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value']
+ new_q_value, _ = torch.min(new_q_value, dim=0)
+
+ # 8. compute policy loss
+ policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean()
+
+ loss_dict['policy_loss'] = policy_loss
+
+ # 9. update policy network
+ self._optimizer_policy.zero_grad()
+ loss_dict['policy_loss'].backward()
+ self._optimizer_policy.step()
+
+ # 10. compute alpha loss
+ if self._auto_alpha:
+ if self._log_space:
+ log_prob = log_prob + self._target_entropy
+ loss_dict['alpha_loss'] = -(self._log_alpha * log_prob.detach()).mean()
+
+ self._alpha_optim.zero_grad()
+ loss_dict['alpha_loss'].backward()
+ self._alpha_optim.step()
+ self._alpha = self._log_alpha.detach().exp()
+ else:
+ log_prob = log_prob + self._target_entropy
+ loss_dict['alpha_loss'] = -(self._alpha * log_prob.detach()).mean()
+
+ self._alpha_optim.zero_grad()
+ loss_dict['alpha_loss'].backward()
+ self._alpha_optim.step()
+ self._alpha = max(0, self._alpha)
+
+ loss_dict['total_loss'] = sum(loss_dict.values())
+
+ # =============
+ # after update
+ # =============
+ self._forward_learn_cnt += 1
+ # target update
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr_q': self._optimizer_q.defaults['lr'],
+ 'cur_lr_p': self._optimizer_policy.defaults['lr'],
+ 'priority': td_error_per_sample.abs().tolist(),
+ 'td_error': td_error_per_sample.detach().mean().item(),
+ 'alpha': self._alpha.item(),
+ 'target_q_value': target_q_value.detach().mean().item(),
+ **loss_dict
+ }
diff --git a/DI-engine/ding/policy/fqf.py b/DI-engine/ding/policy/fqf.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1ba86fd91ae45dc2c526b4865e6cd12bb7b01a8
--- /dev/null
+++ b/DI-engine/ding/policy/fqf.py
@@ -0,0 +1,260 @@
+from typing import List, Dict, Any, Tuple, Union
+import copy
+import torch
+
+from ding.torch_utils import Adam, RMSprop, to_device
+from ding.rl_utils import fqf_nstep_td_data, fqf_nstep_td_error, fqf_calculate_fraction_loss, \
+ get_train_sample, get_nstep_return_data
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .dqn import DQNPolicy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('fqf')
+class FQFPolicy(DQNPolicy):
+ r"""
+ Overview:
+ Policy class of FQF algorithm.
+
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str fqf | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool True | Whether use priority(PER) | priority sample,
+ | update priority
+ 6 | ``other.eps`` float 0.05 | Start value for epsilon decay. It's
+ | ``.start`` | small because rainbow use noisy net.
+ 7 | ``other.eps`` float 0.05 | End value for epsilon decay.
+ | ``.end``
+ 8 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 9 ``nstep`` int 3, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 10 | ``learn.update`` int 3 | How many updates(iterations) to train | this args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 11 ``learn.kappa`` float / | Threshold of Huber loss
+ == ==================== ======== ============== ======================================== =======================
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='fqf',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.97,
+ # (int) N-step reward for target q_value estimation
+ nstep=1,
+ learn=dict(
+
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=3,
+ batch_size=64,
+ learning_rate_fraction=2.5e-9,
+ learning_rate_quantile=0.00005,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (int) Frequence of target network update.
+ target_update_freq=100,
+ # (float) Threshold of Huber loss. In the FQF paper, this is denoted by kappa. Default to 1.0.
+ kappa=1.0,
+ # (float) Coefficient of entropy_loss.
+ ent_coef=0,
+ # (bool) Whether ignore done(usually for max step termination env)
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) Only one of [n_sample, n_step, n_episode] shoule be set
+ # n_sample=8,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(),
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=0.95,
+ end=0.1,
+ # (int) Decay length(env step)
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=10000, )
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'fqf', ['ding.model.template.q_learning']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the optimizer, algorithm config, main and target models.
+ """
+ self._priority = self._cfg.priority
+ # Optimizer
+ self._fraction_loss_optimizer = RMSprop(
+ self._model.head.quantiles_proposal.parameters(),
+ lr=self._cfg.learn.learning_rate_fraction,
+ alpha=0.95,
+ eps=0.00001
+ )
+ self._quantile_loss_optimizer = Adam(
+ list(self._model.head.Q.parameters()) + list(self._model.head.fqf_fc.parameters()) +
+ list(self._model.encoder.parameters()),
+ lr=self._cfg.learn.learning_rate_quantile,
+ eps=1e-2 / self._cfg.learn.batch_size
+ )
+
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+ self._kappa = self._cfg.learn.kappa
+ self._ent_coef = self._cfg.learn.ent_coef
+
+ # use model_wrapper for specialized demands of different modes
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='assign',
+ update_kwargs={'freq': self._cfg.learn.target_update_freq}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
+ """
+ data = default_preprocess_learn(
+ data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # Q-learning forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # Current q value (main model)
+ ret = self._learn_model.forward(data['obs'])
+ logit = ret['logit'] # [batch, action_dim(64)]
+ q_value = ret['q'] # [batch, num_quantiles, action_dim(64)]
+ quantiles = ret['quantiles'] # [batch, num_quantiles+1]
+ quantiles_hats = ret['quantiles_hats'] # [batch, num_quantiles], requires_grad = False
+ q_tau_i = ret['q_tau_i'] # [batch_size, num_quantiles-1, action_dim(64)]
+ entropies = ret['entropies'] # [batch, 1]
+
+ # Target q value
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(data['next_obs'])['q']
+ # Max q value action (main model)
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+
+ data_n = fqf_nstep_td_data(
+ q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], quantiles_hats,
+ data['weight']
+ )
+ value_gamma = data.get('value_gamma')
+
+ entropy_loss = -self._ent_coef * entropies.mean()
+
+ fraction_loss = fqf_calculate_fraction_loss(q_tau_i.detach(), q_value, quantiles, data['action']) + entropy_loss
+
+ quantile_loss, td_error_per_sample = fqf_nstep_td_error(
+ data_n, self._gamma, nstep=self._nstep, kappa=self._kappa, value_gamma=value_gamma
+ )
+
+ # compute grad norm of a network's parameters
+ def compute_grad_norm(model):
+ return torch.norm(torch.stack([torch.norm(p.grad.detach(), 2.0) for p in model.parameters()]), 2.0)
+
+ # ====================
+ # fraction_proposal network update
+ # ====================
+ self._fraction_loss_optimizer.zero_grad()
+ fraction_loss.backward(retain_graph=True)
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ with torch.no_grad():
+ total_norm_quantiles_proposal = compute_grad_norm(self._model.head.quantiles_proposal)
+ self._fraction_loss_optimizer.step()
+
+ # ====================
+ # Q-learning update
+ # ====================
+ self._quantile_loss_optimizer.zero_grad()
+ quantile_loss.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ with torch.no_grad():
+ total_norm_Q = compute_grad_norm(self._model.head.Q)
+ total_norm_fqf_fc = compute_grad_norm(self._model.head.fqf_fc)
+ total_norm_encoder = compute_grad_norm(self._model.encoder)
+ self._quantile_loss_optimizer.step()
+
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr_fraction_loss': self._fraction_loss_optimizer.defaults['lr'],
+ 'cur_lr_quantile_loss': self._quantile_loss_optimizer.defaults['lr'],
+ 'logit': logit.mean().item(),
+ 'fraction_loss': fraction_loss.item(),
+ 'quantile_loss': quantile_loss.item(),
+ 'total_norm_quantiles_proposal': total_norm_quantiles_proposal,
+ 'total_norm_Q': total_norm_Q,
+ 'total_norm_fqf_fc': total_norm_fqf_fc,
+ 'total_norm_encoder': total_norm_encoder,
+ 'priority': td_error_per_sample.abs().tolist(),
+ # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
+ '[histogram]action_distribution': data['action'],
+ '[histogram]quantiles_hats': quantiles_hats[0], # quantiles_hats.requires_grad = False
+ }
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return [
+ 'cur_lr_fraction_loss', 'cur_lr_quantile_loss', 'logit', 'fraction_loss', 'quantile_loss',
+ 'total_norm_quantiles_proposal', 'total_norm_Q', 'total_norm_fqf_fc', 'total_norm_encoder'
+ ]
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer_fraction_loss': self._fraction_loss_optimizer.state_dict(),
+ 'optimizer_quantile_loss': self._quantile_loss_optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._fraction_loss_optimizer.load_state_dict(state_dict['optimizer_fraction_loss'])
+ self._quantile_loss_optimizer.load_state_dict(state_dict['optimizer_quantile_loss'])
diff --git a/DI-engine/ding/policy/happo.py b/DI-engine/ding/policy/happo.py
new file mode 100644
index 0000000000000000000000000000000000000000..4cbd38324bbae0b750c1839962cf845c71903634
--- /dev/null
+++ b/DI-engine/ding/policy/happo.py
@@ -0,0 +1,734 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import torch
+import copy
+import numpy as np
+from torch.distributions import Independent, Normal
+
+from ding.torch_utils import Adam, to_device, to_dtype, unsqueeze, ContrastiveLoss
+from ding.rl_utils import happo_data, happo_error, happo_policy_error, happo_policy_data, \
+ v_nstep_td_data, v_nstep_td_error, get_train_sample, gae, gae_data, happo_error_continuous, \
+ get_gae
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY, split_data_generator, RunningMeanStd
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('happo')
+class HAPPOPolicy(Policy):
+ """
+ Overview:
+ Policy class of on policy version HAPPO algorithm. Paper link: https://arxiv.org/abs/2109.11251.
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='happo',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used)
+ on_policy=True,
+ # (bool) Whether to use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (bool) Whether to use Importance Sampling Weight to correct biased update due to priority.
+ # If True, priority must be True.
+ priority_IS_weight=False,
+ # (bool) Whether to recompurete advantages in each iteration of on-policy PPO
+ recompute_adv=True,
+ # (str) Which kind of action space used in PPOPolicy, ['discrete', 'continuous', 'hybrid']
+ action_space='discrete',
+ # (bool) Whether to use nstep return to calculate value target, otherwise, use return = adv + value
+ nstep_return=False,
+ # (bool) Whether to enable multi-agent training, i.e.: MAPPO
+ multi_agent=False,
+ # (bool) Whether to need policy data in process transition
+ transition_with_policy_data=True,
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.0,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=True,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=0.5,
+ ignore_done=False,
+ ),
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] shoule be set
+ # n_sample=64,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc)
+ gae_lambda=0.95,
+ ),
+ eval=dict(),
+ )
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For HAPPO, it mainly \
+ contains optimizer, algorithm-specific arguments such as loss weight, clip_ratio and recompute_adv. This \
+ method also executes some special network initializations and prepares running mean/std monitor for value.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPO"
+
+ assert self._cfg.action_space in ["continuous", "discrete"]
+ self._action_space = self._cfg.action_space
+ if self._cfg.learn.ppo_param_init:
+ for n, m in self._model.named_modules():
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.orthogonal_(m.weight)
+ torch.nn.init.zeros_(m.bias)
+ if self._action_space in ['continuous']:
+ # init log sigma
+ for agent_id in range(self._cfg.agent_num):
+ # if hasattr(self._model.agent_models[agent_id].actor_head, 'log_sigma_param'):
+ # torch.nn.init.constant_(self._model.agent_models[agent_id].actor_head.log_sigma_param, 1)
+ # The above initialization step has been changed to reparameterizationHead.
+ for m in list(self._model.agent_models[agent_id].critic.modules()) + \
+ list(self._model.agent_models[agent_id].actor.modules()):
+ if isinstance(m, torch.nn.Linear):
+ # orthogonal initialization
+ torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
+ torch.nn.init.zeros_(m.bias)
+ # do last policy layer scaling, this will make initial actions have (close to)
+ # 0 mean and std, and will help boost performances,
+ # see https://arxiv.org/abs/2006.05990, Fig.24 for details
+ for m in self._model.agent_models[agent_id].actor.modules():
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.zeros_(m.bias)
+ m.weight.data.copy_(0.01 * m.weight.data)
+
+ # Add the actor/critic parameters of each HAVACAgent in HAVAC to the parameter list of actor/critic_optimizer
+ actor_params = []
+ critic_params = []
+ for agent_idx in range(self._model.agent_num):
+ actor_params.append({'params': self._model.agent_models[agent_idx].actor.parameters()})
+ critic_params.append({'params': self._model.agent_models[agent_idx].critic.parameters()})
+
+ self._actor_optimizer = Adam(
+ actor_params,
+ lr=self._cfg.learn.learning_rate,
+ grad_clip_type=self._cfg.learn.grad_clip_type,
+ clip_value=self._cfg.learn.grad_clip_value,
+ # eps = 1e-5,
+ )
+
+ self._critic_optimizer = Adam(
+ critic_params,
+ lr=self._cfg.learn.critic_learning_rate,
+ grad_clip_type=self._cfg.learn.grad_clip_type,
+ clip_value=self._cfg.learn.grad_clip_value,
+ # eps = 1e-5,
+ )
+
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ # self._learn_model = model_wrap(
+ # self._model,
+ # wrapper_name='hidden_state',
+ # state_num=self._cfg.learn.batch_size,
+ # init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ # )
+
+ # Algorithm config
+ self._value_weight = self._cfg.learn.value_weight
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._clip_ratio = self._cfg.learn.clip_ratio
+ self._adv_norm = self._cfg.learn.adv_norm
+ self._value_norm = self._cfg.learn.value_norm
+ if self._value_norm:
+ self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)
+ self._gamma = self._cfg.collect.discount_factor
+ self._gae_lambda = self._cfg.collect.gae_lambda
+ self._recompute_adv = self._cfg.recompute_adv
+ # Main model
+ self._learn_model.reset()
+
+ def prepocess_data_agent(self, data: Dict[str, Any]):
+ """
+ Overview:
+ Preprocess data for agent dim. This function is used in learn mode. \
+ It will be called recursively to process nested dict data. \
+ It will transpose the data with shape (B, agent_num, ...) to (agent_num, B, ...). \
+ Arguments:
+ - data (:obj:`dict`): Dict type data, where each element is the data of an agent of dict type.
+ Returns:
+ - ret (:obj:`dict`): Dict type data, where each element is the data of an agent of dict type.
+ """
+ ret = {}
+ for key, value in data.items():
+ if isinstance(value, dict):
+ ret[key] = self.prepocess_data_agent(value)
+ elif isinstance(value, torch.Tensor) and len(value.shape) > 1:
+ ret[key] = value.transpose(0, 1)
+ else:
+ ret[key] = value
+ return ret
+
+ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): List type data, where each element is the data of an agent of dict type.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`):
+ Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \
+ adv_abs_max, approx_kl, clipfrac
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss, clipfrac, approx_kl.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including the latest \
+ collected training samples for on-policy algorithms like HAPPO. For each element in list, the key of \
+ dict is the name of data items and the value is the corresponding data. Usually, the value is \
+ torch.Tensor or np.ndarray or there dict/list combinations. In the ``_forward_learn`` method, data \
+ often need to first be stacked in the batch dimension by some utility functions such as \
+ ``default_preprocess_learn``. \
+ For HAPPO, each element in list is a dict containing at least the following keys: ``obs``, \
+ ``action``, ``reward``, ``logit``, ``value``, ``done``. Sometimes, it also contains other keys \
+ such as ``weight``.
+ Returns:
+ - return_infos (:obj:`List[Dict[str, Any]]`): The information list that indicated training result, each \
+ training iteration contains append a information dict into the final list. The list will be precessed \
+ and recorded in text log and tensorboard. The value of the dict must be python scalar or a list of \
+ scalars. For the detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. tip::
+ The training procedure of HAPPO is three for loops. The outermost loop trains each agent separately. \
+ The middle loop trains all the collected training samples with ``epoch_per_collect`` epochs. The inner \
+ loop splits all the data into different mini-batch with the length of ``batch_size``.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for HAPPOPolicy: ``ding.policy.tests.test_happo``.
+ """
+ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False)
+ all_data_len = data['obs']['agent_state'].shape[0]
+ # fator is the ratio of the old and new strategies of the first m-1 agents, initialized to 1.
+ # Each transition has its own factor. ref: http://arxiv.org/abs/2109.11251
+ factor = torch.ones(all_data_len, 1) # (B, 1)
+ if self._cuda:
+ data = to_device(data, self._device)
+ factor = to_device(factor, self._device)
+ # process agent dim
+ data = self.prepocess_data_agent(data)
+ # ====================
+ # PPO forward
+ # ====================
+ return_infos = []
+ self._learn_model.train()
+
+ for agent_id in range(self._cfg.agent_num):
+ agent_data = {}
+ for key, value in data.items():
+ if value is not None:
+ if type(value) is dict:
+ agent_data[key] = {k: v[agent_id] for k, v in value.items()} # not feasible for rnn
+ elif len(value.shape) > 1:
+ agent_data[key] = data[key][agent_id]
+ else:
+ agent_data[key] = data[key]
+ else:
+ agent_data[key] = data[key]
+
+ # update factor
+ agent_data['factor'] = factor
+ # calculate old_logits of all data in buffer for later factor
+ inputs = {
+ 'obs': agent_data['obs'],
+ # 'actor_prev_state': agent_data['actor_prev_state'],
+ # 'critic_prev_state': agent_data['critic_prev_state'],
+ }
+ old_logits = self._learn_model.forward(agent_id, inputs, mode='compute_actor')['logit']
+
+ for epoch in range(self._cfg.learn.epoch_per_collect):
+ if self._recompute_adv: # calculate new value using the new updated value network
+ with torch.no_grad():
+ inputs['obs'] = agent_data['obs']
+ # value = self._learn_model.forward(agent_id, agent_data['obs'], mode='compute_critic')['value']
+ value = self._learn_model.forward(agent_id, inputs, mode='compute_critic')['value']
+ inputs['obs'] = agent_data['next_obs']
+ next_value = self._learn_model.forward(agent_id, inputs, mode='compute_critic')['value']
+ if self._value_norm:
+ value *= self._running_mean_std.std
+ next_value *= self._running_mean_std.std
+
+ traj_flag = agent_data.get('traj_flag', None) # traj_flag indicates termination of trajectory
+ compute_adv_data = gae_data(
+ value, next_value, agent_data['reward'], agent_data['done'], traj_flag
+ )
+ agent_data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda)
+
+ unnormalized_returns = value + agent_data['adv']
+
+ if self._value_norm:
+ agent_data['value'] = value / self._running_mean_std.std
+ agent_data['return'] = unnormalized_returns / self._running_mean_std.std
+ self._running_mean_std.update(unnormalized_returns.cpu().numpy())
+ else:
+ agent_data['value'] = value
+ agent_data['return'] = unnormalized_returns
+
+ else: # don't recompute adv
+ if self._value_norm:
+ unnormalized_return = agent_data['adv'] + agent_data['value'] * self._running_mean_std.std
+ agent_data['return'] = unnormalized_return / self._running_mean_std.std
+ self._running_mean_std.update(unnormalized_return.cpu().numpy())
+ else:
+ agent_data['return'] = agent_data['adv'] + agent_data['value']
+
+ for batch in split_data_generator(agent_data, self._cfg.learn.batch_size, shuffle=True):
+ inputs = {
+ 'obs': batch['obs'],
+ # 'actor_prev_state': batch['actor_prev_state'],
+ # 'critic_prev_state': batch['critic_prev_state'],
+ }
+ output = self._learn_model.forward(agent_id, inputs, mode='compute_actor_critic')
+ adv = batch['adv']
+ if self._adv_norm:
+ # Normalize advantage in a train_batch
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
+
+ # Calculate happo error
+ if self._action_space == 'continuous':
+ happo_batch = happo_data(
+ output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
+ batch['return'], batch['weight'], batch['factor']
+ )
+ happo_loss, happo_info = happo_error_continuous(happo_batch, self._clip_ratio)
+ elif self._action_space == 'discrete':
+ happo_batch = happo_data(
+ output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
+ batch['return'], batch['weight'], batch['factor']
+ )
+ happo_loss, happo_info = happo_error(happo_batch, self._clip_ratio)
+ wv, we = self._value_weight, self._entropy_weight
+ total_loss = happo_loss.policy_loss + wv * happo_loss.value_loss - we * happo_loss.entropy_loss
+
+ # actor update
+ # critic update
+ self._actor_optimizer.zero_grad()
+ self._critic_optimizer.zero_grad()
+ total_loss.backward()
+ self._actor_optimizer.step()
+ self._critic_optimizer.step()
+
+ return_info = {
+ 'agent{}_cur_lr'.format(agent_id): self._actor_optimizer.defaults['lr'],
+ 'agent{}_total_loss'.format(agent_id): total_loss.item(),
+ 'agent{}_policy_loss'.format(agent_id): happo_loss.policy_loss.item(),
+ 'agent{}_value_loss'.format(agent_id): happo_loss.value_loss.item(),
+ 'agent{}_entropy_loss'.format(agent_id): happo_loss.entropy_loss.item(),
+ 'agent{}_adv_max'.format(agent_id): adv.max().item(),
+ 'agent{}_adv_mean'.format(agent_id): adv.mean().item(),
+ 'agent{}_value_mean'.format(agent_id): output['value'].mean().item(),
+ 'agent{}_value_max'.format(agent_id): output['value'].max().item(),
+ 'agent{}_approx_kl'.format(agent_id): happo_info.approx_kl,
+ 'agent{}_clipfrac'.format(agent_id): happo_info.clipfrac,
+ }
+ if self._action_space == 'continuous':
+ return_info.update(
+ {
+ 'agent{}_act'.format(agent_id): batch['action'].float().mean().item(),
+ 'agent{}_mu_mean'.format(agent_id): output['logit']['mu'].mean().item(),
+ 'agent{}_sigma_mean'.format(agent_id): output['logit']['sigma'].mean().item(),
+ }
+ )
+ return_infos.append(return_info)
+ # calculate the factor
+ inputs = {
+ 'obs': agent_data['obs'],
+ # 'actor_prev_state': agent_data['actor_prev_state'],
+ }
+ new_logits = self._learn_model.forward(agent_id, inputs, mode='compute_actor')['logit']
+ if self._cfg.action_space == 'discrete':
+ dist_new = torch.distributions.categorical.Categorical(logits=new_logits)
+ dist_old = torch.distributions.categorical.Categorical(logits=old_logits)
+ elif self._cfg.action_space == 'continuous':
+ dist_new = Normal(new_logits['mu'], new_logits['sigma'])
+ dist_old = Normal(old_logits['mu'], old_logits['sigma'])
+ logp_new = dist_new.log_prob(agent_data['action'])
+ logp_old = dist_old.log_prob(agent_data['action'])
+ if len(logp_new.shape) > 1:
+ # for logp with shape(B, action_shape), we need to calculate the product of all action dimensions.
+ factor = factor * torch.prod(
+ torch.exp(logp_new - logp_old), dim=-1
+ ).reshape(all_data_len, 1).detach() # attention the shape
+ else:
+ # for logp with shape(B, ), directly calculate factor
+ factor = factor * torch.exp(logp_new - logp_old).reshape(all_data_len, 1).detach()
+ return return_infos
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode optimizer and model.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn mode. It contains the \
+ state_dict of current policy network and optimizer.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'actor_optimizer': self._actor_optimizer.state_dict(),
+ 'critic_optimizer': self._critic_optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict of learn mode optimizer and model.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn mode. It contains the state_dict \
+ of current policy network and optimizer.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._actor_optimizer.load_state_dict(state_dict['actor_optimizer'])
+ self._critic_optimizer.load_state_dict(state_dict['critic_optimizer'])
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For HAPPO, it contains \
+ the collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \
+ discrete action space), and other algorithm-specific arguments such as unroll_len and gae_lambda.
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+
+ .. tip::
+ Some variables need to initialize independently in different modes, such as gamma and gae_lambda in PPO. \
+ This design is for the convenience of parallel execution of different policy modes.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ assert self._cfg.action_space in ["continuous", "discrete"]
+ self._action_space = self._cfg.action_space
+ if self._action_space == 'continuous':
+ self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample')
+ elif self._action_space == 'discrete':
+ self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
+ self._collect_model.reset()
+ self._gamma = self._cfg.collect.discount_factor
+ self._gae_lambda = self._cfg.collect.gae_lambda
+ self._recompute_adv = self._cfg.recompute_adv
+
+ def _forward_collect(self, data: Dict[int, Any]) -> dict:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data (action logit and value) for learn mode defined in ``self._process_transition`` \
+ method. The key of the dict is the same as the input data, i.e. environment id.
+
+ .. tip::
+ If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \
+ related data as extra keyword arguments of this method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for HAPPOPolicy: ``ding.policy.tests.test_happo``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {k: v.transpose(0, 1) for k, v in data.items()} # not feasible for rnn
+ self._collect_model.eval()
+ with torch.no_grad():
+ outputs = []
+ for agent_id in range(self._cfg.agent_num):
+ # output = self._collect_model.forward(agent_id, data, mode='compute_actor_critic')
+ single_agent_obs = {k: v[agent_id] for k, v in data.items()}
+ input = {
+ 'obs': single_agent_obs,
+ }
+ output = self._collect_model.forward(agent_id, input, mode='compute_actor_critic')
+ outputs.append(output)
+ # transfer data from (M, B, N)->(B, M, N)
+ result = {}
+ for key in outputs[0].keys():
+ if isinstance(outputs[0][key], dict):
+ subkeys = outputs[0][key].keys()
+ stacked_subvalues = {}
+ for subkey in subkeys:
+ stacked_subvalues[subkey] = \
+ torch.stack([output[key][subkey] for output in outputs], dim=0).transpose(0, 1)
+ result[key] = stacked_subvalues
+ else:
+ # If Value is tensor, stack it directly
+ if isinstance(outputs[0][key], torch.Tensor):
+ result[key] = torch.stack([output[key] for output in outputs], dim=0).transpose(0, 1)
+ else:
+ # If it is not tensor, assume that it is a non-stackable data type \
+ # (such as int, float, etc.), and directly retain the original value
+ result[key] = [output[key] for output in outputs]
+ output = result
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For HAPPO, it contains obs, next_obs, action, reward, done, logit, value.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): The env observation of current timestep.
+ - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \
+ as input. For PPO, it contains the state value, action and the logit of the action.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
+ except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
+ reward, done, info, etc.
+ Returns:
+ - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.
+
+ .. note::
+ ``next_obs`` is used to calculate nstep return when necessary, so we place in into transition by default. \
+ You can delete this field to save memory occupancy if you do not need nstep return.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': model_output['action'],
+ 'logit': model_output['logit'],
+ 'value': model_output['value'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. In HAPPO, a train sample is a processed transition with new computed \
+ ``traj_flag`` and ``adv`` field. This method is usually used in collectors to execute necessary \
+ RL data preprocessing before training, which can help learner amortize revelant time consumption. \
+ In addition, you can also implement this method as an identity function and do the data processing \
+ in ``self._forward_learn`` method.
+ Arguments:
+ - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \
+ as input transitions, but may contain more data for training, such as GAE advantage.
+ """
+ data = to_device(data, self._device)
+ for transition in data:
+ transition['traj_flag'] = copy.deepcopy(transition['done'])
+ data[-1]['traj_flag'] = True
+
+ if self._cfg.learn.ignore_done:
+ data[-1]['done'] = False
+
+ if data[-1]['done']:
+ last_value = torch.zeros_like(data[-1]['value'])
+ else:
+ with torch.no_grad():
+ last_values = []
+ for agent_id in range(self._cfg.agent_num):
+ inputs = {'obs': {k: unsqueeze(v[agent_id], 0) for k, v in data[-1]['next_obs'].items()}}
+ last_value = self._collect_model.forward(agent_id, inputs, mode='compute_actor_critic')['value']
+ last_values.append(last_value)
+ last_value = torch.cat(last_values)
+ if len(last_value.shape) == 2: # multi_agent case:
+ last_value = last_value.squeeze(0)
+ if self._value_norm:
+ last_value *= self._running_mean_std.std
+ for i in range(len(data)):
+ data[i]['value'] *= self._running_mean_std.std
+ data = get_gae(
+ data,
+ to_device(last_value, self._device),
+ gamma=self._gamma,
+ gae_lambda=self._gae_lambda,
+ cuda=False,
+ )
+ if self._value_norm:
+ for i in range(len(data)):
+ data[i]['value'] /= self._running_mean_std.std
+
+ # remove next_obs for save memory when not recompute adv
+ if not self._recompute_adv:
+ for i in range(len(data)):
+ data[i].pop('next_obs')
+ return get_train_sample(data, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For PPO, it contains the \
+ eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete action).
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ assert self._cfg.action_space in ["continuous", "discrete"]
+ self._action_space = self._cfg.action_space
+ if self._action_space == 'continuous':
+ self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample')
+ elif self._action_space == 'discrete':
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs. ``_forward_eval`` in HAPPO often uses deterministic sample method to \
+ get actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \
+ exploitation.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for HAPPOPolicy: ``ding.policy.tests.test_happo``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ # transfer data from (B, M, N)->(M, B, N)
+ data = {k: v.transpose(0, 1) for k, v in data.items()} # not feasible for rnn
+ self._eval_model.eval()
+ with torch.no_grad():
+ outputs = []
+ for agent_id in range(self._cfg.agent_num):
+ single_agent_obs = {k: v[agent_id] for k, v in data.items()}
+ input = {
+ 'obs': single_agent_obs,
+ }
+ output = self._eval_model.forward(agent_id, input, mode='compute_actor')
+ outputs.append(output)
+ output = self.revert_agent_data(outputs)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For example about HAPPO, its registered name is ``happo`` and the import_names is \
+ ``ding.model.template.havac``.
+ """
+ return 'havac', ['ding.model.template.havac']
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ variables = super()._monitor_vars_learn() + [
+ 'policy_loss',
+ 'value_loss',
+ 'entropy_loss',
+ 'adv_max',
+ 'adv_mean',
+ 'approx_kl',
+ 'clipfrac',
+ 'value_max',
+ 'value_mean',
+ ]
+ if self._action_space == 'continuous':
+ variables += ['mu_mean', 'sigma_mean', 'sigma_grad', 'act']
+ prefixes = [f'agent{i}_' for i in range(self._cfg.agent_num)]
+ variables = [prefix + var for prefix in prefixes for var in variables]
+ return variables
+
+ def revert_agent_data(self, data: list):
+ """
+ Overview:
+ Revert the data of each agent to the original data format.
+ Arguments:
+ - data (:obj:`list`): List type data, where each element is the data of an agent of dict type.
+ Returns:
+ - ret (:obj:`dict`): Dict type data, where each element is the data of an agent of dict type.
+ """
+ ret = {}
+ # Traverse all keys of the first output
+ for key in data[0].keys():
+ if isinstance(data[0][key], torch.Tensor):
+ # If the value corresponding to the current key is tensor, stack N tensors
+ stacked_tensor = torch.stack([output[key] for output in data], dim=0)
+ ret[key] = stacked_tensor.transpose(0, 1)
+ elif isinstance(data[0][key], dict):
+ # If the value corresponding to the current key is a dictionary, recursively \
+ # call the function to process the contents inside the dictionary.
+ ret[key] = self.revert_agent_data([output[key] for output in data])
+ return ret
diff --git a/DI-engine/ding/policy/ibc.py b/DI-engine/ding/policy/ibc.py
new file mode 100644
index 0000000000000000000000000000000000000000..b39e14f53aad834fc47b487044aa12a91e74b7cc
--- /dev/null
+++ b/DI-engine/ding/policy/ibc.py
@@ -0,0 +1,186 @@
+from typing import Dict, Any, List, Tuple
+from collections import namedtuple
+from easydict import EasyDict
+
+import torch
+import torch.nn.functional as F
+
+from ding.model import model_wrap
+from ding.torch_utils import to_device
+from ding.utils.data import default_collate, default_decollate
+from ding.utils import POLICY_REGISTRY
+from .bc import BehaviourCloningPolicy
+from ding.model.template.ebm import create_stochastic_optimizer
+from ding.model.template.ebm import StochasticOptimizer, MCMC, AutoRegressiveDFO
+from ding.torch_utils import unsqueeze_repeat
+from ding.utils import EasyTimer
+
+
+@POLICY_REGISTRY.register('ibc')
+class IBCPolicy(BehaviourCloningPolicy):
+ r"""
+ Overview:
+ Implicit Behavior Cloning
+ https://arxiv.org/abs/2109.00137.pdf
+
+ .. note::
+ The code is adapted from the pytorch version of IBC https://github.com/kevinzakka/ibc,
+ which only supports the derivative-free optimization (dfo) variants.
+ This implementation moves a step forward and supports all variants of energy-based model
+ mentioned in the paper (dfo, autoregressive dfo, and mcmc).
+ """
+
+ config = dict(
+ type='ibc',
+ cuda=False,
+ on_policy=False,
+ continuous=True,
+ model=dict(stochastic_optim=dict(type='mcmc', )),
+ learn=dict(
+ train_epoch=30,
+ batch_size=256,
+ optim=dict(
+ learning_rate=1e-5,
+ weight_decay=0.0,
+ beta1=0.9,
+ beta2=0.999,
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'ebm', ['ding.model.template.ebm']
+
+ def _init_learn(self):
+ self._timer = EasyTimer(cuda=self._cfg.cuda)
+ self._sync_timer = EasyTimer(cuda=self._cfg.cuda)
+ optim_cfg = self._cfg.learn.optim
+ self._optimizer = torch.optim.AdamW(
+ self._model.parameters(),
+ lr=optim_cfg.learning_rate,
+ weight_decay=optim_cfg.weight_decay,
+ betas=(optim_cfg.beta1, optim_cfg.beta2),
+ )
+ self._stochastic_optimizer: StochasticOptimizer = \
+ create_stochastic_optimizer(self._device, self._cfg.model.stochastic_optim)
+ self._learn_model = model_wrap(self._model, 'base')
+ self._learn_model.reset()
+
+ def _forward_learn(self, data):
+ with self._timer:
+ data = default_collate(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._learn_model.train()
+
+ loss_dict = dict()
+
+ # obs: (B, O)
+ # action: (B, A)
+ obs, action = data['obs'], data['action']
+ # When action/observation space is 1, the action/observation dimension will
+ # be squeezed in the first place, therefore unsqueeze there to make the data
+ # compatiable with the ibc pipeline.
+ if len(obs.shape) == 1:
+ obs = obs.unsqueeze(-1)
+ if len(action.shape) == 1:
+ action = action.unsqueeze(-1)
+
+ # N refers to the number of negative samples, i.e. self._stochastic_optimizer.inference_samples.
+ # (B, N, O), (B, N, A)
+ obs, negatives = self._stochastic_optimizer.sample(obs, self._learn_model)
+
+ # (B, N+1, A)
+ targets = torch.cat([action.unsqueeze(dim=1), negatives], dim=1)
+ # (B, N+1, O)
+ obs = torch.cat([obs[:, :1], obs], dim=1)
+
+ permutation = torch.rand(targets.shape[0], targets.shape[1]).argsort(dim=1)
+ targets = targets[torch.arange(targets.shape[0]).unsqueeze(-1), permutation]
+
+ # (B, )
+ ground_truth = (permutation == 0).nonzero()[:, 1].to(self._device)
+
+ # (B, N+1) for ebm
+ # (B, N+1, A) for autoregressive ebm
+ energy = self._learn_model.forward(obs, targets)
+
+ logits = -1.0 * energy
+ if isinstance(self._stochastic_optimizer, AutoRegressiveDFO):
+ # autoregressive case
+ # (B, A)
+ ground_truth = unsqueeze_repeat(ground_truth, logits.shape[-1], -1)
+ loss = F.cross_entropy(logits, ground_truth)
+ loss_dict['ebm_loss'] = loss.item()
+
+ if isinstance(self._stochastic_optimizer, MCMC):
+ grad_penalty = self._stochastic_optimizer.grad_penalty(obs, targets, self._learn_model)
+ loss += grad_penalty
+ loss_dict['grad_penalty'] = grad_penalty.item()
+ loss_dict['total_loss'] = loss.item()
+
+ self._optimizer.zero_grad()
+ loss.backward()
+ with self._sync_timer:
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ sync_time = self._sync_timer.value
+ self._optimizer.step()
+
+ total_time = self._timer.value
+
+ return {
+ 'total_time': total_time,
+ 'sync_time': sync_time,
+ **loss_dict,
+ }
+
+ def _monitor_vars_learn(self):
+ if isinstance(self._stochastic_optimizer, MCMC):
+ return ['total_loss', 'ebm_loss', 'grad_penalty', 'total_time', 'sync_time']
+ else:
+ return ['total_loss', 'ebm_loss', 'total_time', 'sync_time']
+
+ def _init_eval(self):
+ self._eval_model = model_wrap(self._model, wrapper_name='base')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ tensor_input = isinstance(data, torch.Tensor)
+ if not tensor_input:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ self._eval_model.eval()
+ output = self._stochastic_optimizer.infer(data, self._eval_model)
+ output = dict(action=output)
+
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ if tensor_input:
+ return output
+ else:
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def set_statistic(self, statistics: EasyDict) -> None:
+ self._stochastic_optimizer.set_action_bounds(statistics.action_bounds)
+
+ # =================================================================== #
+ # Implicit Behavioral Cloning does not need `collect`-related functions
+ # =================================================================== #
+ def _init_collect(self):
+ raise NotImplementedError
+
+ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
+ raise NotImplementedError
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ raise NotImplementedError
+
+ def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ raise NotImplementedError
diff --git a/DI-engine/ding/policy/il.py b/DI-engine/ding/policy/il.py
new file mode 100644
index 0000000000000000000000000000000000000000..77989facecaabb4557fa1fb043f23435b1689631
--- /dev/null
+++ b/DI-engine/ding/policy/il.py
@@ -0,0 +1,231 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import torch
+import torch.nn as nn
+
+from ding.torch_utils import Adam, to_device
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+FootballKaggle5thPlaceModel = None
+
+
+@POLICY_REGISTRY.register('IL')
+class ILPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of Imitation learning algorithm
+ Interface:
+ __init__, set_setting, __repr__, state_dict_handle
+ Property:
+ learn_mode, collect_mode, eval_mode
+ """
+ config = dict(
+ type='IL',
+ cuda=True,
+ # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
+ on_policy=False,
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ learn=dict(
+
+ # (int) collect n_episode data, train model n_iteration time
+ update_per_collect=20,
+ # (int) the number of data for a train iteration
+ batch_size=64,
+ # (float) gradient-descent step size
+ learning_rate=0.0002,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration time
+ # n_sample=128,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=800, ), ),
+ other=dict(
+ replay_buffer=dict(
+ replay_buffer_size=100000,
+ # (int) max use count of data, if count is bigger than this value,
+ # the data will be removed from buffer
+ max_reuse=10,
+ ),
+ command=dict(),
+ ),
+ )
+
+ # TODO different collect model and learn model
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'football_iql', ['dizoo.gfootball.model.iql.iql_network']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init optimizers, algorithm config, main and target models.
+ """
+ # actor and critic optimizer
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+
+ # main and target models
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ self._learn_model.train()
+ self._learn_model.reset()
+
+ self._forward_learn_cnt = 0 # count iterations
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses.
+ """
+ data = default_collate(data, cat_1dim=False)
+ data['done'] = None
+ if self._cuda:
+ data = to_device(data, self._device)
+ loss_dict = {}
+ # ====================
+ # imitation learn forward
+ # ====================
+ obs = data.get('obs')
+ logit = data.get('logit')
+ assert isinstance(obs['processed_obs'], torch.Tensor), obs['processed_obs']
+ model_action_logit = self._learn_model.forward(obs['processed_obs'])['logit']
+ supervised_loss = nn.MSELoss(reduction='none')(model_action_logit, logit).mean()
+ self._optimizer.zero_grad()
+ supervised_loss.backward()
+ self._optimizer.step()
+ loss_dict['supervised_loss'] = supervised_loss
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ **loss_dict,
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ """
+ self._collect_model = model_wrap(FootballKaggle5thPlaceModel(), wrapper_name='base')
+ self._gamma = self._cfg.collect.discount_factor
+ self._collect_model.eval()
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of collect mode.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
+ ReturnsKeys
+ - necessary: ``action``
+ - optional: ``logit``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ with torch.no_grad():
+ output = self._collect_model.forward(default_decollate(data['obs']['raw_obs']))
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
+ (here 'obs' indicates obs after env step, i.e. next_obs).
+ Return:
+ - transition (:obj:`Dict[str, Any]`): Dict type transition data.
+ """
+ transition = {
+ 'obs': obs,
+ 'action': model_output['action'],
+ 'logit': model_output['logit'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, origin_data: list) -> Union[None, List[Any]]:
+ datas = []
+ pre_rew = 0
+ for i in range(len(origin_data) - 1, -1, -1):
+ data = {}
+ data['obs'] = origin_data[i]['obs']
+ data['action'] = origin_data[i]['action']
+ cur_rew = origin_data[i]['reward']
+ pre_rew = cur_rew + (pre_rew * self._gamma)
+ # sample uniformly
+ data['priority'] = 1
+ data['logit'] = origin_data[i]['logit']
+ datas.append(data)
+ return datas
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model. Unlike learn and collect model, eval model does not need noise.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of eval mode, similar to ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ReturnsKeys
+ - necessary: ``action``
+ - optional: ``logit``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ with torch.no_grad():
+ output = self._eval_model.forward(data['obs']['processed_obs'])
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ r"""
+ Overview:
+ Return variables' name if variables are to used in monitor.
+ Returns:
+ - vars (:obj:`List[str]`): Variables' name list.
+ """
+ return ['cur_lr', 'supervised_loss']
diff --git a/DI-engine/ding/policy/impala.py b/DI-engine/ding/policy/impala.py
new file mode 100644
index 0000000000000000000000000000000000000000..46adeb1204dd68b00d0cf628a3767f0f2d1ed00a
--- /dev/null
+++ b/DI-engine/ding/policy/impala.py
@@ -0,0 +1,490 @@
+from collections import namedtuple
+from typing import List, Dict, Any, Tuple
+
+import torch
+import treetensor.torch as ttorch
+
+from ding.model import model_wrap
+from ding.rl_utils import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action, get_train_sample
+from ding.torch_utils import Adam, RMSprop, to_device
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate, ttorch_collate
+from ding.policy.base_policy import Policy
+
+
+@POLICY_REGISTRY.register('impala')
+class IMPALAPolicy(Policy):
+ """
+ Overview:
+ Policy class of IMPALA algorithm. Paper link: https://arxiv.org/abs/1802.01561.
+
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str impala | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4. ``priority`` bool False | Whether use priority(PER) | priority sample,
+ | update priority
+
+ 5 | ``priority_`` bool False | Whether use Importance Sampling Weight | If True, priority
+ | ``IS_weight`` | | must be True
+ 6 ``unroll_len`` int 32 | trajectory length to calculate v-trace
+ | target
+ 7 | ``learn.update`` int 4 | How many updates(iterations) to train | this args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='impala',
+ # (bool) Whether to use cuda in policy.
+ cuda=False,
+ # (bool) Whether learning policy is the same as collecting data policy(on-policy).
+ on_policy=False,
+ # (bool) Whether to enable priority experience sample.
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (str) Which kind of action space used in IMPALAPolicy, ['discrete', 'continuous'].
+ action_space='discrete',
+ # (int) the trajectory length to calculate v-trace target.
+ unroll_len=32,
+ # (bool) Whether to need policy data in process transition.
+ transition_with_policy_data=True,
+ # learn_mode config
+ learn=dict(
+ # (int) collect n_sample data, train model update_per_collect times.
+ update_per_collect=4,
+ # (int) the number of data for a train iteration.
+ batch_size=16,
+ # (float) The step size of gradient descent.
+ learning_rate=0.0005,
+ # (float) loss weight of the value network, the weight of policy network is set to 1.
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1.
+ entropy_weight=0.0001,
+ # (float) discount factor for future reward, defaults int [0, 1].
+ discount_factor=0.99,
+ # (float) additional discounting parameter.
+ lambda_=0.95,
+ # (float) clip ratio of importance weights.
+ rho_clip_ratio=1.0,
+ # (float) clip ratio of importance weights.
+ c_clip_ratio=1.0,
+ # (float) clip ratio of importance sampling.
+ rho_pg_clip_ratio=1.0,
+ # (str) The gradient clip operation type used in IMPALA, ['clip_norm', clip_value', 'clip_momentum_norm'].
+ grad_clip_type=None,
+ # (float) The gradient clip target value used in IMPALA.
+ # If ``grad_clip_type`` is 'clip_norm', then the maximum of gradient will be normalized to this value.
+ clip_value=0.5,
+ # (str) Optimizer used to train the network, ['adam', 'rmsprop'].
+ optim='adam',
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ # Only one of [n_sample, n_episode] shoule be set.
+ # n_sample=16,
+ ),
+ eval=dict(), # for compatibility
+ other=dict(
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer. Usually, larger buffer size is better.
+ replay_buffer_size=1000,
+ # (int) Maximum use times for a sample in buffer. If reaches this value, the sample will be removed.
+ max_use=16,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For example about IMPALA , its registered name is ``vac`` and the import_names is \
+ ``ding.model.template.vac``.
+ """
+ return 'vac', ['ding.model.template.vac']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For IMPALA, it mainly \
+ contains optimizer, algorithm-specific arguments such as loss weight and gamma, main (learn) model.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ assert self._cfg.action_space in ["continuous", "discrete"], self._cfg.action_space
+ self._action_space = self._cfg.action_space
+ # Optimizer
+ optim_type = self._cfg.learn.optim
+ if optim_type == 'rmsprop':
+ self._optimizer = RMSprop(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+ elif optim_type == 'adam':
+ self._optimizer = Adam(
+ self._model.parameters(),
+ grad_clip_type=self._cfg.learn.grad_clip_type,
+ clip_value=self._cfg.learn.clip_value,
+ lr=self._cfg.learn.learning_rate
+ )
+ else:
+ raise NotImplementedError("Now only support rmsprop and adam, but input is {}".format(optim_type))
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+
+ self._action_shape = self._cfg.model.action_shape
+ self._unroll_len = self._cfg.unroll_len
+
+ # Algorithm config
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._value_weight = self._cfg.learn.value_weight
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._gamma = self._cfg.learn.discount_factor
+ self._lambda = self._cfg.learn.lambda_
+ self._rho_clip_ratio = self._cfg.learn.rho_clip_ratio
+ self._c_clip_ratio = self._cfg.learn.c_clip_ratio
+ self._rho_pg_clip_ratio = self._cfg.learn.rho_pg_clip_ratio
+
+ # Main model
+ self._learn_model.reset()
+
+ def _data_preprocess_learn(self, data: List[Dict[str, Any]]):
+ """
+ Overview:
+ Data preprocess function of learn mode.
+ Convert list trajectory data to to trajectory data, which is a dict of tensors.
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): List type data, a list of data for training. Each list element is a \
+ dict, whose values are torch.Tensor or np.ndarray or dict/list combinations, keys include at least \
+ 'obs', 'next_obs', 'logit', 'action', 'reward', 'done'
+ Returns:
+ - data (:obj:`dict`): Dict type data. Values are torch.Tensor or np.ndarray or dict/list combinations. \
+ ReturnsKeys:
+ - necessary: 'logit', 'action', 'reward', 'done', 'weight', 'obs_plus_1'.
+ - optional and not used in later computation: 'obs', 'next_obs'.'IS', 'collect_iter', 'replay_unique_id', \
+ 'replay_buffer_idx', 'priority', 'staleness', 'use'.
+ ReturnsShapes:
+ - obs_plus_1 (:obj:`torch.FloatTensor`): :math:`(T * B, obs_shape)`, where T is timestep, B is batch size \
+ and obs_shape is the shape of single env observation
+ - logit (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where N is action dim
+ - action (:obj:`torch.LongTensor`): :math:`(T, B)`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T+1, B)`
+ - done (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ - weight (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ """
+ elem = data[0]
+ if isinstance(elem, dict): # old pipeline
+ data = default_collate(data)
+ elif isinstance(elem, list): # new task pipeline
+ data = default_collate(default_collate(data))
+ else:
+ raise TypeError("not support element type ({}) in IMPALA".format(type(elem)))
+ if self._cuda:
+ data = to_device(data, self._device)
+ if self._priority_IS_weight:
+ assert self._priority, "Use IS Weight correction, but Priority is not used."
+ if self._priority and self._priority_IS_weight:
+ data['weight'] = data['IS']
+ else:
+ data['weight'] = data.get('weight', None)
+ if isinstance(elem, dict): # old pipeline
+ for k in data:
+ if isinstance(data[k], list):
+ data[k] = default_collate(data[k])
+ data['obs_plus_1'] = torch.cat([data['obs'], data['next_obs'][-1:]], dim=0) # shape (T+1)*B,env_obs_shape
+ return data
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss and current learning rate.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
+ dimension by some utility functions such as ``default_preprocess_learn``. \
+ For IMPALA, each element in list is a dict containing at least the following keys: ``obs``, \
+ ``action``, ``logit``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such \
+ as ``weight``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to unittest for IMPALAPolicy: ``ding.policy.tests.test_impala``.
+ """
+ data = self._data_preprocess_learn(data)
+ # ====================
+ # IMPALA forward
+ # ====================
+ self._learn_model.train()
+ output = self._learn_model.forward(
+ data['obs_plus_1'].view((-1, ) + data['obs_plus_1'].shape[2:]), mode='compute_actor_critic'
+ )
+ target_logit, behaviour_logit, actions, values, rewards, weights = self._reshape_data(output, data)
+ # Calculate vtrace error
+ data = vtrace_data(target_logit, behaviour_logit, actions, values, rewards, weights)
+ g, l, r, c, rg = self._gamma, self._lambda, self._rho_clip_ratio, self._c_clip_ratio, self._rho_pg_clip_ratio
+ if self._action_space == 'continuous':
+ vtrace_loss = vtrace_error_continuous_action(data, g, l, r, c, rg)
+ elif self._action_space == 'discrete':
+ vtrace_loss = vtrace_error_discrete_action(data, g, l, r, c, rg)
+
+ wv, we = self._value_weight, self._entropy_weight
+ total_loss = vtrace_loss.policy_loss + wv * vtrace_loss.value_loss - we * vtrace_loss.entropy_loss
+ # ====================
+ # IMPALA update
+ # ====================
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ self._optimizer.step()
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': vtrace_loss.policy_loss.item(),
+ 'value_loss': vtrace_loss.value_loss.item(),
+ 'entropy_loss': vtrace_loss.entropy_loss.item(),
+ }
+
+ def _reshape_data(self, output: Dict[str, Any], data: Dict[str, Any]) -> Tuple:
+ """
+ Overview:
+ Obtain weights for loss calculating, where should be 0 for done positions. Update values and rewards with \
+ the weight.
+ Arguments:
+ - output (:obj:`Dict[int, Any]`): Dict type data, output of learn_model forward. \
+ Values are torch.Tensor or np.ndarray or dict/list combinations,keys are value, logit.
+ - data (:obj:`Dict[int, Any]`): Dict type data, input of policy._forward_learn Values are torch.Tensor or \
+ np.ndarray or dict/list combinations. Keys includes at least ['logit', 'action', 'reward', 'done'].
+ Returns:
+ - data (:obj:`Tuple[Any]`): Tuple of target_logit, behaviour_logit, actions, values, rewards, weights.
+ ReturnsShapes:
+ - target_logit (:obj:`torch.FloatTensor`): :math:`((T+1), B, Obs_Shape)`, where T is timestep,\
+ B is batch size and Obs_Shape is the shape of single env observation.
+ - behaviour_logit (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where N is action dim.
+ - actions (:obj:`torch.LongTensor`): :math:`(T, B)`
+ - values (:obj:`torch.FloatTensor`): :math:`(T+1, B)`
+ - rewards (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ - weights (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ """
+ if self._action_space == 'continuous':
+ target_logit = {}
+ target_logit['mu'] = output['logit']['mu'].reshape(self._unroll_len + 1, -1,
+ self._action_shape)[:-1
+ ] # shape (T+1),B,env_action_shape
+ target_logit['sigma'] = output['logit']['sigma'].reshape(self._unroll_len + 1, -1, self._action_shape
+ )[:-1] # shape (T+1),B,env_action_shape
+ elif self._action_space == 'discrete':
+ target_logit = output['logit'].reshape(self._unroll_len + 1, -1,
+ self._action_shape)[:-1] # shape (T+1),B,env_action_shape
+ behaviour_logit = data['logit'] # shape T,B
+ actions = data['action'] # shape T,B for discrete # shape T,B,env_action_shape for continuous
+ values = output['value'].reshape(self._unroll_len + 1, -1) # shape T+1,B,env_action_shape
+ rewards = data['reward'] # shape T,B
+ weights_ = 1 - data['done'].float() # shape T,B
+ weights = torch.ones_like(rewards) # shape T,B
+ values[1:] = values[1:] * weights_
+ weights[1:] = weights_[:-1]
+ rewards = rewards * weights # shape T,B
+ return target_logit, behaviour_logit, actions, values, rewards, weights
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For IMPALA, it contains \
+ the collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \
+ discrete action space), and other algorithm-specific arguments such as unroll_len.
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+ """
+ assert self._cfg.action_space in ["continuous", "discrete"]
+ self._action_space = self._cfg.action_space
+ if self._action_space == 'continuous':
+ self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample')
+ elif self._action_space == 'discrete':
+ self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
+
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data (action logit and value) for learn mode defined in ``self._process_transition`` \
+ method. The key of the dict is the same as the input data, i.e. environment id.
+
+ .. tip::
+ If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \
+ related data as extra keyword arguments of this method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to unittest for IMPALAPolicy: ``ding.policy.tests.test_impala``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ output = {i: d for i, d in zip(data_id, output)}
+ return output
+
+ def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training. In IMPALA, a train sample is processed transitions with unroll_len length.
+ Arguments:
+ - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \
+ as input transitions, but may contain more data for training.
+ """
+ return get_train_sample(data, self._unroll_len)
+
+ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor],
+ timestep: namedtuple) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For IMPALA, it contains obs, next_obs, action, reward, done, logit.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.
+ - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \
+ as input. For IMPALA, it contains the action and the logit of the action.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
+ except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
+ reward, done, info, etc.
+ Returns:
+ - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'logit': policy_output['logit'],
+ 'action': policy_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For IMPALA, it contains the \
+ eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete action).
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ assert self._cfg.action_space in ["continuous", "discrete"], self._cfg.action_space
+ self._action_space = self._cfg.action_space
+ if self._action_space == 'continuous':
+ self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample')
+ elif self._action_space == 'discrete':
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs. ``_forward_eval`` in IMPALA often uses deterministic sample to get \
+ actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \
+ exploitation.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to unittest for IMPALAPolicy: ``ding.policy.tests.test_impala``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ output = {i: d for i, d in zip(data_id, output)}
+ return output
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return super()._monitor_vars_learn() + ['policy_loss', 'value_loss', 'entropy_loss']
diff --git a/DI-engine/ding/policy/iqn.py b/DI-engine/ding/policy/iqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bb9b683ccad9f165c8c54f82d9b8bb0f819ccb4
--- /dev/null
+++ b/DI-engine/ding/policy/iqn.py
@@ -0,0 +1,198 @@
+from typing import List, Dict, Any, Tuple, Union
+import copy
+import torch
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import iqn_nstep_td_data, iqn_nstep_td_error, get_train_sample, get_nstep_return_data
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .dqn import DQNPolicy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('iqn')
+class IQNPolicy(DQNPolicy):
+ r"""
+ Overview:
+ Policy class of IQN algorithm.
+
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str qrdqn | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool True | Whether use priority(PER) | priority sample,
+ | update priority
+ 6 | ``other.eps`` float 0.05 | Start value for epsilon decay. It's
+ | ``.start`` | small because rainbow use noisy net.
+ 7 | ``other.eps`` float 0.05 | End value for epsilon decay.
+ | ``.end``
+ 8 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 9 ``nstep`` int 3, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 10 | ``learn.update`` int 3 | How many updates(iterations) to train | this args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 11 ``learn.kappa`` float / | Threshold of Huber loss
+ == ==================== ======== ============== ======================================== =======================
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='iqn',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.97,
+ # (int) N-step reward for target q_value estimation
+ nstep=1,
+ learn=dict(
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=3,
+ batch_size=64,
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (int) Frequence of target network update.
+ target_update_freq=100,
+ # (float) Threshold of Huber loss. In the IQN paper, this is denoted by kappa. Default to 1.0.
+ kappa=1.0,
+ # (bool) Whether ignore done(usually for max step termination env)
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) Only one of [n_sample, n_step, n_episode] shoule be set
+ # n_sample=8,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(),
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=0.95,
+ end=0.1,
+ # (int) Decay length(env step)
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=10000, )
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'iqn', ['ding.model.template.q_learning']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the optimizer, algorithm config, main and target models.
+ """
+ self._priority = self._cfg.priority
+ # Optimizer
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+ self._kappa = self._cfg.learn.kappa
+
+ # use model_wrapper for specialized demands of different modes
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='assign',
+ update_kwargs={'freq': self._cfg.learn.target_update_freq}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
+ """
+ data = default_preprocess_learn(
+ data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # Q-learning forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # Current q value (main model)
+ ret = self._learn_model.forward(data['obs'])
+ q_value = ret['q']
+ replay_quantiles = ret['quantiles']
+ # Target q value
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(data['next_obs'])['q']
+ # Max q value action (main model)
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+
+ data_n = iqn_nstep_td_data(
+ q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], replay_quantiles,
+ data['weight']
+ )
+ value_gamma = data.get('value_gamma')
+ loss, td_error_per_sample = iqn_nstep_td_error(
+ data_n, self._gamma, nstep=self._nstep, kappa=self._kappa, value_gamma=value_gamma
+ )
+
+ # ====================
+ # Q-learning update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ self._optimizer.step()
+
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
+ # '[histogram]action_distribution': data['action'],
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
diff --git a/DI-engine/ding/policy/madqn.py b/DI-engine/ding/policy/madqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..50ceb40b0f6c715bc98f7168ede58296ac29e80e
--- /dev/null
+++ b/DI-engine/ding/policy/madqn.py
@@ -0,0 +1,350 @@
+from typing import List, Dict, Any, Tuple, Union, Optional
+from collections import namedtuple
+import torch
+import copy
+
+from ding.torch_utils import RMSprop, to_device
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \
+ v_nstep_td_data, v_nstep_td_error, get_nstep_return_data
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import timestep_collate, default_collate, default_decollate
+from .qmix import QMIXPolicy
+
+
+@POLICY_REGISTRY.register('madqn')
+class MADQNPolicy(QMIXPolicy):
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='madqn',
+ # (bool) Whether to use cuda for network.
+ cuda=True,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ nstep=3,
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=100,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) Target network update momentum parameter.
+ # in [0, 1].
+ target_update_theta=0.008,
+ # (float) The discount factor for future rewards,
+ # in [0, 1].
+ discount_factor=0.99,
+ # (bool) Whether to use double DQN mechanism(target q for surpassing over estimation)
+ double_q=False,
+ weight_decay=1e-5,
+ ),
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] shoule be set
+ n_episode=32,
+ # (int) Cut trajectories into pieces with length "unroll_len", the length of timesteps
+ # in each forward when training. In qmix, it is greater than 1 because there is RNN.
+ unroll_len=10,
+ ),
+ eval=dict(),
+ other=dict(
+ eps=dict(
+ # (str) Type of epsilon decay
+ type='exp',
+ # (float) Start value for epsilon decay, in [0, 1].
+ # 0 means not use epsilon decay.
+ start=1,
+ # (float) Start value for epsilon decay, in [0, 1].
+ end=0.05,
+ # (int) Decay length(env step)
+ decay=50000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=5000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default model setting for demonstration.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
+ """
+ return 'madqn', ['ding.model.template.madqn']
+
+ def _init_learn(self) -> None:
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in QMIX"
+ self._optimizer_current = RMSprop(
+ params=self._model.current.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ alpha=0.99,
+ eps=0.00001,
+ weight_decay=self._cfg.learn.weight_decay
+ )
+ self._optimizer_cooperation = RMSprop(
+ params=self._model.cooperation.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ alpha=0.99,
+ eps=0.00001,
+ weight_decay=self._cfg.learn.weight_decay
+ )
+ self._gamma = self._cfg.learn.discount_factor
+ self._nstep = self._cfg.nstep
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_update_theta}
+ )
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._learn_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _data_preprocess_learn(self, data: List[Any]) -> dict:
+ r"""
+ Overview:
+ Preprocess the data to fit the required data format for learning
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function
+ Returns:
+ - data (:obj:`Dict[str, Any]`): the processed data, from \
+ [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])}
+ """
+ # data preprocess
+ data = timestep_collate(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+ data['weight'] = data.get('weight', None)
+ data['done'] = data['done'].float()
+ return data
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
+ np.ndarray or dict/list combinations.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
+ recorded in text log and tensorboard, values are python scalar or a list of scalars.
+ ArgumentsKeys:
+ - necessary: ``obs``, ``next_obs``, ``action``, ``reward``, ``weight``, ``prev_state``, ``done``
+ ReturnsKeys:
+ - necessary: ``cur_lr``, ``total_loss``
+ - cur_lr (:obj:`float`): Current learning rate
+ - total_loss (:obj:`float`): The calculated loss
+ """
+ data = self._data_preprocess_learn(data)
+ # ====================
+ # Q-mix forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # for hidden_state plugin, we need to reset the main model and target model
+ self._learn_model.reset(state=data['prev_state'][0])
+ self._target_model.reset(state=data['prev_state'][0])
+ inputs = {'obs': data['obs'], 'action': data['action']}
+
+ total_q = self._learn_model.forward(inputs, single_step=False)['total_q']
+
+ if self._cfg.learn.double_q:
+ next_inputs = {'obs': data['next_obs']}
+ self._learn_model.reset(state=data['prev_state'][1])
+ logit_detach = self._learn_model.forward(next_inputs, single_step=False)['logit'].clone().detach()
+ next_inputs = {'obs': data['next_obs'], 'action': logit_detach.argmax(dim=-1)}
+ else:
+ next_inputs = {'obs': data['next_obs']}
+ with torch.no_grad():
+ target_total_q = self._target_model.forward(next_inputs, cooperation=True, single_step=False)['total_q']
+
+ if self._nstep == 1:
+
+ v_data = v_1step_td_data(total_q, target_total_q, data['reward'], data['done'], data['weight'])
+ loss, td_error_per_sample = v_1step_td_error(v_data, self._gamma)
+ # for visualization
+ with torch.no_grad():
+ if data['done'] is not None:
+ target_v = self._gamma * (1 - data['done']) * target_total_q + data['reward']
+ else:
+ target_v = self._gamma * target_total_q + data['reward']
+ else:
+ data['reward'] = data['reward'].permute(0, 2, 1).contiguous()
+ loss = []
+ td_error_per_sample = []
+ for t in range(self._cfg.collect.unroll_len):
+ v_data = v_nstep_td_data(
+ total_q[t], target_total_q[t], data['reward'][t], data['done'][t], data['weight'], self._gamma
+ )
+ # calculate v_nstep_td critic_loss
+ loss_i, td_error_per_sample_i = v_nstep_td_error(v_data, self._gamma, self._nstep)
+ loss.append(loss_i)
+ td_error_per_sample.append(td_error_per_sample_i)
+ loss = sum(loss) / (len(loss) + 1e-8)
+ td_error_per_sample = sum(td_error_per_sample) / (len(td_error_per_sample) + 1e-8)
+
+ self._optimizer_current.zero_grad()
+ loss.backward()
+ grad_norm = torch.nn.utils.clip_grad_norm_(self._model.current.parameters(), self._cfg.learn.clip_value)
+ self._optimizer_current.step()
+
+ # cooperation
+ self._learn_model.reset(state=data['prev_state'][0])
+ self._target_model.reset(state=data['prev_state'][0])
+ cooperation_total_q = self._learn_model.forward(inputs, cooperation=True, single_step=False)['total_q']
+ next_inputs = {'obs': data['next_obs']}
+ with torch.no_grad():
+ cooperation_target_total_q = self._target_model.forward(
+ next_inputs, cooperation=True, single_step=False
+ )['total_q']
+
+ if self._nstep == 1:
+ v_data = v_1step_td_data(
+ cooperation_total_q, cooperation_target_total_q, data['reward'], data['done'], data['weight']
+ )
+ cooperation_loss, _ = v_1step_td_error(v_data, self._gamma)
+ else:
+ cooperation_loss_all = []
+ for t in range(self._cfg.collect.unroll_len):
+ v_data = v_nstep_td_data(
+ cooperation_total_q[t], cooperation_target_total_q[t], data['reward'][t], data['done'][t],
+ data['weight'], self._gamma
+ )
+ cooperation_loss, _ = v_nstep_td_error(v_data, self._gamma, self._nstep)
+ cooperation_loss_all.append(cooperation_loss)
+ cooperation_loss = sum(cooperation_loss_all) / (len(cooperation_loss_all) + 1e-8)
+ self._optimizer_cooperation.zero_grad()
+ cooperation_loss.backward()
+ cooperation_grad_norm = torch.nn.utils.clip_grad_norm_(
+ self._model.cooperation.parameters(), self._cfg.learn.clip_value
+ )
+ self._optimizer_cooperation.step()
+
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer_current.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'total_q': total_q.mean().item() / self._cfg.model.agent_num,
+ 'target_total_q': target_total_q.mean().item() / self._cfg.model.agent_num,
+ 'grad_norm': grad_norm,
+ 'cooperation_grad_norm': cooperation_grad_norm,
+ 'cooperation_loss': cooperation_loss.item(),
+ }
+
+ def _reset_learn(self, data_id: Optional[List[int]] = None) -> None:
+ r"""
+ Overview:
+ Reset learn model to the state indicated by data_id
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\
+ the model state to the state indicated by data_id
+ """
+ self._learn_model.reset(data_id=data_id)
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Return the state_dict of learn mode, usually including model and optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer_current': self._optimizer_current.state_dict(),
+ 'optimizer_cooperation': self._optimizer_cooperation.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer_current.load_state_dict(state_dict['optimizer_current'])
+ self._optimizer_cooperation.load_state_dict(state_dict['optimizer_cooperation'])
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data, including 'obs', 'next_obs', 'prev_state',\
+ 'action', 'reward', 'done'
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'prev_state': model_output['prev_state'],
+ 'action': model_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the train sample from trajectory.
+ Arguments:
+ - data (:obj:`list`): The trajectory's cache
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ if self._cfg.nstep == 1:
+ return get_train_sample(data, self._unroll_len)
+ else:
+ data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
+ return get_train_sample(data, self._unroll_len)
+
+ def _monitor_vars_learn(self) -> List[str]:
+ r"""
+ Overview:
+ Return variables' name if variables are to used in monitor.
+ Returns:
+ - vars (:obj:`List[str]`): Variables' name list.
+ """
+ return [
+ 'cur_lr', 'total_loss', 'total_q', 'target_total_q', 'grad_norm', 'target_reward_total_q',
+ 'cooperation_grad_norm', 'cooperation_loss'
+ ]
diff --git a/DI-engine/ding/policy/mbpolicy/__init__.py b/DI-engine/ding/policy/mbpolicy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e23c8d823da77ccb14d09bcf6fb941e25b828d65
--- /dev/null
+++ b/DI-engine/ding/policy/mbpolicy/__init__.py
@@ -0,0 +1,2 @@
+from .mbsac import MBSACPolicy
+from .dreamer import DREAMERPolicy
diff --git a/DI-engine/ding/policy/mbpolicy/dreamer.py b/DI-engine/ding/policy/mbpolicy/dreamer.py
new file mode 100644
index 0000000000000000000000000000000000000000..43d3b886198bccca7ad4c86e380d25175dc5d8f6
--- /dev/null
+++ b/DI-engine/ding/policy/mbpolicy/dreamer.py
@@ -0,0 +1,344 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import torch
+from torch import nn
+from copy import deepcopy
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import get_train_sample
+from ding.utils import POLICY_REGISTRY, deep_merge_dicts
+from ding.utils.data import default_collate, default_decollate
+from ding.policy import Policy
+from ding.model import model_wrap
+from ding.policy.common_utils import default_preprocess_learn
+
+from .utils import imagine, compute_target, compute_actor_loss, RewardEMA, tensorstats
+
+
+@POLICY_REGISTRY.register('dreamer')
+class DREAMERPolicy(Policy):
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='dreamer',
+ # (bool) Whether to use cuda for network and loss computation.
+ cuda=False,
+ # (int) Number of training samples (randomly collected) in replay buffer when training starts.
+ random_collect_size=5000,
+ # (bool) Whether to need policy-specific data in preprocess transition.
+ transition_with_policy_data=False,
+ # (int)
+ imag_horizon=15,
+ learn=dict(
+ # (float) Lambda for TD-lambda return.
+ lambda_=0.95,
+ # (float) Max norm of gradients.
+ grad_clip=100,
+ learning_rate=3e-5,
+ batch_size=16,
+ batch_length=64,
+ imag_sample=True,
+ slow_value_target=True,
+ slow_target_update=1,
+ slow_target_fraction=0.02,
+ discount=0.997,
+ reward_EMA=True,
+ actor_entropy=3e-4,
+ actor_state_entropy=0.0,
+ value_decay=0.0,
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'dreamervac', ['ding.model.template.vac']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the optimizer, algorithm config, main and target models.
+ """
+ # Algorithm config
+ self._lambda = self._cfg.learn.lambda_
+ self._grad_clip = self._cfg.learn.grad_clip
+
+ self._critic = self._model.critic
+ self._actor = self._model.actor
+
+ if self._cfg.learn.slow_value_target:
+ self._slow_value = deepcopy(self._critic)
+ self._updates = 0
+
+ # Optimizer
+ self._optimizer_value = Adam(
+ self._critic.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ )
+ self._optimizer_actor = Adam(
+ self._actor.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ )
+
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ self._learn_model.reset()
+
+ self._forward_learn_cnt = 0
+
+ if self._cfg.learn.reward_EMA:
+ self.reward_ema = RewardEMA(device=self._device)
+
+ def _forward_learn(self, start: dict, world_model, envstep) -> Dict[str, Any]:
+ # log dict
+ log_vars = {}
+ self._learn_model.train()
+ self._update_slow_target()
+
+ self._actor.requires_grad_(requires_grad=True)
+ # start is dict of {stoch, deter, logit}
+ if self._cuda:
+ start = to_device(start, self._device)
+
+ # train self._actor
+ imag_feat, imag_state, imag_action = imagine(
+ self._cfg.learn, world_model, start, self._actor, self._cfg.imag_horizon
+ )
+ reward = world_model.heads["reward"](world_model.dynamics.get_feat(imag_state)).mode()
+ actor_ent = self._actor(imag_feat).entropy()
+ state_ent = world_model.dynamics.get_dist(imag_state).entropy()
+ # this target is not scaled
+ # slow is flag to indicate whether slow_target is used for lambda-return
+ target, weights, base = compute_target(
+ self._cfg.learn, world_model, self._critic, imag_feat, imag_state, reward, actor_ent, state_ent
+ )
+ actor_loss, mets = compute_actor_loss(
+ self._cfg.learn,
+ self._actor,
+ self.reward_ema,
+ imag_feat,
+ imag_state,
+ imag_action,
+ target,
+ actor_ent,
+ state_ent,
+ weights,
+ base,
+ )
+ log_vars.update(mets)
+ value_input = imag_feat
+ self._actor.requires_grad_(requires_grad=False)
+
+ self._critic.requires_grad_(requires_grad=True)
+ value = self._critic(value_input[:-1].detach())
+ # to do
+ # target = torch.stack(target, dim=1)
+ # (time, batch, 1), (time, batch, 1) -> (time, batch)
+ value_loss = -value.log_prob(target.detach())
+ slow_target = self._slow_value(value_input[:-1].detach())
+ if self._cfg.learn.slow_value_target:
+ value_loss = value_loss - value.log_prob(slow_target.mode().detach())
+ if self._cfg.learn.value_decay:
+ value_loss += self._cfg.learn.value_decay * value.mode()
+ # (time, batch, 1), (time, batch, 1) -> (1,)
+ value_loss = torch.mean(weights[:-1] * value_loss[:, :, None])
+ self._critic.requires_grad_(requires_grad=False)
+
+ log_vars.update(tensorstats(value.mode(), "value"))
+ log_vars.update(tensorstats(target, "target"))
+ log_vars.update(tensorstats(reward, "imag_reward"))
+ log_vars.update(tensorstats(imag_action, "imag_action"))
+ log_vars["actor_ent"] = torch.mean(actor_ent).detach().cpu().numpy().item()
+ # ====================
+ # actor-critic update
+ # ====================
+ self._model.requires_grad_(requires_grad=True)
+ world_model.requires_grad_(requires_grad=True)
+
+ loss_dict = {
+ 'critic_loss': value_loss,
+ 'actor_loss': actor_loss,
+ }
+
+ norm_dict = self._update(loss_dict)
+
+ self._model.requires_grad_(requires_grad=False)
+ world_model.requires_grad_(requires_grad=False)
+ # =============
+ # after update
+ # =============
+ self._forward_learn_cnt += 1
+
+ return {
+ **log_vars,
+ **norm_dict,
+ **loss_dict,
+ }
+
+ def _update(self, loss_dict):
+ # update actor
+ self._optimizer_actor.zero_grad()
+ loss_dict['actor_loss'].backward()
+ actor_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip)
+ self._optimizer_actor.step()
+ # update critic
+ self._optimizer_value.zero_grad()
+ loss_dict['critic_loss'].backward()
+ critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip)
+ self._optimizer_value.step()
+ return {'actor_grad_norm': actor_norm, 'critic_grad_norm': critic_norm}
+
+ def _update_slow_target(self):
+ if self._cfg.learn.slow_value_target:
+ if self._updates % self._cfg.learn.slow_target_update == 0:
+ mix = self._cfg.learn.slow_target_fraction
+ for s, d in zip(self._critic.parameters(), self._slow_value.parameters()):
+ d.data = mix * s.data + (1 - mix) * d.data
+ self._updates += 1
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ ret = {
+ 'model': self._learn_model.state_dict(),
+ 'optimizer_value': self._optimizer_value.state_dict(),
+ 'optimizer_actor': self._optimizer_actor.state_dict(),
+ }
+ return ret
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._optimizer_value.load_state_dict(state_dict['optimizer_value'])
+ self._optimizer_actor.load_state_dict(state_dict['optimizer_actor'])
+
+ def _init_collect(self) -> None:
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._collect_model = model_wrap(self._model, wrapper_name='base')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: dict, world_model, envstep, reset=None, state=None) -> dict:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+
+ if state is None:
+ batch_size = len(data_id)
+ latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter}
+ action = torch.zeros((batch_size, self._cfg.collect.action_size)).to(self._device)
+ else:
+ #state = default_collate(list(state.values()))
+ latent = to_device(default_collate(list(zip(*state))[0]), self._device)
+ action = to_device(default_collate(list(zip(*state))[1]), self._device)
+ if len(action.shape) == 1:
+ action = action.unsqueeze(-1)
+ if reset.any():
+ mask = 1 - reset
+ for key in latent.keys():
+ for i in range(latent[key].shape[0]):
+ latent[key][i] *= mask[i]
+ for i in range(len(action)):
+ action[i] *= mask[i]
+
+ data = data - 0.5
+ embed = world_model.encoder(data)
+ latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample)
+ feat = world_model.dynamics.get_feat(latent)
+
+ actor = self._actor(feat)
+ action = actor.sample()
+ logprob = actor.log_prob(action)
+ latent = {k: v.detach() for k, v in latent.items()}
+ action = action.detach()
+
+ state = (latent, action)
+ output = {"action": action, "logprob": logprob, "state": state}
+
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ transition = {
+ 'obs': obs,
+ 'action': model_output['action'],
+ # TODO(zp) random_collect just have action
+ #'logprob': model_output['logprob'],
+ 'reward': timestep.reward,
+ 'discount': timestep.info['discount'],
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ return get_train_sample(data, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ self._eval_model = model_wrap(self._model, wrapper_name='base')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict, world_model, reset=None, state=None) -> dict:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+
+ if state is None:
+ batch_size = len(data_id)
+ latent = world_model.dynamics.initial(batch_size) # {logit, stoch, deter}
+ action = torch.zeros((batch_size, self._cfg.collect.action_size)).to(self._device)
+ else:
+ #state = default_collate(list(state.values()))
+ latent = to_device(default_collate(list(zip(*state))[0]), self._device)
+ action = to_device(default_collate(list(zip(*state))[1]), self._device)
+ if len(action.shape) == 1:
+ action = action.unsqueeze(-1)
+ if reset.any():
+ mask = 1 - reset
+ for key in latent.keys():
+ for i in range(latent[key].shape[0]):
+ latent[key][i] *= mask[i]
+ for i in range(len(action)):
+ action[i] *= mask[i]
+
+ data = data - 0.5
+ embed = world_model.encoder(data)
+ latent, _ = world_model.dynamics.obs_step(latent, action, embed, self._cfg.collect.collect_dyn_sample)
+ feat = world_model.dynamics.get_feat(latent)
+
+ actor = self._actor(feat)
+ action = actor.mode()
+ logprob = actor.log_prob(action)
+ latent = {k: v.detach() for k, v in latent.items()}
+ action = action.detach()
+
+ state = (latent, action)
+ output = {"action": action, "logprob": logprob, "state": state}
+
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ r"""
+ Overview:
+ Return variables' name if variables are to used in monitor.
+ Returns:
+ - vars (:obj:`List[str]`): Variables' name list.
+ """
+ return [
+ 'normed_target_mean', 'normed_target_std', 'normed_target_min', 'normed_target_max', 'EMA_005', 'EMA_095',
+ 'actor_entropy', 'actor_state_entropy', 'value_mean', 'value_std', 'value_min', 'value_max', 'target_mean',
+ 'target_std', 'target_min', 'target_max', 'imag_reward_mean', 'imag_reward_std', 'imag_reward_min',
+ 'imag_reward_max', 'imag_action_mean', 'imag_action_std', 'imag_action_min', 'imag_action_max', 'actor_ent',
+ 'actor_loss', 'critic_loss', 'actor_grad_norm', 'critic_grad_norm'
+ ]
diff --git a/DI-engine/ding/policy/mbpolicy/mbsac.py b/DI-engine/ding/policy/mbpolicy/mbsac.py
new file mode 100644
index 0000000000000000000000000000000000000000..1918e161db04ae88f678ce0ff61d213ecc0ae939
--- /dev/null
+++ b/DI-engine/ding/policy/mbpolicy/mbsac.py
@@ -0,0 +1,406 @@
+from typing import Dict, Any, List
+from functools import partial
+
+import torch
+from torch import Tensor
+from torch import nn
+from torch.distributions import Normal, Independent
+
+from ding.torch_utils import to_device, fold_batch, unfold_batch, unsqueeze_repeat
+from ding.utils import POLICY_REGISTRY
+from ding.policy import SACPolicy
+from ding.rl_utils import generalized_lambda_returns
+from ding.policy.common_utils import default_preprocess_learn
+
+from .utils import q_evaluation
+
+
+@POLICY_REGISTRY.register('mbsac')
+class MBSACPolicy(SACPolicy):
+ """
+ Overview:
+ Model based SAC with value expansion (arXiv: 1803.00101)
+ and value gradient (arXiv: 1510.09142) w.r.t lambda-return.
+
+ https://arxiv.org/pdf/1803.00101.pdf
+ https://arxiv.org/pdf/1510.09142.pdf
+
+ Config:
+ == ==================== ======== ============= ==================================
+ ID Symbol Type Default Value Description
+ == ==================== ======== ============= ==================================
+ 1 ``learn._lambda`` float 0.8 | Lambda for TD-lambda return.
+ 2 ``learn.grad_clip` float 100.0 | Max norm of gradients.
+ 3 | ``learn.sample`` bool True | Whether to sample states or
+ | ``_state`` | transitions from env buffer.
+ == ==================== ======== ============= ==================================
+
+ .. note::
+ For other configs, please refer to ding.policy.sac.SACPolicy.
+ """
+
+ config = dict(
+ learn=dict(
+ # (float) Lambda for TD-lambda return.
+ lambda_=0.8,
+ # (float) Max norm of gradients.
+ grad_clip=100,
+ # (bool) Whether to sample states or transitions from environment buffer.
+ sample_state=True,
+ )
+ )
+
+ def _init_learn(self) -> None:
+ super()._init_learn()
+ self._target_model.requires_grad_(False)
+
+ self._lambda = self._cfg.learn.lambda_
+ self._grad_clip = self._cfg.learn.grad_clip
+ self._sample_state = self._cfg.learn.sample_state
+ self._auto_alpha = self._cfg.learn.auto_alpha
+ # TODO: auto alpha
+ assert not self._auto_alpha, "NotImplemented"
+
+ # TODO: TanhTransform leads to NaN
+ def actor_fn(obs: Tensor):
+ # (mu, sigma) = self._learn_model.forward(
+ # obs, mode='compute_actor')['logit']
+ # # enforce action bounds
+ # dist = TransformedDistribution(
+ # Independent(Normal(mu, sigma), 1), [TanhTransform()])
+ # action = dist.rsample()
+ # log_prob = dist.log_prob(action)
+ # return action, -self._alpha.detach() * log_prob
+ (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit']
+ dist = Independent(Normal(mu, sigma), 1)
+ pred = dist.rsample()
+ action = torch.tanh(pred)
+
+ log_prob = dist.log_prob(
+ pred
+ ) + 2 * (pred + torch.nn.functional.softplus(-2. * pred) - torch.log(torch.tensor(2.))).sum(-1)
+ return action, -self._alpha.detach() * log_prob
+
+ self._actor_fn = actor_fn
+
+ def critic_fn(obss: Tensor, actions: Tensor, model: nn.Module):
+ eval_data = {'obs': obss, 'action': actions}
+ q_values = model.forward(eval_data, mode='compute_critic')['q_value']
+ return q_values
+
+ self._critic_fn = critic_fn
+ self._forward_learn_cnt = 0
+
+ def _forward_learn(self, data: dict, world_model, envstep) -> Dict[str, Any]:
+ # preprocess data
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ if len(data['action'].shape) == 1:
+ data['action'] = data['action'].unsqueeze(1)
+
+ self._learn_model.train()
+ self._target_model.train()
+
+ # TODO: use treetensor
+ # rollout length is determined by world_model.rollout_length_scheduler
+ if self._sample_state:
+ # data['reward'], ... are not used
+ obss, actions, rewards, aug_rewards, dones = \
+ world_model.rollout(data['obs'], self._actor_fn, envstep)
+ else:
+ obss, actions, rewards, aug_rewards, dones = \
+ world_model.rollout(data['next_obs'], self._actor_fn, envstep)
+ obss = torch.cat([data['obs'].unsqueeze(0), obss])
+ actions = torch.cat([data['action'].unsqueeze(0), actions])
+ rewards = torch.cat([data['reward'].unsqueeze(0), rewards])
+ aug_rewards = torch.cat([torch.zeros_like(data['reward']).unsqueeze(0), aug_rewards])
+ dones = torch.cat([data['done'].unsqueeze(0), dones])
+
+ dones = torch.cat([torch.zeros_like(data['done']).unsqueeze(0), dones])
+
+ # (T+1, B)
+ target_q_values = q_evaluation(obss, actions, partial(self._critic_fn, model=self._target_model))
+ if self._twin_critic:
+ target_q_values = torch.min(target_q_values[0], target_q_values[1]) + aug_rewards
+ else:
+ target_q_values = target_q_values + aug_rewards
+
+ # (T, B)
+ lambda_return = generalized_lambda_returns(target_q_values, rewards, self._gamma, self._lambda, dones[1:])
+
+ # (T, B)
+ # If S_t terminates, we should not consider loss from t+1,...
+ weight = (1 - dones[:-1].detach()).cumprod(dim=0)
+
+ # (T+1, B)
+ q_values = q_evaluation(obss.detach(), actions.detach(), partial(self._critic_fn, model=self._learn_model))
+ if self._twin_critic:
+ critic_loss = 0.5 * torch.square(q_values[0][:-1] - lambda_return.detach()) \
+ + 0.5 * torch.square(q_values[1][:-1] - lambda_return.detach())
+ else:
+ critic_loss = 0.5 * torch.square(q_values[:-1] - lambda_return.detach())
+
+ # value expansion loss
+ critic_loss = (critic_loss * weight).mean()
+
+ # value gradient loss
+ policy_loss = -(lambda_return * weight).mean()
+
+ # alpha_loss = None
+
+ loss_dict = {
+ 'critic_loss': critic_loss,
+ 'policy_loss': policy_loss,
+ # 'alpha_loss': alpha_loss.detach(),
+ }
+
+ norm_dict = self._update(loss_dict)
+
+ # =============
+ # after update
+ # =============
+ self._forward_learn_cnt += 1
+ # target update
+ self._target_model.update(self._learn_model.state_dict())
+
+ return {
+ 'cur_lr_q': self._optimizer_q.defaults['lr'],
+ 'cur_lr_p': self._optimizer_policy.defaults['lr'],
+ 'alpha': self._alpha.item(),
+ 'target_q_value': target_q_values.detach().mean().item(),
+ **norm_dict,
+ **loss_dict,
+ }
+
+ def _update(self, loss_dict):
+ # update critic
+ self._optimizer_q.zero_grad()
+ loss_dict['critic_loss'].backward()
+ critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip)
+ self._optimizer_q.step()
+ # update policy
+ self._optimizer_policy.zero_grad()
+ loss_dict['policy_loss'].backward()
+ policy_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip)
+ self._optimizer_policy.step()
+ # update temperature
+ # self._alpha_optim.zero_grad()
+ # loss_dict['alpha_loss'].backward()
+ # self._alpha_optim.step()
+ return {'policy_norm': policy_norm, 'critic_norm': critic_norm}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ r"""
+ Overview:
+ Return variables' name if variables are to used in monitor.
+ Returns:
+ - vars (:obj:`List[str]`): Variables' name list.
+ """
+ alpha_loss = ['alpha_loss'] if self._auto_alpha else []
+ return [
+ 'policy_loss',
+ 'critic_loss',
+ 'policy_norm',
+ 'critic_norm',
+ 'cur_lr_q',
+ 'cur_lr_p',
+ 'alpha',
+ 'target_q_value',
+ ] + alpha_loss
+
+
+@POLICY_REGISTRY.register('stevesac')
+class STEVESACPolicy(SACPolicy):
+ r"""
+ Overview:
+ Model based SAC with stochastic value expansion (arXiv 1807.01675).\
+ This implementation also uses value gradient w.r.t the same STEVE target.
+
+ https://arxiv.org/pdf/1807.01675.pdf
+
+ Config:
+ == ==================== ======== ============= =====================================
+ ID Symbol Type Default Value Description
+ == ==================== ======== ============= =====================================
+ 1 ``learn.grad_clip` float 100.0 | Max norm of gradients.
+ 2 ``learn.ensemble_size`` int 1 | The number of ensemble world models.
+ == ==================== ======== ============= =====================================
+
+ .. note::
+ For other configs, please refer to ding.policy.sac.SACPolicy.
+ """
+
+ config = dict(
+ learn=dict(
+ # (float) Max norm of gradients.
+ grad_clip=100,
+ # (int) The number of ensemble world models.
+ ensemble_size=1,
+ )
+ )
+
+ def _init_learn(self) -> None:
+ super()._init_learn()
+ self._target_model.requires_grad_(False)
+
+ self._grad_clip = self._cfg.learn.grad_clip
+ self._ensemble_size = self._cfg.learn.ensemble_size
+ self._auto_alpha = self._cfg.learn.auto_alpha
+ # TODO: auto alpha
+ assert not self._auto_alpha, "NotImplemented"
+
+ def actor_fn(obs: Tensor):
+ obs, dim = fold_batch(obs, 1)
+ (mu, sigma) = self._learn_model.forward(obs, mode='compute_actor')['logit']
+ dist = Independent(Normal(mu, sigma), 1)
+ pred = dist.rsample()
+ action = torch.tanh(pred)
+
+ log_prob = dist.log_prob(
+ pred
+ ) + 2 * (pred + torch.nn.functional.softplus(-2. * pred) - torch.log(torch.tensor(2.))).sum(-1)
+ aug_reward = -self._alpha.detach() * log_prob
+
+ return unfold_batch(action, dim), unfold_batch(aug_reward, dim)
+
+ self._actor_fn = actor_fn
+
+ def critic_fn(obss: Tensor, actions: Tensor, model: nn.Module):
+ eval_data = {'obs': obss, 'action': actions}
+ q_values = model.forward(eval_data, mode='compute_critic')['q_value']
+ return q_values
+
+ self._critic_fn = critic_fn
+ self._forward_learn_cnt = 0
+
+ def _forward_learn(self, data: dict, world_model, envstep) -> Dict[str, Any]:
+ # preprocess data
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ if len(data['action'].shape) == 1:
+ data['action'] = data['action'].unsqueeze(1)
+
+ # [B, D] -> [E, B, D]
+ data['next_obs'] = unsqueeze_repeat(data['next_obs'], self._ensemble_size)
+ data['reward'] = unsqueeze_repeat(data['reward'], self._ensemble_size)
+ data['done'] = unsqueeze_repeat(data['done'], self._ensemble_size)
+
+ self._learn_model.train()
+ self._target_model.train()
+
+ obss, actions, rewards, aug_rewards, dones = \
+ world_model.rollout(data['next_obs'], self._actor_fn, envstep, keep_ensemble=True)
+ rewards = torch.cat([data['reward'].unsqueeze(0), rewards])
+ dones = torch.cat([data['done'].unsqueeze(0), dones])
+
+ # (T, E, B)
+ target_q_values = q_evaluation(obss, actions, partial(self._critic_fn, model=self._target_model))
+ if self._twin_critic:
+ target_q_values = torch.min(target_q_values[0], target_q_values[1]) + aug_rewards
+ else:
+ target_q_values = target_q_values + aug_rewards
+
+ # (T+1, E, B)
+ discounts = ((1 - dones) * self._gamma).cumprod(dim=0)
+ discounts = torch.cat([torch.ones_like(discounts)[:1], discounts])
+ # (T, E, B)
+ cum_rewards = (rewards * discounts[:-1]).cumsum(dim=0)
+ discounted_q_values = target_q_values * discounts[1:]
+ steve_return = cum_rewards + discounted_q_values
+ # (T, B)
+ steve_return_mean = steve_return.mean(1)
+ with torch.no_grad():
+ steve_return_inv_var = 1 / (1e-8 + steve_return.var(1, unbiased=False))
+ steve_return_weight = steve_return_inv_var / (1e-8 + steve_return_inv_var.sum(dim=0))
+ # (B, )
+ steve_return = (steve_return_mean * steve_return_weight).sum(0)
+
+ eval_data = {'obs': data['obs'], 'action': data['action']}
+ q_values = self._learn_model.forward(eval_data, mode='compute_critic')['q_value']
+ if self._twin_critic:
+ critic_loss = 0.5 * torch.square(q_values[0] - steve_return.detach()) \
+ + 0.5 * torch.square(q_values[1] - steve_return.detach())
+ else:
+ critic_loss = 0.5 * torch.square(q_values - steve_return.detach())
+
+ critic_loss = critic_loss.mean()
+
+ policy_loss = -steve_return.mean()
+
+ # alpha_loss = None
+
+ loss_dict = {
+ 'critic_loss': critic_loss,
+ 'policy_loss': policy_loss,
+ # 'alpha_loss': alpha_loss.detach(),
+ }
+
+ norm_dict = self._update(loss_dict)
+
+ # =============
+ # after update
+ # =============
+ self._forward_learn_cnt += 1
+ # target update
+ self._target_model.update(self._learn_model.state_dict())
+
+ return {
+ 'cur_lr_q': self._optimizer_q.defaults['lr'],
+ 'cur_lr_p': self._optimizer_policy.defaults['lr'],
+ 'alpha': self._alpha.item(),
+ 'target_q_value': target_q_values.detach().mean().item(),
+ **norm_dict,
+ **loss_dict,
+ }
+
+ def _update(self, loss_dict):
+ # update critic
+ self._optimizer_q.zero_grad()
+ loss_dict['critic_loss'].backward()
+ critic_norm = nn.utils.clip_grad_norm_(self._model.critic.parameters(), self._grad_clip)
+ self._optimizer_q.step()
+ # update policy
+ self._optimizer_policy.zero_grad()
+ loss_dict['policy_loss'].backward()
+ policy_norm = nn.utils.clip_grad_norm_(self._model.actor.parameters(), self._grad_clip)
+ self._optimizer_policy.step()
+ # update temperature
+ # self._alpha_optim.zero_grad()
+ # loss_dict['alpha_loss'].backward()
+ # self._alpha_optim.step()
+ return {'policy_norm': policy_norm, 'critic_norm': critic_norm}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ r"""
+ Overview:
+ Return variables' name if variables are to used in monitor.
+ Returns:
+ - vars (:obj:`List[str]`): Variables' name list.
+ """
+ alpha_loss = ['alpha_loss'] if self._auto_alpha else []
+ return [
+ 'policy_loss',
+ 'critic_loss',
+ 'policy_norm',
+ 'critic_norm',
+ 'cur_lr_q',
+ 'cur_lr_p',
+ 'alpha',
+ 'target_q_value',
+ ] + alpha_loss
diff --git a/DI-engine/ding/policy/mbpolicy/tests/test_mbpolicy_utils.py b/DI-engine/ding/policy/mbpolicy/tests/test_mbpolicy_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d09ab215abaafee9103ea7abab47ceefb7175f8
--- /dev/null
+++ b/DI-engine/ding/policy/mbpolicy/tests/test_mbpolicy_utils.py
@@ -0,0 +1,19 @@
+import pytest
+import torch
+from ding.policy.mbpolicy.utils import q_evaluation
+
+
+@pytest.mark.unittest
+def test_q_evaluation():
+ T, B, O, A = 10, 20, 100, 30
+ obss = torch.randn(T, B, O)
+ actions = torch.randn(T, B, A)
+
+ def fake_q_fn(obss, actions):
+ # obss: flatten_B * O
+ # actions: flatten_B * A
+ # return: flatten_B
+ return obss.sum(-1)
+
+ q_value = q_evaluation(obss, actions, fake_q_fn)
+ assert q_value.shape == (T, B)
diff --git a/DI-engine/ding/policy/mbpolicy/utils.py b/DI-engine/ding/policy/mbpolicy/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b17c36e47f35474df81fcf3f450e8ad4f32fc448
--- /dev/null
+++ b/DI-engine/ding/policy/mbpolicy/utils.py
@@ -0,0 +1,148 @@
+from typing import Callable, Tuple, Union
+import torch
+from torch import Tensor
+from ding.torch_utils import fold_batch, unfold_batch
+from ding.rl_utils import generalized_lambda_returns
+from ding.torch_utils.network.dreamer import static_scan
+
+
+def q_evaluation(obss: Tensor, actions: Tensor, q_critic_fn: Callable[[Tensor, Tensor],
+ Tensor]) -> Union[Tensor, Tuple[Tensor, Tensor]]:
+ """
+ Overview:
+ Evaluate (observation, action) pairs along the trajectory
+
+ Arguments:
+ - obss (:obj:`torch.Tensor`): the observations along the trajectory
+ - actions (:obj:`torch.Size`): the actions along the trajectory
+ - q_critic_fn (:obj:`Callable`): the unified API :math:`Q(S_t, A_t)`
+
+ Returns:
+ - q_value (:obj:`torch.Tensor`): the action-value function evaluated along the trajectory
+
+ Shapes:
+ :math:`N`: time step
+ :math:`B`: batch size
+ :math:`O`: observation dimension
+ :math:`A`: action dimension
+
+ - obss: [N, B, O]
+ - actions: [N, B, A]
+ - q_value: [N, B]
+
+ """
+ obss, dim = fold_batch(obss, 1)
+ actions, _ = fold_batch(actions, 1)
+ q_values = q_critic_fn(obss, actions)
+ # twin critic
+ if isinstance(q_values, list):
+ return [unfold_batch(q_values[0], dim), unfold_batch(q_values[1], dim)]
+ return unfold_batch(q_values, dim)
+
+
+def imagine(cfg, world_model, start, actor, horizon, repeats=None):
+ dynamics = world_model.dynamics
+ flatten = lambda x: x.reshape([-1] + list(x.shape[2:]))
+ start = {k: flatten(v) for k, v in start.items()}
+
+ def step(prev, _):
+ state, _, _ = prev
+ feat = dynamics.get_feat(state)
+ inp = feat.detach()
+ action = actor(inp).sample()
+ succ = dynamics.img_step(state, action, sample=cfg.imag_sample)
+ return succ, feat, action
+
+ succ, feats, actions = static_scan(step, [torch.arange(horizon)], (start, None, None))
+ states = {k: torch.cat([start[k][None], v[:-1]], 0) for k, v in succ.items()}
+
+ return feats, states, actions
+
+
+def compute_target(cfg, world_model, critic, imag_feat, imag_state, reward, actor_ent, state_ent):
+ if "discount" in world_model.heads:
+ inp = world_model.dynamics.get_feat(imag_state)
+ discount = cfg.discount * world_model.heads["discount"](inp).mean
+ # TODO whether to detach
+ discount = discount.detach()
+ else:
+ discount = cfg.discount * torch.ones_like(reward)
+
+ value = critic(imag_feat).mode()
+ # value(imag_horizon, 16*64, 1)
+ # action(imag_horizon, 16*64, ch)
+ # discount(imag_horizon, 16*64, 1)
+ target = generalized_lambda_returns(value, reward[:-1], discount[:-1], cfg.lambda_)
+ weights = torch.cumprod(torch.cat([torch.ones_like(discount[:1]), discount[:-1]], 0), 0).detach()
+ return target, weights, value[:-1]
+
+
+def compute_actor_loss(
+ cfg,
+ actor,
+ reward_ema,
+ imag_feat,
+ imag_state,
+ imag_action,
+ target,
+ actor_ent,
+ state_ent,
+ weights,
+ base,
+):
+ metrics = {}
+ inp = imag_feat.detach()
+ policy = actor(inp)
+ actor_ent = policy.entropy()
+ # Q-val for actor is not transformed using symlog
+ if cfg.reward_EMA:
+ offset, scale = reward_ema(target)
+ normed_target = (target - offset) / scale
+ normed_base = (base - offset) / scale
+ adv = normed_target - normed_base
+ metrics.update(tensorstats(normed_target, "normed_target"))
+ values = reward_ema.values
+ metrics["EMA_005"] = values[0].detach().cpu().numpy().item()
+ metrics["EMA_095"] = values[1].detach().cpu().numpy().item()
+
+ actor_target = adv
+ if cfg.actor_entropy > 0:
+ actor_entropy = cfg.actor_entropy * actor_ent[:-1][:, :, None]
+ actor_target += actor_entropy
+ metrics["actor_entropy"] = torch.mean(actor_entropy).detach().cpu().numpy().item()
+ if cfg.actor_state_entropy > 0:
+ state_entropy = cfg.actor_state_entropy * state_ent[:-1]
+ actor_target += state_entropy
+ metrics["actor_state_entropy"] = torch.mean(state_entropy).detach().cpu().numpy().item()
+ actor_loss = -torch.mean(weights[:-1] * actor_target)
+ return actor_loss, metrics
+
+
+class RewardEMA(object):
+ """running mean and std"""
+
+ def __init__(self, device, alpha=1e-2):
+ self.device = device
+ self.values = torch.zeros((2, )).to(device)
+ self.alpha = alpha
+ self.range = torch.tensor([0.05, 0.95]).to(device)
+
+ def __call__(self, x):
+ flat_x = torch.flatten(x.detach())
+ x_quantile = torch.quantile(input=flat_x, q=self.range)
+ self.values = self.alpha * x_quantile + (1 - self.alpha) * self.values
+ scale = torch.clip(self.values[1] - self.values[0], min=1.0)
+ offset = self.values[0]
+ return offset.detach(), scale.detach()
+
+
+def tensorstats(tensor, prefix=None):
+ metrics = {
+ 'mean': torch.mean(tensor).detach().cpu().numpy(),
+ 'std': torch.std(tensor).detach().cpu().numpy(),
+ 'min': torch.min(tensor).detach().cpu().numpy(),
+ 'max': torch.max(tensor).detach().cpu().numpy(),
+ }
+ if prefix:
+ metrics = {f'{prefix}_{k}': v.item() for k, v in metrics.items()}
+ return metrics
diff --git a/DI-engine/ding/policy/mdqn.py b/DI-engine/ding/policy/mdqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..8842c11102e76edff3f210a9b48388a78039db5a
--- /dev/null
+++ b/DI-engine/ding/policy/mdqn.py
@@ -0,0 +1,281 @@
+from typing import List, Dict, Any
+import copy
+import torch
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import m_q_1step_td_data, m_q_1step_td_error
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+
+from .dqn import DQNPolicy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('mdqn')
+class MDQNPolicy(DQNPolicy):
+ """
+ Overview:
+ Policy class of Munchausen DQN algorithm, extended by auxiliary objectives.
+ Paper link: https://arxiv.org/abs/2007.14430.
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str mdqn | RL policy register name, refer to | This arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool False | Whether use priority(PER) | Priority sample,
+ | update priority
+ 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight
+ | ``_weight`` | to correct biased update. If True,
+ | priority must be True.
+ 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 7 ``nstep`` int 1, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 8 | ``learn.update`` int 1 | How many updates(iterations) to train | This args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ | ``_gpu``
+ 10 | ``learn.batch_`` int 32 | The number of samples of an iteration
+ | ``size``
+ 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
+ | ``_rate``
+ 12 | ``learn.target_`` int 2000 | Frequence of target network update. | Hard(assign) update
+ | ``update_freq``
+ 13 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
+ | ``done`` | calculation. | fake termination env
+ 14 ``collect.n_sample`` int 4 | The number of training samples of a | It varies from
+ | call of collector. | different envs
+ 15 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
+ | ``_len``
+ 16 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp',
+ | 'linear'].
+ 17 | ``other.eps.`` float 0.01 | start value of exploration rate | [0,1]
+ | ``start``
+ 18 | ``other.eps.`` float 0.001 | end value of exploration rate | [0,1]
+ | ``end``
+ 19 | ``other.eps.`` int 250000 | decay length of exploration | greater than 0. set
+ | ``decay`` | decay=250000 means
+ | the exploration rate
+ | decay from start
+ | value to end value
+ | during decay length.
+ 20 | ``entropy_tau`` float 0.003 | the ration of entropy in TD loss
+ 21 | ``alpha`` float 0.9 | the ration of Munchausen term to the
+ | TD loss
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='mdqn',
+ # (bool) Whether to use cuda in policy.
+ cuda=False,
+ # (bool) Whether learning policy is the same as collecting data policy(on-policy).
+ on_policy=False,
+ # (bool) Whether to enable priority experience sample.
+ priority=False,
+ # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (float) Discount factor(gamma) for returns.
+ discount_factor=0.97,
+ # (float) Entropy factor (tau) for Munchausen DQN.
+ entropy_tau=0.03,
+ # (float) Discount factor (alpha) for Munchausen term.
+ m_alpha=0.9,
+ # (int) The number of step for calculating target q_value.
+ nstep=1,
+ # learn_mode config
+ learn=dict(
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=3,
+ # (int) How many samples in a training batch
+ batch_size=64,
+ # (float) The step size of gradient descent
+ learning_rate=0.001,
+ # (int) Frequence of target network update.
+ target_update_freq=100,
+ # (bool) Whether ignore done(usually for max step termination env).
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with done is False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ # Only one of [n_sample, n_episode] shoule be set.
+ n_sample=4,
+ # (int) Split episodes or trajectories into pieces with length `unroll_len`.
+ unroll_len=1,
+ ),
+ eval=dict(), # for compability
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ # (float) Epsilon start value.
+ start=0.95,
+ # (float) Epsilon end value.
+ end=0.1,
+ # (int) Decay length(env step).
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer. Usually, larger buffer size is better.
+ replay_buffer_size=10000,
+ ),
+ ),
+ )
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For MDQN, it contains \
+ optimizer, algorithm-specific arguments such as entropy_tau, m_alpha and nstep, main and target model.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ # Optimizer
+ # set eps in order to consistent with the original paper implementation
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate, eps=0.0003125)
+
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+ self._entropy_tau = self._cfg.entropy_tau
+ self._m_alpha = self._cfg.m_alpha
+
+ # use model_wrapper for specialized demands of different modes
+ self._target_model = copy.deepcopy(self._model)
+ if 'target_update_freq' in self._cfg.learn:
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='assign',
+ update_kwargs={'freq': self._cfg.learn.target_update_freq}
+ )
+ elif 'target_theta' in self._cfg.learn:
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ else:
+ raise RuntimeError("DQN needs target network, please either indicate target_update_freq or target_theta")
+ self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss, action_gap, clip_frac, priority.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
+ dimension by some utility functions such as ``default_preprocess_learn``. \
+ For MDQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
+ ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \
+ and ``value_gamma``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for MDQNPolicy: ``ding.policy.tests.test_mdqn``.
+ """
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # Q-learning forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # Current q value (main model)
+ q_value = self._learn_model.forward(data['obs'])['logit']
+ # Target q value
+ with torch.no_grad():
+ target_q_value_current = self._target_model.forward(data['obs'])['logit']
+ target_q_value = self._target_model.forward(data['next_obs'])['logit']
+
+ data_m = m_q_1step_td_data(
+ q_value, target_q_value_current, target_q_value, data['action'], data['reward'].squeeze(0), data['done'],
+ data['weight']
+ )
+
+ loss, td_error_per_sample, action_gap, clipfrac = m_q_1step_td_error(
+ data_m, self._gamma, self._entropy_tau, self._m_alpha
+ )
+ # ====================
+ # Q-learning update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ self._optimizer.step()
+
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'q_value': q_value.mean().item(),
+ 'target_q_value': target_q_value.mean().item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ 'action_gap': action_gap.item(),
+ 'clip_frac': clipfrac.mean().item(),
+ }
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return ['cur_lr', 'total_loss', 'q_value', 'action_gap', 'clip_frac']
diff --git a/DI-engine/ding/policy/ngu.py b/DI-engine/ding/policy/ngu.py
new file mode 100644
index 0000000000000000000000000000000000000000..95fe2dd82ab1b0587a3c3e55270de0a6223b867b
--- /dev/null
+++ b/DI-engine/ding/policy/ngu.py
@@ -0,0 +1,597 @@
+from typing import List, Dict, Any, Tuple, Union, Optional
+from collections import namedtuple
+import torch
+import copy
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_nstep_td_error_with_rescale, get_nstep_return_data, \
+ get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import timestep_collate, default_collate, default_decollate
+from .base_policy import Policy
+
+
+@POLICY_REGISTRY.register('ngu')
+class NGUPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of NGU. The corresponding paper is `never give up: learning directed exploration strategies`.
+
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str dqn | RL policy register name, refer to | This arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool False | Whether use priority(PER) | Priority sample,
+ | update priority
+ 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight
+ | ``_weight`` | to correct biased update. If True,
+ | priority must be True.
+ 6 | ``discount_`` float 0.997, | Reward's future discount factor, aka. | May be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 7 ``nstep`` int 3, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 8 ``burnin_step`` int 2 | The timestep of burnin operation,
+ | which is designed to RNN hidden state
+ | difference caused by off-policy
+ 9 | ``learn.update`` int 1 | How many updates(iterations) to train | This args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 10 | ``learn.batch_`` int 64 | The number of samples of an iteration
+ | ``size``
+ 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
+ | ``_rate``
+ 12 | ``learn.value_`` bool True | Whether use value_rescale function for
+ | ``rescale`` | predicted value
+ 13 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
+ | ``update_freq``
+ 14 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
+ | ``done`` | calculation. | fake termination env
+ 15 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
+ | call of collector. | different envs
+ 16 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
+ | ``_len``
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='ngu',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=True,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=True,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.997,
+ # (int) N-step reward for target q_value estimation
+ nstep=5,
+ # (int) the timestep of burnin operation, which is designed to RNN hidden state difference
+ # caused by off-policy
+ burnin_step=20,
+ # (int) is the total length of [sequence sample] minus
+ # the length of burnin part in [sequence sample],
+ # i.e., = = +
+ learn_unroll_len=80, # set this key according to the episode length
+ learn=dict(
+ update_per_collect=1,
+ batch_size=64,
+ learning_rate=0.0001,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (float type) target_update_theta: Used for soft update of the target network,
+ # aka. Interpolation factor in polyak averaging for target networks.
+ target_update_theta=0.001,
+ # (bool) whether use value_rescale function for predicted value
+ value_rescale=True,
+ ignore_done=False,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ # `env_num` is used in hidden state, should equal to that one in env config.
+ # User should specify this value in user config.
+ env_num=None,
+ ),
+ eval=dict(
+ # `env_num` is used in hidden state, should equal to that one in env config.
+ # User should specify this value in user config.
+ env_num=None,
+ ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=10000, ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'ngu', ['ding.model.template.ngu']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Init the learner model of R2D2Policy
+
+ Arguments:
+ .. note::
+
+ The _init_learn method takes the argument from the self._cfg.learn in the config file
+
+ - learning_rate (:obj:`float`): The learning rate fo the optimizer
+ - gamma (:obj:`float`): The discount factor
+ - nstep (:obj:`int`): The num of n step return
+ - value_rescale (:obj:`bool`): Whether to use value rescaled loss in algorithm
+ - burnin_step (:obj:`int`): The num of step of burnin
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+ self._burnin_step = self._cfg.burnin_step
+ self._value_rescale = self._cfg.learn.value_rescale
+
+ self._target_model = copy.deepcopy(self._model)
+ # here we should not adopt the 'assign' mode of target network here because the reset bug
+ # self._target_model = model_wrap(
+ # self._target_model,
+ # wrapper_name='target',
+ # update_type='assign',
+ # update_kwargs={'freq': self._cfg.learn.target_update_freq}
+ # )
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_update_theta}
+ )
+ self._target_model = model_wrap(
+ self._target_model, wrapper_name='hidden_state', state_num=self._cfg.learn.batch_size, save_prev_state=True
+ )
+ self._learn_model = model_wrap(
+ self._model, wrapper_name='hidden_state', state_num=self._cfg.learn.batch_size, save_prev_state=True
+ )
+ self._learn_model = model_wrap(self._learn_model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> dict:
+ r"""
+ Overview:
+ Preprocess the data to fit the required data format for learning
+
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function
+
+ Returns:
+ - data (:obj:`Dict[str, Any]`): the processed data, including at least \
+ ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight']
+ - data_info (:obj:`dict`): the data info, such as replay_buffer_idx, replay_unique_id
+ """
+
+ # data preprocess
+ data = timestep_collate(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ if self._priority_IS_weight:
+ assert self._priority, "Use IS Weight correction, but Priority is not used."
+ if self._priority and self._priority_IS_weight:
+ data['weight'] = data['IS']
+ else:
+ data['weight'] = data.get('weight', None)
+
+ bs = self._burnin_step
+
+ # data['done'], data['weight'], data['value_gamma'] is used in def _forward_learn() to calculate
+ # the q_nstep_td_error, should be length of [self._sequence_len-self._burnin_step]
+ ignore_done = self._cfg.learn.ignore_done
+ if ignore_done:
+ data['done'] = [None for _ in range(self._sequence_len - bs - self._nstep)]
+ else:
+ data['done'] = data['done'][bs:].float() # for computation of online model self._learn_model
+ # NOTE that after the proprocessing of get_nstep_return_data() in _get_train_sample
+ # the data['done'] [t] is already the n-step done
+
+ # if the data don't include 'weight' or 'value_gamma' then fill in None in a list
+ # with length of [self._sequence_len-self._burnin_step],
+ # below is two different implementation ways
+ if 'value_gamma' not in data:
+ data['value_gamma'] = [None for _ in range(self._sequence_len - bs)]
+ else:
+ data['value_gamma'] = data['value_gamma'][bs:]
+
+ if 'weight' not in data:
+ data['weight'] = [None for _ in range(self._sequence_len - bs)]
+ else:
+ data['weight'] = data['weight'] * torch.ones_like(data['done'])
+ # every timestep in sequence has same weight, which is the _priority_IS_weight in PER
+
+ # the burnin_nstep_obs is used to calculate the init hidden state of rnn for the calculation of the q_value,
+ # target_q_value, and target_q_action
+ data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep]
+ data['burnin_nstep_action'] = data['action'][:bs + self._nstep]
+ data['burnin_nstep_reward'] = data['reward'][:bs + self._nstep]
+ data['burnin_nstep_beta'] = data['beta'][:bs + self._nstep]
+
+ # split obs into three parts 'burnin_obs' [0:bs], 'main_obs' [bs:bs+nstep], 'target_obs' [bs+nstep:]
+ # data['burnin_obs'] = data['obs'][:bs]
+ data['main_obs'] = data['obs'][bs:-self._nstep]
+ data['target_obs'] = data['obs'][bs + self._nstep:]
+
+ # data['burnin_action'] = data['action'][:bs]
+ data['main_action'] = data['action'][bs:-self._nstep]
+ data['target_action'] = data['action'][bs + self._nstep:]
+
+ # data['burnin_reward'] = data['reward'][:bs]
+ data['main_reward'] = data['reward'][bs:-self._nstep]
+ data['target_reward'] = data['reward'][bs + self._nstep:]
+
+ # data['burnin_beta'] = data['beta'][:bs]
+ data['main_beta'] = data['beta'][bs:-self._nstep]
+ data['target_beta'] = data['beta'][bs + self._nstep:]
+
+ # Note that Must be here after the previous slicing operation
+ data['action'] = data['action'][bs:-self._nstep]
+ data['reward'] = data['reward'][bs:-self._nstep]
+
+ return data
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Acquire the data, calculate the loss and optimize learner model.
+
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least \
+ ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight']
+
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including cur_lr and total_loss
+ - cur_lr (:obj:`float`): Current learning rate
+ - total_loss (:obj:`float`): The calculated loss
+ """
+ # forward
+ data = self._data_preprocess_learn(data)
+ self._learn_model.train()
+ self._target_model.train()
+ # use the hidden state in timestep=0
+ self._learn_model.reset(data_id=None, state=data['prev_state'][0])
+ self._target_model.reset(data_id=None, state=data['prev_state'][0])
+
+ if len(data['burnin_nstep_obs']) != 0:
+ with torch.no_grad():
+ inputs = {
+ 'obs': data['burnin_nstep_obs'],
+ 'action': data['burnin_nstep_action'],
+ 'reward': data['burnin_nstep_reward'],
+ 'beta': data['burnin_nstep_beta'],
+ 'enable_fast_timestep': True
+ }
+ tmp = self._learn_model.forward(
+ inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep]
+ )
+ tmp_target = self._target_model.forward(
+ inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep]
+ )
+
+ inputs = {
+ 'obs': data['main_obs'],
+ 'action': data['main_action'],
+ 'reward': data['main_reward'],
+ 'beta': data['main_beta'],
+ 'enable_fast_timestep': True
+ }
+ self._learn_model.reset(data_id=None, state=tmp['saved_state'][0])
+ q_value = self._learn_model.forward(inputs)['logit']
+
+ self._learn_model.reset(data_id=None, state=tmp['saved_state'][1])
+ self._target_model.reset(data_id=None, state=tmp_target['saved_state'][1])
+
+ next_inputs = {
+ 'obs': data['target_obs'],
+ 'action': data['target_action'],
+ 'reward': data['target_reward'],
+ 'beta': data['target_beta'],
+ 'enable_fast_timestep': True
+ }
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(next_inputs)['logit']
+ # argmax_action double_dqn
+ target_q_action = self._learn_model.forward(next_inputs)['action']
+
+ action, reward, done, weight = data['action'], data['reward'], data['done'], data['weight']
+ value_gamma = [
+ None for _ in range(self._sequence_len - self._burnin_step)
+ ] # NOTE this is important, because we use diffrent gamma according to their beta in NGU alg.
+
+ # T, B, nstep -> T, nstep, B
+ reward = reward.permute(0, 2, 1).contiguous()
+ loss = []
+ td_error = []
+ self._gamma = [self.index_to_gamma[int(i)] for i in data['main_beta'][0]] # T, B -> B, e.g. 75,64 -> 64
+
+ # reward torch.Size([4, 5, 64])
+ for t in range(self._sequence_len - self._burnin_step - self._nstep):
+ # here t=0 means timestep in the original sample sequence, we minus self._nstep
+ # because for the last timestep in the sequence, we don't have their target obs
+ td_data = q_nstep_td_data(
+ q_value[t], target_q_value[t], action[t], target_q_action[t], reward[t], done[t], weight[t]
+ )
+ if self._value_rescale:
+ l, e = q_nstep_td_error_with_rescale(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t])
+ loss.append(l)
+ td_error.append(e.abs())
+ else:
+ l, e = q_nstep_td_error(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t])
+ loss.append(l)
+ td_error.append(e.abs())
+ loss = sum(loss) / (len(loss) + 1e-8)
+
+ # using the mixture of max and mean absolute n-step TD-errors as the priority of the sequence
+ td_error_per_sample = 0.9 * torch.max(
+ torch.stack(td_error), dim=0
+ )[0] + (1 - 0.9) * (torch.sum(torch.stack(td_error), dim=0) / (len(td_error) + 1e-8))
+ # td_error shape list(, B),
+ # for example, (75,64)
+ # torch.sum(torch.stack(td_error), dim=0) can also be replaced with sum(td_error)
+
+ # update
+ self._optimizer.zero_grad()
+ loss.backward()
+ self._optimizer.step()
+ # after update
+ self._target_model.update(self._learn_model.state_dict())
+
+ # the information for debug
+ batch_range = torch.arange(action[0].shape[0])
+ q_s_a_t0 = q_value[0][batch_range, action[0]]
+ target_q_s_a_t0 = target_q_value[0][batch_range, target_q_action[0]]
+
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ # the first timestep in the sequence, may not be the start of episode
+ 'q_s_taken-a_t0': q_s_a_t0.mean().item(),
+ 'target_q_s_max-a_t0': target_q_s_a_t0.mean().item(),
+ 'q_s_a-mean_t0': q_value[0].mean().item(),
+ }
+
+ def _reset_learn(self, data_id: Optional[List[int]] = None) -> None:
+ self._learn_model.reset(data_id=data_id)
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ """
+ assert 'unroll_len' not in self._cfg.collect, "ngu use default "
+ self._nstep = self._cfg.nstep
+ self._burnin_step = self._cfg.burnin_step
+ self._gamma = self._cfg.discount_factor
+ self._sequence_len = self._cfg.learn_unroll_len + self._cfg.burnin_step
+ self._unroll_len = self._sequence_len
+ self._collect_model = model_wrap(
+ self._model, wrapper_name='hidden_state', state_num=self._cfg.collect.env_num, save_prev_state=True
+ )
+ self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample')
+ self._collect_model.reset()
+ self.index_to_gamma = { # NOTE
+ i: 1 - torch.exp(
+ (
+ (self._cfg.collect.env_num - 1 - i) * torch.log(torch.tensor(1 - 0.997)) +
+ i * torch.log(torch.tensor(1 - 0.99))
+ ) / (self._cfg.collect.env_num - 1)
+ )
+ for i in range(self._cfg.collect.env_num)
+ }
+ # NOTE: for NGU policy collect phase
+ self.beta_index = {
+ i: torch.randint(0, self._cfg.collect.env_num, [1])
+ for i in range(self._cfg.collect.env_num)
+ }
+ # epsilon=0.4, alpha=9
+ self.eps = {i: 0.4 ** (1 + 8 * i / (self._cfg.collect.env_num - 1)) for i in range(self._cfg.collect.env_num)}
+
+ def _forward_collect(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Collect output according to eps_greedy plugin
+
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs'].
+
+ Returns:
+ - data (:obj:`dict`): The collected data
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+
+ obs = data['obs']
+ prev_action = data['prev_action'].long()
+ prev_reward_extrinsic = data['prev_reward_extrinsic']
+
+ beta_index = default_collate(list(self.beta_index.values()))
+ if len(data_id) != self._cfg.collect.env_num:
+ # in case, some env is in reset state and only return part data
+ beta_index = beta_index[data_id]
+
+ if self._cuda:
+ obs = to_device(obs, self._device)
+ beta_index = to_device(beta_index, self._device)
+ prev_action = to_device(prev_action, self._device)
+ prev_reward_extrinsic = to_device(prev_reward_extrinsic, self._device)
+ # TODO(pu): add prev_reward_intrinsic to network input,
+ # reward uses some kind of embedding instead of 1D value
+ data = {
+ 'obs': obs,
+ 'prev_action': prev_action,
+ 'prev_reward_extrinsic': prev_reward_extrinsic,
+ 'beta': beta_index
+ }
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, data_id=data_id, eps=self.eps, inference=True)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None:
+ self._collect_model.reset(data_id=data_id)
+ # NOTE: for NGU policy, in collect phase, each episode, we sample a new beta for each env
+ if data_id is not None:
+ self.beta_index[data_id[0]] = torch.randint(0, self._cfg.collect.env_num, [1])
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple, env_id) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['reward', 'done'] \
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ if hasattr(timestep, 'null'):
+ transition = {
+ 'beta': self.beta_index[env_id],
+ 'obs': obs['obs'], # NOTE: input obs including obs, prev_action, prev_reward_extrinsic
+ 'action': model_output['action'],
+ 'prev_state': model_output['prev_state'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ 'null': timestep.null,
+ }
+ else:
+ transition = {
+ 'beta': self.beta_index[env_id],
+ 'obs': obs['obs'], # NOTE: input obs including obs, prev_action, prev_reward_extrinsic
+ 'action': model_output['action'],
+ 'prev_state': model_output['prev_state'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ 'null': False,
+ }
+ return transition
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the trajectory and the n step return data, then sample from the n_step return data
+
+ Arguments:
+ - data (:obj:`list`): The trajectory's cache
+
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ data = get_nstep_return_data(data, self._nstep, gamma=self.index_to_gamma[int(data[0]['beta'])].item())
+ return get_train_sample(data, self._sequence_len)
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model with argmax strategy.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='hidden_state', state_num=self._cfg.eval.env_num)
+ self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+ # NOTE: for NGU policy eval phase
+ # beta_index = 0 -> beta is approximately 0
+ self.beta_index = {i: torch.tensor([0]) for i in range(self._cfg.eval.env_num)}
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of collect mode, similar to ``self._forward_collect``.
+
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs'].
+
+ Returns:
+ - output (:obj:`dict`): Dict type data, including at least inferred action according to input obs.
+ """
+
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+
+ obs = data['obs']
+ prev_action = data['prev_action'].long()
+ prev_reward_extrinsic = data['prev_reward_extrinsic']
+
+ beta_index = default_collate(list(self.beta_index.values()))
+ if len(data_id) != self._cfg.collect.env_num:
+ # in case, some env is in reset state and only return part data
+ beta_index = beta_index[data_id]
+
+ if self._cuda:
+ obs = to_device(obs, self._device)
+ beta_index = to_device(beta_index, self._device)
+ prev_action = to_device(prev_action, self._device)
+ prev_reward_extrinsic = to_device(prev_reward_extrinsic, self._device)
+ # TODO(pu): add prev_reward_intrinsic to network input,
+ # reward uses some kind of embedding instead of 1D value
+ data = {
+ 'obs': obs,
+ 'prev_action': prev_action,
+ 'prev_reward_extrinsic': prev_reward_extrinsic,
+ 'beta': beta_index
+ }
+
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, data_id=data_id, inference=True)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
+ self._eval_model.reset(data_id=data_id)
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return super()._monitor_vars_learn() + [
+ 'total_loss', 'priority', 'q_s_taken-a_t0', 'target_q_s_max-a_t0', 'q_s_a-mean_t0'
+ ]
diff --git a/DI-engine/ding/policy/offppo_collect_traj.py b/DI-engine/ding/policy/offppo_collect_traj.py
new file mode 100644
index 0000000000000000000000000000000000000000..219d582c830ef15a6e3cdc5d122a77722b2aff98
--- /dev/null
+++ b/DI-engine/ding/policy/offppo_collect_traj.py
@@ -0,0 +1,309 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import torch
+import copy
+import numpy as np
+from torch.distributions import Independent, Normal
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, get_gae_with_default_last_value, \
+ v_nstep_td_data, v_nstep_td_error, get_nstep_return_data, get_train_sample, gae, gae_data, ppo_error_continuous,\
+ get_gae
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY, split_data_generator, RunningMeanStd
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('offppo_collect_traj')
+class OffPPOCollectTrajPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of off policy PPO algorithm to collect expert traj for R2D3.
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='ppo',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used)
+ on_policy=True,
+ # (bool) Whether to use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (bool) Whether to use nstep_return for value loss
+ nstep_return=False,
+ nstep=3,
+ learn=dict(
+
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=5,
+ batch_size=64,
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ ignore_done=False,
+ ),
+ collect=dict(
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc)
+ gae_lambda=0.95,
+ ),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'vac', ['ding.model.template.vac']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the optimizer, algorithm config and the main model.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPO"
+ # Orthogonal init
+ for m in self._model.modules():
+ if isinstance(m, torch.nn.Conv2d):
+ torch.nn.init.orthogonal_(m.weight)
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.orthogonal_(m.weight)
+ # Optimizer
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+
+ # Algorithm config
+ self._value_weight = self._cfg.learn.value_weight
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._clip_ratio = self._cfg.learn.clip_ratio
+ self._adv_norm = self._cfg.learn.adv_norm
+ self._nstep = self._cfg.nstep
+ self._nstep_return = self._cfg.nstep_return
+ # Main model
+ self._learn_model.reset()
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`):
+ Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \
+ adv_abs_max, approx_kl, clipfrac
+ """
+ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=self._nstep_return)
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # PPO forward
+ # ====================
+
+ self._learn_model.train()
+ # normal ppo
+ if not self._nstep_return:
+ output = self._learn_model.forward(data['obs'], mode='compute_actor_critic')
+ adv = data['adv']
+ return_ = data['value'] + adv
+ if self._adv_norm:
+ # Normalize advantage in a total train_batch
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
+ # Calculate ppo error
+ ppodata = ppo_data(
+ output['logit'], data['logit'], data['action'], output['value'], data['value'], adv, return_,
+ data['weight']
+ )
+ ppo_loss, ppo_info = ppo_error(ppodata, self._clip_ratio)
+ wv, we = self._value_weight, self._entropy_weight
+ total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
+
+ else:
+ output = self._learn_model.forward(data['obs'], mode='compute_actor')
+ adv = data['adv']
+ if self._adv_norm:
+ # Normalize advantage in a total train_batch
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
+
+ # Calculate ppo error
+ ppodata = ppo_policy_data(output['logit'], data['logit'], data['action'], adv, data['weight'])
+ ppo_policy_loss, ppo_info = ppo_policy_error(ppodata, self._clip_ratio)
+ wv, we = self._value_weight, self._entropy_weight
+ next_obs = data.get('next_obs')
+ value_gamma = data.get('value_gamma')
+ reward = data.get('reward')
+ # current value
+ value = self._learn_model.forward(data['obs'], mode='compute_critic')
+ # target value
+ next_data = {'obs': next_obs}
+ target_value = self._learn_model.forward(next_data['obs'], mode='compute_critic')
+ # TODO what should we do here to keep shape
+ assert self._nstep > 1
+ td_data = v_nstep_td_data(
+ value['value'], target_value['value'], reward.t(), data['done'], data['weight'], value_gamma
+ )
+ # calculate v_nstep_td critic_loss
+ critic_loss, td_error_per_sample = v_nstep_td_error(td_data, self._gamma, self._nstep)
+ ppo_loss_data = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
+ ppo_loss = ppo_loss_data(ppo_policy_loss.policy_loss, critic_loss, ppo_policy_loss.entropy_loss)
+ total_loss = ppo_policy_loss.policy_loss + wv * critic_loss - we * ppo_policy_loss.entropy_loss
+
+ # ====================
+ # PPO update
+ # ====================
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ self._optimizer.step()
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': ppo_loss.policy_loss.item(),
+ 'value_loss': ppo_loss.value_loss.item(),
+ 'entropy_loss': ppo_loss.entropy_loss.item(),
+ 'adv_abs_max': adv.abs().max().item(),
+ 'approx_kl': ppo_info.approx_kl,
+ 'clipfrac': ppo_info.clipfrac,
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ # self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
+ # NOTE this policy is to collect expert traj, so we have to use argmax_sample wrapper
+ self._collect_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._collect_model.reset()
+ self._gamma = self._cfg.collect.discount_factor
+ self._gae_lambda = self._cfg.collect.gae_lambda
+ self._nstep = self._cfg.nstep
+ self._nstep_return = self._cfg.nstep_return
+
+ def _forward_collect(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function for collect mode
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs'].
+ Returns:
+ - data (:obj:`dict`): The collected data
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, mode='compute_actor_critic')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ """
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ transition = {
+ 'obs': obs,
+ 'action': model_output['action'],
+ # 'prev_state': model_output['prev_state'],
+ 'prev_state': None,
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the trajectory and calculate GAE, return one data to cache for next time calculation
+ Arguments:
+ - data (:obj:`list`): The trajectory's cache
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ from copy import deepcopy
+ # data_one_step = deepcopy(get_nstep_return_data(data, 1, gamma=self._gamma))
+ data_one_step = deepcopy(data)
+ data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
+ for i in range(len(data)):
+ # here we record the one-step done, we don't need record one-step reward,
+ # because the n-step reward in data already include one-step reward
+ data[i]['done_one_step'] = data_one_step[i]['done']
+ return get_train_sample(data, self._unroll_len) # self._unroll_len_add_burnin_step
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model with argmax strategy.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function for eval mode, similar to ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs'].
+ Returns:
+ - output (:obj:`dict`): Dict type data, including at least inferred action according to input obs.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return super()._monitor_vars_learn() + [
+ 'policy_loss', 'value_loss', 'entropy_loss', 'adv_abs_max', 'approx_kl', 'clipfrac'
+ ]
diff --git a/DI-engine/ding/policy/pc.py b/DI-engine/ding/policy/pc.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c472462b0940d3b25ac0935c99862c9f53a6191
--- /dev/null
+++ b/DI-engine/ding/policy/pc.py
@@ -0,0 +1,186 @@
+import math
+from typing import List, Dict, Any, Tuple
+from collections import namedtuple
+
+import torch
+import torch.nn as nn
+from torch.optim import Adam, SGD, AdamW
+from torch.optim.lr_scheduler import LambdaLR
+
+from ding.policy import Policy
+from ding.model import model_wrap
+from ding.torch_utils import to_device
+from ding.utils import EasyTimer
+from ding.utils import POLICY_REGISTRY
+
+
+@POLICY_REGISTRY.register('pc_bfs')
+class ProcedureCloningBFSPolicy(Policy):
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'pc_bfs', ['ding.model.template.procedure_cloning']
+
+ config = dict(
+ type='pc',
+ cuda=False,
+ on_policy=False,
+ continuous=False,
+ max_bfs_steps=100,
+ learn=dict(
+ update_per_collect=1,
+ batch_size=32,
+ learning_rate=1e-5,
+ lr_decay=False,
+ decay_epoch=30,
+ decay_rate=0.1,
+ warmup_lr=1e-4,
+ warmup_epoch=3,
+ optimizer='SGD',
+ momentum=0.9,
+ weight_decay=1e-4,
+ ),
+ collect=dict(
+ unroll_len=1,
+ noise=False,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=10000)),
+ )
+
+ def _init_learn(self):
+ assert self._cfg.learn.optimizer in ['SGD', 'Adam']
+ if self._cfg.learn.optimizer == 'SGD':
+ self._optimizer = SGD(
+ self._model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ weight_decay=self._cfg.learn.weight_decay,
+ momentum=self._cfg.learn.momentum
+ )
+ elif self._cfg.learn.optimizer == 'Adam':
+ if self._cfg.learn.weight_decay is None:
+ self._optimizer = Adam(
+ self._model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ )
+ else:
+ self._optimizer = AdamW(
+ self._model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ weight_decay=self._cfg.learn.weight_decay
+ )
+ if self._cfg.learn.lr_decay:
+
+ def lr_scheduler_fn(epoch):
+ if epoch <= self._cfg.learn.warmup_epoch:
+ return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate
+ else:
+ ratio = (epoch - self._cfg.learn.warmup_epoch) // self._cfg.learn.decay_epoch
+ return math.pow(self._cfg.learn.decay_rate, ratio)
+
+ self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn)
+ self._timer = EasyTimer(cuda=True)
+ self._learn_model = model_wrap(self._model, 'base')
+ self._learn_model.reset()
+ self._max_bfs_steps = self._cfg.max_bfs_steps
+ self._maze_size = self._cfg.maze_size
+ self._num_actions = self._cfg.num_actions
+
+ self._loss = nn.CrossEntropyLoss()
+
+ def process_states(self, observations, maze_maps):
+ """Returns [B, W, W, 3] binary values. Channels are (wall; goal; obs)"""
+ loc = torch.nn.functional.one_hot(
+ (observations[:, 0] * self._maze_size + observations[:, 1]).long(),
+ self._maze_size * self._maze_size,
+ ).long()
+ loc = torch.reshape(loc, [observations.shape[0], self._maze_size, self._maze_size])
+ states = torch.cat([maze_maps, loc], dim=-1).long()
+ return states
+
+ def _forward_learn(self, data):
+ if self._cuda:
+ collated_data = to_device(data, self._device)
+ else:
+ collated_data = data
+ observations = collated_data['obs'],
+ bfs_input_maps, bfs_output_maps = collated_data['bfs_in'].long(), collated_data['bfs_out'].long()
+ states = observations
+ bfs_input_onehot = torch.nn.functional.one_hot(bfs_input_maps, self._num_actions + 1).float()
+
+ bfs_states = torch.cat([
+ states,
+ bfs_input_onehot,
+ ], dim=-1)
+ logits = self._model(bfs_states)['logit']
+ logits = logits.flatten(0, -2)
+ labels = bfs_output_maps.flatten(0, -1)
+
+ loss = self._loss(logits, labels)
+ preds = torch.argmax(logits, dim=-1)
+ acc = torch.sum((preds == labels)) / preds.shape[0]
+
+ self._optimizer.zero_grad()
+ loss.backward()
+ self._optimizer.step()
+ pred_loss = loss.item()
+
+ cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups]
+ cur_lr = sum(cur_lr) / len(cur_lr)
+ return {'cur_lr': cur_lr, 'total_loss': pred_loss, 'acc': acc}
+
+ def _monitor_vars_learn(self):
+ return ['cur_lr', 'total_loss', 'acc']
+
+ def _init_eval(self):
+ self._eval_model = model_wrap(self._model, wrapper_name='base')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data):
+ if self._cuda:
+ data = to_device(data, self._device)
+ max_len = self._max_bfs_steps
+ data_id = list(data.keys())
+ output = {}
+
+ for ii in data_id:
+ states = data[ii].unsqueeze(0)
+ bfs_input_maps = self._num_actions * torch.ones([1, self._maze_size, self._maze_size]).long()
+ if self._cuda:
+ bfs_input_maps = to_device(bfs_input_maps, self._device)
+ xy = torch.where(states[:, :, :, -1] == 1)
+ observation = (xy[1][0].item(), xy[2][0].item())
+
+ i = 0
+ while bfs_input_maps[0, observation[0], observation[1]].item() == self._num_actions and i < max_len:
+ bfs_input_onehot = torch.nn.functional.one_hot(bfs_input_maps, self._num_actions + 1).long()
+
+ bfs_states = torch.cat([
+ states,
+ bfs_input_onehot,
+ ], dim=-1)
+ logits = self._model(bfs_states)['logit']
+ bfs_input_maps = torch.argmax(logits, dim=-1)
+ i += 1
+ output[ii] = bfs_input_maps[0, observation[0], observation[1]]
+ if self._cuda:
+ output[ii] = {'action': to_device(output[ii], 'cpu'), 'info': {}}
+ if output[ii]['action'].item() == self._num_actions:
+ output[ii]['action'] = torch.randint(low=0, high=self._num_actions, size=[1])[0]
+ return output
+
+ def _init_collect(self) -> None:
+ raise NotImplementedError
+
+ def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]:
+ raise NotImplementedError
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ raise NotImplementedError
+
+ def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ raise NotImplementedError
diff --git a/DI-engine/ding/policy/pdqn.py b/DI-engine/ding/policy/pdqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b66e263abbea65b2c0afc87e151f6ec72db1878
--- /dev/null
+++ b/DI-engine/ding/policy/pdqn.py
@@ -0,0 +1,527 @@
+from typing import List, Dict, Any, Tuple
+from collections import namedtuple
+import copy
+import torch
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, get_nstep_return_data, get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('pdqn')
+class PDQNPolicy(Policy):
+ """
+ Overview:
+ Policy class of PDQN algorithm, which extends the DQN algorithm on discrete-continuous hybrid action spaces.
+ Paper link: https://arxiv.org/abs/1810.06394.
+
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str pdqn | RL policy register name, refer to | This arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy | This value is always
+ | or off-policy | False for PDQN
+ 4 ``priority`` bool False | Whether use priority(PER) | Priority sample,
+ | update priority
+ 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight
+ | ``_weight`` | to correct biased update. If True,
+ | priority must be True.
+ 6 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | May be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+
+ 7 ``nstep`` int 1, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 8 | ``learn.update`` int 3 | How many updates(iterations) to train | This args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 9 | ``learn.batch_`` int 64 | The number of samples of an iteration
+ | ``size``
+ | ``_gpu``
+ 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
+ | ``_rate``
+ 12 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
+ | ``update_freq``
+ 13 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
+ | ``done`` | calculation. | fake termination env
+ 14 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
+ | call of collector. | different envs
+ 15 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
+ | ``_len``
+ 16 | ``collect.noise`` float 0.1 | add noise to continuous args
+ | ``_sigma`` | during collection
+ 17 | ``other.eps.type`` str exp | exploration rate decay type | Support ['exp',
+ | 'linear'].
+ 18 | ``other.eps.`` float 0.95 | start value of exploration rate | [0,1]
+ | ``start``
+ 19 | ``other.eps.`` float 0.05 | end value of exploration rate | [0,1]
+ | ``end``
+ 20 | ``other.eps.`` int 10000 | decay length of exploration | greater than 0. set
+ | ``decay`` | decay=10000 means
+ | the exploration rate
+ | decay from start
+ | value to end value
+ | during decay length.
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='pdqn',
+ # (bool) Whether to use cuda in policy.
+ cuda=False,
+ # (bool) Whether learning policy is the same as collecting data policy(on-policy).
+ on_policy=False,
+ # (bool) Whether to enable priority experience sample.
+ priority=False,
+ # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (float) Discount factor(gamma) for returns.
+ discount_factor=0.97,
+ # (int) The number of step for calculating target q_value.
+ nstep=1,
+ # learn_mode config
+ learn=dict(
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=3,
+ # (int) How many samples in a training batch.
+ batch_size=64,
+ # (float) The step size of gradient descent.
+ learning_rate=0.001,
+ # (int) Frequence of target network update.
+ target_theta=0.005,
+ # (bool) Whether ignore done(usually for max step termination env).
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with done is False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ # Only one of [n_sample, n_episode] shoule be set.
+ # n_sample=8,
+ # (int) Split episodes or trajectories into pieces with length `unroll_len`.
+ unroll_len=1,
+ # (float) It is a must to add noise during collection. So here omits noise and only set ``noise_sigma``.
+ noise_sigma=0.1,
+ ),
+ eval=dict(), # for compatibility
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ # (float) Epsilon start value.
+ start=0.95,
+ # (float) Epsilon end value.
+ end=0.1,
+ # (int) Decay length(env step)
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer. Usually, larger buffer size is better.
+ replay_buffer_size=10000,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For example about PDQN, its registered name is ``pdqn`` and the import_names is \
+ ``ding.model.template.pdqn``.
+ """
+ return 'pdqn', ['ding.model.template.pdqn']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For PDQN, it mainly \
+ contains two optimizers, algorithm-specific arguments such as nstep and gamma, main and target model.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ # Optimizer
+ self._dis_optimizer = Adam(
+ list(self._model.dis_head.parameters()) + list(self._model.cont_encoder.parameters()),
+ # this is very important to put cont_encoder.parameters in here.
+ lr=self._cfg.learn.learning_rate_dis
+ )
+ self._cont_optimizer = Adam(list(self._model.cont_head.parameters()), lr=self._cfg.learn.learning_rate_cont)
+
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+
+ # use model_wrapper for specialized demands of different modes
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='hybrid_argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+ self.cont_train_cnt = 0
+ self.disc_train_cnt = 0
+ self.train_cnt = 0
+
+ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss, q value, target_q_value, priority.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
+ dimension by some utility functions such as ``default_preprocess_learn``. \
+ For PDQN, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
+ ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \
+ and ``value_gamma``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PDQNPolicy: ``ding.policy.tests.test_pdqn``.
+ """
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ self.train_cnt += 1
+ # ================================
+ # Continuous args network forward
+ # ================================
+ if self.train_cnt == 1 or self.train_cnt % self._cfg.learn.update_circle in range(5, 10):
+ dis_loss = torch.Tensor([0])
+ td_error_per_sample = torch.Tensor([0])
+ target_q_value = torch.Tensor([0])
+
+ action_args = self._learn_model.forward(data['obs'], mode='compute_continuous')['action_args']
+
+ # Current q value (main model) for cont loss
+ discrete_inputs = {'state': data['obs'], 'action_args': action_args}
+ # with torch.no_grad():
+ q_pi_action_value = self._learn_model.forward(discrete_inputs, mode='compute_discrete')['logit']
+ cont_loss = -q_pi_action_value.sum(dim=-1).mean()
+
+ # ================================
+ # Continuous args network update
+ # ================================
+ self._cont_optimizer.zero_grad()
+ cont_loss.backward()
+ self._cont_optimizer.step()
+
+ # ====================
+ # Q-learning forward
+ # ====================
+ if self.train_cnt == 1 or self.train_cnt % self._cfg.learn.update_circle in range(0, 5):
+ cont_loss = torch.Tensor([0])
+ q_pi_action_value = torch.Tensor([0])
+ self._learn_model.train()
+ self._target_model.train()
+ # Current q value (main model)
+ discrete_inputs = {'state': data['obs'], 'action_args': data['action']['action_args']}
+ q_data_action_args_value = self._learn_model.forward(discrete_inputs, mode='compute_discrete')['logit']
+
+ # Target q value
+ with torch.no_grad():
+ next_action_args = self._learn_model.forward(data['next_obs'], mode='compute_continuous')['action_args']
+ next_action_args_cp = next_action_args.clone().detach()
+ next_discrete_inputs = {'state': data['next_obs'], 'action_args': next_action_args_cp}
+ target_q_value = self._target_model.forward(next_discrete_inputs, mode='compute_discrete')['logit']
+ # Max q value action (main model)
+ target_q_discrete_action = self._learn_model.forward(
+ next_discrete_inputs, mode='compute_discrete'
+ )['action']['action_type']
+
+ data_n = q_nstep_td_data(
+ q_data_action_args_value, target_q_value, data['action']['action_type'], target_q_discrete_action,
+ data['reward'], data['done'], data['weight']
+ )
+ value_gamma = data.get('value_gamma')
+ dis_loss, td_error_per_sample = q_nstep_td_error(
+ data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma
+ )
+
+ # ====================
+ # Q-learning update
+ # ====================
+ self._dis_optimizer.zero_grad()
+ dis_loss.backward()
+ self._dis_optimizer.step()
+
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+
+ return {
+ 'cur_lr': self._dis_optimizer.defaults['lr'],
+ 'q_loss': dis_loss.item(),
+ 'total_loss': cont_loss.item() + dis_loss.item(),
+ 'continuous_loss': cont_loss.item(),
+ 'q_value': q_pi_action_value.mean().item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ 'reward': data['reward'].mean().item(),
+ 'target_q_value': target_q_value.mean().item(),
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode, usually including model, target model, discrete part optimizer, and \
+ continuous part optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'dis_optimizer': self._dis_optimizer.state_dict(),
+ 'cont_optimizer': self._cont_optimizer.state_dict()
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._dis_optimizer.load_state_dict(state_dict['dis_optimizer'])
+ self._cont_optimizer.load_state_dict(state_dict['cont_optimizer'])
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For PDQN, it contains the \
+ collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism and \
+ continuous action mechanism, besides, other algorithm-specific arguments such as unroll_len and nstep are \
+ also initialized here.
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+
+ .. tip::
+ Some variables need to initialize independently in different modes, such as gamma and nstep in PDQN. This \
+ design is for the convenience of parallel execution of different policy modes.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._gamma = self._cfg.discount_factor # necessary for parallel
+ self._nstep = self._cfg.nstep # necessary for parallel
+ self._collect_model = model_wrap(
+ self._model,
+ wrapper_name='action_noise',
+ noise_type='gauss',
+ noise_kwargs={
+ 'mu': 0.0,
+ 'sigma': self._cfg.collect.noise_sigma
+ },
+ noise_range=None
+ )
+ self._collect_model = model_wrap(self._collect_model, wrapper_name='hybrid_eps_greedy_multinomial_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \
+ exploration, i.e., classic epsilon-greedy exploration strategy.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ - eps (:obj:`float`): The epsilon value for exploration.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \
+ dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PDQNPolicy: ``ding.policy.tests.test_pdqn``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ action_args = self._collect_model.forward(data, 'compute_continuous', eps=eps)['action_args']
+ inputs = {'state': data, 'action_args': action_args.clone().detach()}
+ output = self._collect_model.forward(inputs, 'compute_discrete', eps=eps)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. In PDQN, a train sample is a processed transition. \
+ This method is usually used in collectors to execute necessary \
+ RL data preprocessing before training, which can help learner amortize revelant time consumption. \
+ In addition, you can also implement this method as an identity function and do the data processing \
+ in ``self._forward_learn`` method.
+ Arguments:
+ - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \
+ as input transitions, but may contain more data for training, such as nstep reward and target obs.
+ """
+ transitions = get_nstep_return_data(transitions, self._nstep, gamma=self._gamma)
+ return get_train_sample(transitions, self._unroll_len)
+
+ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor],
+ timestep: namedtuple) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For PDQN, it contains obs, next_obs, action, reward, done and logit.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.
+ - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \
+ as input. For PDQN, it contains the hybrid action and the logit (discrete part q_value) of the action.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
+ except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
+ reward, done, info, etc.
+ Returns:
+ - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': policy_output['action'],
+ 'logit': policy_output['logit'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For PDQN, it contains the \
+ eval model to greedily select action with argmax q_value mechanism.
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='hybrid_argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PDQNPolicy: ``ding.policy.tests.test_pdqn``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ action_args = self._eval_model.forward(data, mode='compute_continuous')['action_args']
+ inputs = {'state': data, 'action_args': action_args.clone().detach()}
+ output = self._eval_model.forward(inputs, mode='compute_discrete')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return ['cur_lr', 'total_loss', 'q_loss', 'continuous_loss', 'q_value', 'reward', 'target_q_value']
diff --git a/DI-engine/ding/policy/pg.py b/DI-engine/ding/policy/pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..667439d07bc3f5c5eaa92eff50776b946648fae2
--- /dev/null
+++ b/DI-engine/ding/policy/pg.py
@@ -0,0 +1,219 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import torch
+import treetensor as ttorch
+
+from ding.rl_utils import get_gae_with_default_last_value, get_train_sample
+from ding.torch_utils import Adam, to_device
+from ding.utils import POLICY_REGISTRY, split_data_generator
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('pg')
+class PGPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of Policy Gradient (REINFORCE) algorithm.
+ """
+ config = dict(
+ # (string) RL policy register name (refer to function "register_policy").
+ type='pg',
+ # (bool) whether to use cuda for network.
+ cuda=False,
+ # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
+ on_policy=True, # for pg strictly on policy algorithm, this line should not be modified by users
+ # (str) action space type: ['discrete', 'continuous']
+ action_space='discrete',
+ # (bool) whether to use deterministic action for evaluation.
+ deterministic_eval=True,
+ learn=dict(
+
+ # (int) the number of samples for one update.
+ batch_size=64,
+ # (float) the step size of one gradient descend.
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ # (float) max grad norm value.
+ grad_norm=5,
+ # (bool) whether to ignore done signal for non-termination env.
+ ignore_done=False,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ # n_episode=8,
+ # (int) trajectory unroll length
+ unroll_len=1,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ collector=dict(get_train_sample=True),
+ ),
+ eval=dict(),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'pg', ['ding.model.template.pg']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the optimizer, algorithm config, main and target models.
+ """
+ # Optimizer
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._grad_norm = self._cfg.learn.grad_norm
+ self._learn_model = self._model # for compatibility
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs','adv']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
+ """
+ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False)
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._model.train()
+
+ return_infos = []
+ for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
+ # forward
+ output = self._learn_model.forward(batch['obs'])
+ return_ = batch['return']
+ dist = output['dist']
+ # calculate PG loss
+ log_prob = dist.log_prob(batch['action'])
+ policy_loss = -(log_prob * return_).mean()
+ entropy_loss = -self._cfg.learn.entropy_weight * dist.entropy().mean()
+ total_loss = policy_loss + entropy_loss
+
+ # update
+ self._optimizer.zero_grad()
+ total_loss.backward()
+
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ list(self._learn_model.parameters()),
+ max_norm=self._grad_norm,
+ )
+ self._optimizer.step()
+
+ # only record last updates information in logger
+ return_info = {
+ 'cur_lr': self._optimizer.param_groups[0]['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': policy_loss.item(),
+ 'entropy_loss': entropy_loss.item(),
+ 'return_abs_max': return_.abs().max().item(),
+ 'grad_norm': grad_norm,
+ }
+ return_infos.append(return_info)
+ return return_infos
+
+ def _init_collect(self) -> None:
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._gamma = self._cfg.collect.discount_factor
+
+ def _forward_collect(self, data: dict) -> dict:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._model.eval()
+ with torch.no_grad():
+ output = self._model.forward(data)
+ output['action'] = output['dist'].sample()
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ return {
+ 'obs': obs,
+ 'action': model_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the trajectory and the n step return data, then sample from the n_step return data
+ Arguments:
+ - data (:obj:`list`): The trajectory's buffer list
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ assert data[-1]['done'], "PG needs a complete epsiode"
+
+ if self._cfg.learn.ignore_done:
+ raise NotImplementedError
+
+ R = 0.
+ if isinstance(data, list):
+ for i in reversed(range(len(data))):
+ R = self._gamma * R + data[i]['reward']
+ data[i]['return'] = R
+ return get_train_sample(data, self._unroll_len)
+ elif isinstance(data, ttorch.Tensor):
+ data_size = data['done'].shape[0]
+ data['return'] = ttorch.torch.zeros(data_size)
+ for i in reversed(range(data_size)):
+ R = self._gamma * R + data['reward'][i]
+ data['return'][i] = R
+ return get_train_sample(data, self._unroll_len)
+ else:
+ raise ValueError
+
+ def _init_eval(self) -> None:
+ pass
+
+ def _forward_eval(self, data: dict) -> dict:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._model.eval()
+ with torch.no_grad():
+ output = self._model.forward(data)
+ if self._cfg.deterministic_eval:
+ if self._cfg.action_space == 'discrete':
+ output['action'] = output['logit'].argmax(dim=-1)
+ elif self._cfg.action_space == 'continuous':
+ output['action'] = output['logit']['mu']
+ else:
+ raise KeyError("invalid action_space: {}".format(self._cfg.action_space))
+ else:
+ output['action'] = output['dist'].sample()
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return super()._monitor_vars_learn() + ['policy_loss', 'entropy_loss', 'return_abs_max', 'grad_norm']
diff --git a/DI-engine/ding/policy/plan_diffuser.py b/DI-engine/ding/policy/plan_diffuser.py
new file mode 100755
index 0000000000000000000000000000000000000000..ad58546a154f945a07ec5c9853c23f4a6ca6bc7e
--- /dev/null
+++ b/DI-engine/ding/policy/plan_diffuser.py
@@ -0,0 +1,400 @@
+from typing import List, Dict, Any, Optional, Tuple, Union
+from collections import namedtuple, defaultdict
+import copy
+import numpy as np
+import torch
+import torch.nn.functional as F
+from torch.distributions import Normal, Independent
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, \
+ qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_nstep_return_data
+from ding.policy import Policy
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY, DatasetNormalizer
+from ding.utils.data import default_collate, default_decollate
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('pd')
+class PDPolicy(Policy):
+ r"""
+ Overview:
+ Implicit Plan Diffuser
+ https://arxiv.org/pdf/2205.09991.pdf
+
+ """
+ config = dict(
+ type='pd',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool type) priority: Determine whether to use priority in buffer sample.
+ # Default False in SAC.
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (int) Number of training samples(randomly collected) in replay buffer when training starts.
+ # Default 10000 in SAC.
+ random_collect_size=10000,
+ nstep=1,
+ # normalizer type
+ normalizer='GaussianNormalizer',
+ model=dict(
+ diffuser_model='GaussianDiffusion',
+ diffuser_model_cfg=dict(
+ # the type of model
+ model='TemporalUnet',
+ # config of model
+ model_cfg=dict(
+ # model dim, In GaussianInvDynDiffusion, it is obs_dim. In others, it is obs_dim + action_dim
+ transition_dim=23,
+ dim=32,
+ dim_mults=[1, 2, 4, 8],
+ # whether use return as a condition
+ returns_condition=False,
+ condition_dropout=0.1,
+ # whether use calc energy
+ calc_energy=False,
+ kernel_size=5,
+ # whether use attention
+ attention=False,
+ ),
+ # horizon of tarjectory which generated by model
+ horizon=80,
+ # timesteps of diffusion
+ n_timesteps=1000,
+ # hidden dim of action model
+ # Whether predict epsilon
+ predict_epsilon=True,
+ # discount of loss
+ loss_discount=1.0,
+ # whether clip denoise
+ clip_denoised=False,
+ action_weight=10,
+ ),
+ value_model='ValueDiffusion',
+ value_model_cfg=dict(
+ # the type of model
+ model='TemporalValue',
+ # config of model
+ model_cfg=dict(
+ horizon=4,
+ # model dim, In GaussianInvDynDiffusion, it is obs_dim. In others, it is obs_dim + action_dim
+ transition_dim=23,
+ dim=32,
+ dim_mults=[1, 2, 4, 8],
+ # whether use calc energy
+ kernel_size=5,
+ ),
+ # horizon of tarjectory which generated by model
+ horizon=80,
+ # timesteps of diffusion
+ n_timesteps=1000,
+ # hidden dim of action model
+ predict_epsilon=True,
+ # discount of loss
+ loss_discount=1.0,
+ # whether clip denoise
+ clip_denoised=False,
+ action_weight=1.0,
+ ),
+ # guide_steps for p sample
+ n_guide_steps=2,
+ # scale of grad for p sample
+ scale=0.1,
+ # t of stopgrad for p sample
+ t_stopgrad=2,
+ # whether use std as a scale for grad
+ scale_grad_by_std=True,
+ ),
+ learn=dict(
+
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=1,
+ # (int) Minibatch size for gradient descent.
+ batch_size=100,
+
+ # (float type) learning_rate_q: Learning rate for model.
+ # Default to 3e-4.
+ # Please set to 1e-3, when model.value_network is True.
+ learning_rate=3e-4,
+ # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with done is False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+
+ # (float type) target_theta: Used for soft update of the target network,
+ # aka. Interpolation factor in polyak averaging for target networks.
+ # Default to 0.005.
+ target_theta=0.005,
+ # (float) discount factor for the discounted sum of rewards, aka. gamma.
+ discount_factor=0.99,
+ gradient_accumulate_every=2,
+ # train_epoch = train_epoch * gradient_accumulate_every
+ train_epoch=60000,
+ # batch_size of every env when eval
+ plan_batch_size=64,
+
+ # step start update target model and frequence
+ step_start_update_target=2000,
+ update_target_freq=10,
+ # update weight of target net
+ target_weight=0.995,
+ value_step=200e3,
+
+ # dataset weight include returns
+ include_returns=True,
+
+ # (float) Weight uniform initialization range in the last output layer
+ init_w=3e-3,
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'pd', ['ding.model.template.diffusion']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init q, value and policy's optimizers, algorithm config, main and target models.
+ """
+ # Init
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self.action_dim = self._cfg.model.diffuser_model_cfg.action_dim
+ self.obs_dim = self._cfg.model.diffuser_model_cfg.obs_dim
+ self.n_timesteps = self._cfg.model.diffuser_model_cfg.n_timesteps
+ self.gradient_accumulate_every = self._cfg.learn.gradient_accumulate_every
+ self.plan_batch_size = self._cfg.learn.plan_batch_size
+ self.gradient_steps = 1
+ self.update_target_freq = self._cfg.learn.update_target_freq
+ self.step_start_update_target = self._cfg.learn.step_start_update_target
+ self.target_weight = self._cfg.learn.target_weight
+ self.value_step = self._cfg.learn.value_step
+ self.use_target = False
+ self.horizon = self._cfg.model.diffuser_model_cfg.horizon
+ self.include_returns = self._cfg.learn.include_returns
+
+ # Optimizers
+ self._plan_optimizer = Adam(
+ self._model.diffuser.model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ )
+ if self._model.value:
+ self._value_optimizer = Adam(
+ self._model.value.model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ )
+
+ # Algorithm config
+ self._gamma = self._cfg.learn.discount_factor
+
+ # Main and target models
+ self._target_model = copy.deepcopy(self._model)
+ # self._target_model = model_wrap(
+ # self._target_model,
+ # wrapper_name='target',
+ # update_type='momentum',
+ # update_kwargs={'theta': self._cfg.learn.target_theta}
+ # )
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ self._learn_model.reset()
+ # self._target_model.reset()
+
+ self._forward_learn_cnt = 0
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ loss_dict = {}
+
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+
+ conds = {}
+ vals = data['condition_val']
+ ids = data['condition_id']
+ for i in range(len(ids)):
+ conds[ids[i][0].item()] = vals[i]
+ if len(ids) > 1:
+ self.use_target = True
+ data['conditions'] = conds
+ if 'returns' in data.keys():
+ data['returns'] = data['returns'].unsqueeze(-1)
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ self._learn_model.train()
+ # self._target_model.train()
+ x = data['trajectories']
+
+ batch_size = len(x)
+ t = torch.randint(0, self.n_timesteps, (batch_size, ), device=x.device).long()
+ cond = data['conditions']
+ if 'returns' in data.keys():
+ target = data['returns']
+ loss_dict['diffuse_loss'], loss_dict['a0_loss'] = self._model.diffuser_loss(x, cond, t)
+ loss_dict['diffuse_loss'] = loss_dict['diffuse_loss'] / self.gradient_accumulate_every
+ loss_dict['diffuse_loss'].backward()
+ if self._forward_learn_cnt < self.value_step and self._model.value:
+ loss_dict['value_loss'], logs = self._model.value_loss(x, cond, target, t)
+ loss_dict['value_loss'] = loss_dict['value_loss'] / self.gradient_accumulate_every
+ loss_dict['value_loss'].backward()
+ loss_dict.update(logs)
+
+ if self.gradient_steps >= self.gradient_accumulate_every:
+ self._plan_optimizer.step()
+ self._plan_optimizer.zero_grad()
+ if self._forward_learn_cnt < self.value_step and self._model.value:
+ self._value_optimizer.step()
+ self._value_optimizer.zero_grad()
+ self.gradient_steps = 1
+ else:
+ self.gradient_steps += 1
+ self._forward_learn_cnt += 1
+ if self._forward_learn_cnt % self.update_target_freq == 0:
+ if self._forward_learn_cnt < self.step_start_update_target:
+ self._target_model.load_state_dict(self._model.state_dict())
+ else:
+ self.update_model_average(self._target_model, self._learn_model)
+
+ if 'returns' in data.keys():
+ loss_dict['max_return'] = target.max().item()
+ loss_dict['min_return'] = target.min().item()
+ loss_dict['mean_return'] = target.mean().item()
+ loss_dict['max_traj'] = x.max().item()
+ loss_dict['min_traj'] = x.min().item()
+ loss_dict['mean_traj'] = x.mean().item()
+ return loss_dict
+
+ def update_model_average(self, ma_model, current_model):
+ for current_params, ma_params in zip(current_model.parameters(), ma_model.parameters()):
+ old_weight, up_weight = ma_params.data, current_params.data
+ if old_weight is None:
+ ma_params.data = up_weight
+ else:
+ old_weight * self.target_weight + (1 - self.target_weight) * up_weight
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return [
+ 'diffuse_loss',
+ 'value_loss',
+ 'max_return',
+ 'min_return',
+ 'mean_return',
+ 'max_traj',
+ 'min_traj',
+ 'mean_traj',
+ 'mean_pred',
+ 'max_pred',
+ 'min_pred',
+ 'a0_loss',
+ ]
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ if self._model.value:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'plan_optimizer': self._plan_optimizer.state_dict(),
+ 'value_optimizer': self._value_optimizer.state_dict(),
+ }
+ else:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'plan_optimizer': self._plan_optimizer.state_dict(),
+ }
+
+ def _init_eval(self):
+ self._eval_model = model_wrap(self._target_model, wrapper_name='base')
+ self._eval_model.reset()
+ if self.use_target:
+ self._plan_seq = []
+
+ def init_data_normalizer(self, normalizer: DatasetNormalizer = None):
+ self.normalizer = normalizer
+
+ def _forward_eval(self, data: dict) -> Dict[str, Any]:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+
+ self._eval_model.eval()
+ if self.use_target:
+ cur_obs = self.normalizer.normalize(data[:, :self.obs_dim], 'observations')
+ target_obs = self.normalizer.normalize(data[:, self.obs_dim:], 'observations')
+ else:
+ obs = self.normalizer.normalize(data, 'observations')
+ with torch.no_grad():
+ if self.use_target:
+ cur_obs = torch.tensor(cur_obs)
+ target_obs = torch.tensor(target_obs)
+ if self._cuda:
+ cur_obs = to_device(cur_obs, self._device)
+ target_obs = to_device(target_obs, self._device)
+ conditions = {0: cur_obs, self.horizon - 1: target_obs}
+ else:
+ obs = torch.tensor(obs)
+ if self._cuda:
+ obs = to_device(obs, self._device)
+ conditions = {0: obs}
+
+ if self.use_target:
+ if self._plan_seq == [] or 0 in self._eval_t:
+ plan_traj = self._eval_model.get_eval(conditions, self.plan_batch_size)
+ plan_traj = to_device(plan_traj, 'cpu').numpy()
+ if self._plan_seq == []:
+ self._plan_seq = plan_traj
+ self._eval_t = [0] * len(data_id)
+ else:
+ for id in data_id:
+ if self._eval_t[id] == 0:
+ self._plan_seq[id] = plan_traj[id]
+ action = []
+ for id in data_id:
+ if self._eval_t[id] < len(self._plan_seq[id]) - 1:
+ next_waypoint = self._plan_seq[id][self._eval_t[id] + 1]
+ else:
+ next_waypoint = self._plan_seq[id][-1].copy()
+ next_waypoint[2:] = 0
+ cur_ob = cur_obs[id]
+ cur_ob = to_device(cur_ob, 'cpu').numpy()
+ act = next_waypoint[:2] - cur_ob[:2] + (next_waypoint[2:] - cur_ob[2:])
+ action.append(act)
+ self._eval_t[id] += 1
+ else:
+ action = self._eval_model.get_eval(conditions, self.plan_batch_size)
+ if self._cuda:
+ action = to_device(action, 'cpu')
+ action = self.normalizer.unnormalize(action, 'actions')
+ action = torch.tensor(action).to('cpu')
+ output = {'action': action}
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
+ if self.use_target and data_id:
+ for id in data_id:
+ self._eval_t[id] = 0
+
+ def _init_collect(self) -> None:
+ pass
+
+ def _forward_collect(self, data: dict, **kwargs) -> dict:
+ pass
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ pass
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ pass
diff --git a/DI-engine/ding/policy/policy_factory.py b/DI-engine/ding/policy/policy_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba9b77df290d48a7e192ae1c51c283909d66ebf9
--- /dev/null
+++ b/DI-engine/ding/policy/policy_factory.py
@@ -0,0 +1,108 @@
+from typing import Dict, Any, Callable
+from collections import namedtuple
+from easydict import EasyDict
+import gym
+import torch
+
+from ding.torch_utils import to_device
+
+
+class PolicyFactory:
+ """
+ Overview:
+ Policy factory class, used to generate different policies for general purpose. Such as random action policy, \
+ which is used for initial sample collecting for better exploration when ``random_collect_size`` > 0.
+ Interfaces:
+ ``get_random_policy``
+ """
+
+ @staticmethod
+ def get_random_policy(
+ policy: 'Policy.collect_mode', # noqa
+ action_space: 'gym.spaces.Space' = None, # noqa
+ forward_fn: Callable = None,
+ ) -> 'Policy.collect_mode': # noqa
+ """
+ Overview:
+ According to the given action space, define the forward function of the random policy, then pack it with \
+ other interfaces of the given policy, and return the final collect mode interfaces of policy.
+ Arguments:
+ - policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy.
+ - action_space (:obj:`gym.spaces.Space`): The action space of the environment, gym-style.
+ - forward_fn (:obj:`Callable`): It action space is too complex, you can define your own forward function \
+ and pass it to this function, note you should set ``action_space`` to ``None`` in this case.
+ Returns:
+ - random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy.
+ """
+ assert not (action_space is None and forward_fn is None)
+ random_collect_function = namedtuple(
+ 'random_collect_function', [
+ 'forward',
+ 'process_transition',
+ 'get_train_sample',
+ 'reset',
+ 'get_attribute',
+ ]
+ )
+
+ def forward(data: Dict[int, Any], *args, **kwargs) -> Dict[int, Any]:
+
+ actions = {}
+ for env_id in data:
+ if not isinstance(action_space, list):
+ if isinstance(action_space, gym.spaces.Discrete):
+ action = torch.LongTensor([action_space.sample()])
+ elif isinstance(action_space, gym.spaces.MultiDiscrete):
+ action = [torch.LongTensor([v]) for v in action_space.sample()]
+ else:
+ action = torch.as_tensor(action_space.sample())
+ actions[env_id] = {'action': action}
+ elif 'global_state' in data[env_id].keys():
+ # for smac
+ logit = torch.ones_like(data[env_id]['action_mask'])
+ logit[data[env_id]['action_mask'] == 0.0] = -1e8
+ dist = torch.distributions.categorical.Categorical(logits=torch.Tensor(logit))
+ actions[env_id] = {'action': dist.sample(), 'logit': torch.as_tensor(logit)}
+ else:
+ # for gfootball
+ actions[env_id] = {
+ 'action': torch.as_tensor([action_space_agent.sample() for action_space_agent in action_space]),
+ 'logit': torch.ones([len(action_space), action_space[0].n])
+ }
+ return actions
+
+ def reset(*args, **kwargs) -> None:
+ pass
+
+ if action_space is None:
+ return random_collect_function(
+ forward_fn, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute
+ )
+ elif forward_fn is None:
+ return random_collect_function(
+ forward, policy.process_transition, policy.get_train_sample, reset, policy.get_attribute
+ )
+
+
+def get_random_policy(
+ cfg: EasyDict,
+ policy: 'Policy.collect_mode', # noqa
+ env: 'BaseEnvManager' # noqa
+) -> 'Policy.collect_mode': # noqa
+ """
+ Overview:
+ The entry function to get the corresponding random policy. If a policy needs special data items in a \
+ transition, then return itself, otherwise, we will use ``PolicyFactory`` to return a general random policy.
+ Arguments:
+ - cfg (:obj:`EasyDict`): The EasyDict-type dict configuration.
+ - policy (:obj:`Policy.collect_mode`): The collect mode interfaces of the policy.
+ - env (:obj:`BaseEnvManager`): The env manager instance, which is used to get the action space for random \
+ action generation.
+ Returns:
+ - random_policy (:obj:`Policy.collect_mode`): The collect mode intefaces of the random policy.
+ """
+ if cfg.policy.get('transition_with_policy_data', False):
+ return policy
+ else:
+ action_space = env.action_space
+ return PolicyFactory.get_random_policy(policy, action_space=action_space)
diff --git a/DI-engine/ding/policy/ppg.py b/DI-engine/ding/policy/ppg.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e691281c9b852099dc36723c460df255bcfb55a
--- /dev/null
+++ b/DI-engine/ding/policy/ppg.py
@@ -0,0 +1,1322 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import copy
+import torch
+from torch.utils.data import Dataset, DataLoader
+
+from ding.utils import POLICY_REGISTRY, split_data_generator, RunningMeanStd
+from ding.utils.data import default_collate, default_decollate
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import get_gae_with_default_last_value, get_train_sample, gae, gae_data, get_gae, \
+ ppo_policy_data, ppo_policy_error, ppo_value_data, ppo_value_error, ppg_data, ppg_joint_error
+from ding.model import model_wrap
+from .base_policy import Policy
+
+
+class ExperienceDataset(Dataset):
+ """
+ Overview:
+ A dataset class for storing and accessing experience data.
+
+ Interface:
+ ``__init__``, ``__len__``, ``__getitem__``.
+ """
+
+ def __init__(self, data):
+ """
+ Arguments:
+ - data (:obj:`dict`): A dictionary containing the experience data, where the keys represent the data types \
+ and the values are the corresponding data arrays.
+ """
+ super().__init__()
+ self.data = data
+
+ def __len__(self):
+ return list(self.data.values())[0].shape[0]
+
+ def __getitem__(self, ind):
+ data = {}
+ for key in self.data.keys():
+ data[key] = self.data[key][ind]
+ return data
+
+
+def create_shuffled_dataloader(data, batch_size):
+ ds = ExperienceDataset(data)
+ return DataLoader(ds, batch_size=batch_size, shuffle=True)
+
+
+@POLICY_REGISTRY.register('ppg')
+class PPGPolicy(Policy):
+ """
+ Overview:
+ Policy class of PPG algorithm. PPG is a policy gradient algorithm with auxiliary phase training. \
+ The auxiliary phase training is proposed to distill the value into the policy network, \
+ while making sure the policy network does not change the action predictions (kl div loss). \
+ Paper link: https://arxiv.org/abs/2009.04416.
+
+ Interface:
+ ``_init_learn``, ``_data_preprocess_learn``, ``_forward_learn``, ``_state_dict_learn``, \
+ ``_load_state_dict_learn``, ``_init_collect``, ``_forward_collect``, ``_process_transition``, \
+ ``_get_train_sample``, ``_get_batch_size``, ``_init_eval``, ``_forward_eval``, ``default_model``, \
+ ``_monitor_vars_learn``, ``learn_aux``.
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str ppg | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool True | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4. ``priority`` bool False | Whether use priority(PER) | priority sample,
+ | update priority
+ 5 | ``priority_`` bool False | Whether use Importance Sampling | IS weight
+ | ``IS_weight`` | Weight to correct biased update.
+ 6 | ``learn.update`` int 5 | How many updates(iterations) to train | this args can be vary
+ | ``_per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 7 | ``learn.value_`` float 1.0 | The loss weight of value network | policy network weight
+ | ``weight`` | is set to 1
+ 8 | ``learn.entropy_`` float 0.01 | The loss weight of entropy | policy network weight
+ | ``weight`` | regularization | is set to 1
+ 9 | ``learn.clip_`` float 0.2 | PPO clip ratio
+ | ``ratio``
+ 10 | ``learn.adv_`` bool False | Whether to use advantage norm in
+ | ``norm`` | a whole training batch
+ 11 | ``learn.aux_`` int 5 | The frequency(normal update times)
+ | ``freq`` | of auxiliary phase training
+ 12 | ``learn.aux_`` int 6 | The training epochs of auxiliary
+ | ``train_epoch`` | phase
+ 13 | ``learn.aux_`` int 1 | The loss weight of behavioral_cloning
+ | ``bc_weight`` | in auxiliary phase
+ 14 | ``collect.dis`` float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse
+ | ``count_factor`` | gamma | reward env
+ 15 | ``collect.gae_`` float 0.95 | GAE lambda factor for the balance
+ | ``lambda`` | of bias and variance(1-step td and mc)
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='ppg',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used)
+ on_policy=True,
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ learn=dict(
+ actor_epoch_per_collect=1,
+ critic_epoch_per_collect=1,
+ batch_size=64,
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ value_norm=False,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ # (int) The frequency(normal update times) of auxiliary phase training
+ aux_freq=8,
+ # (int) The training epochs of auxiliary phase
+ aux_train_epoch=6,
+ # (int) The loss weight of behavioral_cloning in auxiliary phase
+ aux_bc_weight=1,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(
+ # n_sample=64,
+ unroll_len=1,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc)
+ gae_lambda=0.95,
+ ),
+ eval=dict(),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+ """
+ return 'ppg', ['ding.model.template.ppg']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For PPG, it mainly \
+ contains optimizer, algorithm-specific arguments such as aux_bc_weight and aux_train_epoch. This method \
+ also executes some special network initializations and prepares running mean/std monitor for value. \
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ # Optimizer
+ self._optimizer_ac = Adam(self._model.actor_critic.parameters(), lr=self._cfg.learn.learning_rate)
+ self._optimizer_aux_critic = Adam(self._model.aux_critic.parameters(), lr=self._cfg.learn.learning_rate)
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+
+ # Algorithm config
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPG"
+ self._value_weight = self._cfg.learn.value_weight
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._value_norm = self._cfg.learn.value_norm
+ if self._value_norm:
+ self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)
+ self._clip_ratio = self._cfg.learn.clip_ratio
+ self._adv_norm = self._cfg.learn.adv_norm
+
+ # Main model
+ self._learn_model.reset()
+
+ # Auxiliary memories
+ self._aux_train_epoch = self._cfg.learn.aux_train_epoch
+ self._train_iteration = 0
+ self._aux_memories = []
+ self._aux_bc_weight = self._cfg.learn.aux_bc_weight
+
+ def _data_preprocess_learn(self, data: List[Any]) -> dict:
+ """
+ Overview:
+ Preprocess the data to fit the required data format for learning, including \
+ collate(stack data into batch), ignore done(in some fake terminate env),\
+ prepare loss weight per training sample, and cpu tensor to cuda.
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): The data collected from collect function.
+ Returns:
+ - data (:obj:`Dict[str, Any]`): The processed data, including at least ['done', 'weight'].
+ """
+ # data preprocess
+ data = default_collate(data)
+ ignore_done = self._cfg.learn.ignore_done
+ if ignore_done:
+ data['done'] = None
+ else:
+ data['done'] = data['done'].float()
+ data['weight'] = None
+ if self._cuda:
+ data = to_device(data, self._device)
+ return data
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ """
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Input data used for policy forward, including the \
+ collected training samples from replay buffer. For each element in dict, the key of the \
+ dict is the name of data items and the value is the corresponding data. Usually, the value is \
+ torch.Tensor or np.ndarray or there dict/list combinations. In the ``_forward_learn`` method, data \
+ often need to first be stacked in the batch dimension by some utility functions such as \
+ ``default_preprocess_learn``. \
+ For PPG, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
+ ``reward``, ``logit``, ``value``, ``done``. Sometimes, it also contains other keys such as ``weight``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
+ recorded in text log and tensorboard, values are python scalar or a list of scalars. \
+ For the detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PPGPolicy: ``ding.policy.tests.test_ppgs``.
+ """
+ data = self._data_preprocess_learn(data)
+ # ====================
+ # PPG forward
+ # ====================
+ self._learn_model.train()
+ return_infos = []
+ if self._value_norm:
+ unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std
+ data['return'] = unnormalized_return / self._running_mean_std.std
+ self._running_mean_std.update(unnormalized_return.cpu().numpy())
+ else:
+ data['return'] = data['adv'] + data['value']
+
+ for epoch in range(self._cfg.learn.actor_epoch_per_collect):
+ for policy_data in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
+ policy_adv = policy_data['adv']
+ if self._adv_norm:
+ # Normalize advantage in a total train_batch
+ policy_adv = (policy_adv - policy_adv.mean()) / (policy_adv.std() + 1e-8)
+ # Policy Phase(Policy)
+ policy_output = self._learn_model.forward(policy_data['obs'], mode='compute_actor')
+ policy_error_data = ppo_policy_data(
+ policy_output['logit'], policy_data['logit'], policy_data['action'], policy_adv,
+ policy_data['weight']
+ )
+ ppo_policy_loss, ppo_info = ppo_policy_error(policy_error_data, self._clip_ratio)
+ policy_loss = ppo_policy_loss.policy_loss - self._entropy_weight * ppo_policy_loss.entropy_loss
+ self._optimizer_ac.zero_grad()
+ policy_loss.backward()
+ self._optimizer_ac.step()
+
+ for epoch in range(self._cfg.learn.critic_epoch_per_collect):
+ for value_data in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
+ value_adv = value_data['adv']
+ return_ = value_data['return']
+ if self._adv_norm:
+ # Normalize advantage in a total train_batch
+ value_adv = (value_adv - value_adv.mean()) / (value_adv.std() + 1e-8)
+ # Policy Phase(Value)
+ value_output = self._learn_model.forward(value_data['obs'], mode='compute_critic')
+ value_error_data = ppo_value_data(
+ value_output['value'], value_data['value'], return_, value_data['weight']
+ )
+ value_loss = self._value_weight * ppo_value_error(value_error_data, self._clip_ratio)
+ self._optimizer_aux_critic.zero_grad()
+ value_loss.backward()
+ self._optimizer_aux_critic.step()
+
+ data['return_'] = data['return']
+
+ self._aux_memories.append(copy.deepcopy(data))
+
+ self._train_iteration += 1
+
+ # ====================
+ # PPG update
+ # use aux loss after iterations and reset aux_memories
+ # ====================
+
+ # Auxiliary Phase
+ # record data for auxiliary head
+
+ if self._train_iteration % self._cfg.learn.aux_freq == 0:
+ aux_loss, bc_loss, aux_value_loss = self.learn_aux()
+ return {
+ 'policy_cur_lr': self._optimizer_ac.defaults['lr'],
+ 'value_cur_lr': self._optimizer_aux_critic.defaults['lr'],
+ 'policy_loss': ppo_policy_loss.policy_loss.item(),
+ 'value_loss': value_loss.item(),
+ 'entropy_loss': ppo_policy_loss.entropy_loss.item(),
+ 'policy_adv_abs_max': policy_adv.abs().max().item(),
+ 'approx_kl': ppo_info.approx_kl,
+ 'clipfrac': ppo_info.clipfrac,
+ 'aux_value_loss': aux_value_loss,
+ 'auxiliary_loss': aux_loss,
+ 'behavioral_cloning_loss': bc_loss,
+ }
+ else:
+ return {
+ 'policy_cur_lr': self._optimizer_ac.defaults['lr'],
+ 'value_cur_lr': self._optimizer_aux_critic.defaults['lr'],
+ 'policy_loss': ppo_policy_loss.policy_loss.item(),
+ 'value_loss': value_loss.item(),
+ 'entropy_loss': ppo_policy_loss.entropy_loss.item(),
+ 'policy_adv_abs_max': policy_adv.abs().max().item(),
+ 'approx_kl': ppo_info.approx_kl,
+ 'clipfrac': ppo_info.clipfrac,
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode, usually including model and optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'optimizer_ac': self._optimizer_ac.state_dict(),
+ 'optimizer_aux_critic': self._optimizer_aux_critic.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.\
+ When the value is distilled into the policy network, we need to make sure the policy \
+ network does not change the action predictions, we need two optimizers, \
+ _optimizer_ac is used in policy net, and _optimizer_aux_critic is used in value net.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._optimizer_ac.load_state_dict(state_dict['optimizer_ac'])
+ self._optimizer_aux_critic.load_state_dict(state_dict['optimizer_aux_critic'])
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For PPG, it contains the \
+ collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \
+ discrete action space), and other algorithm-specific arguments such as unroll_len and gae_lambda.
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
+ # TODO continuous action space exploration
+ self._collect_model.reset()
+ self._gamma = self._cfg.collect.discount_factor
+ self._gae_lambda = self._cfg.collect.gae_lambda
+
+ def _forward_collect(self, data: dict) -> dict:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs.
+
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data (action logit and value) for learn mode defined in \
+ ``self._process_transition`` method. The key of the dict is the same as the input data, \
+ i.e. environment id.
+
+ .. tip::
+ If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \
+ related data as extra keyword arguments of this method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PPGPolicy: ``ding.policy.tests.test_ppg``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, mode='compute_actor_critic')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For PPG, it contains obs, next_obs, action, reward, done, logit, value.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): The output of the policy network with the observation \
+ as input. For PPG, it contains the state value, action and the logit of the action.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step \
+ method, except all the elements have been transformed into tensor data. Usually, it contains the next \
+ obs, reward, done, info, etc.
+ Returns:
+ - transition (:obj:`dict`): The processed transition data of the current timestep.
+
+ .. note::
+ ``next_obs`` is used to calculate nstep return when necessary, so we place in into transition by default. \
+ You can delete this field to save memory occupancy if you do not need nstep return.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'logit': model_output['logit'],
+ 'action': model_output['action'],
+ 'value': model_output['value'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: List[Dict[str, Any]]) -> Union[None, List[Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. In PPG, a train sample is a processed transition with new computed \
+ ``adv`` field. This method is usually used in collectors to execute necessary. \
+ RL data preprocessing before training, which can help learner amortize revelant time consumption. \
+ In addition, you can also implement this method as an identity function and do the data processing \
+ in ``self._forward_learn`` method.
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`dict`): The processed train samples, each element is the similar format \
+ as input transitions, but may contain more data for training, such as GAE advantage.
+ """
+ data = to_device(data, self._device)
+ if self._cfg.learn.ignore_done:
+ data[-1]['done'] = False
+
+ if data[-1]['done']:
+ last_value = torch.zeros_like(data[-1]['value'])
+ else:
+ with torch.no_grad():
+ last_value = self._collect_model.forward(
+ data[-1]['next_obs'].unsqueeze(0), mode='compute_actor_critic'
+ )['value']
+ if self._value_norm:
+ last_value *= self._running_mean_std.std
+ for i in range(len(data)):
+ data[i]['value'] *= self._running_mean_std.std
+ data = get_gae(
+ data,
+ to_device(last_value, self._device),
+ gamma=self._gamma,
+ gae_lambda=self._gae_lambda,
+ cuda=False,
+ )
+ if self._value_norm:
+ for i in range(len(data)):
+ data[i]['value'] /= self._running_mean_std.std
+
+ return get_train_sample(data, self._unroll_len)
+
+ def _get_batch_size(self) -> Dict[str, int]:
+ """
+ Overview:
+ Get learn batch size. In the PPG algorithm, different networks require different data.\
+ We need to get data['policy'] and data['value'] to train policy net and value net,\
+ this function is used to get the batch size of data['policy'] and data['value'].
+ Returns:
+ - output (:obj:`dict[str, int]`): Dict type data, including str type batch size and int type batch size.
+ """
+ bs = self._cfg.learn.batch_size
+ return {'policy': bs, 'value': bs}
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For PPG, it contains the \
+ eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete \
+ action). This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs. ``_forward_eval`` in PPG often uses deterministic sample method to get \
+ actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \
+ exploitation.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PPGPolicy: ``ding.policy.tests.test_ppg``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - vars (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return [
+ 'policy_cur_lr',
+ 'value_cur_lr',
+ 'policy_loss',
+ 'value_loss',
+ 'entropy_loss',
+ 'policy_adv_abs_max',
+ 'approx_kl',
+ 'clipfrac',
+ 'aux_value_loss',
+ 'auxiliary_loss',
+ 'behavioral_cloning_loss',
+ ]
+
+ def learn_aux(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ The auxiliary phase training, where the value is distilled into the policy network. In PPG algorithm, \
+ we use the value function loss as the auxiliary objective, thereby sharing features between the policy \
+ and value function while minimizing distortions to the policy. We also use behavioral cloning loss to \
+ optimize the auxiliary objective while otherwise preserving the original policy.
+ Returns:
+ - aux_loss (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`): Including average auxiliary loss\
+ average behavioral cloning loss, and average auxiliary value loss.
+ """
+ aux_memories = self._aux_memories
+ # gather states and target values into one tensor
+ data = {}
+ states = []
+ actions = []
+ return_ = []
+ old_values = []
+ weights = []
+ for memory in aux_memories:
+ # for memory in memories:
+ states.append(memory['obs'])
+ actions.append(memory['action'])
+ return_.append(memory['return_'])
+ old_values.append(memory['value'])
+ if memory['weight'] is None:
+ weight = torch.ones_like(memory['action'])
+ else:
+ weight = torch.tensor(memory['weight'])
+ weights.append(weight)
+
+ data['obs'] = torch.cat(states)
+ data['action'] = torch.cat(actions)
+ data['return_'] = torch.cat(return_)
+ data['value'] = torch.cat(old_values)
+ data['weight'] = torch.cat(weights).float()
+ # compute current policy logit_old
+ with torch.no_grad():
+ data['logit_old'] = self._model.forward(data['obs'], mode='compute_actor')['logit']
+
+ # prepared dataloader for auxiliary phase training
+ dl = create_shuffled_dataloader(data, self._cfg.learn.batch_size)
+
+ # the proposed auxiliary phase training
+ # where the value is distilled into the policy network,
+ # while making sure the policy network does not change the action predictions (kl div loss)
+
+ i = 0
+ auxiliary_loss_ = 0
+ behavioral_cloning_loss_ = 0
+ value_loss_ = 0
+
+ for epoch in range(self._aux_train_epoch):
+ for data in dl:
+ policy_output = self._model.forward(data['obs'], mode='compute_actor_critic')
+
+ # Calculate ppg error 'logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'return_', 'weight'
+ data_ppg = ppg_data(
+ policy_output['logit'], data['logit_old'], data['action'], policy_output['value'], data['value'],
+ data['return_'], data['weight']
+ )
+ ppg_joint_loss = ppg_joint_error(data_ppg, self._clip_ratio)
+ wb = self._aux_bc_weight
+ total_loss = ppg_joint_loss.auxiliary_loss + wb * ppg_joint_loss.behavioral_cloning_loss
+
+ # # policy network loss copmoses of both the kl div loss as well as the auxiliary loss
+ # aux_loss = clipped_value_loss(policy_values, rewards, old_values, self.value_clip)
+ # loss_kl = F.kl_div(action_logprobs, old_action_probs, reduction='batchmean')
+ # policy_loss = aux_loss + loss_kl
+
+ self._optimizer_ac.zero_grad()
+ total_loss.backward()
+ self._optimizer_ac.step()
+
+ # paper says it is important to train the value network extra during the auxiliary phase
+ # Calculate ppg error 'value_new', 'value_old', 'return_', 'weight'
+ values = self._model.forward(data['obs'], mode='compute_critic')['value']
+ data_aux = ppo_value_data(values, data['value'], data['return_'], data['weight'])
+
+ value_loss = ppo_value_error(data_aux, self._clip_ratio)
+
+ self._optimizer_aux_critic.zero_grad()
+ value_loss.backward()
+ self._optimizer_aux_critic.step()
+
+ auxiliary_loss_ += ppg_joint_loss.auxiliary_loss.item()
+ behavioral_cloning_loss_ += ppg_joint_loss.behavioral_cloning_loss.item()
+ value_loss_ += value_loss.item()
+ i += 1
+
+ self._aux_memories = []
+
+ return auxiliary_loss_ / i, behavioral_cloning_loss_ / i, value_loss_ / i
+
+
+@POLICY_REGISTRY.register('ppg_offpolicy')
+class PPGOffPolicy(Policy):
+ """
+ Overview:
+ Policy class of PPG algorithm with off-policy training mode. Off-policy PPG contains two different data \
+ max_use buffers. The policy buffer offers data for policy phase , while the value buffer provides auxiliary \
+ phase's data. The whole training procedure is similar to off-policy PPO but execute additional auxiliary \
+ phase with a fixed frequency.
+ Interface:
+ ``_init_learn``, ``_data_preprocess_learn``, ``_forward_learn``, ``_state_dict_learn``, \
+ ``_load_state_dict_learn``, ``_init_collect``, ``_forward_collect``, ``_process_transition``, \
+ ``_get_train_sample``, ``_get_batch_size``, ``_init_eval``, ``_forward_eval``, ``default_model``, \
+ ``_monitor_vars_learn``, ``learn_aux``.
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str ppg | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool True | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4. ``priority`` bool False | Whether use priority(PER) | priority sample,
+ | update priority
+ 5 | ``priority_`` bool False | Whether use Importance Sampling | IS weight
+ | ``IS_weight`` | Weight to correct biased update.
+ 6 | ``learn.update`` int 5 | How many updates(iterations) to train | this args can be vary
+ | ``_per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 7 | ``learn.value_`` float 1.0 | The loss weight of value network | policy network weight
+ | ``weight`` | is set to 1
+ 8 | ``learn.entropy_`` float 0.01 | The loss weight of entropy | policy network weight
+ | ``weight`` | regularization | is set to 1
+ 9 | ``learn.clip_`` float 0.2 | PPO clip ratio
+ | ``ratio``
+ 10 | ``learn.adv_`` bool False | Whether to use advantage norm in
+ | ``norm`` | a whole training batch
+ 11 | ``learn.aux_`` int 5 | The frequency(normal update times)
+ | ``freq`` | of auxiliary phase training
+ 12 | ``learn.aux_`` int 6 | The training epochs of auxiliary
+ | ``train_epoch`` | phase
+ 13 | ``learn.aux_`` int 1 | The loss weight of behavioral_cloning
+ | ``bc_weight`` | in auxiliary phase
+ 14 | ``collect.dis`` float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse
+ | ``count_factor`` | gamma | reward env
+ 15 | ``collect.gae_`` float 0.95 | GAE lambda factor for the balance
+ | ``lambda`` | of bias and variance(1-step td and mc)
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='ppg_offpolicy',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used)
+ on_policy=False,
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (bool) Whether to need policy data in process transition
+ transition_with_policy_data=True,
+ learn=dict(
+ update_per_collect=5,
+ batch_size=64,
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ # (int) The frequency(normal update times) of auxiliary phase training
+ aux_freq=5,
+ # (int) The training epochs of auxiliary phase
+ aux_train_epoch=6,
+ # (int) The loss weight of behavioral_cloning in auxiliary phase
+ aux_bc_weight=1,
+ ignore_done=False,
+ ),
+ collect=dict(
+ # n_sample=64,
+ unroll_len=1,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc)
+ gae_lambda=0.95,
+ ),
+ eval=dict(),
+ other=dict(
+ replay_buffer=dict(
+ # PPG use two separate buffer for different reuse
+ multi_buffer=True,
+ policy=dict(replay_buffer_size=1000, ),
+ value=dict(replay_buffer_size=1000, ),
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path.
+ """
+ return 'ppg', ['ding.model.template.ppg']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For PPG, it mainly \
+ contains optimizer, algorithm-specific arguments such as aux_bc_weight and aux_train_epoch. This method \
+ also executes some special network initializations and prepares running mean/std monitor for value. \
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ # Optimizer
+ self._optimizer_ac = Adam(self._model.actor_critic.parameters(), lr=self._cfg.learn.learning_rate)
+ self._optimizer_aux_critic = Adam(self._model.aux_critic.parameters(), lr=self._cfg.learn.learning_rate)
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+
+ # Algorithm config
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPG"
+ self._value_weight = self._cfg.learn.value_weight
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._clip_ratio = self._cfg.learn.clip_ratio
+ self._adv_norm = self._cfg.learn.adv_norm
+
+ # Main model
+ self._learn_model.reset()
+
+ # Auxiliary memories
+ self._aux_train_epoch = self._cfg.learn.aux_train_epoch
+ self._train_iteration = 0
+ self._aux_memories = []
+ self._aux_bc_weight = self._cfg.learn.aux_bc_weight
+
+ def _data_preprocess_learn(self, data: List[Any]) -> dict:
+ """
+ Overview:
+ Preprocess the data to fit the required data format for learning, including \
+ collate(stack data into batch), ignore done(in some fake terminate env),\
+ prepare loss weight per training sample, and cpu tensor to cuda.
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): The data collected from collect function.
+ Returns:
+ - data (:obj:`Dict[str, Any]`): The processed data, including at least ['done', 'weight'].
+ """
+ # data preprocess
+ for k, data_item in data.items():
+ data_item = default_collate(data_item)
+ ignore_done = self._cfg.learn.ignore_done
+ if ignore_done:
+ data_item['done'] = None
+ else:
+ data_item['done'] = data_item['done'].float()
+ data_item['weight'] = None
+ data[k] = data_item
+ if self._cuda:
+ data = to_device(data, self._device)
+ return data
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ """
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Input data used for policy forward, including the \
+ collected training samples from replay buffer. For each element in dict, the key of the \
+ dict is the name of data items and the value is the corresponding data. Usually, \
+ the class type of value is either torch.Tensor or np.ndarray, or a dict/list containing \
+ either torch.Tensor or np.ndarray items In the ``_forward_learn`` method, data \
+ often need to first be stacked in the batch dimension by some utility functions such as \
+ ``default_preprocess_learn``. \
+ For PPGOff, each element in list is a dict containing at least the following keys: ``obs``, \
+ ``action``, ``reward``, ``logit``, ``value``, ``done``. Sometimes, it also contains other keys \
+ such as ``weight``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
+ recorded in text log and tensorboard, values are python scalar or a list of scalars. \
+ For the detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ ReturnsKeys:
+ - necessary: "current lr", "total_loss", "policy_loss", "value_loss", "entropy_loss", \
+ "adv_abs_max", "approx_kl", "clipfrac", \
+ "aux_value_loss", "auxiliary_loss", "behavioral_cloning_loss".
+
+ - current_lr (:obj:`float`): Current learning rate.
+ - total_loss (:obj:`float`): The calculated loss.
+ - policy_loss (:obj:`float`): The policy(actor) loss of ppg.
+ - value_loss (:obj:`float`): The value(critic) loss of ppg.
+ - entropy_loss (:obj:`float`): The entropy loss.
+ - auxiliary_loss (:obj:`float`): The auxiliary loss, we use the value function loss \
+ as the auxiliary objective, thereby sharing features between the policy and value function\
+ while minimizing distortions to the policy.
+ - aux_value_loss (:obj:`float`): The auxiliary value loss, we need to train the value network extra \
+ during the auxiliary phase, it's the value loss we train the value network during auxiliary phase.
+ - behavioral_cloning_loss (:obj:`float`): The behavioral cloning loss, used to optimize the auxiliary\
+ objective while otherwise preserving the original policy.
+ """
+ data = self._data_preprocess_learn(data)
+ # ====================
+ # PPG forward
+ # ====================
+ self._learn_model.train()
+ policy_data, value_data = data['policy'], data['value']
+ policy_adv, value_adv = policy_data['adv'], value_data['adv']
+ return_ = value_data['value'] + value_adv
+ if self._adv_norm:
+ # Normalize advantage in a total train_batch
+ policy_adv = (policy_adv - policy_adv.mean()) / (policy_adv.std() + 1e-8)
+ value_adv = (value_adv - value_adv.mean()) / (value_adv.std() + 1e-8)
+ # Policy Phase(Policy)
+ policy_output = self._learn_model.forward(policy_data['obs'], mode='compute_actor')
+ policy_error_data = ppo_policy_data(
+ policy_output['logit'], policy_data['logit'], policy_data['action'], policy_adv, policy_data['weight']
+ )
+ ppo_policy_loss, ppo_info = ppo_policy_error(policy_error_data, self._clip_ratio)
+ policy_loss = ppo_policy_loss.policy_loss - self._entropy_weight * ppo_policy_loss.entropy_loss
+ self._optimizer_ac.zero_grad()
+ policy_loss.backward()
+ self._optimizer_ac.step()
+
+ # Policy Phase(Value)
+ value_output = self._learn_model.forward(value_data['obs'], mode='compute_critic')
+ value_error_data = ppo_value_data(value_output['value'], value_data['value'], return_, value_data['weight'])
+ value_loss = self._value_weight * ppo_value_error(value_error_data, self._clip_ratio)
+ self._optimizer_aux_critic.zero_grad()
+ value_loss.backward()
+ self._optimizer_aux_critic.step()
+
+ # ====================
+ # PPG update
+ # use aux loss after iterations and reset aux_memories
+ # ====================
+
+ # Auxiliary Phase
+ # record data for auxiliary head
+ data = data['value']
+ data['return_'] = return_.data
+ self._aux_memories.append(copy.deepcopy(data))
+
+ self._train_iteration += 1
+ total_loss = policy_loss + value_loss
+ if self._train_iteration % self._cfg.learn.aux_freq == 0:
+ aux_loss, bc_loss, aux_value_loss = self.learn_aux()
+ total_loss += aux_loss + bc_loss + aux_value_loss
+ return {
+ 'policy_cur_lr': self._optimizer_ac.defaults['lr'],
+ 'value_cur_lr': self._optimizer_aux_critic.defaults['lr'],
+ 'policy_loss': ppo_policy_loss.policy_loss.item(),
+ 'value_loss': value_loss.item(),
+ 'entropy_loss': ppo_policy_loss.entropy_loss.item(),
+ 'policy_adv_abs_max': policy_adv.abs().max().item(),
+ 'approx_kl': ppo_info.approx_kl,
+ 'clipfrac': ppo_info.clipfrac,
+ 'aux_value_loss': aux_value_loss,
+ 'auxiliary_loss': aux_loss,
+ 'behavioral_cloning_loss': bc_loss,
+ 'total_loss': total_loss.item(),
+ }
+ else:
+ return {
+ 'policy_cur_lr': self._optimizer_ac.defaults['lr'],
+ 'value_cur_lr': self._optimizer_aux_critic.defaults['lr'],
+ 'policy_loss': ppo_policy_loss.policy_loss.item(),
+ 'value_loss': value_loss.item(),
+ 'entropy_loss': ppo_policy_loss.entropy_loss.item(),
+ 'policy_adv_abs_max': policy_adv.abs().max().item(),
+ 'approx_kl': ppo_info.approx_kl,
+ 'clipfrac': ppo_info.clipfrac,
+ 'total_loss': total_loss.item(),
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode, usually including model and optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'optimizer_ac': self._optimizer_ac.state_dict(),
+ 'optimizer_aux_critic': self._optimizer_aux_critic.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.\
+ When the value is distilled into the policy network, we need to make sure the policy \
+ network does not change the action predictions, we need two optimizers, \
+ _optimizer_ac is used in policy net, and _optimizer_aux_critic is used in value net.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._optimizer_ac.load_state_dict(state_dict['optimizer_ac'])
+ self._optimizer_aux_critic.load_state_dict(state_dict['optimizer_aux_critic'])
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For PPO, it contains the \
+ collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \
+ discrete action space), and other algorithm-specific arguments such as unroll_len and gae_lambda.
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
+ # TODO continuous action space exploration
+ self._collect_model.reset()
+ self._gamma = self._cfg.collect.discount_factor
+ self._gae_lambda = self._cfg.collect.gae_lambda
+
+ def _forward_collect(self, data: dict) -> dict:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs.
+
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data (action logit and value) for learn mode defined in \
+ ``self._process_transition`` method. The key of the dict is the same as the input data, \
+ i.e. environment id.
+
+ .. tip::
+ If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \
+ related data as extra keyword arguments of this method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PPGOffPolicy: ``ding.policy.tests.test_ppg``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, mode='compute_actor_critic')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For PPG, it contains obs, next_obs, action, reward, done, logit, value.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): The output of the policy network with the observation \
+ as input. For PPG, it contains the state value, action and the logit of the action.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step \
+ method, except all the elements have been transformed into tensor data. Usually, it contains the next \
+ obs, reward, done, info, etc.
+ Returns:
+ - transition (:obj:`dict`): The processed transition data of the current timestep.
+
+ .. note::
+ ``next_obs`` is used to calculate nstep return when necessary, so we place in into transition by default. \
+ You can delete this field to save memory occupancy if you do not need nstep return.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'logit': model_output['logit'],
+ 'action': model_output['action'],
+ 'value': model_output['value'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. In PPG, a train sample is a processed transition with new computed \
+ ``adv`` field. This method is usually used in collectors to execute necessary. \
+ RL data preprocessing before training, which can help learner amortize revelant time consumption. \
+ In addition, you can also implement this method as an identity function and do the data processing \
+ in ``self._forward_learn`` method.
+ Arguments:
+ - data (:obj:`list`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`dict`): The processed train samples, each element is the similar format \
+ as input transitions, but may contain more data for training, such as GAE advantage.
+ """
+ data = get_gae_with_default_last_value(
+ data,
+ data[-1]['done'],
+ gamma=self._gamma,
+ gae_lambda=self._gae_lambda,
+ cuda=False,
+ )
+ data = get_train_sample(data, self._unroll_len)
+ for d in data:
+ d['buffer_name'] = ["policy", "value"]
+ return data
+
+ def _get_batch_size(self) -> Dict[str, int]:
+ """
+ Overview:
+ Get learn batch size. In the PPG algorithm, different networks require different data.\
+ We need to get data['policy'] and data['value'] to train policy net and value net,\
+ this function is used to get the batch size of data['policy'] and data['value'].
+ Returns:
+ - output (:obj:`dict[str, int]`): Dict type data, including str type batch size and int type batch size.
+ """
+ bs = self._cfg.learn.batch_size
+ return {'policy': bs, 'value': bs}
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For PPG, it contains the \
+ eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete \
+ action). This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs. ``_forward_eval`` in PPG often uses deterministic sample method to get \
+ actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \
+ exploitation.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PPGOffPolicy: ``ding.policy.tests.test_ppg``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - vars (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return [
+ 'policy_cur_lr',
+ 'value_cur_lr',
+ 'policy_loss',
+ 'value_loss',
+ 'entropy_loss',
+ 'policy_adv_abs_max',
+ 'approx_kl',
+ 'clipfrac',
+ 'aux_value_loss',
+ 'auxiliary_loss',
+ 'behavioral_cloning_loss',
+ ]
+
+ def learn_aux(self) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ The auxiliary phase training, where the value is distilled into the policy network. In PPG algorithm, \
+ we use the value function loss as the auxiliary objective, thereby sharing features between the policy \
+ and value function while minimizing distortions to the policy. We also use behavioral cloning loss to \
+ optimize the auxiliary objective while otherwise preserving the original policy.
+ Returns:
+ - aux_loss (:obj:`Tuple[torch.Tensor, torch.Tensor, torch.Tensor]`): Including average auxiliary loss\
+ average behavioral cloning loss, and average auxiliary value loss.
+ """
+ aux_memories = self._aux_memories
+ # gather states and target values into one tensor
+ data = {}
+ states = []
+ actions = []
+ return_ = []
+ old_values = []
+ weights = []
+ for memory in aux_memories:
+ # for memory in memories:
+ states.append(memory['obs'])
+ actions.append(memory['action'])
+ return_.append(memory['return_'])
+ old_values.append(memory['value'])
+ if memory['weight'] is None:
+ weight = torch.ones_like(memory['action'])
+ else:
+ weight = torch.tensor(memory['weight'])
+ weights.append(weight)
+
+ data['obs'] = torch.cat(states)
+ data['action'] = torch.cat(actions)
+ data['return_'] = torch.cat(return_)
+ data['value'] = torch.cat(old_values)
+ data['weight'] = torch.cat(weights)
+ # compute current policy logit_old
+ with torch.no_grad():
+ data['logit_old'] = self._model.forward(data['obs'], mode='compute_actor')['logit']
+
+ # prepared dataloader for auxiliary phase training
+ dl = create_shuffled_dataloader(data, self._cfg.learn.batch_size)
+
+ # the proposed auxiliary phase training
+ # where the value is distilled into the policy network,
+ # while making sure the policy network does not change the action predictions (kl div loss)
+
+ i = 0
+ auxiliary_loss_ = 0
+ behavioral_cloning_loss_ = 0
+ value_loss_ = 0
+
+ for epoch in range(self._aux_train_epoch):
+ for data in dl:
+ policy_output = self._model.forward(data['obs'], mode='compute_actor_critic')
+
+ # Calculate ppg error 'logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'return_', 'weight'
+ data_ppg = ppg_data(
+ policy_output['logit'], data['logit_old'], data['action'], policy_output['value'], data['value'],
+ data['return_'], data['weight']
+ )
+ ppg_joint_loss = ppg_joint_error(data_ppg, self._clip_ratio)
+ wb = self._aux_bc_weight
+ total_loss = ppg_joint_loss.auxiliary_loss + wb * ppg_joint_loss.behavioral_cloning_loss
+
+ # # policy network loss copmoses of both the kl div loss as well as the auxiliary loss
+ # aux_loss = clipped_value_loss(policy_values, rewards, old_values, self.value_clip)
+ # loss_kl = F.kl_div(action_logprobs, old_action_probs, reduction='batchmean')
+ # policy_loss = aux_loss + loss_kl
+
+ self._optimizer_ac.zero_grad()
+ total_loss.backward()
+ self._optimizer_ac.step()
+
+ # paper says it is important to train the value network extra during the auxiliary phase
+ # Calculate ppg error 'value_new', 'value_old', 'return_', 'weight'
+ values = self._model.forward(data['obs'], mode='compute_critic')['value']
+ data_aux = ppo_value_data(values, data['value'], data['return_'], data['weight'])
+
+ value_loss = ppo_value_error(data_aux, self._clip_ratio)
+
+ self._optimizer_aux_critic.zero_grad()
+ value_loss.backward()
+ self._optimizer_aux_critic.step()
+
+ auxiliary_loss_ += ppg_joint_loss.auxiliary_loss.item()
+ behavioral_cloning_loss_ += ppg_joint_loss.behavioral_cloning_loss.item()
+ value_loss_ += value_loss.item()
+ i += 1
+
+ self._aux_memories = []
+
+ return auxiliary_loss_ / i, behavioral_cloning_loss_ / i, value_loss_ / i
diff --git a/DI-engine/ding/policy/ppo.py b/DI-engine/ding/policy/ppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..289bc72c44e6129d83c79923464a93f3e5632ad3
--- /dev/null
+++ b/DI-engine/ding/policy/ppo.py
@@ -0,0 +1,1841 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import torch
+import copy
+import numpy as np
+
+from ding.torch_utils import Adam, to_device, to_dtype, unsqueeze, ContrastiveLoss
+from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, get_gae_with_default_last_value, \
+ v_nstep_td_data, v_nstep_td_error, get_nstep_return_data, get_train_sample, gae, gae_data, ppo_error_continuous, \
+ get_gae, ppo_policy_error_continuous
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY, split_data_generator, RunningMeanStd
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('ppo')
+class PPOPolicy(Policy):
+ """
+ Overview:
+ Policy class of on-policy version PPO algorithm. Paper link: https://arxiv.org/abs/1707.06347.
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='ppo',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used)
+ on_policy=True,
+ # (bool) Whether to use priority (priority sample, IS weight, update priority).
+ priority=False,
+ # (bool) Whether to use Importance Sampling Weight to correct biased update due to priority.
+ # If True, priority must be True.
+ priority_IS_weight=False,
+ # (bool) Whether to recompurete advantages in each iteration of on-policy PPO.
+ recompute_adv=True,
+ # (str) Which kind of action space used in PPOPolicy, ['discrete', 'continuous', 'hybrid']
+ action_space='discrete',
+ # (bool) Whether to use nstep return to calculate value target, otherwise, use return = adv + value.
+ nstep_return=False,
+ # (bool) Whether to enable multi-agent training, i.e.: MAPPO.
+ multi_agent=False,
+ # (bool) Whether to need policy ``_forward_collect`` output data in process transition.
+ transition_with_policy_data=True,
+ # learn_mode config
+ learn=dict(
+ # (int) After collecting n_sample/n_episode data, how many epoches to train models.
+ # Each epoch means the one entire passing of training data.
+ epoch_per_collect=10,
+ # (int) How many samples in a training batch.
+ batch_size=64,
+ # (float) The step size of gradient descent.
+ learning_rate=3e-4,
+ # (float) The loss weight of value network, policy network weight is set to 1.
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1.
+ entropy_weight=0.0,
+ # (float) PPO clip ratio, defaults to 0.2.
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch.
+ adv_norm=True,
+ # (bool) Whether to use value norm with running mean and std in the whole training process.
+ value_norm=True,
+ # (bool) Whether to enable special network parameters initialization scheme in PPO, such as orthogonal init.
+ ppo_param_init=True,
+ # (str) The gradient clip operation type used in PPO, ['clip_norm', clip_value', 'clip_momentum_norm'].
+ grad_clip_type='clip_norm',
+ # (float) The gradient clip target value used in PPO.
+ # If ``grad_clip_type`` is 'clip_norm', then the maximum of gradient will be normalized to this value.
+ grad_clip_value=0.5,
+ # (bool) Whether ignore done (usually for max step termination env).
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ # Only one of [n_sample, n_episode] should be set.
+ # n_sample=64,
+ # (int) Split episodes or trajectories into pieces with length `unroll_len`.
+ unroll_len=1,
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc)
+ gae_lambda=0.95,
+ ),
+ eval=dict(), # for compability
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For example about PPO, its registered name is ``ppo`` and the import_names is \
+ ``ding.model.template.vac``.
+
+ .. note::
+ Because now PPO supports both single-agent and multi-agent usages, so we can implement these functions \
+ with the same policy and two different default models, which is controled by ``self._cfg.multi_agent``.
+ """
+ if self._cfg.multi_agent:
+ return 'mavac', ['ding.model.template.mavac']
+ else:
+ return 'vac', ['ding.model.template.vac']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For PPO, it mainly contains \
+ optimizer, algorithm-specific arguments such as loss weight, clip_ratio and recompute_adv. This method \
+ also executes some special network initializations and prepares running mean/std monitor for value.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPO"
+
+ assert self._cfg.action_space in ["continuous", "discrete", "hybrid"]
+ self._action_space = self._cfg.action_space
+ if self._cfg.learn.ppo_param_init:
+ for n, m in self._model.named_modules():
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.orthogonal_(m.weight)
+ torch.nn.init.zeros_(m.bias)
+ if self._action_space in ['continuous', 'hybrid']:
+ # init log sigma
+ if self._action_space == 'continuous':
+ if hasattr(self._model.actor_head, 'log_sigma_param'):
+ torch.nn.init.constant_(self._model.actor_head.log_sigma_param, -0.5)
+ elif self._action_space == 'hybrid': # actor_head[1]: ReparameterizationHead, for action_args
+ if hasattr(self._model.actor_head[1], 'log_sigma_param'):
+ torch.nn.init.constant_(self._model.actor_head[1].log_sigma_param, -0.5)
+
+ for m in list(self._model.critic.modules()) + list(self._model.actor.modules()):
+ if isinstance(m, torch.nn.Linear):
+ # orthogonal initialization
+ torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
+ torch.nn.init.zeros_(m.bias)
+ # do last policy layer scaling, this will make initial actions have (close to)
+ # 0 mean and std, and will help boost performances,
+ # see https://arxiv.org/abs/2006.05990, Fig.24 for details
+ for m in self._model.actor.modules():
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.zeros_(m.bias)
+ m.weight.data.copy_(0.01 * m.weight.data)
+
+ # Optimizer
+ self._optimizer = Adam(
+ self._model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ grad_clip_type=self._cfg.learn.grad_clip_type,
+ clip_value=self._cfg.learn.grad_clip_value
+ )
+
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+
+ # Algorithm config
+ self._value_weight = self._cfg.learn.value_weight
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._clip_ratio = self._cfg.learn.clip_ratio
+ self._adv_norm = self._cfg.learn.adv_norm
+ self._value_norm = self._cfg.learn.value_norm
+ if self._value_norm:
+ self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)
+ self._gamma = self._cfg.collect.discount_factor
+ self._gae_lambda = self._cfg.collect.gae_lambda
+ self._recompute_adv = self._cfg.recompute_adv
+ # Main model
+ self._learn_model.reset()
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss, clipfrac, approx_kl.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including the latest \
+ collected training samples for on-policy algorithms like PPO. For each element in list, the key of the \
+ dict is the name of data items and the value is the corresponding data. Usually, the value is \
+ torch.Tensor or np.ndarray or there dict/list combinations. In the ``_forward_learn`` method, data \
+ often need to first be stacked in the batch dimension by some utility functions such as \
+ ``default_preprocess_learn``. \
+ For PPO, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
+ ``reward``, ``logit``, ``value``, ``done``. Sometimes, it also contains other keys such as ``weight``.
+ Returns:
+ - return_infos (:obj:`List[Dict[str, Any]]`): The information list that indicated training result, each \
+ training iteration contains append a information dict into the final list. The list will be precessed \
+ and recorded in text log and tensorboard. The value of the dict must be python scalar or a list of \
+ scalars. For the detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. tip::
+ The training procedure of PPO is two for loops. The outer loop trains all the collected training samples \
+ with ``epoch_per_collect`` epochs. The inner loop splits all the data into different mini-batch with \
+ the length of ``batch_size``.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PPOPolicy: ``ding.policy.tests.test_ppo``.
+ """
+ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False)
+ if self._cuda:
+ data = to_device(data, self._device)
+ data['obs'] = to_dtype(data['obs'], torch.float32)
+ if 'next_obs' in data:
+ data['next_obs'] = to_dtype(data['next_obs'], torch.float32)
+ # ====================
+ # PPO forward
+ # ====================
+ return_infos = []
+ self._learn_model.train()
+
+ for epoch in range(self._cfg.learn.epoch_per_collect):
+ if self._recompute_adv: # calculate new value using the new updated value network
+ with torch.no_grad():
+ value = self._learn_model.forward(data['obs'], mode='compute_critic')['value']
+ next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value']
+ if self._value_norm:
+ value *= self._running_mean_std.std
+ next_value *= self._running_mean_std.std
+
+ traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory
+ compute_adv_data = gae_data(value, next_value, data['reward'], data['done'], traj_flag)
+ data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda)
+
+ unnormalized_returns = value + data['adv']
+
+ if self._value_norm:
+ data['value'] = value / self._running_mean_std.std
+ data['return'] = unnormalized_returns / self._running_mean_std.std
+ self._running_mean_std.update(unnormalized_returns.cpu().numpy())
+ else:
+ data['value'] = value
+ data['return'] = unnormalized_returns
+
+ else: # don't recompute adv
+ if self._value_norm:
+ unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std
+ data['return'] = unnormalized_return / self._running_mean_std.std
+ self._running_mean_std.update(unnormalized_return.cpu().numpy())
+ else:
+ data['return'] = data['adv'] + data['value']
+
+ for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
+ output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic')
+ adv = batch['adv']
+ if self._adv_norm:
+ # Normalize advantage in a train_batch
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
+
+ # Calculate ppo error
+ if self._action_space == 'continuous':
+ ppo_batch = ppo_data(
+ output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
+ batch['return'], batch['weight']
+ )
+ ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio)
+ elif self._action_space == 'discrete':
+ ppo_batch = ppo_data(
+ output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
+ batch['return'], batch['weight']
+ )
+ ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio)
+ elif self._action_space == 'hybrid':
+ # discrete part (discrete policy loss and entropy loss)
+ ppo_discrete_batch = ppo_policy_data(
+ output['logit']['action_type'], batch['logit']['action_type'], batch['action']['action_type'],
+ adv, batch['weight']
+ )
+ ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(ppo_discrete_batch, self._clip_ratio)
+ # continuous part (continuous policy loss and entropy loss, value loss)
+ ppo_continuous_batch = ppo_data(
+ output['logit']['action_args'], batch['logit']['action_args'], batch['action']['action_args'],
+ output['value'], batch['value'], adv, batch['return'], batch['weight']
+ )
+ ppo_continuous_loss, ppo_continuous_info = ppo_error_continuous(
+ ppo_continuous_batch, self._clip_ratio
+ )
+ # sum discrete and continuous loss
+ ppo_loss = type(ppo_continuous_loss)(
+ ppo_continuous_loss.policy_loss + ppo_discrete_loss.policy_loss, ppo_continuous_loss.value_loss,
+ ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss
+ )
+ ppo_info = type(ppo_continuous_info)(
+ max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl),
+ max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac)
+ )
+ wv, we = self._value_weight, self._entropy_weight
+ total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
+
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ self._optimizer.step()
+
+ return_info = {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': ppo_loss.policy_loss.item(),
+ 'value_loss': ppo_loss.value_loss.item(),
+ 'entropy_loss': ppo_loss.entropy_loss.item(),
+ 'adv_max': adv.max().item(),
+ 'adv_mean': adv.mean().item(),
+ 'value_mean': output['value'].mean().item(),
+ 'value_max': output['value'].max().item(),
+ 'approx_kl': ppo_info.approx_kl,
+ 'clipfrac': ppo_info.clipfrac,
+ }
+ if self._action_space == 'continuous':
+ return_info.update(
+ {
+ 'act': batch['action'].float().mean().item(),
+ 'mu_mean': output['logit']['mu'].mean().item(),
+ 'sigma_mean': output['logit']['sigma'].mean().item(),
+ }
+ )
+ return_infos.append(return_info)
+ return return_infos
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For PPO, it contains the \
+ collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \
+ discrete action space), and other algorithm-specific arguments such as unroll_len and gae_lambda.
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+
+ .. tip::
+ Some variables need to initialize independently in different modes, such as gamma and gae_lambda in PPO. \
+ This design is for the convenience of parallel execution of different policy modes.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ assert self._cfg.action_space in ["continuous", "discrete", "hybrid"], self._cfg.action_space
+ self._action_space = self._cfg.action_space
+ if self._action_space == 'continuous':
+ self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample')
+ elif self._action_space == 'discrete':
+ self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
+ elif self._action_space == 'hybrid':
+ self._collect_model = model_wrap(self._model, wrapper_name='hybrid_reparam_multinomial_sample')
+ self._collect_model.reset()
+ self._gamma = self._cfg.collect.discount_factor
+ self._gae_lambda = self._cfg.collect.gae_lambda
+ self._recompute_adv = self._cfg.recompute_adv
+
+ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data (action logit and value) for learn mode defined in ``self._process_transition`` \
+ method. The key of the dict is the same as the input data, i.e. environment id.
+
+ .. tip::
+ If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \
+ related data as extra keyword arguments of this method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PPOPolicy: ``ding.policy.tests.test_ppo``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, mode='compute_actor_critic')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor],
+ timestep: namedtuple) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For PPO, it contains obs, next_obs, action, reward, done, logit, value.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.
+ - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \
+ as input. For PPO, it contains the state value, action and the logit of the action.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
+ except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
+ reward, done, info, etc.
+ Returns:
+ - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.
+
+ .. note::
+ ``next_obs`` is used to calculate nstep return when necessary, so we place in into transition by default. \
+ You can delete this field to save memory occupancy if you do not need nstep return.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': policy_output['action'],
+ 'logit': policy_output['logit'],
+ 'value': policy_output['value'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. In PPO, a train sample is a processed transition with new computed \
+ ``traj_flag`` and ``adv`` field. This method is usually used in collectors to execute necessary \
+ RL data preprocessing before training, which can help learner amortize revelant time consumption. \
+ In addition, you can also implement this method as an identity function and do the data processing \
+ in ``self._forward_learn`` method.
+ Arguments:
+ - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \
+ as input transitions, but may contain more data for training, such as GAE advantage.
+ """
+ data = transitions
+ data = to_device(data, self._device)
+ for transition in data:
+ transition['traj_flag'] = copy.deepcopy(transition['done'])
+ data[-1]['traj_flag'] = True
+
+ if self._cfg.learn.ignore_done:
+ data[-1]['done'] = False
+
+ if data[-1]['done']:
+ last_value = torch.zeros_like(data[-1]['value'])
+ else:
+ with torch.no_grad():
+ last_value = self._collect_model.forward(
+ unsqueeze(data[-1]['next_obs'], 0), mode='compute_actor_critic'
+ )['value']
+ if len(last_value.shape) == 2: # multi_agent case:
+ last_value = last_value.squeeze(0)
+ if self._value_norm:
+ last_value *= self._running_mean_std.std
+ for i in range(len(data)):
+ data[i]['value'] *= self._running_mean_std.std
+ data = get_gae(
+ data,
+ to_device(last_value, self._device),
+ gamma=self._gamma,
+ gae_lambda=self._gae_lambda,
+ cuda=False,
+ )
+ if self._value_norm:
+ for i in range(len(data)):
+ data[i]['value'] /= self._running_mean_std.std
+
+ # remove next_obs for save memory when not recompute adv
+ if not self._recompute_adv:
+ for i in range(len(data)):
+ data[i].pop('next_obs')
+ return get_train_sample(data, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For PPO, it contains the \
+ eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete action).
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ assert self._cfg.action_space in ["continuous", "discrete", "hybrid"]
+ self._action_space = self._cfg.action_space
+ if self._action_space == 'continuous':
+ self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample')
+ elif self._action_space == 'discrete':
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ elif self._action_space == 'hybrid':
+ self._eval_model = model_wrap(self._model, wrapper_name='hybrid_reparam_multinomial_sample')
+
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs. ``_forward_eval`` in PPO often uses deterministic sample method to get \
+ actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \
+ exploitation.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PPOPolicy: ``ding.policy.tests.test_ppo``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ variables = super()._monitor_vars_learn() + [
+ 'policy_loss',
+ 'value_loss',
+ 'entropy_loss',
+ 'adv_max',
+ 'adv_mean',
+ 'approx_kl',
+ 'clipfrac',
+ 'value_max',
+ 'value_mean',
+ ]
+ if self._action_space == 'continuous':
+ variables += ['mu_mean', 'sigma_mean', 'sigma_grad', 'act']
+ return variables
+
+
+@POLICY_REGISTRY.register('ppo_pg')
+class PPOPGPolicy(Policy):
+ """
+ Overview:
+ Policy class of on policy version PPO algorithm (pure policy gradient without value network).
+ Paper link: https://arxiv.org/abs/1707.06347.
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='ppo_pg',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used)
+ on_policy=True,
+ # (str) Which kind of action space used in PPOPolicy, ['discrete', 'continuous', 'hybrid']
+ action_space='discrete',
+ # (bool) Whether to enable multi-agent training, i.e.: MAPPO.
+ multi_agent=False,
+ # (bool) Whether to need policy data in process transition.
+ transition_with_policy_data=True,
+ # learn_mode config
+ learn=dict(
+ # (int) After collecting n_sample/n_episode data, how many epoches to train models.
+ # Each epoch means the one entire passing of training data.
+ epoch_per_collect=10,
+ # (int) How many samples in a training batch.
+ batch_size=64,
+ # (float) The step size of gradient descent.
+ learning_rate=3e-4,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1.
+ entropy_weight=0.0,
+ # (float) PPO clip ratio, defaults to 0.2.
+ clip_ratio=0.2,
+ # (bool) Whether to enable special network parameters initialization scheme in PPO, such as orthogonal init.
+ ppo_param_init=True,
+ # (str) The gradient clip operation type used in PPO, ['clip_norm', clip_value', 'clip_momentum_norm'].
+ grad_clip_type='clip_norm',
+ # (float) The gradient clip target value used in PPO.
+ # If ``grad_clip_type`` is 'clip_norm', then the maximum of gradient will be normalized to this value.
+ grad_clip_value=0.5,
+ # (bool) Whether ignore done (usually for max step termination env).
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training episodes collected in one collection process. Only one of n_episode shoule be set.
+ # n_episode=8,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ ),
+ eval=dict(), # for compability
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+ """
+ return 'pg', ['ding.model.template.pg']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For PPOPG, it mainly \
+ contains optimizer, algorithm-specific arguments such as loss weight and clip_ratio. This method \
+ also executes some special network initializations.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ assert self._cfg.action_space in ["continuous", "discrete"]
+ self._action_space = self._cfg.action_space
+ if self._cfg.learn.ppo_param_init:
+ for n, m in self._model.named_modules():
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.orthogonal_(m.weight)
+ torch.nn.init.zeros_(m.bias)
+ if self._action_space == 'continuous':
+ if hasattr(self._model.head, 'log_sigma_param'):
+ torch.nn.init.constant_(self._model.head.log_sigma_param, -0.5)
+ for m in self._model.modules():
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.zeros_(m.bias)
+ m.weight.data.copy_(0.01 * m.weight.data)
+
+ # Optimizer
+ self._optimizer = Adam(
+ self._model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ grad_clip_type=self._cfg.learn.grad_clip_type,
+ clip_value=self._cfg.learn.grad_clip_value
+ )
+
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+
+ # Algorithm config
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._clip_ratio = self._cfg.learn.clip_ratio
+ self._gamma = self._cfg.collect.discount_factor
+ # Main model
+ self._learn_model.reset()
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss, clipfrac, approx_kl.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including the latest \
+ collected training samples for on-policy algorithms like PPO. For each element in list, the key of the \
+ dict is the name of data items and the value is the corresponding data. Usually, the value is \
+ torch.Tensor or np.ndarray or there dict/list combinations. In the ``_forward_learn`` method, data \
+ often need to first be stacked in the batch dimension by some utility functions such as \
+ ``default_preprocess_learn``. \
+ For PPOPG, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
+ ``return``, ``logit``, ``done``. Sometimes, it also contains other keys such as ``weight``.
+ Returns:
+ - return_infos (:obj:`List[Dict[str, Any]]`): The information list that indicated training result, each \
+ training iteration contains append a information dict into the final list. The list will be precessed \
+ and recorded in text log and tensorboard. The value of the dict must be python scalar or a list of \
+ scalars. For the detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. tip::
+ The training procedure of PPOPG is two for loops. The outer loop trains all the collected training samples \
+ with ``epoch_per_collect`` epochs. The inner loop splits all the data into different mini-batch with \
+ the length of ``batch_size``.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+ """
+
+ data = default_preprocess_learn(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+ return_infos = []
+ self._learn_model.train()
+
+ for epoch in range(self._cfg.learn.epoch_per_collect):
+ for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
+ output = self._learn_model.forward(batch['obs'])
+
+ ppo_batch = ppo_policy_data(
+ output['logit'], batch['logit'], batch['action'], batch['return'], batch['weight']
+ )
+ if self._action_space == 'continuous':
+ ppo_loss, ppo_info = ppo_policy_error_continuous(ppo_batch, self._clip_ratio)
+ elif self._action_space == 'discrete':
+ ppo_loss, ppo_info = ppo_policy_error(ppo_batch, self._clip_ratio)
+ total_loss = ppo_loss.policy_loss - self._entropy_weight * ppo_loss.entropy_loss
+
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ self._optimizer.step()
+
+ return_info = {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': ppo_loss.policy_loss.item(),
+ 'entropy_loss': ppo_loss.entropy_loss.item(),
+ 'approx_kl': ppo_info.approx_kl,
+ 'clipfrac': ppo_info.clipfrac,
+ }
+ if self._action_space == 'continuous':
+ return_info.update(
+ {
+ 'act': batch['action'].float().mean().item(),
+ 'mu_mean': output['logit']['mu'].mean().item(),
+ 'sigma_mean': output['logit']['sigma'].mean().item(),
+ }
+ )
+ return_infos.append(return_info)
+ return return_infos
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For PPOPG, it contains \
+ the collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \
+ discrete action space), and other algorithm-specific arguments such as unroll_len and gae_lambda.
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+
+ .. tip::
+ Some variables need to initialize independently in different modes, such as gamma and gae_lambda in PPO. \
+ This design is for the convenience of parallel execution of different policy modes.
+ """
+ assert self._cfg.action_space in ["continuous", "discrete"], self._cfg.action_space
+ self._action_space = self._cfg.action_space
+ self._unroll_len = self._cfg.collect.unroll_len
+ if self._action_space == 'continuous':
+ self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample')
+ elif self._action_space == 'discrete':
+ self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
+ self._collect_model.reset()
+ self._gamma = self._cfg.collect.discount_factor
+
+ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data (action logit) for learn mode defined in ``self._process_transition`` \
+ method. The key of the dict is the same as the input data, i.e. environment id.
+
+ .. tip::
+ If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \
+ related data as extra keyword arguments of this method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor],
+ timestep: namedtuple) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For PPOPG, it contains obs, action, reward, done, logit.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.
+ - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \
+ as input. For PPOPG, it contains the action and the logit of the action.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
+ except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
+ reward, done, info, etc.
+ Returns:
+ - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.
+ """
+ transition = {
+ 'obs': obs,
+ 'action': policy_output['action'],
+ 'logit': policy_output['logit'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given entire episode data (a list of transition), process it into a list of sample that \
+ can be used for training directly. In PPOPG, a train sample is a processed transition with new computed \
+ ``return`` field. This method is usually used in collectors to execute necessary \
+ RL data preprocessing before training, which can help learner amortize revelant time consumption. \
+ In addition, you can also implement this method as an identity function and do the data processing \
+ in ``self._forward_learn`` method.
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]`): The episode data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \
+ as input transitions, but may contain more data for training, such as discounted episode return.
+ """
+ assert data[-1]['done'] is True, "PPO-PG needs a complete epsiode"
+
+ if self._cfg.learn.ignore_done:
+ raise NotImplementedError
+
+ R = 0.
+ for i in reversed(range(len(data))):
+ R = self._gamma * R + data[i]['reward']
+ data[i]['return'] = R
+
+ return get_train_sample(data, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For PPOPG, it contains the \
+ eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete action).
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ assert self._cfg.action_space in ["continuous", "discrete"]
+ self._action_space = self._cfg.action_space
+ if self._action_space == 'continuous':
+ self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample')
+ elif self._action_space == 'discrete':
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs. ``_forward_eval`` in PPO often uses deterministic sample method to get \
+ actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \
+ exploitation.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PPOPGPolicy: ``ding.policy.tests.test_ppo``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return super()._monitor_vars_learn() + [
+ 'policy_loss',
+ 'entropy_loss',
+ 'approx_kl',
+ 'clipfrac',
+ ]
+
+
+@POLICY_REGISTRY.register('ppo_offpolicy')
+class PPOOffPolicy(Policy):
+ """
+ Overview:
+ Policy class of off-policy version PPO algorithm. Paper link: https://arxiv.org/abs/1707.06347.
+ This version is more suitable for large-scale distributed training.
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='ppo',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ on_policy=False,
+ # (bool) Whether to use priority (priority sample, IS weight, update priority).
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (str) Which kind of action space used in PPOPolicy, ["continuous", "discrete", "hybrid"].
+ action_space='discrete',
+ # (bool) Whether to use nstep_return for value loss.
+ nstep_return=False,
+ # (int) The timestep of TD (temporal-difference) loss.
+ nstep=3,
+ # (bool) Whether to need policy data in process transition.
+ transition_with_policy_data=True,
+ # learn_mode config
+ learn=dict(
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=5,
+ # (int) How many samples in a training batch.
+ batch_size=64,
+ # (float) The step size of gradient descent.
+ learning_rate=0.001,
+ # (float) The loss weight of value network, policy network weight is set to 1.
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1.
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2.
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch.
+ adv_norm=False,
+ # (bool) Whether to use value norm with running mean and std in the whole training process.
+ value_norm=True,
+ # (bool) Whether to enable special network parameters initialization scheme in PPO, such as orthogonal init.
+ ppo_param_init=True,
+ # (str) The gradient clip operation type used in PPO, ['clip_norm', clip_value', 'clip_momentum_norm'].
+ grad_clip_type='clip_norm',
+ # (float) The gradient clip target value used in PPO.
+ # If ``grad_clip_type`` is 'clip_norm', then the maximum of gradient will be normalized to this value.
+ grad_clip_value=0.5,
+ # (bool) Whether ignore done (usually for max step termination env).
+ ignore_done=False,
+ # (float) The weight decay (L2 regularization) loss weight, defaults to 0.0.
+ weight_decay=0.0,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ # Only one of [n_sample, n_episode] shoule be set.
+ # n_sample=64,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # (float) GAE lambda factor for the balance of bias and variance (1-step td and mc).
+ gae_lambda=0.95,
+ ),
+ eval=dict(), # for compability
+ other=dict(
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer. Usually, larger buffer size is better.
+ replay_buffer_size=10000,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+ """
+ return 'vac', ['ding.model.template.vac']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For PPOOff, it mainly \
+ contains optimizer, algorithm-specific arguments such as loss weight and clip_ratio. This method \
+ also executes some special network initializations and prepares running mean/std monitor for value.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPOOff"
+
+ assert self._cfg.action_space in ["continuous", "discrete", "hybrid"]
+ self._action_space = self._cfg.action_space
+
+ if self._cfg.learn.ppo_param_init:
+ for n, m in self._model.named_modules():
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.orthogonal_(m.weight)
+ torch.nn.init.zeros_(m.bias)
+ if self._action_space in ['continuous', 'hybrid']:
+ # init log sigma
+ if self._action_space == 'continuous':
+ if hasattr(self._model.actor_head, 'log_sigma_param'):
+ torch.nn.init.constant_(self._model.actor_head.log_sigma_param, -2.0)
+ elif self._action_space == 'hybrid': # actor_head[1]: ReparameterizationHead, for action_args
+ if hasattr(self._model.actor_head[1], 'log_sigma_param'):
+ torch.nn.init.constant_(self._model.actor_head[1].log_sigma_param, -0.5)
+
+ for m in list(self._model.critic.modules()) + list(self._model.actor.modules()):
+ if isinstance(m, torch.nn.Linear):
+ # orthogonal initialization
+ torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
+ torch.nn.init.zeros_(m.bias)
+ # do last policy layer scaling, this will make initial actions have (close to)
+ # 0 mean and std, and will help boost performances,
+ # see https://arxiv.org/abs/2006.05990, Fig.24 for details
+ for m in self._model.actor.modules():
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.zeros_(m.bias)
+ m.weight.data.copy_(0.01 * m.weight.data)
+
+ # Optimizer
+ self._optimizer = Adam(
+ self._model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ grad_clip_type=self._cfg.learn.grad_clip_type,
+ clip_value=self._cfg.learn.grad_clip_value
+ )
+
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+
+ # Algorithm config
+ self._value_weight = self._cfg.learn.value_weight
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._clip_ratio = self._cfg.learn.clip_ratio
+ self._adv_norm = self._cfg.learn.adv_norm
+ self._value_norm = self._cfg.learn.value_norm
+ if self._value_norm:
+ self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)
+ self._gamma = self._cfg.collect.discount_factor
+ self._gae_lambda = self._cfg.collect.gae_lambda
+ self._nstep = self._cfg.nstep
+ self._nstep_return = self._cfg.nstep_return
+ # Main model
+ self._learn_model.reset()
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss, clipfrac and approx_kl.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
+ dimension by some utility functions such as ``default_preprocess_learn``. \
+ For PPOOff, each element in list is a dict containing at least the following keys: ``obs``, ``adv``, \
+ ``action``, ``logit``, ``value``, ``done``. Sometimes, it also contains other keys such as ``weight`` \
+ and ``value_gamma``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+ """
+ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=self._nstep_return)
+ if self._cuda:
+ data = to_device(data, self._device)
+ data['obs'] = to_dtype(data['obs'], torch.float32)
+ if 'next_obs' in data:
+ data['next_obs'] = to_dtype(data['next_obs'], torch.float32)
+ # ====================
+ # PPO forward
+ # ====================
+
+ self._learn_model.train()
+
+ with torch.no_grad():
+ if self._value_norm:
+ unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std
+ data['return'] = unnormalized_return / self._running_mean_std.std
+ self._running_mean_std.update(unnormalized_return.cpu().numpy())
+ else:
+ data['return'] = data['adv'] + data['value']
+
+ # normal ppo
+ if not self._nstep_return:
+ output = self._learn_model.forward(data['obs'], mode='compute_actor_critic')
+ adv = data['adv']
+
+ if self._adv_norm:
+ # Normalize advantage in a total train_batch
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
+ # Calculate ppo loss
+ if self._action_space == 'continuous':
+ ppodata = ppo_data(
+ output['logit'], data['logit'], data['action'], output['value'], data['value'], adv, data['return'],
+ data['weight']
+ )
+ ppo_loss, ppo_info = ppo_error_continuous(ppodata, self._clip_ratio)
+ elif self._action_space == 'discrete':
+ ppodata = ppo_data(
+ output['logit'], data['logit'], data['action'], output['value'], data['value'], adv, data['return'],
+ data['weight']
+ )
+ ppo_loss, ppo_info = ppo_error(ppodata, self._clip_ratio)
+ elif self._action_space == 'hybrid':
+ # discrete part (discrete policy loss and entropy loss)
+ ppo_discrete_batch = ppo_policy_data(
+ output['logit']['action_type'], data['logit']['action_type'], data['action']['action_type'], adv,
+ data['weight']
+ )
+ ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(ppo_discrete_batch, self._clip_ratio)
+ # continuous part (continuous policy loss and entropy loss, value loss)
+ ppo_continuous_batch = ppo_data(
+ output['logit']['action_args'], data['logit']['action_args'], data['action']['action_args'],
+ output['value'], data['value'], adv, data['return'], data['weight']
+ )
+ ppo_continuous_loss, ppo_continuous_info = ppo_error_continuous(ppo_continuous_batch, self._clip_ratio)
+ # sum discrete and continuous loss
+ ppo_loss = type(ppo_continuous_loss)(
+ ppo_continuous_loss.policy_loss + ppo_discrete_loss.policy_loss, ppo_continuous_loss.value_loss,
+ ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss
+ )
+ ppo_info = type(ppo_continuous_info)(
+ max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl),
+ max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac)
+ )
+
+ wv, we = self._value_weight, self._entropy_weight
+ total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
+
+ else:
+ output = self._learn_model.forward(data['obs'], mode='compute_actor')
+ adv = data['adv']
+ if self._adv_norm:
+ # Normalize advantage in a total train_batch
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
+
+ # Calculate ppo loss
+ if self._action_space == 'continuous':
+ ppodata = ppo_policy_data(output['logit'], data['logit'], data['action'], adv, data['weight'])
+ ppo_policy_loss, ppo_info = ppo_policy_error_continuous(ppodata, self._clip_ratio)
+ elif self._action_space == 'discrete':
+ ppodata = ppo_policy_data(output['logit'], data['logit'], data['action'], adv, data['weight'])
+ ppo_policy_loss, ppo_info = ppo_policy_error(ppodata, self._clip_ratio)
+ elif self._action_space == 'hybrid':
+ # discrete part (discrete policy loss and entropy loss)
+ ppo_discrete_data = ppo_policy_data(
+ output['logit']['action_type'], data['logit']['action_type'], data['action']['action_type'], adv,
+ data['weight']
+ )
+ ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(ppo_discrete_data, self._clip_ratio)
+ # continuous part (continuous policy loss and entropy loss, value loss)
+ ppo_continuous_data = ppo_policy_data(
+ output['logit']['action_args'], data['logit']['action_args'], data['action']['action_args'], adv,
+ data['weight']
+ )
+ ppo_continuous_loss, ppo_continuous_info = ppo_policy_error_continuous(
+ ppo_continuous_data, self._clip_ratio
+ )
+ # sum discrete and continuous loss
+ ppo_policy_loss = type(ppo_continuous_loss)(
+ ppo_continuous_loss.policy_loss + ppo_discrete_loss.policy_loss,
+ ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss
+ )
+ ppo_info = type(ppo_continuous_info)(
+ max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl),
+ max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac)
+ )
+
+ wv, we = self._value_weight, self._entropy_weight
+ next_obs = data.get('next_obs')
+ value_gamma = data.get('value_gamma')
+ reward = data.get('reward')
+ # current value
+ value = self._learn_model.forward(data['obs'], mode='compute_critic')
+ # target value
+ next_data = {'obs': next_obs}
+ target_value = self._learn_model.forward(next_data['obs'], mode='compute_critic')
+ # TODO what should we do here to keep shape
+ assert self._nstep > 1
+ td_data = v_nstep_td_data(
+ value['value'], target_value['value'], reward, data['done'], data['weight'], value_gamma
+ )
+ # calculate v_nstep_td critic_loss
+ critic_loss, td_error_per_sample = v_nstep_td_error(td_data, self._gamma, self._nstep)
+ ppo_loss_data = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
+ ppo_loss = ppo_loss_data(ppo_policy_loss.policy_loss, critic_loss, ppo_policy_loss.entropy_loss)
+ total_loss = ppo_policy_loss.policy_loss + wv * critic_loss - we * ppo_policy_loss.entropy_loss
+
+ # ====================
+ # PPO update
+ # ====================
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ self._optimizer.step()
+ return_info = {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': ppo_loss.policy_loss.item(),
+ 'value': data['value'].mean().item(),
+ 'value_loss': ppo_loss.value_loss.item(),
+ 'entropy_loss': ppo_loss.entropy_loss.item(),
+ 'adv_abs_max': adv.abs().max().item(),
+ 'approx_kl': ppo_info.approx_kl,
+ 'clipfrac': ppo_info.clipfrac,
+ }
+ if self._action_space == 'continuous':
+ return_info.update(
+ {
+ 'act': data['action'].float().mean().item(),
+ 'mu_mean': output['logit']['mu'].mean().item(),
+ 'sigma_mean': output['logit']['sigma'].mean().item(),
+ }
+ )
+ return return_info
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For PPOOff, it contains \
+ collect_model to balance the exploration and exploitation (e.g. the multinomial sample mechanism in \
+ discrete action space), and other algorithm-specific arguments such as unroll_len and gae_lambda.
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+
+ .. tip::
+ Some variables need to initialize independently in different modes, such as gamma and gae_lambda in PPOOff.
+ This design is for the convenience of parallel execution of different policy modes.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ assert self._cfg.action_space in ["continuous", "discrete", "hybrid"]
+ self._action_space = self._cfg.action_space
+ if self._action_space == 'continuous':
+ self._collect_model = model_wrap(self._model, wrapper_name='reparam_sample')
+ elif self._action_space == 'discrete':
+ self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
+ elif self._action_space == 'hybrid':
+ self._collect_model = model_wrap(self._model, wrapper_name='hybrid_reparam_multinomial_sample')
+ self._collect_model.reset()
+ self._gamma = self._cfg.collect.discount_factor
+ self._gae_lambda = self._cfg.collect.gae_lambda
+ self._nstep = self._cfg.nstep
+ self._nstep_return = self._cfg.nstep_return
+ self._value_norm = self._cfg.learn.value_norm
+ if self._value_norm:
+ self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)
+
+ def _forward_collect(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data (action logit and value) for learn mode defined in ``self._process_transition`` \
+ method. The key of the dict is the same as the input data, i.e. environment id.
+
+ .. tip::
+ If you want to add more tricks on this policy, like temperature factor in multinomial sample, you can pass \
+ related data as extra keyword arguments of this method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PPOOffPolicy: ``ding.policy.tests.test_ppo``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, mode='compute_actor_critic')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor],
+ timestep: namedtuple) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For PPO, it contains obs, next_obs, action, reward, done, logit, value.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.
+ - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \
+ as input. For PPO, it contains the state value, action and the logit of the action.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
+ except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
+ reward, done, info, etc.
+ Returns:
+ - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.
+
+ .. note::
+ ``next_obs`` is used to calculate nstep return when necessary, so we place in into transition by default. \
+ You can delete this field to save memory occupancy if you do not need nstep return.
+ """
+
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'logit': policy_output['logit'],
+ 'action': policy_output['action'],
+ 'value': policy_output['value'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. In PPO, a train sample is a processed transition with new computed \
+ ``traj_flag`` and ``adv`` field. This method is usually used in collectors to execute necessary \
+ RL data preprocessing before training, which can help learner amortize revelant time consumption. \
+ In addition, you can also implement this method as an identity function and do the data processing \
+ in ``self._forward_learn`` method.
+ Arguments:
+ - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \
+ as input transitions, but may contain more data for training, such as GAE advantage.
+ """
+ data = transitions
+ data = to_device(data, self._device)
+ for transition in data:
+ transition['traj_flag'] = copy.deepcopy(transition['done'])
+ data[-1]['traj_flag'] = True
+
+ if self._cfg.learn.ignore_done:
+ data[-1]['done'] = False
+
+ if data[-1]['done']:
+ last_value = torch.zeros_like(data[-1]['value'])
+ else:
+ with torch.no_grad():
+ last_value = self._collect_model.forward(
+ unsqueeze(data[-1]['next_obs'], 0), mode='compute_actor_critic'
+ )['value']
+ if len(last_value.shape) == 2: # multi_agent case:
+ last_value = last_value.squeeze(0)
+ if self._value_norm:
+ last_value *= self._running_mean_std.std
+ for i in range(len(data)):
+ data[i]['value'] *= self._running_mean_std.std
+ data = get_gae(
+ data,
+ to_device(last_value, self._device),
+ gamma=self._gamma,
+ gae_lambda=self._gae_lambda,
+ cuda=False,
+ )
+ if self._value_norm:
+ for i in range(len(data)):
+ data[i]['value'] /= self._running_mean_std.std
+
+ if not self._nstep_return:
+ return get_train_sample(data, self._unroll_len)
+ else:
+ return get_nstep_return_data(data, self._nstep)
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For PPOOff, it contains the \
+ eval model to select optimial action (e.g. greedily select action with argmax mechanism in discrete action).
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ assert self._cfg.action_space in ["continuous", "discrete", "hybrid"]
+ self._action_space = self._cfg.action_space
+ if self._action_space == 'continuous':
+ self._eval_model = model_wrap(self._model, wrapper_name='deterministic_sample')
+ elif self._action_space == 'discrete':
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ elif self._action_space == 'hybrid':
+ self._eval_model = model_wrap(self._model, wrapper_name='hybrid_deterministic_argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs. ``_forward_eval`` in PPO often uses deterministic sample method to get \
+ actions while ``_forward_collect`` usually uses stochastic sample method for balance exploration and \
+ exploitation.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for PPOOffPolicy: ``ding.policy.tests.test_ppo``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ variables = super()._monitor_vars_learn() + [
+ 'policy_loss', 'value', 'value_loss', 'entropy_loss', 'adv_abs_max', 'approx_kl', 'clipfrac'
+ ]
+ if self._action_space == 'continuous':
+ variables += ['mu_mean', 'sigma_mean', 'sigma_grad', 'act']
+ return variables
+
+
+@POLICY_REGISTRY.register('ppo_stdim')
+class PPOSTDIMPolicy(PPOPolicy):
+ """
+ Overview:
+ Policy class of on policy version PPO algorithm with ST-DIM auxiliary model.
+ PPO paper link: https://arxiv.org/abs/1707.06347.
+ ST-DIM paper link: https://arxiv.org/abs/1906.08226.
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='ppo_stdim',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used)
+ on_policy=True,
+ # (bool) Whether to use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (bool) Whether to use Importance Sampling Weight to correct biased update due to priority.
+ # If True, priority must be True.
+ priority_IS_weight=False,
+ # (bool) Whether to recompurete advantages in each iteration of on-policy PPO
+ recompute_adv=True,
+ # (str) Which kind of action space used in PPOPolicy, ['discrete', 'continuous']
+ action_space='discrete',
+ # (bool) Whether to use nstep return to calculate value target, otherwise, use return = adv + value
+ nstep_return=False,
+ # (bool) Whether to enable multi-agent training, i.e.: MAPPO
+ multi_agent=False,
+ # (bool) Whether to need policy data in process transition
+ transition_with_policy_data=True,
+ # (float) The loss weight of the auxiliary model to the main loss.
+ aux_loss_weight=0.001,
+ aux_model=dict(
+ # (int) the encoding size (of each head) to apply contrastive loss.
+ encode_shape=64,
+ # ([int, int]) the heads number of the obs encoding and next_obs encoding respectively.
+ heads=[1, 1],
+ # (str) the contrastive loss type.
+ loss_type='infonce',
+ # (float) a parameter to adjust the polarity between positive and negative samples.
+ temperature=1.0,
+ ),
+ # learn_mode config
+ learn=dict(
+ # (int) After collecting n_sample/n_episode data, how many epoches to train models.
+ # Each epoch means the one entire passing of training data.
+ epoch_per_collect=10,
+ # (int) How many samples in a training batch.
+ batch_size=64,
+ # (float) The step size of gradient descent.
+ learning_rate=3e-4,
+ # (float) The loss weight of value network, policy network weight is set to 1.
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1.
+ entropy_weight=0.0,
+ # (float) PPO clip ratio, defaults to 0.2.
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch.
+ adv_norm=True,
+ # (bool) Whether to use value norm with running mean and std in the whole training process.
+ value_norm=True,
+ # (bool) Whether to enable special network parameters initialization scheme in PPO, such as orthogonal init.
+ ppo_param_init=True,
+ # (str) The gradient clip operation type used in PPO, ['clip_norm', clip_value', 'clip_momentum_norm'].
+ grad_clip_type='clip_norm',
+ # (float) The gradient clip target value used in PPO.
+ # If ``grad_clip_type`` is 'clip_norm', then the maximum of gradient will be normalized to this value.
+ grad_clip_value=0.5,
+ # (bool) Whether ignore done (usually for max step termination env).
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ # Only one of [n_sample, n_episode] shoule be set.
+ # n_sample=64,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # (float) GAE lambda factor for the balance of bias and variance (1-step td and mc).
+ gae_lambda=0.95,
+ ),
+ eval=dict(), # for compability
+ )
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the auxiliary model, its optimizer, and the axuliary loss weight to the main loss.
+ """
+ super()._init_learn()
+ x_size, y_size = self._get_encoding_size()
+ self._aux_model = ContrastiveLoss(x_size, y_size, **self._cfg.aux_model)
+ if self._cuda:
+ self._aux_model.cuda()
+ self._aux_optimizer = Adam(self._aux_model.parameters(), lr=self._cfg.learn.learning_rate)
+ self._aux_loss_weight = self._cfg.aux_loss_weight
+
+ def _get_encoding_size(self):
+ """
+ Overview:
+ Get the input encoding size of the ST-DIM axuiliary model.
+ Returns:
+ - info_dict (:obj:`[Tuple, Tuple]`): The encoding size without the first (Batch) dimension.
+ """
+ obs = self._cfg.model.obs_shape
+ if isinstance(obs, int):
+ obs = [obs]
+ test_data = {
+ "obs": torch.randn(1, *obs),
+ "next_obs": torch.randn(1, *obs),
+ }
+ if self._cuda:
+ test_data = to_device(test_data, self._device)
+ with torch.no_grad():
+ x, y = self._model_encode(test_data)
+ return x.size()[1:], y.size()[1:]
+
+ def _model_encode(self, data):
+ """
+ Overview:
+ Get the encoding of the main model as input for the auxiliary model.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, same as the _forward_learn input.
+ Returns:
+ - (:obj:`Tuple[Tensor]`): the tuple of two tensors to apply contrastive embedding learning.
+ In ST-DIM algorithm, these two variables are the dqn encoding of `obs` and `next_obs`\
+ respectively.
+ """
+ assert hasattr(self._model, "encoder")
+ x = self._model.encoder(data["obs"])
+ y = self._model.encoder(data["next_obs"])
+ return x, y
+
+ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`):
+ Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \
+ adv_abs_max, approx_kl, clipfrac
+ """
+ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False)
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # PPO forward
+ # ====================
+ return_infos = []
+ self._learn_model.train()
+
+ for epoch in range(self._cfg.learn.epoch_per_collect):
+ if self._recompute_adv: # calculate new value using the new updated value network
+ with torch.no_grad():
+ value = self._learn_model.forward(data['obs'], mode='compute_critic')['value']
+ next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value']
+ if self._value_norm:
+ value *= self._running_mean_std.std
+ next_value *= self._running_mean_std.std
+
+ traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory
+ compute_adv_data = gae_data(value, next_value, data['reward'], data['done'], traj_flag)
+ data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda)
+
+ unnormalized_returns = value + data['adv']
+
+ if self._value_norm:
+ data['value'] = value / self._running_mean_std.std
+ data['return'] = unnormalized_returns / self._running_mean_std.std
+ self._running_mean_std.update(unnormalized_returns.cpu().numpy())
+ else:
+ data['value'] = value
+ data['return'] = unnormalized_returns
+
+ else: # don't recompute adv
+ if self._value_norm:
+ unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std
+ data['return'] = unnormalized_return / self._running_mean_std.std
+ self._running_mean_std.update(unnormalized_return.cpu().numpy())
+ else:
+ data['return'] = data['adv'] + data['value']
+
+ for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
+ # ======================
+ # Auxiliary model update
+ # ======================
+
+ # RL network encoding
+ # To train the auxiliary network, the gradients of x, y should be 0.
+ with torch.no_grad():
+ x_no_grad, y_no_grad = self._model_encode(batch)
+ # the forward function of the auxiliary network
+ self._aux_model.train()
+ aux_loss_learn = self._aux_model.forward(x_no_grad, y_no_grad)
+ # the BP process of the auxiliary network
+ self._aux_optimizer.zero_grad()
+ aux_loss_learn.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._aux_model)
+ self._aux_optimizer.step()
+
+ output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic')
+ adv = batch['adv']
+ if self._adv_norm:
+ # Normalize advantage in a train_batch
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
+
+ # Calculate ppo loss
+ if self._action_space == 'continuous':
+ ppo_batch = ppo_data(
+ output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
+ batch['return'], batch['weight']
+ )
+ ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._clip_ratio)
+ elif self._action_space == 'discrete':
+ ppo_batch = ppo_data(
+ output['logit'], batch['logit'], batch['action'], output['value'], batch['value'], adv,
+ batch['return'], batch['weight']
+ )
+ ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio)
+
+ # ======================
+ # Compute auxiliary loss
+ # ======================
+
+ # In total_loss BP, the gradients of x, y are required to update the encoding network.
+ # The auxiliary network won't be updated since the self._optimizer does not contain
+ # its weights.
+ x, y = self._model_encode(data)
+ self._aux_model.eval()
+ aux_loss_eval = self._aux_model.forward(x, y) * self._aux_loss_weight
+
+ wv, we = self._value_weight, self._entropy_weight
+ total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss\
+ + aux_loss_eval
+
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ self._optimizer.step()
+
+ return_info = {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': total_loss.item(),
+ 'aux_loss_learn': aux_loss_learn.item(),
+ 'aux_loss_eval': aux_loss_eval.item(),
+ 'policy_loss': ppo_loss.policy_loss.item(),
+ 'value_loss': ppo_loss.value_loss.item(),
+ 'entropy_loss': ppo_loss.entropy_loss.item(),
+ 'adv_max': adv.max().item(),
+ 'adv_mean': adv.mean().item(),
+ 'value_mean': output['value'].mean().item(),
+ 'value_max': output['value'].max().item(),
+ 'approx_kl': ppo_info.approx_kl,
+ 'clipfrac': ppo_info.clipfrac,
+ }
+ if self._action_space == 'continuous':
+ return_info.update(
+ {
+ 'act': batch['action'].float().mean().item(),
+ 'mu_mean': output['logit']['mu'].mean().item(),
+ 'sigma_mean': output['logit']['sigma'].mean().item(),
+ }
+ )
+ return_infos.append(return_info)
+ return return_infos
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode, usually including model, optimizer and aux_optimizer for \
+ representation learning.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ 'aux_optimizer': self._aux_optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+ self._aux_optimizer.load_state_dict(state_dict['aux_optimizer'])
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return super()._monitor_vars_learn() + ["aux_loss_learn", "aux_loss_eval"]
diff --git a/DI-engine/ding/policy/ppof.py b/DI-engine/ding/policy/ppof.py
new file mode 100644
index 0000000000000000000000000000000000000000..81e605384cd555d65e9e559dadbd213619be6c22
--- /dev/null
+++ b/DI-engine/ding/policy/ppof.py
@@ -0,0 +1,359 @@
+from typing import List, Dict, Any, Tuple, Union, Callable, Optional
+from collections import namedtuple
+from easydict import EasyDict
+import copy
+import random
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+from torch.optim import AdamW
+
+from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, gae, gae_data, ppo_error_continuous, \
+ get_gae, ppo_policy_error_continuous, ArgmaxSampler, MultinomialSampler, ReparameterizationSampler, MuSampler, \
+ HybridStochasticSampler, HybridDeterminsticSampler, value_transform, value_inv_transform, symlog, inv_symlog
+from ding.utils import POLICY_REGISTRY, RunningMeanStd
+
+
+@POLICY_REGISTRY.register('ppof')
+class PPOFPolicy:
+ config = dict(
+ type='ppo',
+ on_policy=True,
+ cuda=True,
+ action_space='discrete',
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ # learn
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ # learningrate scheduler, which the format is (10000, 0.1)
+ lr_scheduler=None,
+ weight_decay=0,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm='baseline',
+ ppo_param_init=True,
+ grad_norm=0.5,
+ # collect
+ n_sample=128,
+ unroll_len=1,
+ # eval
+ deterministic_eval=True,
+ # model
+ model=dict(),
+ )
+ mode = ['learn', 'collect', 'eval']
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ @classmethod
+ def default_model(cls: type) -> Callable:
+ from .model import PPOFModel
+ return PPOFModel
+
+ def __init__(self, cfg: "EasyDict", model: torch.nn.Module, enable_mode: List[str] = None) -> None:
+ self._cfg = cfg
+ if model is None:
+ self._model = self.default_model()
+ else:
+ self._model = model
+ if self._cfg.cuda and torch.cuda.is_available():
+ self._device = 'cuda'
+ self._model.cuda()
+ else:
+ self._device = 'cpu'
+ assert self._cfg.action_space in ["continuous", "discrete", "hybrid", 'multi_discrete']
+ self._action_space = self._cfg.action_space
+ if self._cfg.ppo_param_init:
+ self._model_param_init()
+
+ if enable_mode is None:
+ enable_mode = self.mode
+ self.enable_mode = enable_mode
+ if 'learn' in enable_mode:
+ self._optimizer = AdamW(
+ self._model.parameters(),
+ lr=self._cfg.learning_rate,
+ weight_decay=self._cfg.weight_decay,
+ )
+ # define linear lr scheduler
+ if self._cfg.lr_scheduler is not None:
+ epoch_num, min_lr_lambda = self._cfg.lr_scheduler
+
+ self._lr_scheduler = torch.optim.lr_scheduler.LambdaLR(
+ self._optimizer,
+ lr_lambda=lambda epoch: max(1.0 - epoch * (1.0 - min_lr_lambda) / epoch_num, min_lr_lambda)
+ )
+
+ if self._cfg.value_norm:
+ self._running_mean_std = RunningMeanStd(epsilon=1e-4, device=self._device)
+ if 'collect' in enable_mode:
+ if self._action_space == 'discrete':
+ self._collect_sampler = MultinomialSampler()
+ elif self._action_space == 'continuous':
+ self._collect_sampler = ReparameterizationSampler()
+ elif self._action_space == 'hybrid':
+ self._collect_sampler = HybridStochasticSampler()
+ if 'eval' in enable_mode:
+ if self._action_space == 'discrete':
+ if self._cfg.deterministic_eval:
+ self._eval_sampler = ArgmaxSampler()
+ else:
+ self._eval_sampler = MultinomialSampler()
+ elif self._action_space == 'continuous':
+ if self._cfg.deterministic_eval:
+ self._eval_sampler = MuSampler()
+ else:
+ self._eval_sampler = ReparameterizationSampler()
+ elif self._action_space == 'hybrid':
+ if self._cfg.deterministic_eval:
+ self._eval_sampler = HybridDeterminsticSampler()
+ else:
+ self._eval_sampler = HybridStochasticSampler()
+ # for compatibility
+ self.learn_mode = self
+ self.collect_mode = self
+ self.eval_mode = self
+
+ def _model_param_init(self):
+ for n, m in self._model.named_modules():
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.orthogonal_(m.weight)
+ torch.nn.init.zeros_(m.bias)
+ if self._action_space in ['continuous', 'hybrid']:
+ for m in list(self._model.critic.modules()) + list(self._model.actor.modules()):
+ if isinstance(m, torch.nn.Linear):
+ # orthogonal initialization
+ torch.nn.init.orthogonal_(m.weight, gain=np.sqrt(2))
+ torch.nn.init.zeros_(m.bias)
+ # init log sigma
+ if self._action_space == 'continuous':
+ torch.nn.init.constant_(self._model.actor_head.log_sigma_param, -0.5)
+ for m in self._model.actor_head.mu.modules():
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.zeros_(m.bias)
+ m.weight.data.copy_(0.01 * m.weight.data)
+ elif self._action_space == 'hybrid': # actor_head[1]: ReparameterizationHead, for action_args
+ if hasattr(self._model.actor_head[1], 'log_sigma_param'):
+ torch.nn.init.constant_(self._model.actor_head[1].log_sigma_param, -0.5)
+ for m in self._model.actor_head[1].mu.modules():
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.zeros_(m.bias)
+ m.weight.data.copy_(0.01 * m.weight.data)
+
+ def forward(self, data: ttorch.Tensor) -> Dict[str, Any]:
+ return_infos = []
+ self._model.train()
+ bs = self._cfg.batch_size
+ data = data[:self._cfg.n_sample // bs * bs] # rounding
+
+ # outer training loop
+ for epoch in range(self._cfg.epoch_per_collect):
+ # recompute adv
+ with torch.no_grad():
+ # get the value dictionary
+ # In popart, the dictionary has two keys: 'pred' and 'unnormalized_pred'
+ value = self._model.compute_critic(data.obs)
+ next_value = self._model.compute_critic(data.next_obs)
+ reward = data.reward
+
+ assert self._cfg.value_norm in ['popart', 'value_rescale', 'symlog', 'baseline'],\
+ 'Not supported value normalization! Value normalization supported: \
+ popart, value rescale, symlog, baseline'
+
+ if self._cfg.value_norm == 'popart':
+ unnormalized_value = value['unnormalized_pred']
+ unnormalized_next_value = value['unnormalized_pred']
+
+ mu = self._model.critic_head.popart.mu
+ sigma = self._model.critic_head.popart.sigma
+ reward = (reward - mu) / sigma
+
+ value = value['pred']
+ next_value = next_value['pred']
+ elif self._cfg.value_norm == 'value_rescale':
+ value = value_inv_transform(value['pred'])
+ next_value = value_inv_transform(next_value['pred'])
+ elif self._cfg.value_norm == 'symlog':
+ value = inv_symlog(value['pred'])
+ next_value = inv_symlog(next_value['pred'])
+ elif self._cfg.value_norm == 'baseline':
+ value = value['pred'] * self._running_mean_std.std
+ next_value = next_value['pred'] * self._running_mean_std.std
+
+ traj_flag = data.get('traj_flag', None) # traj_flag indicates termination of trajectory
+ adv_data = gae_data(value, next_value, reward, data.done, traj_flag)
+ data.adv = gae(adv_data, self._cfg.discount_factor, self._cfg.gae_lambda)
+
+ unnormalized_returns = value + data.adv # In popart, this return is normalized
+
+ if self._cfg.value_norm == 'popart':
+ self._model.critic_head.popart.update_parameters((data.reward).unsqueeze(1))
+ elif self._cfg.value_norm == 'value_rescale':
+ value = value_transform(value)
+ unnormalized_returns = value_transform(unnormalized_returns)
+ elif self._cfg.value_norm == 'symlog':
+ value = symlog(value)
+ unnormalized_returns = symlog(unnormalized_returns)
+ elif self._cfg.value_norm == 'baseline':
+ value /= self._running_mean_std.std
+ unnormalized_returns /= self._running_mean_std.std
+ self._running_mean_std.update(unnormalized_returns.cpu().numpy())
+ data.value = value
+ data.return_ = unnormalized_returns
+
+ # inner training loop
+ split_data = ttorch.split(data, self._cfg.batch_size)
+ random.shuffle(list(split_data))
+ for batch in split_data:
+ output = self._model.compute_actor_critic(batch.obs)
+ adv = batch.adv
+ if self._cfg.adv_norm:
+ # Normalize advantage in a train_batch
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
+
+ # Calculate ppo error
+ if self._action_space == 'continuous':
+ ppo_batch = ppo_data(
+ output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None
+ )
+ ppo_loss, ppo_info = ppo_error_continuous(ppo_batch, self._cfg.clip_ratio)
+ elif self._action_space == 'discrete':
+ ppo_batch = ppo_data(
+ output.logit, batch.logit, batch.action, output.value, batch.value, adv, batch.return_, None
+ )
+ ppo_loss, ppo_info = ppo_error(ppo_batch, self._cfg.clip_ratio)
+ elif self._action_space == 'hybrid':
+ # discrete part (discrete policy loss and entropy loss)
+ ppo_discrete_batch = ppo_policy_data(
+ output.logit.action_type, batch.logit.action_type, batch.action.action_type, adv, None
+ )
+ ppo_discrete_loss, ppo_discrete_info = ppo_policy_error(ppo_discrete_batch, self._cfg.clip_ratio)
+ # continuous part (continuous policy loss and entropy loss, value loss)
+ ppo_continuous_batch = ppo_data(
+ output.logit.action_args, batch.logit.action_args, batch.action.action_args, output.value,
+ batch.value, adv, batch.return_, None
+ )
+ ppo_continuous_loss, ppo_continuous_info = ppo_error_continuous(
+ ppo_continuous_batch, self._cfg.clip_ratio
+ )
+ # sum discrete and continuous loss
+ ppo_loss = type(ppo_continuous_loss)(
+ ppo_continuous_loss.policy_loss + ppo_discrete_loss.policy_loss, ppo_continuous_loss.value_loss,
+ ppo_continuous_loss.entropy_loss + ppo_discrete_loss.entropy_loss
+ )
+ ppo_info = type(ppo_continuous_info)(
+ max(ppo_continuous_info.approx_kl, ppo_discrete_info.approx_kl),
+ max(ppo_continuous_info.clipfrac, ppo_discrete_info.clipfrac)
+ )
+ wv, we = self._cfg.value_weight, self._cfg.entropy_weight
+ total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
+
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.grad_norm)
+ self._optimizer.step()
+
+ return_info = {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': ppo_loss.policy_loss.item(),
+ 'value_loss': ppo_loss.value_loss.item(),
+ 'entropy_loss': ppo_loss.entropy_loss.item(),
+ 'adv_max': adv.max().item(),
+ 'adv_mean': adv.mean().item(),
+ 'value_mean': output.value.mean().item(),
+ 'value_max': output.value.max().item(),
+ 'approx_kl': ppo_info.approx_kl,
+ 'clipfrac': ppo_info.clipfrac,
+ }
+ if self._action_space == 'continuous':
+ return_info.update(
+ {
+ 'action': batch.action.float().mean().item(),
+ 'mu_mean': output.logit.mu.mean().item(),
+ 'sigma_mean': output.logit.sigma.mean().item(),
+ }
+ )
+ elif self._action_space == 'hybrid':
+ return_info.update(
+ {
+ 'action': batch.action.action_args.float().mean().item(),
+ 'mu_mean': output.logit.action_args.mu.mean().item(),
+ 'sigma_mean': output.logit.action_args.sigma.mean().item(),
+ }
+ )
+ return_infos.append(return_info)
+
+ if self._cfg.lr_scheduler is not None:
+ self._lr_scheduler.step()
+
+ return return_infos
+
+ def state_dict(self) -> Dict[str, Any]:
+ state_dict = {
+ 'model': self._model.state_dict(),
+ }
+ if 'learn' in self.enable_mode:
+ state_dict['optimizer'] = self._optimizer.state_dict()
+ return state_dict
+
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+ self._model.load_state_dict(state_dict['model'])
+ if 'learn' in self.enable_mode:
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def collect(self, data: ttorch.Tensor) -> ttorch.Tensor:
+ self._model.eval()
+ with torch.no_grad():
+ output = self._model.compute_actor_critic(data)
+ action = self._collect_sampler(output.logit)
+ output.action = action
+ return output
+
+ def process_transition(self, obs: ttorch.Tensor, inference_output: dict, timestep: namedtuple) -> ttorch.Tensor:
+ return ttorch.as_tensor(
+ {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': inference_output.action,
+ 'logit': inference_output.logit,
+ 'value': inference_output.value,
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ )
+
+ def eval(self, data: ttorch.Tensor) -> ttorch.Tensor:
+ self._model.eval()
+ with torch.no_grad():
+ logit = self._model.compute_actor(data)
+ action = self._eval_sampler(logit)
+ return ttorch.as_tensor({'logit': logit, 'action': action})
+
+ def monitor_vars(self) -> List[str]:
+ variables = [
+ 'cur_lr',
+ 'policy_loss',
+ 'value_loss',
+ 'entropy_loss',
+ 'adv_max',
+ 'adv_mean',
+ 'approx_kl',
+ 'clipfrac',
+ 'value_max',
+ 'value_mean',
+ ]
+ if self._action_space in ['action', 'mu_mean', 'sigma_mean']:
+ variables += ['mu_mean', 'sigma_mean', 'action']
+ return variables
+
+ def reset(self, env_id_list: Optional[List[int]] = None) -> None:
+ pass
diff --git a/DI-engine/ding/policy/prompt_pg.py b/DI-engine/ding/policy/prompt_pg.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebccadb8a37324d5692c44afbb7eeeabbd84085d
--- /dev/null
+++ b/DI-engine/ding/policy/prompt_pg.py
@@ -0,0 +1,206 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import torch
+
+from ding.rl_utils import get_train_sample
+from ding.torch_utils import Adam, to_device
+from ding.utils import POLICY_REGISTRY, split_data_generator
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from ..model import model_wrap
+
+
+@POLICY_REGISTRY.register('prompt_pg')
+class PromptPGPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of Prompt Policy Gradient (PromptPG) algorithm.
+ Link of the original paper: https://arxiv.org/abs/2209.14610
+ """
+ config = dict(
+ # (string) RL policy register name (refer to function "register_policy").
+ type='prompt_pg',
+ # (bool) whether to use cuda for network.
+ cuda=True,
+ # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
+ on_policy=True, # for pg strictly on policy algorithm, this line should not be modified by users
+ # (bool) whether to use deterministic action for evaluation.
+ deterministic_eval=True,
+ learn=dict(
+ # (int) the number of samples for one update.
+ batch_size=64,
+ # (float) the step size of one gradient descend.
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ # (float) max grad norm value.
+ grad_norm=5,
+ # (bool) whether to ignore done signal for non-termination env.
+ ignore_done=False,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ # n_episode=8,
+ # (int) trajectory unroll length
+ unroll_len=1,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0,
+ collector=dict(get_train_sample=True),
+ ),
+ eval=dict(),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'language_transformer', ['ding.model.template.language_transformer']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the optimizer, algorithm config, main and target models.
+ """
+ # Optimizer
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._grad_norm = self._cfg.learn.grad_norm
+ self._learn_model = self._model # for compatibility
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
+ """
+ self._model.train()
+
+ return_infos = []
+ for i in range(0, len(data), self._cfg.learn.batch_size):
+ batch = default_collate(data[i:i + self._cfg.learn.batch_size])
+ if self._cuda:
+ batch = to_device(batch, self._device)
+
+ # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
+ train_samples, cand_samples = batch["obs"]["train_sample"], batch["obs"]["candidate_samples"]
+ for ii in range(len(cand_samples)):
+ cand_samples[ii] = cand_samples[ii][0]
+ output = self._learn_model.forward(train_samples, cand_samples)
+ return_ = batch['return']
+
+ # calculate PG loss
+ real_act = batch['action'] # shape: (B, shot_number)
+ # Calculate loss.
+ total_policy_loss, total_entropy_loss = 0, 0
+ for ii in range(self._cfg.shot_number):
+ log_prob = output['dist'].log_prob(real_act[:, ii])
+ policy_loss = -(log_prob * return_).mean()
+ total_policy_loss += policy_loss
+ total_entropy_loss += -self._cfg.learn.entropy_weight * output['dist'].entropy().mean()
+ total_loss = total_entropy_loss + total_policy_loss
+
+ # update
+ self._optimizer.zero_grad()
+ total_loss.backward()
+
+ grad_norm = torch.nn.utils.clip_grad_norm_(
+ list(self._learn_model.parameters()),
+ max_norm=self._grad_norm,
+ )
+ self._optimizer.step()
+
+ # only record last updates information in logger
+ return_info = {
+ 'cur_lr': self._optimizer.param_groups[0]['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': total_policy_loss.item(),
+ 'entropy_loss': total_entropy_loss.item(),
+ 'return_abs_max': return_.abs().max().item(),
+ 'grad_norm': grad_norm,
+ }
+ return_infos.append(return_info)
+ return return_infos
+
+ def _init_collect(self) -> None:
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._gamma = self._cfg.collect.discount_factor
+ self._collect_model = model_wrap(self._model, wrapper_name='combination_multinomial_sample')
+
+ def _forward_collect(self, data: dict) -> dict:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ self._model.eval()
+ with torch.no_grad():
+ # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
+ for ii in range(len(data['candidate_samples'])):
+ data['candidate_samples'][ii] = data['candidate_samples'][ii][0]
+ output = self._collect_model.forward(self._cfg.shot_number, data['train_sample'], data['candidate_samples'])
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ return {
+ 'obs': obs,
+ 'action': model_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the trajectory and the n step return data, then sample from the n_step return data
+ Arguments:
+ - data (:obj:`list`): The trajectory's buffer list
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ if self._cfg.learn.ignore_done:
+ raise NotImplementedError
+
+ R = 0.
+ for i in reversed(range(len(data))):
+ R = self._gamma * R + data[i]['reward']
+ data[i]['return'] = R
+ return get_train_sample(data, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ self._eval_model = model_wrap(self._model, wrapper_name='combination_argmax_sample')
+
+ def _forward_eval(self, data: dict) -> dict:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ self._model.eval()
+ with torch.no_grad():
+ # Prepare train_sample (the question to be answered) and the candidate_samples (the prompts to be selected)
+ for ii in range(len(data['candidate_samples'])):
+ data['candidate_samples'][ii] = data['candidate_samples'][ii][0]
+ output = self._eval_model.forward(self._cfg.shot_number, data['train_sample'], data['candidate_samples'])
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return super()._monitor_vars_learn() + ['policy_loss', 'entropy_loss', 'return_abs_max', 'grad_norm']
diff --git a/DI-engine/ding/policy/qmix.py b/DI-engine/ding/policy/qmix.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff1d66f7c8231001ed2d540aa0d68f0d8f613714
--- /dev/null
+++ b/DI-engine/ding/policy/qmix.py
@@ -0,0 +1,516 @@
+from typing import List, Dict, Any, Tuple, Optional
+from collections import namedtuple
+import copy
+import torch
+
+from ding.torch_utils import RMSprop, to_device
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import timestep_collate, default_collate, default_decollate
+from .base_policy import Policy
+
+
+@POLICY_REGISTRY.register('qmix')
+class QMIXPolicy(Policy):
+ """
+ Overview:
+ Policy class of QMIX algorithm. QMIX is a multi-agent reinforcement learning algorithm, \
+ you can view the paper in the following link https://arxiv.org/abs/1803.11485.
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str qmix | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool True | Whether to use cuda for network | this arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4. ``priority`` bool False | Whether use priority(PER) | priority sample,
+ | update priority
+ 5 | ``priority_`` bool False | Whether use Importance Sampling | IS weight
+ | ``IS_weight`` | Weight to correct biased update.
+ 6 | ``learn.update_`` int 20 | How many updates(iterations) to train | this args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 7 | ``learn.target_`` float 0.001 | Target network update momentum | between[0,1]
+ | ``update_theta`` | parameter.
+ 8 | ``learn.discount`` float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse
+ | ``_factor`` | gamma | reward env
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='qmix',
+ # (bool) Whether to use cuda for network.
+ cuda=True,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # learn_mode config
+ learn=dict(
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=20,
+ # (int) How many samples in a training batch.
+ batch_size=32,
+ # (float) The step size of gradient descent.
+ learning_rate=0.0005,
+ clip_value=100,
+ # (float) Target network update momentum parameter, in [0, 1].
+ target_update_theta=0.008,
+ # (float) The discount factor for future rewards, in [0, 1].
+ discount_factor=0.99,
+ # (bool) Whether to use double DQN mechanism(target q for surpassing over estimation).
+ double_q=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ # In each collect phase, we collect a total of sequence samples, a sample with length unroll_len.
+ # n_sample=32,
+ # (int) Split trajectories into pieces with length ``unroll_len``, the length of timesteps
+ # in each forward when training. In qmix, it is greater than 1 because there is RNN.
+ unroll_len=10,
+ ),
+ eval=dict(), # for compatibility
+ other=dict(
+ eps=dict(
+ # (str) Type of epsilon decay.
+ type='exp',
+ # (float) Start value for epsilon decay, in [0, 1].
+ start=1,
+ # (float) Start value for epsilon decay, in [0, 1].
+ end=0.05,
+ # (int) Decay length(env step).
+ decay=50000,
+ ),
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer. Usually, larger buffer size is better.
+ replay_buffer_size=5000,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default model setting for demonstration.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For QMIX, ``ding.model.qmix.qmix``
+ """
+ return 'qmix', ['ding.model.template.qmix']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including some attributes and modules. For QMIX, it mainly contains \
+ optimizer, algorithm-specific arguments such as gamma, main and target model. Because of the use of RNN, \
+ all the models should be wrappered with ``hidden_state`` which needs to be initialized with proper size.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. tip::
+ For multi-agent algorithm, we often need to use ``agent_num`` to initialize some necessary variables.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ - agent_num (:obj:`int`): Since this is a multi-agent algorithm, we need to input the agent num.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in QMIX"
+ self._optimizer = RMSprop(
+ params=self._model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ alpha=0.99,
+ eps=0.00001,
+ weight_decay=1e-5
+ )
+ self._gamma = self._cfg.learn.discount_factor
+
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_update_theta}
+ )
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._learn_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Preprocess the data to fit the required data format for learning
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function
+ Returns:
+ - data (:obj:`Dict[str, Any]`): the processed data, from \
+ [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])}
+ """
+ # data preprocess
+ data = timestep_collate(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+ data['weight'] = data.get('weight', None)
+ data['done'] = data['done'].float()
+ return data
+
+ def _forward_learn(self, data: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data (trajectory for QMIX) from the replay buffer and then \
+ returns the output result, including various training information such as loss, q value, grad_norm.
+ Arguments:
+ - data (:obj:`List[List[Dict[int, Any]]]`): The input data used for policy forward, including a batch of \
+ training samples. For each dict element, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the time and \
+ batch dimension by the utility functions ``self._data_preprocess_learn``. \
+ For QMIX, each element in list is a trajectory with the length of ``unroll_len``, and the element in \
+ trajectory list is a dict containing at least the following keys: ``obs``, ``action``, ``prev_state``, \
+ ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \
+ and ``value_gamma``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for QMIXPolicy: ``ding.policy.tests.test_qmix``.
+ """
+ data = self._data_preprocess_learn(data)
+ # ====================
+ # Q-mix forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # for hidden_state plugin, we need to reset the main model and target model
+ self._learn_model.reset(state=data['prev_state'][0])
+ self._target_model.reset(state=data['prev_state'][0])
+ inputs = {'obs': data['obs'], 'action': data['action']}
+ total_q = self._learn_model.forward(inputs, single_step=False)['total_q']
+
+ if self._cfg.learn.double_q:
+ next_inputs = {'obs': data['next_obs']}
+ self._learn_model.reset(state=data['prev_state'][1])
+ logit_detach = self._learn_model.forward(next_inputs, single_step=False)['logit'].clone().detach()
+ next_inputs = {'obs': data['next_obs'], 'action': logit_detach.argmax(dim=-1)}
+ else:
+ next_inputs = {'obs': data['next_obs']}
+ with torch.no_grad():
+ target_total_q = self._target_model.forward(next_inputs, single_step=False)['total_q']
+
+ with torch.no_grad():
+ if data['done'] is not None:
+ target_v = self._gamma * (1 - data['done']) * target_total_q + data['reward']
+ else:
+ target_v = self._gamma * target_total_q + data['reward']
+
+ data = v_1step_td_data(total_q, target_total_q, data['reward'], data['done'], data['weight'])
+ loss, td_error_per_sample = v_1step_td_error(data, self._gamma)
+ # ====================
+ # Q-mix update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ grad_norm = torch.nn.utils.clip_grad_norm_(self._model.parameters(), self._cfg.learn.clip_value)
+ self._optimizer.step()
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'total_q': total_q.mean().item() / self._cfg.model.agent_num,
+ 'target_reward_total_q': target_v.mean().item() / self._cfg.model.agent_num,
+ 'target_total_q': target_total_q.mean().item() / self._cfg.model.agent_num,
+ 'grad_norm': grad_norm,
+ }
+
+ def _reset_learn(self, data_id: Optional[List[int]] = None) -> None:
+ """
+ Overview:
+ Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the \
+ memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \
+ varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \
+ different trajectories in ``data_id`` will have different hidden state in RNN.
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \
+ (i.e. RNN hidden_state in QMIX) specified by ``data_id``.
+ """
+ self._learn_model.reset(data_id=data_id)
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode, usually including model, target_model and optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For QMIX, it contains the \
+ collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism and \
+ maintain the hidden state of rnn. Besides, there are some initialization operations about other \
+ algorithm-specific arguments such as burnin_step, unroll_len and nstep.
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._collect_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.collect.env_num,
+ save_prev_state=True,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \
+ exploration, i.e., classic epsilon-greedy exploration strategy.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ - eps (:obj:`float`): The epsilon value for exploration.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data (prev_state) for learn mode defined in ``self._process_transition`` method. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ RNN's hidden states are maintained in the policy, so we don't need pass them into data but to reset the \
+ hidden states with ``_reset_collect`` method when episode ends. Besides, the previous hidden states are \
+ necessary for training, so we need to return them in ``_process_transition`` method.
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for QMIXPolicy: ``ding.policy.tests.test_qmix``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, eps=eps, data_id=data_id)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None:
+ """
+ Overview:
+ Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \
+ memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \
+ varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \
+ different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN.
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \
+ (i.e., RNN hidden_state in QMIX) specified by ``data_id``.
+ """
+ self._collect_model.reset(data_id=data_id)
+
+ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor],
+ timestep: namedtuple) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For QMIX, it contains obs, next_obs, action, prev_state, reward, done.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): The env observation of current timestep, usually including ``agent_obs`` \
+ and ``global_obs`` in multi-agent environment like MPE and SMAC.
+ - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \
+ as input. For QMIX, it contains the action and the prev_state of RNN.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
+ except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
+ reward, done, info, etc.
+ Returns:
+ - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'prev_state': policy_output['prev_state'],
+ 'action': policy_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. In QMIX, a train sample is processed transitions with unroll_len \
+ length. This method is usually used in collectors to execute necessary \
+ RL data preprocessing before training, which can help learner amortize revelant time consumption. \
+ In addition, you can also implement this method as an identity function and do the data processing \
+ in ``self._forward_learn`` method.
+ Arguments:
+ - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each sample is a fixed-length \
+ trajectory, and each element in a sample is the similar format as input transitions.
+ """
+ return get_train_sample(transitions, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For QMIX, it contains the \
+ eval model to greedily select action with argmax q_value mechanism and main the hidden state.
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ self._eval_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.eval.env_num,
+ save_prev_state=True,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs. ``_forward_eval`` often use argmax sample method to get actions that \
+ q_value is the highest.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ RNN's hidden states are maintained in the policy, so we don't need pass them into data but to reset the \
+ hidden states with ``_reset_eval`` method when the episode ends.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for QMIXPolicy: ``ding.policy.tests.test_qmix``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, data_id=data_id)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
+ """
+ Overview:
+ Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \
+ memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \
+ varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \
+ different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN.
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \
+ (i.e., RNN hidden_state in QMIX) specified by ``data_id``.
+ """
+ self._eval_model.reset(data_id=data_id)
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return ['cur_lr', 'total_loss', 'total_q', 'target_total_q', 'grad_norm', 'target_reward_total_q']
diff --git a/DI-engine/ding/policy/qrdqn.py b/DI-engine/ding/policy/qrdqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..d2ed004464ee026ce2d237fbf32f2f785c86f3ae
--- /dev/null
+++ b/DI-engine/ding/policy/qrdqn.py
@@ -0,0 +1,239 @@
+from typing import List, Dict, Any, Tuple, Union
+import copy
+import torch
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import qrdqn_nstep_td_data, qrdqn_nstep_td_error, get_train_sample, get_nstep_return_data
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .dqn import DQNPolicy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('qrdqn')
+class QRDQNPolicy(DQNPolicy):
+ r"""
+ Overview:
+ Policy class of QRDQN algorithm. QRDQN (https://arxiv.org/pdf/1710.10044.pdf) is a distributional RL \
+ algorithm, which is an extension of DQN. The main idea of QRDQN is to use quantile regression to \
+ estimate the quantile of the distribution of the return value, and then use the quantile to calculate \
+ the quantile loss.
+
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str qrdqn | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool True | Whether use priority(PER) | priority sample,
+ | update priority
+ 6 | ``other.eps`` float 0.05 | Start value for epsilon decay. It's
+ | ``.start`` | small because rainbow use noisy net.
+ 7 | ``other.eps`` float 0.05 | End value for epsilon decay.
+ | ``.end``
+ 8 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 9 ``nstep`` int 3, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 10 | ``learn.update`` int 3 | How many updates(iterations) to train | this args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 11 ``learn.kappa`` float / | Threshold of Huber loss
+ == ==================== ======== ============== ======================================== =======================
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='qrdqn',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.97,
+ # (int) N-step reward for target q_value estimation
+ nstep=1,
+ learn=dict(
+
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=3,
+ batch_size=64,
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (int) Frequence of target network update.
+ target_update_freq=100,
+ # (bool) Whether ignore done(usually for max step termination env)
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) Only one of [n_sample, n_step, n_episode] shoule be set
+ # n_sample=8,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(),
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=0.95,
+ end=0.1,
+ # (int) Decay length(env step)
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=10000, )
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+ """
+ return 'qrdqn', ['ding.model.template.q_learning']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For QRDQN, it mainly \
+ contains optimizer, algorithm-specific arguments such as nstep and gamma. This method \
+ also executes some special network initializations and prepares running mean/std monitor for value.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ # Optimizer
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+
+ # use model_wrapper for specialized demands of different modes
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='assign',
+ update_kwargs={'freq': self._cfg.learn.target_update_freq}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss, current lr.
+
+ Arguments:
+ - data (:obj:`dict`): Input data used for policy forward, including the \
+ collected training samples from replay buffer. For each element in dict, the key of the \
+ dict is the name of data items and the value is the corresponding data. Usually, the value is \
+ torch.Tensor or np.ndarray or there dict/list combinations. In the ``_forward_learn`` method, data \
+ often need to first be stacked in the batch dimension by some utility functions such as \
+ ``default_preprocess_learn``. \
+ For QRDQN, each element in list is a dict containing at least the following keys: ``obs``, \
+ ``action``, ``reward``, ``next_obs``. Sometimes, it also contains other keys such as ``weight``.
+
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The output result dict of forward learn, \
+ containing current lr, total_loss and priority. When discrete action satisfying \
+ len(data['action'])==1, it also could contain ``action_distribution`` which is used \
+ to draw histogram on tensorboard. For more information, please refer to the :class:`DQNPolicy`.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for QRDQNPolicy: ``ding.policy.tests.test_qrdqn``.
+ """
+
+ data = default_preprocess_learn(
+ data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # Q-learning forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # Current q value (main model)
+ ret = self._learn_model.forward(data['obs'])
+ q_value, tau = ret['q'], ret['tau']
+ # Target q value
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(data['next_obs'])['q']
+ # Max q value action (main model)
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+
+ data_n = qrdqn_nstep_td_data(
+ q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], tau, data['weight']
+ )
+ value_gamma = data.get('value_gamma')
+ loss, td_error_per_sample = qrdqn_nstep_td_error(
+ data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma
+ )
+
+ # ====================
+ # Q-learning update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ self._optimizer.step()
+
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
+ # '[histogram]action_distribution': data['action'],
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
diff --git a/DI-engine/ding/policy/qtran.py b/DI-engine/ding/policy/qtran.py
new file mode 100644
index 0000000000000000000000000000000000000000..c75b942eb8bc6f912e873a4e8814319b2c3f2778
--- /dev/null
+++ b/DI-engine/ding/policy/qtran.py
@@ -0,0 +1,457 @@
+from typing import List, Dict, Any, Tuple, Union, Optional
+from collections import namedtuple
+import torch
+import torch.nn.functional as F
+import copy
+from easydict import EasyDict
+
+from ding.torch_utils import Adam, RMSprop, to_device
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_epsilon_greedy_fn, get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import timestep_collate, default_collate, default_decollate
+from .base_policy import Policy
+
+
+@POLICY_REGISTRY.register('qtran')
+class QTRANPolicy(Policy):
+ """
+ Overview:
+ Policy class of QTRAN algorithm. QTRAN is a multi model reinforcement learning algorithm, \
+ you can view the paper in the following link https://arxiv.org/abs/1803.11485
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str qtran | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool True | Whether to use cuda for network | this arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4. ``priority`` bool False | Whether use priority(PER) | priority sample,
+ | update priority
+ 5 | ``priority_`` bool False | Whether use Importance Sampling | IS weight
+ | ``IS_weight`` | Weight to correct biased update.
+ 6 | ``learn.update_`` int 20 | How many updates(iterations) to train | this args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 7 | ``learn.target_`` float 0.001 | Target network update momentum | between[0,1]
+ | ``update_theta`` | parameter.
+ 8 | ``learn.discount`` float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse
+ | ``_factor`` | gamma | reward env
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='qtran',
+ # (bool) Whether to use cuda for network.
+ cuda=True,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=1.5,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) Target network update momentum parameter.
+ # in [0, 1].
+ target_update_theta=0.008,
+ # (float) The discount factor for future rewards,
+ # in [0, 1].
+ discount_factor=0.99,
+ # (float) the loss weight of TD-error
+ td_weight=1,
+ # (float) the loss weight of Opt Loss
+ opt_weight=0.01,
+ # (float) the loss weight of Nopt Loss
+ nopt_min_weight=0.0001,
+ # (bool) Whether to use double DQN mechanism(target q for surpassing over estimation)
+ double_q=True,
+ ),
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] shoule be set
+ # n_sample=32 * 16,
+ # (int) Cut trajectories into pieces with length "unroll_len", the length of timesteps
+ # in each forward when training. In qtran, it is greater than 1 because there is RNN.
+ unroll_len=10,
+ ),
+ eval=dict(),
+ other=dict(
+ eps=dict(
+ # (str) Type of epsilon decay
+ type='exp',
+ # (float) Start value for epsilon decay, in [0, 1].
+ # 0 means not use epsilon decay.
+ start=1,
+ # (float) Start value for epsilon decay, in [0, 1].
+ end=0.05,
+ # (int) Decay length(env step)
+ decay=50000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=5000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default model setting for demonstration.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For QTRAN, ``ding.model.qtran.qtran``
+ """
+ return 'qtran', ['ding.model.template.qtran']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the learner model of QTRANPolicy
+ Arguments:
+ .. note::
+
+ The _init_learn method takes the argument from the self._cfg.learn in the config file
+
+ - learning_rate (:obj:`float`): The learning rate fo the optimizer
+ - gamma (:obj:`float`): The discount factor
+ - agent_num (:obj:`int`): This is a multi-agent algorithm, we need to input agent num.
+ - batch_size (:obj:`int`): Need batch size info to init hidden_state plugins
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in QTRAN"
+ self._optimizer = RMSprop(
+ params=self._model.parameters(), lr=self._cfg.learn.learning_rate, alpha=0.99, eps=0.00001
+ )
+ self._gamma = self._cfg.learn.discount_factor
+ self._td_weight = self._cfg.learn.td_weight
+ self._opt_weight = self._cfg.learn.opt_weight
+ self._nopt_min_weight = self._cfg.learn.nopt_min_weight
+
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_update_theta}
+ )
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._learn_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _data_preprocess_learn(self, data: List[Any]) -> dict:
+ r"""
+ Overview:
+ Preprocess the data to fit the required data format for learning
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function
+ Returns:
+ - data (:obj:`Dict[str, Any]`): the processed data, from \
+ [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])}
+ """
+ # data preprocess
+ data = timestep_collate(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+ data['weight'] = data.get('weight', None)
+ data['done'] = data['done'].float()
+ return data
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
+ np.ndarray or dict/list combinations.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
+ recorded in text log and tensorboard, values are python scalar or a list of scalars.
+ ArgumentsKeys:
+ - necessary: ``obs``, ``next_obs``, ``action``, ``reward``, ``weight``, ``prev_state``, ``done``
+ ReturnsKeys:
+ - necessary: ``cur_lr``, ``total_loss``
+ - cur_lr (:obj:`float`): Current learning rate
+ - total_loss (:obj:`float`): The calculated loss
+ """
+ data = self._data_preprocess_learn(data)
+ # ====================
+ # Q-mix forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # for hidden_state plugin, we need to reset the main model and target model
+ self._learn_model.reset(state=data['prev_state'][0])
+ self._target_model.reset(state=data['prev_state'][0])
+ inputs = {'obs': data['obs'], 'action': data['action']}
+ learn_ret = self._learn_model.forward(inputs, single_step=False)
+ total_q = learn_ret['total_q']
+ vs = learn_ret['vs']
+ agent_q_act = learn_ret['agent_q_act']
+ logit_detach = learn_ret['logit'].clone()
+ logit_detach[data['obs']['action_mask'] == 0.0] = -9999999
+ logit_q, logit_action = logit_detach.max(dim=-1, keepdim=False)
+
+ if self._cfg.learn.double_q:
+ next_inputs = {'obs': data['next_obs']}
+ double_q_detach = self._learn_model.forward(next_inputs, single_step=False)['logit'].clone().detach()
+ _, double_q_action = double_q_detach.max(dim=-1, keepdim=False)
+ next_inputs = {'obs': data['next_obs'], 'action': double_q_action}
+ else:
+ next_inputs = {'obs': data['next_obs']}
+ with torch.no_grad():
+ target_total_q = self._target_model.forward(next_inputs, single_step=False)['total_q']
+
+ # -- TD Loss --
+ td_data = v_1step_td_data(total_q, target_total_q.detach(), data['reward'], data['done'], data['weight'])
+ td_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma)
+ # -- TD Loss --
+
+ # -- Opt Loss --
+ if data['weight'] is None:
+ weight = torch.ones_like(data['reward'])
+ opt_inputs = {'obs': data['obs'], 'action': logit_action}
+ max_q = self._learn_model.forward(opt_inputs, single_step=False)['total_q']
+ opt_error = logit_q.sum(dim=2) - max_q.detach() + vs
+ opt_loss = (opt_error ** 2 * weight).mean()
+ # -- Opt Loss --
+
+ # -- Nopt Loss --
+ nopt_values = agent_q_act.sum(dim=2) - total_q.detach() + vs
+ nopt_error = nopt_values.clamp(max=0)
+ nopt_min_loss = (nopt_error ** 2 * weight).mean()
+ # -- Nopt Loss --
+
+ total_loss = self._td_weight * td_loss + self._opt_weight * opt_loss + self._nopt_min_weight * nopt_min_loss
+ # ====================
+ # Q-mix update
+ # ====================
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ # just get grad_norm
+ grad_norm = torch.nn.utils.clip_grad_norm_(self._model.parameters(), 10000000)
+ self._optimizer.step()
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': total_loss.item(),
+ 'td_loss': td_loss.item(),
+ 'opt_loss': opt_loss.item(),
+ 'nopt_loss': nopt_min_loss.item(),
+ 'grad_norm': grad_norm,
+ }
+
+ def _reset_learn(self, data_id: Optional[List[int]] = None) -> None:
+ r"""
+ Overview:
+ Reset learn model to the state indicated by data_id
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\
+ the model state to the state indicated by data_id
+ """
+ self._learn_model.reset(data_id=data_id)
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Return the state_dict of learn mode, usually including model and optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ r"""
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ Enable the eps_greedy_sample and the hidden_state plugin.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._collect_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.collect.env_num,
+ save_prev_state=True,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: dict, eps: float) -> dict:
+ r"""
+ Overview:
+ Forward function for collect mode with eps_greedy
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, eps=eps, data_id=data_id)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None:
+ r"""
+ Overview:
+ Reset collect model to the state indicated by data_id
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\
+ the model state to the state indicated by data_id
+ """
+ self._collect_model.reset(data_id=data_id)
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data, including 'obs', 'next_obs', 'prev_state',\
+ 'action', 'reward', 'done'
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'prev_state': model_output['prev_state'],
+ 'action': model_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model with argmax strategy and the hidden_state plugin.
+ """
+ self._eval_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.eval.env_num,
+ save_prev_state=True,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of eval mode, similar to ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, data_id=data_id)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
+ r"""
+ Overview:
+ Reset eval model to the state indicated by data_id
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id that store the state and we will reset\
+ the model state to the state indicated by data_id
+ """
+ self._eval_model.reset(data_id=data_id)
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the train sample from trajectory.
+ Arguments:
+ - data (:obj:`list`): The trajectory's cache
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ return get_train_sample(data, self._unroll_len)
+
+ def _monitor_vars_learn(self) -> List[str]:
+ r"""
+ Overview:
+ Return variables' name if variables are to used in monitor.
+ Returns:
+ - vars (:obj:`List[str]`): Variables' name list.
+ """
+ return ['cur_lr', 'total_loss', 'td_loss', 'opt_loss', 'nopt_loss', 'grad_norm']
diff --git a/DI-engine/ding/policy/r2d2.py b/DI-engine/ding/policy/r2d2.py
new file mode 100644
index 0000000000000000000000000000000000000000..0726c2c8203d6975ea192ebd6d0b6acba881da72
--- /dev/null
+++ b/DI-engine/ding/policy/r2d2.py
@@ -0,0 +1,651 @@
+import copy
+from collections import namedtuple
+from typing import List, Dict, Any, Tuple, Union, Optional
+
+import torch
+
+from ding.model import model_wrap
+from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_nstep_td_error_with_rescale, get_nstep_return_data, \
+ get_train_sample
+from ding.torch_utils import Adam, to_device
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import timestep_collate, default_collate, default_decollate
+from .base_policy import Policy
+
+
+@POLICY_REGISTRY.register('r2d2')
+class R2D2Policy(Policy):
+ """
+ Overview:
+ Policy class of R2D2, from paper `Recurrent Experience Replay in Distributed Reinforcement Learning` .
+ R2D2 proposes that several tricks should be used to improve upon DRQN, namely some recurrent experience replay \
+ tricks and the burn-in mechanism for off-policy training.
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str r2d2 | RL policy register name, refer to | This arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool False | Whether use priority(PER) | Priority sample,
+ | update priority
+ 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight
+ | ``_weight`` | to correct biased update. If True,
+ | priority must be True.
+ 6 | ``discount_`` float 0.997, | Reward's future discount factor, aka. | May be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 7 ``nstep`` int 3, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 8 ``burnin_step`` int 2 | The timestep of burnin operation,
+ | which is designed to RNN hidden state
+ | difference caused by off-policy
+ 9 | ``learn.update`` int 1 | How many updates(iterations) to train | This args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 10 | ``learn.batch_`` int 64 | The number of samples of an iteration
+ | ``size``
+ 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
+ | ``_rate``
+ 12 | ``learn.value_`` bool True | Whether use value_rescale function for
+ | ``rescale`` | predicted value
+ 13 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
+ | ``update_freq``
+ 14 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
+ | ``done`` | calculation. | fake termination env
+ 15 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
+ | call of collector. | different envs
+ 16 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
+ | ``_len``
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='r2d2',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether to use priority(priority sample, IS weight, update priority)
+ priority=True,
+ # (bool) Whether to use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=True,
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.997,
+ # (int) N-step reward for target q_value estimation
+ nstep=5,
+ # (int) the timestep of burnin operation, which is designed to RNN hidden state difference
+ # caused by off-policy
+ burnin_step=20,
+ # (int) the trajectory length to unroll the RNN network minus
+ # the timestep of burnin operation
+ learn_unroll_len=80,
+ # learn_mode config
+ learn=dict(
+ # (int) The number of training updates (iterations) to perform after each data collection by the collector.
+ # A larger "update_per_collect" value implies a more off-policy approach.
+ # The whole pipeline process follows this cycle: collect data -> update policy -> collect data -> ...
+ update_per_collect=1,
+ # (int) The number of samples in a training batch.
+ batch_size=64,
+ # (float) The step size of gradient descent, determining the rate of learning.
+ learning_rate=0.0001,
+ # (int) Frequence of target network update.
+ # target_update_freq=100,
+ target_update_theta=0.001,
+ # (bool) whether use value_rescale function for predicted value
+ value_rescale=True,
+ # (bool) Whether ignore done(usually for max step termination env).
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with done is False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ # (bool) It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In R2D2 policy, for each collect_env, we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ traj_len_inf=True,
+ # (int) `env_num` is used in hidden state, should equal to that one in env config (e.g. collector_env_num).
+ # User should specify this value in user config. `None` is a placeholder.
+ env_num=None,
+ ),
+ # eval_mode config
+ eval=dict(
+ # (int) `env_num` is used in hidden state, should equal to that one in env config (e.g. evaluator_env_num).
+ # User should specify this value in user config.
+ env_num=None,
+ ),
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Type of decay. Supports either 'exp' (exponential) or 'linear'.
+ type='exp',
+ # (float) Initial value of epsilon at the start.
+ start=0.95,
+ # (float) Final value of epsilon after decay.
+ end=0.05,
+ # (int) The number of environment steps over which epsilon should decay.
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer. Usually, larger buffer size is better.
+ replay_buffer_size=10000,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For example about R2D2, its registered name is ``drqn`` and the import_names is \
+ ``ding.model.template.q_learning``.
+ """
+ return 'drqn', ['ding.model.template.q_learning']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including some attributes and modules. For R2D2, it mainly contains \
+ optimizer, algorithm-specific arguments such as burnin_step, value_rescale and gamma, main and target \
+ model. Because of the use of RNN, all the models should be wrappered with ``hidden_state`` which needs to \
+ be initialized with proper size.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+ self._burnin_step = self._cfg.burnin_step
+ self._value_rescale = self._cfg.learn.value_rescale
+
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_update_theta}
+ )
+
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ )
+ self._learn_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ )
+ self._learn_model = model_wrap(self._learn_model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Preprocess the data to fit the required data format for learning
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): The data collected from collect function
+ Returns:
+ - data (:obj:`Dict[str, torch.Tensor]`): The processed data, including at least \
+ ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight']
+ """
+ # data preprocess
+ data = timestep_collate(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ if self._priority_IS_weight:
+ assert self._priority, "Use IS Weight correction, but Priority is not used."
+ if self._priority and self._priority_IS_weight:
+ data['weight'] = data['IS']
+ else:
+ data['weight'] = data.get('weight', None)
+
+ burnin_step = self._burnin_step
+
+ # data['done'], data['weight'], data['value_gamma'] is used in def _forward_learn() to calculate
+ # the q_nstep_td_error, should be length of [self._sequence_len-self._burnin_step]
+ ignore_done = self._cfg.learn.ignore_done
+ if ignore_done:
+ data['done'] = [None for _ in range(self._sequence_len - burnin_step)]
+ else:
+ data['done'] = data['done'][burnin_step:].float() # for computation of online model self._learn_model
+ # NOTE that after the proprocessing of get_nstep_return_data() in _get_train_sample
+ # the data['done'] [t] is already the n-step done
+
+ # if the data don't include 'weight' or 'value_gamma' then fill in None in a list
+ # with length of [self._sequence_len-self._burnin_step],
+ # below is two different implementation ways
+ if 'value_gamma' not in data:
+ data['value_gamma'] = [None for _ in range(self._sequence_len - burnin_step)]
+ else:
+ data['value_gamma'] = data['value_gamma'][burnin_step:]
+
+ if 'weight' not in data or data['weight'] is None:
+ data['weight'] = [None for _ in range(self._sequence_len - burnin_step)]
+ else:
+ data['weight'] = data['weight'] * torch.ones_like(data['done'])
+ # every timestep in sequence has same weight, which is the _priority_IS_weight in PER
+
+ # cut the seq_len from burn_in step to (seq_len - nstep) step
+ data['action'] = data['action'][burnin_step:-self._nstep]
+ # cut the seq_len from burn_in step to (seq_len - nstep) step
+ data['reward'] = data['reward'][burnin_step:-self._nstep]
+
+ # the burnin_nstep_obs is used to calculate the init hidden state of rnn for the calculation of the q_value,
+ # target_q_value, and target_q_action
+
+ # these slicing are all done in the outermost layer, which is the seq_len dim
+ data['burnin_nstep_obs'] = data['obs'][:burnin_step + self._nstep]
+ # the main_obs is used to calculate the q_value, the [bs:-self._nstep] means using the data from
+ # [bs] timestep to [self._sequence_len-self._nstep] timestep
+ data['main_obs'] = data['obs'][burnin_step:-self._nstep]
+ # the target_obs is used to calculate the target_q_value
+ data['target_obs'] = data['obs'][burnin_step + self._nstep:]
+
+ return data
+
+ def _forward_learn(self, data: List[List[Dict[str, Any]]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data (trajectory for R2D2) from the replay buffer and then \
+ returns the output result, including various training information such as loss, q value, priority.
+ Arguments:
+ - data (:obj:`List[List[Dict[int, Any]]]`): The input data used for policy forward, including a batch of \
+ training samples. For each dict element, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the time and \
+ batch dimension by the utility functions ``self._data_preprocess_learn``. \
+ For R2D2, each element in list is a trajectory with the length of ``unroll_len``, and the element in \
+ trajectory list is a dict containing at least the following keys: ``obs``, ``action``, ``prev_state``, \
+ ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight`` \
+ and ``value_gamma``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for R2D2Policy: ``ding.policy.tests.test_r2d2``.
+ """
+ # forward
+ data = self._data_preprocess_learn(data) # output datatype: Dict
+ self._learn_model.train()
+ self._target_model.train()
+ # use the hidden state in timestep=0
+ # note the reset method is performed at the hidden state wrapper, to reset self._state.
+ self._learn_model.reset(data_id=None, state=data['prev_state'][0])
+ self._target_model.reset(data_id=None, state=data['prev_state'][0])
+
+ if len(data['burnin_nstep_obs']) != 0:
+ with torch.no_grad():
+ inputs = {'obs': data['burnin_nstep_obs'], 'enable_fast_timestep': True}
+ burnin_output = self._learn_model.forward(
+ inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep]
+ ) # keys include 'logit', 'hidden_state' 'saved_state', \
+ # 'action', for their specific dim, please refer to DRQN model
+ burnin_output_target = self._target_model.forward(
+ inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep]
+ )
+
+ self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][0])
+ inputs = {'obs': data['main_obs'], 'enable_fast_timestep': True}
+ q_value = self._learn_model.forward(inputs)['logit']
+ self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][1])
+ self._target_model.reset(data_id=None, state=burnin_output_target['saved_state'][1])
+
+ next_inputs = {'obs': data['target_obs'], 'enable_fast_timestep': True}
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(next_inputs)['logit']
+ # argmax_action double_dqn
+ target_q_action = self._learn_model.forward(next_inputs)['action']
+
+ action, reward, done, weight = data['action'], data['reward'], data['done'], data['weight']
+ value_gamma = data['value_gamma']
+ # T, B, nstep -> T, nstep, B
+ reward = reward.permute(0, 2, 1).contiguous()
+ loss = []
+ td_error = []
+ for t in range(self._sequence_len - self._burnin_step - self._nstep):
+ # here t=0 means timestep in the original sample sequence, we minus self._nstep
+ # because for the last timestep in the sequence, we don't have their target obs
+ td_data = q_nstep_td_data(
+ q_value[t], target_q_value[t], action[t], target_q_action[t], reward[t], done[t], weight[t]
+ )
+ if self._value_rescale:
+ l, e = q_nstep_td_error_with_rescale(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t])
+ loss.append(l)
+ td_error.append(e.abs())
+ else:
+ l, e = q_nstep_td_error(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t])
+ loss.append(l)
+ # td will be a list of the length
+ #
+ # and each value is a tensor of the size batch_size
+ td_error.append(e.abs())
+ loss = sum(loss) / (len(loss) + 1e-8)
+
+ # using the mixture of max and mean absolute n-step TD-errors as the priority of the sequence
+ td_error_per_sample = 0.9 * torch.max(
+ torch.stack(td_error), dim=0
+ )[0] + (1 - 0.9) * (torch.sum(torch.stack(td_error), dim=0) / (len(td_error) + 1e-8))
+ # torch.max(torch.stack(td_error), dim=0) will return tuple like thing, please refer to torch.max
+ # td_error shape list(, B),
+ # for example, (75,64)
+ # torch.sum(torch.stack(td_error), dim=0) can also be replaced with sum(td_error)
+
+ # update
+ self._optimizer.zero_grad()
+ loss.backward()
+ self._optimizer.step()
+ # after update
+ self._target_model.update(self._learn_model.state_dict())
+
+ # the information for debug
+ batch_range = torch.arange(action[0].shape[0])
+ q_s_a_t0 = q_value[0][batch_range, action[0]]
+ target_q_s_a_t0 = target_q_value[0][batch_range, target_q_action[0]]
+
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'priority': td_error_per_sample.tolist(), # note abs operation has been performed above
+ # the first timestep in the sequence, may not be the start of episode
+ 'q_s_taken-a_t0': q_s_a_t0.mean().item(),
+ 'target_q_s_max-a_t0': target_q_s_a_t0.mean().item(),
+ 'q_s_a-mean_t0': q_value[0].mean().item(),
+ }
+
+ def _reset_learn(self, data_id: Optional[List[int]] = None) -> None:
+ """
+ Overview:
+ Reset some stateful variables for learn mode when necessary, such as the hidden state of RNN or the \
+ memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \
+ varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \
+ different trajectories in ``data_id`` will have different hidden state in RNN.
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \
+ (i.e. RNN hidden_state in R2D2) specified by ``data_id``.
+ """
+
+ self._learn_model.reset(data_id=data_id)
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode, usually including model, target_model and optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For R2D2, it contains the \
+ collect_model to balance the exploration and exploitation with epsilon-greedy sample mechanism and \
+ maintain the hidden state of rnn. Besides, there are some initialization operations about other \
+ algorithm-specific arguments such as burnin_step, unroll_len and nstep.
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+
+ .. tip::
+ Some variables need to initialize independently in different modes, such as gamma and nstep in R2D2. This \
+ design is for the convenience of parallel execution of different policy modes.
+ """
+ self._nstep = self._cfg.nstep
+ self._burnin_step = self._cfg.burnin_step
+ self._gamma = self._cfg.discount_factor
+ self._sequence_len = self._cfg.learn_unroll_len + self._cfg.burnin_step
+ self._unroll_len = self._sequence_len
+
+ # for r2d2, this hidden_state wrapper is to add the 'prev hidden state' for each transition.
+ # Note that collect env forms a batch and the key is added for the batch simultaneously.
+ self._collect_model = model_wrap(
+ self._model, wrapper_name='hidden_state', state_num=self._cfg.collect.env_num, save_prev_state=True
+ )
+ self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \
+ exploration, i.e., classic epsilon-greedy exploration strategy.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ - eps (:obj:`float`): The epsilon value for exploration.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data (prev_state) for learn mode defined in ``self._process_transition`` method. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ RNN's hidden states are maintained in the policy, so we don't need pass them into data but to reset the \
+ hidden states with ``_reset_collect`` method when episode ends. Besides, the previous hidden states are \
+ necessary for training, so we need to return them in ``_process_transition`` method.
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for R2D2Policy: ``ding.policy.tests.test_r2d2``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._collect_model.eval()
+ with torch.no_grad():
+ # in collect phase, inference=True means that each time we only pass one timestep data,
+ # so the we can get the hidden state of rnn: at each timestep.
+ output = self._collect_model.forward(data, data_id=data_id, eps=eps, inference=True)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None:
+ """
+ Overview:
+ Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \
+ memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \
+ varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \
+ different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN.
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \
+ (i.e., RNN hidden_state in R2D2) specified by ``data_id``.
+ """
+ self._collect_model.reset(data_id=data_id)
+
+ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor],
+ timestep: namedtuple) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For R2D2, it contains obs, action, prev_state, reward, and done.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.
+ - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network given the observation \
+ as input. For R2D2, it contains the action and the prev_state of RNN.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
+ except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
+ reward, done, info, etc.
+ Returns:
+ - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.
+ """
+ transition = {
+ 'obs': obs,
+ 'action': policy_output['action'],
+ 'prev_state': policy_output['prev_state'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. In R2D2, a train sample is processed transitions with unroll_len \
+ length. This method is usually used in collectors to execute necessary \
+ RL data preprocessing before training, which can help learner amortize revelant time consumption. \
+ In addition, you can also implement this method as an identity function and do the data processing \
+ in ``self._forward_learn`` method.
+ Arguments:
+ - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each sample is a fixed-length \
+ trajectory, and each element in a sample is the similar format as input transitions, but may contain \
+ more data for training, such as nstep reward and value_gamma factor.
+ """
+ transitions = get_nstep_return_data(transitions, self._nstep, gamma=self._gamma)
+ return get_train_sample(transitions, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For R2D2, it contains the \
+ eval model to greedily select action with argmax q_value mechanism and main the hidden state.
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='hidden_state', state_num=self._cfg.eval.env_num)
+ self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs. ``_forward_eval`` often use argmax sample method to get actions that \
+ q_value is the highest.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ RNN's hidden states are maintained in the policy, so we don't need pass them into data but to reset the \
+ hidden states with ``_reset_eval`` method when the episode ends.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for R2D2Policy: ``ding.policy.tests.test_r2d2``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, data_id=data_id, inference=True)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
+ """
+ Overview:
+ Reset some stateful variables for eval mode when necessary, such as the hidden state of RNN or the \
+ memory bank of some special algortihms. If ``data_id`` is None, it means to reset all the stateful \
+ varaibles. Otherwise, it will reset the stateful variables according to the ``data_id``. For example, \
+ different environments/episodes in evaluation in ``data_id`` will have different hidden state in RNN.
+ Arguments:
+ - data_id (:obj:`Optional[List[int]]`): The id of the data, which is used to reset the stateful variables \
+ (i.e., RNN hidden_state in R2D2) specified by ``data_id``.
+ """
+ self._eval_model.reset(data_id=data_id)
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return super()._monitor_vars_learn() + [
+ 'total_loss', 'priority', 'q_s_taken-a_t0', 'target_q_s_max-a_t0', 'q_s_a-mean_t0'
+ ]
diff --git a/DI-engine/ding/policy/r2d2_collect_traj.py b/DI-engine/ding/policy/r2d2_collect_traj.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cf312010f1577be5ee6a90a667d934f3e3b360b
--- /dev/null
+++ b/DI-engine/ding/policy/r2d2_collect_traj.py
@@ -0,0 +1,491 @@
+import copy
+from collections import namedtuple
+from typing import List, Dict, Any, Tuple, Union, Optional
+
+import torch
+
+from ding.model import model_wrap
+from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_nstep_td_error_with_rescale, get_nstep_return_data, \
+ get_train_sample
+from ding.torch_utils import Adam, to_device
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import timestep_collate, default_collate, default_decollate
+from .base_policy import Policy
+
+
+@POLICY_REGISTRY.register('r2d2_collect_traj')
+class R2D2CollectTrajPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of R2D2 for collecting expert traj for R2D3.
+
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str dqn | RL policy register name, refer to | This arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool False | Whether use priority(PER) | Priority sample,
+ | update priority
+ 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight
+ | ``_weight`` | to correct biased update. If True,
+ | priority must be True.
+ 6 | ``discount_`` float 0.997, | Reward's future discount factor, aka. | May be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 7 ``nstep`` int 3, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 8 ``burnin_step`` int 2 | The timestep of burnin operation,
+ | which is designed to RNN hidden state
+ | difference caused by off-policy
+ 9 | ``learn.update`` int 1 | How many updates(iterations) to train | This args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 10 | ``learn.batch_`` int 64 | The number of samples of an iteration
+ | ``size``
+ 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
+ | ``_rate``
+ 12 | ``learn.value_`` bool True | Whether use value_rescale function for
+ | ``rescale`` | predicted value
+ 13 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
+ | ``update_freq``
+ 14 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
+ | ``done`` | calculation. | fake termination env
+ 15 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
+ | call of collector. | different envs
+ 16 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
+ | ``_len``
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='r2d2',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=True,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=True,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.997,
+ # (int) N-step reward for target q_value estimation
+ nstep=5,
+ # (int) the timestep of burnin operation, which is designed to RNN hidden state difference
+ # caused by off-policy
+ burnin_step=2,
+ # (int) the trajectory length to unroll the RNN network minus
+ # the timestep of burnin operation
+ unroll_len=80,
+ learn=dict(
+ update_per_collect=1,
+ batch_size=64,
+ learning_rate=0.0001,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (int) Frequence of target network update.
+ # target_update_freq=100,
+ target_update_theta=0.001,
+ # (bool) whether use value_rescale function for predicted value
+ value_rescale=True,
+ ignore_done=False,
+ ),
+ collect=dict(
+ # NOTE it is important that don't include key n_sample here, to make sure self._traj_len=INF
+ # each_iter_n_sample=32,
+ # `env_num` is used in hidden state, should equal to that one in env config.
+ # User should specify this value in user config.
+ env_num=None,
+ ),
+ eval=dict(
+ # `env_num` is used in hidden state, should equal to that one in env config.
+ # User should specify this value in user config.
+ env_num=None,
+ ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=10000, ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'drqn', ['ding.model.template.q_learning']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Init the learner model of R2D2Policy
+
+ Arguments:
+ .. note::
+
+ The _init_learn method takes the argument from the self._cfg.learn in the config file
+
+ - learning_rate (:obj:`float`): The learning rate fo the optimizer
+ - gamma (:obj:`float`): The discount factor
+ - nstep (:obj:`int`): The num of n step return
+ - value_rescale (:obj:`bool`): Whether to use value rescaled loss in algorithm
+ - burnin_step (:obj:`int`): The num of step of burnin
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+ self._burnin_step = self._cfg.burnin_step
+ self._value_rescale = self._cfg.learn.value_rescale
+
+ self._target_model = copy.deepcopy(self._model)
+ # self._target_model = model_wrap( TODO(pu)
+ # self._target_model,
+ # wrapper_name='target',
+ # update_type='assign',
+ # update_kwargs={'freq': self._cfg.learn.target_update_freq}
+ # )
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_update_theta}
+ )
+
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ )
+ self._learn_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ )
+ self._learn_model = model_wrap(self._learn_model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> dict:
+ r"""
+ Overview:
+ Preprocess the data to fit the required data format for learning
+
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function
+
+ Returns:
+ - data (:obj:`Dict[str, Any]`): the processed data, including at least \
+ ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight']
+ - data_info (:obj:`dict`): the data info, such as replay_buffer_idx, replay_unique_id
+ """
+ # data preprocess
+ data = timestep_collate(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ if self._priority_IS_weight:
+ assert self._priority, "Use IS Weight correction, but Priority is not used."
+ if self._priority and self._priority_IS_weight:
+ data['weight'] = data['IS']
+ else:
+ data['weight'] = data.get('weight', None)
+
+ bs = self._burnin_step
+
+ # data['done'], data['weight'], data['value_gamma'] is used in def _forward_learn() to calculate
+ # the q_nstep_td_error, should be length of [self._unroll_len_add_burnin_step-self._burnin_step]
+ ignore_done = self._cfg.learn.ignore_done
+ if ignore_done:
+ data['done'] = [None for _ in range(self._unroll_len_add_burnin_step - bs)]
+ else:
+ data['done'] = data['done'][bs:].float() # for computation of online model self._learn_model
+ # NOTE that after the proprocessing of get_nstep_return_data() in _get_train_sample
+ # the data['done'] [t] is already the n-step done
+
+ # if the data don't include 'weight' or 'value_gamma' then fill in None in a list
+ # with length of [self._unroll_len_add_burnin_step-self._burnin_step],
+ # below is two different implementation ways
+ if 'value_gamma' not in data:
+ data['value_gamma'] = [None for _ in range(self._unroll_len_add_burnin_step - bs)]
+ else:
+ data['value_gamma'] = data['value_gamma'][bs:]
+
+ if 'weight' not in data:
+ data['weight'] = [None for _ in range(self._unroll_len_add_burnin_step - bs)]
+ else:
+ data['weight'] = data['weight'] * torch.ones_like(data['done'])
+ # every timestep in sequence has same weight, which is the _priority_IS_weight in PER
+
+ data['action'] = data['action'][bs:-self._nstep]
+ data['reward'] = data['reward'][bs:-self._nstep]
+
+ # the burnin_nstep_obs is used to calculate the init hidden state of rnn for the calculation of the q_value,
+ # target_q_value, and target_q_action
+ data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep]
+ # the main_obs is used to calculate the q_value, the [bs:-self._nstep] means using the data from
+ # [bs] timestep to [self._unroll_len_add_burnin_step-self._nstep] timestep
+ data['main_obs'] = data['obs'][bs:-self._nstep]
+ # the target_obs is used to calculate the target_q_value
+ data['target_obs'] = data['obs'][bs + self._nstep:]
+
+ return data
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Acquire the data, calculate the loss and optimize learner model.
+
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least \
+ ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight']
+
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including cur_lr and total_loss
+ - cur_lr (:obj:`float`): Current learning rate
+ - total_loss (:obj:`float`): The calculated loss
+ """
+ # forward
+ data = self._data_preprocess_learn(data)
+ self._learn_model.train()
+ self._target_model.train()
+ # take out timestep=0
+ self._learn_model.reset(data_id=None, state=data['prev_state'][0])
+ self._target_model.reset(data_id=None, state=data['prev_state'][0])
+
+ if len(data['burnin_nstep_obs']) != 0:
+ with torch.no_grad():
+ inputs = {'obs': data['burnin_nstep_obs'], 'enable_fast_timestep': True}
+ burnin_output = self._learn_model.forward(
+ inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep]
+ )
+ burnin_output_target = self._target_model.forward(
+ inputs, saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep]
+ )
+
+ self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][0])
+ inputs = {'obs': data['main_obs'], 'enable_fast_timestep': True}
+ q_value = self._learn_model.forward(inputs)['logit']
+ self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][1])
+ self._target_model.reset(data_id=None, state=burnin_output_target['saved_state'][1])
+
+ next_inputs = {'obs': data['target_obs'], 'enable_fast_timestep': True}
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(next_inputs)['logit']
+ # argmax_action double_dqn
+ target_q_action = self._learn_model.forward(next_inputs)['action']
+
+ action, reward, done, weight = data['action'], data['reward'], data['done'], data['weight']
+ value_gamma = data['value_gamma']
+ # T, B, nstep -> T, nstep, B
+ reward = reward.permute(0, 2, 1).contiguous()
+ loss = []
+ td_error = []
+ for t in range(self._unroll_len_add_burnin_step - self._burnin_step - self._nstep):
+ # here t=0 means timestep in the original sample sequence, we minus self._nstep
+ # because for the last timestep in the sequence, we don't have their target obs
+ td_data = q_nstep_td_data(
+ q_value[t], target_q_value[t], action[t], target_q_action[t], reward[t], done[t], weight[t]
+ )
+ if self._value_rescale:
+ l, e = q_nstep_td_error_with_rescale(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t])
+ loss.append(l)
+ td_error.append(e.abs())
+ else:
+ l, e = q_nstep_td_error(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t])
+ loss.append(l)
+ td_error.append(e.abs())
+ loss = sum(loss) / (len(loss) + 1e-8)
+
+ # using the mixture of max and mean absolute n-step TD-errors as the priority of the sequence
+ td_error_per_sample = 0.9 * torch.max(
+ torch.stack(td_error), dim=0
+ )[0] + (1 - 0.9) * (torch.sum(torch.stack(td_error), dim=0) / (len(td_error) + 1e-8))
+ # td_error shape list(, B), for example, (75,64)
+ # torch.sum(torch.stack(td_error), dim=0) can also be replaced with sum(td_error)
+
+ # update
+ self._optimizer.zero_grad()
+ loss.backward()
+ self._optimizer.step()
+ # after update
+ self._target_model.update(self._learn_model.state_dict())
+
+ # the information for debug
+ batch_range = torch.arange(action[0].shape[0])
+ q_s_a_t0 = q_value[0][batch_range, action[0]]
+ target_q_s_a_t0 = target_q_value[0][batch_range, target_q_action[0]]
+
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ # the first timestep in the sequence, may not be the start of episode TODO(pu)
+ 'q_s_taken-a_t0': q_s_a_t0.mean().item(),
+ 'target_q_s_max-a_t0': target_q_s_a_t0.mean().item(),
+ 'q_s_a-mean_t0': q_value[0].mean().item(),
+ }
+
+ def _reset_learn(self, data_id: Optional[List[int]] = None) -> None:
+ self._learn_model.reset(data_id=data_id)
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ """
+ # assert 'unroll_len' not in self._cfg.collect, "r2d2 use default unroll_len"
+ self._nstep = self._cfg.nstep
+ self._burnin_step = self._cfg.burnin_step
+ self._gamma = self._cfg.discount_factor
+ self._unroll_len_add_burnin_step = self._cfg.unroll_len + self._cfg.burnin_step
+ self._unroll_len = self._unroll_len_add_burnin_step # for compatibility
+ # self._unroll_len = self._cfg.collect.unroll_len
+
+ self._collect_model = model_wrap(
+ self._model, wrapper_name='hidden_state', state_num=self._cfg.collect.env_num, save_prev_state=True
+ )
+ # self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample')
+ self._collect_model = model_wrap(self._collect_model, wrapper_name='argmax_sample')
+
+ self._collect_model.reset()
+
+ # def _forward_collect(self, data: dict, eps: float) -> dict:
+ def _forward_collect(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Collect output according to eps_greedy plugin
+
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs'].
+
+ Returns:
+ - data (:obj:`dict`): The collected data
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._collect_model.eval()
+ with torch.no_grad():
+ # in collect phase, inference=True means that each time we only pass one timestep data,
+ # so the we can get the hidden state of rnn: at each timestep.
+ # output = self._collect_model.forward(data, data_id=data_id, eps=eps, inference=True)
+ output = self._collect_model.forward(data, data_id=data_id, inference=True)
+ # output = self._collect_model.forward(data, inference=True)
+
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None:
+ self._collect_model.reset(data_id=data_id)
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['reward', 'done'] \
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ transition = {
+ 'obs': obs,
+ 'action': model_output['action'],
+ # 'prev_state': model_output['prev_state'],
+ 'prev_state': None,
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the trajectory and the n step return data, then sample from the n_step return data
+
+ Arguments:
+ - data (:obj:`list`): The trajectory's cache
+
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ from copy import deepcopy
+ data_one_step = deepcopy(get_nstep_return_data(data, 1, gamma=self._gamma))
+ # data_one_step = deepcopy(data)
+ data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
+ for i in range(len(data)):
+ # here we record the one-step done, we don't need record one-step reward,
+ # because the n-step reward in data already include one-step reward
+ data[i]['done_one_step'] = data_one_step[i]['done']
+ return get_train_sample(data, self._unroll_len) # self._unroll_len_add_burnin_step
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model with argmax strategy.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='hidden_state', state_num=self._cfg.eval.env_num)
+ self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of collect mode, similar to ``self._forward_collect``.
+
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs'].
+
+ Returns:
+ - output (:obj:`dict`): Dict type data, including at least inferred action according to input obs.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, data_id=data_id, inference=True)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
+ self._eval_model.reset(data_id=data_id)
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return super()._monitor_vars_learn() + [
+ 'total_loss', 'priority', 'q_s_taken-a_t0', 'target_q_s_max-a_t0', 'q_s_a-mean_t0'
+ ]
diff --git a/DI-engine/ding/policy/r2d2_gtrxl.py b/DI-engine/ding/policy/r2d2_gtrxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..73b89239f3a65765f01a030b3e944edb0785e9c6
--- /dev/null
+++ b/DI-engine/ding/policy/r2d2_gtrxl.py
@@ -0,0 +1,475 @@
+import copy
+import torch
+from collections import namedtuple
+from typing import List, Dict, Any, Tuple, Union, Optional
+
+from ding.model import model_wrap
+from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_nstep_td_error_with_rescale, get_nstep_return_data, \
+ get_train_sample
+from ding.torch_utils import Adam, to_device
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import timestep_collate, default_collate, default_decollate
+from .base_policy import Policy
+
+
+@POLICY_REGISTRY.register('r2d2_gtrxl')
+class R2D2GTrXLPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of R2D2 adopting the Transformer architecture GTrXL as backbone.
+
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str r2d2_gtrxl | RL policy register name, refer to | This arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool False | Whether use priority(PER) | Priority sample,
+ | update priority
+ 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight
+ | ``_weight`` | to correct biased update. If True,
+ | priority must be True.
+ 6 | ``discount_`` float 0.99, | Reward's future discount factor, aka. | May be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 7 | ``nstep`` int 5, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 8 | ``burnin_step`` int 1 | The timestep of burnin operation,
+ | which is designed to warm-up GTrXL
+ | memory difference caused by off-policy
+ 9 | ``learn.update`` int 1 | How many updates(iterations) to train | This args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 10 | ``learn.batch_`` int 64 | The number of samples of an iteration
+ | ``size``
+ 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
+ | ``_rate``
+ 12 | ``learn.value_`` bool True | Whether use value_rescale function for
+ | ``rescale`` | predicted value
+ 13 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
+ | ``update_freq``
+ 14 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
+ | ``done`` | calculation. | fake termination env
+ 15 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
+ | call of collector. | different envs
+ 16 | ``collect.unroll`` int 25 | unroll length of an iteration | unroll_len>1
+ | ``_len``
+ 17 | ``collect.seq`` int 20 | Training sequence length | unroll_len>=seq_len>1
+ | ``_len``
+ 18 | ``learn.init_`` str zero | 'zero' or 'old', how to initialize the |
+ | ``memory`` | memory before each training iteration. |
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='r2d2_gtrxl',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=True,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=True,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # (int) N-step reward for target q_value estimation
+ nstep=5,
+ # how many steps to use as burnin
+ burnin_step=1,
+ # (int) trajectory length
+ unroll_len=25,
+ # (int) training sequence length
+ seq_len=20,
+ learn=dict(
+ update_per_collect=1,
+ batch_size=64,
+ learning_rate=0.0001,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (int) Frequence of target network update.
+ # target_update_freq=100,
+ target_update_theta=0.001,
+ ignore_done=False,
+ # (bool) whether use value_rescale function for predicted value
+ value_rescale=False,
+ # 'zero' or 'old', how to initialize the memory in training
+ init_memory='zero'
+ ),
+ collect=dict(
+ # NOTE it is important that don't include key n_sample here, to make sure self._traj_len=INF
+ each_iter_n_sample=32,
+ # `env_num` is used in hidden state, should equal to that one in env config.
+ # User should specify this value in user config.
+ env_num=None,
+ ),
+ eval=dict(
+ # `env_num` is used in hidden state, should equal to that one in env config.
+ # User should specify this value in user config.
+ env_num=None,
+ ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=10000, ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'gtrxldqn', ['ding.model.template.q_learning']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Init the learner model of GTrXLR2D2Policy. \
+ Target model has 2 wrappers: 'target' for weights update and 'transformer_segment' to split trajectories \
+ in segments. Learn model has 2 wrappers: 'argmax' to select the best action and 'transformer_segment'.
+
+ Arguments:
+ - learning_rate (:obj:`float`): The learning rate fo the optimizer
+ - gamma (:obj:`float`): The discount factor
+ - nstep (:obj:`int`): The num of n step return
+ - value_rescale (:obj:`bool`): Whether to use value rescaled loss in algorithm
+ - burnin_step (:obj:`int`): The num of step of burnin
+ - seq_len (:obj:`int`): Training sequence length
+ - init_memory (:obj:`str`): 'zero' or 'old', how to initialize the memory before each training iteration.
+
+ .. note::
+ The ``_init_learn`` method takes the argument from the self._cfg.learn in the config file
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+ self._burnin_step = self._cfg.burnin_step
+ self._batch_size = self._cfg.learn.batch_size
+ self._seq_len = self._cfg.seq_len
+ self._value_rescale = self._cfg.learn.value_rescale
+ self._init_memory = self._cfg.learn.init_memory
+ assert self._init_memory in ['zero', 'old']
+
+ self._target_model = copy.deepcopy(self._model)
+
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_update_theta}
+ )
+ self._target_model = model_wrap(self._target_model, seq_len=self._seq_len, wrapper_name='transformer_segment')
+
+ self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._learn_model = model_wrap(self._learn_model, seq_len=self._seq_len, wrapper_name='transformer_segment')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> dict:
+ r"""
+ Overview:
+ Preprocess the data to fit the required data format for learning
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function
+ Returns:
+ - data (:obj:`Dict[str, Any]`): the processed data, including at least \
+ ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight']
+ - data_info (:obj:`dict`): the data info, such as replay_buffer_idx, replay_unique_id
+ """
+ if self._init_memory == 'old' and 'prev_memory' in data[0].keys():
+ # retrieve the memory corresponding to the first and n_step(th) element in each trajectory and remove it
+ # from 'data'
+ prev_mem = [b['prev_memory'][0] for b in data]
+ prev_mem_target = [b['prev_memory'][self._nstep] for b in data]
+ # stack the memory entries along the batch dimension,
+ # reshape the new memory to have shape (layer_num+1, memory_len, bs, embedding_dim) compatible with GTrXL
+ prev_mem_batch = torch.stack(prev_mem, 0).permute(1, 2, 0, 3)
+ prev_mem_target_batch = torch.stack(prev_mem_target, 0).permute(1, 2, 0, 3)
+ data = timestep_collate(data)
+ data['prev_memory_batch'] = prev_mem_batch
+ data['prev_memory_target_batch'] = prev_mem_target_batch
+ else:
+ data = timestep_collate(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ if self._priority_IS_weight:
+ assert self._priority, "Use IS Weight correction, but Priority is not used."
+ if self._priority and self._priority_IS_weight:
+ data['weight'] = data['IS']
+ else:
+ data['weight'] = data.get('weight', None)
+
+ # data['done'], data['weight'], data['value_gamma'] is used in def _forward_learn() to calculate
+ # the q_nstep_td_error, should be length of [self._unroll_len]
+ ignore_done = self._cfg.learn.ignore_done
+ if ignore_done:
+ data['done'] = [None for _ in range(self._unroll_len)]
+ else:
+ data['done'] = data['done'].float() # for computation of online model self._learn_model
+ # NOTE that after the proprocessing of get_nstep_return_data() in _get_train_sample
+ # the data['done'][t] is already the n-step done
+
+ # if the data don't include 'weight' or 'value_gamma' then fill in None in a list
+ # with length of [self._unroll_len_add_burnin_step-self._burnin_step],
+ # below is two different implementation ways
+ if 'value_gamma' not in data:
+ data['value_gamma'] = [None for _ in range(self._unroll_len)]
+ else:
+ data['value_gamma'] = data['value_gamma']
+
+ if 'weight' not in data or data['weight'] is None:
+ data['weight'] = [None for _ in range(self._unroll_len)]
+ else:
+ data['weight'] = data['weight'] * torch.ones_like(data['done'])
+ # every timestep in sequence has same weight, which is the _priority_IS_weight in PER
+
+ data['action'] = data['action'][:-self._nstep]
+ data['reward'] = data['reward'][:-self._nstep]
+
+ data['main_obs'] = data['obs'][:-self._nstep]
+ # the target_obs is used to calculate the target_q_value
+ data['target_obs'] = data['obs'][self._nstep:]
+
+ return data
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Acquire the data, calculate the loss and optimize learner model.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least \
+ ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including cur_lr and total_loss
+ - cur_lr (:obj:`float`): Current learning rate
+ - total_loss (:obj:`float`): The calculated loss
+ """
+ data = self._data_preprocess_learn(data) # shape (seq_len, bs, obs_dim)
+ self._learn_model.train()
+ self._target_model.train()
+ if self._init_memory == 'old':
+ # use the previous hidden state memory
+ self._learn_model.reset_memory(state=data['prev_memory_batch'])
+ self._target_model.reset_memory(state=data['prev_memory_target_batch'])
+ elif self._init_memory == 'zero':
+ # use the zero-initialized state memory
+ self._learn_model.reset_memory()
+ self._target_model.reset_memory()
+
+ inputs = data['main_obs']
+ q_value = self._learn_model.forward(inputs)['logit'] # shape (seq_len, bs, act_dim)
+ next_inputs = data['target_obs']
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(next_inputs)['logit']
+ if self._init_memory == 'old':
+ self._learn_model.reset_memory(state=data['prev_memory_target_batch'])
+ elif self._init_memory == 'zero':
+ self._learn_model.reset_memory()
+ target_q_action = self._learn_model.forward(next_inputs)['action'] # argmax_action double_dqn
+
+ action, reward, done, weight = data['action'], data['reward'], data['done'], data['weight']
+ value_gamma = data['value_gamma']
+ # T, B, nstep -> T, nstep, B
+ reward = reward.permute(0, 2, 1).contiguous()
+ loss = []
+ td_error = []
+ for t in range(self._burnin_step, self._unroll_len - self._nstep):
+ # here skip the first 'burnin_step' steps because we only needed that to initialize the memory, and
+ # skip the last 'nstep' steps because we don't have their target obs
+ td_data = q_nstep_td_data(
+ q_value[t], target_q_value[t], action[t], target_q_action[t], reward[t], done[t], weight[t]
+ )
+ if self._value_rescale:
+ l, e = q_nstep_td_error_with_rescale(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t])
+ else:
+ l, e = q_nstep_td_error(td_data, self._gamma, self._nstep, value_gamma=value_gamma[t])
+ loss.append(l)
+ td_error.append(e.abs())
+ loss = sum(loss) / (len(loss) + 1e-8)
+
+ # using the mixture of max and mean absolute n-step TD-errors as the priority of the sequence
+ td_error_per_sample = 0.9 * torch.max(
+ torch.stack(td_error), dim=0
+ )[0] + (1 - 0.9) * (torch.sum(torch.stack(td_error), dim=0) / (len(td_error) + 1e-8))
+ # td_error shape list(, B), for example, (75,64)
+ # torch.sum(torch.stack(td_error), dim=0) can also be replaced with sum(td_error)
+
+ # update
+ self._optimizer.zero_grad()
+ loss.backward()
+ self._optimizer.step()
+ # after update
+ self._target_model.update(self._learn_model.state_dict())
+
+ # the information for debug
+ batch_range = torch.arange(action[0].shape[0])
+ q_s_a_t0 = q_value[0][batch_range, action[0]]
+ target_q_s_a_t0 = target_q_value[0][batch_range, target_q_action[0]]
+
+ ret = {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ # the first timestep in the sequence, may not be the start of episode
+ 'q_s_taken-a_t0': q_s_a_t0.mean().item(),
+ 'target_q_s_max-a_t0': target_q_s_a_t0.mean().item(),
+ 'q_s_a-mean_t0': q_value[0].mean().item(),
+ }
+
+ return ret
+
+ def _reset_learn(self, data_id: Optional[List[int]] = None) -> None:
+ self._learn_model.reset(data_id=data_id)
+ self._target_model.reset(data_id=data_id)
+ self._learn_model.reset_memory()
+ self._target_model.reset_memory()
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init unroll length and sequence len, collect model.
+ """
+ assert 'unroll_len' not in self._cfg.collect, "Use default unroll_len"
+ self._nstep = self._cfg.nstep
+ self._gamma = self._cfg.discount_factor
+ self._unroll_len = self._cfg.unroll_len
+ self._seq_len = self._cfg.seq_len
+ self._collect_model = model_wrap(self._model, wrapper_name='transformer_input', seq_len=self._seq_len)
+ self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample')
+ self._collect_model = model_wrap(
+ self._collect_model, wrapper_name='transformer_memory', batch_size=self.cfg.collect.env_num
+ )
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: dict, eps: float) -> dict:
+ r"""
+ Overview:
+ Forward function for collect mode with eps_greedy
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, eps=eps, data_id=data_id)
+ del output['input_seq']
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None:
+ # data_id is ID of env to be reset
+ self._collect_model.reset(data_id=data_id)
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['reward', 'done'] \
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ transition = {
+ 'obs': obs,
+ 'action': model_output['action'],
+ 'prev_memory': model_output['memory'], # state of the memory before taking the 'action'
+ 'prev_state': None,
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the trajectory and the n step return data, then sample from the n_step return data
+ Arguments:
+ - data (:obj:`list`): The trajectory's cache
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ self._seq_len = self._cfg.seq_len
+ data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
+ return get_train_sample(data, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model with argmax strategy.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='transformer_input', seq_len=self._seq_len)
+ self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample')
+ self._eval_model = model_wrap(
+ self._eval_model, wrapper_name='transformer_memory', batch_size=self.cfg.eval.env_num
+ )
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of eval mode, similar to ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, data_id=data_id)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
+ self._eval_model.reset(data_id=data_id)
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return super()._monitor_vars_learn() + [
+ 'total_loss', 'priority', 'q_s_taken-a_t0', 'target_q_s_max-a_t0', 'q_s_a-mean_t0'
+ ]
diff --git a/DI-engine/ding/policy/r2d3.py b/DI-engine/ding/policy/r2d3.py
new file mode 100644
index 0000000000000000000000000000000000000000..feb836292142fb8898132d440e3da977bcb4f021
--- /dev/null
+++ b/DI-engine/ding/policy/r2d3.py
@@ -0,0 +1,563 @@
+import copy
+from collections import namedtuple
+from typing import List, Dict, Any, Tuple, Union, Optional
+
+import torch
+
+from ding.model import model_wrap
+from ding.rl_utils import q_nstep_td_error_with_rescale, get_nstep_return_data, \
+ get_train_sample, dqfd_nstep_td_error, dqfd_nstep_td_error_with_rescale, dqfd_nstep_td_data
+from ding.torch_utils import Adam, to_device
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import timestep_collate, default_collate, default_decollate
+from .base_policy import Policy
+
+
+@POLICY_REGISTRY.register('r2d3')
+class R2D3Policy(Policy):
+ r"""
+ Overview:
+ Policy class of r2d3, from paper `Making Efficient Use of Demonstrations to Solve Hard Exploration Problems` .
+
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str dqn | RL policy register name, refer to | This arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | This arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool False | Whether use priority(PER) | Priority sample,
+ | update priority
+ 5 | ``priority_IS`` bool False | Whether use Importance Sampling Weight
+ | ``_weight`` | to correct biased update. If True,
+ | priority must be True.
+ 6 | ``discount_`` float 0.997, | Reward's future discount factor, aka. | May be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 7 ``nstep`` int 3, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 8 ``burnin_step`` int 2 | The timestep of burnin operation,
+ | which is designed to RNN hidden state
+ | difference caused by off-policy
+ 9 | ``learn.update`` int 1 | How many updates(iterations) to train | This args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 10 | ``learn.batch_`` int 64 | The number of samples of an iteration
+ | ``size``
+ 11 | ``learn.learning`` float 0.001 | Gradient step length of an iteration.
+ | ``_rate``
+ 12 | ``learn.value_`` bool True | Whether use value_rescale function for
+ | ``rescale`` | predicted value
+ 13 | ``learn.target_`` int 100 | Frequence of target network update. | Hard(assign) update
+ | ``update_freq``
+ 14 | ``learn.ignore_`` bool False | Whether ignore done for target value | Enable it for some
+ | ``done`` | calculation. | fake termination env
+ 15 ``collect.n_sample`` int [8, 128] | The number of training samples of a | It varies from
+ | call of collector. | different envs
+ 16 | ``collect.unroll`` int 1 | unroll length of an iteration | In RNN, unroll_len>1
+ | ``_len``
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='r2d3',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=True,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=True,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.997,
+ # (int) N-step reward for target q_value estimation
+ nstep=5,
+ # (int) the timestep of burnin operation, which is designed to RNN hidden state difference
+ # caused by off-policy
+ burnin_step=2,
+ # (int) the trajectory length to unroll the RNN network minus
+ # the timestep of burnin operation
+ learn_unroll_len=80,
+ learn=dict(
+ update_per_collect=1,
+ batch_size=64,
+ learning_rate=0.0001,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (int) Frequence of target network update.
+ # target_update_freq=100,
+ target_update_theta=0.001,
+ # (bool) whether use value_rescale function for predicted value
+ value_rescale=True,
+ ignore_done=False,
+ ),
+ collect=dict(
+ # NOTE it is important that don't include key n_sample here, to make sure self._traj_len=INF
+ # each_iter_n_sample=32,
+ # `env_num` is used in hidden state, should equal to that one in env config.
+ # User should specify this value in user config.
+ env_num=None,
+ ),
+ eval=dict(
+ # `env_num` is used in hidden state, should equal to that one in env config.
+ # User should specify this value in user config.
+ env_num=None,
+ ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=10000, ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'drqn', ['ding.model.template.q_learning']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Init the learner model of r2d3Policy
+
+ Arguments:
+ .. note::
+
+ The _init_learn method takes the argument from the self._cfg.learn in the config file
+
+ - learning_rate (:obj:`float`): The learning rate fo the optimizer
+ - gamma (:obj:`float`): The discount factor
+ - nstep (:obj:`int`): The num of n step return
+ - value_rescale (:obj:`bool`): Whether to use value rescaled loss in algorithm
+ - burnin_step (:obj:`int`): The num of step of burnin
+ """
+ self.lambda1 = self._cfg.learn.lambda1 # n-step return
+ self.lambda2 = self._cfg.learn.lambda2 # supervised loss
+ self.lambda3 = self._cfg.learn.lambda3 # L2
+ self.lambda_one_step_td = self._cfg.learn.lambda_one_step_td # 1-step return
+ # margin function in JE, here we implement this as a constant
+ self.margin_function = self._cfg.learn.margin_function
+
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._optimizer = Adam(
+ self._model.parameters(), lr=self._cfg.learn.learning_rate, weight_decay=self.lambda3, optim_type='adamw'
+ )
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+ self._burnin_step = self._cfg.burnin_step
+ self._value_rescale = self._cfg.learn.value_rescale
+
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_update_theta}
+ )
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ )
+ self._learn_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ )
+ self._learn_model = model_wrap(self._learn_model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _data_preprocess_learn(self, data: List[Dict[str, Any]]) -> dict:
+ r"""
+ Overview:
+ Preprocess the data to fit the required data format for learning
+
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function
+
+ Returns:
+ - data (:obj:`Dict[str, Any]`): the processed data, including at least \
+ ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight']
+ - data_info (:obj:`dict`): the data info, such as replay_buffer_idx, replay_unique_id
+ """
+ # data preprocess
+ data = timestep_collate(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ if self._priority_IS_weight:
+ assert self._priority, "Use IS Weight correction, but Priority is not used."
+ if self._priority and self._priority_IS_weight:
+ data['weight'] = data['IS']
+ else:
+ data['weight'] = data.get('weight', None)
+
+ bs = self._burnin_step
+
+ # data['done'], data['weight'], data['value_gamma'] is used in def _forward_learn() to calculate
+ # the q_nstep_td_error, should be length of [self._sequence_len-self._burnin_step-self._nstep]
+ ignore_done = self._cfg.learn.ignore_done
+ if ignore_done:
+ data['done'] = [None for _ in range(self._sequence_len - bs)]
+ else:
+ data['done'] = data['done'][bs:].float()
+ # NOTE that after the proprocessing of get_nstep_return_data() in _get_train_sample
+ # the data['done'] [t] is already the n-step done
+
+ # if the data don't include 'weight' or 'value_gamma' then fill in None in a list
+ # with length of [self._sequence_len-self._burnin_step-self._nstep],
+ # below is two different implementation ways
+ if 'value_gamma' not in data:
+ data['value_gamma'] = [None for _ in range(self._sequence_len - bs)]
+ else:
+ data['value_gamma'] = data['value_gamma'][bs:]
+
+ if 'weight' not in data:
+ data['weight'] = [None for _ in range(self._sequence_len - bs)]
+ else:
+ data['weight'] = data['weight'] * torch.ones_like(data['done'])
+ # every timestep in sequence has same weight, which is the _priority_IS_weight in PER
+
+ data['action'] = data['action'][bs:-self._nstep]
+ data['reward'] = data['reward'][bs:-self._nstep]
+
+ # the burnin_nstep_obs is used to calculate the init hidden state of rnn for the calculation of the q_value,
+ # target_q_value, and target_q_action
+ data['burnin_nstep_obs'] = data['obs'][:bs + self._nstep]
+ # the main_obs is used to calculate the q_value, the [bs:-self._nstep] means using the data from
+ # [bs] timestep to [self._sequence_len-self._nstep] timestep
+ data['main_obs'] = data['obs'][bs:-self._nstep]
+ # the target_obs is used to calculate the target_q_value
+ data['target_obs'] = data['obs'][bs + self._nstep:]
+
+ # TODO(pu)
+ data['target_obs_one_step'] = data['obs'][bs + 1:]
+ if ignore_done:
+ data['done_one_step'] = [None for _ in range(self._sequence_len - bs)]
+ else:
+ data['done_one_step'] = data['done_one_step'][bs:].float()
+
+ return data
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Acquire the data, calculate the loss and optimize learner model.
+
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least \
+ ['main_obs', 'target_obs', 'burnin_obs', 'action', 'reward', 'done', 'weight']
+
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including cur_lr and total_loss
+ - cur_lr (:obj:`float`): Current learning rate
+ - total_loss (:obj:`float`): The calculated loss
+ """
+ # forward
+ data = self._data_preprocess_learn(data)
+ self._learn_model.train()
+ self._target_model.train()
+ # take out the hidden state in timestep=0
+ self._learn_model.reset(data_id=None, state=data['prev_state'][0])
+ self._target_model.reset(data_id=None, state=data['prev_state'][0])
+
+ if len(data['burnin_nstep_obs']) != 0:
+ with torch.no_grad():
+ inputs = {'obs': data['burnin_nstep_obs'], 'enable_fast_timestep': True}
+ burnin_output = self._learn_model.forward(
+ inputs,
+ saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep, self._burnin_step + 1]
+ )
+ burnin_output_target = self._target_model.forward(
+ inputs,
+ saved_state_timesteps=[self._burnin_step, self._burnin_step + self._nstep, self._burnin_step + 1]
+ )
+
+ self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][0])
+ inputs = {'obs': data['main_obs'], 'enable_fast_timestep': True}
+ q_value = self._learn_model.forward(inputs)['logit']
+
+ # n-step
+ self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][1])
+ self._target_model.reset(data_id=None, state=burnin_output_target['saved_state'][1])
+
+ next_inputs = {'obs': data['target_obs'], 'enable_fast_timestep': True}
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(next_inputs)['logit']
+ # argmax_action double_dqn
+ target_q_action = self._learn_model.forward(next_inputs)['action']
+
+ # one-step
+ self._learn_model.reset(data_id=None, state=burnin_output['saved_state'][2])
+ self._target_model.reset(data_id=None, state=burnin_output_target['saved_state'][2])
+
+ next_inputs_one_step = {'obs': data['target_obs_one_step'], 'enable_fast_timestep': True}
+ with torch.no_grad():
+ target_q_value_one_step = self._target_model.forward(next_inputs_one_step)['logit']
+ # argmax_action double_dqn
+ target_q_action_one_step = self._learn_model.forward(next_inputs_one_step)['action']
+
+ action, reward, done, weight = data['action'], data['reward'], data['done'], data['weight']
+ value_gamma = data['value_gamma']
+ done_one_step = data['done_one_step']
+ # T, B, nstep -> T, nstep, B
+ reward = reward.permute(0, 2, 1).contiguous()
+ loss = []
+ loss_nstep = []
+ loss_1step = []
+ loss_sl = []
+ td_error = []
+ for t in range(self._sequence_len - self._burnin_step - self._nstep):
+ # here t=0 means timestep in the original sample sequence, we minus self._nstep
+ # because for the last timestep in the sequence, we don't have their target obs
+ td_data = dqfd_nstep_td_data(
+ q_value[t],
+ target_q_value[t],
+ action[t],
+ target_q_action[t],
+ reward[t],
+ done[t],
+ done_one_step[t],
+ weight[t],
+ target_q_value_one_step[t],
+ target_q_action_one_step[t],
+ data['is_expert'][t], # is_expert flag(expert 1, agent 0)
+ )
+
+ if self._value_rescale:
+ l, e, loss_statistics = dqfd_nstep_td_error_with_rescale(
+ td_data,
+ self._gamma,
+ self.lambda1,
+ self.lambda2,
+ self.margin_function,
+ self.lambda_one_step_td,
+ self._nstep,
+ False,
+ value_gamma=value_gamma[t],
+ )
+ loss.append(l)
+ # td_error.append(e.abs()) # first sum then abs
+ td_error.append(e) # first abs then sum
+ # loss statistics for debugging
+ loss_nstep.append(loss_statistics[0])
+ loss_1step.append(loss_statistics[1])
+ loss_sl.append(loss_statistics[2])
+
+ else:
+ l, e, loss_statistics = dqfd_nstep_td_error(
+ td_data,
+ self._gamma,
+ self.lambda1,
+ self.lambda2,
+ self.margin_function,
+ self.lambda_one_step_td,
+ self._nstep,
+ False,
+ value_gamma=value_gamma[t],
+ )
+ loss.append(l)
+ # td_error.append(e.abs()) # first sum then abs
+ td_error.append(e) # first abs then sum
+ # loss statistics for debugging
+ loss_nstep.append(loss_statistics[0])
+ loss_1step.append(loss_statistics[1])
+ loss_sl.append(loss_statistics[2])
+
+ loss = sum(loss) / (len(loss) + 1e-8)
+ # loss statistics for debugging
+ loss_nstep = sum(loss_nstep) / (len(loss_nstep) + 1e-8)
+ loss_1step = sum(loss_1step) / (len(loss_1step) + 1e-8)
+ loss_sl = sum(loss_sl) / (len(loss_sl) + 1e-8)
+
+ # using the mixture of max and mean absolute n-step TD-errors as the priority of the sequence
+ td_error_per_sample = 0.9 * torch.max(
+ torch.stack(td_error), dim=0
+ )[0] + (1 - 0.9) * (torch.sum(torch.stack(td_error), dim=0) / (len(td_error) + 1e-8))
+ # td_error shape list(, B), for example, (75,64)
+ # torch.sum(torch.stack(td_error), dim=0) can also be replaced with sum(td_error)
+
+ # update
+ self._optimizer.zero_grad()
+ loss.backward()
+ self._optimizer.step()
+ # after update
+ self._target_model.update(self._learn_model.state_dict())
+
+ # the information for debug
+ batch_range = torch.arange(action[0].shape[0])
+ q_s_a_t0 = q_value[0][batch_range, action[0]]
+ target_q_s_a_t0 = target_q_value[0][batch_range, target_q_action[0]]
+
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ # loss statistics for debugging
+ 'nstep_loss': loss_nstep.item(),
+ '1step_loss': loss_1step.item(),
+ 'sl_loss': loss_sl.item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ # the first timestep in the sequence, may not be the start of episode
+ 'q_s_taken-a_t0': q_s_a_t0.mean().item(),
+ 'target_q_s_max-a_t0': target_q_s_a_t0.mean().item(),
+ 'q_s_a-mean_t0': q_value[0].mean().item(),
+ }
+
+ def _reset_learn(self, data_id: Optional[List[int]] = None) -> None:
+ self._learn_model.reset(data_id=data_id)
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ """
+ assert 'unroll_len' not in self._cfg.collect, "r2d3 use default unroll_len"
+ self._nstep = self._cfg.nstep
+ self._burnin_step = self._cfg.burnin_step
+ self._gamma = self._cfg.discount_factor
+ self._sequence_len = self._cfg.learn_unroll_len + self._cfg.burnin_step
+ self._unroll_len = self._sequence_len # for compatibility
+
+ self._collect_model = model_wrap(
+ self._model, wrapper_name='hidden_state', state_num=self._cfg.collect.env_num, save_prev_state=True
+ )
+ self._collect_model = model_wrap(self._collect_model, wrapper_name='eps_greedy_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: dict, eps: float) -> dict:
+ r"""
+ Overview:
+ Collect output according to eps_greedy plugin
+
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs'].
+
+ Returns:
+ - data (:obj:`dict`): The collected data
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._collect_model.eval()
+ with torch.no_grad():
+ # in collect phase, inference=True means that each time we only pass one timestep data,
+ # so the we can get the hidden state of rnn: at each timestep.
+ output = self._collect_model.forward(data, data_id=data_id, eps=eps, inference=True)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None:
+ self._collect_model.reset(data_id=data_id)
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action', 'prev_state']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['reward', 'done'] \
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ transition = {
+ 'obs': obs,
+ 'action': model_output['action'],
+ 'prev_state': model_output['prev_state'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the trajectory and the n step return data, then sample from the n_step return data
+
+ Arguments:
+ - data (:obj:`list`): The trajectory's cache
+
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ from copy import deepcopy
+ data_one_step = deepcopy(get_nstep_return_data(data, 1, gamma=self._gamma))
+ data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
+ for i in range(len(data)):
+ # here we record the one-step done, we don't need record one-step reward,
+ # because the n-step reward in data already include one-step reward
+ data[i]['done_one_step'] = data_one_step[i]['done']
+ return get_train_sample(data, self._sequence_len)
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model with argmax strategy.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='hidden_state', state_num=self._cfg.eval.env_num)
+ self._eval_model = model_wrap(self._eval_model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of collect mode, similar to ``self._forward_collect``.
+
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs'].
+
+ Returns:
+ - output (:obj:`dict`): Dict type data, including at least inferred action according to input obs.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ data = {'obs': data}
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, data_id=data_id, inference=True)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
+ self._eval_model.reset(data_id=data_id)
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return super()._monitor_vars_learn() + [
+ 'total_loss', 'nstep_loss', '1step_loss', 'sl_loss', 'priority', 'q_s_taken-a_t0', 'target_q_s_max-a_t0',
+ 'q_s_a-mean_t0'
+ ]
diff --git a/DI-engine/ding/policy/rainbow.py b/DI-engine/ding/policy/rainbow.py
new file mode 100644
index 0000000000000000000000000000000000000000..1efd00e90b3dba18d720caf9e931487ffaa66729
--- /dev/null
+++ b/DI-engine/ding/policy/rainbow.py
@@ -0,0 +1,302 @@
+from typing import List, Dict, Any, Tuple, Union
+import torch
+import copy
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import dist_nstep_td_data, dist_nstep_td_error, get_train_sample, get_nstep_return_data
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .dqn import DQNPolicy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('rainbow')
+class RainbowDQNPolicy(DQNPolicy):
+ r"""
+ Overview:
+ Rainbow DQN contain several improvements upon DQN, including:
+ - target network
+ - dueling architecture
+ - prioritized experience replay
+ - n_step return
+ - noise net
+ - distribution net
+
+ Therefore, the RainbowDQNPolicy class inherit upon DQNPolicy class
+
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str rainbow | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool False | Whether to use cuda for network | this arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4 ``priority`` bool True | Whether use priority(PER) | priority sample,
+ | update priority
+ 5 ``model.v_min`` float -10 | Value of the smallest atom
+ | in the support set.
+ 6 ``model.v_max`` float 10 | Value of the largest atom
+ | in the support set.
+ 7 ``model.n_atom`` int 51 | Number of atoms in the support set
+ | of the value distribution.
+ 8 | ``other.eps`` float 0.05 | Start value for epsilon decay. It's
+ | ``.start`` | small because rainbow use noisy net.
+ 9 | ``other.eps`` float 0.05 | End value for epsilon decay.
+ | ``.end``
+ 10 | ``discount_`` float 0.97, | Reward's future discount factor, aka. | may be 1 when sparse
+ | ``factor`` [0.95, 0.999] | gamma | reward env
+ 11 ``nstep`` int 3, | N-step reward discount sum for target
+ [3, 5] | q_value estimation
+ 12 | ``learn.update`` int 3 | How many updates(iterations) to train | this args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ == ==================== ======== ============== ======================================== =======================
+
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='rainbow',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=True,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=True,
+ # (int) Number of training samples(randomly collected) in replay buffer when training starts.
+ # random_collect_size=2000,
+ model=dict(
+ # (float) Value of the smallest atom in the support set.
+ # Default to -10.0.
+ v_min=-10,
+ # (float) Value of the smallest atom in the support set.
+ # Default to 10.0.
+ v_max=10,
+ # (int) Number of atoms in the support set of the
+ # value distribution. Default to 51.
+ n_atom=51,
+ ),
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # (int) N-step reward for target q_value estimation
+ nstep=3,
+ learn=dict(
+
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=1,
+ batch_size=32,
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (int) Frequence of target network update.
+ target_update_freq=100,
+ # (bool) Whether ignore done(usually for max step termination env)
+ ignore_done=False,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] shoule be set
+ # n_sample=32,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(),
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ # (float) End value for epsilon decay, in [0, 1]. It's equals to `end` because rainbow uses noisy net.
+ start=0.05,
+ # (float) End value for epsilon decay, in [0, 1].
+ end=0.05,
+ # (int) Env steps of epsilon decay.
+ decay=100000,
+ ),
+ replay_buffer=dict(
+ # (int) Max size of replay buffer.
+ replay_buffer_size=100000,
+ # (float) Prioritization exponent.
+ alpha=0.6,
+ # (float) Importance sample soft coefficient.
+ # 0 means no correction, while 1 means full correction
+ beta=0.4,
+ # (int) Anneal step for beta: 0 means no annealing. Defaults to 0
+ anneal_step=100000,
+ )
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'rainbowdqn', ['ding.model.template.q_learning']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Init the learner model of RainbowDQNPolicy
+
+ Arguments:
+ - learning_rate (:obj:`float`): the learning rate fo the optimizer
+ - gamma (:obj:`float`): the discount factor
+ - nstep (:obj:`int`): the num of n step return
+ - v_min (:obj:`float`): value distribution minimum value
+ - v_max (:obj:`float`): value distribution maximum value
+ - n_atom (:obj:`int`): the number of atom sample point
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+ self._v_max = self._cfg.model.v_max
+ self._v_min = self._cfg.model.v_min
+ self._n_atom = self._cfg.model.n_atom
+
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='assign',
+ update_kwargs={'freq': self._cfg.learn.target_update_freq}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ """
+ Overview:
+ Forward and backward function of learn mode, acquire the data and calculate the loss and\
+ optimize learner model
+
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'next_obs', 'reward', 'action']
+
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including cur_lr and total_loss
+ - cur_lr (:obj:`float`): current learning rate
+ - total_loss (:obj:`float`): the calculated loss
+ """
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # Rainbow forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # reset noise of noisenet for both main model and target model
+ self._reset_noise(self._learn_model)
+ self._reset_noise(self._target_model)
+ q_dist = self._learn_model.forward(data['obs'])['distribution']
+ with torch.no_grad():
+ target_q_dist = self._target_model.forward(data['next_obs'])['distribution']
+ self._reset_noise(self._learn_model)
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+ value_gamma = data.get('value_gamma', None)
+ data = dist_nstep_td_data(
+ q_dist, target_q_dist, data['action'], target_q_action, data['reward'], data['done'], data['weight']
+ )
+ loss, td_error_per_sample = dist_nstep_td_error(
+ data, self._gamma, self._v_min, self._v_max, self._n_atom, nstep=self._nstep, value_gamma=value_gamma
+ )
+ # ====================
+ # Rainbow update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ self._optimizer.step()
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ }
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init moethod. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+
+ .. note::
+ the rainbow dqn enable the eps_greedy_sample, but might not need to use it, \
+ as the noise_net contain noise that can help exploration
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._nstep = self._cfg.nstep
+ self._gamma = self._cfg.discount_factor
+ self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: dict, eps: float) -> dict:
+ r"""
+ Overview:
+ Reset the noise from noise net and collect output according to eps_greedy plugin
+
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ self._reset_noise(self._collect_model)
+ with torch.no_grad():
+ output = self._collect_model.forward(data, eps=eps)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _get_train_sample(self, traj: list) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the trajectory and the n step return data, then sample from the n_step return data
+
+ Arguments:
+ - traj (:obj:`list`): The trajactory's buffer list
+
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ data = get_nstep_return_data(traj, self._nstep, gamma=self._gamma)
+ return get_train_sample(data, self._unroll_len)
+
+ def _reset_noise(self, model: torch.nn.Module):
+ r"""
+ Overview:
+ Reset the noise of model
+
+ Arguments:
+ - model (:obj:`torch.nn.Module`): the model to reset, must contain reset_noise method
+ """
+ for m in model.modules():
+ if hasattr(m, 'reset_noise'):
+ m.reset_noise()
diff --git a/DI-engine/ding/policy/sac.py b/DI-engine/ding/policy/sac.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bb870569db0ea1b84280a6ab4b70f4cf8363c72
--- /dev/null
+++ b/DI-engine/ding/policy/sac.py
@@ -0,0 +1,1491 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import copy
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.distributions import Normal, Independent
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample, q_v_1step_td_error, q_v_1step_td_data
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('discrete_sac')
+class DiscreteSACPolicy(Policy):
+ """
+ Overview:
+ Policy class of discrete SAC algorithm. Paper link: https://arxiv.org/abs/1910.07207.
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='discrete_sac',
+ # (bool) Whether to use cuda for network and loss computation.
+ cuda=False,
+ # (bool) Whether to belong to on-policy or off-policy algorithm, DiscreteSAC is an off-policy algorithm.
+ on_policy=False,
+ # (bool) Whether to use priority sampling in buffer. Default to False in DiscreteSAC.
+ priority=False,
+ # (bool) Whether use Importance Sampling weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (int) Number of training samples (randomly collected) in replay buffer when training starts.
+ random_collect_size=10000,
+ # (bool) Whether to need policy-specific data in process transition.
+ transition_with_policy_data=True,
+ # (bool) Whether to enable multi-agent training setting.
+ multi_agent=False,
+ model=dict(
+ # (bool) Whether to use double-soft-q-net for target q computation.
+ # For more details, please refer to TD3 about Clipped Double-Q Learning trick.
+ twin_critic=True,
+ ),
+ # learn_mode config
+ learn=dict(
+ # (int) How many updates (iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ update_per_collect=1,
+ # (int) Minibatch size for one gradient descent.
+ batch_size=256,
+ # (float) Learning rate for soft q network.
+ learning_rate_q=3e-4,
+ # (float) Learning rate for policy network.
+ learning_rate_policy=3e-4,
+ # (float) Learning rate for auto temperature parameter `\alpha`.
+ learning_rate_alpha=3e-4,
+ # (float) Used for soft update of the target network,
+ # aka. Interpolation factor in EMA update for target network.
+ target_theta=0.005,
+ # (float) Discount factor for the discounted sum of rewards, aka. gamma.
+ discount_factor=0.99,
+ # (float) Entropy regularization coefficient in SAC.
+ # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details.
+ # If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`.
+ alpha=0.2,
+ # (bool) Whether to use auto temperature parameter `\alpha` .
+ # Temperature parameter `\alpha` determines the relative importance of the entropy term against the reward.
+ # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details.
+ # Note that: Using auto alpha needs to set the above `learning_rate_alpha`.
+ auto_alpha=True,
+ # (bool) Whether to use auto `\alpha` in log space.
+ log_space=True,
+ # (float) Target policy entropy value for auto temperature (alpha) adjustment.
+ target_entropy=None,
+ # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with done is False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+ # (float) Weight uniform initialization max range in the last output layer
+ init_w=3e-3,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ # Only one of [n_sample, n_episode] shoule be set.
+ n_sample=1,
+ # (int) Split episodes or trajectories into pieces with length `unroll_len`.
+ unroll_len=1,
+ # (bool) Whether to collect logit in `process_transition`.
+ # In some algorithm like guided cost learning, we need to use logit to train the reward model.
+ collector_logit=False,
+ ),
+ eval=dict(), # for compability
+ other=dict(
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer. Usually, larger buffer size is good
+ # for SAC but cost more storage.
+ replay_buffer_size=1000000,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+ """
+ if self._cfg.multi_agent:
+ return 'discrete_maqac', ['ding.model.template.maqac']
+ else:
+ return 'discrete_qac', ['ding.model.template.qac']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For DiscreteSAC, it mainly \
+ contains three optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target \
+ model. Especially, the ``auto_alpha`` mechanism for balancing max entropy target is also initialized here.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._twin_critic = self._cfg.model.twin_critic
+
+ self._optimizer_q = Adam(
+ self._model.critic.parameters(),
+ lr=self._cfg.learn.learning_rate_q,
+ )
+ self._optimizer_policy = Adam(
+ self._model.actor.parameters(),
+ lr=self._cfg.learn.learning_rate_policy,
+ )
+
+ # Algorithm-Specific Config
+ self._gamma = self._cfg.learn.discount_factor
+ if self._cfg.learn.auto_alpha:
+ if self._cfg.learn.target_entropy is None:
+ assert 'action_shape' in self._cfg.model, "DiscreteSAC need network model with action_shape variable"
+ self._target_entropy = -np.prod(self._cfg.model.action_shape)
+ else:
+ self._target_entropy = self._cfg.learn.target_entropy
+ if self._cfg.learn.log_space:
+ self._log_alpha = torch.log(torch.FloatTensor([self._cfg.learn.alpha]))
+ self._log_alpha = self._log_alpha.to(self._device).requires_grad_()
+ self._alpha_optim = torch.optim.Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha)
+ assert self._log_alpha.shape == torch.Size([1]) and self._log_alpha.requires_grad
+ self._alpha = self._log_alpha.detach().exp()
+ self._auto_alpha = True
+ self._log_space = True
+ else:
+ self._alpha = torch.FloatTensor([self._cfg.learn.alpha]).to(self._device).requires_grad_()
+ self._alpha_optim = torch.optim.Adam([self._alpha], lr=self._cfg.learn.learning_rate_alpha)
+ self._auto_alpha = True
+ self._log_space = False
+ else:
+ self._alpha = torch.tensor(
+ [self._cfg.learn.alpha], requires_grad=False, device=self._device, dtype=torch.float32
+ )
+ self._auto_alpha = False
+
+ # Main and target models
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss, action, priority.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
+ dimension by some utility functions such as ``default_preprocess_learn``. \
+ For SAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
+ ``logit``, ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys like ``weight``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for DiscreteSACPolicy: \
+ ``ding.policy.tests.test_discrete_sac``.
+ """
+ loss_dict = {}
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ self._learn_model.train()
+ self._target_model.train()
+ obs = data['obs']
+ next_obs = data['next_obs']
+ reward = data['reward']
+ done = data['done']
+ logit = data['logit']
+ action = data['action']
+
+ # 1. predict q value
+ q_value = self._learn_model.forward(obs, mode='compute_critic')['q_value']
+ dist = torch.distributions.categorical.Categorical(logits=logit)
+ dist_entropy = dist.entropy()
+ entropy = dist_entropy.mean()
+
+ # 2. predict target value
+
+ # target q value. SARSA: first predict next action, then calculate next q value
+ with torch.no_grad():
+ policy_output_next = self._learn_model.forward(next_obs, mode='compute_actor')
+ if self._cfg.multi_agent:
+ policy_output_next['logit'][policy_output_next['action_mask'] == 0.0] = -1e8
+ prob = F.softmax(policy_output_next['logit'], dim=-1)
+ log_prob = torch.log(prob + 1e-8)
+ target_q_value = self._target_model.forward(next_obs, mode='compute_critic')['q_value']
+ # the value of a policy according to the maximum entropy objective
+ if self._twin_critic:
+ # find min one as target q value
+ target_value = (
+ prob * (torch.min(target_q_value[0], target_q_value[1]) - self._alpha * log_prob.squeeze(-1))
+ ).sum(dim=-1)
+ else:
+ target_value = (prob * (target_q_value - self._alpha * log_prob.squeeze(-1))).sum(dim=-1)
+
+ # 3. compute q loss
+ if self._twin_critic:
+ q_data0 = q_v_1step_td_data(q_value[0], target_value, action, reward, done, data['weight'])
+ loss_dict['critic_loss'], td_error_per_sample0 = q_v_1step_td_error(q_data0, self._gamma)
+ q_data1 = q_v_1step_td_data(q_value[1], target_value, action, reward, done, data['weight'])
+ loss_dict['twin_critic_loss'], td_error_per_sample1 = q_v_1step_td_error(q_data1, self._gamma)
+ td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2
+ else:
+ q_data = q_v_1step_td_data(q_value, target_value, action, reward, done, data['weight'])
+ loss_dict['critic_loss'], td_error_per_sample = q_v_1step_td_error(q_data, self._gamma)
+
+ # 4. update q network
+ self._optimizer_q.zero_grad()
+ loss_dict['critic_loss'].backward()
+ if self._twin_critic:
+ loss_dict['twin_critic_loss'].backward()
+ self._optimizer_q.step()
+
+ # 5. evaluate to get action distribution
+ policy_output = self._learn_model.forward(obs, mode='compute_actor')
+ # 6. apply discrete action mask in multi_agent setting
+ if self._cfg.multi_agent:
+ policy_output['logit'][policy_output['action_mask'] == 0.0] = -1e8
+ logit = policy_output['logit']
+ prob = F.softmax(logit, dim=-1)
+ log_prob = F.log_softmax(logit, dim=-1)
+
+ with torch.no_grad():
+ new_q_value = self._learn_model.forward(obs, mode='compute_critic')['q_value']
+ if self._twin_critic:
+ new_q_value = torch.min(new_q_value[0], new_q_value[1])
+ # 7. compute policy loss
+ # we need to sum different actions' policy loss and calculate the average value of a batch
+ policy_loss = (prob * (self._alpha * log_prob - new_q_value)).sum(dim=-1).mean()
+
+ loss_dict['policy_loss'] = policy_loss
+
+ # 8. update policy network
+ self._optimizer_policy.zero_grad()
+ loss_dict['policy_loss'].backward()
+ self._optimizer_policy.step()
+
+ # 9. compute alpha loss
+ if self._auto_alpha:
+ if self._log_space:
+ log_prob = log_prob + self._target_entropy
+ loss_dict['alpha_loss'] = (-prob.detach() * (self._log_alpha * log_prob.detach())).sum(dim=-1).mean()
+
+ self._alpha_optim.zero_grad()
+ loss_dict['alpha_loss'].backward()
+ self._alpha_optim.step()
+ self._alpha = self._log_alpha.detach().exp()
+ else:
+ log_prob = log_prob + self._target_entropy
+ loss_dict['alpha_loss'] = (-prob.detach() * (self._alpha * log_prob.detach())).sum(dim=-1).mean()
+
+ self._alpha_optim.zero_grad()
+ loss_dict['alpha_loss'].backward()
+ self._alpha_optim.step()
+ self._alpha.data = torch.where(self._alpha > 0, self._alpha,
+ torch.zeros_like(self._alpha)).requires_grad_()
+ loss_dict['total_loss'] = sum(loss_dict.values())
+
+ # target update
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'total_loss': loss_dict['total_loss'].item(),
+ 'policy_loss': loss_dict['policy_loss'].item(),
+ 'critic_loss': loss_dict['critic_loss'].item(),
+ 'cur_lr_q': self._optimizer_q.defaults['lr'],
+ 'cur_lr_p': self._optimizer_policy.defaults['lr'],
+ 'priority': td_error_per_sample.abs().tolist(),
+ 'td_error': td_error_per_sample.detach().mean().item(),
+ 'alpha': self._alpha.item(),
+ 'q_value_1': target_q_value[0].detach().mean().item(),
+ 'q_value_2': target_q_value[1].detach().mean().item(),
+ 'target_value': target_value.detach().mean().item(),
+ 'entropy': entropy.item(),
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode, usually including model, target_model and optimizers.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring.
+ """
+ ret = {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer_q': self._optimizer_q.state_dict(),
+ 'optimizer_policy': self._optimizer_policy.state_dict(),
+ }
+ if self._auto_alpha:
+ ret.update({'optimizer_alpha': self._alpha_optim.state_dict()})
+ return ret
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer_q.load_state_dict(state_dict['optimizer_q'])
+ self._optimizer_policy.load_state_dict(state_dict['optimizer_policy'])
+ if self._auto_alpha:
+ self._alpha_optim.load_state_dict(state_dict['optimizer_alpha'])
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For SAC, it contains the \
+ collect_model to balance the exploration and exploitation with the epsilon and multinomial sample \
+ mechanism, and other algorithm-specific arguments such as unroll_len. \
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ # Empirically, we found that eps_greedy_multinomial_sample works better than multinomial_sample
+ # and eps_greedy_sample, and we don't divide logit by alpha,
+ # for the details please refer to ding/model/wrapper/model_wrappers
+ self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_multinomial_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs. Besides, this policy also needs ``eps`` argument for \
+ exploration, i.e., classic epsilon-greedy exploration strategy.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ - eps (:obj:`float`): The epsilon value for exploration.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \
+ dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for DiscreteSACPolicy: \
+ ``ding.policy.tests.test_discrete_sac``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, mode='compute_actor', eps=eps)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor],
+ timestep: namedtuple) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For discrete SAC, it contains obs, next_obs, logit, action, reward, done.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.
+ - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \
+ as input. For discrete SAC, it contains the action and the logit of the action.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
+ except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
+ reward, done, info, etc.
+ Returns:
+ - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': policy_output['action'],
+ 'logit': policy_output['logit'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. In discrete SAC, a train sample is a processed transition (unroll_len=1).
+ Arguments:
+ - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \
+ as input transitions, but may contain more data for training.
+ """
+ return get_train_sample(transitions, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For DiscreteSAC, it contains \
+ the eval model to greedily select action type with argmax q_value mechanism.
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for DiscreteSACPolicy: \
+ ``ding.policy.tests.test_discrete_sac``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ twin_critic = ['twin_critic_loss'] if self._twin_critic else []
+ if self._auto_alpha:
+ return super()._monitor_vars_learn() + [
+ 'alpha_loss', 'policy_loss', 'critic_loss', 'cur_lr_q', 'cur_lr_p', 'target_q_value', 'q_value_1',
+ 'q_value_2', 'alpha', 'td_error', 'target_value', 'entropy'
+ ] + twin_critic
+ else:
+ return super()._monitor_vars_learn() + [
+ 'policy_loss', 'critic_loss', 'cur_lr_q', 'cur_lr_p', 'target_q_value', 'q_value_1', 'q_value_2',
+ 'alpha', 'td_error', 'target_value', 'entropy'
+ ] + twin_critic
+
+
+@POLICY_REGISTRY.register('sac')
+class SACPolicy(Policy):
+ """
+ Overview:
+ Policy class of continuous SAC algorithm. Paper link: https://arxiv.org/pdf/1801.01290.pdf
+
+ Config:
+ == ==================== ======== ============= ================================= =======================
+ ID Symbol Type Default Value Description Other
+ == ==================== ======== ============= ================================= =======================
+ 1 ``type`` str sac | RL policy register name, refer | this arg is optional,
+ | to registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool True | Whether to use cuda for network |
+ 3 ``on_policy`` bool False | SAC is an off-policy |
+ | algorithm. |
+ 4 ``priority`` bool False | Whether to use priority |
+ | sampling in buffer. |
+ 5 | ``priority_IS_`` bool False | Whether use Importance Sampling |
+ | ``weight`` | weight to correct biased update |
+ 6 | ``random_`` int 10000 | Number of randomly collected | Default to 10000 for
+ | ``collect_size`` | training samples in replay | SAC, 25000 for DDPG/
+ | | buffer when training starts. | TD3.
+ 7 | ``learn.learning`` float 3e-4 | Learning rate for soft q | Defalut to 1e-3
+ | ``_rate_q`` | network. |
+ 8 | ``learn.learning`` float 3e-4 | Learning rate for policy | Defalut to 1e-3
+ | ``_rate_policy`` | network. |
+ 9 | ``learn.alpha`` float 0.2 | Entropy regularization | alpha is initiali-
+ | | coefficient. | zation for auto
+ | | | alpha, when
+ | | | auto_alpha is True
+ 10 | ``learn.`` bool False | Determine whether to use | Temperature parameter
+ | ``auto_alpha`` | auto temperature parameter | determines the
+ | | alpha. | relative importance
+ | | | of the entropy term
+ | | | against the reward.
+ 11 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only
+ | ``ignore_done`` | done flag. | in env like Pendulum
+ 12 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation
+ | ``target_theta`` | target network. | factor in polyak aver
+ | | | aging for target
+ | | | networks.
+ == ==================== ======== ============= ================================= =======================
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='sac',
+ # (bool) Whether to use cuda for network and loss computation.
+ cuda=False,
+ # (bool) Whether to belong to on-policy or off-policy algorithm, SAC is an off-policy algorithm.
+ on_policy=False,
+ # (bool) Whether to use priority sampling in buffer. Default to False in SAC.
+ priority=False,
+ # (bool) Whether use Importance Sampling weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (int) Number of training samples (randomly collected) in replay buffer when training starts.
+ random_collect_size=10000,
+ # (bool) Whether to need policy-specific data in process transition.
+ transition_with_policy_data=True,
+ # (bool) Whether to enable multi-agent training setting.
+ multi_agent=False,
+ model=dict(
+ # (bool) Whether to use double-soft-q-net for target q computation.
+ # For more details, please refer to TD3 about Clipped Double-Q Learning trick.
+ twin_critic=True,
+ # (str) Use reparameterization trick for continous action.
+ action_space='reparameterization',
+ ),
+ # learn_mode config
+ learn=dict(
+ # (int) How many updates (iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ update_per_collect=1,
+ # (int) Minibatch size for one gradient descent.
+ batch_size=256,
+ # (float) Learning rate for soft q network.
+ learning_rate_q=3e-4,
+ # (float) Learning rate for policy network.
+ learning_rate_policy=3e-4,
+ # (float) Learning rate for auto temperature parameter `\alpha`.
+ learning_rate_alpha=3e-4,
+ # (float) Used for soft update of the target network,
+ # aka. Interpolation factor in EMA update for target network.
+ target_theta=0.005,
+ # (float) discount factor for the discounted sum of rewards, aka. gamma.
+ discount_factor=0.99,
+ # (float) Entropy regularization coefficient in SAC.
+ # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details.
+ # If auto_alpha is set to `True`, alpha is initialization for auto `\alpha`.
+ alpha=0.2,
+ # (bool) Whether to use auto temperature parameter `\alpha` .
+ # Temperature parameter `\alpha` determines the relative importance of the entropy term against the reward.
+ # Please check out the original SAC paper (arXiv 1801.01290): Eq 1 for more details.
+ # Note that: Using auto alpha needs to set the above `learning_rate_alpha`.
+ auto_alpha=True,
+ # (bool) Whether to use auto `\alpha` in log space.
+ log_space=True,
+ # (float) Target policy entropy value for auto temperature (alpha) adjustment.
+ target_entropy=None,
+ # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+ # (float) Weight uniform initialization max range in the last output layer.
+ init_w=3e-3,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ n_sample=1,
+ # (int) Split episodes or trajectories into pieces with length `unroll_len`.
+ unroll_len=1,
+ # (bool) Whether to collect logit in `process_transition`.
+ # In some algorithm like guided cost learning, we need to use logit to train the reward model.
+ collector_logit=False,
+ ),
+ eval=dict(), # for compability
+ other=dict(
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer. Usually, larger buffer size is good
+ # for SAC but cost more storage.
+ replay_buffer_size=1000000,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default neural network model setting for demonstration. ``__init__`` method will \
+ automatically call this method to get the default model setting and create model.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): The registered model name and model's import_names.
+ """
+ if self._cfg.multi_agent:
+ return 'continuous_maqac', ['ding.model.template.maqac']
+ else:
+ return 'continuous_qac', ['ding.model.template.qac']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \
+ contains three optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target \
+ model. Especially, the ``auto_alpha`` mechanism for balancing max entropy target is also initialized here.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._twin_critic = self._cfg.model.twin_critic
+
+ # Weight Init for the last output layer
+ if hasattr(self._model, 'actor_head'): # keep compatibility
+ init_w = self._cfg.learn.init_w
+ self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w)
+ self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w)
+ self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w)
+ self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w)
+
+ self._optimizer_q = Adam(
+ self._model.critic.parameters(),
+ lr=self._cfg.learn.learning_rate_q,
+ )
+ self._optimizer_policy = Adam(
+ self._model.actor.parameters(),
+ lr=self._cfg.learn.learning_rate_policy,
+ )
+
+ # Algorithm-Specific Config
+ self._gamma = self._cfg.learn.discount_factor
+ if self._cfg.learn.auto_alpha:
+ if self._cfg.learn.target_entropy is None:
+ assert 'action_shape' in self._cfg.model, "SAC need network model with action_shape variable"
+ self._target_entropy = -np.prod(self._cfg.model.action_shape)
+ else:
+ self._target_entropy = self._cfg.learn.target_entropy
+ if self._cfg.learn.log_space:
+ self._log_alpha = torch.log(torch.FloatTensor([self._cfg.learn.alpha]))
+ self._log_alpha = self._log_alpha.to(self._device).requires_grad_()
+ self._alpha_optim = torch.optim.Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha)
+ assert self._log_alpha.shape == torch.Size([1]) and self._log_alpha.requires_grad
+ self._alpha = self._log_alpha.detach().exp()
+ self._auto_alpha = True
+ self._log_space = True
+ else:
+ self._alpha = torch.FloatTensor([self._cfg.learn.alpha]).to(self._device).requires_grad_()
+ self._alpha_optim = torch.optim.Adam([self._alpha], lr=self._cfg.learn.learning_rate_alpha)
+ self._auto_alpha = True
+ self._log_space = False
+ else:
+ self._alpha = torch.tensor(
+ [self._cfg.learn.alpha], requires_grad=False, device=self._device, dtype=torch.float32
+ )
+ self._auto_alpha = False
+
+ # Main and target models
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss, action, priority.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
+ dimension by some utility functions such as ``default_preprocess_learn``. \
+ For SAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
+ ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``.
+ """
+ loss_dict = {}
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ self._learn_model.train()
+ self._target_model.train()
+ obs = data['obs']
+ next_obs = data['next_obs']
+ reward = data['reward']
+ done = data['done']
+
+ # 1. predict q value
+ q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
+
+ # 2. predict target value
+ with torch.no_grad():
+ (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit']
+
+ dist = Independent(Normal(mu, sigma), 1)
+ pred = dist.rsample()
+ next_action = torch.tanh(pred)
+ y = 1 - next_action.pow(2) + 1e-6
+ # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum)
+ next_log_prob = dist.log_prob(pred).unsqueeze(-1)
+ next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True)
+
+ next_data = {'obs': next_obs, 'action': next_action}
+ target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']
+ # the value of a policy according to the maximum entropy objective
+ if self._twin_critic:
+ # find min one as target q value
+ target_q_value = torch.min(target_q_value[0],
+ target_q_value[1]) - self._alpha * next_log_prob.squeeze(-1)
+ else:
+ target_q_value = target_q_value - self._alpha * next_log_prob.squeeze(-1)
+
+ # 3. compute q loss
+ if self._twin_critic:
+ q_data0 = v_1step_td_data(q_value[0], target_q_value, reward, done, data['weight'])
+ loss_dict['critic_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma)
+ q_data1 = v_1step_td_data(q_value[1], target_q_value, reward, done, data['weight'])
+ loss_dict['twin_critic_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma)
+ td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2
+ else:
+ q_data = v_1step_td_data(q_value, target_q_value, reward, done, data['weight'])
+ loss_dict['critic_loss'], td_error_per_sample = v_1step_td_error(q_data, self._gamma)
+
+ # 4. update q network
+ self._optimizer_q.zero_grad()
+ if self._twin_critic:
+ (loss_dict['critic_loss'] + loss_dict['twin_critic_loss']).backward()
+ else:
+ loss_dict['critic_loss'].backward()
+ self._optimizer_q.step()
+
+ # 5. evaluate to get action distribution
+ (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit']
+ dist = Independent(Normal(mu, sigma), 1)
+ pred = dist.rsample()
+ action = torch.tanh(pred)
+ y = 1 - action.pow(2) + 1e-6
+ # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum)
+ log_prob = dist.log_prob(pred).unsqueeze(-1)
+ log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
+
+ eval_data = {'obs': obs, 'action': action}
+ new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value']
+ if self._twin_critic:
+ new_q_value = torch.min(new_q_value[0], new_q_value[1])
+
+ # 6. compute policy loss
+ policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean()
+
+ loss_dict['policy_loss'] = policy_loss
+
+ # 7. update policy network
+ self._optimizer_policy.zero_grad()
+ loss_dict['policy_loss'].backward()
+ self._optimizer_policy.step()
+
+ # 8. compute alpha loss
+ if self._auto_alpha:
+ if self._log_space:
+ log_prob = log_prob + self._target_entropy
+ loss_dict['alpha_loss'] = -(self._log_alpha * log_prob.detach()).mean()
+
+ self._alpha_optim.zero_grad()
+ loss_dict['alpha_loss'].backward()
+ self._alpha_optim.step()
+ self._alpha = self._log_alpha.detach().exp()
+ else:
+ log_prob = log_prob + self._target_entropy
+ loss_dict['alpha_loss'] = -(self._alpha * log_prob.detach()).mean()
+
+ self._alpha_optim.zero_grad()
+ loss_dict['alpha_loss'].backward()
+ self._alpha_optim.step()
+ self._alpha = max(0, self._alpha)
+
+ loss_dict['total_loss'] = sum(loss_dict.values())
+
+ # target update
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr_q': self._optimizer_q.defaults['lr'],
+ 'cur_lr_p': self._optimizer_policy.defaults['lr'],
+ 'priority': td_error_per_sample.abs().tolist(),
+ 'td_error': td_error_per_sample.detach().mean().item(),
+ 'alpha': self._alpha.item(),
+ 'target_q_value': target_q_value.detach().mean().item(),
+ 'transformed_log_prob': log_prob.mean().item(),
+ **loss_dict
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Return the state_dict of learn mode, usually including model, target_model and optimizers.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of current policy learn state, for saving and restoring.
+ """
+ ret = {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer_q': self._optimizer_q.state_dict(),
+ 'optimizer_policy': self._optimizer_policy.state_dict(),
+ }
+ if self._auto_alpha:
+ ret.update({'optimizer_alpha': self._alpha_optim.state_dict()})
+ return ret
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): The dict of policy learn state saved before.
+
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer_q.load_state_dict(state_dict['optimizer_q'])
+ self._optimizer_policy.load_state_dict(state_dict['optimizer_policy'])
+ if self._auto_alpha:
+ self._alpha_optim.load_state_dict(state_dict['optimizer_alpha'])
+
+ def _init_collect(self) -> None:
+ """
+ Overview:
+ Initialize the collect mode of policy, including related attributes and modules. For SAC, it contains the \
+ collect_model other algorithm-specific arguments such as unroll_len. \
+ This method will be called in ``__init__`` method if ``collect`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_collect`` method, you'd better name them \
+ with prefix ``_collect_`` to avoid conflict with other modes, such as ``self._collect_attr1``.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._collect_model = model_wrap(self._model, wrapper_name='base')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: Dict[int, Any], **kwargs) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of collect mode (collecting training data by interacting with envs). Forward means \
+ that the policy gets some necessary data (mainly observation) from the envs and then returns the output \
+ data, such as the action to interact with the envs.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action and \
+ other necessary data for learn mode defined in ``self._process_transition`` method. The key of the \
+ dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ (mu, sigma) = self._collect_model.forward(data, mode='compute_actor')['logit']
+ dist = Independent(Normal(mu, sigma), 1)
+ action = torch.tanh(dist.rsample())
+ output = {'logit': (mu, sigma), 'action': action}
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: torch.Tensor, policy_output: Dict[str, torch.Tensor],
+ timestep: namedtuple) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Process and pack one timestep transition data into a dict, which can be directly used for training and \
+ saved in replay buffer. For continuous SAC, it contains obs, next_obs, action, reward, done. The logit \
+ will be also added when ``collector_logit`` is True.
+ Arguments:
+ - obs (:obj:`torch.Tensor`): The env observation of current timestep, such as stacked 2D image in Atari.
+ - policy_output (:obj:`Dict[str, torch.Tensor]`): The output of the policy network with the observation \
+ as input. For continuous SAC, it contains the action and the logit (mu and sigma) of the action.
+ - timestep (:obj:`namedtuple`): The execution result namedtuple returned by the environment step method, \
+ except all the elements have been transformed into tensor data. Usually, it contains the next obs, \
+ reward, done, info, etc.
+ Returns:
+ - transition (:obj:`Dict[str, torch.Tensor]`): The processed transition data of the current timestep.
+ """
+ if self._cfg.collect.collector_logit:
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'logit': policy_output['logit'],
+ 'action': policy_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ else:
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': policy_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, transitions: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory (transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. In continuous SAC, a train sample is a processed transition \
+ (unroll_len=1).
+ Arguments:
+ - transitions (:obj:`List[Dict[str, Any]`): The trajectory data (a list of transition), each element is \
+ the same format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`List[Dict[str, Any]]`): The processed train samples, each element is the similar format \
+ as input transitions, but may contain more data for training.
+ """
+ return get_train_sample(transitions, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ """
+ Overview:
+ Initialize the eval mode of policy, including related attributes and modules. For SAC, it contains the \
+ eval model, which is equipped with ``base`` model wrapper to ensure compability.
+ This method will be called in ``__init__`` method if ``eval`` field is in ``enable_field``.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_eval`` method, you'd better name them \
+ with prefix ``_eval_`` to avoid conflict with other modes, such as ``self._eval_attr1``.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='base')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: Dict[int, Any]) -> Dict[int, Any]:
+ """
+ Overview:
+ Policy forward function of eval mode (evaluation policy performance by interacting with envs). Forward \
+ means that the policy gets some necessary data (mainly observation) from the envs and then returns the \
+ action to interact with the envs.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): The input data used for policy forward, including at least the obs. The \
+ key of the dict is environment id and the value is the corresponding data of the env.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The output data of policy forward, including at least the action. The \
+ key of the dict is the same as the input data, i.e. environment id.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ ``logit`` in SAC means the mu and sigma of Gaussioan distribution. Here we use this name for consistency.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``.
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ (mu, sigma) = self._eval_model.forward(data, mode='compute_actor')['logit']
+ action = torch.tanh(mu) # deterministic_eval
+ output = {'action': action}
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ twin_critic = ['twin_critic_loss'] if self._twin_critic else []
+ alpha_loss = ['alpha_loss'] if self._auto_alpha else []
+ return [
+ 'value_loss'
+ 'alpha_loss',
+ 'policy_loss',
+ 'critic_loss',
+ 'cur_lr_q',
+ 'cur_lr_p',
+ 'target_q_value',
+ 'alpha',
+ 'td_error',
+ 'transformed_log_prob',
+ ] + twin_critic + alpha_loss
+
+
+@POLICY_REGISTRY.register('sqil_sac')
+class SQILSACPolicy(SACPolicy):
+ """
+ Overview:
+ Policy class of continuous SAC algorithm with SQIL extension.
+ SAC paper link: https://arxiv.org/pdf/1801.01290.pdf
+ SQIL paper link: https://arxiv.org/abs/1905.11108
+ """
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Initialize the learn mode of policy, including related attributes and modules. For SAC, it mainly \
+ contains three optimizers, algorithm-specific arguments such as gamma and twin_critic, main and target \
+ model. Especially, the ``auto_alpha`` mechanism for balancing max entropy target is also initialized here.
+ This method will be called in ``__init__`` method if ``learn`` field is in ``enable_field``.
+
+ .. note::
+ For the member variables that need to be saved and loaded, please refer to the ``_state_dict_learn`` \
+ and ``_load_state_dict_learn`` methods.
+
+ .. note::
+ For the member variables that need to be monitored, please refer to the ``_monitor_vars_learn`` method.
+
+ .. note::
+ If you want to set some spacial member variables in ``_init_learn`` method, you'd better name them \
+ with prefix ``_learn_`` to avoid conflict with other modes, such as ``self._learn_attr1``.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ self._twin_critic = self._cfg.model.twin_critic
+
+ # Weight Init for the last output layer
+ init_w = self._cfg.learn.init_w
+ self._model.actor_head[-1].mu.weight.data.uniform_(-init_w, init_w)
+ self._model.actor_head[-1].mu.bias.data.uniform_(-init_w, init_w)
+ self._model.actor_head[-1].log_sigma_layer.weight.data.uniform_(-init_w, init_w)
+ self._model.actor_head[-1].log_sigma_layer.bias.data.uniform_(-init_w, init_w)
+
+ self._optimizer_q = Adam(
+ self._model.critic.parameters(),
+ lr=self._cfg.learn.learning_rate_q,
+ )
+ self._optimizer_policy = Adam(
+ self._model.actor.parameters(),
+ lr=self._cfg.learn.learning_rate_policy,
+ )
+
+ # Algorithm-Specific Config
+ self._gamma = self._cfg.learn.discount_factor
+ if self._cfg.learn.auto_alpha:
+ if self._cfg.learn.target_entropy is None:
+ assert 'action_shape' in self._cfg.model, "SQILSACPolicy need network model with action_shape variable"
+ self._target_entropy = -np.prod(self._cfg.model.action_shape)
+ else:
+ self._target_entropy = self._cfg.learn.target_entropy
+ if self._cfg.learn.log_space:
+ self._log_alpha = torch.log(torch.FloatTensor([self._cfg.learn.alpha]))
+ self._log_alpha = self._log_alpha.to(self._device).requires_grad_()
+ self._alpha_optim = torch.optim.Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha)
+ assert self._log_alpha.shape == torch.Size([1]) and self._log_alpha.requires_grad
+ self._alpha = self._log_alpha.detach().exp()
+ self._auto_alpha = True
+ self._log_space = True
+ else:
+ self._alpha = torch.FloatTensor([self._cfg.learn.alpha]).to(self._device).requires_grad_()
+ self._alpha_optim = torch.optim.Adam([self._alpha], lr=self._cfg.learn.learning_rate_alpha)
+ self._auto_alpha = True
+ self._log_space = False
+ else:
+ self._alpha = torch.tensor(
+ [self._cfg.learn.alpha], requires_grad=False, device=self._device, dtype=torch.float32
+ )
+ self._auto_alpha = False
+
+ # Main and target models
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ # monitor cossimilarity and entropy switch
+ self._monitor_cos = True
+ self._monitor_entropy = True
+
+ def _forward_learn(self, data: List[Dict[str, Any]]) -> Dict[str, Any]:
+ """
+ Overview:
+ Policy forward function of learn mode (training policy and updating parameters). Forward means \
+ that the policy inputs some training batch data from the replay buffer and then returns the output \
+ result, including various training information such as loss, action, priority.
+ Arguments:
+ - data (:obj:`List[Dict[int, Any]]`): The input data used for policy forward, including a batch of \
+ training samples. For each element in list, the key of the dict is the name of data items and the \
+ value is the corresponding data. Usually, the value is torch.Tensor or np.ndarray or there dict/list \
+ combinations. In the ``_forward_learn`` method, data often need to first be stacked in the batch \
+ dimension by some utility functions such as ``default_preprocess_learn``. \
+ For SAC, each element in list is a dict containing at least the following keys: ``obs``, ``action``, \
+ ``reward``, ``next_obs``, ``done``. Sometimes, it also contains other keys such as ``weight``.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): The information dict that indicated training result, which will be \
+ recorded in text log and tensorboard, values must be python scalar or a list of scalars. For the \
+ detailed definition of the dict, refer to the code of ``_monitor_vars_learn`` method.
+
+ .. note::
+ For SQIL + SAC, input data is composed of two parts with the same size: agent data and expert data. \
+ Both of them are relabelled with new reward according to SQIL algorithm.
+
+ .. note::
+ The input value can be torch.Tensor or dict/list combinations and current policy supports all of them. \
+ For the data type that not supported, the main reason is that the corresponding model does not support it. \
+ You can implement you own model rather than use the default model. For more information, please raise an \
+ issue in GitHub repo and we will continue to follow up.
+
+ .. note::
+ For more detailed examples, please refer to our unittest for SACPolicy: ``ding.policy.tests.test_sac``.
+ """
+ loss_dict = {}
+ if self._monitor_cos:
+ agent_data = default_preprocess_learn(
+ data[0:len(data) // 2],
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+
+ expert_data = default_preprocess_learn(
+ data[len(data) // 2:],
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if self._cuda:
+ agent_data = to_device(agent_data, self._device)
+ expert_data = to_device(expert_data, self._device)
+
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ self._learn_model.train()
+ self._target_model.train()
+ obs = data['obs']
+ next_obs = data['next_obs']
+ reward = data['reward']
+ done = data['done']
+
+ # 1. predict q value
+ q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
+
+ # 2. predict target value
+ with torch.no_grad():
+ (mu, sigma) = self._learn_model.forward(next_obs, mode='compute_actor')['logit']
+ dist = Independent(Normal(mu, sigma), 1)
+ pred = dist.rsample()
+ next_action = torch.tanh(pred)
+ y = 1 - next_action.pow(2) + 1e-6
+ # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum)
+ next_log_prob = dist.log_prob(pred).unsqueeze(-1)
+ next_log_prob = next_log_prob - torch.log(y).sum(-1, keepdim=True)
+
+ next_data = {'obs': next_obs, 'action': next_action}
+ target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']
+ # the value of a policy according to the maximum entropy objective
+ if self._twin_critic:
+ # find min one as target q value
+ target_q_value = torch.min(target_q_value[0],
+ target_q_value[1]) - self._alpha * next_log_prob.squeeze(-1)
+ else:
+ target_q_value = target_q_value - self._alpha * next_log_prob.squeeze(-1)
+
+ # 3. compute q loss
+ if self._twin_critic:
+ q_data0 = v_1step_td_data(q_value[0], target_q_value, reward, done, data['weight'])
+ loss_dict['critic_loss'], td_error_per_sample0 = v_1step_td_error(q_data0, self._gamma)
+ q_data1 = v_1step_td_data(q_value[1], target_q_value, reward, done, data['weight'])
+ loss_dict['twin_critic_loss'], td_error_per_sample1 = v_1step_td_error(q_data1, self._gamma)
+ td_error_per_sample = (td_error_per_sample0 + td_error_per_sample1) / 2
+ else:
+ q_data = v_1step_td_data(q_value, target_q_value, reward, done, data['weight'])
+ loss_dict['critic_loss'], td_error_per_sample = v_1step_td_error(q_data, self._gamma)
+
+ # 4. update q network
+ self._optimizer_q.zero_grad()
+ if self._twin_critic:
+ (loss_dict['critic_loss'] + loss_dict['twin_critic_loss']).backward()
+ else:
+ loss_dict['critic_loss'].backward()
+ self._optimizer_q.step()
+
+ # 5. evaluate to get action distribution
+ if self._monitor_cos:
+ # agent
+ (mu, sigma) = self._learn_model.forward(agent_data['obs'], mode='compute_actor')['logit']
+ dist = Independent(Normal(mu, sigma), 1)
+ pred = dist.rsample()
+ action = torch.tanh(pred)
+ y = 1 - action.pow(2) + 1e-6
+ # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum)
+ agent_log_prob = dist.log_prob(pred).unsqueeze(-1)
+ agent_log_prob = agent_log_prob - torch.log(y).sum(-1, keepdim=True)
+
+ eval_data = {'obs': agent_data['obs'], 'action': action}
+ agent_new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value']
+ if self._twin_critic:
+ agent_new_q_value = torch.min(agent_new_q_value[0], agent_new_q_value[1])
+ # expert
+ (mu, sigma) = self._learn_model.forward(expert_data['obs'], mode='compute_actor')['logit']
+ dist = Independent(Normal(mu, sigma), 1)
+ pred = dist.rsample()
+ action = torch.tanh(pred)
+ y = 1 - action.pow(2) + 1e-6
+ # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum)
+ expert_log_prob = dist.log_prob(pred).unsqueeze(-1)
+ expert_log_prob = expert_log_prob - torch.log(y).sum(-1, keepdim=True)
+
+ eval_data = {'obs': expert_data['obs'], 'action': action}
+ expert_new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value']
+ if self._twin_critic:
+ expert_new_q_value = torch.min(expert_new_q_value[0], expert_new_q_value[1])
+
+ (mu, sigma) = self._learn_model.forward(data['obs'], mode='compute_actor')['logit']
+ dist = Independent(Normal(mu, sigma), 1)
+ # for monitor the entropy of policy
+ if self._monitor_entropy:
+ dist_entropy = dist.entropy()
+ entropy = dist_entropy.mean()
+
+ pred = dist.rsample()
+ action = torch.tanh(pred)
+ y = 1 - action.pow(2) + 1e-6
+ # keep dimension for loss computation (usually for action space is 1 env. e.g. pendulum)
+ log_prob = dist.log_prob(pred).unsqueeze(-1)
+ log_prob = log_prob - torch.log(y).sum(-1, keepdim=True)
+
+ eval_data = {'obs': obs, 'action': action}
+ new_q_value = self._learn_model.forward(eval_data, mode='compute_critic')['q_value']
+ if self._twin_critic:
+ new_q_value = torch.min(new_q_value[0], new_q_value[1])
+
+ # 6. compute policy loss
+ policy_loss = (self._alpha * log_prob - new_q_value.unsqueeze(-1)).mean()
+ loss_dict['policy_loss'] = policy_loss
+
+ # 7. update policy network
+ if self._monitor_cos:
+ agent_policy_loss = (self._alpha * agent_log_prob - agent_new_q_value.unsqueeze(-1)).mean()
+ expert_policy_loss = (self._alpha * expert_log_prob - expert_new_q_value.unsqueeze(-1)).mean()
+ loss_dict['agent_policy_loss'] = agent_policy_loss
+ loss_dict['expert_policy_loss'] = expert_policy_loss
+ self._optimizer_policy.zero_grad()
+ loss_dict['agent_policy_loss'].backward()
+ agent_grad = (list(list(self._learn_model.actor.children())[-1].children())[-1].weight.grad).mean()
+ self._optimizer_policy.zero_grad()
+ loss_dict['expert_policy_loss'].backward()
+ expert_grad = (list(list(self._learn_model.actor.children())[-1].children())[-1].weight.grad).mean()
+ cos = nn.CosineSimilarity(dim=0)
+ cos_similarity = cos(agent_grad, expert_grad)
+ self._optimizer_policy.zero_grad()
+ loss_dict['policy_loss'].backward()
+ self._optimizer_policy.step()
+
+ # 8. compute alpha loss
+ if self._auto_alpha:
+ if self._log_space:
+ log_prob = log_prob + self._target_entropy
+ loss_dict['alpha_loss'] = -(self._log_alpha * log_prob.detach()).mean()
+
+ self._alpha_optim.zero_grad()
+ loss_dict['alpha_loss'].backward()
+ self._alpha_optim.step()
+ self._alpha = self._log_alpha.detach().exp()
+ else:
+ log_prob = log_prob + self._target_entropy
+ loss_dict['alpha_loss'] = -(self._alpha * log_prob.detach()).mean()
+
+ self._alpha_optim.zero_grad()
+ loss_dict['alpha_loss'].backward()
+ self._alpha_optim.step()
+ self._alpha = max(0, self._alpha)
+
+ loss_dict['total_loss'] = sum(loss_dict.values())
+
+ # target update
+ self._target_model.update(self._learn_model.state_dict())
+ var_monitor = {
+ 'cur_lr_q': self._optimizer_q.defaults['lr'],
+ 'cur_lr_p': self._optimizer_policy.defaults['lr'],
+ 'priority': td_error_per_sample.abs().tolist(),
+ 'td_error': td_error_per_sample.detach().mean().item(),
+ 'agent_td_error': td_error_per_sample.detach().chunk(2, dim=0)[0].mean().item(),
+ 'expert_td_error': td_error_per_sample.detach().chunk(2, dim=0)[1].mean().item(),
+ 'alpha': self._alpha.item(),
+ 'target_q_value': target_q_value.detach().mean().item(),
+ 'mu': mu.detach().mean().item(),
+ 'sigma': sigma.detach().mean().item(),
+ 'q_value0': new_q_value[0].detach().mean().item(),
+ 'q_value1': new_q_value[1].detach().mean().item(),
+ **loss_dict,
+ }
+ if self._monitor_cos:
+ var_monitor['cos_similarity'] = cos_similarity.item()
+ if self._monitor_entropy:
+ var_monitor['entropy'] = entropy.item()
+ return var_monitor
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ twin_critic = ['twin_critic_loss'] if self._twin_critic else []
+ alpha_loss = ['alpha_loss'] if self._auto_alpha else []
+ cos_similarity = ['cos_similarity'] if self._monitor_cos else []
+ entropy = ['entropy'] if self._monitor_entropy else []
+ return [
+ 'value_loss'
+ 'alpha_loss',
+ 'policy_loss',
+ 'critic_loss',
+ 'cur_lr_q',
+ 'cur_lr_p',
+ 'target_q_value',
+ 'alpha',
+ 'td_error',
+ 'agent_td_error',
+ 'expert_td_error',
+ 'mu',
+ 'sigma',
+ 'q_value0',
+ 'q_value1',
+ ] + twin_critic + alpha_loss + cos_similarity + entropy
diff --git a/DI-engine/ding/policy/sql.py b/DI-engine/ding/policy/sql.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc6170dfb7b7cb9d035129108c3a4950df2e5212
--- /dev/null
+++ b/DI-engine/ding/policy/sql.py
@@ -0,0 +1,296 @@
+from typing import List, Dict, Any, Tuple, Union, Optional
+from collections import namedtuple, deque
+import copy
+import torch
+from torch.distributions import Categorical
+from ditk import logging
+from easydict import EasyDict
+from ding.torch_utils import Adam, to_device
+from ding.utils.data import default_collate, default_decollate
+from ding.rl_utils import q_nstep_td_data, q_nstep_sql_td_error, get_nstep_return_data, get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('sql')
+class SQLPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of SQL algorithm.
+ """
+
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='sql',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.97,
+ # (int) N-step reward for target q_value estimation
+ nstep=1,
+ learn=dict(
+
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=3, # after the batch data come into the learner, train with the data for 3 times
+ batch_size=64,
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ # (int) Frequence of target network update.
+ target_update_freq=100,
+ # (bool) Whether ignore done(usually for max step termination env)
+ ignore_done=False,
+ alpha=0.1,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] shoule be set
+ # n_sample=8, # collect 8 samples and put them in collector
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(),
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=0.95,
+ end=0.1,
+ # (int) Decay length(env step)
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=10000, )
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default model setting for demonstration.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
+
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For DQN, ``ding.model.template.q_learning.DQN``
+ """
+ return 'dqn', ['ding.model.template.q_learning']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the optimizer, algorithm config, main and target models.
+ """
+ self._priority = self._cfg.priority
+ # Optimizer
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+ self._gamma = self._cfg.discount_factor
+ self._nstep = self._cfg.nstep
+ self._alpha = self._cfg.learn.alpha
+ # use wrapper instead of plugin
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='assign',
+ update_kwargs={'freq': self._cfg.learn.target_update_freq}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including current lr and loss.
+ """
+ data = default_preprocess_learn(
+ data, use_priority=self._priority, ignore_done=self._cfg.learn.ignore_done, use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # Q-learning forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # Current q value (main model)
+ q_value = self._learn_model.forward(data['obs'])['logit']
+ with torch.no_grad():
+ # Target q value
+ target_q_value = self._target_model.forward(data['next_obs'])['logit']
+ # Max q value action (main model)
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+
+ data_n = q_nstep_td_data(
+ q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
+ )
+ value_gamma = data.get('value_gamma')
+ loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(
+ data_n, self._gamma, self._cfg.learn.alpha, nstep=self._nstep, value_gamma=value_gamma
+ )
+ record_target_v = record_target_v.mean()
+ # ====================
+ # Q-learning update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ self._optimizer.step()
+
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ 'record_value_function': record_target_v
+ # Only discrete action satisfying len(data['action'])==1 can return this and draw histogram on tensorboard.
+ # '[histogram]action_distribution': data['action'],
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ Enable the eps_greedy_sample
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._gamma = self._cfg.discount_factor # necessary for parallel
+ self._nstep = self._cfg.nstep # necessary for parallel
+ self._collect_model = model_wrap(self._model, wrapper_name='eps_greedy_multinomial_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: Dict[int, Any], eps: float) -> Dict[int, Any]:
+ r"""
+ Overview:
+ Forward function for collect mode with eps_greedy
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ - eps (:obj:`float`): epsilon value for exploration, which is decayed by collected env step.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, eps=eps, alpha=self._cfg.learn.alpha)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _get_train_sample(self, data: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ For a given trajectory(transitions, a list of transition) data, process it into a list of sample that \
+ can be used for training directly. A train sample can be a processed transition(DQN with nstep TD) \
+ or some continuous transitions(DRQN).
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]`): The trajectory data(a list of transition), each element is the same \
+ format as the return value of ``self._process_transition`` method.
+ Returns:
+ - samples (:obj:`dict`): The list of training samples.
+
+ .. note::
+ We will vectorize ``process_transition`` and ``get_train_sample`` method in the following release version. \
+ And the user can customize the this data processing procecure by overriding this two methods and collector \
+ itself.
+ """
+ data = get_nstep_return_data(data, self._nstep, gamma=self._gamma)
+ return get_train_sample(data, self._unroll_len)
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': model_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model with argmax strategy.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of eval mode, similar to ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return super()._monitor_vars_learn() + ['record_value_function']
diff --git a/DI-engine/ding/policy/sqn.py b/DI-engine/ding/policy/sqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5241ee993e5bcfbdf20caf89034fdd9dab5c1547
--- /dev/null
+++ b/DI-engine/ding/policy/sqn.py
@@ -0,0 +1,357 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import math
+import itertools
+import numpy as np
+import torch
+import torch.nn.functional as F
+import copy
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('sqn')
+class SQNPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of SQN algorithm (arxiv: 1912.10891).
+ """
+
+ config = dict(
+ cuda=False,
+ type='sqn',
+ on_policy=False,
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ learn=dict(
+ update_per_collect=16,
+ batch_size=64,
+ learning_rate_q=0.001,
+ learning_rate_alpha=0.001,
+ # ==============================================================
+ # The following configs are algorithm-specific
+ # ==============================================================
+ target_theta=0.005,
+ alpha=0.2,
+ discount_factor=0.99,
+ # If env's action shape is int type, we recommend `self._action_shape / 10`; else, we recommend 0.2
+ target_entropy=0.2,
+ # (bool) Whether ignore done(usually for max step termination env)
+ ignore_done=False,
+ ),
+ collect=dict(
+ # n_sample=16,
+ # Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.8,
+ decay=2000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'sqn', ['ding.model.template.sqn']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init q, value and policy's optimizers, algorithm config, main and target models.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ # Optimizers
+ self._optimizer_q = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate_q)
+
+ # Algorithm config
+ self._gamma = self._cfg.learn.discount_factor
+ self._action_shape = self._cfg.model.action_shape
+ self._target_entropy = self._cfg.learn.target_entropy
+ self._log_alpha = torch.FloatTensor([math.log(self._cfg.learn.alpha)]).to(self._device).requires_grad_(True)
+ self._optimizer_alpha = Adam([self._log_alpha], lr=self._cfg.learn.learning_rate_alpha)
+
+ # Main and target models
+ self._target_model = copy.deepcopy(self._model)
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ self._forward_learn_cnt = 0
+
+ def q_1step_td_loss(self, td_data: dict) -> torch.tensor:
+ q_value = td_data["q_value"]
+ target_q_value = td_data["target_q_value"]
+ action = td_data.get('action')
+ done = td_data.get('done')
+ reward = td_data.get('reward')
+ q0 = q_value[0]
+ q1 = q_value[1]
+ batch_range = torch.arange(action.shape[0])
+ q0_a = q0[batch_range, action]
+ q1_a = q1[batch_range, action]
+ # Target
+ with torch.no_grad():
+ q0_targ = target_q_value[0]
+ q1_targ = target_q_value[1]
+ q_targ = torch.min(q0_targ, q1_targ)
+ # discrete policy
+ alpha = torch.exp(self._log_alpha.clone())
+ # TODO use q_targ or q0 for pi
+ log_pi = F.log_softmax(q_targ / alpha, dim=-1)
+ pi = torch.exp(log_pi)
+ # v = \sum_a \pi(a | s) (Q(s, a) - \alpha \log(\pi(a|s)))
+ target_v_value = (pi * (q_targ - alpha * log_pi)).sum(axis=-1)
+ # q = r + \gamma v
+ q_backup = reward + (1 - done) * self._gamma * target_v_value
+ # alpha_loss
+ entropy = (-pi * log_pi).sum(axis=-1)
+ expect_entropy = (pi * self._target_entropy).sum(axis=-1)
+
+ # Q loss
+ q0_loss = F.mse_loss(q0_a, q_backup)
+ q1_loss = F.mse_loss(q1_a, q_backup)
+ total_q_loss = q0_loss + q1_loss
+ # alpha loss
+ alpha_loss = self._log_alpha * (entropy - expect_entropy).mean()
+ return total_q_loss, alpha_loss, entropy
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ """
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs', 'done',\
+ 'weight']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Learn info, including current lr and loss.
+ """
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._cfg.priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ self._learn_model.train()
+ self._target_model.train()
+ obs = data.get('obs')
+ next_obs = data.get('next_obs')
+ reward = data.get('reward')
+ action = data.get('action')
+ done = data.get('done')
+ # Q-function
+ q_value = self._learn_model.forward(obs)['q_value']
+ target_q_value = self._target_model.forward(next_obs)['q_value']
+
+ num_s_env = 1 if isinstance(self._action_shape, int) else len(self._action_shape) # num of separate env
+
+ for s_env_id in range(num_s_env):
+ if isinstance(self._action_shape, int):
+ td_data = {
+ "q_value": q_value,
+ "target_q_value": target_q_value,
+ "obs": obs,
+ "next_obs": next_obs,
+ "reward": reward,
+ "action": action,
+ "done": done
+ }
+ else:
+ td_data = {
+ "q_value": [q_value[0][s_env_id], q_value[1][s_env_id]],
+ "target_q_value": [target_q_value[0][s_env_id], target_q_value[1][s_env_id]],
+ "obs": obs,
+ "next_obs": next_obs,
+ "reward": reward,
+ "action": action[s_env_id],
+ "done": done
+ }
+ total_q_loss, alpha_loss, entropy = self.q_1step_td_loss(td_data)
+ if s_env_id == 0:
+ a_total_q_loss, a_alpha_loss, a_entropy = total_q_loss, alpha_loss, entropy # accumulate
+ else: # running average, accumulate loss
+ a_total_q_loss += total_q_loss / (num_s_env + 1e-6)
+ a_alpha_loss += alpha_loss / (num_s_env + 1e-6)
+ a_entropy += entropy / (num_s_env + 1e-6)
+
+ self._optimizer_q.zero_grad()
+ a_total_q_loss.backward()
+ self._optimizer_q.step()
+
+ self._optimizer_alpha.zero_grad()
+ a_alpha_loss.backward()
+ self._optimizer_alpha.step()
+
+ # target update
+ self._target_model.update(self._learn_model.state_dict())
+ self._forward_learn_cnt += 1
+ # some useful info
+ return {
+ '[histogram]action_distribution': np.stack([a.cpu().numpy() for a in data['action']]).flatten(),
+ 'q_loss': a_total_q_loss.item(),
+ 'alpha_loss': a_alpha_loss.item(),
+ 'entropy': a_entropy.mean().item(),
+ 'alpha': math.exp(self._log_alpha.item()),
+ 'q_value': np.mean([x.cpu().detach().numpy() for x in itertools.chain(*q_value)], dtype=float),
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'optimizer_q': self._optimizer_q.state_dict(),
+ 'optimizer_alpha': self._optimizer_alpha.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._optimizer_q.load_state_dict(state_dict['optimizer_q'])
+ self._optimizer_alpha.load_state_dict(state_dict['optimizer_alpha'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ Use action noise for exploration.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._collect_model = model_wrap(self._model, wrapper_name='base')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of collect mode.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
+ ReturnsKeys
+ - necessary: ``action``
+ - optional: ``logit``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ # start with random action for better exploration
+ output = self._collect_model.forward(data)
+ _decay = self._cfg.other.eps.decay
+ _act_p = 1 / \
+ (_decay - self._forward_learn_cnt) if self._forward_learn_cnt < _decay - 1000 else 0.999
+
+ if np.random.random(1) < _act_p:
+ if isinstance(self._action_shape, int):
+ logits = output['logit'] / math.exp(self._log_alpha.item())
+ prob = torch.softmax(logits - logits.max(axis=-1, keepdim=True).values, dim=-1)
+ pi_action = torch.multinomial(prob, 1)
+ else:
+ logits = [_logit / math.exp(self._log_alpha.item()) for _logit in output['logit']]
+ prob = [
+ torch.softmax(_logits - _logits.max(axis=-1, keepdim=True).values, dim=-1) for _logits in logits
+ ]
+ pi_action = [torch.multinomial(_prob, 1) for _prob in prob]
+ else:
+ if isinstance(self._action_shape, int):
+ pi_action = torch.randint(0, self._action_shape, (output["logit"].shape[0], ))
+ else:
+ pi_action = [torch.randint(0, d, (output["logit"][0].shape[0], )) for d in self._action_shape]
+
+ output['action'] = pi_action
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
+ (here 'obs' indicates obs after env step, i.e. next_obs).
+ Return:
+ - transition (:obj:`Dict[str, Any]`): Dict type transition data.
+ """
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': model_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ return get_train_sample(data, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model, which use argmax for selecting action
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of eval mode, similar to ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ReturnsKeys
+ - necessary: ``action``
+ - optional: ``logit``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ r"""
+ Overview:
+ Return variables' name if variables are to used in monitor.
+ Returns:
+ - vars (:obj:`List[str]`): Variables' name list.
+ """
+ return ['alpha_loss', 'alpha', 'entropy', 'q_loss', 'q_value']
diff --git a/DI-engine/ding/policy/td3.py b/DI-engine/ding/policy/td3.py
new file mode 100644
index 0000000000000000000000000000000000000000..7359190282c6181822d59601ea3e0e833108408a
--- /dev/null
+++ b/DI-engine/ding/policy/td3.py
@@ -0,0 +1,160 @@
+from typing import List
+from ding.utils import POLICY_REGISTRY
+from .ddpg import DDPGPolicy
+
+
+@POLICY_REGISTRY.register('td3')
+class TD3Policy(DDPGPolicy):
+ """
+ Overview:
+ Policy class of TD3 algorithm. Since DDPG and TD3 share many common things, we can easily derive this TD3 \
+ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and noise in model wrapper.
+ Paper link: https://arxiv.org/pdf/1802.09477.pdf
+
+ Config:
+
+ == ==================== ======== ================== ================================= =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ================== ================================= =======================
+ 1 | ``type`` str td3 | RL policy register name, refer | this arg is optional,
+ | | to registry ``POLICY_REGISTRY`` | a placeholder
+ 2 | ``cuda`` bool False | Whether to use cuda for network |
+ 3 | ``random_`` int 25000 | Number of randomly collected | Default to 25000 for
+ | ``collect_size`` | training samples in replay | DDPG/TD3, 10000 for
+ | | buffer when training starts. | sac.
+ 4 | ``model.twin_`` bool True | Whether to use two critic | Default True for TD3,
+ | ``critic`` | networks or only one. | Clipped Double
+ | | | Q-learning method in
+ | | | TD3 paper.
+ 5 | ``learn.learning`` float 1e-3 | Learning rate for actor |
+ | ``_rate_actor`` | network(aka. policy). |
+ 6 | ``learn.learning`` float 1e-3 | Learning rates for critic |
+ | ``_rate_critic`` | network (aka. Q-network). |
+ 7 | ``learn.actor_`` int 2 | When critic network updates | Default 2 for TD3, 1
+ | ``update_freq`` | once, how many times will actor | for DDPG. Delayed
+ | | network update. | Policy Updates method
+ | | | in TD3 paper.
+ 8 | ``learn.noise`` bool True | Whether to add noise on target | Default True for TD3,
+ | | network's action. | False for DDPG.
+ | | | Target Policy Smoo-
+ | | | thing Regularization
+ | | | in TD3 paper.
+ 9 | ``learn.noise_`` dict | dict(min=-0.5, | Limit for range of target |
+ | ``range`` | max=0.5,) | policy smoothing noise, |
+ | | | aka. noise_clip. |
+ 10 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only
+ | ``ignore_done`` | done flag. | in halfcheetah env.
+ 11 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation
+ | ``target_theta`` | target network. | factor in polyak aver
+ | | | -aging for target
+ | | | networks.
+ 12 | ``collect.-`` float 0.1 | Used for add noise during co- | Sample noise from dis
+ | ``noise_sigma`` | llection, through controlling | -tribution, Ornstein-
+ | | the sigma of distribution | Uhlenbeck process in
+ | | | DDPG paper, Gaussian
+ | | | process in ours.
+ == ==================== ======== ================== ================================= =======================
+ """
+
+ # You can refer to DDPG's default config for more details.
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='td3',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) on_policy: Determine whether on-policy or off-policy. Default False in TD3.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ # Default False in TD3.
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (int) Number of training samples(randomly collected) in replay buffer when training starts.
+ # Default 25000 in DDPG/TD3.
+ random_collect_size=25000,
+ # (bool) Whether to need policy data in process transition.
+ transition_with_policy_data=False,
+ # (str) Action space type
+ action_space='continuous', # ['continuous', 'hybrid']
+ # (bool) Whether use batch normalization for reward
+ reward_batch_norm=False,
+ # (bool) Whether to enable multi-agent training setting
+ multi_agent=False,
+ model=dict(
+ # (bool) Whether to use two critic networks or only one.
+ # Clipped Double Q-Learning for Actor-Critic in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default True for TD3, False for DDPG.
+ twin_critic=True,
+ ),
+ # learn_mode config
+ learn=dict(
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=1,
+ # (int) Minibatch size for gradient descent.
+ batch_size=256,
+ # (float) Learning rates for actor network(aka. policy).
+ learning_rate_actor=1e-3,
+ # (float) Learning rates for critic network(aka. Q-network).
+ learning_rate_critic=1e-3,
+ # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+ # (float) target_theta: Used for soft update of the target network,
+ # aka. Interpolation factor in polyak averaging for target networks.
+ # Default to 0.005.
+ target_theta=0.005,
+ # (float) discount factor for the discounted sum of rewards, aka. gamma.
+ discount_factor=0.99,
+ # (int) When critic network updates once, how many times will actor network update.
+ # Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default 1 for DDPG, 2 for TD3.
+ actor_update_freq=2,
+ # (bool) Whether to add noise on target network's action.
+ # Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default True for TD3, False for DDPG.
+ noise=True,
+ # (float) Sigma for smoothing noise added to target policy.
+ noise_sigma=0.2,
+ # (dict) Limit for range of target policy smoothing noise, aka. noise_clip.
+ noise_range=dict(
+ # (int) min value of noise
+ min=-0.5,
+ # (int) max value of noise
+ max=0.5,
+ ),
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) How many training samples collected in one collection procedure.
+ # Only one of [n_sample, n_episode] shoule be set.
+ # n_sample=1,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ # (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma".
+ noise_sigma=0.1,
+ ),
+ eval=dict(), # for compability
+ other=dict(
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer. Usually, larger buffer size is better.
+ replay_buffer_size=100000,
+ ),
+ ),
+ )
+
+ def _monitor_vars_learn(self) -> List[str]:
+ """
+ Overview:
+ Return the necessary keys for logging the return dict of ``self._forward_learn``. The logger module, such \
+ as text logger, tensorboard logger, will use these keys to save the corresponding data.
+ Returns:
+ - necessary_keys (:obj:`List[str]`): The list of the necessary keys to be logged.
+ """
+ return ["q_value", "loss", "lr", "entropy", "target_q_value", "td_error"]
diff --git a/DI-engine/ding/policy/td3_bc.py b/DI-engine/ding/policy/td3_bc.py
new file mode 100644
index 0000000000000000000000000000000000000000..e30b6bfc07d689b54801d3aa3514613555ca40fe
--- /dev/null
+++ b/DI-engine/ding/policy/td3_bc.py
@@ -0,0 +1,336 @@
+from typing import List, Dict, Any, Tuple, Union
+from easydict import EasyDict
+from collections import namedtuple
+import torch
+import torch.nn.functional as F
+import copy
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+from .ddpg import DDPGPolicy
+
+
+@POLICY_REGISTRY.register('td3_bc')
+class TD3BCPolicy(DDPGPolicy):
+ r"""
+ Overview:
+ Policy class of TD3_BC algorithm.
+
+ Since DDPG and TD3 share many common things, we can easily derive this TD3_BC
+ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and noise in model wrapper.
+
+ https://arxiv.org/pdf/2106.06860.pdf
+
+ Property:
+ learn_mode, collect_mode, eval_mode
+
+ Config:
+
+ == ==================== ======== ================== ================================= =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ================== ================================= =======================
+ 1 ``type`` str td3_bc | RL policy register name, refer | this arg is optional,
+ | to registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool True | Whether to use cuda for network |
+ 3 | ``random_`` int 25000 | Number of randomly collected | Default to 25000 for
+ | ``collect_size`` | training samples in replay | DDPG/TD3, 10000 for
+ | | buffer when training starts. | sac.
+ 4 | ``model.twin_`` bool True | Whether to use two critic | Default True for TD3,
+ | ``critic`` | networks or only one. | Clipped Double
+ | | | Q-learning method in
+ | | | TD3 paper.
+ 5 | ``learn.learning`` float 1e-3 | Learning rate for actor |
+ | ``_rate_actor`` | network(aka. policy). |
+ 6 | ``learn.learning`` float 1e-3 | Learning rates for critic |
+ | ``_rate_critic`` | network (aka. Q-network). |
+ 7 | ``learn.actor_`` int 2 | When critic network updates | Default 2 for TD3, 1
+ | ``update_freq`` | once, how many times will actor | for DDPG. Delayed
+ | | network update. | Policy Updates method
+ | | | in TD3 paper.
+ 8 | ``learn.noise`` bool True | Whether to add noise on target | Default True for TD3,
+ | | network's action. | False for DDPG.
+ | | | Target Policy Smoo-
+ | | | thing Regularization
+ | | | in TD3 paper.
+ 9 | ``learn.noise_`` dict | dict(min=-0.5, | Limit for range of target |
+ | ``range`` | max=0.5,) | policy smoothing noise, |
+ | | | aka. noise_clip. |
+ 10 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only
+ | ``ignore_done`` | done flag. | in halfcheetah env.
+ 11 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation
+ | ``target_theta`` | target network. | factor in polyak aver
+ | | | aging for target
+ | | | networks.
+ 12 | ``collect.-`` float 0.1 | Used for add noise during co- | Sample noise from dis
+ | ``noise_sigma`` | llection, through controlling | tribution, Ornstein-
+ | | the sigma of distribution | Uhlenbeck process in
+ | | | DDPG paper, Guassian
+ | | | process in ours.
+ == ==================== ======== ================== ================================= =======================
+ """
+
+ # You can refer to DDPG's default config for more details.
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='td3_bc',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool type) on_policy: Determine whether on-policy or off-policy.
+ # on-policy setting influences the behaviour of buffer.
+ # Default False in TD3.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ # Default False in TD3.
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (int) Number of training samples(randomly collected) in replay buffer when training starts.
+ # Default 25000 in DDPG/TD3.
+ random_collect_size=25000,
+ # (bool) Whether use batch normalization for reward
+ reward_batch_norm=False,
+ action_space='continuous',
+ model=dict(
+ # (bool) Whether to use two critic networks or only one.
+ # Clipped Double Q-Learning for Actor-Critic in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default True for TD3, False for DDPG.
+ twin_critic=True,
+
+ # (str type) action_space: Use regression trick for continous action
+ action_space='regression',
+
+ # (int) Hidden size for actor network head.
+ actor_head_hidden_size=256,
+
+ # (int) Hidden size for critic network head.
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=1,
+ # (int) Minibatch size for gradient descent.
+ batch_size=256,
+ # (float) Learning rates for actor network(aka. policy).
+ learning_rate_actor=1e-3,
+ # (float) Learning rates for critic network(aka. Q-network).
+ learning_rate_critic=1e-3,
+ # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+ # (float type) target_theta: Used for soft update of the target network,
+ # aka. Interpolation factor in polyak averaging for target networks.
+ # Default to 0.005.
+ target_theta=0.005,
+ # (float) discount factor for the discounted sum of rewards, aka. gamma.
+ discount_factor=0.99,
+ # (int) When critic network updates once, how many times will actor network update.
+ # Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default 1 for DDPG, 2 for TD3.
+ actor_update_freq=2,
+ # (bool) Whether to add noise on target network's action.
+ # Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default True for TD3, False for DDPG.
+ noise=True,
+ # (float) Sigma for smoothing noise added to target policy.
+ noise_sigma=0.2,
+ # (dict) Limit for range of target policy smoothing noise, aka. noise_clip.
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ alpha=2.5,
+ ),
+ collect=dict(
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ # (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma".
+ noise_sigma=0.1,
+ ),
+ eval=dict(
+ evaluator=dict(
+ # (int) Evaluate every "eval_freq" training iterations.
+ eval_freq=5000,
+ ),
+ ),
+ other=dict(
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer.
+ replay_buffer_size=1000000,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'continuous_qac', ['ding.model.template.qac']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Learn mode init method. Called by ``self.__init__``. Init actor and critic optimizers, algorithm config.
+ """
+ super(TD3BCPolicy, self)._init_learn()
+ self._alpha = self._cfg.learn.alpha
+ # actor and critic optimizer
+ self._optimizer_actor = Adam(
+ self._model.actor.parameters(),
+ lr=self._cfg.learn.learning_rate_actor,
+ grad_clip_type='clip_norm',
+ clip_value=1.0,
+ )
+ self._optimizer_critic = Adam(
+ self._model.critic.parameters(),
+ lr=self._cfg.learn.learning_rate_critic,
+ grad_clip_type='clip_norm',
+ clip_value=1.0,
+ )
+
+ self.noise_sigma = self._cfg.learn.noise_sigma
+ self.noise_range = self._cfg.learn.noise_range
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses.
+ """
+ loss_dict = {}
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._cfg.priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # critic learn forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ next_obs = data['next_obs']
+ reward = data['reward']
+ if self._reward_batch_norm:
+ reward = (reward - reward.mean()) / (reward.std() + 1e-8)
+ # current q value
+ q_value = self._learn_model.forward(data, mode='compute_critic')['q_value']
+ q_value_dict = {}
+ if self._twin_critic:
+ q_value_dict['q_value'] = q_value[0].mean()
+ q_value_dict['q_value_twin'] = q_value[1].mean()
+ else:
+ q_value_dict['q_value'] = q_value.mean()
+ # target q value.
+ with torch.no_grad():
+ next_action = self._target_model.forward(next_obs, mode='compute_actor')['action']
+ noise = (torch.randn_like(next_action) *
+ self.noise_sigma).clamp(self.noise_range['min'], self.noise_range['max'])
+ next_action = (next_action + noise).clamp(-1, 1)
+ next_data = {'obs': next_obs, 'action': next_action}
+ target_q_value = self._target_model.forward(next_data, mode='compute_critic')['q_value']
+ if self._twin_critic:
+ # TD3: two critic networks
+ target_q_value = torch.min(target_q_value[0], target_q_value[1]) # find min one as target q value
+ # critic network1
+ td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight'])
+ critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma)
+ loss_dict['critic_loss'] = critic_loss
+ # critic network2(twin network)
+ td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight'])
+ critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma)
+ loss_dict['critic_twin_loss'] = critic_twin_loss
+ td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2
+ else:
+ # DDPG: single critic network
+ td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight'])
+ critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma)
+ loss_dict['critic_loss'] = critic_loss
+ # ================
+ # critic update
+ # ================
+ self._optimizer_critic.zero_grad()
+ for k in loss_dict:
+ if 'critic' in k:
+ loss_dict[k].backward()
+ self._optimizer_critic.step()
+ # ===============================
+ # actor learn forward and update
+ # ===============================
+ # actor updates every ``self._actor_update_freq`` iters
+ if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
+ actor_data = self._learn_model.forward(data['obs'], mode='compute_actor')
+ actor_data['obs'] = data['obs']
+ if self._twin_critic:
+ q_value = self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0]
+ actor_loss = -q_value.mean()
+ else:
+ q_value = self._learn_model.forward(actor_data, mode='compute_critic')['q_value']
+ actor_loss = -q_value.mean()
+
+ # add behavior cloning loss weight(\lambda)
+ lmbda = self._alpha / q_value.abs().mean().detach()
+ # bc_loss = ((actor_data['action'] - data['action'])**2).mean()
+ bc_loss = F.mse_loss(actor_data['action'], data['action'])
+ actor_loss = lmbda * actor_loss + bc_loss
+ loss_dict['actor_loss'] = actor_loss
+ # actor update
+ self._optimizer_actor.zero_grad()
+ actor_loss.backward()
+ self._optimizer_actor.step()
+ # =============
+ # after update
+ # =============
+ loss_dict['total_loss'] = sum(loss_dict.values())
+ self._forward_learn_cnt += 1
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr_actor': self._optimizer_actor.defaults['lr'],
+ 'cur_lr_critic': self._optimizer_critic.defaults['lr'],
+ # 'q_value': np.array(q_value).mean(),
+ 'action': data.get('action').mean(),
+ 'priority': td_error_per_sample.abs().tolist(),
+ 'td_error': td_error_per_sample.abs().mean(),
+ **loss_dict,
+ **q_value_dict,
+ }
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of eval mode, similar to ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ReturnsKeys
+ - necessary: ``action``
+ - optional: ``logit``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_actor')
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
diff --git a/DI-engine/ding/policy/td3_vae.py b/DI-engine/ding/policy/td3_vae.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d029c0a9131c4574747fa188422ab82470db097
--- /dev/null
+++ b/DI-engine/ding/policy/td3_vae.py
@@ -0,0 +1,655 @@
+from typing import List, Dict, Any, Tuple, Union
+from collections import namedtuple
+import torch
+import copy
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from .base_policy import Policy
+from .common_utils import default_preprocess_learn
+from .ddpg import DDPGPolicy
+from ding.model.template.vae import VanillaVAE
+from ding.utils import RunningMeanStd
+from torch.nn import functional as F
+
+
+@POLICY_REGISTRY.register('td3-vae')
+class TD3VAEPolicy(DDPGPolicy):
+ r"""
+ Overview:
+ Policy class of TD3 algorithm.
+
+ Since DDPG and TD3 share many common things, we can easily derive this TD3
+ class from DDPG class by changing ``_actor_update_freq``, ``_twin_critic`` and noise in model wrapper.
+
+ https://arxiv.org/pdf/1802.09477.pdf
+
+ Property:
+ learn_mode, collect_mode, eval_mode
+
+ Config:
+
+ == ==================== ======== ================== ================================= =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ================== ================================= =======================
+ 1 ``type`` str td3 | RL policy register name, refer | this arg is optional,
+ | to registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool True | Whether to use cuda for network |
+ 3 | ``random_`` int 25000 | Number of randomly collected | Default to 25000 for
+ | ``collect_size`` | training samples in replay | DDPG/TD3, 10000 for
+ | | buffer when training starts. | sac.
+ 4 | ``model.twin_`` bool True | Whether to use two critic | Default True for TD3,
+ | ``critic`` | networks or only one. | Clipped Double
+ | | | Q-learning method in
+ | | | TD3 paper.
+ 5 | ``learn.learning`` float 1e-3 | Learning rate for actor |
+ | ``_rate_actor`` | network(aka. policy). |
+ 6 | ``learn.learning`` float 1e-3 | Learning rates for critic |
+ | ``_rate_critic`` | network (aka. Q-network). |
+ 7 | ``learn.actor_`` int 2 | When critic network updates | Default 2 for TD3, 1
+ | ``update_freq`` | once, how many times will actor | for DDPG. Delayed
+ | | network update. | Policy Updates method
+ | | | in TD3 paper.
+ 8 | ``learn.noise`` bool True | Whether to add noise on target | Default True for TD3,
+ | | network's action. | False for DDPG.
+ | | | Target Policy Smoo-
+ | | | thing Regularization
+ | | | in TD3 paper.
+ 9 | ``learn.noise_`` dict | dict(min=-0.5, | Limit for range of target |
+ | ``range`` | max=0.5,) | policy smoothing noise, |
+ | | | aka. noise_clip. |
+ 10 | ``learn.-`` bool False | Determine whether to ignore | Use ignore_done only
+ | ``ignore_done`` | done flag. | in halfcheetah env.
+ 11 | ``learn.-`` float 0.005 | Used for soft update of the | aka. Interpolation
+ | ``target_theta`` | target network. | factor in polyak aver
+ | | | aging for target
+ | | | networks.
+ 12 | ``collect.-`` float 0.1 | Used for add noise during co- | Sample noise from dis
+ | ``noise_sigma`` | llection, through controlling | tribution, Ornstein-
+ | | the sigma of distribution | Uhlenbeck process in
+ | | | DDPG paper, Guassian
+ | | | process in ours.
+ == ==================== ======== ================== ================================= =======================
+ """
+
+ # You can refer to DDPG's default config for more details.
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='td3',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool type) on_policy: Determine whether on-policy or off-policy.
+ # on-policy setting influences the behaviour of buffer.
+ # Default False in TD3.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ # Default False in TD3.
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (int) Number of training samples(randomly collected) in replay buffer when training starts.
+ # Default 25000 in DDPG/TD3.
+ random_collect_size=25000,
+ # (str) Action space type
+ action_space='continuous', # ['continuous', 'hybrid']
+ # (bool) Whether use batch normalization for reward
+ reward_batch_norm=False,
+ original_action_shape=2,
+ model=dict(
+ # (bool) Whether to use two critic networks or only one.
+ # Clipped Double Q-Learning for Actor-Critic in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default True for TD3, False for DDPG.
+ twin_critic=True,
+ ),
+ learn=dict(
+
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=1,
+ # (int) Minibatch size for gradient descent.
+ batch_size=256,
+ # (float) Learning rates for actor network(aka. policy).
+ learning_rate_actor=1e-3,
+ # (float) Learning rates for critic network(aka. Q-network).
+ learning_rate_critic=1e-3,
+ # (bool) Whether ignore done(usually for max step termination env. e.g. pendulum)
+ # Note: Gym wraps the MuJoCo envs by default with TimeLimit environment wrappers.
+ # These limit HalfCheetah, and several other MuJoCo envs, to max length of 1000.
+ # However, interaction with HalfCheetah always gets done with False,
+ # Since we inplace done==True with done==False to keep
+ # TD-error accurate computation(``gamma * (1 - done) * next_v + reward``),
+ # when the episode step is greater than max episode step.
+ ignore_done=False,
+ # (float type) target_theta: Used for soft update of the target network,
+ # aka. Interpolation factor in polyak averaging for target networks.
+ # Default to 0.005.
+ target_theta=0.005,
+ # (float) discount factor for the discounted sum of rewards, aka. gamma.
+ discount_factor=0.99,
+ # (int) When critic network updates once, how many times will actor network update.
+ # Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default 1 for DDPG, 2 for TD3.
+ actor_update_freq=2,
+ # (bool) Whether to add noise on target network's action.
+ # Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default True for TD3, False for DDPG.
+ noise=True,
+ # (float) Sigma for smoothing noise added to target policy.
+ noise_sigma=0.2,
+ # (dict) Limit for range of target policy smoothing noise, aka. noise_clip.
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ # n_sample=1,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ # (float) It is a must to add noise during collection. So here omits "noise" and only set "noise_sigma".
+ noise_sigma=0.1,
+ ),
+ eval=dict(
+ evaluator=dict(
+ # (int) Evaluate every "eval_freq" training iterations.
+ eval_freq=5000,
+ ),
+ ),
+ other=dict(
+ replay_buffer=dict(
+ # (int) Maximum size of replay buffer.
+ replay_buffer_size=100000,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'continuous_qac', ['ding.model.template.qac']
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init actor and critic optimizers, algorithm config, main and target models.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ # actor and critic optimizer
+ self._optimizer_actor = Adam(
+ self._model.actor.parameters(),
+ lr=self._cfg.learn.learning_rate_actor,
+ )
+ self._optimizer_critic = Adam(
+ self._model.critic.parameters(),
+ lr=self._cfg.learn.learning_rate_critic,
+ )
+ self._reward_batch_norm = self._cfg.reward_batch_norm
+
+ self._gamma = self._cfg.learn.discount_factor
+ self._actor_update_freq = self._cfg.learn.actor_update_freq
+ self._twin_critic = self._cfg.model.twin_critic # True for TD3, False for DDPG
+
+ # main and target models
+ self._target_model = copy.deepcopy(self._model)
+ if self._cfg.action_space == 'hybrid':
+ self._target_model = model_wrap(self._target_model, wrapper_name='hybrid_argmax_sample')
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='target',
+ update_type='momentum',
+ update_kwargs={'theta': self._cfg.learn.target_theta}
+ )
+ if self._cfg.learn.noise:
+ self._target_model = model_wrap(
+ self._target_model,
+ wrapper_name='action_noise',
+ noise_type='gauss',
+ noise_kwargs={
+ 'mu': 0.0,
+ 'sigma': self._cfg.learn.noise_sigma
+ },
+ noise_range=self._cfg.learn.noise_range
+ )
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ if self._cfg.action_space == 'hybrid':
+ self._learn_model = model_wrap(self._learn_model, wrapper_name='hybrid_argmax_sample')
+ self._learn_model.reset()
+ self._target_model.reset()
+
+ self._forward_learn_cnt = 0 # count iterations
+ # action_shape, obs_shape, latent_action_dim, hidden_size_list
+ self._vae_model = VanillaVAE(
+ self._cfg.original_action_shape, self._cfg.model.obs_shape, self._cfg.model.action_shape, [256, 256]
+ )
+ # self._vae_model = VanillaVAE(2, 8, 6, [256, 256])
+
+ self._optimizer_vae = Adam(
+ self._vae_model.parameters(),
+ lr=self._cfg.learn.learning_rate_vae,
+ )
+ self._running_mean_std_predict_loss = RunningMeanStd(epsilon=1e-4)
+ self.c_percentage_bound_lower = -1 * torch.ones([6])
+ self.c_percentage_bound_upper = torch.ones([6])
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'action', 'reward', 'next_obs']
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including at least actor and critic lr, different losses.
+ """
+ # warmup phase
+ if 'warm_up' in data[0].keys() and data[0]['warm_up'] is True:
+ loss_dict = {}
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._cfg.priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ # ====================
+ # train vae
+ # ====================
+ result = self._vae_model({'action': data['action'], 'obs': data['obs']})
+
+ result['original_action'] = data['action']
+ result['true_residual'] = data['next_obs'] - data['obs']
+
+ vae_loss = self._vae_model.loss_function(result, kld_weight=0.01, predict_weight=0.01) # TODO(pu): weight
+
+ loss_dict['vae_loss'] = vae_loss['loss'].item()
+ loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss'].item()
+ loss_dict['kld_loss'] = vae_loss['kld_loss'].item()
+ loss_dict['predict_loss'] = vae_loss['predict_loss'].item()
+ self._running_mean_std_predict_loss.update(vae_loss['predict_loss'].unsqueeze(-1).cpu().detach().numpy())
+
+ # vae update
+ self._optimizer_vae.zero_grad()
+ vae_loss['loss'].backward()
+ self._optimizer_vae.step()
+ # For compatibility
+ loss_dict['actor_loss'] = torch.Tensor([0]).item()
+ loss_dict['critic_loss'] = torch.Tensor([0]).item()
+ loss_dict['critic_twin_loss'] = torch.Tensor([0]).item()
+ loss_dict['total_loss'] = torch.Tensor([0]).item()
+ q_value_dict = {}
+ q_value_dict['q_value'] = torch.Tensor([0]).item()
+ q_value_dict['q_value_twin'] = torch.Tensor([0]).item()
+ return {
+ 'cur_lr_actor': self._optimizer_actor.defaults['lr'],
+ 'cur_lr_critic': self._optimizer_critic.defaults['lr'],
+ 'action': torch.Tensor([0]).item(),
+ 'priority': torch.Tensor([0]).item(),
+ 'td_error': torch.Tensor([0]).item(),
+ **loss_dict,
+ **q_value_dict,
+ }
+ else:
+ self._forward_learn_cnt += 1
+ loss_dict = {}
+ q_value_dict = {}
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._cfg.priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=False
+ )
+ if data['vae_phase'][0].item() is True:
+ if self._cuda:
+ data = to_device(data, self._device)
+
+ # ====================
+ # train vae
+ # ====================
+ result = self._vae_model({'action': data['action'], 'obs': data['obs']})
+
+ result['original_action'] = data['action']
+ result['true_residual'] = data['next_obs'] - data['obs']
+
+ # latent space constraint (LSC)
+ # NOTE: using tanh is important, update latent_action using z, shape (128,6)
+ data['latent_action'] = torch.tanh(result['z'].clone().detach()) # NOTE: tanh
+ # data['latent_action'] = result['z'].clone().detach()
+ self.c_percentage_bound_lower = data['latent_action'].sort(dim=0)[0][int(
+ result['recons_action'].shape[0] * 0.02
+ ), :] # values, indices
+ self.c_percentage_bound_upper = data['latent_action'].sort(
+ dim=0
+ )[0][int(result['recons_action'].shape[0] * 0.98), :]
+
+ vae_loss = self._vae_model.loss_function(
+ result, kld_weight=0.01, predict_weight=0.01
+ ) # TODO(pu): weight
+
+ loss_dict['vae_loss'] = vae_loss['loss']
+ loss_dict['reconstruction_loss'] = vae_loss['reconstruction_loss']
+ loss_dict['kld_loss'] = vae_loss['kld_loss']
+ loss_dict['predict_loss'] = vae_loss['predict_loss']
+
+ # vae update
+ self._optimizer_vae.zero_grad()
+ vae_loss['loss'].backward()
+ self._optimizer_vae.step()
+
+ return {
+ 'cur_lr_actor': self._optimizer_actor.defaults['lr'],
+ 'cur_lr_critic': self._optimizer_critic.defaults['lr'],
+ # 'q_value': np.array(q_value).mean(),
+ 'action': torch.Tensor([0]).item(),
+ 'priority': torch.Tensor([0]).item(),
+ 'td_error': torch.Tensor([0]).item(),
+ **loss_dict,
+ **q_value_dict,
+ }
+
+ else:
+ # ====================
+ # critic learn forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ next_obs = data['next_obs']
+ reward = data['reward']
+
+ # ====================
+ # relabel latent action
+ # ====================
+ if self._cuda:
+ data = to_device(data, self._device)
+ result = self._vae_model({'action': data['action'], 'obs': data['obs']})
+ true_residual = data['next_obs'] - data['obs']
+
+ # Representation shift correction (RSC)
+ for i in range(result['recons_action'].shape[0]):
+ if F.mse_loss(result['prediction_residual'][i],
+ true_residual[i]).item() > 4 * self._running_mean_std_predict_loss.mean:
+ # NOTE: using tanh is important, update latent_action using z
+ data['latent_action'][i] = torch.tanh(result['z'][i].clone().detach()) # NOTE: tanh
+ # data['latent_action'][i] = result['z'][i].clone().detach()
+
+ # update all latent action
+ # data['latent_action'] = torch.tanh(result['z'].clone().detach())
+
+ if self._reward_batch_norm:
+ reward = (reward - reward.mean()) / (reward.std() + 1e-8)
+
+ # current q value
+ q_value = self._learn_model.forward(
+ {
+ 'obs': data['obs'],
+ 'action': data['latent_action']
+ }, mode='compute_critic'
+ )['q_value']
+ q_value_dict = {}
+ if self._twin_critic:
+ q_value_dict['q_value'] = q_value[0].mean()
+ q_value_dict['q_value_twin'] = q_value[1].mean()
+ else:
+ q_value_dict['q_value'] = q_value.mean()
+ # target q value.
+ with torch.no_grad():
+ # NOTE: here next_actor_data['action'] is latent action
+ next_actor_data = self._target_model.forward(next_obs, mode='compute_actor')
+ next_actor_data['obs'] = next_obs
+ target_q_value = self._target_model.forward(next_actor_data, mode='compute_critic')['q_value']
+ if self._twin_critic:
+ # TD3: two critic networks
+ target_q_value = torch.min(target_q_value[0], target_q_value[1]) # find min one as target q value
+ # critic network1
+ td_data = v_1step_td_data(q_value[0], target_q_value, reward, data['done'], data['weight'])
+ critic_loss, td_error_per_sample1 = v_1step_td_error(td_data, self._gamma)
+ loss_dict['critic_loss'] = critic_loss
+ # critic network2(twin network)
+ td_data_twin = v_1step_td_data(q_value[1], target_q_value, reward, data['done'], data['weight'])
+ critic_twin_loss, td_error_per_sample2 = v_1step_td_error(td_data_twin, self._gamma)
+ loss_dict['critic_twin_loss'] = critic_twin_loss
+ td_error_per_sample = (td_error_per_sample1 + td_error_per_sample2) / 2
+ else:
+ # DDPG: single critic network
+ td_data = v_1step_td_data(q_value, target_q_value, reward, data['done'], data['weight'])
+ critic_loss, td_error_per_sample = v_1step_td_error(td_data, self._gamma)
+ loss_dict['critic_loss'] = critic_loss
+ # ================
+ # critic update
+ # ================
+ self._optimizer_critic.zero_grad()
+ for k in loss_dict:
+ if 'critic' in k:
+ loss_dict[k].backward()
+ self._optimizer_critic.step()
+ # ===============================
+ # actor learn forward and update
+ # ===============================
+ # actor updates every ``self._actor_update_freq`` iters
+ if (self._forward_learn_cnt + 1) % self._actor_update_freq == 0:
+ # NOTE: actor_data['action] is latent action
+ actor_data = self._learn_model.forward(data['obs'], mode='compute_actor')
+ actor_data['obs'] = data['obs']
+ if self._twin_critic:
+ actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'][0].mean()
+ else:
+ actor_loss = -self._learn_model.forward(actor_data, mode='compute_critic')['q_value'].mean()
+
+ loss_dict['actor_loss'] = actor_loss
+ # actor update
+ self._optimizer_actor.zero_grad()
+ actor_loss.backward()
+ self._optimizer_actor.step()
+ # =============
+ # after update
+ # =============
+ loss_dict['total_loss'] = sum(loss_dict.values())
+ # self._forward_learn_cnt += 1
+ self._target_model.update(self._learn_model.state_dict())
+ if self._cfg.action_space == 'hybrid':
+ action_log_value = -1. # TODO(nyz) better way to viz hybrid action
+ else:
+ action_log_value = data['action'].mean()
+
+ return {
+ 'cur_lr_actor': self._optimizer_actor.defaults['lr'],
+ 'cur_lr_critic': self._optimizer_critic.defaults['lr'],
+ 'action': action_log_value,
+ 'priority': td_error_per_sample.abs().tolist(),
+ 'td_error': td_error_per_sample.abs().mean(),
+ **loss_dict,
+ **q_value_dict,
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'target_model': self._target_model.state_dict(),
+ 'optimizer_actor': self._optimizer_actor.state_dict(),
+ 'optimizer_critic': self._optimizer_critic.state_dict(),
+ 'vae_model': self._vae_model.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._target_model.load_state_dict(state_dict['target_model'])
+ self._optimizer_actor.load_state_dict(state_dict['optimizer_actor'])
+ self._optimizer_critic.load_state_dict(state_dict['optimizer_critic'])
+ self._vae_model.load_state_dict(state_dict['vae_model'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ # collect model
+ self._collect_model = model_wrap(
+ self._model,
+ wrapper_name='action_noise',
+ noise_type='gauss',
+ noise_kwargs={
+ 'mu': 0.0,
+ 'sigma': self._cfg.collect.noise_sigma
+ },
+ noise_range=None
+ )
+ if self._cfg.action_space == 'hybrid':
+ self._collect_model = model_wrap(self._collect_model, wrapper_name='hybrid_eps_greedy_multinomial_sample')
+ self._collect_model.reset()
+
+ def _forward_collect(self, data: dict, **kwargs) -> dict:
+ r"""
+ Overview:
+ Forward function of collect mode.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
+ ReturnsKeys
+ - necessary: ``action``
+ - optional: ``logit``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data, mode='compute_actor', **kwargs)
+ output['latent_action'] = output['action']
+
+ # latent space constraint (LSC)
+ for i in range(output['action'].shape[-1]):
+ output['action'][:, i].clamp_(
+ self.c_percentage_bound_lower[i].item(), self.c_percentage_bound_upper[i].item()
+ )
+
+ # TODO(pu): decode into original hybrid actions, here data is obs
+ # this is very important to generate self.obs_encoding using in decode phase
+ output['action'] = self._vae_model.decode_with_obs(output['action'], data)['reconstruction_action']
+
+ # NOTE: add noise in the original actions
+ from ding.rl_utils.exploration import GaussianNoise
+ action = output['action']
+ gaussian_noise = GaussianNoise(mu=0.0, sigma=0.1)
+ noise = gaussian_noise(output['action'].shape, output['action'].device)
+ if self._cfg.learn.noise_range is not None:
+ noise = noise.clamp(self._cfg.learn.noise_range['min'], self._cfg.learn.noise_range['max'])
+ action += noise
+ self.action_range = {'min': -1, 'max': 1}
+ if self.action_range is not None:
+ action = action.clamp(self.action_range['min'], self.action_range['max'])
+ output['action'] = action
+
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done'] \
+ (here 'obs' indicates obs after env step, i.e. next_obs).
+ Return:
+ - transition (:obj:`Dict[str, Any]`): Dict type transition data.
+ """
+ if 'latent_action' in model_output.keys():
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': model_output['action'],
+ 'latent_action': model_output['latent_action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ else: # if random collect at fist
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': model_output['action'],
+ 'latent_action': 999,
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ if self._cfg.action_space == 'hybrid':
+ transition['logit'] = model_output['logit']
+ return transition
+
+ def _get_train_sample(self, data: list) -> Union[None, List[Any]]:
+ return get_train_sample(data, self._unroll_len)
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model. Unlike learn and collect model, eval model does not need noise.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='base')
+ if self._cfg.action_space == 'hybrid':
+ self._eval_model = model_wrap(self._eval_model, wrapper_name='hybrid_argmax_sample')
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of eval mode, similar to ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ReturnsKeys
+ - necessary: ``action``
+ - optional: ``logit``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data, mode='compute_actor')
+ output['latent_action'] = output['action']
+
+ # latent space constraint (LSC)
+ for i in range(output['action'].shape[-1]):
+ output['action'][:, i].clamp_(
+ self.c_percentage_bound_lower[i].item(), self.c_percentage_bound_upper[i].item()
+ )
+
+ # TODO(pu): decode into original hybrid actions, here data is obs
+ # this is very important to generate self.obs_encoding using in decode phase
+ output['action'] = self._vae_model.decode_with_obs(output['action'], data)['reconstruction_action']
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _monitor_vars_learn(self) -> List[str]:
+ r"""
+ Overview:
+ Return variables' names if variables are to used in monitor.
+ Returns:
+ - vars (:obj:`List[str]`): Variables' name list.
+ """
+ ret = [
+ 'cur_lr_actor', 'cur_lr_critic', 'critic_loss', 'actor_loss', 'total_loss', 'q_value', 'q_value_twin',
+ 'action', 'td_error', 'vae_loss', 'reconstruction_loss', 'kld_loss', 'predict_loss'
+ ]
+ if self._twin_critic:
+ ret += ['critic_twin_loss']
+ return ret
diff --git a/DI-engine/ding/policy/tests/test_common_utils.py b/DI-engine/ding/policy/tests/test_common_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..38bf67ed9872055e1c414521505707d79ddf57b5
--- /dev/null
+++ b/DI-engine/ding/policy/tests/test_common_utils.py
@@ -0,0 +1,179 @@
+import unittest
+import pytest
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+
+from ding.policy.common_utils import default_preprocess_learn
+
+shape_test = [
+ [2],
+ [1],
+]
+
+dtype_test = [
+ "int64",
+ "float32",
+]
+
+data_type_test = [
+ "numpy",
+ "torch",
+ "treetensor",
+]
+
+
+def get_action(shape, dtype, class_type):
+ if class_type == "numpy":
+ if dtype == "int64":
+ dtype = np.int64
+ elif dtype == "float32":
+ dtype = np.float32
+ return np.random.randn(*shape).astype(dtype)
+ else:
+ if dtype == "int64":
+ dtype = torch.int64
+ elif dtype == "float32":
+ dtype = torch.float32
+
+ if class_type == "torch":
+ return torch.randn(*shape).type(dtype)
+ elif class_type == "treetensor":
+ return ttorch.randn(*shape).type(dtype)
+
+
+@pytest.mark.unittest
+def test_default_preprocess_learn_action():
+
+ for shape in shape_test:
+ for dtype in dtype_test:
+ for data_type in data_type_test:
+
+ data = [
+ {
+ 'obs': np.random.randn(4, 84, 84),
+ 'action': get_action(shape, dtype, data_type),
+ 'reward': 1.0,
+ 'next_obs': np.random.randn(4, 84, 84),
+ 'done': False,
+ 'weight': 1.0,
+ 'value': 1.0,
+ 'adv': 1.0,
+ } for _ in range(10)
+ ]
+ use_priority_IS_weight = False
+ use_priority = False
+ use_nstep = False
+ ignore_done = False
+ data = default_preprocess_learn(data, use_priority_IS_weight, use_priority, use_nstep, ignore_done)
+
+ assert data['obs'].shape == torch.Size([10, 4, 84, 84])
+ if dtype in ["int64"] and shape[0] == 1:
+ assert data['action'].shape == torch.Size([10])
+ else:
+ assert data['action'].shape == torch.Size([10, *shape])
+ assert data['reward'].shape == torch.Size([10])
+ assert data['next_obs'].shape == torch.Size([10, 4, 84, 84])
+ assert data['done'].shape == torch.Size([10])
+ assert data['weight'].shape == torch.Size([10])
+ assert data['value'].shape == torch.Size([10])
+ assert data['adv'].shape == torch.Size([10])
+
+
+@pytest.mark.unittest
+def test_default_preprocess_learn_reward_done_adv_1d():
+
+ data = [
+ {
+ 'obs': np.random.randn(4, 84, 84),
+ 'action': np.random.randn(2),
+ 'reward': np.array([1.0]),
+ 'next_obs': np.random.randn(4, 84, 84),
+ 'done': False,
+ 'value': np.array([1.0]),
+ 'adv': np.array([1.0]),
+ } for _ in range(10)
+ ]
+ use_priority_IS_weight = False
+ use_priority = False
+ use_nstep = False
+ ignore_done = False
+ data = default_preprocess_learn(data, use_priority_IS_weight, use_priority, use_nstep, ignore_done)
+
+ assert data['reward'].shape == torch.Size([10])
+ assert data['done'].shape == torch.Size([10])
+ assert data['weight'] is None
+ assert data['value'].shape == torch.Size([10])
+ assert data['adv'].shape == torch.Size([10])
+
+
+@pytest.mark.unittest
+def test_default_preprocess_learn_ignore_done():
+ data = [
+ {
+ 'obs': np.random.randn(4, 84, 84),
+ 'action': np.random.randn(2),
+ 'reward': np.array([1.0]),
+ 'next_obs': np.random.randn(4, 84, 84),
+ 'done': True,
+ 'value': np.array([1.0]),
+ 'adv': np.array([1.0]),
+ } for _ in range(10)
+ ]
+ use_priority_IS_weight = False
+ use_priority = False
+ use_nstep = False
+ ignore_done = True
+ data = default_preprocess_learn(data, use_priority_IS_weight, use_priority, use_nstep, ignore_done)
+
+ assert data['done'].dtype == torch.float32
+ assert torch.sum(data['done']) == 0
+
+
+@pytest.mark.unittest
+def test_default_preprocess_learn_use_priority_IS_weight():
+ data = [
+ {
+ 'obs': np.random.randn(4, 84, 84),
+ 'action': np.random.randn(2),
+ 'reward': 1.0,
+ 'next_obs': np.random.randn(4, 84, 84),
+ 'done': False,
+ 'priority_IS': 1.0,
+ 'value': 1.0,
+ 'adv': 1.0,
+ } for _ in range(10)
+ ]
+ use_priority_IS_weight = True
+ use_priority = True
+ use_nstep = False
+ ignore_done = False
+ data = default_preprocess_learn(data, use_priority_IS_weight, use_priority, use_nstep, ignore_done)
+
+ assert data['weight'].shape == torch.Size([10])
+ assert torch.sum(data['weight']) == torch.tensor(10.0)
+
+
+@pytest.mark.unittest
+def test_default_preprocess_learn_nstep():
+ data = [
+ {
+ 'obs': np.random.randn(4, 84, 84),
+ 'action': np.random.randn(2),
+ 'reward': np.array([1.0, 2.0, 0.0]),
+ 'next_obs': np.random.randn(4, 84, 84),
+ 'done': False,
+ 'value': 1.0,
+ 'adv': 1.0,
+ } for _ in range(10)
+ ]
+ use_priority_IS_weight = False
+ use_priority = False
+ use_nstep = True
+ ignore_done = False
+ data = default_preprocess_learn(data, use_priority_IS_weight, use_priority, use_nstep, ignore_done)
+
+ assert data['reward'].shape == torch.Size([3, 10])
+ assert data['reward'][0][0] == torch.tensor(1.0)
+ assert data['reward'][1][0] == torch.tensor(2.0)
+ assert data['reward'][2][0] == torch.tensor(0.0)
diff --git a/DI-engine/ding/policy/tests/test_cql.py b/DI-engine/ding/policy/tests/test_cql.py
new file mode 100644
index 0000000000000000000000000000000000000000..248653da6ad30501bd429b73df4a988e8be879c4
--- /dev/null
+++ b/DI-engine/ding/policy/tests/test_cql.py
@@ -0,0 +1,103 @@
+import copy
+
+import pytest
+import torch
+from easydict import EasyDict
+from ding.policy.cql import CQLPolicy, DiscreteCQLPolicy
+from ding.utils.data import offline_data_save_type
+from tensorboardX import SummaryWriter
+from ding.model.wrapper.model_wrappers import ArgmaxSampleWrapper, EpsGreedySampleWrapper, TargetNetworkWrapper
+import os
+from typing import List
+from collections import namedtuple
+from ding.utils import deep_merge_dicts
+
+obs_space = 5
+action_space = 3
+
+cfg1 = EasyDict(CQLPolicy.default_config())
+cfg1.model.obs_shape = obs_space
+cfg1.model.action_shape = action_space
+
+cfg2 = copy.deepcopy(cfg1)
+cfg2.learn.auto_alpha = False
+cfg2.learn.log_space = False
+
+cfg3 = EasyDict(DiscreteCQLPolicy.default_config())
+cfg3.model = {}
+cfg3.model.obs_shape = obs_space
+cfg3.model.action_shape = action_space
+
+cfg4 = copy.deepcopy(cfg3)
+cfg4.learn.auto_alpha = False
+
+
+def get_batch(size=8):
+ data = {}
+ for i in range(size):
+ obs = torch.zeros(obs_space)
+ data[i] = obs
+ return data
+
+
+def get_transition(size=20):
+ data = []
+ for i in range(size):
+ sample = {}
+ sample['obs'] = torch.zeros(obs_space)
+ sample['action'] = torch.zeros(action_space)
+ sample['done'] = False
+ sample['next_obs'] = torch.zeros(obs_space)
+ sample['reward'] = torch.Tensor([1.])
+ data.append(sample)
+ return data
+
+
+def get_transition_batch(bs=1):
+ sample = {}
+ sample['obs'] = torch.zeros(bs, obs_space)
+ sample['action'] = torch.zeros(bs, action_space)
+ return sample
+
+
+@pytest.mark.parametrize('cfg', [cfg1, cfg2])
+@pytest.mark.unittest
+def test_cql_continuous(cfg):
+ policy = CQLPolicy(cfg, enable_field=['collect', 'eval', 'learn'])
+ assert type(policy._target_model) == TargetNetworkWrapper
+ q_value = policy._get_q_value(get_transition_batch(cfg.learn.num_actions))
+ assert q_value[0].shape[-1] == 1 and q_value[0].shape[-2] == cfg.learn.num_actions
+ act, log_prob = policy._get_policy_actions(get_transition_batch(cfg.learn.num_actions))
+ assert list(act.shape) == [cfg.learn.num_actions * 10, action_space]
+ sample = get_transition(size=20)
+ out = policy._forward_learn(sample)
+
+
+def get_transition_discrete(size=20):
+ data = []
+ for i in range(size):
+ sample = {}
+ sample['obs'] = torch.zeros(obs_space)
+ sample['action'] = torch.tensor(i % action_space)
+ sample['done'] = False
+ sample['next_obs'] = torch.zeros(obs_space)
+ sample['reward'] = torch.Tensor([1.])
+ data.append(sample)
+ return data
+
+
+@pytest.mark.parametrize('cfg', [cfg3, cfg4])
+@pytest.mark.unittest
+def test_cql_discrete(cfg):
+ policy = DiscreteCQLPolicy(cfg, enable_field=['collect', 'eval', 'learn'])
+ assert type(policy._learn_model) == ArgmaxSampleWrapper
+ assert type(policy._target_model) == TargetNetworkWrapper
+ assert type(policy._collect_model) == EpsGreedySampleWrapper
+ sample = get_transition_batch(bs=20)
+ samples = policy._get_train_sample(sample)
+ assert len(samples['obs']) == 20
+ state = policy._state_dict_learn()
+ policy._load_state_dict_learn(state)
+ sample = get_transition_discrete(size=1)
+ out = policy._forward_learn(sample)
+ out = policy._forward_collect(get_batch(size=8), eps=0.1)
diff --git a/DI-engine/ding/policy/tests/test_r2d3.py b/DI-engine/ding/policy/tests/test_r2d3.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a04eb1b712bafdf7de6417492fca1748565b9a5
--- /dev/null
+++ b/DI-engine/ding/policy/tests/test_r2d3.py
@@ -0,0 +1,137 @@
+import pytest
+import torch
+from easydict import EasyDict
+from ding.policy.r2d3 import R2D3Policy
+from ding.utils.data import offline_data_save_type
+from tensorboardX import SummaryWriter
+from ding.model.wrapper.model_wrappers import ArgmaxSampleWrapper, HiddenStateWrapper, EpsGreedySampleWrapper
+import os
+from typing import List
+from collections import namedtuple
+
+obs_space = 5
+action_space = 4
+
+cfg = dict(
+ cuda=True,
+ on_policy=False,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=obs_space,
+ action_shape=action_space,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ discount_factor=0.99,
+ burnin_step=2,
+ nstep=5,
+ learn_unroll_len=20,
+ burning_step=5,
+ learn=dict(
+ value_rescale=True,
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ lambda1=1.0, # n-step return
+ lambda2=1.0, # supervised loss
+ lambda3=1e-5, # L2 it's very important to set Adam optimizer optim_type='adamw'.
+ lambda_one_step_td=1, # 1-step return
+ margin_function=0.8, # margin function in JE, here we implement this as a constant
+ per_train_iter_k=0,
+ ignore_done=False,
+ ),
+ collect=dict(
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=8,
+ pho=1 / 4,
+ ),
+ eval=dict(env_num=8, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=100000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=int(1e4),
+ alpha=0.6,
+ beta=0.4,
+ ),
+ ),
+)
+cfg = EasyDict(cfg)
+
+
+def get_batch(size=8):
+ data = {}
+ for i in range(size):
+ obs = torch.zeros(obs_space)
+ data[i] = obs
+ return data
+
+
+def get_transition(size=20):
+ data = []
+ import numpy as np
+ for i in range(size):
+ sample = {}
+ sample['obs'] = torch.zeros(obs_space)
+ sample['action'] = torch.tensor(np.array([int(i % action_space)]))
+ sample['done'] = False
+ sample['prev_state'] = [torch.randn(1, 1, 512) for __ in range(2)]
+ sample['reward'] = torch.Tensor([1.])
+ sample['IS'] = 1.
+ sample['is_expert'] = bool(i % 2)
+ data.append(sample)
+ return data
+
+
+@pytest.mark.parametrize('cfg', [cfg])
+@pytest.mark.unittest
+def test_r2d3(cfg):
+ policy = R2D3Policy(cfg, enable_field=['collect', 'eval'])
+ policy._init_learn()
+ assert type(policy._learn_model) == ArgmaxSampleWrapper
+ assert type(policy._target_model) == HiddenStateWrapper
+ policy._reset_learn()
+ policy._reset_learn([0])
+ state = policy._state_dict_learn()
+ policy._load_state_dict_learn(state)
+ policy._init_collect()
+ assert type(policy._collect_model) == EpsGreedySampleWrapper
+ policy._reset_collect()
+ policy._reset_collect([0])
+ policy._init_eval()
+ assert type(policy._eval_model) == ArgmaxSampleWrapper
+ policy._reset_eval()
+ policy._reset_eval([0])
+ assert policy.default_model()[0] == 'drqn'
+ var = policy._monitor_vars_learn()
+ assert type(var) == list
+ assert sum([type(s) == str for s in var]) == len(var)
+ batch = get_batch(8)
+ out = policy._forward_collect(batch, eps=0.1)
+ assert len(set(out[0].keys()).intersection({'logit', 'prev_state', 'action'})) == 3
+ assert list(out[0]['logit'].shape) == [action_space]
+ timestep = namedtuple('timestep', ['reward', 'done'])
+ ts = timestep(
+ 1.,
+ 0.,
+ )
+ ts = policy._process_transition(batch[0], out[0], ts)
+ assert len(set(ts.keys()).intersection({'prev_state', 'action', 'reward', 'done', 'obs'})) == 5
+ ts = get_transition(64 * policy._sequence_len)
+ sample = policy._get_train_sample(ts)
+ n_traj = len(ts) // policy._sequence_len
+ assert len(sample) == n_traj + 1 if len(ts) % policy._sequence_len != 0 else n_traj
+ out = policy._forward_eval(batch)
+ assert len(set(out[0].keys()).intersection({'logit', 'action'})) == 2
+ assert list(out[0]['logit'].shape) == [action_space]
+ for i in range(len(sample)):
+ sample[i]['IS'] = sample[i]['IS'][cfg.burnin_step:]
+ out = policy._forward_learn(sample)
+ policy._value_rescale = False
+ out = policy._forward_learn(sample)
diff --git a/DI-engine/ding/policy/tests/test_stdim.py b/DI-engine/ding/policy/tests/test_stdim.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ff459c9220595e05070aae4d74c73dc2f2db8e6
--- /dev/null
+++ b/DI-engine/ding/policy/tests/test_stdim.py
@@ -0,0 +1,44 @@
+from copy import deepcopy
+import pytest
+import torch
+from easydict import EasyDict
+from ding.model.wrapper.model_wrappers import BaseModelWrapper, MultinomialSampleWrapper
+from ding.policy import PPOSTDIMPolicy
+
+obs_shape = 4
+action_shape = 2
+
+cfg1 = EasyDict(PPOSTDIMPolicy.default_config())
+cfg1.model.obs_shape = obs_shape
+cfg1.model.action_shape = action_shape
+
+cfg2 = deepcopy(cfg1)
+cfg2.action_space = "continuous"
+
+
+def get_transition_discrete(size=64):
+ data = []
+ for i in range(size):
+ sample = {}
+ sample['obs'] = torch.rand(obs_shape)
+ sample['next_obs'] = torch.rand(obs_shape)
+ sample['action'] = torch.tensor([0], dtype=torch.long)
+ sample['value'] = torch.rand(1)
+ sample['logit'] = torch.rand(size=(action_shape, ))
+ sample['done'] = False
+ sample['reward'] = torch.rand(1)
+ data.append(sample)
+ return data
+
+
+@pytest.mark.parametrize('cfg', [cfg1])
+@pytest.mark.unittest
+def test_stdim(cfg):
+ policy = PPOSTDIMPolicy(cfg, enable_field=['collect', 'eval', 'learn'])
+ assert type(policy._learn_model) == BaseModelWrapper
+ assert type(policy._collect_model) == MultinomialSampleWrapper
+ sample = get_transition_discrete(size=64)
+ state = policy._state_dict_learn()
+ policy._load_state_dict_learn(state)
+ sample = get_transition_discrete(size=64)
+ out = policy._forward_learn(sample)
diff --git a/DI-engine/ding/policy/wqmix.py b/DI-engine/ding/policy/wqmix.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a27cdce9431a2f6ecb3d3723f76aac81a895075
--- /dev/null
+++ b/DI-engine/ding/policy/wqmix.py
@@ -0,0 +1,314 @@
+from typing import List, Dict, Any, Tuple, Union, Optional
+from collections import namedtuple
+import torch
+import copy
+
+from ding.torch_utils import RMSprop, to_device
+from ding.rl_utils import v_1step_td_data, v_1step_td_error, get_train_sample
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY
+from ding.utils.data import timestep_collate, default_collate, default_decollate
+from .base_policy import Policy
+from ding.policy.qmix import QMIXPolicy
+
+
+@POLICY_REGISTRY.register('wqmix')
+class WQMIXPolicy(QMIXPolicy):
+ r"""
+ Overview:
+ Policy class of WQMIX algorithm. WQMIX is a reinforcement learning algorithm modified from Qmix, \
+ you can view the paper in the following link https://arxiv.org/abs/2006.10800
+ Interface:
+ _init_learn, _data_preprocess_learn, _forward_learn, _reset_learn, _state_dict_learn, _load_state_dict_learn\
+ _init_collect, _forward_collect, _reset_collect, _process_transition, _init_eval, _forward_eval\
+ _reset_eval, _get_train_sample, default_model
+ Config:
+ == ==================== ======== ============== ======================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============== ======================================== =======================
+ 1 ``type`` str qmix | RL policy register name, refer to | this arg is optional,
+ | registry ``POLICY_REGISTRY`` | a placeholder
+ 2 ``cuda`` bool True | Whether to use cuda for network | this arg can be diff-
+ | erent from modes
+ 3 ``on_policy`` bool False | Whether the RL algorithm is on-policy
+ | or off-policy
+ 4. ``priority`` bool False | Whether use priority(PER) | priority sample,
+ | update priority
+ 5 | ``priority_`` bool False | Whether use Importance Sampling | IS weight
+ | ``IS_weight`` | Weight to correct biased update.
+ 6 | ``learn.update_`` int 20 | How many updates(iterations) to train | this args can be vary
+ | ``per_collect`` | after collector's one collection. Only | from envs. Bigger val
+ | valid in serial training | means more off-policy
+ 7 | ``learn.target_`` float 0.001 | Target network update momentum | between[0,1]
+ | ``update_theta`` | parameter.
+ 8 | ``learn.discount`` float 0.99 | Reward's future discount factor, aka. | may be 1 when sparse
+ | ``_factor`` | gamma | reward env
+ == ==================== ======== ============== ======================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='wqmix',
+ # (bool) Whether to use cuda for network.
+ cuda=True,
+ # (bool) Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=100,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) Target network update momentum parameter.
+ # in [0, 1].
+ target_update_theta=0.008,
+ # (float) The discount factor for future rewards,
+ # in [0, 1].
+ discount_factor=0.99,
+ w=0.5, # for OW
+ # w = 0.75, # for CW
+ wqmix_ow=True,
+ ),
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] shoule be set
+ # n_episode=32,
+ # (int) Cut trajectories into pieces with length "unroll_len", the length of timesteps
+ # in each forward when training. In qmix, it is greater than 1 because there is RNN.
+ unroll_len=10,
+ ),
+ eval=dict(),
+ other=dict(
+ eps=dict(
+ # (str) Type of epsilon decay
+ type='exp',
+ # (float) Start value for epsilon decay, in [0, 1].
+ # 0 means not use epsilon decay.
+ start=1,
+ # (float) Start value for epsilon decay, in [0, 1].
+ end=0.05,
+ # (int) Decay length(env step)
+ decay=50000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=5000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ )
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ """
+ Overview:
+ Return this algorithm default model setting for demonstration.
+ Returns:
+ - model_info (:obj:`Tuple[str, List[str]]`): model name and mode import_names
+ .. note::
+ The user can define and use customized network model but must obey the same inferface definition indicated \
+ by import_names path. For WQMIX, ``ding.model.template.wqmix``
+ """
+ return 'wqmix', ['ding.model.template.wqmix']
+
+ def _init_learn(self) -> None:
+ """
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the learner model of WQMIXPolicy
+ Arguments:
+ .. note::
+
+ The _init_learn method takes the argument from the self._cfg.learn in the config file
+
+ - learning_rate (:obj:`float`): The learning rate fo the optimizer
+ - gamma (:obj:`float`): The discount factor
+ - agent_num (:obj:`int`): Since this is a multi-agent algorithm, we need to input the agent num.
+ - batch_size (:obj:`int`): Need batch size info to init hidden_state plugins
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in WQMIX"
+ self._optimizer = RMSprop(
+ params=list(self._model._q_network.parameters()) + list(self._model._mixer.parameters()),
+ lr=self._cfg.learn.learning_rate,
+ alpha=0.99,
+ eps=0.00001
+ )
+ self._gamma = self._cfg.learn.discount_factor
+ self._optimizer_star = RMSprop(
+ params=list(self._model._q_network_star.parameters()) + list(self._model._mixer_star.parameters()),
+ lr=self._cfg.learn.learning_rate,
+ alpha=0.99,
+ eps=0.00001
+ )
+ self._learn_model = model_wrap(
+ self._model,
+ wrapper_name='hidden_state',
+ state_num=self._cfg.learn.batch_size,
+ init_fn=lambda: [None for _ in range(self._cfg.model.agent_num)]
+ )
+ self._learn_model.reset()
+
+ def _data_preprocess_learn(self, data: List[Any]) -> dict:
+ r"""
+ Overview:
+ Preprocess the data to fit the required data format for learning
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): the data collected from collect function
+ Returns:
+ - data (:obj:`Dict[str, Any]`): the processed data, from \
+ [len=B, ele={dict_key: [len=T, ele=Tensor(any_dims)]}] -> {dict_key: Tensor([T, B, any_dims])}
+ """
+ # data preprocess
+ data = timestep_collate(data)
+ if self._cuda:
+ data = to_device(data, self._device)
+ data['weight'] = data.get('weight', None)
+ data['done'] = data['done'].float()
+ return data
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
+ np.ndarray or dict/list combinations.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
+ recorded in text log and tensorboard, values are python scalar or a list of scalars.
+ ArgumentsKeys:
+ - necessary: ``obs``, ``next_obs``, ``action``, ``reward``, ``weight``, ``prev_state``, ``done``
+ ReturnsKeys:
+ - necessary: ``cur_lr``, ``total_loss``
+ - cur_lr (:obj:`float`): Current learning rate
+ - total_loss (:obj:`float`): The calculated loss
+ """
+ data = self._data_preprocess_learn(data)
+ # ====================
+ # forward
+ # ====================
+ self._learn_model.train()
+
+ inputs = {'obs': data['obs'], 'action': data['action']}
+
+ # for hidden_state plugin, we need to reset the main model and target model
+ self._learn_model.reset(state=data['prev_state'][0])
+ total_q = self._learn_model.forward(inputs, single_step=False, q_star=False)['total_q']
+
+ self._learn_model.reset(state=data['prev_state'][0])
+ total_q_star = self._learn_model.forward(inputs, single_step=False, q_star=True)['total_q']
+
+ next_inputs = {'obs': data['next_obs']}
+ self._learn_model.reset(state=data['prev_state'][1]) # TODO(pu)
+ next_logit_detach = self._learn_model.forward(
+ next_inputs, single_step=False, q_star=False
+ )['logit'].clone().detach()
+
+ next_inputs = {'obs': data['next_obs'], 'action': next_logit_detach.argmax(dim=-1)}
+ with torch.no_grad():
+ self._learn_model.reset(state=data['prev_state'][1]) # TODO(pu)
+ target_total_q = self._learn_model.forward(next_inputs, single_step=False, q_star=True)['total_q']
+
+ with torch.no_grad():
+ if data['done'] is not None:
+ target_v = self._gamma * (1 - data['done']) * target_total_q + data['reward']
+ else:
+ target_v = self._gamma * target_total_q + data['reward']
+
+ td_error = (total_q - target_v).clone().detach()
+ data_ = v_1step_td_data(total_q, target_total_q, data['reward'], data['done'], data['weight'])
+ _, td_error_per_sample = v_1step_td_error(data_, self._gamma)
+
+ data_star = v_1step_td_data(total_q_star, target_total_q, data['reward'], data['done'], data['weight'])
+ loss_star, td_error_per_sample_star_ = v_1step_td_error(data_star, self._gamma)
+
+ # our implemention is based on the https://github.com/oxwhirl/wqmix
+ # Weighting
+ alpha_to_use = self._cfg.learn.alpha
+ if self._cfg.learn.wqmix_ow: # Optimistically-Weighted
+ ws = torch.full_like(td_error, alpha_to_use)
+ # if td_error < 0, i.e. Q < y_i, then w =1; if not, w = alpha_to_use
+ ws = torch.where(td_error < 0, torch.ones_like(td_error), ws)
+ else: # Centrally-Weighted
+ inputs = {'obs': data['obs']}
+ self._learn_model.reset(state=data['prev_state'][0]) # TODO(pu)
+ logit_detach = self._learn_model.forward(inputs, single_step=False, q_star=False)['logit'].clone().detach()
+ cur_max_actions = logit_detach.argmax(dim=-1)
+ inputs = {'obs': data['obs'], 'action': cur_max_actions}
+ self._learn_model.reset(state=data['prev_state'][0]) # TODO(pu)
+ max_action_qtot = self._learn_model.forward(inputs, single_step=False, q_star=True)['total_q'] # Q_star
+ # Only if the action of each agent is optimal, then the joint action is optimal
+ is_max_action = (data['action'] == cur_max_actions).min(dim=2)[0] # shape (H,B,N) -> (H,B)
+ qtot_larger = target_v > max_action_qtot
+ ws = torch.full_like(td_error, alpha_to_use)
+ # if y_i > Q_star or u = u_star, then w =1; if not, w = alpha_to_use
+ ws = torch.where(is_max_action | qtot_larger, torch.ones_like(td_error), ws)
+
+ if data['weight'] is None:
+ data['weight'] = torch.ones_like(data['reward'])
+ loss_weighted = (ws.detach() * td_error_per_sample * data['weight']).mean()
+
+ # ====================
+ # Q and Q_star update
+ # ====================
+ self._optimizer.zero_grad()
+ self._optimizer_star.zero_grad()
+ loss_weighted.backward(retain_graph=True)
+ loss_star.backward()
+ grad_norm_q = torch.nn.utils.clip_grad_norm_(
+ list(self._model._q_network.parameters()) + list(self._model._mixer.parameters()),
+ self._cfg.learn.clip_value
+ ) # Q
+ grad_norm_q_star = torch.nn.utils.clip_grad_norm_(
+ list(self._model._q_network_star.parameters()) + list(self._model._mixer_star.parameters()),
+ self._cfg.learn.clip_value
+ ) # Q_star
+ self._optimizer.step() # Q update
+ self._optimizer_star.step() # Q_star update
+
+ # =============
+ # after update
+ # =============
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss_weighted.item(),
+ 'total_q': total_q.mean().item() / self._cfg.model.agent_num,
+ 'target_reward_total_q': target_v.mean().item() / self._cfg.model.agent_num,
+ 'target_total_q': target_total_q.mean().item() / self._cfg.model.agent_num,
+ 'grad_norm_q': grad_norm_q,
+ 'grad_norm_q_star': grad_norm_q_star,
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Return the state_dict of learn mode, usually including model and optimizer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of current policy learn state, for saving and restoring.
+ """
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ 'optimizer_star': self._optimizer_star.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ r"""
+ Overview:
+ Load the state_dict variable into policy learn mode.
+ Arguments:
+ - state_dict (:obj:`Dict[str, Any]`): the dict of policy learn state saved before.
+ .. tip::
+ If you want to only load some parts of model, you can simply set the ``strict`` argument in \
+ load_state_dict to ``False``, or refer to ``ding.torch_utils.checkpoint_helper`` for more \
+ complicated operation.
+ """
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+ self._optimizer_star.load_state_dict(state_dict['optimizer_star'])
diff --git a/DI-engine/ding/reward_model/__init__.py b/DI-engine/ding/reward_model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4538102861be1ae94d6852777a837f2a6d01b182
--- /dev/null
+++ b/DI-engine/ding/reward_model/__init__.py
@@ -0,0 +1,15 @@
+from .base_reward_model import BaseRewardModel, create_reward_model, get_reward_model_cls
+# inverse RL
+from .pdeil_irl_model import PdeilRewardModel
+from .gail_irl_model import GailRewardModel
+from .pwil_irl_model import PwilRewardModel
+from .red_irl_model import RedRewardModel
+from .trex_reward_model import TrexRewardModel
+from .drex_reward_model import DrexRewardModel
+# sparse reward
+from .her_reward_model import HerRewardModel
+# exploration
+from .rnd_reward_model import RndRewardModel
+from .guided_cost_reward_model import GuidedCostRewardModel
+from .ngu_reward_model import RndNGURewardModel, EpisodicNGURewardModel
+from .icm_reward_model import ICMRewardModel
diff --git a/DI-engine/ding/reward_model/base_reward_model.py b/DI-engine/ding/reward_model/base_reward_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..963bacf1d7cbc183a24ff3707033ec2fcce70985
--- /dev/null
+++ b/DI-engine/ding/reward_model/base_reward_model.py
@@ -0,0 +1,142 @@
+from abc import ABC, abstractmethod
+from typing import Dict
+from easydict import EasyDict
+from ditk import logging
+import os
+import copy
+from typing import Any
+from ding.utils import REWARD_MODEL_REGISTRY, import_module, save_file
+
+
+class BaseRewardModel(ABC):
+ """
+ Overview:
+ the base class of reward model
+ Interface:
+ ``default_config``, ``estimate``, ``train``, ``clear_data``, ``collect_data``, ``load_expert_date``
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ @abstractmethod
+ def estimate(self, data: list) -> Any:
+ """
+ Overview:
+ estimate reward
+ Arguments:
+ - data (:obj:`List`): the list of data used for estimation
+ Returns / Effects:
+ - This can be a side effect function which updates the reward value
+ - If this function returns, an example returned object can be reward (:obj:`Any`): the estimated reward
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def train(self, data) -> None:
+ """
+ Overview:
+ Training the reward model
+ Arguments:
+ - data (:obj:`Any`): Data used for training
+ Effects:
+ - This is mostly a side effect function which updates the reward model
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def collect_data(self, data) -> None:
+ """
+ Overview:
+ Collecting training data in designated formate or with designated transition.
+ Arguments:
+ - data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc)
+ Returns / Effects:
+ - This can be a side effect function which updates the data attribute in ``self``
+ """
+ raise NotImplementedError()
+
+ @abstractmethod
+ def clear_data(self) -> None:
+ """
+ Overview:
+ Clearing training data. \
+ This can be a side effect function which clears the data attribute in ``self``
+ """
+ raise NotImplementedError()
+
+ def load_expert_data(self, data) -> None:
+ """
+ Overview:
+ Getting the expert data, usually used in inverse RL reward model
+ Arguments:
+ - data (:obj:`Any`): Expert data
+ Effects:
+ This is mostly a side effect function which updates the expert data attribute (e.g. ``self.expert_data``)
+ """
+ pass
+
+ def reward_deepcopy(self, train_data) -> Any:
+ """
+ Overview:
+ this method deepcopy reward part in train_data, and other parts keep shallow copy
+ to avoid the reward part of train_data in the replay buffer be incorrectly modified.
+ Arguments:
+ - train_data (:obj:`List`): the List of train data in which the reward part will be operated by deepcopy.
+ """
+ train_data_reward_deepcopy = [
+ {k: copy.deepcopy(v) if k == 'reward' else v
+ for k, v in sample.items()} for sample in train_data
+ ]
+ return train_data_reward_deepcopy
+
+ def state_dict(self) -> Dict:
+ # this method should be overrided by subclass.
+ return {}
+
+ def load_state_dict(self, _state_dict) -> None:
+ # this method should be overrided by subclass.
+ pass
+
+ def save(self, path: str = None, name: str = 'best'):
+ if path is None:
+ path = self.cfg.exp_name
+ path = os.path.join(path, 'reward_model', 'ckpt')
+ if not os.path.exists(path):
+ try:
+ os.makedirs(path)
+ except FileExistsError:
+ pass
+ path = os.path.join(path, 'ckpt_{}.pth.tar'.format(name))
+ state_dict = self.state_dict()
+ save_file(path, state_dict)
+ logging.info('Saved reward model ckpt in {}'.format(path))
+
+
+def create_reward_model(cfg: dict, device: str, tb_logger: 'SummaryWriter') -> BaseRewardModel: # noqa
+ """
+ Overview:
+ Reward Estimation Model.
+ Arguments:
+ - cfg (:obj:`Dict`): Training config
+ - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda"
+ - tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary
+ Returns:
+ - reward (:obj:`Any`): The reward model
+ """
+ cfg = copy.deepcopy(cfg)
+ if 'import_names' in cfg:
+ import_module(cfg.pop('import_names'))
+ if hasattr(cfg, 'reward_model'):
+ reward_model_type = cfg.reward_model.pop('type')
+ else:
+ reward_model_type = cfg.pop('type')
+ return REWARD_MODEL_REGISTRY.build(reward_model_type, cfg, device=device, tb_logger=tb_logger)
+
+
+def get_reward_model_cls(cfg: EasyDict) -> type:
+ import_module(cfg.get('import_names', []))
+ return REWARD_MODEL_REGISTRY.get(cfg.type)
diff --git a/DI-engine/ding/reward_model/drex_reward_model.py b/DI-engine/ding/reward_model/drex_reward_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..645b469088dc06da8444410b4d0df471ff4f5a4b
--- /dev/null
+++ b/DI-engine/ding/reward_model/drex_reward_model.py
@@ -0,0 +1,99 @@
+import copy
+from easydict import EasyDict
+import pickle
+
+from ding.utils import REWARD_MODEL_REGISTRY
+
+from .trex_reward_model import TrexRewardModel
+
+
+@REWARD_MODEL_REGISTRY.register('drex')
+class DrexRewardModel(TrexRewardModel):
+ """
+ Overview:
+ The Drex reward model class (https://arxiv.org/pdf/1907.03976.pdf)
+ Interface:
+ ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \
+ ``__init__``, ``_train``,
+ Config:
+ == ==================== ====== ============= ======================================= ===============
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ====== ============= ======================================= ===============
+ 1 ``type`` str drex | Reward model register name, refer |
+ | to registry ``REWARD_MODEL_REGISTRY`` |
+ 3 | ``learning_rate`` float 0.00001 | learning rate for optimizer |
+ 4 | ``update_per_`` int 100 | Number of updates per collect |
+ | ``collect`` | |
+ 5 | ``batch_size`` int 64 | How many samples in a training batch |
+ 6 | ``hidden_size`` int 128 | Linear model hidden size |
+ 7 | ``num_trajs`` int 0 | Number of downsampled full |
+ | trajectories |
+ 8 | ``num_snippets`` int 6000 | Number of short subtrajectories |
+ | to sample |
+ == ==================== ====== ============= ======================================= ================
+ """
+ config = dict(
+ # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
+ type='drex',
+ # (float) The step size of gradient descent.
+ learning_rate=1e-5,
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=100,
+ # (int) How many samples in a training batch.
+ batch_size=64,
+ # (int) Linear model hidden size
+ hidden_size=128,
+ # (int) Number of downsampled full trajectories.
+ num_trajs=0,
+ # (int) Number of short subtrajectories to sample.
+ num_snippets=6000,
+ )
+
+ bc_cfg = None
+
+ def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
+ """
+ Overview:
+ Initialize ``self.`` See ``help(type(self))`` for accurate signature.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Training config
+ - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda"
+ - tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary
+ """
+ super(DrexRewardModel, self).__init__(copy.deepcopy(config), device, tb_logger)
+
+ self.demo_data = []
+ self.load_expert_data()
+
+ def load_expert_data(self) -> None:
+ """
+ Overview:
+ Getting the expert data from ``config.expert_data_path`` attribute in self
+ Effects:
+ This is a side effect function which updates the expert data attribute \
+ (i.e. ``self.expert_data``) with ``fn:concat_state_action_pairs``
+ """
+ super(DrexRewardModel, self).load_expert_data()
+
+ with open(self.cfg.reward_model.offline_data_path + '/suboptimal_data.pkl', 'rb') as f:
+ self.demo_data = pickle.load(f)
+
+ def train(self):
+ self._train()
+ return_dict = self.pred_data(self.demo_data)
+ res, pred_returns = return_dict['real'], return_dict['pred']
+ self._logger.info("real: " + str(res))
+ self._logger.info("pred: " + str(pred_returns))
+
+ info = {
+ "min_snippet_length": self.min_snippet_length,
+ "max_snippet_length": self.max_snippet_length,
+ "len_num_training_obs": len(self.training_obs),
+ "lem_num_labels": len(self.training_labels),
+ "accuracy": self.calc_accuracy(self.reward_model, self.training_obs, self.training_labels),
+ }
+ self._logger.info(
+ "accuracy and comparison:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))
+ )
diff --git a/DI-engine/ding/reward_model/gail_irl_model.py b/DI-engine/ding/reward_model/gail_irl_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..6533e114dd0776fcc337c6e9257dd4a9ef32f706
--- /dev/null
+++ b/DI-engine/ding/reward_model/gail_irl_model.py
@@ -0,0 +1,293 @@
+from typing import List, Dict, Any
+import pickle
+import random
+from collections.abc import Iterable
+from easydict import EasyDict
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from ding.utils import REWARD_MODEL_REGISTRY
+from .base_reward_model import BaseRewardModel
+import torch.nn.functional as F
+from functools import partial
+
+
+def concat_state_action_pairs(iterator):
+ """
+ Overview:
+ Concatenate state and action pairs from input.
+ Arguments:
+ - iterator (:obj:`Iterable`): Iterables with at least ``obs`` and ``action`` tensor keys.
+ Returns:
+ - res (:obj:`Torch.tensor`): State and action pairs.
+ """
+ assert isinstance(iterator, Iterable)
+ res = []
+ for item in iterator:
+ state = item['obs'].flatten() # to allow 3d obs and actions concatenation
+ action = item['action']
+ s_a = torch.cat([state, action.float()], dim=-1)
+ res.append(s_a)
+ return res
+
+
+def concat_state_action_pairs_one_hot(iterator, action_size: int):
+ """
+ Overview:
+ Concatenate state and action pairs from input. Action values are one-hot encoded
+ Arguments:
+ - iterator (:obj:`Iterable`): Iterables with at least ``obs`` and ``action`` tensor keys.
+ Returns:
+ - res (:obj:`Torch.tensor`): State and action pairs.
+ """
+ assert isinstance(iterator, Iterable)
+ res = []
+ for item in iterator:
+ state = item['obs'].flatten() # to allow 3d obs and actions concatenation
+ action = item['action']
+ action = torch.Tensor([int(i == action) for i in range(action_size)])
+ s_a = torch.cat([state, action], dim=-1)
+ res.append(s_a)
+ return res
+
+
+class RewardModelNetwork(nn.Module):
+
+ def __init__(self, input_size: int, hidden_size: int, output_size: int) -> None:
+ super(RewardModelNetwork, self).__init__()
+ self.l1 = nn.Linear(input_size, hidden_size)
+ self.l2 = nn.Linear(hidden_size, output_size)
+ self.a1 = nn.Tanh()
+ self.a2 = nn.Sigmoid()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ out = x
+ out = self.l1(out)
+ out = self.a1(out)
+ out = self.l2(out)
+ out = self.a2(out)
+ return out
+
+
+class AtariRewardModelNetwork(nn.Module):
+
+ def __init__(self, input_size: int, action_size: int) -> None:
+ super(AtariRewardModelNetwork, self).__init__()
+ self.input_size = input_size
+ self.action_size = action_size
+ self.conv1 = nn.Conv2d(4, 16, 7, stride=3)
+ self.conv2 = nn.Conv2d(16, 16, 5, stride=2)
+ self.conv3 = nn.Conv2d(16, 16, 3, stride=1)
+ self.conv4 = nn.Conv2d(16, 16, 3, stride=1)
+ self.fc1 = nn.Linear(784, 64)
+ self.fc2 = nn.Linear(64 + self.action_size, 1) # here we add 1 to take consideration of the action concat
+ self.a = nn.Sigmoid()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ # input: x = [B, 4 x 84 x 84 + self.action_size], last element is action
+ actions = x[:, -self.action_size:] # [B, self.action_size]
+ # get observations
+ x = x[:, :-self.action_size]
+ x = x.reshape([-1] + self.input_size) # [B, 4, 84, 84]
+ x = F.leaky_relu(self.conv1(x))
+ x = F.leaky_relu(self.conv2(x))
+ x = F.leaky_relu(self.conv3(x))
+ x = F.leaky_relu(self.conv4(x))
+ x = x.reshape(-1, 784)
+ x = F.leaky_relu(self.fc1(x))
+ x = torch.cat([x, actions], dim=-1)
+ x = self.fc2(x)
+ r = self.a(x)
+ return r
+
+
+@REWARD_MODEL_REGISTRY.register('gail')
+class GailRewardModel(BaseRewardModel):
+ """
+ Overview:
+ The Gail reward model class (https://arxiv.org/abs/1606.03476)
+ Interface:
+ ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \
+ ``__init__``, ``state_dict``, ``load_state_dict``, ``learn``
+ Config:
+ == ==================== ======== ============= =================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============= =================================== =======================
+ 1 ``type`` str gail | RL policy register name, refer | this arg is optional,
+ | to registry ``POLICY_REGISTRY`` | a placeholder
+ 2 | ``expert_data_`` str expert_data. | Path to the expert dataset | Should be a '.pkl'
+ | ``path`` .pkl | | file
+ 3 | ``learning_rate`` float 0.001 | The step size of gradient descent |
+ 4 | ``update_per_`` int 100 | Number of updates per collect |
+ | ``collect`` | |
+ 5 | ``batch_size`` int 64 | Training batch size |
+ 6 | ``input_size`` int | Size of the input: |
+ | | obs_dim + act_dim |
+ 7 | ``target_new_`` int 64 | Collect steps per iteration |
+ | ``data_count`` | |
+ 8 | ``hidden_size`` int 128 | Linear model hidden size |
+ 9 | ``collect_count`` int 100000 | Expert dataset size | One entry is a (s,a)
+ | | | tuple
+ 10 | ``clear_buffer_`` int 1 | clear buffer per fixed iters | make sure replay
+ | ``per_iters`` | buffer's data count
+ | | isn't too few.
+ | | (code work in entry)
+ == ==================== ======== ============= =================================== =======================
+ """
+ config = dict(
+ # (str) RL policy register name, refer to registry ``POLICY_REGISTRY``.
+ type='gail',
+ # (float) The step size of gradient descent.
+ learning_rate=1e-3,
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=100,
+ # (int) How many samples in a training batch.
+ batch_size=64,
+ # (int) Size of the input: obs_dim + act_dim.
+ input_size=4,
+ # (int) Collect steps per iteration.
+ target_new_data_count=64,
+ # (int) Linear model hidden size.
+ hidden_size=128,
+ # (int) Expert dataset size.
+ collect_count=100000,
+ # (int) Clear buffer per fixed iters.
+ clear_buffer_per_iters=1,
+ )
+
+ def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
+ """
+ Overview:
+ Initialize ``self.`` See ``help(type(self))`` for accurate signature.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Training config
+ - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda"
+ - tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary
+ """
+ super(GailRewardModel, self).__init__()
+ self.cfg = config
+ assert device in ["cpu", "cuda"] or "cuda" in device
+ self.device = device
+ self.tb_logger = tb_logger
+ obs_shape = config.input_size
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.reward_model = RewardModelNetwork(config.input_size, config.hidden_size, 1)
+ self.concat_state_action_pairs = concat_state_action_pairs
+ elif len(obs_shape) == 3:
+ action_shape = self.cfg.action_size
+ self.reward_model = AtariRewardModelNetwork(config.input_size, action_shape)
+ self.concat_state_action_pairs = partial(concat_state_action_pairs_one_hot, action_size=action_shape)
+ self.reward_model.to(self.device)
+ self.expert_data = []
+ self.train_data = []
+ self.expert_data_loader = None
+ self.opt = optim.Adam(self.reward_model.parameters(), config.learning_rate)
+ self.train_iter = 0
+
+ self.load_expert_data()
+
+ def load_expert_data(self) -> None:
+ """
+ Overview:
+ Getting the expert data from ``config.data_path`` attribute in self
+ Effects:
+ This is a side effect function which updates the expert data attribute \
+ (i.e. ``self.expert_data``) with ``fn:concat_state_action_pairs``
+ """
+ with open(self.cfg.data_path + '/expert_data.pkl', 'rb') as f:
+ self.expert_data_loader: list = pickle.load(f)
+ self.expert_data = self.concat_state_action_pairs(self.expert_data_loader)
+
+ def state_dict(self) -> Dict[str, Any]:
+ return {
+ 'model': self.reward_model.state_dict(),
+ }
+
+ def load_state_dict(self, state_dict: Dict[str, Any]) -> None:
+ self.reward_model.load_state_dict(state_dict['model'])
+
+ def learn(self, train_data: torch.Tensor, expert_data: torch.Tensor) -> float:
+ """
+ Overview:
+ Helper function for ``train`` which calculates loss for train data and expert data.
+ Arguments:
+ - train_data (:obj:`torch.Tensor`): Data used for training
+ - expert_data (:obj:`torch.Tensor`): Expert data
+ Returns:
+ - Combined loss calculated of reward model from using ``train_data`` and ``expert_data``.
+ """
+ # calculate loss, here are some hyper-param
+ out_1: torch.Tensor = self.reward_model(train_data)
+ loss_1: torch.Tensor = torch.log(out_1 + 1e-8).mean()
+ out_2: torch.Tensor = self.reward_model(expert_data)
+ loss_2: torch.Tensor = torch.log(1 - out_2 + 1e-8).mean()
+ # log(x) with 0 None:
+ """
+ Overview:
+ Training the Gail reward model. The training and expert data are randomly sampled with designated\
+ batch size abstracted from the ``batch_size`` attribute in ``self.cfg`` and \
+ correspondingly, the ``expert_data`` as well as ``train_data`` attributes initialized ``self`
+ Effects:
+ - This is a side effect function which updates the reward model and increment the train iteration count.
+ """
+ for _ in range(self.cfg.update_per_collect):
+ sample_expert_data: list = random.sample(self.expert_data, self.cfg.batch_size)
+ sample_train_data: list = random.sample(self.train_data, self.cfg.batch_size)
+ sample_expert_data = torch.stack(sample_expert_data).to(self.device)
+ sample_train_data = torch.stack(sample_train_data).to(self.device)
+ loss = self.learn(sample_train_data, sample_expert_data)
+ self.tb_logger.add_scalar('reward_model/gail_loss', loss, self.train_iter)
+ self.train_iter += 1
+
+ def estimate(self, data: list) -> List[Dict]:
+ """
+ Overview:
+ Estimate reward by rewriting the reward key in each row of the data.
+ Arguments:
+ - data (:obj:`list`): the list of data used for estimation, with at least \
+ ``obs`` and ``action`` keys.
+ Effects:
+ - This is a side effect function which updates the reward values in place.
+ """
+ # NOTE: deepcopy reward part of data is very important,
+ # otherwise the reward of data in the replay buffer will be incorrectly modified.
+ train_data_augmented = self.reward_deepcopy(data)
+ res = self.concat_state_action_pairs(train_data_augmented)
+ res = torch.stack(res).to(self.device)
+ with torch.no_grad():
+ reward = self.reward_model(res).squeeze(-1).cpu()
+ reward = torch.chunk(reward, reward.shape[0], dim=0)
+ for item, rew in zip(train_data_augmented, reward):
+ item['reward'] = -torch.log(rew + 1e-8)
+
+ return train_data_augmented
+
+ def collect_data(self, data: list) -> None:
+ """
+ Overview:
+ Collecting training data formatted by ``fn:concat_state_action_pairs``.
+ Arguments:
+ - data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc)
+ Effects:
+ - This is a side effect function which updates the data attribute in ``self``
+ """
+ self.train_data.extend(self.concat_state_action_pairs(data))
+
+ def clear_data(self) -> None:
+ """
+ Overview:
+ Clearing training data. \
+ This is a side effect function which clears the data attribute in ``self``
+ """
+ self.train_data.clear()
diff --git a/DI-engine/ding/reward_model/guided_cost_reward_model.py b/DI-engine/ding/reward_model/guided_cost_reward_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..437e198f5394fcf891a83f8bb8977219c1e02a45
--- /dev/null
+++ b/DI-engine/ding/reward_model/guided_cost_reward_model.py
@@ -0,0 +1,178 @@
+from typing import List, Dict, Any
+from easydict import EasyDict
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+from torch.distributions import Independent, Normal
+
+from ding.utils import REWARD_MODEL_REGISTRY
+from ding.utils.data import default_collate
+from .base_reward_model import BaseRewardModel
+
+
+class GuidedCostNN(nn.Module):
+
+ def __init__(
+ self,
+ input_size,
+ hidden_size=128,
+ output_size=1,
+ ):
+ super(GuidedCostNN, self).__init__()
+ self.net = nn.Sequential(
+ nn.Linear(input_size, hidden_size),
+ nn.ReLU(),
+ nn.Linear(hidden_size, hidden_size),
+ nn.ReLU(),
+ nn.Linear(hidden_size, output_size),
+ )
+
+ def forward(self, x):
+ return self.net(x)
+
+
+@REWARD_MODEL_REGISTRY.register('guided_cost')
+class GuidedCostRewardModel(BaseRewardModel):
+ """
+ Overview:
+ Policy class of Guided cost algorithm. (https://arxiv.org/pdf/1603.00448.pdf)
+ Interface:
+ ``estimate``, ``train``, ``collect_data``, ``clear_date``, \
+ ``__init__``, ``state_dict``, ``load_state_dict``, ``learn``\
+ ``state_dict_reward_model``, ``load_state_dict_reward_model``
+ Config:
+ == ==================== ======== ============= ======================================== ================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============= ======================================== ================
+ 1 ``type`` str guided_cost | Reward model register name, refer |
+ | to registry ``REWARD_MODEL_REGISTRY`` |
+ 2 | ``continuous`` bool True | Whether action is continuous |
+ 3 | ``learning_rate`` float 0.001 | learning rate for optimizer |
+ 4 | ``update_per_`` int 100 | Number of updates per collect |
+ | ``collect`` | |
+ 5 | ``batch_size`` int 64 | Training batch size |
+ 6 | ``hidden_size`` int 128 | Linear model hidden size |
+ 7 | ``action_shape`` int 1 | Action space shape |
+ 8 | ``log_every_n`` int 50 | add loss to log every n iteration |
+ | ``_train`` | |
+ 9 | ``store_model_`` int 100 | save model every n iteration |
+ | ``every_n_train`` |
+ == ==================== ======== ============= ======================================== ================
+
+ """
+
+ config = dict(
+ # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
+ type='guided_cost',
+ # (float) The step size of gradient descent.
+ learning_rate=1e-3,
+ # (int) Action space shape, such as 1.
+ action_shape=1,
+ # (bool) Whether action is continuous.
+ continuous=True,
+ # (int) How many samples in a training batch.
+ batch_size=64,
+ # (int) Linear model hidden size.
+ hidden_size=128,
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=100,
+ # (int) Add loss to log every n iteration.
+ log_every_n_train=50,
+ # (int) Save model every n iteration.
+ store_model_every_n_train=100,
+ )
+
+ def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
+ super(GuidedCostRewardModel, self).__init__()
+ self.cfg = config
+ self.action_shape = self.cfg.action_shape
+ assert device == "cpu" or device.startswith("cuda")
+ self.device = device
+ self.tb_logger = tb_logger
+ self.reward_model = GuidedCostNN(config.input_size, config.hidden_size)
+ self.reward_model.to(self.device)
+ self.opt = optim.Adam(self.reward_model.parameters(), lr=config.learning_rate)
+
+ def train(self, expert_demo: torch.Tensor, samp: torch.Tensor, iter, step):
+ device_0 = expert_demo[0]['obs'].device
+ device_1 = samp[0]['obs'].device
+ for i in range(len(expert_demo)):
+ expert_demo[i]['prob'] = torch.FloatTensor([1]).to(device_0)
+ if self.cfg.continuous:
+ for i in range(len(samp)):
+ (mu, sigma) = samp[i]['logit']
+ dist = Independent(Normal(mu, sigma), 1)
+ next_action = samp[i]['action']
+ log_prob = dist.log_prob(next_action)
+ samp[i]['prob'] = torch.exp(log_prob).unsqueeze(0).to(device_1)
+ else:
+ for i in range(len(samp)):
+ probs = F.softmax(samp[i]['logit'], dim=-1)
+ prob = probs[samp[i]['action']]
+ samp[i]['prob'] = prob.to(device_1)
+ # Mix the expert data and sample data to train the reward model.
+ samp.extend(expert_demo)
+ expert_demo = default_collate(expert_demo)
+ samp = default_collate(samp)
+ cost_demo = self.reward_model(
+ torch.cat([expert_demo['obs'], expert_demo['action'].float().reshape(-1, self.action_shape)], dim=-1)
+ )
+ cost_samp = self.reward_model(
+ torch.cat([samp['obs'], samp['action'].float().reshape(-1, self.action_shape)], dim=-1)
+ )
+
+ prob = samp['prob'].unsqueeze(-1)
+ loss_IOC = torch.mean(cost_demo) + \
+ torch.log(torch.mean(torch.exp(-cost_samp)/(prob+1e-7)))
+ # UPDATING THE COST FUNCTION
+ self.opt.zero_grad()
+ loss_IOC.backward()
+ self.opt.step()
+ if iter % self.cfg.log_every_n_train == 0:
+ self.tb_logger.add_scalar('reward_model/loss_iter', loss_IOC, iter)
+ self.tb_logger.add_scalar('reward_model/loss_step', loss_IOC, step)
+
+ def estimate(self, data: list) -> List[Dict]:
+ # NOTE: this estimate method of gcl alg. is a little different from the one in other irl alg.,
+ # because its deepcopy is operated before learner train loop.
+ train_data_augmented = data
+ for i in range(len(train_data_augmented)):
+ with torch.no_grad():
+ reward = self.reward_model(
+ torch.cat([train_data_augmented[i]['obs'], train_data_augmented[i]['action'].float()]).unsqueeze(0)
+ ).squeeze(0)
+ train_data_augmented[i]['reward'] = -reward
+
+ return train_data_augmented
+
+ def collect_data(self, data) -> None:
+ """
+ Overview:
+ Collecting training data, not implemented if reward model (i.e. online_net) is only trained ones, \
+ if online_net is trained continuously, there should be some implementations in collect_data method
+ """
+ # if online_net is trained continuously, there should be some implementations in collect_data method
+ pass
+
+ def clear_data(self):
+ """
+ Overview:
+ Collecting clearing data, not implemented if reward model (i.e. online_net) is only trained ones, \
+ if online_net is trained continuously, there should be some implementations in clear_data method
+ """
+ # if online_net is trained continuously, there should be some implementations in clear_data method
+ pass
+
+ def state_dict_reward_model(self) -> Dict[str, Any]:
+ return {
+ 'model': self.reward_model.state_dict(),
+ 'optimizer': self.opt.state_dict(),
+ }
+
+ def load_state_dict_reward_model(self, state_dict: Dict[str, Any]) -> None:
+ self.reward_model.load_state_dict(state_dict['model'])
+ self.opt.load_state_dict(state_dict['optimizer'])
diff --git a/DI-engine/ding/reward_model/her_reward_model.py b/DI-engine/ding/reward_model/her_reward_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..c525f638602a58399e558389b2c24b7c81bbbeb5
--- /dev/null
+++ b/DI-engine/ding/reward_model/her_reward_model.py
@@ -0,0 +1,149 @@
+from typing import List, Dict, Any, Optional, Callable, Tuple
+import copy
+import numpy as np
+import torch
+
+
+class HerRewardModel:
+ """
+ Overview:
+ Hindsight Experience Replay model.
+
+ .. note::
+ - her_strategy (:obj:`str`): Type of strategy that HER uses, should be in ['final', 'future', 'episode']
+ - her_replay_k (:obj:`int`): Number of new episodes generated by an original episode. (Not used in episodic HER)
+ - episode_size (:obj:`int`): Sample how many episodes in one iteration.
+ - sample_per_episode (:obj:`int`): How many new samples are generated from an episode.
+
+ .. note::
+ In HER, we require episode trajectory to change the goals. However, episode lengths are different
+ and may have high variance. As a result, we **recommend** that you only use some transitions in
+ the complete episode by specifying ``episode_size`` and ``sample_per_episode`` in config.
+ Therefore, in one iteration, ``batch_size`` would be ``episode_size`` * ``sample_per_episode``.
+ """
+
+ def __init__(
+ self,
+ cfg: dict,
+ cuda: bool = False,
+ ) -> None:
+ self._cuda = cuda and torch.cuda.is_available()
+ self._device = 'cuda' if self._cuda else 'cpu'
+ self._her_strategy = cfg.her_strategy
+ assert self._her_strategy in ['final', 'future', 'episode']
+ # `her_replay_k` may not be used in episodic HER, so default set to 1.
+ self._her_replay_k = cfg.get('her_replay_k', 1)
+ self._episode_size = cfg.get('episode_size', None)
+ self._sample_per_episode = cfg.get('sample_per_episode', None)
+
+ def estimate(
+ self,
+ episode: List[Dict[str, Any]],
+ merge_func: Optional[Callable] = None,
+ split_func: Optional[Callable] = None,
+ goal_reward_func: Optional[Callable] = None
+ ) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ Get HER processed episodes from original episodes.
+ Arguments:
+ - episode (:obj:`List[Dict[str, Any]]`): Episode list, each element is a transition.
+ - merge_func (:obj:`Callable`): The merge function to use, default set to None. If None, \
+ then use ``__her_default_merge_func``
+ - split_func (:obj:`Callable`): The split function to use, default set to None. If None, \
+ then use ``__her_default_split_func``
+ - goal_reward_func (:obj:`Callable`): The goal_reward function to use, default set to None. If None, \
+ then use ``__her_default_goal_reward_func``
+ Returns:
+ - new_episode (:obj:`List[Dict[str, Any]]`): the processed transitions
+ """
+ if merge_func is None:
+ merge_func = HerRewardModel.__her_default_merge_func
+ if split_func is None:
+ split_func = HerRewardModel.__her_default_split_func
+ if goal_reward_func is None:
+ goal_reward_func = HerRewardModel.__her_default_goal_reward_func
+ new_episodes = [[] for _ in range(self._her_replay_k)]
+ if self._sample_per_episode is None:
+ # Use complete episode
+ indices = range(len(episode))
+ else:
+ # Use some transitions in one episode
+ indices = np.random.randint(0, len(episode), (self._sample_per_episode))
+ for idx in indices:
+ obs, _, _ = split_func(episode[idx]['obs'])
+ next_obs, _, achieved_goal = split_func(episode[idx]['next_obs'])
+ for k in range(self._her_replay_k):
+ if self._her_strategy == 'final':
+ p_idx = -1
+ elif self._her_strategy == 'episode':
+ p_idx = np.random.randint(0, len(episode))
+ elif self._her_strategy == 'future':
+ p_idx = np.random.randint(idx, len(episode))
+ _, _, new_desired_goal = split_func(episode[p_idx]['next_obs'])
+ timestep = {
+ k: copy.deepcopy(v)
+ for k, v in episode[idx].items() if k not in ['obs', 'next_obs', 'reward']
+ }
+ timestep['obs'] = merge_func(obs, new_desired_goal)
+ timestep['next_obs'] = merge_func(next_obs, new_desired_goal)
+ timestep['reward'] = goal_reward_func(achieved_goal, new_desired_goal).to(self._device)
+ new_episodes[k].append(timestep)
+ return new_episodes
+
+ @staticmethod
+ def __her_default_merge_func(x: Any, y: Any) -> Any:
+ r"""
+ Overview:
+ The function to merge obs in HER timestep
+ Arguments:
+ - x (:obj:`Any`): one of the timestep obs to merge
+ - y (:obj:`Any`): another timestep obs to merge
+ Returns:
+ - ret (:obj:`Any`): the merge obs
+ """
+ # TODO(nyz) dict/list merge_func
+ return torch.cat([x, y], dim=0)
+
+ @staticmethod
+ def __her_default_split_func(x: Any) -> Tuple[Any, Any, Any]:
+ r"""
+ Overview:
+ Split the input into obs, desired goal, and achieved goal.
+ Arguments:
+ - x (:obj:`Any`): The input to split
+ Returns:
+ - obs (:obj:`torch.Tensor`): Original obs.
+ - desired_goal (:obj:`torch.Tensor`): The final goal that wants to desired_goal
+ - achieved_goal (:obj:`torch.Tensor`): the achieved_goal
+ """
+ # TODO(nyz) dict/list split_func
+ # achieved_goal = f(obs), default: f == identical function
+ obs, desired_goal = torch.chunk(x, 2)
+ achieved_goal = obs
+ return obs, desired_goal, achieved_goal
+
+ @staticmethod
+ def __her_default_goal_reward_func(achieved_goal: torch.Tensor, desired_goal: torch.Tensor) -> torch.Tensor:
+ r"""
+ Overview:
+ Get the corresponding merge reward according to whether the achieved_goal fit the desired_goal
+ Arguments:
+ - achieved_goal (:obj:`torch.Tensor`): the achieved goal
+ - desired_goal (:obj:`torch.Tensor`): the desired_goal
+ Returns:
+ - goal_reward (:obj:`torch.Tensor`): the goal reward according to \
+ whether the achieved_goal fit the disired_goal
+ """
+ if (achieved_goal == desired_goal).all():
+ return torch.FloatTensor([1])
+ else:
+ return torch.FloatTensor([0])
+
+ @property
+ def episode_size(self) -> int:
+ return self._episode_size
+
+ @property
+ def sample_per_episode(self) -> int:
+ return self._sample_per_episode
diff --git a/DI-engine/ding/reward_model/icm_reward_model.py b/DI-engine/ding/reward_model/icm_reward_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..9cc6e23e9b0e744b44be8944d3a39c325f712eb7
--- /dev/null
+++ b/DI-engine/ding/reward_model/icm_reward_model.py
@@ -0,0 +1,309 @@
+from typing import Union, Tuple, List, Dict
+from easydict import EasyDict
+
+import random
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from ding.utils import SequenceType, REWARD_MODEL_REGISTRY
+from ding.model import FCEncoder, ConvEncoder
+from ding.torch_utils import one_hot
+from .base_reward_model import BaseRewardModel
+
+
+def collect_states(iterator: list) -> Tuple[list, list, list]:
+ states = []
+ next_states = []
+ actions = []
+ for item in iterator:
+ state = item['obs']
+ next_state = item['next_obs']
+ action = item['action']
+ states.append(state)
+ next_states.append(next_state)
+ actions.append(action)
+ return states, next_states, actions
+
+
+class ICMNetwork(nn.Module):
+ """
+ Intrinsic Curiosity Model (ICM Module)
+ Implementation of:
+ [1] Curiosity-driven Exploration by Self-supervised Prediction
+ Pathak, Agrawal, Efros, and Darrell - UC Berkeley - ICML 2017.
+ https://arxiv.org/pdf/1705.05363.pdf
+ [2] Code implementation reference:
+ https://github.com/pathak22/noreward-rl
+ https://github.com/jcwleo/curiosity-driven-exploration-pytorch
+
+ 1) Embedding observations into a latent space
+ 2) Predicting the action logit given two consecutive embedded observations
+ 3) Predicting the next embedded obs, given the embeded former observation and action
+ """
+
+ def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType, action_shape: int) -> None:
+ super(ICMNetwork, self).__init__()
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.feature = FCEncoder(obs_shape, hidden_size_list)
+ elif len(obs_shape) == 3:
+ self.feature = ConvEncoder(obs_shape, hidden_size_list)
+ else:
+ raise KeyError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own ICM model".
+ format(obs_shape)
+ )
+ self.action_shape = action_shape
+ feature_output = hidden_size_list[-1]
+ self.inverse_net = nn.Sequential(nn.Linear(feature_output * 2, 512), nn.ReLU(), nn.Linear(512, action_shape))
+ self.residual = nn.ModuleList(
+ [
+ nn.Sequential(
+ nn.Linear(action_shape + 512, 512),
+ nn.LeakyReLU(),
+ nn.Linear(512, 512),
+ ) for _ in range(8)
+ ]
+ )
+ self.forward_net_1 = nn.Sequential(nn.Linear(action_shape + feature_output, 512), nn.LeakyReLU())
+ self.forward_net_2 = nn.Linear(action_shape + 512, feature_output)
+
+ def forward(self, state: torch.Tensor, next_state: torch.Tensor,
+ action_long: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ r"""
+ Overview:
+ Use observation, next_observation and action to genearte ICM module
+ Parameter updates with ICMNetwork forward setup.
+ Arguments:
+ - state (:obj:`torch.Tensor`):
+ The current state batch
+ - next_state (:obj:`torch.Tensor`):
+ The next state batch
+ - action_long (:obj:`torch.Tensor`):
+ The action batch
+ Returns:
+ - real_next_state_feature (:obj:`torch.Tensor`):
+ Run with the encoder. Return the real next_state's embedded feature.
+ - pred_next_state_feature (:obj:`torch.Tensor`):
+ Run with the encoder and residual network. Return the predicted next_state's embedded feature.
+ - pred_action_logit (:obj:`torch.Tensor`):
+ Run with the encoder. Return the predicted action logit.
+ Shapes:
+ - state (:obj:`torch.Tensor`): :math:`(B, N)`, where B is the batch size and N is ''obs_shape''
+ - next_state (:obj:`torch.Tensor`): :math:`(B, N)`, where B is the batch size and N is ''obs_shape''
+ - action_long (:obj:`torch.Tensor`): :math:`(B)`, where B is the batch size''
+ - real_next_state_feature (:obj:`torch.Tensor`): :math:`(B, M)`, where B is the batch size
+ and M is embedded feature size
+ - pred_next_state_feature (:obj:`torch.Tensor`): :math:`(B, M)`, where B is the batch size
+ and M is embedded feature size
+ - pred_action_logit (:obj:`torch.Tensor`): :math:`(B, A)`, where B is the batch size
+ and A is the ''action_shape''
+ """
+ action = one_hot(action_long, num=self.action_shape)
+ encode_state = self.feature(state)
+ encode_next_state = self.feature(next_state)
+ # get pred action logit
+ concat_state = torch.cat((encode_state, encode_next_state), 1)
+ pred_action_logit = self.inverse_net(concat_state)
+ # ---------------------
+
+ # get pred next state
+ pred_next_state_feature_orig = torch.cat((encode_state, action), 1)
+ pred_next_state_feature_orig = self.forward_net_1(pred_next_state_feature_orig)
+
+ # residual
+ for i in range(4):
+ pred_next_state_feature = self.residual[i * 2](torch.cat((pred_next_state_feature_orig, action), 1))
+ pred_next_state_feature_orig = self.residual[i * 2 + 1](
+ torch.cat((pred_next_state_feature, action), 1)
+ ) + pred_next_state_feature_orig
+ pred_next_state_feature = self.forward_net_2(torch.cat((pred_next_state_feature_orig, action), 1))
+ real_next_state_feature = encode_next_state
+ return real_next_state_feature, pred_next_state_feature, pred_action_logit
+
+
+@REWARD_MODEL_REGISTRY.register('icm')
+class ICMRewardModel(BaseRewardModel):
+ """
+ Overview:
+ The ICM reward model class (https://arxiv.org/pdf/1705.05363.pdf)
+ Interface:
+ ``estimate``, ``train``, ``collect_data``, ``clear_data``, \
+ ``__init__``, ``_train``, ``load_state_dict``, ``state_dict``
+ Config:
+ == ==================== ======== ============= ==================================== =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ======== ============= ==================================== =======================
+ 1 ``type`` str icm | Reward model register name, |
+ | refer to registry |
+ | ``REWARD_MODEL_REGISTRY`` |
+ 2 | ``intrinsic_`` str add | the intrinsic reward type | including add, new
+ | ``reward_type`` | | , or assign
+ 3 | ``learning_rate`` float 0.001 | The step size of gradient descent |
+ 4 | ``obs_shape`` Tuple( 6 | the observation shape |
+ [int,
+ list])
+ 5 | ``action_shape`` int 7 | the action space shape |
+ 6 | ``batch_size`` int 64 | Training batch size |
+ 7 | ``hidden`` list [64, 64, | the MLP layer shape |
+ | ``_size_list`` (int) 128] | |
+ 8 | ``update_per_`` int 100 | Number of updates per collect |
+ | ``collect`` | |
+ 9 | ``reverse_scale`` float 1 | the importance weight of the |
+ | forward and reverse loss |
+ 10 | ``intrinsic_`` float 0.003 | the weight of intrinsic reward | r = w*r_i + r_e
+ ``reward_weight``
+ 11 | ``extrinsic_`` bool True | Whether to normlize
+ ``reward_norm`` | extrinsic reward
+ 12 | ``extrinsic_`` int 1 | the upper bound of the reward
+ ``reward_norm_max`` | normalization
+ 13 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay
+ ``_per_iters`` | buffer's data count
+ | isn't too few.
+ | (code work in entry)
+ == ==================== ======== ============= ==================================== =======================
+ """
+ config = dict(
+ # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
+ type='icm',
+ # (str) The intrinsic reward type, including add, new, or assign.
+ intrinsic_reward_type='add',
+ # (float) The step size of gradient descent.
+ learning_rate=1e-3,
+ # (Tuple[int, list]), The observation shape.
+ obs_shape=6,
+ # (int) The action shape, support discrete action only in this version.
+ action_shape=7,
+ # (float) Batch size.
+ batch_size=64,
+ # (list) The MLP layer shape.
+ hidden_size_list=[64, 64, 128],
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=100,
+ # (float) The importance weight of the forward and reverse loss.
+ reverse_scale=1,
+ # (float) The weight of intrinsic reward.
+ # r = intrinsic_reward_weight * r_i + r_e.
+ intrinsic_reward_weight=0.003, # 1/300
+ # (bool) Whether to normlize extrinsic reward.
+ # Normalize the reward to [0, extrinsic_reward_norm_max].
+ extrinsic_reward_norm=True,
+ # (int) The upper bound of the reward normalization.
+ extrinsic_reward_norm_max=1,
+ # (int) Clear buffer per fixed iters.
+ clear_buffer_per_iters=100,
+ )
+
+ def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
+ super(ICMRewardModel, self).__init__()
+ self.cfg = config
+ assert device == "cpu" or device.startswith("cuda")
+ self.device = device
+ self.tb_logger = tb_logger
+ self.reward_model = ICMNetwork(config.obs_shape, config.hidden_size_list, config.action_shape)
+ self.reward_model.to(self.device)
+ self.intrinsic_reward_type = config.intrinsic_reward_type
+ assert self.intrinsic_reward_type in ['add', 'new', 'assign']
+ self.train_data = []
+ self.train_states = []
+ self.train_next_states = []
+ self.train_actions = []
+ self.opt = optim.Adam(self.reward_model.parameters(), config.learning_rate)
+ self.ce = nn.CrossEntropyLoss(reduction="mean")
+ self.forward_mse = nn.MSELoss(reduction='none')
+ self.reverse_scale = config.reverse_scale
+ self.res = nn.Softmax(dim=-1)
+ self.estimate_cnt_icm = 0
+ self.train_cnt_icm = 0
+
+ def _train(self) -> None:
+ self.train_cnt_icm += 1
+ train_data_list = [i for i in range(0, len(self.train_states))]
+ train_data_index = random.sample(train_data_list, self.cfg.batch_size)
+ data_states: list = [self.train_states[i] for i in train_data_index]
+ data_states: torch.Tensor = torch.stack(data_states).to(self.device)
+ data_next_states: list = [self.train_next_states[i] for i in train_data_index]
+ data_next_states: torch.Tensor = torch.stack(data_next_states).to(self.device)
+ data_actions: list = [self.train_actions[i] for i in train_data_index]
+ data_actions: torch.Tensor = torch.cat(data_actions).to(self.device)
+
+ real_next_state_feature, pred_next_state_feature, pred_action_logit = self.reward_model(
+ data_states, data_next_states, data_actions
+ )
+ inverse_loss = self.ce(pred_action_logit, data_actions.long())
+ forward_loss = self.forward_mse(pred_next_state_feature, real_next_state_feature.detach()).mean()
+ self.tb_logger.add_scalar('icm_reward/forward_loss', forward_loss, self.train_cnt_icm)
+ self.tb_logger.add_scalar('icm_reward/inverse_loss', inverse_loss, self.train_cnt_icm)
+ action = torch.argmax(self.res(pred_action_logit), -1)
+ accuracy = torch.sum(action == data_actions.squeeze(-1)).item() / data_actions.shape[0]
+ self.tb_logger.add_scalar('icm_reward/action_accuracy', accuracy, self.train_cnt_icm)
+ loss = self.reverse_scale * inverse_loss + forward_loss
+ self.tb_logger.add_scalar('icm_reward/total_loss', loss, self.train_cnt_icm)
+ loss = self.reverse_scale * inverse_loss + forward_loss
+ self.opt.zero_grad()
+ loss.backward()
+ self.opt.step()
+
+ def train(self) -> None:
+ for _ in range(self.cfg.update_per_collect):
+ self._train()
+
+ def estimate(self, data: list) -> List[Dict]:
+ # NOTE: deepcopy reward part of data is very important,
+ # otherwise the reward of data in the replay buffer will be incorrectly modified.
+ train_data_augmented = self.reward_deepcopy(data)
+ states, next_states, actions = collect_states(train_data_augmented)
+ states = torch.stack(states).to(self.device)
+ next_states = torch.stack(next_states).to(self.device)
+ actions = torch.cat(actions).to(self.device)
+ with torch.no_grad():
+ real_next_state_feature, pred_next_state_feature, _ = self.reward_model(states, next_states, actions)
+ raw_icm_reward = self.forward_mse(real_next_state_feature, pred_next_state_feature).mean(dim=1)
+ self.estimate_cnt_icm += 1
+ self.tb_logger.add_scalar('icm_reward/raw_icm_reward_max', raw_icm_reward.max(), self.estimate_cnt_icm)
+ self.tb_logger.add_scalar('icm_reward/raw_icm_reward_mean', raw_icm_reward.mean(), self.estimate_cnt_icm)
+ self.tb_logger.add_scalar('icm_reward/raw_icm_reward_min', raw_icm_reward.min(), self.estimate_cnt_icm)
+ self.tb_logger.add_scalar('icm_reward/raw_icm_reward_std', raw_icm_reward.std(), self.estimate_cnt_icm)
+ icm_reward = (raw_icm_reward - raw_icm_reward.min()) / (raw_icm_reward.max() - raw_icm_reward.min() + 1e-8)
+ self.tb_logger.add_scalar('icm_reward/icm_reward_max', icm_reward.max(), self.estimate_cnt_icm)
+ self.tb_logger.add_scalar('icm_reward/icm_reward_mean', icm_reward.mean(), self.estimate_cnt_icm)
+ self.tb_logger.add_scalar('icm_reward/icm_reward_min', icm_reward.min(), self.estimate_cnt_icm)
+ self.tb_logger.add_scalar('icm_reward/icm_reward_std', icm_reward.std(), self.estimate_cnt_icm)
+ icm_reward = (raw_icm_reward - raw_icm_reward.min()) / (raw_icm_reward.max() - raw_icm_reward.min() + 1e-8)
+ icm_reward = icm_reward.to(self.device)
+ for item, icm_rew in zip(train_data_augmented, icm_reward):
+ if self.intrinsic_reward_type == 'add':
+ if self.cfg.extrinsic_reward_norm:
+ item['reward'] = item[
+ 'reward'] / self.cfg.extrinsic_reward_norm_max + icm_rew * self.cfg.intrinsic_reward_weight
+ else:
+ item['reward'] = item['reward'] + icm_rew * self.cfg.intrinsic_reward_weight
+ elif self.intrinsic_reward_type == 'new':
+ item['intrinsic_reward'] = icm_rew
+ if self.cfg.extrinsic_reward_norm:
+ item['reward'] = item['reward'] / self.cfg.extrinsic_reward_norm_max
+ elif self.intrinsic_reward_type == 'assign':
+ item['reward'] = icm_rew
+
+ return train_data_augmented
+
+ def collect_data(self, data: list) -> None:
+ self.train_data.extend(collect_states(data))
+ states, next_states, actions = collect_states(data)
+ self.train_states.extend(states)
+ self.train_next_states.extend(next_states)
+ self.train_actions.extend(actions)
+
+ def clear_data(self) -> None:
+ self.train_data.clear()
+ self.train_states.clear()
+ self.train_next_states.clear()
+ self.train_actions.clear()
+
+ def state_dict(self) -> Dict:
+ return self.reward_model.state_dict()
+
+ def load_state_dict(self, _state_dict: Dict) -> None:
+ self.reward_model.load_state_dict(_state_dict)
diff --git a/DI-engine/ding/reward_model/ngu_reward_model.py b/DI-engine/ding/reward_model/ngu_reward_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a8758bdb7e14c11bec1cad6564ab2aad6f8a5b9
--- /dev/null
+++ b/DI-engine/ding/reward_model/ngu_reward_model.py
@@ -0,0 +1,543 @@
+import copy
+import random
+from typing import Union, Tuple, Dict, List
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from easydict import EasyDict
+
+from ding.model import FCEncoder, ConvEncoder
+from ding.utils import RunningMeanStd
+from ding.utils import SequenceType, REWARD_MODEL_REGISTRY
+from .base_reward_model import BaseRewardModel
+
+
+def collect_data_and_exclude_null_data_rnd(data_in):
+ res = []
+ for item in data_in:
+ if torch.nonzero(torch.tensor(item['null']).float()).shape[0] != 0: # if have null padding in data
+ # the index of not null data in data_in
+ # not_null_index = torch.nonzero(torch.tensor(item['null']).float()).squeeze(-1)
+ null_start_index = int(torch.nonzero(torch.tensor(item['null']).float()).squeeze(-1)[0])
+ obs = item['obs'][:null_start_index] # exclude the null padding data
+ else:
+ obs = item['obs'] # sequence data
+ res.append(obs)
+ return res
+
+
+def collect_data_rnd(data_in):
+ res = []
+ is_null_list = []
+ for item in data_in:
+ state = item['obs']
+ is_null = item['null']
+ res.append(state)
+ is_null_list.append(is_null)
+ return res, is_null_list
+
+
+def collect_data_and_exclude_null_data_episodic(data_in):
+ obs_list = []
+ action_list = []
+ for item in data_in:
+ if torch.nonzero(torch.tensor(item['null']).float()).shape[0] != 0: # if have null padding in data
+ # the index of not null data in data_in
+ # not_null_index = torch.nonzero(torch.tensor(item['null']).float()).squeeze(-1)
+ null_start_index = int(torch.nonzero(torch.tensor(item['null']).float()).squeeze(-1)[0])
+ obs = item['obs'][:null_start_index] # sequence data
+ action = item['action'][:null_start_index] # exclude the null padding data
+ else:
+ obs = item['obs'] # sequence data
+ action = item['action']
+ obs_list.append(obs)
+ action_list.append(action)
+ return obs_list, action_list
+
+
+def collect_data_episodic(data_in):
+ res = []
+ is_null_list = []
+ for item in data_in:
+ state = item['obs']
+ is_null = item['null']
+ res.append(state)
+ is_null_list.append(is_null)
+ return res, is_null_list
+
+
+class RndNetwork(nn.Module):
+
+ def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType) -> None:
+ super(RndNetwork, self).__init__()
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.target = FCEncoder(obs_shape, hidden_size_list)
+ self.predictor = FCEncoder(obs_shape, hidden_size_list)
+ elif len(obs_shape) == 3:
+ self.target = ConvEncoder(obs_shape, hidden_size_list)
+ self.predictor = ConvEncoder(obs_shape, hidden_size_list)
+ else:
+ raise KeyError(
+ "not support obs_shape for pre-defined encoder: {}, "
+ "please customize your own RND model".format(obs_shape)
+ )
+ for param in self.target.parameters():
+ param.requires_grad = False
+
+ def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ predict_feature = self.predictor(obs)
+ with torch.no_grad():
+ target_feature = self.target(obs)
+ return predict_feature, target_feature
+
+
+@REWARD_MODEL_REGISTRY.register('rnd-ngu')
+class RndNGURewardModel(BaseRewardModel):
+ r"""
+ Overview:
+ inter-episodic/RND reward model for NGU.
+ The corresponding paper is `never give up: learning directed exploration strategies`.
+ """
+ config = dict(
+ type='rnd-ngu',
+ intrinsic_reward_type='add',
+ learning_rate=1e-3,
+ batch_size=64,
+ hidden_size_list=[64, 64, 128],
+ update_per_collect=100,
+ )
+
+ def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
+ super(RndNGURewardModel, self).__init__()
+ self.cfg = config
+ assert device == "cpu" or device.startswith("cuda")
+ self.device = device
+ self.tb_logger = tb_logger
+ self.reward_model = RndNetwork(config.obs_shape, config.hidden_size_list)
+ self.reward_model.to(self.device)
+ self.intrinsic_reward_type = config.intrinsic_reward_type
+ assert self.intrinsic_reward_type in ['add', 'new', 'assign']
+ self.train_data_total = []
+ self.train_data = []
+ self.opt = optim.Adam(self.reward_model.predictor.parameters(), config.learning_rate)
+ self.estimate_cnt_rnd = 0
+ self._running_mean_std_rnd = RunningMeanStd(epsilon=1e-4)
+ self.only_use_last_five_frames = config.only_use_last_five_frames_for_icm_rnd
+
+ def _train(self) -> None:
+ train_data: list = random.sample(list(self.train_data_cur), self.cfg.batch_size)
+
+ train_data: torch.Tensor = torch.stack(train_data).to(self.device)
+
+ predict_feature, target_feature = self.reward_model(train_data)
+ loss = F.mse_loss(predict_feature, target_feature.detach())
+ self.opt.zero_grad()
+ loss.backward()
+ self.opt.step()
+
+ def train(self) -> None:
+ if self.only_use_last_five_frames:
+ # self.train_obs shape list(list) [batch_size,seq_length,N
+
+ # stack episode dim
+ self.train_obs = [torch.stack(episode_obs[-5:], dim=0) for episode_obs in self.train_data_total]
+
+ # stack batch dim
+ # way 1
+ if isinstance(self.cfg.obs_shape, int):
+ self.train_data_cur = torch.stack(
+ self.train_obs, dim=0
+ ).view(len(self.train_obs) * len(self.train_obs[0]), self.cfg.obs_shape)
+ else: # len(self.cfg.obs_shape) == 3 for image obs
+ self.train_data_cur = torch.stack(
+ self.train_obs, dim=0
+ ).view(len(self.train_obs) * self.train_obs[0].shape[0], *self.cfg.obs_shape)
+ # way 2
+ # self.train_data_cur = torch.cat(self.train_obs, 0)
+
+ else:
+ self.train_data_cur = sum(self.train_data_total, [])
+ # another implementation way
+ # tmp = []
+ # for i in range(len(self.train_data)):
+ # tmp += self.train_data[i]
+ # self.train_data = tmp
+
+ for _ in range(self.cfg.update_per_collect):
+ self._train()
+
+ def estimate(self, data: list) -> torch.Tensor:
+ """
+ Rewrite the reward key in each row of the data.
+ """
+ obs, is_null = collect_data_rnd(data)
+ if isinstance(obs[0], list): # if obs shape list( list(torch.tensor) )
+ obs = sum(obs, [])
+
+ obs = torch.stack(obs).to(self.device)
+
+ with torch.no_grad():
+ predict_feature, target_feature = self.reward_model(obs)
+ reward = F.mse_loss(predict_feature, target_feature, reduction='none').mean(dim=1)
+ self._running_mean_std_rnd.update(reward.cpu().numpy())
+ # transform to mean 1 std 1
+ reward = 1 + (reward - self._running_mean_std_rnd.mean) / (self._running_mean_std_rnd.std + 1e-11)
+ self.estimate_cnt_rnd += 1
+ self.tb_logger.add_scalar('rnd_reward/rnd_reward_max', reward.max(), self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('rnd_reward/rnd_reward_mean', reward.mean(), self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('rnd_reward/rnd_reward_min', reward.min(), self.estimate_cnt_rnd)
+ return reward
+
+ def collect_data(self, data: list) -> None:
+ self.train_data_total.extend(collect_data_and_exclude_null_data_rnd(data))
+
+ def clear_data(self) -> None:
+ self.train_data_total.clear()
+
+ def reward_deepcopy(self, train_data):
+ """
+ this method deepcopy reward part in train_data, and other parts keep shallow copy
+ to avoid the reward part of train_data in the replay buffer be incorrectly modified.
+ """
+ train_data_reward_deepcopy = [
+ {k: copy.deepcopy(v) if k == 'reward' else v
+ for k, v in sample.items()} for sample in train_data
+ ]
+ return train_data_reward_deepcopy
+
+
+class InverseNetwork(nn.Module):
+
+ def __init__(self, obs_shape: Union[int, SequenceType], action_shape, hidden_size_list: SequenceType) -> None:
+ super(InverseNetwork, self).__init__()
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.embedding_net = FCEncoder(obs_shape, hidden_size_list)
+ elif len(obs_shape) == 3:
+ self.embedding_net = ConvEncoder(obs_shape, hidden_size_list)
+ else:
+ raise KeyError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own RND model".
+ format(obs_shape)
+ )
+ self.inverse_net = nn.Sequential(
+ nn.Linear(hidden_size_list[-1] * 2, 512), nn.ReLU(inplace=True), nn.Linear(512, action_shape)
+ )
+
+ def forward(self, inputs: Dict, inference: bool = False) -> Dict:
+ if inference:
+ with torch.no_grad():
+ cur_obs_embedding = self.embedding_net(inputs['obs'])
+ return cur_obs_embedding
+ else:
+ # obs: torch.Tensor, next_obs: torch.Tensor
+ cur_obs_embedding = self.embedding_net(inputs['obs'])
+ next_obs_embedding = self.embedding_net(inputs['next_obs'])
+ # get pred action
+ obs_plus_next_obs = torch.cat([cur_obs_embedding, next_obs_embedding], dim=-1)
+ pred_action_logits = self.inverse_net(obs_plus_next_obs)
+ pred_action_probs = nn.Softmax(dim=-1)(pred_action_logits)
+ return pred_action_logits, pred_action_probs
+
+
+@REWARD_MODEL_REGISTRY.register('episodic')
+class EpisodicNGURewardModel(BaseRewardModel):
+ r"""
+ Overview:
+ Episodic reward model for NGU.
+ The corresponding paper is `never give up: learning directed exploration strategies`.
+ """
+ config = dict(
+ type='episodic',
+ intrinsic_reward_type='add',
+ learning_rate=1e-3,
+ batch_size=64,
+ hidden_size_list=[64, 64, 128],
+ update_per_collect=100,
+ # means if using rescale trick to the last non-zero reward
+ # when combing extrinsic and intrinsic reward.
+ # the rescale trick only used in:
+ # 1. sparse reward env minigrid, in which the last non-zero reward is a strong positive signal
+ # 2. the last reward of each episode directly reflects the agent's completion of the task, e.g. lunarlander
+ # Note that the ngu intrinsic reward is a positive value (max value is 5), in these envs,
+ # the last non-zero reward should not be overwhelmed by intrinsic rewards, so we need rescale the
+ # original last nonzero extrinsic reward.
+ last_nonzero_reward_rescale=False,
+ # means the rescale value for the last non-zero reward, only used when last_nonzero_reward_rescale is True
+ last_nonzero_reward_weight=1,
+ )
+
+ def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
+ super(EpisodicNGURewardModel, self).__init__()
+ self.cfg = config
+ assert device == "cpu" or device.startswith("cuda")
+ self.device = device
+ self.tb_logger = tb_logger
+ self.episodic_reward_model = InverseNetwork(config.obs_shape, config.action_shape, config.hidden_size_list)
+ self.episodic_reward_model.to(self.device)
+ self.intrinsic_reward_type = config.intrinsic_reward_type
+ assert self.intrinsic_reward_type in ['add', 'new', 'assign']
+ self.train_obs_total = []
+ self.train_action_total = []
+ self.opt = optim.Adam(self.episodic_reward_model.parameters(), config.learning_rate)
+ self.estimate_cnt_episodic = 0
+ self._running_mean_std_episodic_dist = RunningMeanStd(epsilon=1e-4)
+ self._running_mean_std_episodic_reward = RunningMeanStd(epsilon=1e-4)
+ self.only_use_last_five_frames = config.only_use_last_five_frames_for_icm_rnd
+
+ def _train(self) -> None:
+ # sample episode's timestep index
+ train_index = np.random.randint(low=0, high=self.train_obs.shape[0], size=self.cfg.batch_size)
+
+ train_obs: torch.Tensor = self.train_obs[train_index].to(self.device) # shape (self.cfg.batch_size, obs_dim)
+ train_next_obs: torch.Tensor = self.train_next_obs[train_index].to(self.device)
+ train_action: torch.Tensor = self.train_action[train_index].to(self.device)
+
+ train_data = {'obs': train_obs, 'next_obs': train_next_obs}
+ pred_action_logits, pred_action_probs = self.episodic_reward_model(train_data)
+
+ inverse_loss = F.cross_entropy(pred_action_logits, train_action.squeeze(-1))
+ self.opt.zero_grad()
+ inverse_loss.backward()
+ self.opt.step()
+
+ def train(self) -> None:
+ self.train_next_obs_total = copy.deepcopy(self.train_obs_total)
+
+ if self.only_use_last_five_frames:
+ # self.train_obs shape: list(list) [batch_size,seq_length,obs_dim]
+ self.train_obs = [torch.stack(episode_obs[-6:-1], dim=0) for episode_obs in self.train_obs_total]
+ self.train_next_obs = [torch.stack(episode_obs[-5:], dim=0) for episode_obs in self.train_next_obs_total]
+ self.train_action = [
+ torch.stack(episode_action[-6:-1], dim=0) for episode_action in self.train_action_total
+ ]
+ else:
+ self.train_obs = [
+ torch.stack(episode_obs[:-1], dim=0) for episode_obs in self.train_obs_total if len(episode_obs) > 1
+ ]
+ self.train_next_obs = [
+ torch.stack(episode_next_obs[1:], dim=0) for episode_next_obs in self.train_next_obs_total
+ if len(episode_next_obs) > 1
+ ]
+ self.train_action = [
+ torch.stack(episode_action[:-1], dim=0) for episode_action in self.train_action_total
+ if len(episode_action) > 1
+ ]
+
+ # stack batch dim
+ self.train_obs = torch.cat(self.train_obs, 0)
+ self.train_next_obs = torch.cat(self.train_next_obs, 0)
+ self.train_action = torch.cat(self.train_action, 0)
+
+ for _ in range(self.cfg.update_per_collect):
+ self._train()
+
+ def _compute_intrinsic_reward(
+ self,
+ episodic_memory: List,
+ current_controllable_state: torch.Tensor,
+ k=10,
+ kernel_cluster_distance=0.008,
+ kernel_epsilon=0.0001,
+ c=0.001,
+ siminarity_max=8,
+ ) -> torch.Tensor:
+ # this function is modified from https://github.com/Coac/never-give-up/blob/main/embedding_model.py
+ state_dist = torch.cdist(current_controllable_state.unsqueeze(0), episodic_memory, p=2).squeeze(0).sort()[0][:k]
+ self._running_mean_std_episodic_dist.update(state_dist.cpu().numpy())
+ state_dist = state_dist / (self._running_mean_std_episodic_dist.mean + 1e-11)
+
+ state_dist = torch.clamp(state_dist - kernel_cluster_distance, min=0, max=None)
+ kernel = kernel_epsilon / (state_dist + kernel_epsilon)
+ s = torch.sqrt(torch.clamp(torch.sum(kernel), min=0, max=None)) + c
+
+ if s > siminarity_max:
+ print('s > siminarity_max:', s.max(), s.min())
+ return torch.tensor(0) # NOTE
+ return 1 / s
+ # average value 1/( ( 10* 1e-4/(1+1e-4) )**(1/2)+1e-3 ) = 30
+
+ def estimate(self, data: list) -> torch.Tensor:
+ """
+ Rewrite the reward key in each row of the data.
+ """
+
+ obs, is_null = collect_data_episodic(data)
+ # obs shape list(list()) [batch_size,seq_length,obs_dim]
+ batch_size = len(obs)
+ seq_length = len(obs[0])
+
+ # stack episode dim
+ obs = [torch.stack(episode_obs, dim=0) for episode_obs in obs]
+
+ # stack batch dim
+ # way 0
+ if isinstance(self.cfg.obs_shape, int):
+ obs = torch.stack(obs, dim=0).view(batch_size * seq_length, self.cfg.obs_shape).to(self.device)
+ else: # len(self.cfg.obs_shape) == 3 for image obs
+ obs = torch.stack(obs, dim=0).view(batch_size * seq_length, *self.cfg.obs_shape).to(self.device)
+ # way 2
+ # obs = torch.cat(obs, 0)
+
+ inputs = {'obs': obs, 'is_null': is_null}
+ with torch.no_grad():
+ cur_obs_embedding = self.episodic_reward_model(inputs, inference=True)
+ cur_obs_embedding = cur_obs_embedding.view(batch_size, seq_length, -1)
+ episodic_reward = [[] for _ in range(batch_size)]
+ null_cnt = 0 # the number of null transitions in the whole minibatch
+ for i in range(batch_size):
+ for j in range(seq_length):
+ if j < 10:
+ # if self._running_mean_std_episodic_reward.mean is not None:
+ # episodic_reward[i].append(torch.tensor(self._running_mean_std_episodic_reward.mean).to(self.device))
+ # else:
+ episodic_reward[i].append(torch.tensor(0.).to(self.device))
+ elif j:
+ episodic_memory = cur_obs_embedding[i][:j]
+ reward = self._compute_intrinsic_reward(episodic_memory,
+ cur_obs_embedding[i][j]).to(self.device)
+ episodic_reward[i].append(reward)
+
+ if torch.nonzero(torch.tensor(is_null[i]).float()).shape[0] != 0:
+ # TODO(pu): if have null padding, the episodic_reward should be 0
+ not_null_index = torch.nonzero(torch.tensor(is_null[i]).float()).squeeze(-1)
+ null_start_index = int(torch.nonzero(torch.tensor(is_null[i]).float()).squeeze(-1)[0])
+ # add the number of null transitions in i'th sequence in batch
+ null_cnt = null_cnt + seq_length - null_start_index
+ for k in range(null_start_index, seq_length):
+ episodic_reward[i][k] = torch.tensor(0).to(self.device)
+ # episodic_reward[i][null_start_index:-1]=[torch.tensor(0).to(self.device)
+ # for i in range(seq_length-null_start_index)]
+
+ # list(list(tensor)) -> tensor
+ tmp = [torch.stack(episodic_reward_tmp, dim=0) for episodic_reward_tmp in episodic_reward]
+ # stack batch dim
+ episodic_reward = torch.stack(tmp, dim=0) # TODO(pu): image case
+ episodic_reward = episodic_reward.view(-1) # torch.Size([32, 42]) -> torch.Size([32*42]
+
+ episodic_reward_real_mean = sum(episodic_reward) / (
+ batch_size * seq_length - null_cnt
+ ) # TODO(pu): recompute mean
+ self.estimate_cnt_episodic += 1
+ self._running_mean_std_episodic_reward.update(episodic_reward.cpu().numpy())
+
+ self.tb_logger.add_scalar(
+ 'episodic_reward/episodic_reward_max', episodic_reward.max(), self.estimate_cnt_episodic
+ )
+ self.tb_logger.add_scalar(
+ 'episodic_reward/episodic_reward_mean', episodic_reward_real_mean, self.estimate_cnt_episodic
+ )
+ self.tb_logger.add_scalar(
+ 'episodic_reward/episodic_reward_min', episodic_reward.min(), self.estimate_cnt_episodic
+ )
+ self.tb_logger.add_scalar(
+ 'episodic_reward/episodic_reward_std_', episodic_reward.std(), self.estimate_cnt_episodic
+ )
+ # transform to [0,1]: er01
+ episodic_reward = (episodic_reward -
+ episodic_reward.min()) / (episodic_reward.max() - episodic_reward.min() + 1e-11)
+ """1. transform to batch mean1: erbm1"""
+ # episodic_reward = episodic_reward / (episodic_reward.mean() + 1e-11)
+ # the null_padding transition have episodic reward=0,
+ # episodic_reward = episodic_reward / (episodic_reward_real_mean + 1e-11)
+ """2. transform to long-term mean1: erlm1"""
+ # episodic_reward = episodic_reward / self._running_mean_std_episodic_reward.mean
+ """3. transform to mean 0, std 1, which is wrong, rnd_reward is in [1,5], episodic reward should >0,
+ otherwise, e.g. when the episodic_reward is -2, the rnd_reward larger,
+ the total intrinsic reward smaller, which is not correct."""
+ # episodic_reward = (episodic_reward - self._running_mean_std_episodic_reward.mean)
+ # / self._running_mean_std_episodic_reward.std
+ """4. transform to std1, which is not very meaningful"""
+ # episodic_reward = episodic_reward / self._running_mean_std_episodic_reward.std
+
+ return episodic_reward
+
+ def collect_data(self, data: list) -> None:
+ train_obs, train_action = collect_data_and_exclude_null_data_episodic(data)
+ self.train_obs_total.extend(train_obs)
+ self.train_action_total.extend(train_action)
+
+ def clear_data(self) -> None:
+ self.train_obs_total = []
+ self.train_action_total = []
+
+ def fusion_reward(
+ self, train_data, inter_episodic_reward, episodic_reward, nstep, collector_env_num, tb_logger, estimate_cnt
+ ):
+ # NOTE: deepcopy reward part of train_data is very important,
+ # otherwise the reward of train_data in the replay buffer will be incorrectly modified.
+ data = self.reward_deepcopy(train_data)
+ estimate_cnt += 1
+ index_to_beta = {
+ i: 0.3 * torch.sigmoid(torch.tensor(10 * (2 * i - (collector_env_num - 2)) / (collector_env_num - 2)))
+ for i in range(collector_env_num)
+ }
+ batch_size = len(data)
+ seq_length = len(data[0]['reward'])
+ device = data[0]['reward'][0].device
+ intrinsic_reward_type = 'add'
+ intrisic_reward = episodic_reward * torch.clamp(inter_episodic_reward, min=1, max=5)
+ tb_logger.add_scalar('intrinsic_reward/intrinsic_reward_max', intrisic_reward.max(), estimate_cnt)
+ tb_logger.add_scalar('intrinsic_reward/intrinsic_reward_mean', intrisic_reward.mean(), estimate_cnt)
+ tb_logger.add_scalar('intrinsic_reward/intrinsic_reward_min', intrisic_reward.min(), estimate_cnt)
+
+ if not isinstance(data[0], (list, dict)):
+ # not rnn based rl algorithm
+ intrisic_reward = intrisic_reward.to(device)
+ intrisic_reward = torch.chunk(intrisic_reward, intrisic_reward.shape[0], dim=0)
+ for item, rew in zip(data, intrisic_reward):
+ if intrinsic_reward_type == 'add':
+ item['reward'] += rew * index_to_beta[data['beta']]
+ else:
+ # rnn based rl algorithm
+ intrisic_reward = intrisic_reward.to(device)
+
+ # tensor to tuple
+ intrisic_reward = torch.chunk(intrisic_reward, int(intrisic_reward.shape[0]), dim=0)
+
+ if self.cfg.last_nonzero_reward_weight is None and self.cfg.last_nonzero_reward_rescale:
+ # for minigrid env
+ self.cfg.last_nonzero_reward_weight = seq_length
+
+ # this is for the nstep rl algorithms
+ for i in range(batch_size): # batch_size typically 64
+ for j in range(seq_length): # burnin+unroll_len is the sequence length, e.g. 100=2+98
+ if j < seq_length - nstep:
+ intrinsic_reward = torch.cat(
+ [intrisic_reward[i * seq_length + j + k] for k in range(nstep)], dim=0
+ )
+ # if intrinsic_reward_type == 'add':
+ if not data[i]['null'][j]:
+ # if data[i]['null'][j]==True, means its's null data, only the not null data,
+ # we add a intrinsic_reward
+ if data[i]['done'][j] and self.cfg.last_nonzero_reward_rescale:
+ # if not null data, and data[i]['done'][j]==True, so this is the last nstep transition
+ # in the original data.
+
+ # means if using rescale trick to the last non-zero reward
+ # when combing extrinsic and intrinsic reward.
+ # only used in sparse reward env minigrid, in which the last non-zero reward
+ # is a strong positive signal, should not be overwhelmed by intrinsic rewards。
+ for k in reversed(range(nstep)):
+ # here we want to find the last nonzero reward in the nstep reward list:
+ # data[i]['reward'][j], that is also the last reward in the sequence, here,
+ # we set the sequence length is large enough,
+ # so we can consider the sequence as the whole episode plus null_padding
+
+ # TODO(pu): what should we do if the last reward in the whole episode is zero?
+ if data[i]['reward'][j][k] != 0:
+ # find the last one that is nonzero, and enlarging times
+ last_nonzero_rew = copy.deepcopy(data[i]['reward'][j][k])
+ data[i]['reward'][j][k] = \
+ self.cfg.last_nonzero_reward_weight * last_nonzero_rew + \
+ intrinsic_reward[k] * index_to_beta[int(data[i]['beta'][j])]
+ # substitute the kth reward in the list data[i]['reward'][j] with
+ # times amplified reward
+ break
+ else:
+ data[i]['reward'][j] = data[i]['reward'][j] + intrinsic_reward * index_to_beta[
+ int(data[i]['beta'][j])]
+
+ return data, estimate_cnt
diff --git a/DI-engine/ding/reward_model/pdeil_irl_model.py b/DI-engine/ding/reward_model/pdeil_irl_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..b09416f5c2fade0fbc9f3dfc69be46927c3b679a
--- /dev/null
+++ b/DI-engine/ding/reward_model/pdeil_irl_model.py
@@ -0,0 +1,227 @@
+from typing import List, Dict
+from ditk import logging
+import numpy as np
+import torch
+import pickle
+try:
+ from sklearn.svm import SVC
+except ImportError:
+ SVC = None
+from ding.torch_utils import cov
+from ding.utils import REWARD_MODEL_REGISTRY, one_time_warning
+from .base_reward_model import BaseRewardModel
+
+
+@REWARD_MODEL_REGISTRY.register('pdeil')
+class PdeilRewardModel(BaseRewardModel):
+ """
+ Overview:
+ The Pdeil reward model class (https://arxiv.org/abs/2112.06746)
+ Interface:
+ ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \
+ ``__init__``, ``_train``, ``_batch_mn_pdf``
+ Config:
+ == ==================== ===== ============= ======================================= =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ===== ============= ======================================= =======================
+ 1 ``type`` str pdeil | Reward model register name, refer |
+ | to registry ``REWARD_MODEL_REGISTRY`` |
+ 2 | ``expert_data_`` str expert_data. | Path to the expert dataset | Should be a '.pkl'
+ | ``path`` .pkl | | file
+ 3 | ``discrete_`` bool False | Whether the action is discrete |
+ | ``action`` | |
+ 4 | ``alpha`` float 0.5 | coefficient for Probability |
+ | | Density Estimator |
+ 5 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay
+ ``_per_iters`` | buffer's data count
+ | isn't too few.
+ | (code work in entry)
+ == ==================== ===== ============= ======================================= =======================
+ """
+ config = dict(
+ # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
+ type='pdeil',
+ # (str) Path to the expert dataset.
+ # expert_data_path='expert_data.pkl',
+ # (bool) Whether the action is discrete.
+ discrete_action=False,
+ # (float) Coefficient for Probability Density Estimator.
+ # alpha + beta = 1, alpha is in [0,1]
+ # when alpha is close to 0, the estimator has high variance and low bias;
+ # when alpha is close to 1, the estimator has high bias and low variance.
+ alpha=0.5,
+ # (int) Clear buffer per fixed iters.
+ clear_buffer_per_iters=1,
+ )
+
+ def __init__(self, cfg: dict, device, tb_logger: 'SummaryWriter') -> None: # noqa
+ """
+ Overview:
+ Initialize ``self.`` See ``help(type(self))`` for accurate signature.
+ Some rules in naming the attributes of ``self.``:
+
+ - ``e_`` : expert values
+ - ``_sigma_`` : standard division values
+ - ``p_`` : current policy values
+ - ``_s_`` : states
+ - ``_a_`` : actions
+ Arguments:
+ - cfg (:obj:`Dict`): Training config
+ - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda"
+ - tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary
+ """
+ super(PdeilRewardModel, self).__init__()
+ try:
+ import scipy.stats as stats
+ self.stats = stats
+ except ImportError:
+ import sys
+ logging.warning("Please install scipy first, such as `pip3 install scipy`.")
+ sys.exit(1)
+ self.cfg: dict = cfg
+ self.e_u_s = None
+ self.e_sigma_s = None
+ if cfg.discrete_action:
+ self.svm = None
+ else:
+ self.e_u_s_a = None
+ self.e_sigma_s_a = None
+ self.p_u_s = None
+ self.p_sigma_s = None
+ self.expert_data = None
+ self.train_data: list = []
+ assert device in ["cpu", "cuda"] or "cuda" in device
+ # pedil default use cpu device
+ self.device = 'cpu'
+
+ self.load_expert_data()
+ states: list = []
+ actions: list = []
+ for item in self.expert_data:
+ states.append(item['obs'])
+ actions.append(item['action'])
+ states: torch.Tensor = torch.stack(states, dim=0)
+ actions: torch.Tensor = torch.stack(actions, dim=0)
+ self.e_u_s: torch.Tensor = torch.mean(states, axis=0)
+ self.e_sigma_s: torch.Tensor = cov(states, rowvar=False)
+ if self.cfg.discrete_action and SVC is None:
+ one_time_warning("You are using discrete action while the SVC is not installed!")
+ if self.cfg.discrete_action and SVC is not None:
+ self.svm: SVC = SVC(probability=True)
+ self.svm.fit(states.cpu().numpy(), actions.cpu().numpy())
+ else:
+ # states action conjuct
+ state_actions = torch.cat((states, actions.float()), dim=-1)
+ self.e_u_s_a = torch.mean(state_actions, axis=0)
+ self.e_sigma_s_a = cov(state_actions, rowvar=False)
+
+ def load_expert_data(self) -> None:
+ """
+ Overview:
+ Getting the expert data from ``config['expert_data_path']`` attribute in self.
+ Effects:
+ This is a side effect function which updates the expert data attribute (e.g. ``self.expert_data``)
+ """
+ expert_data_path: str = self.cfg.expert_data_path
+ with open(expert_data_path, 'rb') as f:
+ self.expert_data: list = pickle.load(f)
+
+ def _train(self, states: torch.Tensor) -> None:
+ """
+ Overview:
+ Helper function for ``train`` which caclulates loss for train data and expert data.
+ Arguments:
+ - states (:obj:`torch.Tensor`): current policy states
+ Effects:
+ - Update attributes of ``p_u_s`` and ``p_sigma_s``
+ """
+ # we only need to collect the current policy state
+ self.p_u_s = torch.mean(states, axis=0)
+ self.p_sigma_s = cov(states, rowvar=False)
+
+ def train(self):
+ """
+ Overview:
+ Training the Pdeil reward model.
+ """
+ states = torch.stack([item['obs'] for item in self.train_data], dim=0)
+ self._train(states)
+
+ def _batch_mn_pdf(self, x: np.ndarray, mean: np.ndarray, cov: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Get multivariate normal pdf of given np array.
+ """
+ return np.asarray(
+ self.stats.multivariate_normal.pdf(x, mean=mean, cov=cov, allow_singular=False), dtype=np.float32
+ )
+
+ def estimate(self, data: list) -> List[Dict]:
+ """
+ Overview:
+ Estimate reward by rewriting the reward keys.
+ Arguments:
+ - data (:obj:`list`): the list of data used for estimation,\
+ with at least ``obs`` and ``action`` keys.
+ Effects:
+ - This is a side effect function which updates the reward values in place.
+ """
+ # NOTE: deepcopy reward part of data is very important,
+ # otherwise the reward of data in the replay buffer will be incorrectly modified.
+ train_data_augmented = self.reward_deepcopy(data)
+ s = torch.stack([item['obs'] for item in train_data_augmented], dim=0)
+ a = torch.stack([item['action'] for item in train_data_augmented], dim=0)
+ if self.p_u_s is None:
+ print("you need to train you reward model first")
+ for item in train_data_augmented:
+ item['reward'].zero_()
+ else:
+ rho_1 = self._batch_mn_pdf(s.cpu().numpy(), self.e_u_s.cpu().numpy(), self.e_sigma_s.cpu().numpy())
+ rho_1 = torch.from_numpy(rho_1)
+ rho_2 = self._batch_mn_pdf(s.cpu().numpy(), self.p_u_s.cpu().numpy(), self.p_sigma_s.cpu().numpy())
+ rho_2 = torch.from_numpy(rho_2)
+ if self.cfg.discrete_action:
+ rho_3 = self.svm.predict_proba(s.cpu().numpy())[a.cpu().numpy()]
+ rho_3 = torch.from_numpy(rho_3)
+ else:
+ s_a = torch.cat([s, a.float()], dim=-1)
+ rho_3 = self._batch_mn_pdf(
+ s_a.cpu().numpy(),
+ self.e_u_s_a.cpu().numpy(),
+ self.e_sigma_s_a.cpu().numpy()
+ )
+ rho_3 = torch.from_numpy(rho_3)
+ rho_3 = rho_3 / rho_1
+ alpha = self.cfg.alpha
+ beta = 1 - alpha
+ den = rho_1 * rho_3
+ frac = alpha * rho_1 + beta * rho_2
+ if frac.abs().max() < 1e-4:
+ for item in train_data_augmented:
+ item['reward'].zero_()
+ else:
+ reward = den / frac
+ reward = torch.chunk(reward, reward.shape[0], dim=0)
+ for item, rew in zip(train_data_augmented, reward):
+ item['reward'] = rew
+ return train_data_augmented
+
+ def collect_data(self, item: list):
+ """
+ Overview:
+ Collecting training data by iterating data items in the input list
+ Arguments:
+ - data (:obj:`list`): Raw training data (e.g. some form of states, actions, obs, etc)
+ Effects:
+ - This is a side effect function which updates the data attribute in ``self`` by \
+ iterating data items in the input data items' list
+ """
+ self.train_data.extend(item)
+
+ def clear_data(self):
+ """
+ Overview:
+ Clearing training data. \
+ This is a side effect function which clears the data attribute in ``self``
+ """
+ self.train_data.clear()
diff --git a/DI-engine/ding/reward_model/pwil_irl_model.py b/DI-engine/ding/reward_model/pwil_irl_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8738ee2d81416063d66db9d71cbe6ee30161f126
--- /dev/null
+++ b/DI-engine/ding/reward_model/pwil_irl_model.py
@@ -0,0 +1,259 @@
+from typing import Dict, List
+import math
+import random
+import pickle
+import torch
+
+from ding.utils import REWARD_MODEL_REGISTRY
+from .base_reward_model import BaseRewardModel
+
+
+def collect_state_action_pairs(iterator):
+ # concat state and action
+ """
+ Overview:
+ Concate state and action pairs from input iterator.
+ Arguments:
+ - iterator (:obj:`Iterable`): Iterables with at least ``obs`` and ``action`` tensor keys.
+ Returns:
+ - res (:obj:`Torch.tensor`): State and action pairs.
+ """
+ res = []
+ for item in iterator:
+ state = item['obs']
+ action = item['action']
+ # s_a = torch.cat([state, action.float()], dim=-1)
+ res.append((state, action))
+ return res
+
+
+@REWARD_MODEL_REGISTRY.register('pwil')
+class PwilRewardModel(BaseRewardModel):
+ """
+ Overview:
+ The Pwil reward model class (https://arxiv.org/pdf/2006.04678.pdf)
+ Interface:
+ ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \
+ ``__init__``, ``_train``, ``_get_state_distance``, ``_get_action_distance``
+ Config:
+ == ================== ===== ============= ======================================= =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ================== ===== ============= ======================================= =======================
+ 1 ``type`` str pwil | Reward model register name, refer |
+ | to registry ``REWARD_MODEL_REGISTRY`` |
+ 2 | ``expert_data_`` str expert_data. | Path to the expert dataset | Should be a '.pkl'
+ | ``path`` .pkl | | file
+ 3 | ``sample_size`` int 1000 | sample data from expert dataset |
+ | with fixed size |
+ 4 | ``alpha`` int 5 | factor alpha |
+ 5 | ``beta`` int 5 | factor beta |
+ 6 | ``s_size`` int 4 | state size |
+ 7 | ``a_size`` int 2 | action size |
+ 8 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay
+ ``_per_iters`` | buffer's data count
+ | isn't too few.
+ | (code work in entry)
+ == ================== ===== ============= ======================================= =======================
+ Properties:
+ - reward_table (:obj: `Dict`): In this algorithm, reward model is a dictionary.
+ """
+ config = dict(
+ # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
+ type='pwil',
+ # (str) Path to the expert dataset.
+ # expert_data_path='expert_data.pkl',
+ # (int) Sample data from expert dataset with fixed size.
+ sample_size=1000,
+ # r = alpha * exp((-beta*T/sqrt(|s_size|+ |a_size|))*c_i)
+ # key idea for this reward is to minimize.
+ # the Wasserstein distance between the state-action distribution.
+ # (int) Factor alpha.
+ alpha=5,
+ # (int) Factor beta.
+ beta=5,
+ #(int)State size.
+ # s_size=4,
+ # (int) Action size.
+ # a_size=2,
+ # (int) Clear buffer per fixed iters.
+ clear_buffer_per_iters=1,
+ )
+
+ def __init__(self, config: Dict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
+ """
+ Overview:
+ Initialize ``self.`` See ``help(type(self))`` for accurate signature.
+ Arguments:
+ - cfg (:obj:`Dict`): Training config
+ - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda"
+ - tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary
+ """
+ super(PwilRewardModel, self).__init__()
+ self.cfg: Dict = config
+ assert device in ["cpu", "cuda"] or "cuda" in device
+ self.device = device
+ self.expert_data: List[tuple] = []
+ self.train_data: List[tuple] = []
+ # In this algo, model is a dict
+ self.reward_table: Dict = {}
+ self.T: int = 0
+
+ self.load_expert_data()
+
+ def load_expert_data(self) -> None:
+ """
+ Overview:
+ Getting the expert data from ``config['expert_data_path']`` attribute in self
+ Effects:
+ This is a side effect function which updates the expert data attribute (e.g. ``self.expert_data``); \
+ in this algorithm, also the ``self.expert_s``, ``self.expert_a`` for states and actions are updated.
+
+ """
+ with open(self.cfg.expert_data_path, 'rb') as f:
+ self.expert_data = pickle.load(f)
+ print("the data size is:", len(self.expert_data))
+ sample_size = min(self.cfg.sample_size, len(self.expert_data))
+ self.expert_data = random.sample(self.expert_data, sample_size)
+ self.expert_data = [(item['obs'], item['action']) for item in self.expert_data]
+ self.expert_s, self.expert_a = list(zip(*self.expert_data))
+ print('the expert data demonstrations is:', len(self.expert_data))
+
+ def collect_data(self, data: list) -> None:
+ """
+ Overview:
+ Collecting training data formatted by ``fn:concat_state_action_pairs``.
+ Arguments:
+ - data (:obj:`list`): Raw training data (e.g. some form of states, actions, obs, etc)
+ Effects:
+ - This is a side effect function which updates the data attribute in ``self``; \
+ in this algorithm, also the ``s_size``, ``a_size`` for states and actions are updated in the \
+ attribute in ``self.cfg`` Dict; ``reward_factor`` also updated as ``collect_data`` called.
+ """
+ self.train_data.extend(collect_state_action_pairs(data))
+ self.T = len(self.train_data)
+
+ s_size = self.cfg.s_size
+ a_size = self.cfg.a_size
+ beta = self.cfg.beta
+ self.reward_factor = -beta * self.T / math.sqrt(s_size + a_size)
+
+ def train(self) -> None:
+ """
+ Overview:
+ Training the Pwil reward model.
+ """
+ self._train(self.train_data)
+
+ def estimate(self, data: list) -> List[Dict]:
+ """
+ Overview:
+ Estimate reward by rewriting the reward key in each row of the data.
+ Arguments:
+ - data (:obj:`list`): the list of data used for estimation, \
+ with at least ``obs`` and ``action`` keys.
+ Effects:
+ - This is a side effect function which updates the ``reward_table`` with ``(obs,action)`` \
+ tuples from input.
+ """
+ # NOTE: deepcopy reward part of data is very important,
+ # otherwise the reward of data in the replay buffer will be incorrectly modified.
+ train_data_augmented = self.reward_deepcopy(data)
+ for item in train_data_augmented:
+ s = item['obs']
+ a = item['action']
+ if (s, a) in self.reward_table:
+ item['reward'] = self.reward_table[(s, a)]
+ else:
+ # when (s, a) pair is not trained, set the reward value to default value(e.g.: 0)
+ item['reward'] = torch.zeros_like(item['reward'])
+ return train_data_augmented
+
+ def _get_state_distance(self, s1: list, s2: list) -> torch.Tensor:
+ """
+ Overview:
+ Getting distances of states given 2 state lists. One single state \
+ is of shape ``torch.Size([n])`` (``n`` referred in in-code comments)
+ Arguments:
+ - s1 (:obj:`torch.Tensor list`): the 1st states' list of size M
+ - s2 (:obj:`torch.Tensor list`): the 2nd states' list of size N
+ Returns:
+ - distance (:obj:`torch.Tensor`) Euclidean distance tensor of \
+ the state tensor lists, of size M x N.
+ """
+ # Format the values in the tensors to be of float type
+ s1 = torch.stack(s1).float()
+ s2 = torch.stack(s2).float()
+ M, N = s1.shape[0], s2.shape[0]
+ # Automatically fill in length
+ s1 = s1.view(M, -1)
+ s2 = s2.view(N, -1)
+ # Automatically fill in & format the tensor size to be (MxNxn)
+ s1 = s1.unsqueeze(1).repeat(1, N, 1)
+ s2 = s2.unsqueeze(0).repeat(M, 1, 1)
+ # Return the distance tensor of size MxN
+ return ((s1 - s2) ** 2).mean(dim=-1)
+
+ def _get_action_distance(self, a1: list, a2: list) -> torch.Tensor:
+ # TODO the metric of action distance maybe different from envs
+ """
+ Overview:
+ Getting distances of actions given 2 action lists. One single action \
+ is of shape ``torch.Size([n])`` (``n`` referred in in-code comments)
+ Arguments:
+ - a1 (:obj:`torch.Tensor list`): the 1st actions' list of size M
+ - a2 (:obj:`torch.Tensor list`): the 2nd actions' list of size N
+ Returns:
+ - distance (:obj:`torch.Tensor`) Euclidean distance tensor of \
+ the action tensor lists, of size M x N.
+ """
+ a1 = torch.stack(a1).float()
+ a2 = torch.stack(a2).float()
+ M, N = a1.shape[0], a2.shape[0]
+ a1 = a1.view(M, -1)
+ a2 = a2.view(N, -1)
+ a1 = a1.unsqueeze(1).repeat(1, N, 1)
+ a2 = a2.unsqueeze(0).repeat(M, 1, 1)
+ return ((a1 - a2) ** 2).mean(dim=-1)
+
+ def _train(self, data: list):
+ """
+ Overview:
+ Helper function for ``train``, find the min disctance ``s_e``, ``a_e``.
+ Arguments:
+ - data (:obj:`list`): Raw training data (e.g. some form of states, actions, obs, etc)
+ Effects:
+ - This is a side effect function which updates the ``reward_table`` attribute in ``self`` .
+ """
+ batch_s, batch_a = list(zip(*data))
+ s_distance_matrix = self._get_state_distance(batch_s, self.expert_s)
+ a_distance_matrix = self._get_action_distance(batch_a, self.expert_a)
+ distance_matrix = s_distance_matrix + a_distance_matrix
+ w_e_list = [1 / len(self.expert_data)] * len(self.expert_data)
+ for i, item in enumerate(data):
+ s, a = item
+ w_pi = 1 / self.T
+ c = 0
+ expert_data_idx = torch.arange(len(self.expert_data)).tolist()
+ while w_pi > 0:
+ selected_dist = distance_matrix[i, expert_data_idx]
+ nearest_distance = selected_dist.min().item()
+ nearest_index_selected = selected_dist.argmin().item()
+ nearest_index = expert_data_idx[nearest_index_selected]
+ if w_pi >= w_e_list[nearest_index]:
+ c = c + nearest_distance * w_e_list[nearest_index]
+ w_pi = w_pi - w_e_list[nearest_index]
+ expert_data_idx.pop(nearest_index_selected)
+ else:
+ c = c + w_pi * nearest_distance
+ w_e_list[nearest_index] = w_e_list[nearest_index] - w_pi
+ w_pi = 0
+ reward = self.cfg.alpha * math.exp(self.reward_factor * c)
+ self.reward_table[(s, a)] = torch.FloatTensor([reward])
+
+ def clear_data(self) -> None:
+ """
+ Overview:
+ Clearing training data. \
+ This is a side effect function which clears the data attribute in ``self``
+ """
+ self.train_data.clear()
diff --git a/DI-engine/ding/reward_model/red_irl_model.py b/DI-engine/ding/reward_model/red_irl_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7daeeceec556698be6f516dfdc2fcaae82c9011
--- /dev/null
+++ b/DI-engine/ding/reward_model/red_irl_model.py
@@ -0,0 +1,214 @@
+from typing import Dict, List
+import pickle
+import random
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from ding.utils import REWARD_MODEL_REGISTRY, one_time_warning
+from .base_reward_model import BaseRewardModel
+
+
+class SENet(nn.Module):
+ """support estimation network"""
+
+ def __init__(self, input_size: int, hidden_size: int, output_dims: int) -> None:
+ super(SENet, self).__init__()
+ self.l_1 = nn.Linear(input_size, hidden_size)
+ self.l_2 = nn.Linear(hidden_size, output_dims)
+ self.act = nn.Tanh()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ out = self.l_1(x)
+ out = self.act(out)
+ out = self.l_2(out)
+ out = self.act(out)
+ return out
+
+
+@REWARD_MODEL_REGISTRY.register('red')
+class RedRewardModel(BaseRewardModel):
+ """
+ Overview:
+ The implement of reward model in RED (https://arxiv.org/abs/1905.06750)
+ Interface:
+ ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \
+ ``__init__``, ``_train``
+ Config:
+ == ================== ===== ============= ======================================= =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ================== ===== ============= ======================================= =======================
+ 1 ``type`` str red | Reward model register name, refer |
+ | to registry ``REWARD_MODEL_REGISTRY`` |
+ 2 | ``expert_data_`` str expert_data | Path to the expert dataset | Should be a '.pkl'
+ | ``path`` .pkl | | file
+ 3 | ``sample_size`` int 1000 | sample data from expert dataset |
+ | with fixed size |
+ 4 | ``sigma`` int 5 | hyperparameter of r(s,a) | r(s,a) = exp(
+ | -sigma* L(s,a))
+ 5 | ``batch_size`` int 64 | Training batch size |
+ 6 | ``hidden_size`` int 128 | Linear model hidden size |
+ 7 | ``update_per_`` int 100 | Number of updates per collect |
+ | ``collect`` | |
+ 8 | ``clear_buffer`` int 1 | clear buffer per fixed iters | make sure replay
+ ``_per_iters`` | buffer's data count
+ | isn't too few.
+ | (code work in entry)
+ == ================== ===== ============= ======================================= =======================
+ Properties:
+ - online_net (:obj: `SENet`): The reward model, in default initialized once as the training begins.
+ """
+ config = dict(
+ # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
+ type='red',
+ # (int) Linear model input size.
+ # input_size=4,
+ # (int) Sample data from expert dataset with fixed size.
+ sample_size=1000,
+ # (int) Linear model hidden size.
+ hidden_size=128,
+ # (float) The step size of gradient descent.
+ learning_rate=1e-3,
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=100,
+ # (str) Path to the expert dataset
+ # expert_data_path='expert_data.pkl',
+ # (int) How many samples in a training batch.
+ batch_size=64,
+ # (float) Hyperparameter at estimated score of r(s,a).
+ # r(s,a) = exp(-sigma* L(s,a))
+ sigma=0.5,
+ # (int) Clear buffer per fixed iters.
+ clear_buffer_per_iters=1,
+ )
+
+ def __init__(self, config: Dict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
+ """
+ Overview:
+ Initialize ``self.`` See ``help(type(self))`` for accurate signature.
+ Arguments:
+ - cfg (:obj:`Dict`): Training config
+ - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda"
+ - tb_logger (:obj:`str`): Logger, defaultly set as 'SummaryWriter' for model summary
+ """
+ super(RedRewardModel, self).__init__()
+ self.cfg: Dict = config
+ self.expert_data: List[tuple] = []
+ self.device = device
+ assert device in ["cpu", "cuda"] or "cuda" in device
+ self.tb_logger = tb_logger
+ self.target_net: SENet = SENet(config.input_size, config.hidden_size, 1)
+ self.online_net: SENet = SENet(config.input_size, config.hidden_size, 1)
+ self.target_net.to(device)
+ self.online_net.to(device)
+ self.opt: optim.Adam = optim.Adam(self.online_net.parameters(), config.learning_rate)
+ self.train_once_flag = False
+
+ self.load_expert_data()
+
+ def load_expert_data(self) -> None:
+ """
+ Overview:
+ Getting the expert data from ``config['expert_data_path']`` attribute in self.
+ Effects:
+ This is a side effect function which updates the expert data attribute (e.g. ``self.expert_data``)
+ """
+ with open(self.cfg.expert_data_path, 'rb') as f:
+ self.expert_data = pickle.load(f)
+ sample_size = min(len(self.expert_data), self.cfg.sample_size)
+ self.expert_data = random.sample(self.expert_data, sample_size)
+ print('the expert data size is:', len(self.expert_data))
+
+ def _train(self, batch_data: torch.Tensor) -> float:
+ """
+ Overview:
+ Helper function for ``train`` which caclulates loss for train data and expert data.
+ Arguments:
+ - batch_data (:obj:`torch.Tensor`): Data used for training
+ Returns:
+ - Combined loss calculated of reward model from using ``batch_data`` in both target and reward models.
+ """
+ with torch.no_grad():
+ target = self.target_net(batch_data)
+ hat: torch.Tensor = self.online_net(batch_data)
+ loss: torch.Tensor = ((hat - target) ** 2).mean()
+ self.opt.zero_grad()
+ loss.backward()
+ self.opt.step()
+ return loss.item()
+
+ def train(self) -> None:
+ """
+ Overview:
+ Training the RED reward model. In default, RED model should be trained once.
+ Effects:
+ - This is a side effect function which updates the reward model and increment the train iteration count.
+ """
+ if self.train_once_flag:
+ one_time_warning('RED model should be trained once, we do not train it anymore')
+ else:
+ for i in range(self.cfg.update_per_collect):
+ sample_batch = random.sample(self.expert_data, self.cfg.batch_size)
+ states_data = []
+ actions_data = []
+ for item in sample_batch:
+ states_data.append(item['obs'])
+ actions_data.append(item['action'])
+ states_tensor: torch.Tensor = torch.stack(states_data).float()
+ actions_tensor: torch.Tensor = torch.stack(actions_data).float()
+ states_actions_tensor: torch.Tensor = torch.cat([states_tensor, actions_tensor], dim=1)
+ states_actions_tensor = states_actions_tensor.to(self.device)
+ loss = self._train(states_actions_tensor)
+ self.tb_logger.add_scalar('reward_model/red_loss', loss, i)
+ self.train_once_flag = True
+
+ def estimate(self, data: list) -> List[Dict]:
+ """
+ Overview:
+ Estimate reward by rewriting the reward key
+ Arguments:
+ - data (:obj:`list`): the list of data used for estimation, \
+ with at least ``obs`` and ``action`` keys.
+ Effects:
+ - This is a side effect function which updates the reward values in place.
+ """
+ # NOTE: deepcopy reward part of data is very important,
+ # otherwise the reward of data in the replay buffer will be incorrectly modified.
+ train_data_augmented = self.reward_deepcopy(data)
+ states_data = []
+ actions_data = []
+ for item in train_data_augmented:
+ states_data.append(item['obs'])
+ actions_data.append(item['action'])
+ states_tensor = torch.stack(states_data).float()
+ actions_tensor = torch.stack(actions_data).float()
+ states_actions_tensor = torch.cat([states_tensor, actions_tensor], dim=1)
+ states_actions_tensor = states_actions_tensor.to(self.device)
+ with torch.no_grad():
+ hat_1 = self.online_net(states_actions_tensor)
+ hat_2 = self.target_net(states_actions_tensor)
+ c = ((hat_1 - hat_2) ** 2).mean(dim=1)
+ r = torch.exp(-self.cfg.sigma * c)
+ for item, rew in zip(train_data_augmented, r):
+ item['reward'] = rew
+ return train_data_augmented
+
+ def collect_data(self, data) -> None:
+ """
+ Overview:
+ Collecting training data, not implemented if reward model (i.e. online_net) is only trained ones, \
+ if online_net is trained continuously, there should be some implementations in collect_data method
+ """
+ # if online_net is trained continuously, there should be some implementations in collect_data method
+ pass
+
+ def clear_data(self):
+ """
+ Overview:
+ Collecting clearing data, not implemented if reward model (i.e. online_net) is only trained ones, \
+ if online_net is trained continuously, there should be some implementations in clear_data method
+ """
+ # if online_net is trained continuously, there should be some implementations in clear_data method
+ pass
diff --git a/DI-engine/ding/reward_model/rnd_reward_model.py b/DI-engine/ding/reward_model/rnd_reward_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..00bb1542fd84a353d249506026df7da4d1cc69cc
--- /dev/null
+++ b/DI-engine/ding/reward_model/rnd_reward_model.py
@@ -0,0 +1,235 @@
+from typing import Union, Tuple, List, Dict
+from easydict import EasyDict
+
+import random
+import torch
+import torch.nn as nn
+import torch.optim as optim
+import torch.nn.functional as F
+
+from ding.utils import SequenceType, REWARD_MODEL_REGISTRY
+from ding.model import FCEncoder, ConvEncoder
+from .base_reward_model import BaseRewardModel
+from ding.utils import RunningMeanStd
+from ding.torch_utils.data_helper import to_tensor
+import numpy as np
+
+
+def collect_states(iterator):
+ res = []
+ for item in iterator:
+ state = item['obs']
+ res.append(state)
+ return res
+
+
+class RndNetwork(nn.Module):
+
+ def __init__(self, obs_shape: Union[int, SequenceType], hidden_size_list: SequenceType) -> None:
+ super(RndNetwork, self).__init__()
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.target = FCEncoder(obs_shape, hidden_size_list)
+ self.predictor = FCEncoder(obs_shape, hidden_size_list)
+ elif len(obs_shape) == 3:
+ self.target = ConvEncoder(obs_shape, hidden_size_list)
+ self.predictor = ConvEncoder(obs_shape, hidden_size_list)
+ else:
+ raise KeyError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own RND model".
+ format(obs_shape)
+ )
+ for param in self.target.parameters():
+ param.requires_grad = False
+
+ def forward(self, obs: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ predict_feature = self.predictor(obs)
+ with torch.no_grad():
+ target_feature = self.target(obs)
+ return predict_feature, target_feature
+
+
+@REWARD_MODEL_REGISTRY.register('rnd')
+class RndRewardModel(BaseRewardModel):
+ """
+ Overview:
+ The RND reward model class (https://arxiv.org/abs/1810.12894v1)
+ Interface:
+ ``estimate``, ``train``, ``collect_data``, ``clear_data``, \
+ ``__init__``, ``_train``, ``load_state_dict``, ``state_dict``
+ Config:
+ == ==================== ===== ============= ======================================= =======================
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ===== ============= ======================================= =======================
+ 1 ``type`` str rnd | Reward model register name, refer |
+ | to registry ``REWARD_MODEL_REGISTRY`` |
+ 2 | ``intrinsic_`` str add | the intrinsic reward type | including add, new
+ | ``reward_type`` | | , or assign
+ 3 | ``learning_rate`` float 0.001 | The step size of gradient descent |
+ 4 | ``batch_size`` int 64 | Training batch size |
+ 5 | ``hidden`` list [64, 64, | the MLP layer shape |
+ | ``_size_list`` (int) 128] | |
+ 6 | ``update_per_`` int 100 | Number of updates per collect |
+ | ``collect`` | |
+ 7 | ``obs_norm`` bool True | Observation normalization |
+ 8 | ``obs_norm_`` int 0 | min clip value for obs normalization |
+ | ``clamp_min``
+ 9 | ``obs_norm_`` int 1 | max clip value for obs normalization |
+ | ``clamp_max``
+ 10 | ``intrinsic_`` float 0.01 | the weight of intrinsic reward | r = w*r_i + r_e
+ ``reward_weight``
+ 11 | ``extrinsic_`` bool True | Whether to normlize extrinsic reward
+ ``reward_norm``
+ 12 | ``extrinsic_`` int 1 | the upper bound of the reward
+ ``reward_norm_max`` | normalization
+ == ==================== ===== ============= ======================================= =======================
+ """
+ config = dict(
+ # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
+ type='rnd',
+ # (str) The intrinsic reward type, including add, new, or assign.
+ intrinsic_reward_type='add',
+ # (float) The step size of gradient descent.
+ learning_rate=1e-3,
+ # (float) Batch size.
+ batch_size=64,
+ # (list(int)) Sequence of ``hidden_size`` of reward network.
+ # If obs.shape == 1, use MLP layers.
+ # If obs.shape == 3, use conv layer and final dense layer.
+ hidden_size_list=[64, 64, 128],
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=100,
+ # (bool) Observation normalization: transform obs to mean 0, std 1.
+ obs_norm=True,
+ # (int) Min clip value for observation normalization.
+ obs_norm_clamp_min=-1,
+ # (int) Max clip value for observation normalization.
+ obs_norm_clamp_max=1,
+ # Means the relative weight of RND intrinsic_reward.
+ # (float) The weight of intrinsic reward
+ # r = intrinsic_reward_weight * r_i + r_e.
+ intrinsic_reward_weight=0.01,
+ # (bool) Whether to normlize extrinsic reward.
+ # Normalize the reward to [0, extrinsic_reward_norm_max].
+ extrinsic_reward_norm=True,
+ # (int) The upper bound of the reward normalization.
+ extrinsic_reward_norm_max=1,
+ )
+
+ def __init__(self, config: EasyDict, device: str = 'cpu', tb_logger: 'SummaryWriter' = None) -> None: # noqa
+ super(RndRewardModel, self).__init__()
+ self.cfg = config
+ assert device == "cpu" or device.startswith("cuda")
+ self.device = device
+ if tb_logger is None: # TODO
+ from tensorboardX import SummaryWriter
+ tb_logger = SummaryWriter('rnd_reward_model')
+ self.tb_logger = tb_logger
+ self.reward_model = RndNetwork(config.obs_shape, config.hidden_size_list)
+ self.reward_model.to(self.device)
+ self.intrinsic_reward_type = config.intrinsic_reward_type
+ assert self.intrinsic_reward_type in ['add', 'new', 'assign']
+ self.train_obs = []
+ self.opt = optim.Adam(self.reward_model.predictor.parameters(), config.learning_rate)
+ self._running_mean_std_rnd_reward = RunningMeanStd(epsilon=1e-4)
+ self.estimate_cnt_rnd = 0
+ self.train_cnt_icm = 0
+ self._running_mean_std_rnd_obs = RunningMeanStd(epsilon=1e-4)
+
+ def _train(self) -> None:
+ train_data: list = random.sample(self.train_obs, self.cfg.batch_size)
+ train_data: torch.Tensor = torch.stack(train_data).to(self.device)
+ if self.cfg.obs_norm:
+ # Note: observation normalization: transform obs to mean 0, std 1
+ self._running_mean_std_rnd_obs.update(train_data.cpu().numpy())
+ train_data = (train_data - to_tensor(self._running_mean_std_rnd_obs.mean).to(self.device)) / to_tensor(
+ self._running_mean_std_rnd_obs.std
+ ).to(self.device)
+ train_data = torch.clamp(train_data, min=self.cfg.obs_norm_clamp_min, max=self.cfg.obs_norm_clamp_max)
+
+ predict_feature, target_feature = self.reward_model(train_data)
+ loss = F.mse_loss(predict_feature, target_feature.detach())
+ self.tb_logger.add_scalar('rnd_reward/loss', loss, self.train_cnt_icm)
+ self.opt.zero_grad()
+ loss.backward()
+ self.opt.step()
+
+ def train(self) -> None:
+ for _ in range(self.cfg.update_per_collect):
+ self._train()
+ self.train_cnt_icm += 1
+
+ def estimate(self, data: list) -> List[Dict]:
+ """
+ Rewrite the reward key in each row of the data.
+ """
+ # NOTE: deepcopy reward part of data is very important,
+ # otherwise the reward of data in the replay buffer will be incorrectly modified.
+ train_data_augmented = self.reward_deepcopy(data)
+
+ obs = collect_states(train_data_augmented)
+ obs = torch.stack(obs).to(self.device)
+ if self.cfg.obs_norm:
+ # Note: observation normalization: transform obs to mean 0, std 1
+ obs = (obs - to_tensor(self._running_mean_std_rnd_obs.mean
+ ).to(self.device)) / to_tensor(self._running_mean_std_rnd_obs.std).to(self.device)
+ obs = torch.clamp(obs, min=self.cfg.obs_norm_clamp_min, max=self.cfg.obs_norm_clamp_max)
+
+ with torch.no_grad():
+ predict_feature, target_feature = self.reward_model(obs)
+ mse = F.mse_loss(predict_feature, target_feature, reduction='none').mean(dim=1)
+ self._running_mean_std_rnd_reward.update(mse.cpu().numpy())
+
+ # Note: according to the min-max normalization, transform rnd reward to [0,1]
+ rnd_reward = (mse - mse.min()) / (mse.max() - mse.min() + 1e-8)
+
+ # save the rnd_reward statistics into tb_logger
+ self.estimate_cnt_rnd += 1
+ self.tb_logger.add_scalar('rnd_reward/rnd_reward_max', rnd_reward.max(), self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('rnd_reward/rnd_reward_mean', rnd_reward.mean(), self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('rnd_reward/rnd_reward_min', rnd_reward.min(), self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('rnd_reward/rnd_reward_std', rnd_reward.std(), self.estimate_cnt_rnd)
+
+ rnd_reward = rnd_reward.to(self.device)
+ rnd_reward = torch.chunk(rnd_reward, rnd_reward.shape[0], dim=0)
+ """
+ NOTE: Following normalization approach to extrinsic reward seems be not reasonable,
+ because this approach compresses the extrinsic reward magnitude, resulting in less informative reward signals.
+ """
+ # rewards = torch.stack([data[i]['reward'] for i in range(len(data))])
+ # rewards = (rewards - torch.min(rewards)) / (torch.max(rewards) - torch.min(rewards))
+
+ for item, rnd_rew in zip(train_data_augmented, rnd_reward):
+ if self.intrinsic_reward_type == 'add':
+ if self.cfg.extrinsic_reward_norm:
+ item['reward'] = item[
+ 'reward'] / self.cfg.extrinsic_reward_norm_max + rnd_rew * self.cfg.intrinsic_reward_weight
+ else:
+ item['reward'] = item['reward'] + rnd_rew * self.cfg.intrinsic_reward_weight
+ elif self.intrinsic_reward_type == 'new':
+ item['intrinsic_reward'] = rnd_rew
+ if self.cfg.extrinsic_reward_norm:
+ item['reward'] = item['reward'] / self.cfg.extrinsic_reward_norm_max
+ elif self.intrinsic_reward_type == 'assign':
+ item['reward'] = rnd_rew
+
+ # save the augmented_reward statistics into tb_logger
+ rew = [item['reward'].cpu().numpy() for item in train_data_augmented]
+ self.tb_logger.add_scalar('augmented_reward/reward_max', np.max(rew), self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('augmented_reward/reward_mean', np.mean(rew), self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('augmented_reward/reward_min', np.min(rew), self.estimate_cnt_rnd)
+ self.tb_logger.add_scalar('augmented_reward/reward_std', np.std(rew), self.estimate_cnt_rnd)
+ return train_data_augmented
+
+ def collect_data(self, data: list) -> None:
+ self.train_obs.extend(collect_states(data))
+
+ def clear_data(self) -> None:
+ self.train_obs.clear()
+
+ def state_dict(self) -> Dict:
+ return self.reward_model.state_dict()
+
+ def load_state_dict(self, _state_dict: Dict) -> None:
+ self.reward_model.load_state_dict(_state_dict)
diff --git a/DI-engine/ding/reward_model/tests/test_gail_irl_model.py b/DI-engine/ding/reward_model/tests/test_gail_irl_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..bac9d3d9503dce817be5ebcd9f4183c15a0e46d0
--- /dev/null
+++ b/DI-engine/ding/reward_model/tests/test_gail_irl_model.py
@@ -0,0 +1,104 @@
+import pytest
+import torch
+from easydict import EasyDict
+from ding.reward_model.gail_irl_model import GailRewardModel
+from ding.utils.data import offline_data_save_type
+from tensorboardX import SummaryWriter
+import os
+
+obs_space_1d, obs_space_3d = 4, [4, 84, 84]
+expert_data_path_1d, expert_data_path_3d = './expert_data_1d', './expert_data_3d'
+if not os.path.exists('./expert_data_1d'):
+ try:
+ os.mkdir('./expert_data_1d')
+ except FileExistsError:
+ pass
+if not os.path.exists('./expert_data_3d'):
+ try:
+ os.mkdir('./expert_data_3d')
+ except FileExistsError:
+ pass
+device = 'cpu'
+action_space = 3
+
+cfg1 = dict(
+ input_size=obs_space_1d + 1,
+ hidden_size=64,
+ batch_size=5,
+ learning_rate=1e-3,
+ update_per_collect=2,
+ data_path=expert_data_path_1d,
+),
+
+cfg2 = dict(
+ input_size=obs_space_3d,
+ hidden_size=64,
+ batch_size=5,
+ learning_rate=1e-3,
+ update_per_collect=2,
+ data_path=expert_data_path_3d,
+ action_size=action_space,
+),
+
+# create fake expert dataset
+data_1d = []
+for i in range(20):
+ d = {}
+ d['obs'] = torch.zeros(obs_space_1d)
+ d['action'] = torch.Tensor([1.])
+ data_1d.append(d)
+
+data_3d = []
+for i in range(20):
+ d = {}
+ d['obs'] = torch.zeros(obs_space_3d)
+ d['action'] = torch.Tensor([1.])
+ data_3d.append(d)
+
+
+@pytest.mark.parametrize('cfg', cfg1)
+@pytest.mark.unittest
+def test_dataset_1d(cfg):
+ offline_data_save_type(
+ exp_data=data_1d, expert_data_path=expert_data_path_1d + '/expert_data.pkl', data_type='naive'
+ )
+ data = data_1d
+ cfg = EasyDict(cfg)
+ policy = GailRewardModel(cfg, device, tb_logger=SummaryWriter())
+ policy.load_expert_data()
+ assert len(policy.expert_data) == 20
+ state = policy.state_dict()
+ policy.load_state_dict(state)
+ policy.collect_data(data)
+ assert len(policy.train_data) == 20
+ for _ in range(5):
+ policy.train()
+ train_data_augmented = policy.estimate(data)
+ assert 'reward' in train_data_augmented[0].keys()
+ policy.clear_data()
+ assert len(policy.train_data) == 0
+ os.popen('rm -rf {}'.format(expert_data_path_1d))
+
+
+@pytest.mark.parametrize('cfg', cfg2)
+@pytest.mark.unittest
+def test_dataset_3d(cfg):
+ offline_data_save_type(
+ exp_data=data_3d, expert_data_path=expert_data_path_3d + '/expert_data.pkl', data_type='naive'
+ )
+ data = data_3d
+ cfg = EasyDict(cfg)
+ policy = GailRewardModel(cfg, device, tb_logger=SummaryWriter())
+ policy.load_expert_data()
+ assert len(policy.expert_data) == 20
+ state = policy.state_dict()
+ policy.load_state_dict(state)
+ policy.collect_data(data)
+ assert len(policy.train_data) == 20
+ for _ in range(5):
+ policy.train()
+ train_data_augmented = policy.estimate(data)
+ assert 'reward' in train_data_augmented[0].keys()
+ policy.clear_data()
+ assert len(policy.train_data) == 0
+ os.popen('rm -rf {}'.format(expert_data_path_3d))
diff --git a/DI-engine/ding/reward_model/trex_reward_model.py b/DI-engine/ding/reward_model/trex_reward_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..635dc5e75e648a29582e767b47342f60e014f73e
--- /dev/null
+++ b/DI-engine/ding/reward_model/trex_reward_model.py
@@ -0,0 +1,440 @@
+from copy import deepcopy
+from typing import Tuple, Optional, List, Dict
+from easydict import EasyDict
+import pickle
+import os
+import numpy as np
+
+import torch
+import torch.nn as nn
+import torch.optim as optim
+
+from ding.utils import REWARD_MODEL_REGISTRY
+from ding.utils import SequenceType
+from ding.model.common import FCEncoder
+from ding.utils import build_logger
+from ding.utils.data import default_collate
+
+from .base_reward_model import BaseRewardModel
+from .rnd_reward_model import collect_states
+
+
+class TrexConvEncoder(nn.Module):
+ r"""
+ Overview:
+ The ``Convolution Encoder`` used in models. Used to encoder raw 2-dim observation.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ obs_shape: SequenceType,
+ hidden_size_list: SequenceType = [16, 16, 16, 16, 64, 1],
+ activation: Optional[nn.Module] = nn.LeakyReLU()
+ ) -> None:
+ r"""
+ Overview:
+ Init the Trex Convolution Encoder according to arguments. TrexConvEncoder is different \
+ from the ConvEncoder in model.common.encoder, their stride and kernel size parameters \
+ are different
+ Arguments:
+ - obs_shape (:obj:`SequenceType`): Sequence of ``in_channel``, some ``output size``
+ - hidden_size_list (:obj:`SequenceType`): The collection of ``hidden_size``
+ - activation (:obj:`nn.Module`):
+ The type of activation to use in the conv ``layers``,
+ if ``None`` then default set to ``nn.LeakyReLU()``
+ """
+ super(TrexConvEncoder, self).__init__()
+ self.obs_shape = obs_shape
+ self.act = activation
+ self.hidden_size_list = hidden_size_list
+
+ layers = []
+ kernel_size = [7, 5, 3, 3]
+ stride = [3, 2, 1, 1]
+ input_size = obs_shape[0] # in_channel
+ for i in range(len(kernel_size)):
+ layers.append(nn.Conv2d(input_size, hidden_size_list[i], kernel_size[i], stride[i]))
+ layers.append(self.act)
+ input_size = hidden_size_list[i]
+ layers.append(nn.Flatten())
+ self.main = nn.Sequential(*layers)
+
+ flatten_size = self._get_flatten_size()
+ self.mid = nn.Sequential(
+ nn.Linear(flatten_size, hidden_size_list[-2]), self.act,
+ nn.Linear(hidden_size_list[-2], hidden_size_list[-1])
+ )
+
+ def _get_flatten_size(self) -> int:
+ r"""
+ Overview:
+ Get the encoding size after ``self.main`` to get the number of ``in-features`` to feed to ``nn.Linear``.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Encoded Tensor after ``self.main``
+ Returns:
+ - outputs (:obj:`torch.Tensor`): Size int, also number of in-feature
+ """
+ test_data = torch.randn(1, *self.obs_shape)
+ with torch.no_grad():
+ output = self.main(test_data)
+ return output.shape[1]
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ r"""
+ Overview:
+ Return embedding tensor of the env observation
+ Arguments:
+ - x (:obj:`torch.Tensor`): Env raw observation
+ Returns:
+ - outputs (:obj:`torch.Tensor`): Embedding tensor
+ """
+ x = self.main(x)
+ x = self.mid(x)
+ return x
+
+
+class TrexModel(nn.Module):
+
+ def __init__(self, obs_shape):
+ super(TrexModel, self).__init__()
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.encoder = nn.Sequential(FCEncoder(obs_shape, [512, 64]), nn.Linear(64, 1))
+ # Conv Encoder
+ elif len(obs_shape) == 3:
+ self.encoder = TrexConvEncoder(obs_shape)
+ else:
+ raise KeyError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own Trex model".
+ format(obs_shape)
+ )
+
+ def cum_return(self, traj: torch.Tensor, mode: str = 'sum') -> Tuple[torch.Tensor, torch.Tensor]:
+ '''calculate cumulative return of trajectory'''
+ r = self.encoder(traj)
+ if mode == 'sum':
+ sum_rewards = torch.sum(r)
+ sum_abs_rewards = torch.sum(torch.abs(r))
+ return sum_rewards, sum_abs_rewards
+ elif mode == 'batch':
+ return r, torch.abs(r)
+ else:
+ raise KeyError("not support mode: {}, please choose mode=sum or mode=batch".format(mode))
+
+ def forward(self, traj_i: torch.Tensor, traj_j: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+ '''compute cumulative return for each trajectory and return logits'''
+ cum_r_i, abs_r_i = self.cum_return(traj_i)
+ cum_r_j, abs_r_j = self.cum_return(traj_j)
+ return torch.cat((cum_r_i.unsqueeze(0), cum_r_j.unsqueeze(0)), 0), abs_r_i + abs_r_j
+
+
+@REWARD_MODEL_REGISTRY.register('trex')
+class TrexRewardModel(BaseRewardModel):
+ """
+ Overview:
+ The Trex reward model class (https://arxiv.org/pdf/1904.06387.pdf)
+ Interface:
+ ``estimate``, ``train``, ``load_expert_data``, ``collect_data``, ``clear_date``, \
+ ``__init__``, ``_train``,
+ Config:
+ == ==================== ====== ============= ============================================ =============
+ ID Symbol Type Default Value Description Other(Shape)
+ == ==================== ====== ============= ============================================ =============
+ 1 ``type`` str trex | Reward model register name, refer |
+ | to registry ``REWARD_MODEL_REGISTRY`` |
+ 3 | ``learning_rate`` float 0.00001 | learning rate for optimizer |
+ 4 | ``update_per_`` int 100 | Number of updates per collect |
+ | ``collect`` | |
+ 5 | ``num_trajs`` int 0 | Number of downsampled full trajectories |
+ 6 | ``num_snippets`` int 6000 | Number of short subtrajectories to sample |
+ == ==================== ====== ============= ============================================ =============
+ """
+ config = dict(
+ # (str) Reward model register name, refer to registry ``REWARD_MODEL_REGISTRY``.
+ type='trex',
+ # (float) The step size of gradient descent.
+ learning_rate=1e-5,
+ # (int) How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=100,
+ # (int) Number of downsampled full trajectories.
+ num_trajs=0,
+ # (int) Number of short subtrajectories to sample.
+ num_snippets=6000,
+ )
+
+ def __init__(self, config: EasyDict, device: str, tb_logger: 'SummaryWriter') -> None: # noqa
+ """
+ Overview:
+ Initialize ``self.`` See ``help(type(self))`` for accurate signature.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Training config
+ - device (:obj:`str`): Device usage, i.e. "cpu" or "cuda"
+ - tb_logger (:obj:`SummaryWriter`): Logger, defaultly set as 'SummaryWriter' for model summary
+ """
+ super(TrexRewardModel, self).__init__()
+ self.cfg = config
+ assert device in ["cpu", "cuda"] or "cuda" in device
+ self.device = device
+ self.tb_logger = tb_logger
+ self.reward_model = TrexModel(self.cfg.policy.model.obs_shape)
+ self.reward_model.to(self.device)
+ self.pre_expert_data = []
+ self.train_data = []
+ self.expert_data_loader = None
+ self.opt = optim.Adam(self.reward_model.parameters(), config.reward_model.learning_rate)
+ self.train_iter = 0
+ self.learning_returns = []
+ self.training_obs = []
+ self.training_labels = []
+ self.num_trajs = self.cfg.reward_model.num_trajs
+ self.num_snippets = self.cfg.reward_model.num_snippets
+ # minimum number of short subtrajectories to sample
+ self.min_snippet_length = config.reward_model.min_snippet_length
+ # maximum number of short subtrajectories to sample
+ self.max_snippet_length = config.reward_model.max_snippet_length
+ self.l1_reg = 0
+ self.data_for_save = {}
+ self._logger, self._tb_logger = build_logger(
+ path='./{}/log/{}'.format(self.cfg.exp_name, 'trex_reward_model'), name='trex_reward_model'
+ )
+ self.load_expert_data()
+
+ def load_expert_data(self) -> None:
+ """
+ Overview:
+ Getting the expert data.
+ Effects:
+ This is a side effect function which updates the expert data attribute \
+ (i.e. ``self.expert_data``) with ``fn:concat_state_action_pairs``
+ """
+ with open(os.path.join(self.cfg.exp_name, 'episodes_data.pkl'), 'rb') as f:
+ self.pre_expert_data = pickle.load(f)
+ with open(os.path.join(self.cfg.exp_name, 'learning_returns.pkl'), 'rb') as f:
+ self.learning_returns = pickle.load(f)
+
+ self.create_training_data()
+ self._logger.info("num_training_obs: {}".format(len(self.training_obs)))
+ self._logger.info("num_labels: {}".format(len(self.training_labels)))
+
+ def create_training_data(self):
+ num_trajs = self.num_trajs
+ num_snippets = self.num_snippets
+ min_snippet_length = self.min_snippet_length
+ max_snippet_length = self.max_snippet_length
+
+ demo_lengths = []
+ for i in range(len(self.pre_expert_data)):
+ demo_lengths.append([len(d) for d in self.pre_expert_data[i]])
+
+ self._logger.info("demo_lengths: {}".format(demo_lengths))
+ max_snippet_length = min(np.min(demo_lengths), max_snippet_length)
+ self._logger.info("min snippet length: {}".format(min_snippet_length))
+ self._logger.info("max snippet length: {}".format(max_snippet_length))
+
+ # collect training data
+ max_traj_length = 0
+ num_bins = len(self.pre_expert_data)
+ assert num_bins >= 2
+
+ # add full trajs (for use on Enduro)
+ si = np.random.randint(6, size=num_trajs)
+ sj = np.random.randint(6, size=num_trajs)
+ step = np.random.randint(3, 7, size=num_trajs)
+ for n in range(num_trajs):
+ # pick two random demonstrations
+ bi, bj = np.random.choice(num_bins, size=(2, ), replace=False)
+ ti = np.random.choice(len(self.pre_expert_data[bi]))
+ tj = np.random.choice(len(self.pre_expert_data[bj]))
+ # create random partial trajs by finding random start frame and random skip frame
+ traj_i = self.pre_expert_data[bi][ti][si[n]::step[n]] # slice(start,stop,step)
+ traj_j = self.pre_expert_data[bj][tj][sj[n]::step[n]]
+
+ label = int(bi <= bj)
+
+ self.training_obs.append((traj_i, traj_j))
+ self.training_labels.append(label)
+ max_traj_length = max(max_traj_length, len(traj_i), len(traj_j))
+
+ # fixed size snippets with progress prior
+ rand_length = np.random.randint(min_snippet_length, max_snippet_length, size=num_snippets)
+ for n in range(num_snippets):
+ # pick two random demonstrations
+ bi, bj = np.random.choice(num_bins, size=(2, ), replace=False)
+ ti = np.random.choice(len(self.pre_expert_data[bi]))
+ tj = np.random.choice(len(self.pre_expert_data[bj]))
+ # create random snippets
+ # find min length of both demos to ensure we can pick a demo no earlier
+ # than that chosen in worse preferred demo
+ min_length = min(len(self.pre_expert_data[bi][ti]), len(self.pre_expert_data[bj][tj]))
+ if bi < bj: # pick tj snippet to be later than ti
+ ti_start = np.random.randint(min_length - rand_length[n] + 1)
+ # print(ti_start, len(demonstrations[tj]))
+ tj_start = np.random.randint(ti_start, len(self.pre_expert_data[bj][tj]) - rand_length[n] + 1)
+ else: # ti is better so pick later snippet in ti
+ tj_start = np.random.randint(min_length - rand_length[n] + 1)
+ # print(tj_start, len(demonstrations[ti]))
+ ti_start = np.random.randint(tj_start, len(self.pre_expert_data[bi][ti]) - rand_length[n] + 1)
+ # skip everyother framestack to reduce size
+ traj_i = self.pre_expert_data[bi][ti][ti_start:ti_start + rand_length[n]:2]
+ traj_j = self.pre_expert_data[bj][tj][tj_start:tj_start + rand_length[n]:2]
+
+ max_traj_length = max(max_traj_length, len(traj_i), len(traj_j))
+ label = int(bi <= bj)
+ self.training_obs.append((traj_i, traj_j))
+ self.training_labels.append(label)
+ self._logger.info(("maximum traj length: {}".format(max_traj_length)))
+ return self.training_obs, self.training_labels
+
+ def _train(self):
+ # check if gpu available
+ device = self.device # torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
+ # Assume that we are on a CUDA machine, then this should print a CUDA device:
+ self._logger.info("device: {}".format(device))
+ training_inputs, training_outputs = self.training_obs, self.training_labels
+ loss_criterion = nn.CrossEntropyLoss()
+
+ cum_loss = 0.0
+ training_data = list(zip(training_inputs, training_outputs))
+ for epoch in range(self.cfg.reward_model.update_per_collect): # todo
+ np.random.shuffle(training_data)
+ training_obs, training_labels = zip(*training_data)
+ for i in range(len(training_labels)):
+
+ # traj_i, traj_j has the same length, however, they change as i increases
+ traj_i, traj_j = training_obs[i] # traj_i is a list of array generated by env.step
+ traj_i = np.array(traj_i)
+ traj_j = np.array(traj_j)
+ traj_i = torch.from_numpy(traj_i).float().to(device)
+ traj_j = torch.from_numpy(traj_j).float().to(device)
+
+ # training_labels[i] is a boolean integer: 0 or 1
+ labels = torch.tensor([training_labels[i]]).to(device)
+
+ # forward + backward + zero out gradient + optimize
+ outputs, abs_rewards = self.reward_model.forward(traj_i, traj_j)
+ outputs = outputs.unsqueeze(0)
+ loss = loss_criterion(outputs, labels) + self.l1_reg * abs_rewards
+ self.opt.zero_grad()
+ loss.backward()
+ self.opt.step()
+
+ # print stats to see if learning
+ item_loss = loss.item()
+ cum_loss += item_loss
+ if i % 100 == 99:
+ self._logger.info("[epoch {}:{}] loss {}".format(epoch, i, cum_loss))
+ self._logger.info("abs_returns: {}".format(abs_rewards))
+ cum_loss = 0.0
+ self._logger.info("check pointing")
+ if not os.path.exists(os.path.join(self.cfg.exp_name, 'ckpt_reward_model')):
+ os.makedirs(os.path.join(self.cfg.exp_name, 'ckpt_reward_model'))
+ torch.save(self.reward_model.state_dict(), os.path.join(self.cfg.exp_name, 'ckpt_reward_model/latest.pth.tar'))
+ self._logger.info("finished training")
+
+ def train(self):
+ self._train()
+ # print out predicted cumulative returns and actual returns
+ sorted_returns = sorted(self.learning_returns, key=lambda s: s[0])
+ demonstrations = [
+ x for _, x in sorted(zip(self.learning_returns, self.pre_expert_data), key=lambda pair: pair[0][0])
+ ]
+ with torch.no_grad():
+ pred_returns = [self.predict_traj_return(self.reward_model, traj[0]) for traj in demonstrations]
+ for i, p in enumerate(pred_returns):
+ self._logger.info("{} {} {}".format(i, p, sorted_returns[i][0]))
+ info = {
+ "demo_length": [len(d[0]) for d in self.pre_expert_data],
+ "min_snippet_length": self.min_snippet_length,
+ "max_snippet_length": min(np.min([len(d[0]) for d in self.pre_expert_data]), self.max_snippet_length),
+ "len_num_training_obs": len(self.training_obs),
+ "lem_num_labels": len(self.training_labels),
+ "accuracy": self.calc_accuracy(self.reward_model, self.training_obs, self.training_labels),
+ }
+ self._logger.info(
+ "accuracy and comparison:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))
+ )
+
+ def predict_traj_return(self, net, traj):
+ device = self.device
+ # torch.set_printoptions(precision=20)
+ # torch.use_deterministic_algorithms(True)
+ with torch.no_grad():
+ rewards_from_obs = net.cum_return(
+ torch.from_numpy(np.array(traj)).float().to(device), mode='batch'
+ )[0].squeeze().tolist()
+ # rewards_from_obs1 = net.cum_return(torch.from_numpy(np.array([traj[0]])).float().to(device))[0].item()
+ # different precision
+ return sum(rewards_from_obs) # rewards_from_obs is a list of floats
+
+ def calc_accuracy(self, reward_network, training_inputs, training_outputs):
+ device = self.device
+ loss_criterion = nn.CrossEntropyLoss()
+ num_correct = 0.
+ with torch.no_grad():
+ for i in range(len(training_inputs)):
+ label = training_outputs[i]
+ traj_i, traj_j = training_inputs[i]
+ traj_i = np.array(traj_i)
+ traj_j = np.array(traj_j)
+ traj_i = torch.from_numpy(traj_i).float().to(device)
+ traj_j = torch.from_numpy(traj_j).float().to(device)
+
+ #forward to get logits
+ outputs, abs_return = reward_network.forward(traj_i, traj_j)
+ _, pred_label = torch.max(outputs, 0)
+ if pred_label.item() == label:
+ num_correct += 1.
+ return num_correct / len(training_inputs)
+
+ def pred_data(self, data):
+ obs = [default_collate(data[i])['obs'] for i in range(len(data))]
+ res = [torch.sum(default_collate(data[i])['reward']).item() for i in range(len(data))]
+ pred_returns = [self.predict_traj_return(self.reward_model, obs[i]) for i in range(len(obs))]
+ return {'real': res, 'pred': pred_returns}
+
+ def estimate(self, data: list) -> List[Dict]:
+ """
+ Overview:
+ Estimate reward by rewriting the reward key in each row of the data.
+ Arguments:
+ - data (:obj:`list`): the list of data used for estimation, with at least \
+ ``obs`` and ``action`` keys.
+ Effects:
+ - This is a side effect function which updates the reward values in place.
+ """
+ # NOTE: deepcopy reward part of data is very important,
+ # otherwise the reward of data in the replay buffer will be incorrectly modified.
+ train_data_augmented = self.reward_deepcopy(data)
+
+ res = collect_states(train_data_augmented)
+ res = torch.stack(res).to(self.device)
+ with torch.no_grad():
+ sum_rewards, sum_abs_rewards = self.reward_model.cum_return(res, mode='batch')
+
+ for item, rew in zip(train_data_augmented, sum_rewards): # TODO optimise this loop as well ?
+ item['reward'] = rew
+
+ return train_data_augmented
+
+ def collect_data(self, data: list) -> None:
+ """
+ Overview:
+ Collecting training data formatted by ``fn:concat_state_action_pairs``.
+ Arguments:
+ - data (:obj:`Any`): Raw training data (e.g. some form of states, actions, obs, etc)
+ Effects:
+ - This is a side effect function which updates the data attribute in ``self``
+ """
+ pass
+
+ def clear_data(self) -> None:
+ """
+ Overview:
+ Clearing training data. \
+ This is a side effect function which clears the data attribute in ``self``
+ """
+ self.training_obs.clear()
+ self.training_labels.clear()
diff --git a/DI-engine/ding/rl_utils/__init__.py b/DI-engine/ding/rl_utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e86f6c17869f36539db6587a5517eb18cb1bf92e
--- /dev/null
+++ b/DI-engine/ding/rl_utils/__init__.py
@@ -0,0 +1,27 @@
+from .exploration import get_epsilon_greedy_fn, create_noise_generator
+from .ppo import ppo_data, ppo_loss, ppo_info, ppo_policy_data, ppo_policy_error, ppo_value_data, ppo_value_error, \
+ ppo_error, ppo_error_continuous, ppo_policy_error_continuous, ppo_data_continuous, ppo_policy_data_continuous
+from .happo import happo_data, happo_policy_data, happo_value_data, happo_loss, happo_policy_loss, happo_info, \
+ happo_error, happo_policy_error, happo_value_error, happo_error_continuous, happo_policy_error_continuous
+from .ppg import ppg_data, ppg_joint_loss, ppg_joint_error
+from .gae import gae_data, gae
+from .a2c import a2c_data, a2c_error, a2c_error_continuous
+from .coma import coma_data, coma_error
+from .td import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, \
+ q_1step_td_error, m_q_1step_td_data, m_q_1step_td_error, td_lambda_data, td_lambda_error, \
+ q_nstep_td_error_with_rescale, v_1step_td_data, v_1step_td_error, v_nstep_td_data, v_nstep_td_error, \
+ generalized_lambda_returns, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_error, dist_nstep_td_data, \
+ nstep_return_data, nstep_return, iqn_nstep_td_data, iqn_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error, \
+ fqf_nstep_td_data, fqf_nstep_td_error, fqf_calculate_fraction_loss, evaluate_quantile_at_action, \
+ q_nstep_sql_td_error, dqfd_nstep_td_error, dqfd_nstep_td_data, q_v_1step_td_error, q_v_1step_td_data, \
+ dqfd_nstep_td_error_with_rescale, discount_cumsum, bdq_nstep_td_error
+from .vtrace import vtrace_loss, compute_importance_weights
+from .upgo import upgo_loss
+from .adder import get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample
+from .value_rescale import value_transform, value_inv_transform, symlog, inv_symlog
+from .vtrace import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action
+from .beta_function import beta_function_map
+from .retrace import compute_q_retraces
+from .acer import acer_policy_error, acer_value_error, acer_trust_region_update
+from .sampler import ArgmaxSampler, MultinomialSampler, MuSampler, ReparameterizationSampler, HybridStochasticSampler, \
+ HybridDeterminsticSampler
diff --git a/DI-engine/ding/rl_utils/a2c.py b/DI-engine/ding/rl_utils/a2c.py
new file mode 100644
index 0000000000000000000000000000000000000000..09cb199553e41e0b8d3c94cf4e9717f39be5e09b
--- /dev/null
+++ b/DI-engine/ding/rl_utils/a2c.py
@@ -0,0 +1,88 @@
+from collections import namedtuple
+import torch
+import torch.nn.functional as F
+from torch.distributions import Independent, Normal
+
+a2c_data = namedtuple('a2c_data', ['logit', 'action', 'value', 'adv', 'return_', 'weight'])
+a2c_loss = namedtuple('a2c_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
+
+
+def a2c_error(data: namedtuple) -> namedtuple:
+ """
+ Overview:
+ Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for discrete action space
+ Arguments:
+ - data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data``
+ Returns:
+ - a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor
+ Shapes:
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - value (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - adv (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - return (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
+ - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ - value_loss (:obj:`torch.FloatTensor`): :math:`()`
+ - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
+ Examples:
+ >>> data = a2c_data(
+ >>> logit=torch.randn(2, 3),
+ >>> action=torch.randint(0, 3, (2, )),
+ >>> value=torch.randn(2, ),
+ >>> adv=torch.randn(2, ),
+ >>> return_=torch.randn(2, ),
+ >>> weight=torch.ones(2, ),
+ >>> )
+ >>> loss = a2c_error(data)
+ """
+ logit, action, value, adv, return_, weight = data
+ if weight is None:
+ weight = torch.ones_like(value)
+ dist = torch.distributions.categorical.Categorical(logits=logit)
+ logp = dist.log_prob(action)
+ entropy_loss = (dist.entropy() * weight).mean()
+ policy_loss = -(logp * adv * weight).mean()
+ value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean()
+ return a2c_loss(policy_loss, value_loss, entropy_loss)
+
+
+def a2c_error_continuous(data: namedtuple) -> namedtuple:
+ """
+ Overview:
+ Implementation of A2C(Advantage Actor-Critic) (arXiv:1602.01783) for continuous action space
+ Arguments:
+ - data (:obj:`namedtuple`): a2c input data with fieids shown in ``a2c_data``
+ Returns:
+ - a2c_loss (:obj:`namedtuple`): the a2c loss item, all of them are the differentiable 0-dim tensor
+ Shapes:
+ - logit (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
+ - action (:obj:`torch.LongTensor`): :math:`(B, N)`
+ - value (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - adv (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - return (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
+ - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ - value_loss (:obj:`torch.FloatTensor`): :math:`()`
+ - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
+ Examples:
+ >>> data = a2c_data(
+ >>> logit={'mu': torch.randn(2, 3), 'sigma': torch.sqrt(torch.randn(2, 3)**2)},
+ >>> action=torch.randn(2, 3),
+ >>> value=torch.randn(2, ),
+ >>> adv=torch.randn(2, ),
+ >>> return_=torch.randn(2, ),
+ >>> weight=torch.ones(2, ),
+ >>> )
+ >>> loss = a2c_error_continuous(data)
+ """
+ logit, action, value, adv, return_, weight = data
+ if weight is None:
+ weight = torch.ones_like(value)
+
+ dist = Independent(Normal(logit['mu'], logit['sigma']), 1)
+ logp = dist.log_prob(action)
+ entropy_loss = (dist.entropy() * weight).mean()
+ policy_loss = -(logp * adv * weight).mean()
+ value_loss = (F.mse_loss(return_, value, reduction='none') * weight).mean()
+ return a2c_loss(policy_loss, value_loss, entropy_loss)
diff --git a/DI-engine/ding/rl_utils/acer.py b/DI-engine/ding/rl_utils/acer.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba83fc93934c13dd5675462d31e74bb02c95e2bf
--- /dev/null
+++ b/DI-engine/ding/rl_utils/acer.py
@@ -0,0 +1,124 @@
+from typing import Tuple, List
+from collections import namedtuple
+import torch
+import torch.nn.functional as F
+EPS = 1e-8
+
+
+def acer_policy_error(
+ q_values: torch.Tensor,
+ q_retraces: torch.Tensor,
+ v_pred: torch.Tensor,
+ target_logit: torch.Tensor,
+ actions: torch.Tensor,
+ ratio: torch.Tensor,
+ c_clip_ratio: float = 10.0
+) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Get ACER policy loss.
+ Arguments:
+ - q_values (:obj:`torch.Tensor`): Q values
+ - q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method)
+ - v_pred (:obj:`torch.Tensor`): V values
+ - target_pi (:obj:`torch.Tensor`): The new policy's probability
+ - actions (:obj:`torch.Tensor`): The actions in replay buffer
+ - ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy
+ - c_clip_ratio (:obj:`float`): clip value for ratio
+ Returns:
+ - actor_loss (:obj:`torch.Tensor`): policy loss from q_retrace
+ - bc_loss (:obj:`torch.Tensor`): correct policy loss
+ Shapes:
+ - q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim
+ - q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
+ - v_pred (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
+ - target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
+ - actions (:obj:`torch.LongTensor`): :math:`(T, B)`
+ - ratio (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
+ - actor_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
+ - bc_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
+ Examples:
+ >>> q_values=torch.randn(2, 3, 4),
+ >>> q_retraces=torch.randn(2, 3, 1),
+ >>> v_pred=torch.randn(2, 3, 1),
+ >>> target_pi=torch.randn(2, 3, 4),
+ >>> actions=torch.randint(0, 4, (2, 3)),
+ >>> ratio=torch.randn(2, 3, 4),
+ >>> loss = acer_policy_error(q_values, q_retraces, v_pred, target_pi, actions, ratio)
+ """
+ actions = actions.unsqueeze(-1)
+ with torch.no_grad():
+ advantage_retraces = q_retraces - v_pred # shape T,B,1
+ advantage_native = q_values - v_pred # shape T,B,env_action_shape
+ actor_loss = ratio.gather(-1, actions).clamp(max=c_clip_ratio) * advantage_retraces * target_logit.gather(
+ -1, actions
+ ) # shape T,B,1
+
+ # bias correction term, the first target_pi will not calculate gradient flow
+ bias_correction_loss = (1.0-c_clip_ratio/(ratio+EPS)).clamp(min=0.0)*torch.exp(target_logit).detach() * \
+ advantage_native*target_logit # shape T,B,env_action_shape
+ bias_correction_loss = bias_correction_loss.sum(-1, keepdim=True)
+ return actor_loss, bias_correction_loss
+
+
+def acer_value_error(q_values, q_retraces, actions):
+ """
+ Overview:
+ Get ACER critic loss.
+ Arguments:
+ - q_values (:obj:`torch.Tensor`): Q values
+ - q_retraces (:obj:`torch.Tensor`): Q values (be calculated by retrace method)
+ - actions (:obj:`torch.Tensor`): The actions in replay buffer
+ - ratio (:obj:`torch.Tensor`): ratio of new polcy with behavior policy
+ Returns:
+ - critic_loss (:obj:`torch.Tensor`): critic loss
+ Shapes:
+ - q_values (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where B is batch size and N is action dim
+ - q_retraces (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
+ - actions (:obj:`torch.LongTensor`): :math:`(T, B)`
+ - critic_loss (:obj:`torch.FloatTensor`): :math:`(T, B, 1)`
+ Examples:
+ >>> q_values=torch.randn(2, 3, 4)
+ >>> q_retraces=torch.randn(2, 3, 1)
+ >>> actions=torch.randint(0, 4, (2, 3))
+ >>> loss = acer_value_error(q_values, q_retraces, actions)
+ """
+ actions = actions.unsqueeze(-1)
+ critic_loss = 0.5 * (q_retraces - q_values.gather(-1, actions)).pow(2)
+ return critic_loss
+
+
+def acer_trust_region_update(
+ actor_gradients: List[torch.Tensor], target_logit: torch.Tensor, avg_logit: torch.Tensor,
+ trust_region_value: float
+) -> List[torch.Tensor]:
+ """
+ Overview:
+ calcuate gradient with trust region constrain
+ Arguments:
+ - actor_gradients (:obj:`list(torch.Tensor)`): gradients value's for different part
+ - target_pi (:obj:`torch.Tensor`): The new policy's probability
+ - avg_pi (:obj:`torch.Tensor`): The average policy's probability
+ - trust_region_value (:obj:`float`): the range of trust region
+ Returns:
+ - update_gradients (:obj:`list(torch.Tensor)`): gradients with trust region constraint
+ Shapes:
+ - target_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
+ - avg_pi (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
+ - update_gradients (:obj:`list(torch.FloatTensor)`): :math:`(T, B, N)`
+ Examples:
+ >>> actor_gradients=[torch.randn(2, 3, 4)]
+ >>> target_pi=torch.randn(2, 3, 4)
+ >>> avg_pi=torch.randn(2, 3, 4)
+ >>> loss = acer_trust_region_update(actor_gradients, target_pi, avg_pi, 0.1)
+ """
+ with torch.no_grad():
+ KL_gradients = [torch.exp(avg_logit)]
+ update_gradients = []
+ # TODO: here is only one elements in this list.Maybe will use to more elements in the future
+ actor_gradient = actor_gradients[0]
+ KL_gradient = KL_gradients[0]
+ scale = actor_gradient.mul(KL_gradient).sum(-1, keepdim=True) - trust_region_value
+ scale = torch.div(scale, KL_gradient.mul(KL_gradient).sum(-1, keepdim=True)).clamp(min=0.0)
+ update_gradients.append(actor_gradient - scale * KL_gradient)
+ return update_gradients
diff --git a/DI-engine/ding/rl_utils/adder.py b/DI-engine/ding/rl_utils/adder.py
new file mode 100644
index 0000000000000000000000000000000000000000..26b431b870557d254530406d3ce0429e226419a1
--- /dev/null
+++ b/DI-engine/ding/rl_utils/adder.py
@@ -0,0 +1,240 @@
+from typing import List, Dict, Any, Optional
+from collections import deque
+import copy
+import torch
+
+from ding.utils import list_split, lists_to_dicts
+from ding.rl_utils.gae import gae, gae_data
+
+
+class Adder(object):
+ """
+ Overview:
+ Adder is a component that handles different transformations and calculations for transitions
+ in Collector Module(data generation and processing), such as GAE, n-step return, transition sampling etc.
+ Interface:
+ __init__, get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample
+ """
+
+ @classmethod
+ def get_gae(cls, data: List[Dict[str, Any]], last_value: torch.Tensor, gamma: float, gae_lambda: float,
+ cuda: bool) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ Get GAE advantage for stacked transitions(T timestep, 1 batch). Call ``gae`` for calculation.
+ Arguments:
+ - data (:obj:`list`): Transitions list, each element is a transition dict with at least ['value', 'reward']
+ - last_value (:obj:`torch.Tensor`): The last value(i.e.: the T+1 timestep)
+ - gamma (:obj:`float`): The future discount factor, should be in [0, 1], defaults to 0.99.
+ - gae_lambda (:obj:`float`): GAE lambda parameter, should be in [0, 1], defaults to 0.97, \
+ when lambda -> 0, it induces bias, but when lambda -> 1, it has high variance due to the sum of terms.
+ - cuda (:obj:`bool`): Whether use cuda in GAE computation
+ Returns:
+ - data (:obj:`list`): transitions list like input one, but each element owns extra advantage key 'adv'
+ Examples:
+ >>> B, T = 2, 3 # batch_size, timestep
+ >>> data = [dict(value=torch.randn(B), reward=torch.randn(B)) for _ in range(T)]
+ >>> last_value = torch.randn(B)
+ >>> gamma = 0.99
+ >>> gae_lambda = 0.95
+ >>> cuda = False
+ >>> data = Adder.get_gae(data, last_value, gamma, gae_lambda, cuda)
+ """
+ value = torch.stack([d['value'] for d in data])
+ next_value = torch.stack([d['value'] for d in data][1:] + [last_value])
+ reward = torch.stack([d['reward'] for d in data])
+ if cuda:
+ value = value.cuda()
+ next_value = next_value.cuda()
+ reward = reward.cuda()
+
+ adv = gae(gae_data(value, next_value, reward, None, None), gamma, gae_lambda)
+
+ if cuda:
+ adv = adv.cpu()
+ for i in range(len(data)):
+ data[i]['adv'] = adv[i]
+ return data
+
+ @classmethod
+ def get_gae_with_default_last_value(cls, data: deque, done: bool, gamma: float, gae_lambda: float,
+ cuda: bool) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ Like ``get_gae`` above to get GAE advantage for stacked transitions. However, this function is designed in
+ case ``last_value`` is not passed. If transition is not done yet, it wouold assign last value in ``data``
+ as ``last_value``, discard the last element in ``data``(i.e. len(data) would decrease by 1), and then call
+ ``get_gae``. Otherwise it would make ``last_value`` equal to 0.
+ Arguments:
+ - data (:obj:`deque`): Transitions list, each element is a transition dict with \
+ at least['value', 'reward']
+ - done (:obj:`bool`): Whether the transition reaches the end of an episode(i.e. whether the env is done)
+ - gamma (:obj:`float`): The future discount factor, should be in [0, 1], defaults to 0.99.
+ - gae_lambda (:obj:`float`): GAE lambda parameter, should be in [0, 1], defaults to 0.97, \
+ when lambda -> 0, it induces bias, but when lambda -> 1, it has high variance due to the sum of terms.
+ - cuda (:obj:`bool`): Whether use cuda in GAE computation
+ Returns:
+ - data (:obj:`List[Dict[str, Any]]`): transitions list like input one, but each element owns \
+ extra advantage key 'adv'
+ Examples:
+ >>> B, T = 2, 3 # batch_size, timestep
+ >>> data = [dict(value=torch.randn(B), reward=torch.randn(B)) for _ in range(T)]
+ >>> done = False
+ >>> gamma = 0.99
+ >>> gae_lambda = 0.95
+ >>> cuda = False
+ >>> data = Adder.get_gae_with_default_last_value(data, done, gamma, gae_lambda, cuda)
+ """
+ if done:
+ last_value = torch.zeros_like(data[-1]['value'])
+ else:
+ last_data = data.pop()
+ last_value = last_data['value']
+ return cls.get_gae(data, last_value, gamma, gae_lambda, cuda)
+
+ @classmethod
+ def get_nstep_return_data(
+ cls,
+ data: deque,
+ nstep: int,
+ cum_reward=False,
+ correct_terminate_gamma=True,
+ gamma=0.99,
+ ) -> deque:
+ """
+ Overview:
+ Process raw traj data by updating keys ['next_obs', 'reward', 'done'] in data's dict element.
+ Arguments:
+ - data (:obj:`deque`): Transitions list, each element is a transition dict
+ - nstep (:obj:`int`): Number of steps. If equals to 1, return ``data`` directly; \
+ Otherwise update with nstep value.
+ Returns:
+ - data (:obj:`deque`): Transitions list like input one, but each element updated with nstep value.
+ Examples:
+ >>> data = [dict(
+ >>> obs=torch.randn(B),
+ >>> reward=torch.randn(1),
+ >>> next_obs=torch.randn(B),
+ >>> done=False) for _ in range(T)]
+ >>> nstep = 2
+ >>> data = Adder.get_nstep_return_data(data, nstep)
+ """
+ if nstep == 1:
+ return data
+ fake_reward = torch.zeros(1)
+ next_obs_flag = 'next_obs' in data[0]
+ for i in range(len(data) - nstep):
+ # update keys ['next_obs', 'reward', 'done'] with their n-step value
+ if next_obs_flag:
+ data[i]['next_obs'] = data[i + nstep]['obs'] # do not need deepcopy
+ if cum_reward:
+ data[i]['reward'] = sum([data[i + j]['reward'] * (gamma ** j) for j in range(nstep)])
+ else:
+ data[i]['reward'] = torch.cat([data[i + j]['reward'] for j in range(nstep)])
+ data[i]['done'] = data[i + nstep - 1]['done']
+ if correct_terminate_gamma:
+ data[i]['value_gamma'] = gamma ** nstep
+ for i in range(max(0, len(data) - nstep), len(data)):
+ if next_obs_flag:
+ data[i]['next_obs'] = data[-1]['next_obs'] # do not need deepcopy
+ if cum_reward:
+ data[i]['reward'] = sum([data[i + j]['reward'] * (gamma ** j) for j in range(len(data) - i)])
+ else:
+ data[i]['reward'] = torch.cat(
+ [data[i + j]['reward']
+ for j in range(len(data) - i)] + [fake_reward for _ in range(nstep - (len(data) - i))]
+ )
+ data[i]['done'] = data[-1]['done']
+ if correct_terminate_gamma:
+ data[i]['value_gamma'] = gamma ** (len(data) - i - 1)
+ return data
+
+ @classmethod
+ def get_train_sample(
+ cls,
+ data: List[Dict[str, Any]],
+ unroll_len: int,
+ last_fn_type: str = 'last',
+ null_transition: Optional[dict] = None
+ ) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ Process raw traj data by updating keys ['next_obs', 'reward', 'done'] in data's dict element.
+ If ``unroll_len`` equals to 1, which means no process is needed, can directly return ``data``.
+ Otherwise, ``data`` will be splitted according to ``unroll_len``, process residual part according to
+ ``last_fn_type`` and call ``lists_to_dicts`` to form sampled training data.
+ Arguments:
+ - data (:obj:`List[Dict[str, Any]]`): Transitions list, each element is a transition dict
+ - unroll_len (:obj:`int`): Learn training unroll length
+ - last_fn_type (:obj:`str`): The method type name for dealing with last residual data in a traj \
+ after splitting, should be in ['last', 'drop', 'null_padding']
+ - null_transition (:obj:`Optional[dict]`): Dict type null transition, used in ``null_padding``
+ Returns:
+ - data (:obj:`List[Dict[str, Any]]`): Transitions list processed after unrolling
+ """
+ if unroll_len == 1:
+ return data
+ else:
+ # cut data into pieces whose length is unroll_len
+ split_data, residual = list_split(data, step=unroll_len)
+
+ def null_padding():
+ template = copy.deepcopy(residual[0])
+ template['null'] = True
+ if isinstance(template['obs'], dict):
+ template['obs'] = {k: torch.zeros_like(v) for k, v in template['obs'].items()}
+ else:
+ template['obs'] = torch.zeros_like(template['obs'])
+ if 'action' in template:
+ template['action'] = torch.zeros_like(template['action'])
+ template['done'] = True
+ template['reward'] = torch.zeros_like(template['reward'])
+ if 'value_gamma' in template:
+ template['value_gamma'] = 0.
+ null_data = [cls._get_null_transition(template, null_transition) for _ in range(miss_num)]
+ return null_data
+
+ if residual is not None:
+ miss_num = unroll_len - len(residual)
+ if last_fn_type == 'drop':
+ # drop the residual part
+ pass
+ elif last_fn_type == 'last':
+ if len(split_data) > 0:
+ # copy last datas from split_data's last element, and insert in front of residual
+ last_data = copy.deepcopy(split_data[-1][-miss_num:])
+ split_data.append(last_data + residual)
+ else:
+ # get null transitions using ``null_padding``, and insert behind residual
+ null_data = null_padding()
+ split_data.append(residual + null_data)
+ elif last_fn_type == 'null_padding':
+ # same to the case of 'last' type and split_data is empty
+ null_data = null_padding()
+ split_data.append(residual + null_data)
+ # collate unroll_len dicts according to keys
+ if len(split_data) > 0:
+ split_data = [lists_to_dicts(d, recursive=True) for d in split_data]
+ return split_data
+
+ @classmethod
+ def _get_null_transition(cls, template: dict, null_transition: Optional[dict] = None) -> dict:
+ """
+ Overview:
+ Get null transition for padding. If ``cls._null_transition`` is None, return input ``template`` instead.
+ Arguments:
+ - template (:obj:`dict`): The template for null transition.
+ - null_transition (:obj:`Optional[dict]`): Dict type null transition, used in ``null_padding``
+ Returns:
+ - null_transition (:obj:`dict`): The deepcopied null transition.
+ """
+ if null_transition is not None:
+ return copy.deepcopy(null_transition)
+ else:
+ return copy.deepcopy(template)
+
+
+get_gae = Adder.get_gae
+get_gae_with_default_last_value = Adder.get_gae_with_default_last_value
+get_nstep_return_data = Adder.get_nstep_return_data
+get_train_sample = Adder.get_train_sample
diff --git a/DI-engine/ding/rl_utils/beta_function.py b/DI-engine/ding/rl_utils/beta_function.py
new file mode 100644
index 0000000000000000000000000000000000000000..4096228321984206ed5ad81789ed7224042a4514
--- /dev/null
+++ b/DI-engine/ding/rl_utils/beta_function.py
@@ -0,0 +1,40 @@
+"""
+Referenced papar
+"""
+import torch
+from typing import Union
+
+beta_function_map = {}
+
+beta_function_map['uniform'] = lambda x: x
+
+# For beta functions, concavity corresponds to risk-averse and convexity to risk-seeking policies
+
+
+# For CPW, eta = 0.71 most closely match human subjects
+# this function is locally concave for small values of τ and becomes locally convex for larger values of τ
+def cpw(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor, float]:
+ return (x ** eta) / ((x ** eta + (1 - x) ** eta) ** (1 / eta))
+
+
+beta_function_map['CPW'] = cpw
+
+
+# CVaR is risk-averse
+def CVaR(x: Union[torch.Tensor, float], eta: float = 0.71) -> Union[torch.Tensor, float]:
+ assert eta <= 1.0
+ return x * eta
+
+
+beta_function_map['CVaR'] = CVaR
+
+
+# risk-averse (eta < 0) or risk-seeking (eta > 0)
+def Pow(x: Union[torch.Tensor, float], eta: float = 0.0) -> Union[torch.Tensor, float]:
+ if eta >= 0:
+ return x ** (1 / (1 + eta))
+ else:
+ return 1 - (1 - x) ** (1 / 1 - eta)
+
+
+beta_function_map['Pow'] = Pow
diff --git a/DI-engine/ding/rl_utils/coma.py b/DI-engine/ding/rl_utils/coma.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ee1778293e547efa5e1543b31fc687f903a8485
--- /dev/null
+++ b/DI-engine/ding/rl_utils/coma.py
@@ -0,0 +1,60 @@
+from collections import namedtuple
+import torch
+import torch.nn.functional as F
+from ding.rl_utils.td import generalized_lambda_returns
+
+coma_data = namedtuple('coma_data', ['logit', 'action', 'q_value', 'target_q_value', 'reward', 'weight'])
+coma_loss = namedtuple('coma_loss', ['policy_loss', 'q_value_loss', 'entropy_loss'])
+
+
+def coma_error(data: namedtuple, gamma: float, lambda_: float) -> namedtuple:
+ """
+ Overview:
+ Implementation of COMA
+ Arguments:
+ - data (:obj:`namedtuple`): coma input data with fieids shown in ``coma_data``
+ Returns:
+ - coma_loss (:obj:`namedtuple`): the coma loss item, all of them are the differentiable 0-dim tensor
+ Shapes:
+ - logit (:obj:`torch.FloatTensor`): :math:`(T, B, A, N)`, where B is batch size A is the agent num, and N is \
+ action dim
+ - action (:obj:`torch.LongTensor`): :math:`(T, B, A)`
+ - q_value (:obj:`torch.FloatTensor`): :math:`(T, B, A, N)`
+ - target_q_value (:obj:`torch.FloatTensor`): :math:`(T, B, A, N)`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(T ,B, A)`
+ - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ - value_loss (:obj:`torch.FloatTensor`): :math:`()`
+ - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
+ Examples:
+ >>> action_dim = 4
+ >>> agent_num = 3
+ >>> data = coma_data(
+ >>> logit=torch.randn(2, 3, agent_num, action_dim),
+ >>> action=torch.randint(0, action_dim, (2, 3, agent_num)),
+ >>> q_value=torch.randn(2, 3, agent_num, action_dim),
+ >>> target_q_value=torch.randn(2, 3, agent_num, action_dim),
+ >>> reward=torch.randn(2, 3),
+ >>> weight=torch.ones(2, 3, agent_num),
+ >>> )
+ >>> loss = coma_error(data, 0.99, 0.99)
+ """
+ logit, action, q_value, target_q_value, reward, weight = data
+ if weight is None:
+ weight = torch.ones_like(action)
+ q_taken = torch.gather(q_value, -1, index=action.unsqueeze(-1)).squeeze(-1)
+ target_q_taken = torch.gather(target_q_value, -1, index=action.unsqueeze(-1)).squeeze(-1)
+ T, B, A = target_q_taken.shape
+ reward = reward.unsqueeze(-1).expand_as(target_q_taken).reshape(T, -1)
+ target_q_taken = target_q_taken.reshape(T, -1)
+ return_ = generalized_lambda_returns(target_q_taken, reward[:-1], gamma, lambda_)
+ return_ = return_.reshape(T - 1, B, A)
+ q_value_loss = (F.mse_loss(return_, q_taken[:-1], reduction='none') * weight[:-1]).mean()
+
+ dist = torch.distributions.categorical.Categorical(logits=logit)
+ logp = dist.log_prob(action)
+ baseline = (torch.softmax(logit, dim=-1) * q_value).sum(-1).detach()
+ adv = (q_taken - baseline).detach()
+ entropy_loss = (dist.entropy() * weight).mean()
+ policy_loss = -(logp * adv * weight).mean()
+ return coma_loss(policy_loss, q_value_loss, entropy_loss)
diff --git a/DI-engine/ding/rl_utils/exploration.py b/DI-engine/ding/rl_utils/exploration.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa296b59d3ad5d78893c2ec281b34c253082cb21
--- /dev/null
+++ b/DI-engine/ding/rl_utils/exploration.py
@@ -0,0 +1,209 @@
+import math
+from abc import ABC, abstractmethod
+from typing import Callable, Union, Optional
+from copy import deepcopy
+from ding.torch_utils.data_helper import to_device
+
+import torch
+
+
+def get_epsilon_greedy_fn(start: float, end: float, decay: int, type_: str = 'exp') -> Callable:
+ """
+ Overview:
+ Generate an epsilon_greedy function with decay, which inputs current timestep and outputs current epsilon.
+ Arguments:
+ - start (:obj:`float`): Epsilon start value. For 'linear', it should be 1.0.
+ - end (:obj:`float`): Epsilon end value.
+ - decay (:obj:`int`): Controls the speed that epsilon decreases from ``start`` to ``end``. \
+ We recommend epsilon decays according to env step rather than iteration.
+ - type (:obj:`str`): How epsilon decays, now supports ['linear', 'exp'(exponential)]
+ Returns:
+ - eps_fn (:obj:`function`): The epsilon greedy function with decay
+ """
+ assert type_ in ['linear', 'exp'], type_
+ if type_ == 'exp':
+ return lambda x: (start - end) * math.exp(-1 * x / decay) + end
+ elif type_ == 'linear':
+
+ def eps_fn(x):
+ if x >= decay:
+ return end
+ else:
+ return (start - end) * (1 - x / decay) + end
+
+ return eps_fn
+
+
+class BaseNoise(ABC):
+ r"""
+ Overview:
+ Base class for action noise
+ Interface:
+ __init__, __call__
+ Examples:
+ >>> noise_generator = OUNoise() # init one type of noise
+ >>> noise = noise_generator(action.shape, action.device) # generate noise
+ """
+
+ def __init__(self) -> None:
+ """
+ Overview:
+ Initialization method
+ """
+ super().__init__()
+
+ @abstractmethod
+ def __call__(self, shape: tuple, device: str) -> torch.Tensor:
+ """
+ Overview:
+ Generate noise according to action tensor's shape, device
+ Arguments:
+ - shape (:obj:`tuple`): size of the action tensor, output noise's size should be the same
+ - device (:obj:`str`): device of the action tensor, output noise's device should be the same as it
+ Returns:
+ - noise (:obj:`torch.Tensor`): generated action noise, \
+ have the same shape and device with the input action tensor
+ """
+ raise NotImplementedError
+
+
+class GaussianNoise(BaseNoise):
+ r"""
+ Overview:
+ Derived class for generating gaussian noise, which satisfies :math:`X \sim N(\mu, \sigma^2)`
+ Interface:
+ __init__, __call__
+ """
+
+ def __init__(self, mu: float = 0.0, sigma: float = 1.0) -> None:
+ """
+ Overview:
+ Initialize :math:`\mu` and :math:`\sigma` in Gaussian Distribution
+ Arguments:
+ - mu (:obj:`float`): :math:`\mu` , mean value
+ - sigma (:obj:`float`): :math:`\sigma` , standard deviation, should be positive
+ """
+ super(GaussianNoise, self).__init__()
+ self._mu = mu
+ assert sigma >= 0, "GaussianNoise's sigma should be positive."
+ self._sigma = sigma
+
+ def __call__(self, shape: tuple, device: str) -> torch.Tensor:
+ """
+ Overview:
+ Generate gaussian noise according to action tensor's shape, device
+ Arguments:
+ - shape (:obj:`tuple`): size of the action tensor, output noise's size should be the same
+ - device (:obj:`str`): device of the action tensor, output noise's device should be the same as it
+ Returns:
+ - noise (:obj:`torch.Tensor`): generated action noise, \
+ have the same shape and device with the input action tensor
+ """
+ noise = torch.randn(shape, device=device)
+ noise = noise * self._sigma + self._mu
+ return noise
+
+
+class OUNoise(BaseNoise):
+ r"""
+ Overview:
+ Derived class for generating Ornstein-Uhlenbeck process noise.
+ Satisfies :math:`dx_t=\theta(\mu-x_t)dt + \sigma dW_t`,
+ where :math:`W_t` denotes Weiner Process, acting as a random perturbation term.
+ Interface:
+ __init__, reset, __call__
+ """
+
+ def __init__(
+ self,
+ mu: float = 0.0,
+ sigma: float = 0.3,
+ theta: float = 0.15,
+ dt: float = 1e-2,
+ x0: Optional[Union[float, torch.Tensor]] = 0.0,
+ ) -> None:
+ """
+ Overview:
+ Initialize ``_alpha`` :math:`=\theta * dt\`,
+ ``beta`` :math:`= \sigma * \sqrt{dt}`, in Ornstein-Uhlenbeck process
+ Arguments:
+ - mu (:obj:`float`): :math:`\mu` , mean value
+ - sigma (:obj:`float`): :math:`\sigma` , standard deviation of the perturbation noise
+ - theta (:obj:`float`): how strongly the noise reacts to perturbations, \
+ greater value means stronger reaction
+ - dt (:obj:`float`): derivative of time t
+ - x0 (:obj:`float` or :obj:`torch.Tensor`): initial action
+ """
+ super().__init__()
+ self._mu = mu
+ self._alpha = theta * dt
+ self._beta = sigma * math.sqrt(dt)
+ self._x0 = x0
+ self.reset()
+
+ def reset(self) -> None:
+ """
+ Overview:
+ Reset ``_x`` to the initial state ``_x0``
+ """
+ self._x = deepcopy(self._x0)
+
+ def __call__(self, shape: tuple, device: str, mu: Optional[float] = None) -> torch.Tensor:
+ """
+ Overview:
+ Generate gaussian noise according to action tensor's shape, device
+ Arguments:
+ - shape (:obj:`tuple`): size of the action tensor, output noise's size should be the same
+ - device (:obj:`str`): device of the action tensor, output noise's device should be the same as it
+ - mu (:obj:`float`): new mean value :math:`\mu`, you can set it to `None` if don't need it
+ Returns:
+ - noise (:obj:`torch.Tensor`): generated action noise, \
+ have the same shape and device with the input action tensor
+ """
+ if self._x is None or \
+ (isinstance(self._x, torch.Tensor) and self._x.shape != shape):
+ self._x = torch.zeros(shape)
+ if mu is None:
+ mu = self._mu
+ noise = self._alpha * (mu - self._x) + self._beta * torch.randn(shape)
+ self._x += noise
+ noise = to_device(noise, device)
+ return noise
+
+ @property
+ def x0(self) -> Union[float, torch.Tensor]:
+ """
+ Overview:
+ Get ``self._x0``
+ """
+ return self._x0
+
+ @x0.setter
+ def x0(self, _x0: Union[float, torch.Tensor]) -> None:
+ """
+ Overview:
+ Set ``self._x0`` and reset ``self.x`` to ``self._x0`` as well
+ """
+ self._x0 = _x0
+ self.reset()
+
+
+noise_mapping = {'gauss': GaussianNoise, 'ou': OUNoise}
+
+
+def create_noise_generator(noise_type: str, noise_kwargs: dict) -> BaseNoise:
+ """
+ Overview:
+ Given the key (noise_type), create a new noise generator instance if in noise_mapping's values,
+ or raise an KeyError. In other words, a derived noise generator must first register,
+ then call ``create_noise generator`` to get the instance object.
+ Arguments:
+ - noise_type (:obj:`str`): the type of noise generator to be created
+ Returns:
+ - noise (:obj:`BaseNoise`): the created new noise generator, should be an instance of one of \
+ noise_mapping's values
+ """
+ if noise_type not in noise_mapping.keys():
+ raise KeyError("not support noise type: {}".format(noise_type))
+ else:
+ return noise_mapping[noise_type](**noise_kwargs)
diff --git a/DI-engine/ding/rl_utils/gae.py b/DI-engine/ding/rl_utils/gae.py
new file mode 100644
index 0000000000000000000000000000000000000000..800fcae35426f9249c3aa1d664256709af11c77f
--- /dev/null
+++ b/DI-engine/ding/rl_utils/gae.py
@@ -0,0 +1,70 @@
+from collections import namedtuple
+import torch
+from ding.hpc_rl import hpc_wrapper
+
+gae_data = namedtuple('gae_data', ['value', 'next_value', 'reward', 'done', 'traj_flag'])
+
+
+def shape_fn_gae(args, kwargs):
+ r"""
+ Overview:
+ Return shape of gae for hpc
+ Returns:
+ shape: [T, B]
+ """
+ if len(args) <= 0:
+ tmp = kwargs['data'].reward.shape
+ else:
+ tmp = args[0].reward.shape
+ return tmp
+
+
+@hpc_wrapper(
+ shape_fn=shape_fn_gae, namedtuple_data=True, include_args=[0, 1, 2], include_kwargs=['data', 'gamma', 'lambda_']
+)
+def gae(data: namedtuple, gamma: float = 0.99, lambda_: float = 0.97) -> torch.FloatTensor:
+ """
+ Overview:
+ Implementation of Generalized Advantage Estimator (arXiv:1506.02438)
+ Arguments:
+ - data (:obj:`namedtuple`): gae input data with fields ['value', 'reward'], which contains some episodes or \
+ trajectories data.
+ - gamma (:obj:`float`): the future discount factor, should be in [0, 1], defaults to 0.99.
+ - lambda (:obj:`float`): the gae parameter lambda, should be in [0, 1], defaults to 0.97, when lambda -> 0, \
+ it induces bias, but when lambda -> 1, it has high variance due to the sum of terms.
+ Returns:
+ - adv (:obj:`torch.FloatTensor`): the calculated advantage
+ Shapes:
+ - value (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is trajectory length and B is batch size
+ - next_value (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ - adv (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ Examples:
+ >>> value = torch.randn(2, 3)
+ >>> next_value = torch.randn(2, 3)
+ >>> reward = torch.randn(2, 3)
+ >>> data = gae_data(value, next_value, reward, None, None)
+ >>> adv = gae(data)
+ """
+ value, next_value, reward, done, traj_flag = data
+ if done is None:
+ done = torch.zeros_like(reward, device=reward.device)
+ if traj_flag is None:
+ traj_flag = done
+ done = done.float()
+ traj_flag = traj_flag.float()
+ if len(value.shape) == len(reward.shape) + 1: # for some marl case: value(T, B, A), reward(T, B)
+ reward = reward.unsqueeze(-1)
+ done = done.unsqueeze(-1)
+ traj_flag = traj_flag.unsqueeze(-1)
+
+ next_value *= (1 - done)
+ delta = reward + gamma * next_value - value
+ factor = gamma * lambda_ * (1 - traj_flag)
+ adv = torch.zeros_like(value)
+ gae_item = torch.zeros_like(value[0])
+
+ for t in reversed(range(reward.shape[0])):
+ gae_item = delta[t] + factor[t] * gae_item
+ adv[t] = gae_item
+ return adv
diff --git a/DI-engine/ding/rl_utils/happo.py b/DI-engine/ding/rl_utils/happo.py
new file mode 100644
index 0000000000000000000000000000000000000000..b37ddc7528485839a78b2c4b3de878510056d337
--- /dev/null
+++ b/DI-engine/ding/rl_utils/happo.py
@@ -0,0 +1,347 @@
+from collections import namedtuple
+from typing import Optional, Tuple
+import torch
+import torch.nn as nn
+from torch.distributions import Independent, Normal
+from ding.hpc_rl import hpc_wrapper
+
+happo_value_data = namedtuple('happo_value_data', ['value_new', 'value_old', 'return_', 'weight'])
+happo_loss = namedtuple('happo_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
+happo_policy_loss = namedtuple('happo_policy_loss', ['policy_loss', 'entropy_loss'])
+happo_info = namedtuple('happo_info', ['approx_kl', 'clipfrac'])
+happo_data = namedtuple(
+ 'happo_data', ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight', 'factor']
+)
+happo_policy_data = namedtuple('happo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight', 'factor'])
+
+
+def happo_error(
+ data: namedtuple,
+ clip_ratio: float = 0.2,
+ use_value_clip: bool = True,
+ dual_clip: Optional[float] = None,
+) -> Tuple[namedtuple, namedtuple]:
+ """
+ Overview:
+ Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip
+ Arguments:
+ - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data``
+ - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2
+ - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
+ - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
+ defaults to 5.0, if you don't want to use it, set this parameter to None
+ Returns:
+ - happo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
+ - happo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
+ Shapes:
+ - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
+ - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - value_new (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - value_old (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - adv (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - return (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
+ - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ - value_loss (:obj:`torch.FloatTensor`): :math:`()`
+ - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
+ Examples:
+ >>> action_dim = 4
+ >>> data = happo_data(
+ >>> logit_new=torch.randn(3, action_dim),
+ >>> logit_old=torch.randn(3, action_dim),
+ >>> action=torch.randint(0, action_dim, (3,)),
+ >>> value_new=torch.randn(3),
+ >>> value_old=torch.randn(3),
+ >>> adv=torch.randn(3),
+ >>> return_=torch.randn(3),
+ >>> weight=torch.ones(3),
+ >>> factor=torch.ones(3, 1),
+ >>> )
+ >>> loss, info = happo_error(data)
+
+ .. note::
+
+ adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many
+ ways to calculate this mean and std, like among data buffer or train batch, so we don't couple
+ this part into happo_error, you can refer to our examples for different ways.
+ """
+ assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format(
+ dual_clip
+ )
+ logit_new, logit_old, action, value_new, value_old, adv, return_, weight, factor = data
+ policy_data = happo_policy_data(logit_new, logit_old, action, adv, weight, factor)
+ policy_output, policy_info = happo_policy_error(policy_data, clip_ratio, dual_clip)
+ value_data = happo_value_data(value_new, value_old, return_, weight)
+ value_loss = happo_value_error(value_data, clip_ratio, use_value_clip)
+
+ return happo_loss(policy_output.policy_loss, value_loss, policy_output.entropy_loss), policy_info
+
+
+def happo_policy_error(
+ data: namedtuple,
+ clip_ratio: float = 0.2,
+ dual_clip: Optional[float] = None,
+) -> Tuple[namedtuple, namedtuple]:
+ '''
+ Overview:
+ Get PPO policy loss
+ Arguments:
+ - data (:obj:`namedtuple`): ppo input data with fieids shown in ``ppo_policy_data``
+ - clip_ratio (:obj:`float`): clip value for ratio
+ - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
+ defaults to 5.0, if you don't want to use it, set this parameter to None
+ Returns:
+ - happo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable \
+ 0-dim tensor.
+ - happo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
+ Shapes:
+ - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
+ - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - adv (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
+ - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
+ Examples:
+ >>> action_dim = 4
+ >>> data = ppo_policy_data(
+ >>> logit_new=torch.randn(3, action_dim),
+ >>> logit_old=torch.randn(3, action_dim),
+ >>> action=torch.randint(0, action_dim, (3,)),
+ >>> adv=torch.randn(3),
+ >>> weight=torch.ones(3),
+ >>> factor=torch.ones(3, 1),
+ >>> )
+ >>> loss, info = happo_policy_error(data)
+ '''
+ logit_new, logit_old, action, adv, weight, factor = data
+ if weight is None:
+ weight = torch.ones_like(adv)
+ dist_new = torch.distributions.categorical.Categorical(logits=logit_new)
+ dist_old = torch.distributions.categorical.Categorical(logits=logit_old)
+ logp_new = dist_new.log_prob(action)
+ logp_old = dist_old.log_prob(action)
+ dist_new_entropy = dist_new.entropy()
+ if dist_new_entropy.shape != weight.shape:
+ dist_new_entropy = dist_new.entropy().mean(dim=1)
+ entropy_loss = (dist_new_entropy * weight).mean()
+ # policy_loss
+ ratio = torch.exp(logp_new - logp_old)
+ if ratio.shape != adv.shape:
+ ratio = ratio.mean(dim=1)
+ surr1 = ratio * adv
+ surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv
+ # shape factor: (B,1) surr1: (B,)
+ clip1 = torch.min(surr1, surr2) * factor.squeeze(1)
+ if dual_clip is not None:
+ clip2 = torch.max(clip1, dual_clip * adv)
+ # only use dual_clip when adv < 0
+ policy_loss = -(torch.where(adv < 0, clip2, clip1) * weight).mean()
+ else:
+ policy_loss = (-clip1 * weight).mean()
+ with torch.no_grad():
+ approx_kl = (logp_old - logp_new).mean().item()
+ clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
+ clipfrac = torch.as_tensor(clipped).float().mean().item()
+ return happo_policy_loss(policy_loss, entropy_loss), happo_info(approx_kl, clipfrac)
+
+
+def happo_value_error(
+ data: namedtuple,
+ clip_ratio: float = 0.2,
+ use_value_clip: bool = True,
+) -> torch.Tensor:
+ '''
+ Overview:
+ Get PPO value loss
+ Arguments:
+ - data (:obj:`namedtuple`): ppo input data with fieids shown in ``happo_value_data``
+ - clip_ratio (:obj:`float`): clip value for ratio
+ - use_value_clip (:obj:`bool`): whether use value clip
+ Returns:
+ - value_loss (:obj:`torch.FloatTensor`): the ppo value loss item, \
+ all of them are the differentiable 0-dim tensor
+ Shapes:
+ - value_new (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size
+ - value_old (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - return (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
+ - value_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ Examples:
+ >>> action_dim = 4
+ >>> data = happo_value_data(
+ >>> value_new=torch.randn(3),
+ >>> value_old=torch.randn(3),
+ >>> return_=torch.randn(3),
+ >>> weight=torch.ones(3),
+ >>> )
+ >>> loss, info = happo_value_error(data)
+ '''
+ value_new, value_old, return_, weight = data
+ if weight is None:
+ weight = torch.ones_like(value_old)
+ # value_loss
+ if use_value_clip:
+ value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio)
+ v1 = (return_ - value_new).pow(2)
+ v2 = (return_ - value_clip).pow(2)
+ value_loss = 0.5 * (torch.max(v1, v2) * weight).mean()
+ else:
+ value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean()
+ return value_loss
+
+
+def happo_error_continuous(
+ data: namedtuple,
+ clip_ratio: float = 0.2,
+ use_value_clip: bool = True,
+ dual_clip: Optional[float] = None,
+) -> Tuple[namedtuple, namedtuple]:
+ """
+ Overview:
+ Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip
+ Arguments:
+ - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data``
+ - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2
+ - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
+ - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
+ defaults to 5.0, if you don't want to use it, set this parameter to None
+ Returns:
+ - happo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
+ - happo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
+ Shapes:
+ - mu_sigma_new (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim
+ - mu_sigma_old (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - value_new (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - value_old (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - adv (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - return (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
+ - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ - value_loss (:obj:`torch.FloatTensor`): :math:`()`
+ - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
+ Examples:
+ >>> action_dim = 4
+ >>> data = ppo_data_continuous(
+ >>> mu_sigma_new= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2),
+ >>> mu_sigma_old= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2),
+ >>> action=torch.randn(3, action_dim),
+ >>> value_new=torch.randn(3),
+ >>> value_old=torch.randn(3),
+ >>> adv=torch.randn(3),
+ >>> return_=torch.randn(3),
+ >>> weight=torch.ones(3),
+ >>> )
+ >>> loss, info = happo_error(data)
+
+ .. note::
+
+ adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many
+ ways to calculate this mean and std, like among data buffer or train batch, so we don't couple
+ this part into happo_error, you can refer to our examples for different ways.
+ """
+ assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format(
+ dual_clip
+ )
+ mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight, factor_batch = data
+ if weight is None:
+ weight = torch.ones_like(adv)
+
+ dist_new = Normal(mu_sigma_new['mu'], mu_sigma_new['sigma'])
+ if len(mu_sigma_old['mu'].shape) == 1:
+ dist_old = Normal(mu_sigma_old['mu'].unsqueeze(-1), mu_sigma_old['sigma'].unsqueeze(-1))
+ else:
+ dist_old = Normal(mu_sigma_old['mu'], mu_sigma_old['sigma'])
+ logp_new = dist_new.log_prob(action)
+ logp_old = dist_old.log_prob(action)
+ entropy_loss = (dist_new.entropy() * weight.unsqueeze(1)).mean()
+
+ # policy_loss
+ ratio = torch.exp(logp_new - logp_old)
+ ratio = torch.prod(ratio, dim=-1)
+ surr1 = ratio * adv
+ surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv
+ if dual_clip is not None:
+ # shape factor: (B,1) surr1: (B,)
+ policy_loss = (-torch.max(factor_batch.squeeze(1) * torch.min(surr1, surr2), dual_clip * adv) * weight).mean()
+ else:
+ policy_loss = (-factor_batch.squeeze(1) * torch.min(surr1, surr2) * weight).mean()
+ with torch.no_grad():
+ approx_kl = (logp_old - logp_new).mean().item()
+ clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
+ clipfrac = torch.as_tensor(clipped).float().mean().item()
+ # value_loss
+ if use_value_clip:
+ value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio)
+ v1 = (return_ - value_new).pow(2)
+ v2 = (return_ - value_clip).pow(2)
+ value_loss = 0.5 * (torch.max(v1, v2) * weight).mean()
+ else:
+ value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean()
+
+ return happo_loss(policy_loss, value_loss, entropy_loss), happo_info(approx_kl, clipfrac)
+
+
+def happo_policy_error_continuous(data: namedtuple,
+ clip_ratio: float = 0.2,
+ dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]:
+ """
+ Overview:
+ Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip
+ Arguments:
+ - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data``
+ - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2
+ - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
+ defaults to 5.0, if you don't want to use it, set this parameter to None
+ Returns:
+ - happo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
+ - happo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
+ Shapes:
+ - mu_sigma_new (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim
+ - mu_sigma_old (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - adv (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
+ - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
+ Examples:
+ >>> action_dim = 4
+ >>> data = ppo_policy_data_continuous(
+ >>> mu_sigma_new=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2),
+ >>> mu_sigma_old=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2),
+ >>> action=torch.randn(3, action_dim),
+ >>> adv=torch.randn(3),
+ >>> weight=torch.ones(3),
+ >>> )
+ >>> loss, info = happo_policy_error_continuous(data)
+ """
+ assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format(
+ dual_clip
+ )
+ mu_sigma_new, mu_sigma_old, action, adv, weight = data
+ if weight is None:
+ weight = torch.ones_like(adv)
+
+ dist_new = Independent(Normal(mu_sigma_new['mu'], mu_sigma_new['sigma']), 1)
+ if len(mu_sigma_old['mu'].shape) == 1:
+ dist_old = Independent(Normal(mu_sigma_old['mu'].unsqueeze(-1), mu_sigma_old['sigma'].unsqueeze(-1)), 1)
+ else:
+ dist_old = Independent(Normal(mu_sigma_old['mu'], mu_sigma_old['sigma']), 1)
+ logp_new = dist_new.log_prob(action)
+ logp_old = dist_old.log_prob(action)
+ entropy_loss = (dist_new.entropy() * weight).mean()
+ # policy_loss
+ ratio = torch.exp(logp_new - logp_old)
+ surr1 = ratio * adv
+ surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv
+ if dual_clip is not None:
+ policy_loss = (-torch.max(torch.min(surr1, surr2), dual_clip * adv) * weight).mean()
+ else:
+ policy_loss = (-torch.min(surr1, surr2) * weight).mean()
+ with torch.no_grad():
+ approx_kl = (logp_old - logp_new).mean().item()
+ clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
+ clipfrac = torch.as_tensor(clipped).float().mean().item()
+ return happo_policy_loss(policy_loss, entropy_loss), happo_info(approx_kl, clipfrac)
diff --git a/DI-engine/ding/rl_utils/isw.py b/DI-engine/ding/rl_utils/isw.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0745f103123f9ff5bf091696867638ddb1c479a
--- /dev/null
+++ b/DI-engine/ding/rl_utils/isw.py
@@ -0,0 +1,59 @@
+from typing import Union
+import torch
+from torch.distributions import Categorical, Independent, Normal
+
+
+def compute_importance_weights(
+ target_output: Union[torch.Tensor, dict],
+ behaviour_output: Union[torch.Tensor, dict],
+ action: torch.Tensor,
+ action_space_type: str = 'discrete',
+ requires_grad: bool = False
+):
+ """
+ Overview:
+ Computing importance sampling weight with given output and action
+ Arguments:
+ - target_output (:obj:`Union[torch.Tensor,dict]`): the output taking the action \
+ by the current policy network, \
+ usually this output is network output logit if action space is discrete, \
+ or is a dict containing parameters of action distribution if action space is continuous.
+ - behaviour_output (:obj:`Union[torch.Tensor,dict]`): the output taking the action \
+ by the behaviour policy network,\
+ usually this output is network output logit, if action space is discrete, \
+ or is a dict containing parameters of action distribution if action space is continuous.
+ - action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\
+ i.e.: behaviour_action
+ - action_space_type (:obj:`str`): action space types in ['discrete', 'continuous']
+ - requires_grad (:obj:`bool`): whether requires grad computation
+ Returns:
+ - rhos (:obj:`torch.Tensor`): Importance sampling weight
+ Shapes:
+ - target_output (:obj:`Union[torch.FloatTensor,dict]`): :math:`(T, B, N)`, \
+ where T is timestep, B is batch size and N is action dim
+ - behaviour_output (:obj:`Union[torch.FloatTensor,dict]`): :math:`(T, B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(T, B)`
+ - rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ Examples:
+ >>> target_output = torch.randn(2, 3, 4)
+ >>> behaviour_output = torch.randn(2, 3, 4)
+ >>> action = torch.randint(0, 4, (2, 3))
+ >>> rhos = compute_importance_weights(target_output, behaviour_output, action)
+ """
+ grad_context = torch.enable_grad() if requires_grad else torch.no_grad()
+ assert isinstance(action, torch.Tensor)
+ assert action_space_type in ['discrete', 'continuous']
+
+ with grad_context:
+ if action_space_type == 'continuous':
+ dist_target = Independent(Normal(loc=target_output['mu'], scale=target_output['sigma']), 1)
+ dist_behaviour = Independent(Normal(loc=behaviour_output['mu'], scale=behaviour_output['sigma']), 1)
+ rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action)
+ rhos = torch.exp(rhos)
+ return rhos
+ elif action_space_type == 'discrete':
+ dist_target = Categorical(logits=target_output)
+ dist_behaviour = Categorical(logits=behaviour_output)
+ rhos = dist_target.log_prob(action) - dist_behaviour.log_prob(action)
+ rhos = torch.exp(rhos)
+ return rhos
diff --git a/DI-engine/ding/rl_utils/ppg.py b/DI-engine/ding/rl_utils/ppg.py
new file mode 100644
index 0000000000000000000000000000000000000000..286266e57b60ac88b0006103deccd85889380bec
--- /dev/null
+++ b/DI-engine/ding/rl_utils/ppg.py
@@ -0,0 +1,69 @@
+from typing import Tuple
+from collections import namedtuple
+import torch
+import torch.nn.functional as F
+
+ppg_data = namedtuple('ppg_data', ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'return_', 'weight'])
+ppg_joint_loss = namedtuple('ppg_joint_loss', ['auxiliary_loss', 'behavioral_cloning_loss'])
+
+
+def ppg_joint_error(
+ data: namedtuple,
+ clip_ratio: float = 0.2,
+ use_value_clip: bool = True,
+) -> Tuple[namedtuple, namedtuple]:
+ '''
+ Overview:
+ Get PPG joint loss
+ Arguments:
+ - data (:obj:`namedtuple`): ppg input data with fieids shown in ``ppg_data``
+ - clip_ratio (:obj:`float`): clip value for ratio
+ - use_value_clip (:obj:`bool`): whether use value clip
+ Returns:
+ - ppg_joint_loss (:obj:`namedtuple`): the ppg loss item, all of them are the differentiable 0-dim tensor
+ Shapes:
+ - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
+ - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B,)`
+ - value_new (:obj:`torch.FloatTensor`): :math:`(B, 1)`
+ - value_old (:obj:`torch.FloatTensor`): :math:`(B, 1)`
+ - return (:obj:`torch.FloatTensor`): :math:`(B, 1)`
+ - weight (:obj:`torch.FloatTensor`): :math:`(B,)`
+ - auxiliary_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ - behavioral_cloning_loss (:obj:`torch.FloatTensor`): :math:`()`
+ Examples:
+ >>> action_dim = 4
+ >>> data = ppg_data(
+ >>> logit_new=torch.randn(3, action_dim),
+ >>> logit_old=torch.randn(3, action_dim),
+ >>> action=torch.randint(0, action_dim, (3,)),
+ >>> value_new=torch.randn(3, 1),
+ >>> value_old=torch.randn(3, 1),
+ >>> return_=torch.randn(3, 1),
+ >>> weight=torch.ones(3),
+ >>> )
+ >>> loss = ppg_joint_error(data, 0.99, 0.99)
+ '''
+ logit_new, logit_old, action, value_new, value_old, return_, weight = data
+
+ if weight is None:
+ weight = torch.ones_like(return_)
+
+ # auxiliary_loss
+ if use_value_clip:
+ value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio)
+ v1 = (return_ - value_new).pow(2)
+ v2 = (return_ - value_clip).pow(2)
+ auxiliary_loss = 0.5 * (torch.max(v1, v2) * weight).mean()
+ else:
+ auxiliary_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean()
+
+ dist_new = torch.distributions.categorical.Categorical(logits=logit_new)
+ dist_old = torch.distributions.categorical.Categorical(logits=logit_old)
+ logp_new = dist_new.log_prob(action)
+ logp_old = dist_old.log_prob(action)
+
+ # behavioral cloning loss
+ behavioral_cloning_loss = F.kl_div(logp_new, logp_old, reduction='batchmean')
+
+ return ppg_joint_loss(auxiliary_loss, behavioral_cloning_loss)
diff --git a/DI-engine/ding/rl_utils/ppo.py b/DI-engine/ding/rl_utils/ppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..29b441b24a21df8bf7e3ebaf6f4b8522e9a4cf36
--- /dev/null
+++ b/DI-engine/ding/rl_utils/ppo.py
@@ -0,0 +1,365 @@
+from collections import namedtuple
+from typing import Optional, Tuple
+import torch
+import torch.nn as nn
+from torch.distributions import Independent, Normal
+from ding.hpc_rl import hpc_wrapper
+
+ppo_data = namedtuple(
+ 'ppo_data', ['logit_new', 'logit_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight']
+)
+ppo_data_continuous = namedtuple(
+ 'ppo_data_continuous',
+ ['mu_sigma_new', 'mu_sigma_old', 'action', 'value_new', 'value_old', 'adv', 'return_', 'weight']
+)
+ppo_policy_data = namedtuple('ppo_policy_data', ['logit_new', 'logit_old', 'action', 'adv', 'weight'])
+ppo_policy_data_continuous = namedtuple(
+ 'ppo_policy_data_continuous', ['mu_sigma_new', 'mu_sigma_old', 'action', 'adv', 'weight']
+)
+ppo_value_data = namedtuple('ppo_value_data', ['value_new', 'value_old', 'return_', 'weight'])
+ppo_loss = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
+ppo_policy_loss = namedtuple('ppo_policy_loss', ['policy_loss', 'entropy_loss'])
+ppo_info = namedtuple('ppo_info', ['approx_kl', 'clipfrac'])
+
+
+def shape_fn_ppo(args, kwargs):
+ r"""
+ Overview:
+ Return shape of ppo for hpc
+ Returns:
+ shape: [B, N]
+ """
+ if len(args) <= 0:
+ tmp = kwargs['data'].logit_new.shape
+ else:
+ tmp = args[0].logit_new.shape
+ return tmp
+
+
+@hpc_wrapper(
+ shape_fn=shape_fn_ppo,
+ namedtuple_data=True,
+ include_args=[0, 1, 2, 3],
+ include_kwargs=['data', 'clip_ratio', 'use_value_clip', 'dual_clip']
+)
+def ppo_error(
+ data: namedtuple,
+ clip_ratio: float = 0.2,
+ use_value_clip: bool = True,
+ dual_clip: Optional[float] = None
+) -> Tuple[namedtuple, namedtuple]:
+ """
+ Overview:
+ Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip
+ Arguments:
+ - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data``
+ - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2
+ - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
+ - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
+ defaults to 5.0, if you don't want to use it, set this parameter to None
+ Returns:
+ - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
+ - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
+ Shapes:
+ - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
+ - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - value_new (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - value_old (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - adv (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - return (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
+ - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ - value_loss (:obj:`torch.FloatTensor`): :math:`()`
+ - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
+ Examples:
+ >>> action_dim = 4
+ >>> data = ppo_data(
+ >>> logit_new=torch.randn(3, action_dim),
+ >>> logit_old=torch.randn(3, action_dim),
+ >>> action=torch.randint(0, action_dim, (3,)),
+ >>> value_new=torch.randn(3),
+ >>> value_old=torch.randn(3),
+ >>> adv=torch.randn(3),
+ >>> return_=torch.randn(3),
+ >>> weight=torch.ones(3),
+ >>> )
+ >>> loss, info = ppo_error(data)
+
+ .. note::
+
+ adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many
+ ways to calculate this mean and std, like among data buffer or train batch, so we don't couple
+ this part into ppo_error, you can refer to our examples for different ways.
+ """
+ assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format(
+ dual_clip
+ )
+ logit_new, logit_old, action, value_new, value_old, adv, return_, weight = data
+ policy_data = ppo_policy_data(logit_new, logit_old, action, adv, weight)
+ policy_output, policy_info = ppo_policy_error(policy_data, clip_ratio, dual_clip)
+ value_data = ppo_value_data(value_new, value_old, return_, weight)
+ value_loss = ppo_value_error(value_data, clip_ratio, use_value_clip)
+
+ return ppo_loss(policy_output.policy_loss, value_loss, policy_output.entropy_loss), policy_info
+
+
+def ppo_policy_error(data: namedtuple,
+ clip_ratio: float = 0.2,
+ dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]:
+ '''
+ Overview:
+ Get PPO policy loss
+ Arguments:
+ - data (:obj:`namedtuple`): ppo input data with fieids shown in ``ppo_policy_data``
+ - clip_ratio (:obj:`float`): clip value for ratio
+ - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
+ defaults to 5.0, if you don't want to use it, set this parameter to None
+ Returns:
+ - ppo_policy_loss (:obj:`namedtuple`): the ppo policy loss item, all of them are the differentiable 0-dim tensor
+ - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
+ Shapes:
+ - logit_new (:obj:`torch.FloatTensor`): :math:`(B, N)`, where B is batch size and N is action dim
+ - logit_old (:obj:`torch.FloatTensor`): :math:`(B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - adv (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
+ - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
+ Examples:
+ >>> action_dim = 4
+ >>> data = ppo_policy_data(
+ >>> logit_new=torch.randn(3, action_dim),
+ >>> logit_old=torch.randn(3, action_dim),
+ >>> action=torch.randint(0, action_dim, (3,)),
+ >>> adv=torch.randn(3),
+ >>> weight=torch.ones(3),
+ >>> )
+ >>> loss, info = ppo_policy_error(data)
+ '''
+ logit_new, logit_old, action, adv, weight = data
+ if weight is None:
+ weight = torch.ones_like(adv)
+ dist_new = torch.distributions.categorical.Categorical(logits=logit_new)
+ dist_old = torch.distributions.categorical.Categorical(logits=logit_old)
+ logp_new = dist_new.log_prob(action)
+ logp_old = dist_old.log_prob(action)
+ dist_new_entropy = dist_new.entropy()
+ if dist_new_entropy.shape != weight.shape:
+ dist_new_entropy = dist_new.entropy().mean(dim=1)
+ entropy_loss = (dist_new_entropy * weight).mean()
+ # policy_loss
+ ratio = torch.exp(logp_new - logp_old)
+ if ratio.shape != adv.shape:
+ ratio = ratio.mean(dim=1)
+ surr1 = ratio * adv
+ surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv
+ if dual_clip is not None:
+ clip1 = torch.min(surr1, surr2)
+ clip2 = torch.max(clip1, dual_clip * adv)
+ # only use dual_clip when adv < 0
+ policy_loss = -(torch.where(adv < 0, clip2, clip1) * weight).mean()
+ else:
+ policy_loss = (-torch.min(surr1, surr2) * weight).mean()
+ with torch.no_grad():
+ approx_kl = (logp_old - logp_new).mean().item()
+ clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
+ clipfrac = torch.as_tensor(clipped).float().mean().item()
+ return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac)
+
+
+def ppo_value_error(
+ data: namedtuple,
+ clip_ratio: float = 0.2,
+ use_value_clip: bool = True,
+) -> torch.Tensor:
+ '''
+ Overview:
+ Get PPO value loss
+ Arguments:
+ - data (:obj:`namedtuple`): ppo input data with fieids shown in ``ppo_value_data``
+ - clip_ratio (:obj:`float`): clip value for ratio
+ - use_value_clip (:obj:`bool`): whether use value clip
+ Returns:
+ - value_loss (:obj:`torch.FloatTensor`): the ppo value loss item, \
+ all of them are the differentiable 0-dim tensor
+ Shapes:
+ - value_new (:obj:`torch.FloatTensor`): :math:`(B, )`, where B is batch size
+ - value_old (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - return (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
+ - value_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ Examples:
+ >>> action_dim = 4
+ >>> data = ppo_value_data(
+ >>> value_new=torch.randn(3),
+ >>> value_old=torch.randn(3),
+ >>> return_=torch.randn(3),
+ >>> weight=torch.ones(3),
+ >>> )
+ >>> loss, info = ppo_value_error(data)
+ '''
+ value_new, value_old, return_, weight = data
+ if weight is None:
+ weight = torch.ones_like(value_old)
+ # value_loss
+ if use_value_clip:
+ value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio)
+ v1 = (return_ - value_new).pow(2)
+ v2 = (return_ - value_clip).pow(2)
+ value_loss = 0.5 * (torch.max(v1, v2) * weight).mean()
+ else:
+ value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean()
+ return value_loss
+
+
+def ppo_error_continuous(
+ data: namedtuple,
+ clip_ratio: float = 0.2,
+ use_value_clip: bool = True,
+ dual_clip: Optional[float] = None
+) -> Tuple[namedtuple, namedtuple]:
+ """
+ Overview:
+ Implementation of Proximal Policy Optimization (arXiv:1707.06347) with value_clip and dual_clip
+ Arguments:
+ - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data``
+ - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2
+ - use_value_clip (:obj:`bool`): whether to use clip in value loss with the same ratio as policy
+ - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
+ defaults to 5.0, if you don't want to use it, set this parameter to None
+ Returns:
+ - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
+ - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
+ Shapes:
+ - mu_sigma_new (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim
+ - mu_sigma_old (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - value_new (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - value_old (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - adv (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - return (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
+ - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ - value_loss (:obj:`torch.FloatTensor`): :math:`()`
+ - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
+ Examples:
+ >>> action_dim = 4
+ >>> data = ppo_data_continuous(
+ >>> mu_sigma_new= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2),
+ >>> mu_sigma_old= dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2),
+ >>> action=torch.randn(3, action_dim),
+ >>> value_new=torch.randn(3),
+ >>> value_old=torch.randn(3),
+ >>> adv=torch.randn(3),
+ >>> return_=torch.randn(3),
+ >>> weight=torch.ones(3),
+ >>> )
+ >>> loss, info = ppo_error(data)
+
+ .. note::
+
+ adv is already normalized value (adv - adv.mean()) / (adv.std() + 1e-8), and there are many
+ ways to calculate this mean and std, like among data buffer or train batch, so we don't couple
+ this part into ppo_error, you can refer to our examples for different ways.
+ """
+ assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format(
+ dual_clip
+ )
+ mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight = data
+ if weight is None:
+ weight = torch.ones_like(adv)
+
+ dist_new = Independent(Normal(mu_sigma_new['mu'], mu_sigma_new['sigma']), 1)
+ if len(mu_sigma_old['mu'].shape) == 1:
+ dist_old = Independent(Normal(mu_sigma_old['mu'].unsqueeze(-1), mu_sigma_old['sigma'].unsqueeze(-1)), 1)
+ else:
+ dist_old = Independent(Normal(mu_sigma_old['mu'], mu_sigma_old['sigma']), 1)
+ logp_new = dist_new.log_prob(action)
+ logp_old = dist_old.log_prob(action)
+ entropy_loss = (dist_new.entropy() * weight).mean()
+ # policy_loss
+ ratio = torch.exp(logp_new - logp_old)
+ surr1 = ratio * adv
+ surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv
+ if dual_clip is not None:
+ policy_loss = (-torch.max(torch.min(surr1, surr2), dual_clip * adv) * weight).mean()
+ else:
+ policy_loss = (-torch.min(surr1, surr2) * weight).mean()
+ with torch.no_grad():
+ approx_kl = (logp_old - logp_new).mean().item()
+ clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
+ clipfrac = torch.as_tensor(clipped).float().mean().item()
+ # value_loss
+ if use_value_clip:
+ value_clip = value_old + (value_new - value_old).clamp(-clip_ratio, clip_ratio)
+ v1 = (return_ - value_new).pow(2)
+ v2 = (return_ - value_clip).pow(2)
+ value_loss = 0.5 * (torch.max(v1, v2) * weight).mean()
+ else:
+ value_loss = 0.5 * ((return_ - value_new).pow(2) * weight).mean()
+
+ return ppo_loss(policy_loss, value_loss, entropy_loss), ppo_info(approx_kl, clipfrac)
+
+
+def ppo_policy_error_continuous(data: namedtuple,
+ clip_ratio: float = 0.2,
+ dual_clip: Optional[float] = None) -> Tuple[namedtuple, namedtuple]:
+ """
+ Overview:
+ Implementation of Proximal Policy Optimization (arXiv:1707.06347) with dual_clip
+ Arguments:
+ - data (:obj:`namedtuple`): the ppo input data with fieids shown in ``ppo_data``
+ - clip_ratio (:obj:`float`): the ppo clip ratio for the constraint of policy update, defaults to 0.2
+ - dual_clip (:obj:`float`): a parameter c mentioned in arXiv:1912.09729 Equ. 5, shoule be in [1, inf),\
+ defaults to 5.0, if you don't want to use it, set this parameter to None
+ Returns:
+ - ppo_loss (:obj:`namedtuple`): the ppo loss item, all of them are the differentiable 0-dim tensor
+ - ppo_info (:obj:`namedtuple`): the ppo optim information for monitoring, all of them are Python scalar
+ Shapes:
+ - mu_sigma_new (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim
+ - mu_sigma_old (:obj:`tuple`): :math:`((B, N), (B, N))`, where B is batch size and N is action dim
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - adv (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - weight (:obj:`torch.FloatTensor` or :obj:`None`): :math:`(B, )`
+ - policy_loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ - entropy_loss (:obj:`torch.FloatTensor`): :math:`()`
+ Examples:
+ >>> action_dim = 4
+ >>> data = ppo_policy_data_continuous(
+ >>> mu_sigma_new=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2),
+ >>> mu_sigma_old=dict(mu=torch.randn(3, action_dim), sigma=torch.randn(3, action_dim)**2),
+ >>> action=torch.randn(3, action_dim),
+ >>> adv=torch.randn(3),
+ >>> weight=torch.ones(3),
+ >>> )
+ >>> loss, info = ppo_policy_error_continuous(data)
+ """
+ assert dual_clip is None or dual_clip > 1.0, "dual_clip value must be greater than 1.0, but get value: {}".format(
+ dual_clip
+ )
+ mu_sigma_new, mu_sigma_old, action, adv, weight = data
+ if weight is None:
+ weight = torch.ones_like(adv)
+
+ dist_new = Independent(Normal(mu_sigma_new['mu'], mu_sigma_new['sigma']), 1)
+ if len(mu_sigma_old['mu'].shape) == 1:
+ dist_old = Independent(Normal(mu_sigma_old['mu'].unsqueeze(-1), mu_sigma_old['sigma'].unsqueeze(-1)), 1)
+ else:
+ dist_old = Independent(Normal(mu_sigma_old['mu'], mu_sigma_old['sigma']), 1)
+ logp_new = dist_new.log_prob(action)
+ logp_old = dist_old.log_prob(action)
+ entropy_loss = (dist_new.entropy() * weight).mean()
+ # policy_loss
+ ratio = torch.exp(logp_new - logp_old)
+ surr1 = ratio * adv
+ surr2 = ratio.clamp(1 - clip_ratio, 1 + clip_ratio) * adv
+ if dual_clip is not None:
+ policy_loss = (-torch.max(torch.min(surr1, surr2), dual_clip * adv) * weight).mean()
+ else:
+ policy_loss = (-torch.min(surr1, surr2) * weight).mean()
+ with torch.no_grad():
+ approx_kl = (logp_old - logp_new).mean().item()
+ clipped = ratio.gt(1 + clip_ratio) | ratio.lt(1 - clip_ratio)
+ clipfrac = torch.as_tensor(clipped).float().mean().item()
+ return ppo_policy_loss(policy_loss, entropy_loss), ppo_info(approx_kl, clipfrac)
diff --git a/DI-engine/ding/rl_utils/retrace.py b/DI-engine/ding/rl_utils/retrace.py
new file mode 100644
index 0000000000000000000000000000000000000000..01b1f3f0f337086c49bb858608a9c68308bf7d54
--- /dev/null
+++ b/DI-engine/ding/rl_utils/retrace.py
@@ -0,0 +1,56 @@
+import torch
+import torch.nn.functional as F
+from collections import namedtuple
+from ding.rl_utils.isw import compute_importance_weights
+
+
+def compute_q_retraces(
+ q_values: torch.Tensor,
+ v_pred: torch.Tensor,
+ rewards: torch.Tensor,
+ actions: torch.Tensor,
+ weights: torch.Tensor,
+ ratio: torch.Tensor,
+ gamma: float = 0.9
+) -> torch.Tensor:
+ """
+ Shapes:
+ - q_values (:obj:`torch.Tensor`): :math:`(T + 1, B, N)`, where T is unroll_len, B is batch size, N is discrete \
+ action dim.
+ - v_pred (:obj:`torch.Tensor`): :math:`(T + 1, B, 1)`
+ - rewards (:obj:`torch.Tensor`): :math:`(T, B)`
+ - actions (:obj:`torch.Tensor`): :math:`(T, B)`
+ - weights (:obj:`torch.Tensor`): :math:`(T, B)`
+ - ratio (:obj:`torch.Tensor`): :math:`(T, B, N)`
+ - q_retraces (:obj:`torch.Tensor`): :math:`(T + 1, B, 1)`
+ Examples:
+ >>> T=2
+ >>> B=3
+ >>> N=4
+ >>> q_values=torch.randn(T+1, B, N)
+ >>> v_pred=torch.randn(T+1, B, 1)
+ >>> rewards=torch.randn(T, B)
+ >>> actions=torch.randint(0, N, (T, B))
+ >>> weights=torch.ones(T, B)
+ >>> ratio=torch.randn(T, B, N)
+ >>> q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio)
+
+ .. note::
+ q_retrace operation doesn't need to compute gradient, just executes forward computation.
+ """
+ T = q_values.size()[0] - 1
+ rewards = rewards.unsqueeze(-1)
+ actions = actions.unsqueeze(-1)
+ weights = weights.unsqueeze(-1)
+ q_retraces = torch.zeros_like(v_pred) # shape (T+1),B,1
+ tmp_retraces = v_pred[-1] # shape B,1
+ q_retraces[-1] = v_pred[-1]
+
+ q_gather = torch.zeros_like(v_pred)
+ q_gather[0:-1] = q_values[0:-1].gather(-1, actions) # shape (T+1),B,1
+ ratio_gather = ratio.gather(-1, actions) # shape T,B,1
+
+ for idx in reversed(range(T)):
+ q_retraces[idx] = rewards[idx] + gamma * weights[idx] * tmp_retraces
+ tmp_retraces = ratio_gather[idx].clamp(max=1.0) * (q_retraces[idx] - q_gather[idx]) + v_pred[idx]
+ return q_retraces # shape (T+1),B,1
diff --git a/DI-engine/ding/rl_utils/sampler.py b/DI-engine/ding/rl_utils/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..8afbd605bbc90b9d611c2402d9d2399cadca4dc9
--- /dev/null
+++ b/DI-engine/ding/rl_utils/sampler.py
@@ -0,0 +1,127 @@
+import torch
+import treetensor.torch as ttorch
+from torch.distributions import Normal, Independent
+
+
+class ArgmaxSampler:
+ '''
+ Overview:
+ Argmax sampler, return the index of the maximum value
+ '''
+
+ def __call__(self, logit: torch.Tensor) -> torch.Tensor:
+ '''
+ Overview:
+ Return the index of the maximum value
+ Arguments:
+ - logit (:obj:`torch.Tensor`): The input tensor
+ Returns:
+ - action (:obj:`torch.Tensor`): The index of the maximum value
+ '''
+ return logit.argmax(dim=-1)
+
+
+class MultinomialSampler:
+ '''
+ Overview:
+ Multinomial sampler, return the index of the sampled value
+ '''
+
+ def __call__(self, logit: torch.Tensor) -> torch.Tensor:
+ '''
+ Overview:
+ Return the index of the sampled value
+ Arguments:
+ - logit (:obj:`torch.Tensor`): The input tensor
+ Returns:
+ - action (:obj:`torch.Tensor`): The index of the sampled value
+ '''
+ dist = torch.distributions.Categorical(logits=logit)
+ return dist.sample()
+
+
+class MuSampler:
+ '''
+ Overview:
+ Mu sampler, return the mu of the input tensor
+ '''
+
+ def __call__(self, logit: ttorch.Tensor) -> torch.Tensor:
+ '''
+ Overview:
+ Return the mu of the input tensor
+ Arguments:
+ - logit (:obj:`ttorch.Tensor`): The input tensor
+ Returns:
+ - action (:obj:`torch.Tensor`): The mu of the input tensor
+ '''
+ return logit.mu
+
+
+class ReparameterizationSampler:
+ '''
+ Overview:
+ Reparameterization sampler, return the reparameterized value of the input tensor
+ '''
+
+ def __call__(self, logit: ttorch.Tensor) -> torch.Tensor:
+ '''
+ Overview:
+ Return the reparameterized value of the input tensor
+ Arguments:
+ - logit (:obj:`ttorch.Tensor`): The input tensor
+ Returns:
+ - action (:obj:`torch.Tensor`): The reparameterized value of the input tensor
+ '''
+ dist = Normal(logit.mu, logit.sigma)
+ dist = Independent(dist, 1)
+ return dist.rsample()
+
+
+class HybridStochasticSampler:
+ '''
+ Overview:
+ Hybrid stochastic sampler, return the sampled action type and the reparameterized action args
+ '''
+
+ def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor:
+ '''
+ Overview:
+ Return the sampled action type and the reparameterized action args
+ Arguments:
+ - logit (:obj:`ttorch.Tensor`): The input tensor
+ Returns:
+ - action (:obj:`ttorch.Tensor`): The sampled action type and the reparameterized action args
+ '''
+ dist = torch.distributions.Categorical(logits=logit.action_type)
+ action_type = dist.sample()
+ dist = Normal(logit.action_args.mu, logit.action_args.sigma)
+ dist = Independent(dist, 1)
+ action_args = dist.rsample()
+ return ttorch.as_tensor({
+ 'action_type': action_type,
+ 'action_args': action_args,
+ })
+
+
+class HybridDeterminsticSampler:
+ '''
+ Overview:
+ Hybrid deterministic sampler, return the argmax action type and the mu action args
+ '''
+
+ def __call__(self, logit: ttorch.Tensor) -> ttorch.Tensor:
+ '''
+ Overview:
+ Return the argmax action type and the mu action args
+ Arguments:
+ - logit (:obj:`ttorch.Tensor`): The input tensor
+ Returns:
+ - action (:obj:`ttorch.Tensor`): The argmax action type and the mu action args
+ '''
+ action_type = logit.action_type.argmax(dim=-1)
+ action_args = logit.action_args.mu
+ return ttorch.as_tensor({
+ 'action_type': action_type,
+ 'action_args': action_args,
+ })
diff --git a/DI-engine/ding/rl_utils/td.py b/DI-engine/ding/rl_utils/td.py
new file mode 100644
index 0000000000000000000000000000000000000000..1622d2c289e55fbe622d656bf4c563cf1952fa29
--- /dev/null
+++ b/DI-engine/ding/rl_utils/td.py
@@ -0,0 +1,1646 @@
+import copy
+import numpy as np
+from collections import namedtuple
+from typing import Union, Optional, Callable
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ding.hpc_rl import hpc_wrapper
+from ding.rl_utils.value_rescale import value_transform, value_inv_transform
+from ding.torch_utils import to_tensor
+
+q_1step_td_data = namedtuple('q_1step_td_data', ['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight'])
+
+
+def discount_cumsum(x, gamma: float = 1.0) -> np.ndarray:
+ assert abs(gamma - 1.) < 1e-5, "gamma equals to 1.0 in original decision transformer paper"
+ disc_cumsum = np.zeros_like(x)
+ disc_cumsum[-1] = x[-1]
+ for t in reversed(range(x.shape[0] - 1)):
+ disc_cumsum[t] = x[t] + gamma * disc_cumsum[t + 1]
+ return disc_cumsum
+
+
+def q_1step_td_error(
+ data: namedtuple,
+ gamma: float,
+ criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
+) -> torch.Tensor:
+ """
+ Overview:
+ 1 step td_error, support single agent case and multi agent case.
+ Arguments:
+ - data (:obj:`q_1step_td_data`): The input data, q_1step_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - criterion (:obj:`torch.nn.modules`): Loss function criterion
+ Returns:
+ - loss (:obj:`torch.Tensor`): 1step td error
+ Shapes:
+ - data (:obj:`q_1step_td_data`): the q_1step_td_data containing\
+ ['q', 'next_q', 'act', 'next_act', 'reward', 'done', 'weight']
+ - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
+ - next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
+ - act (:obj:`torch.LongTensor`): :math:`(B, )`
+ - next_act (:obj:`torch.LongTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`( , B)`
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
+ Examples:
+ >>> action_dim = 4
+ >>> data = q_1step_td_data(
+ >>> q=torch.randn(3, action_dim),
+ >>> next_q=torch.randn(3, action_dim),
+ >>> act=torch.randint(0, action_dim, (3,)),
+ >>> next_act=torch.randint(0, action_dim, (3,)),
+ >>> reward=torch.randn(3),
+ >>> done=torch.randint(0, 2, (3,)).bool(),
+ >>> weight=torch.ones(3),
+ >>> )
+ >>> loss = q_1step_td_error(data, 0.99)
+ """
+ q, next_q, act, next_act, reward, done, weight = data
+ assert len(act.shape) == 1, act.shape
+ assert len(reward.shape) == 1, reward.shape
+ batch_range = torch.arange(act.shape[0])
+ if weight is None:
+ weight = torch.ones_like(reward)
+ q_s_a = q[batch_range, act]
+ target_q_s_a = next_q[batch_range, next_act]
+ target_q_s_a = gamma * (1 - done) * target_q_s_a + reward
+ return (criterion(q_s_a, target_q_s_a.detach()) * weight).mean()
+
+
+m_q_1step_td_data = namedtuple('m_q_1step_td_data', ['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight'])
+
+
+def m_q_1step_td_error(
+ data: namedtuple,
+ gamma: float,
+ tau: float,
+ alpha: float,
+ criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
+) -> torch.Tensor:
+ """
+ Overview:
+ Munchausen td_error for DQN algorithm, support 1 step td error.
+ Arguments:
+ - data (:obj:`m_q_1step_td_data`): The input data, m_q_1step_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - tau (:obj:`float`): Entropy factor for Munchausen DQN
+ - alpha (:obj:`float`): Discount factor for Munchausen term
+ - criterion (:obj:`torch.nn.modules`): Loss function criterion
+ Returns:
+ - loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor
+ Shapes:
+ - data (:obj:`m_q_1step_td_data`): the m_q_1step_td_data containing\
+ ['q', 'target_q', 'next_q', 'act', 'reward', 'done', 'weight']
+ - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
+ - target_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
+ - next_q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
+ - act (:obj:`torch.LongTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`( , B)`
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
+ Examples:
+ >>> action_dim = 4
+ >>> data = m_q_1step_td_data(
+ >>> q=torch.randn(3, action_dim),
+ >>> target_q=torch.randn(3, action_dim),
+ >>> next_q=torch.randn(3, action_dim),
+ >>> act=torch.randint(0, action_dim, (3,)),
+ >>> reward=torch.randn(3),
+ >>> done=torch.randint(0, 2, (3,)),
+ >>> weight=torch.ones(3),
+ >>> )
+ >>> loss = m_q_1step_td_error(data, 0.99, 0.01, 0.01)
+ """
+ q, target_q, next_q, act, reward, done, weight = data
+ lower_bound = -1
+ assert len(act.shape) == 1, act.shape
+ assert len(reward.shape) == 1, reward.shape
+ batch_range = torch.arange(act.shape[0])
+ if weight is None:
+ weight = torch.ones_like(reward)
+ q_s_a = q[batch_range, act]
+ # calculate muchausen addon
+ # replay_log_policy
+ target_v_s = target_q[batch_range].max(1)[0].unsqueeze(-1)
+
+ logsum = torch.logsumexp((target_q - target_v_s) / tau, 1).unsqueeze(-1)
+ log_pi = target_q - target_v_s - tau * logsum
+ act_get = act.unsqueeze(-1)
+ # same to the last second tau_log_pi_a
+ munchausen_addon = log_pi.gather(1, act_get)
+
+ muchausen_term = alpha * torch.clamp(munchausen_addon, min=lower_bound, max=1)
+
+ # replay_next_log_policy
+ target_v_s_next = next_q[batch_range].max(1)[0].unsqueeze(-1)
+ logsum_next = torch.logsumexp((next_q - target_v_s_next) / tau, 1).unsqueeze(-1)
+ tau_log_pi_next = next_q - target_v_s_next - tau * logsum_next
+ # do stable softmax == replay_next_policy
+ pi_target = F.softmax((next_q - target_v_s_next) / tau)
+ target_q_s_a = (gamma * (pi_target * (next_q - tau_log_pi_next) * (1 - done.unsqueeze(-1))).sum(1)).unsqueeze(-1)
+
+ target_q_s_a = reward.unsqueeze(-1) + muchausen_term + target_q_s_a
+ td_error_per_sample = criterion(q_s_a.unsqueeze(-1), target_q_s_a.detach()).squeeze(-1)
+
+ # calculate action_gap and clipfrac
+ with torch.no_grad():
+ top2_q_s = target_q[batch_range].topk(2, dim=1, largest=True, sorted=True)[0]
+ action_gap = (top2_q_s[:, 0] - top2_q_s[:, 1]).mean()
+
+ clipped = munchausen_addon.gt(1) | munchausen_addon.lt(lower_bound)
+ clipfrac = torch.as_tensor(clipped).float()
+
+ return (td_error_per_sample * weight).mean(), td_error_per_sample, action_gap, clipfrac
+
+
+q_v_1step_td_data = namedtuple('q_v_1step_td_data', ['q', 'v', 'act', 'reward', 'done', 'weight'])
+
+
+def q_v_1step_td_error(
+ data: namedtuple, gamma: float, criterion: torch.nn.modules = nn.MSELoss(reduction='none')
+) -> torch.Tensor:
+ # we will use this function in discrete sac algorithm to calculate td error between q and v value.
+ """
+ Overview:
+ td_error between q and v value for SAC algorithm, support 1 step td error.
+ Arguments:
+ - data (:obj:`q_v_1step_td_data`): The input data, q_v_1step_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - criterion (:obj:`torch.nn.modules`): Loss function criterion
+ Returns:
+ - loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor
+ Shapes:
+ - data (:obj:`q_v_1step_td_data`): the q_v_1step_td_data containing\
+ ['q', 'v', 'act', 'reward', 'done', 'weight']
+ - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
+ - v (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - act (:obj:`torch.LongTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`( , B)`
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
+ Examples:
+ >>> action_dim = 4
+ >>> data = q_v_1step_td_data(
+ >>> q=torch.randn(3, action_dim),
+ >>> v=torch.randn(3),
+ >>> act=torch.randint(0, action_dim, (3,)),
+ >>> reward=torch.randn(3),
+ >>> done=torch.randint(0, 2, (3,)),
+ >>> weight=torch.ones(3),
+ >>> )
+ >>> loss = q_v_1step_td_error(data, 0.99)
+ """
+ q, v, act, reward, done, weight = data
+ if len(act.shape) == 1:
+ assert len(reward.shape) == 1, reward.shape
+ batch_range = torch.arange(act.shape[0])
+ if weight is None:
+ weight = torch.ones_like(reward)
+ q_s_a = q[batch_range, act]
+ target_q_s_a = gamma * (1 - done) * v + reward
+ else:
+ assert len(reward.shape) == 1, reward.shape
+ batch_range = torch.arange(act.shape[0])
+ actor_range = torch.arange(act.shape[1])
+ batch_actor_range = torch.arange(act.shape[0] * act.shape[1])
+ if weight is None:
+ weight = torch.ones_like(act)
+ temp_q = q.reshape(act.shape[0] * act.shape[1], -1)
+ temp_act = act.reshape(act.shape[0] * act.shape[1])
+ q_s_a = temp_q[batch_actor_range, temp_act]
+ q_s_a = q_s_a.reshape(act.shape[0], act.shape[1])
+ target_q_s_a = gamma * (1 - done).unsqueeze(1) * v + reward.unsqueeze(1)
+ td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
+ return (td_error_per_sample * weight).mean(), td_error_per_sample
+
+
+def view_similar(x: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
+ size = list(x.shape) + [1 for _ in range(len(target.shape) - len(x.shape))]
+ return x.view(*size)
+
+
+nstep_return_data = namedtuple('nstep_return_data', ['reward', 'next_value', 'done'])
+
+
+def nstep_return(data: namedtuple, gamma: Union[float, list], nstep: int, value_gamma: Optional[torch.Tensor] = None):
+ '''
+ Overview:
+ Calculate nstep return for DQN algorithm, support single agent case and multi agent case.
+ Arguments:
+ - data (:obj:`nstep_return_data`): The input data, nstep_return_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - nstep (:obj:`int`): nstep num
+ - value_gamma (:obj:`torch.Tensor`): Discount factor for value
+ Returns:
+ - return (:obj:`torch.Tensor`): nstep return
+ Shapes:
+ - data (:obj:`nstep_return_data`): the nstep_return_data containing\
+ ['reward', 'next_value', 'done']
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
+ - next_value (:obj:`torch.FloatTensor`): :math:`(, B)`
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ Examples:
+ >>> data = nstep_return_data(
+ >>> reward=torch.randn(3, 3),
+ >>> next_value=torch.randn(3),
+ >>> done=torch.randint(0, 2, (3,)),
+ >>> )
+ >>> loss = nstep_return(data, 0.99, 3)
+ '''
+
+ reward, next_value, done = data
+ assert reward.shape[0] == nstep
+ device = reward.device
+
+ if isinstance(gamma, float):
+ reward_factor = torch.ones(nstep).to(device)
+ for i in range(1, nstep):
+ reward_factor[i] = gamma * reward_factor[i - 1]
+ reward_factor = view_similar(reward_factor, reward)
+ return_tmp = reward.mul(reward_factor).sum(0)
+ if value_gamma is None:
+ return_ = return_tmp + (gamma ** nstep) * next_value * (1 - done)
+ else:
+ return_ = return_tmp + value_gamma * next_value * (1 - done)
+
+ elif isinstance(gamma, list):
+ # if gamma is list, for NGU policy case
+ reward_factor = torch.ones([nstep + 1, done.shape[0]]).to(device)
+ for i in range(1, nstep + 1):
+ reward_factor[i] = torch.stack(gamma, dim=0).to(device) * reward_factor[i - 1]
+ reward_factor = view_similar(reward_factor, reward)
+ return_tmp = reward.mul(reward_factor[:nstep]).sum(0)
+ return_ = return_tmp + reward_factor[nstep] * next_value * (1 - done)
+ else:
+ raise TypeError("The type of gamma should be float or list")
+
+ return return_
+
+
+dist_1step_td_data = namedtuple(
+ 'dist_1step_td_data', ['dist', 'next_dist', 'act', 'next_act', 'reward', 'done', 'weight']
+)
+
+
+def dist_1step_td_error(
+ data: namedtuple,
+ gamma: float,
+ v_min: float,
+ v_max: float,
+ n_atom: int,
+) -> torch.Tensor:
+ """
+ Overview:
+ 1 step td_error for distributed q-learning based algorithm
+ Arguments:
+ - data (:obj:`dist_1step_td_data`): The input data, dist_nstep_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - v_min (:obj:`float`): The min value of support
+ - v_max (:obj:`float`): The max value of support
+ - n_atom (:obj:`int`): The num of atom
+ Returns:
+ - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
+ Shapes:
+ - data (:obj:`dist_1step_td_data`): the dist_1step_td_data containing\
+ ['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight']
+ - dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom]
+ - next_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)`
+ - act (:obj:`torch.LongTensor`): :math:`(B, )`
+ - next_act (:obj:`torch.LongTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`(, B)`
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
+ Examples:
+ >>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True)
+ >>> next_dist = torch.randn(4, 3, 51).abs()
+ >>> act = torch.randint(0, 3, (4,))
+ >>> next_act = torch.randint(0, 3, (4,))
+ >>> reward = torch.randn(4)
+ >>> done = torch.randint(0, 2, (4,))
+ >>> data = dist_1step_td_data(dist, next_dist, act, next_act, reward, done, None)
+ >>> loss = dist_1step_td_error(data, 0.99, -10.0, 10.0, 51)
+ """
+ dist, next_dist, act, next_act, reward, done, weight = data
+ device = reward.device
+ assert len(reward.shape) == 1, reward.shape
+ support = torch.linspace(v_min, v_max, n_atom).to(device)
+ delta_z = (v_max - v_min) / (n_atom - 1)
+
+ if len(act.shape) == 1:
+ reward = reward.unsqueeze(-1)
+ done = done.unsqueeze(-1)
+ batch_size = act.shape[0]
+ batch_range = torch.arange(batch_size)
+ if weight is None:
+ weight = torch.ones_like(reward)
+ next_dist = next_dist[batch_range, next_act].detach()
+ else:
+ reward = reward.unsqueeze(-1).repeat(1, act.shape[1])
+ done = done.unsqueeze(-1).repeat(1, act.shape[1])
+
+ batch_size = act.shape[0] * act.shape[1]
+ batch_range = torch.arange(act.shape[0] * act.shape[1])
+ action_dim = dist.shape[2]
+ dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1)
+ reward = reward.reshape(act.shape[0] * act.shape[1], -1)
+ done = done.reshape(act.shape[0] * act.shape[1], -1)
+ next_dist = next_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1)
+
+ next_act = next_act.reshape(act.shape[0] * act.shape[1])
+ next_dist = next_dist[batch_range, next_act].detach()
+ next_dist = next_dist.reshape(act.shape[0] * act.shape[1], -1)
+ act = act.reshape(act.shape[0] * act.shape[1])
+ if weight is None:
+ weight = torch.ones_like(reward)
+ target_z = reward + (1 - done) * gamma * support
+ target_z = target_z.clamp(min=v_min, max=v_max)
+ b = (target_z - v_min) / delta_z
+ l = b.floor().long()
+ u = b.ceil().long()
+ # Fix disappearing probability mass when l = b = u (b is int)
+ l[(u > 0) * (l == u)] -= 1
+ u[(l < (n_atom - 1)) * (l == u)] += 1
+
+ proj_dist = torch.zeros_like(next_dist)
+ offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size,
+ n_atom).long().to(device)
+ proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_dist * (u.float() - b)).view(-1))
+ proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_dist * (b - l.float())).view(-1))
+
+ log_p = torch.log(dist[batch_range, act])
+
+ loss = -(log_p * proj_dist * weight).sum(-1).mean()
+
+ return loss
+
+
+dist_nstep_td_data = namedtuple(
+ 'dist_1step_td_data', ['dist', 'next_n_dist', 'act', 'next_n_act', 'reward', 'done', 'weight']
+)
+
+
+def shape_fn_dntd(args, kwargs):
+ r"""
+ Overview:
+ Return dntd shape for hpc
+ Returns:
+ shape: [T, B, N, n_atom]
+ """
+ if len(args) <= 0:
+ tmp = [kwargs['data'].reward.shape[0]]
+ tmp.extend(list(kwargs['data'].dist.shape))
+ else:
+ tmp = [args[0].reward.shape[0]]
+ tmp.extend(list(args[0].dist.shape))
+ return tmp
+
+
+@hpc_wrapper(
+ shape_fn=shape_fn_dntd,
+ namedtuple_data=True,
+ include_args=[0, 1, 2, 3],
+ include_kwargs=['data', 'gamma', 'v_min', 'v_max']
+)
+def dist_nstep_td_error(
+ data: namedtuple,
+ gamma: float,
+ v_min: float,
+ v_max: float,
+ n_atom: int,
+ nstep: int = 1,
+ value_gamma: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ """
+ Overview:
+ Multistep (1 step or n step) td_error for distributed q-learning based algorithm, support single\
+ agent case and multi agent case.
+ Arguments:
+ - data (:obj:`dist_nstep_td_data`): The input data, dist_nstep_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - nstep (:obj:`int`): nstep num, default set to 1
+ Returns:
+ - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
+ Shapes:
+ - data (:obj:`dist_nstep_td_data`): the dist_nstep_td_data containing\
+ ['dist', 'next_n_dist', 'act', 'reward', 'done', 'weight']
+ - dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)` i.e. [batch_size, action_dim, n_atom]
+ - next_n_dist (:obj:`torch.FloatTensor`): :math:`(B, N, n_atom)`
+ - act (:obj:`torch.LongTensor`): :math:`(B, )`
+ - next_n_act (:obj:`torch.LongTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ Examples:
+ >>> dist = torch.randn(4, 3, 51).abs().requires_grad_(True)
+ >>> next_n_dist = torch.randn(4, 3, 51).abs()
+ >>> done = torch.randn(4)
+ >>> action = torch.randint(0, 3, size=(4, ))
+ >>> next_action = torch.randint(0, 3, size=(4, ))
+ >>> reward = torch.randn(5, 4)
+ >>> data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None)
+ >>> loss, _ = dist_nstep_td_error(data, 0.95, -10.0, 10.0, 51, 5)
+ """
+ dist, next_n_dist, act, next_n_act, reward, done, weight = data
+ device = reward.device
+ reward_factor = torch.ones(nstep).to(device)
+ for i in range(1, nstep):
+ reward_factor[i] = gamma * reward_factor[i - 1]
+ reward = torch.matmul(reward_factor, reward)
+ support = torch.linspace(v_min, v_max, n_atom).to(device)
+ delta_z = (v_max - v_min) / (n_atom - 1)
+ if len(act.shape) == 1:
+ reward = reward.unsqueeze(-1)
+ done = done.unsqueeze(-1)
+ batch_size = act.shape[0]
+ batch_range = torch.arange(batch_size)
+ if weight is None:
+ weight = torch.ones_like(reward)
+ elif isinstance(weight, float):
+ weight = torch.tensor(weight)
+
+ next_n_dist = next_n_dist[batch_range, next_n_act].detach()
+ else:
+ reward = reward.unsqueeze(-1).repeat(1, act.shape[1])
+ done = done.unsqueeze(-1).repeat(1, act.shape[1])
+
+ batch_size = act.shape[0] * act.shape[1]
+ batch_range = torch.arange(act.shape[0] * act.shape[1])
+ action_dim = dist.shape[2]
+ dist = dist.reshape(act.shape[0] * act.shape[1], action_dim, -1)
+ reward = reward.reshape(act.shape[0] * act.shape[1], -1)
+ done = done.reshape(act.shape[0] * act.shape[1], -1)
+ next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], action_dim, -1)
+
+ next_n_act = next_n_act.reshape(act.shape[0] * act.shape[1])
+ next_n_dist = next_n_dist[batch_range, next_n_act].detach()
+ next_n_dist = next_n_dist.reshape(act.shape[0] * act.shape[1], -1)
+ act = act.reshape(act.shape[0] * act.shape[1])
+ if weight is None:
+ weight = torch.ones_like(reward)
+ elif isinstance(weight, float):
+ weight = torch.tensor(weight)
+
+ if value_gamma is None:
+ target_z = reward + (1 - done) * (gamma ** nstep) * support
+ elif isinstance(value_gamma, float):
+ value_gamma = torch.tensor(value_gamma).unsqueeze(-1)
+ target_z = reward + (1 - done) * value_gamma * support
+ else:
+ value_gamma = value_gamma.unsqueeze(-1)
+ target_z = reward + (1 - done) * value_gamma * support
+ target_z = target_z.clamp(min=v_min, max=v_max)
+ b = (target_z - v_min) / delta_z
+ l = b.floor().long()
+ u = b.ceil().long()
+ # Fix disappearing probability mass when l = b = u (b is int)
+ l[(u > 0) * (l == u)] -= 1
+ u[(l < (n_atom - 1)) * (l == u)] += 1
+
+ proj_dist = torch.zeros_like(next_n_dist)
+ offset = torch.linspace(0, (batch_size - 1) * n_atom, batch_size).unsqueeze(1).expand(batch_size,
+ n_atom).long().to(device)
+ proj_dist.view(-1).index_add_(0, (l + offset).view(-1), (next_n_dist * (u.float() - b)).view(-1))
+ proj_dist.view(-1).index_add_(0, (u + offset).view(-1), (next_n_dist * (b - l.float())).view(-1))
+
+ assert (dist[batch_range, act] > 0.0).all(), ("dist act", dist[batch_range, act], "dist:", dist)
+ log_p = torch.log(dist[batch_range, act])
+
+ if len(weight.shape) == 1:
+ weight = weight.unsqueeze(-1)
+
+ td_error_per_sample = -(log_p * proj_dist).sum(-1)
+
+ loss = -(log_p * proj_dist * weight).sum(-1).mean()
+
+ return loss, td_error_per_sample
+
+
+v_1step_td_data = namedtuple('v_1step_td_data', ['v', 'next_v', 'reward', 'done', 'weight'])
+
+
+def v_1step_td_error(
+ data: namedtuple,
+ gamma: float,
+ criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
+) -> torch.Tensor:
+ '''
+ Overview:
+ 1 step td_error for distributed value based algorithm
+ Arguments:
+ - data (:obj:`v_1step_td_data`): The input data, v_1step_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - criterion (:obj:`torch.nn.modules`): Loss function criterion
+ Returns:
+ - loss (:obj:`torch.Tensor`): 1step td error, 0-dim tensor
+ Shapes:
+ - data (:obj:`v_1step_td_data`): the v_1step_td_data containing\
+ ['v', 'next_v', 'reward', 'done', 'weight']
+ - v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ]
+ - next_v (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`(, B)`
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
+ Examples:
+ >>> v = torch.randn(5).requires_grad_(True)
+ >>> next_v = torch.randn(5)
+ >>> reward = torch.rand(5)
+ >>> done = torch.zeros(5)
+ >>> data = v_1step_td_data(v, next_v, reward, done, None)
+ >>> loss, td_error_per_sample = v_1step_td_error(data, 0.99)
+ '''
+ v, next_v, reward, done, weight = data
+ if weight is None:
+ weight = torch.ones_like(v)
+ if len(v.shape) == len(reward.shape):
+ if done is not None:
+ target_v = gamma * (1 - done) * next_v + reward
+ else:
+ target_v = gamma * next_v + reward
+ else:
+ if done is not None:
+ target_v = gamma * (1 - done).unsqueeze(1) * next_v + reward.unsqueeze(1)
+ else:
+ target_v = gamma * next_v + reward.unsqueeze(1)
+ td_error_per_sample = criterion(v, target_v.detach())
+ return (td_error_per_sample * weight).mean(), td_error_per_sample
+
+
+v_nstep_td_data = namedtuple('v_nstep_td_data', ['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma'])
+
+
+def v_nstep_td_error(
+ data: namedtuple,
+ gamma: float,
+ nstep: int = 1,
+ criterion: torch.nn.modules = nn.MSELoss(reduction='none') # noqa
+) -> torch.Tensor:
+ r"""
+ Overview:
+ Multistep (n step) td_error for distributed value based algorithm
+ Arguments:
+ - data (:obj:`dist_nstep_td_data`): The input data, v_nstep_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - nstep (:obj:`int`): nstep num, default set to 1
+ Returns:
+ - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
+ Shapes:
+ - data (:obj:`dist_nstep_td_data`): The v_nstep_td_data containing\
+ ['v', 'next_n_v', 'reward', 'done', 'weight', 'value_gamma']
+ - v (:obj:`torch.FloatTensor`): :math:`(B, )` i.e. [batch_size, ]
+ - next_v (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
+ - value_gamma (:obj:`torch.Tensor`): If the remaining data in the buffer is less than n_step\
+ we use value_gamma as the gamma discount value for next_v rather than gamma**n_step
+ Examples:
+ >>> v = torch.randn(5).requires_grad_(True)
+ >>> next_v = torch.randn(5)
+ >>> reward = torch.rand(5, 5)
+ >>> done = torch.zeros(5)
+ >>> data = v_nstep_td_data(v, next_v, reward, done, 0.9, 0.99)
+ >>> loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5)
+ """
+ v, next_n_v, reward, done, weight, value_gamma = data
+ if weight is None:
+ weight = torch.ones_like(v)
+ target_v = nstep_return(nstep_return_data(reward, next_n_v, done), gamma, nstep, value_gamma)
+ td_error_per_sample = criterion(v, target_v.detach())
+ return (td_error_per_sample * weight).mean(), td_error_per_sample
+
+
+q_nstep_td_data = namedtuple(
+ 'q_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight']
+)
+
+dqfd_nstep_td_data = namedtuple(
+ 'dqfd_nstep_td_data', [
+ 'q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'done_one_step', 'weight', 'new_n_q_one_step',
+ 'next_n_action_one_step', 'is_expert'
+ ]
+)
+
+
+def shape_fn_qntd(args, kwargs):
+ r"""
+ Overview:
+ Return qntd shape for hpc
+ Returns:
+ shape: [T, B, N]
+ """
+ if len(args) <= 0:
+ tmp = [kwargs['data'].reward.shape[0]]
+ tmp.extend(list(kwargs['data'].q.shape))
+ else:
+ tmp = [args[0].reward.shape[0]]
+ tmp.extend(list(args[0].q.shape))
+ return tmp
+
+
+@hpc_wrapper(shape_fn=shape_fn_qntd, namedtuple_data=True, include_args=[0, 1], include_kwargs=['data', 'gamma'])
+def q_nstep_td_error(
+ data: namedtuple,
+ gamma: Union[float, list],
+ nstep: int = 1,
+ cum_reward: bool = False,
+ value_gamma: Optional[torch.Tensor] = None,
+ criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
+) -> torch.Tensor:
+ """
+ Overview:
+ Multistep (1 step or n step) td_error for q-learning based algorithm
+ Arguments:
+ - data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data
+ - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value
+ - criterion (:obj:`torch.nn.modules`): Loss function criterion
+ - nstep (:obj:`int`): nstep num, default set to 1
+ Returns:
+ - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
+ - td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor
+ Shapes:
+ - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
+ ['q', 'next_n_q', 'action', 'reward', 'done']
+ - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
+ - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
+ Examples:
+ >>> next_q = torch.randn(4, 3)
+ >>> done = torch.randn(4)
+ >>> action = torch.randint(0, 3, size=(4, ))
+ >>> next_action = torch.randint(0, 3, size=(4, ))
+ >>> nstep =3
+ >>> q = torch.randn(4, 3).requires_grad_(True)
+ >>> reward = torch.rand(nstep, 4)
+ >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ >>> loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep)
+ """
+ q, next_n_q, action, next_n_action, reward, done, weight = data
+ if weight is None:
+ weight = torch.ones_like(reward)
+
+ if len(action.shape) == 1: # single agent case
+ action = action.unsqueeze(-1)
+ elif len(action.shape) > 1: # MARL case
+ reward = reward.unsqueeze(-1)
+ weight = weight.unsqueeze(-1)
+ done = done.unsqueeze(-1)
+ if value_gamma is not None:
+ value_gamma = value_gamma.unsqueeze(-1)
+
+ q_s_a = q.gather(-1, action).squeeze(-1)
+
+ target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1)
+
+ if cum_reward:
+ if value_gamma is None:
+ target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)
+ else:
+ target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)
+ else:
+ target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
+ td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
+ return (td_error_per_sample * weight).mean(), td_error_per_sample
+
+
+def bdq_nstep_td_error(
+ data: namedtuple,
+ gamma: Union[float, list],
+ nstep: int = 1,
+ cum_reward: bool = False,
+ value_gamma: Optional[torch.Tensor] = None,
+ criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
+) -> torch.Tensor:
+ """
+ Overview:
+ Multistep (1 step or n step) td_error for BDQ algorithm, referenced paper "Action Branching Architectures for \
+ Deep Reinforcement Learning", link: https://arxiv.org/pdf/1711.08946.
+ In fact, the original paper only provides the 1-step TD-error calculation method, and here we extend the \
+ calculation method of n-step, i.e., TD-error:
+ Arguments:
+ - data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data
+ - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value
+ - criterion (:obj:`torch.nn.modules`): Loss function criterion
+ - nstep (:obj:`int`): nstep num, default set to 1
+ Returns:
+ - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
+ - td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor
+ Shapes:
+ - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing \
+ ['q', 'next_n_q', 'action', 'reward', 'done']
+ - q (:obj:`torch.FloatTensor`): :math:`(B, D, N)` i.e. [batch_size, branch_num, action_bins_per_branch]
+ - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, D, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B, D)`
+ - next_n_action (:obj:`torch.LongTensor`): :math:`(B, D)`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
+ Examples:
+ >>> action_per_branch = 3
+ >>> next_q = torch.randn(8, 6, action_per_branch)
+ >>> done = torch.randn(8)
+ >>> action = torch.randint(0, action_per_branch, size=(8, 6))
+ >>> next_action = torch.randint(0, action_per_branch, size=(8, 6))
+ >>> nstep =3
+ >>> q = torch.randn(8, 6, action_per_branch).requires_grad_(True)
+ >>> reward = torch.rand(nstep, 8)
+ >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ >>> loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep)
+ """
+ q, next_n_q, action, next_n_action, reward, done, weight = data
+ if weight is None:
+ weight = torch.ones_like(reward)
+ reward = reward.unsqueeze(-1)
+ done = done.unsqueeze(-1)
+ if value_gamma is not None:
+ value_gamma = value_gamma.unsqueeze(-1)
+
+ q_s_a = q.gather(-1, action.unsqueeze(-1)).squeeze(-1)
+ target_q_s_a = next_n_q.gather(-1, next_n_action.unsqueeze(-1)).squeeze(-1)
+
+ if cum_reward:
+ if value_gamma is None:
+ target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)
+ else:
+ target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)
+ else:
+ target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
+ td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
+ td_error_per_sample = td_error_per_sample.mean(-1)
+ return (td_error_per_sample * weight).mean(), td_error_per_sample
+
+
+def shape_fn_qntd_rescale(args, kwargs):
+ r"""
+ Overview:
+ Return qntd_rescale shape for hpc
+ Returns:
+ shape: [T, B, N]
+ """
+ if len(args) <= 0:
+ tmp = [kwargs['data'].reward.shape[0]]
+ tmp.extend(list(kwargs['data'].q.shape))
+ else:
+ tmp = [args[0].reward.shape[0]]
+ tmp.extend(list(args[0].q.shape))
+ return tmp
+
+
+@hpc_wrapper(
+ shape_fn=shape_fn_qntd_rescale, namedtuple_data=True, include_args=[0, 1], include_kwargs=['data', 'gamma']
+)
+def q_nstep_td_error_with_rescale(
+ data: namedtuple,
+ gamma: Union[float, list],
+ nstep: int = 1,
+ value_gamma: Optional[torch.Tensor] = None,
+ criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
+ trans_fn: Callable = value_transform,
+ inv_trans_fn: Callable = value_inv_transform,
+) -> torch.Tensor:
+ """
+ Overview:
+ Multistep (1 step or n step) td_error with value rescaling
+ Arguments:
+ - data (:obj:`q_nstep_td_data`): The input data, q_nstep_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - nstep (:obj:`int`): nstep num, default set to 1
+ - criterion (:obj:`torch.nn.modules`): Loss function criterion
+ - trans_fn (:obj:`Callable`): Value transfrom function, default to value_transform\
+ (refer to rl_utils/value_rescale.py)
+ - inv_trans_fn (:obj:`Callable`): Value inverse transfrom function, default to value_inv_transform\
+ (refer to rl_utils/value_rescale.py)
+ Returns:
+ - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
+ Shapes:
+ - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
+ ['q', 'next_n_q', 'action', 'reward', 'done']
+ - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
+ - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ Examples:
+ >>> next_q = torch.randn(4, 3)
+ >>> done = torch.randn(4)
+ >>> action = torch.randint(0, 3, size=(4, ))
+ >>> next_action = torch.randint(0, 3, size=(4, ))
+ >>> nstep =3
+ >>> q = torch.randn(4, 3).requires_grad_(True)
+ >>> reward = torch.rand(nstep, 4)
+ >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ >>> loss, _ = q_nstep_td_error_with_rescale(data, 0.95, nstep=nstep)
+ """
+ q, next_n_q, action, next_n_action, reward, done, weight = data
+ assert len(action.shape) == 1, action.shape
+ if weight is None:
+ weight = torch.ones_like(action)
+
+ batch_range = torch.arange(action.shape[0])
+ q_s_a = q[batch_range, action]
+ target_q_s_a = next_n_q[batch_range, next_n_action]
+
+ target_q_s_a = inv_trans_fn(target_q_s_a)
+ target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
+ target_q_s_a = trans_fn(target_q_s_a)
+
+ td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
+ return (td_error_per_sample * weight).mean(), td_error_per_sample
+
+
+def dqfd_nstep_td_error(
+ data: namedtuple,
+ gamma: float,
+ lambda_n_step_td: float,
+ lambda_supervised_loss: float,
+ margin_function: float,
+ lambda_one_step_td: float = 1.,
+ nstep: int = 1,
+ cum_reward: bool = False,
+ value_gamma: Optional[torch.Tensor] = None,
+ criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
+) -> torch.Tensor:
+ """
+ Overview:
+ Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd
+ Arguments:
+ - data (:obj:`dqfd_nstep_td_data`): The input data, dqfd_nstep_td_data to calculate loss
+ - gamma (:obj:`float`): discount factor
+ - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data
+ - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value
+ - criterion (:obj:`torch.nn.modules`): Loss function criterion
+ - nstep (:obj:`int`): nstep num, default set to 10
+ Returns:
+ - loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor
+ - td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\
+ + supervised margin loss, 1-dim tensor
+ Shapes:
+ - data (:obj:`q_nstep_td_data`): the q_nstep_td_data containing\
+ ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\
+ , 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert']
+ - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
+ - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)`
+ - next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )`
+ - is_expert (:obj:`int`) : 0 or 1
+ Examples:
+ >>> next_q = torch.randn(4, 3)
+ >>> done = torch.randn(4)
+ >>> done_1 = torch.randn(4)
+ >>> next_q_one_step = torch.randn(4, 3)
+ >>> action = torch.randint(0, 3, size=(4, ))
+ >>> next_action = torch.randint(0, 3, size=(4, ))
+ >>> next_action_one_step = torch.randint(0, 3, size=(4, ))
+ >>> is_expert = torch.ones((4))
+ >>> nstep = 3
+ >>> q = torch.randn(4, 3).requires_grad_(True)
+ >>> reward = torch.rand(nstep, 4)
+ >>> data = dqfd_nstep_td_data(
+ >>> q, next_q, action, next_action, reward, done, done_1, None,
+ >>> next_q_one_step, next_action_one_step, is_expert
+ >>> )
+ >>> loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error(
+ >>> data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1,
+ >>> margin_function=0.8, nstep=nstep
+ >>> )
+ """
+ q, next_n_q, action, next_n_action, reward, done, done_one_step, weight, new_n_q_one_step, next_n_action_one_step, \
+ is_expert = data # set is_expert flag(expert 1, agent 0)
+ assert len(action.shape) == 1, action.shape
+ if weight is None:
+ weight = torch.ones_like(action)
+
+ batch_range = torch.arange(action.shape[0])
+ q_s_a = q[batch_range, action]
+ target_q_s_a = next_n_q[batch_range, next_n_action]
+ target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step]
+
+ # calculate n-step TD-loss
+ if cum_reward:
+ if value_gamma is None:
+ target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)
+ else:
+ target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)
+ else:
+ target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
+ td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
+
+ # calculate 1-step TD-loss
+ nstep = 1
+ reward = reward[0].unsqueeze(0) # get the one-step reward
+ value_gamma = None
+ if cum_reward:
+ if value_gamma is None:
+ target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_one_step)
+ else:
+ target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_one_step)
+ else:
+ target_q_s_a_one_step = nstep_return(
+ nstep_return_data(reward, target_q_s_a_one_step, done_one_step), gamma, nstep, value_gamma
+ )
+ td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach())
+ device = q_s_a.device
+ device_cpu = torch.device('cpu')
+ # calculate the supervised loss
+ l = margin_function * torch.ones_like(q).to(device_cpu) # q shape (B, A), action shape (B, )
+ l.scatter_(1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu))
+ # along the first dimension. for the index of the action, fill the corresponding position in l with 0
+ JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a)
+
+ return (
+ (
+ (
+ lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample +
+ lambda_supervised_loss * JE
+ ) * weight
+ ).mean(), lambda_n_step_td * td_error_per_sample.abs() +
+ lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(),
+ (td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean())
+ )
+
+
+def dqfd_nstep_td_error_with_rescale(
+ data: namedtuple,
+ gamma: float,
+ lambda_n_step_td: float,
+ lambda_supervised_loss: float,
+ lambda_one_step_td: float,
+ margin_function: float,
+ nstep: int = 1,
+ cum_reward: bool = False,
+ value_gamma: Optional[torch.Tensor] = None,
+ criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
+ trans_fn: Callable = value_transform,
+ inv_trans_fn: Callable = value_inv_transform,
+) -> torch.Tensor:
+ """
+ Overview:
+ Multistep n step td_error + 1 step td_error + supervised margin loss or dqfd
+ Arguments:
+ - data (:obj:`dqfd_nstep_td_data`): The input data, dqfd_nstep_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data
+ - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target q_value
+ - criterion (:obj:`torch.nn.modules`): Loss function criterion
+ - nstep (:obj:`int`): nstep num, default set to 10
+ Returns:
+ - loss (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error + supervised margin loss, 0-dim tensor
+ - td_error_per_sample (:obj:`torch.Tensor`): Multistep n step td_error + 1 step td_error\
+ + supervised margin loss, 1-dim tensor
+ Shapes:
+ - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
+ ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'weight'\
+ , 'new_n_q_one_step', 'next_n_action_one_step', 'is_expert']
+ - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
+ - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
+ - new_n_q_one_step (:obj:`torch.FloatTensor`): :math:`(B, N)`
+ - next_n_action_one_step (:obj:`torch.LongTensor`): :math:`(B, )`
+ - is_expert (:obj:`int`) : 0 or 1
+ """
+ q, next_n_q, action, next_n_action, reward, done, done_one_step, weight, new_n_q_one_step, next_n_action_one_step, \
+ is_expert = data # set is_expert flag(expert 1, agent 0)
+ assert len(action.shape) == 1, action.shape
+ if weight is None:
+ weight = torch.ones_like(action)
+
+ batch_range = torch.arange(action.shape[0])
+ q_s_a = q[batch_range, action]
+
+ target_q_s_a = next_n_q[batch_range, next_n_action]
+ target_q_s_a = inv_trans_fn(target_q_s_a) # rescale
+
+ target_q_s_a_one_step = new_n_q_one_step[batch_range, next_n_action_one_step]
+ target_q_s_a_one_step = inv_trans_fn(target_q_s_a_one_step) # rescale
+
+ # calculate n-step TD-loss
+ if cum_reward:
+ if value_gamma is None:
+ target_q_s_a = reward + (gamma ** nstep) * target_q_s_a * (1 - done)
+ else:
+ target_q_s_a = reward + value_gamma * target_q_s_a * (1 - done)
+ else:
+ # to use value_gamma in n-step TD-loss
+ target_q_s_a = nstep_return(nstep_return_data(reward, target_q_s_a, done), gamma, nstep, value_gamma)
+
+ target_q_s_a = trans_fn(target_q_s_a) # rescale
+ td_error_per_sample = criterion(q_s_a, target_q_s_a.detach())
+
+ # calculate 1-step TD-loss
+ nstep = 1
+ reward = reward[0].unsqueeze(0) # get the one-step reward
+ value_gamma = None # This is very important, to use gamma in 1-step TD-loss
+ if cum_reward:
+ if value_gamma is None:
+ target_q_s_a_one_step = reward + (gamma ** nstep) * target_q_s_a_one_step * (1 - done_one_step)
+ else:
+ target_q_s_a_one_step = reward + value_gamma * target_q_s_a_one_step * (1 - done_one_step)
+ else:
+ target_q_s_a_one_step = nstep_return(
+ nstep_return_data(reward, target_q_s_a_one_step, done_one_step), gamma, nstep, value_gamma
+ )
+
+ target_q_s_a_one_step = trans_fn(target_q_s_a_one_step) # rescale
+ td_error_one_step_per_sample = criterion(q_s_a, target_q_s_a_one_step.detach())
+ device = q_s_a.device
+ device_cpu = torch.device('cpu')
+ # calculate the supervised loss
+ l = margin_function * torch.ones_like(q).to(device_cpu) # q shape (B, A), action shape (B, )
+ l.scatter_(1, torch.LongTensor(action.unsqueeze(1).to(device_cpu)), torch.zeros_like(q, device=device_cpu))
+ # along the first dimension. for the index of the action, fill the corresponding position in l with 0
+ JE = is_expert * (torch.max(q + l.to(device), dim=1)[0] - q_s_a)
+
+ return (
+ (
+ (
+ lambda_n_step_td * td_error_per_sample + lambda_one_step_td * td_error_one_step_per_sample +
+ lambda_supervised_loss * JE
+ ) * weight
+ ).mean(), lambda_n_step_td * td_error_per_sample.abs() +
+ lambda_one_step_td * td_error_one_step_per_sample.abs() + lambda_supervised_loss * JE.abs(),
+ (td_error_per_sample.mean(), td_error_one_step_per_sample.mean(), JE.mean())
+ )
+
+
+qrdqn_nstep_td_data = namedtuple(
+ 'qrdqn_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'tau', 'weight']
+)
+
+
+def qrdqn_nstep_td_error(
+ data: namedtuple,
+ gamma: float,
+ nstep: int = 1,
+ value_gamma: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ """
+ Overview:
+ Multistep (1 step or n step) td_error with in QRDQN
+ Arguments:
+ - data (:obj:`iqn_nstep_td_data`): The input data, iqn_nstep_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - nstep (:obj:`int`): nstep num, default set to 1
+ Returns:
+ - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
+ Shapes:
+ - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
+ ['q', 'next_n_q', 'action', 'reward', 'done']
+ - q (:obj:`torch.FloatTensor`): :math:`(tau, B, N)` i.e. [tau x batch_size, action_dim]
+ - next_n_q (:obj:`torch.FloatTensor`): :math:`(tau', B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ Examples:
+ >>> next_q = torch.randn(4, 3, 3)
+ >>> done = torch.randn(4)
+ >>> action = torch.randint(0, 3, size=(4, ))
+ >>> next_action = torch.randint(0, 3, size=(4, ))
+ >>> nstep = 3
+ >>> q = torch.randn(4, 3, 3).requires_grad_(True)
+ >>> reward = torch.rand(nstep, 4)
+ >>> data = qrdqn_nstep_td_data(q, next_q, action, next_action, reward, done, 3, None)
+ >>> loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep)
+ """
+ q, next_n_q, action, next_n_action, reward, done, tau, weight = data
+
+ assert len(action.shape) == 1, action.shape
+ assert len(next_n_action.shape) == 1, next_n_action.shape
+ assert len(done.shape) == 1, done.shape
+ assert len(q.shape) == 3, q.shape
+ assert len(next_n_q.shape) == 3, next_n_q.shape
+ assert len(reward.shape) == 2, reward.shape
+
+ if weight is None:
+ weight = torch.ones_like(action)
+
+ batch_range = torch.arange(action.shape[0])
+
+ # shape: batch_size x num x 1
+ q_s_a = q[batch_range, action, :].unsqueeze(2)
+ # shape: batch_size x 1 x num
+ target_q_s_a = next_n_q[batch_range, next_n_action, :].unsqueeze(1)
+
+ assert reward.shape[0] == nstep
+ reward_factor = torch.ones(nstep).to(reward)
+ for i in range(1, nstep):
+ reward_factor[i] = gamma * reward_factor[i - 1]
+ # shape: batch_size
+ reward = torch.matmul(reward_factor, reward)
+ # shape: batch_size x 1 x num
+ if value_gamma is None:
+ target_q_s_a = reward.unsqueeze(-1).unsqueeze(-1) + (gamma ** nstep
+ ) * target_q_s_a * (1 - done).unsqueeze(-1).unsqueeze(-1)
+ else:
+ target_q_s_a = reward.unsqueeze(-1).unsqueeze(
+ -1
+ ) + value_gamma.unsqueeze(-1).unsqueeze(-1) * target_q_s_a * (1 - done).unsqueeze(-1).unsqueeze(-1)
+
+ # shape: batch_size x num x num
+ u = F.smooth_l1_loss(target_q_s_a, q_s_a, reduction="none")
+ # shape: batch_size
+ loss = (u * (tau - (target_q_s_a - q_s_a).detach().le(0.).float()).abs()).sum(-1).mean(1)
+
+ return (loss * weight).mean(), loss
+
+
+def q_nstep_sql_td_error(
+ data: namedtuple,
+ gamma: float,
+ alpha: float,
+ nstep: int = 1,
+ cum_reward: bool = False,
+ value_gamma: Optional[torch.Tensor] = None,
+ criterion: torch.nn.modules = nn.MSELoss(reduction='none'),
+) -> torch.Tensor:
+ """
+ Overview:
+ Multistep (1 step or n step) td_error for q-learning based algorithm
+ Arguments:
+ - data (:obj:`q_nstep_td_data`): The input data, q_nstep_sql_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - Alpha (:obj:`float`): A parameter to weight entropy term in a policy equation
+ - cum_reward (:obj:`bool`): Whether to use cumulative nstep reward, which is figured out when collecting data
+ - value_gamma (:obj:`torch.Tensor`): Gamma discount value for target soft_q_value
+ - criterion (:obj:`torch.nn.modules`): Loss function criterion
+ - nstep (:obj:`int`): nstep num, default set to 1
+ Returns:
+ - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
+ - td_error_per_sample (:obj:`torch.Tensor`): nstep td error, 1-dim tensor
+ Shapes:
+ - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
+ ['q', 'next_n_q', 'action', 'reward', 'done']
+ - q (:obj:`torch.FloatTensor`): :math:`(B, N)` i.e. [batch_size, action_dim]
+ - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ - td_error_per_sample (:obj:`torch.FloatTensor`): :math:`(B, )`
+ Examples:
+ >>> next_q = torch.randn(4, 3)
+ >>> done = torch.randn(4)
+ >>> action = torch.randint(0, 3, size=(4, ))
+ >>> next_action = torch.randint(0, 3, size=(4, ))
+ >>> nstep = 3
+ >>> q = torch.randn(4, 3).requires_grad_(True)
+ >>> reward = torch.rand(nstep, 4)
+ >>> data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ >>> loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 1.0, nstep=nstep)
+ """
+ q, next_n_q, action, next_n_action, reward, done, weight = data
+ assert len(action.shape) == 1, action.shape
+ if weight is None:
+ weight = torch.ones_like(action)
+
+ batch_range = torch.arange(action.shape[0])
+ q_s_a = q[batch_range, action]
+ # target_q_s_a = next_n_q[batch_range, next_n_action]
+ target_v = alpha * torch.logsumexp(
+ next_n_q / alpha, 1
+ ) # target_v = alpha * torch.log(torch.sum(torch.exp(next_n_q / alpha), 1))
+ target_v[target_v == float("Inf")] = 20
+ target_v[target_v == float("-Inf")] = -20
+ # For an appropriate hyper-parameter alpha, these hardcodes can be removed.
+ # However, algorithms may face the danger of explosion for other alphas.
+ # The hardcodes above are to prevent this situation from happening
+ record_target_v = copy.deepcopy(target_v)
+ # print(target_v)
+ if cum_reward:
+ if value_gamma is None:
+ target_v = reward + (gamma ** nstep) * target_v * (1 - done)
+ else:
+ target_v = reward + value_gamma * target_v * (1 - done)
+ else:
+ target_v = nstep_return(nstep_return_data(reward, target_v, done), gamma, nstep, value_gamma)
+ td_error_per_sample = criterion(q_s_a, target_v.detach())
+ return (td_error_per_sample * weight).mean(), td_error_per_sample, record_target_v
+
+
+iqn_nstep_td_data = namedtuple(
+ 'iqn_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'replay_quantiles', 'weight']
+)
+
+
+def iqn_nstep_td_error(
+ data: namedtuple,
+ gamma: float,
+ nstep: int = 1,
+ kappa: float = 1.0,
+ value_gamma: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ """
+ Overview:
+ Multistep (1 step or n step) td_error with in IQN, \
+ referenced paper Implicit Quantile Networks for Distributional Reinforcement Learning \
+
+ Arguments:
+ - data (:obj:`iqn_nstep_td_data`): The input data, iqn_nstep_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - nstep (:obj:`int`): nstep num, default set to 1
+ - criterion (:obj:`torch.nn.modules`): Loss function criterion
+ - beta_function (:obj:`Callable`): The risk function
+ Returns:
+ - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
+ Shapes:
+ - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
+ ['q', 'next_n_q', 'action', 'reward', 'done']
+ - q (:obj:`torch.FloatTensor`): :math:`(tau, B, N)` i.e. [tau x batch_size, action_dim]
+ - next_n_q (:obj:`torch.FloatTensor`): :math:`(tau', B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ Examples:
+ >>> next_q = torch.randn(3, 4, 3)
+ >>> done = torch.randn(4)
+ >>> action = torch.randint(0, 3, size=(4, ))
+ >>> next_action = torch.randint(0, 3, size=(4, ))
+ >>> nstep = 3
+ >>> q = torch.randn(3, 4, 3).requires_grad_(True)
+ >>> replay_quantile = torch.randn([3, 4, 1])
+ >>> reward = torch.rand(nstep, 4)
+ >>> data = iqn_nstep_td_data(q, next_q, action, next_action, reward, done, replay_quantile, None)
+ >>> loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep)
+ """
+ q, next_n_q, action, next_n_action, reward, done, replay_quantiles, weight = data
+
+ assert len(action.shape) == 1, action.shape
+ assert len(next_n_action.shape) == 1, next_n_action.shape
+ assert len(done.shape) == 1, done.shape
+ assert len(q.shape) == 3, q.shape
+ assert len(next_n_q.shape) == 3, next_n_q.shape
+ assert len(reward.shape) == 2, reward.shape
+
+ if weight is None:
+ weight = torch.ones_like(action)
+
+ batch_size = done.shape[0]
+ tau = q.shape[0]
+ tau_prime = next_n_q.shape[0]
+
+ action = action.repeat([tau, 1]).unsqueeze(-1)
+ next_n_action = next_n_action.repeat([tau_prime, 1]).unsqueeze(-1)
+
+ # shape: batch_size x tau x a
+ q_s_a = torch.gather(q, -1, action).permute([1, 0, 2])
+ # shape: batch_size x tau_prim x 1
+ target_q_s_a = torch.gather(next_n_q, -1, next_n_action).permute([1, 0, 2])
+
+ assert reward.shape[0] == nstep
+ device = torch.device("cuda" if reward.is_cuda else "cpu")
+ reward_factor = torch.ones(nstep).to(device)
+ for i in range(1, nstep):
+ reward_factor[i] = gamma * reward_factor[i - 1]
+ reward = torch.matmul(reward_factor, reward)
+ if value_gamma is None:
+ target_q_s_a = reward.unsqueeze(-1) + (gamma ** nstep) * target_q_s_a.squeeze(-1) * (1 - done).unsqueeze(-1)
+ else:
+ target_q_s_a = reward.unsqueeze(-1) + value_gamma.unsqueeze(-1) * target_q_s_a.squeeze(-1) * (1 - done
+ ).unsqueeze(-1)
+ target_q_s_a = target_q_s_a.unsqueeze(-1)
+
+ # shape: batch_size x tau' x tau x 1.
+ bellman_errors = (target_q_s_a[:, :, None, :] - q_s_a[:, None, :, :])
+
+ # The huber loss (see Section 2.3 of the paper) is defined via two cases:
+ huber_loss = torch.where(
+ bellman_errors.abs() <= kappa, 0.5 * bellman_errors ** 2, kappa * (bellman_errors.abs() - 0.5 * kappa)
+ )
+
+ # Reshape replay_quantiles to batch_size x num_tau_samples x 1
+ replay_quantiles = replay_quantiles.reshape([tau, batch_size, 1]).permute([1, 0, 2])
+
+ # shape: batch_size x num_tau_prime_samples x num_tau_samples x 1.
+ replay_quantiles = replay_quantiles[:, None, :, :].repeat([1, tau_prime, 1, 1])
+
+ # shape: batch_size x tau_prime x tau x 1.
+ quantile_huber_loss = (torch.abs(replay_quantiles - ((bellman_errors < 0).float()).detach()) * huber_loss) / kappa
+
+ # shape: batch_size
+ loss = quantile_huber_loss.sum(dim=2).mean(dim=1)[:, 0]
+
+ return (loss * weight).mean(), loss
+
+
+fqf_nstep_td_data = namedtuple(
+ 'fqf_nstep_td_data', ['q', 'next_n_q', 'action', 'next_n_action', 'reward', 'done', 'quantiles_hats', 'weight']
+)
+
+
+def fqf_nstep_td_error(
+ data: namedtuple,
+ gamma: float,
+ nstep: int = 1,
+ kappa: float = 1.0,
+ value_gamma: Optional[torch.Tensor] = None,
+) -> torch.Tensor:
+ """
+ Overview:
+ Multistep (1 step or n step) td_error with in FQF, \
+ referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \
+
+ Arguments:
+ - data (:obj:`fqf_nstep_td_data`): The input data, fqf_nstep_td_data to calculate loss
+ - gamma (:obj:`float`): Discount factor
+ - nstep (:obj:`int`): nstep num, default set to 1
+ - criterion (:obj:`torch.nn.modules`): Loss function criterion
+ - beta_function (:obj:`Callable`): The risk function
+ Returns:
+ - loss (:obj:`torch.Tensor`): nstep td error, 0-dim tensor
+ Shapes:
+ - data (:obj:`q_nstep_td_data`): The q_nstep_td_data containing\
+ ['q', 'next_n_q', 'action', 'reward', 'done']
+ - q (:obj:`torch.FloatTensor`): :math:`(B, tau, N)` i.e. [batch_size, tau, action_dim]
+ - next_n_q (:obj:`torch.FloatTensor`): :math:`(B, tau', N)`
+ - action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - next_n_action (:obj:`torch.LongTensor`): :math:`(B, )`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep(nstep)
+ - done (:obj:`torch.BoolTensor`) :math:`(B, )`, whether done in last timestep
+ - quantiles_hats (:obj:`torch.FloatTensor`): :math:`(B, tau)`
+ Examples:
+ >>> next_q = torch.randn(4, 3, 3)
+ >>> done = torch.randn(4)
+ >>> action = torch.randint(0, 3, size=(4, ))
+ >>> next_action = torch.randint(0, 3, size=(4, ))
+ >>> nstep = 3
+ >>> q = torch.randn(4, 3, 3).requires_grad_(True)
+ >>> quantiles_hats = torch.randn([4, 3])
+ >>> reward = torch.rand(nstep, 4)
+ >>> data = fqf_nstep_td_data(q, next_q, action, next_action, reward, done, quantiles_hats, None)
+ >>> loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep)
+ """
+ q, next_n_q, action, next_n_action, reward, done, quantiles_hats, weight = data
+
+ assert len(action.shape) == 1, action.shape
+ assert len(next_n_action.shape) == 1, next_n_action.shape
+ assert len(done.shape) == 1, done.shape
+ assert len(q.shape) == 3, q.shape
+ assert len(next_n_q.shape) == 3, next_n_q.shape
+ assert len(reward.shape) == 2, reward.shape
+
+ if weight is None:
+ weight = torch.ones_like(action)
+
+ batch_size = done.shape[0]
+ tau = q.shape[1]
+ tau_prime = next_n_q.shape[1]
+
+ # shape: batch_size x tau x 1
+ q_s_a = evaluate_quantile_at_action(q, action)
+ # shape: batch_size x tau_prime x 1
+ target_q_s_a = evaluate_quantile_at_action(next_n_q, next_n_action)
+
+ assert reward.shape[0] == nstep
+ reward_factor = torch.ones(nstep).to(reward.device)
+ for i in range(1, nstep):
+ reward_factor[i] = gamma * reward_factor[i - 1]
+ reward = torch.matmul(reward_factor, reward) # [batch_size]
+ if value_gamma is None:
+ target_q_s_a = reward.unsqueeze(-1) + (gamma ** nstep) * target_q_s_a.squeeze(-1) * (1 - done).unsqueeze(-1)
+ else:
+ target_q_s_a = reward.unsqueeze(-1) + value_gamma.unsqueeze(-1) * target_q_s_a.squeeze(-1) * (1 - done
+ ).unsqueeze(-1)
+ target_q_s_a = target_q_s_a.unsqueeze(-1)
+
+ # shape: batch_size x tau' x tau x 1.
+ bellman_errors = (target_q_s_a.unsqueeze(2) - q_s_a.unsqueeze(1))
+
+ # shape: batch_size x tau' x tau x 1
+ huber_loss = F.smooth_l1_loss(target_q_s_a.unsqueeze(2), q_s_a.unsqueeze(1), reduction="none")
+
+ # shape: batch_size x num_tau_prime_samples x num_tau_samples x 1.
+ quantiles_hats = quantiles_hats[:, None, :, None].repeat([1, tau_prime, 1, 1])
+
+ # shape: batch_size x tau_prime x tau x 1.
+ quantile_huber_loss = (torch.abs(quantiles_hats - ((bellman_errors < 0).float()).detach()) * huber_loss) / kappa
+
+ # shape: batch_size
+ loss = quantile_huber_loss.sum(dim=2).mean(dim=1)[:, 0]
+
+ return (loss * weight).mean(), loss
+
+
+def evaluate_quantile_at_action(q_s, actions):
+ assert q_s.shape[0] == actions.shape[0]
+
+ batch_size, num_quantiles = q_s.shape[:2]
+
+ # Expand actions into (batch_size, num_quantiles, 1).
+ action_index = actions[:, None, None].expand(batch_size, num_quantiles, 1)
+
+ # Calculate quantile values at specified actions.
+ q_s_a = q_s.gather(dim=2, index=action_index)
+
+ return q_s_a
+
+
+def fqf_calculate_fraction_loss(q_tau_i, q_value, quantiles, actions):
+ """
+ Overview:
+ Calculate the fraction loss in FQF, \
+ referenced paper Fully Parameterized Quantile Function for Distributional Reinforcement Learning \
+
+ Arguments:
+ - q_tau_i (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles-1, action_dim)`
+ - q_value (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles, action_dim)`
+ - quantiles (:obj:`torch.FloatTensor`): :math:`(batch_size, num_quantiles+1)`
+ - actions (:obj:`torch.LongTensor`): :math:`(batch_size, )`
+ Returns:
+ - fraction_loss (:obj:`torch.Tensor`): fraction loss, 0-dim tensor
+ """
+ assert q_value.requires_grad
+
+ batch_size = q_value.shape[0]
+ num_quantiles = q_value.shape[1]
+
+ with torch.no_grad():
+ sa_quantiles = evaluate_quantile_at_action(q_tau_i, actions)
+ assert sa_quantiles.shape == (batch_size, num_quantiles - 1, 1)
+ q_s_a_hats = evaluate_quantile_at_action(q_value, actions) # [batch_size, num_quantiles, 1]
+ assert q_s_a_hats.shape == (batch_size, num_quantiles, 1)
+ assert not q_s_a_hats.requires_grad
+
+ # NOTE: Proposition 1 in the paper requires F^{-1} is non-decreasing.
+ # I relax this requirements and calculate gradients of quantiles even when
+ # F^{-1} is not non-decreasing.
+
+ values_1 = sa_quantiles - q_s_a_hats[:, :-1]
+ signs_1 = sa_quantiles > torch.cat([q_s_a_hats[:, :1], sa_quantiles[:, :-1]], dim=1)
+ assert values_1.shape == signs_1.shape
+
+ values_2 = sa_quantiles - q_s_a_hats[:, 1:]
+ signs_2 = sa_quantiles < torch.cat([sa_quantiles[:, 1:], q_s_a_hats[:, -1:]], dim=1)
+ assert values_2.shape == signs_2.shape
+
+ gradient_of_taus = (torch.where(signs_1, values_1, -values_1) +
+ torch.where(signs_2, values_2, -values_2)).view(batch_size, num_quantiles - 1)
+ assert not gradient_of_taus.requires_grad
+ assert gradient_of_taus.shape == quantiles[:, 1:-1].shape
+
+ # Gradients of the network parameters and corresponding loss
+ # are calculated using chain rule.
+ fraction_loss = (gradient_of_taus * quantiles[:, 1:-1]).sum(dim=1).mean()
+
+ return fraction_loss
+
+
+td_lambda_data = namedtuple('td_lambda_data', ['value', 'reward', 'weight'])
+
+
+def shape_fn_td_lambda(args, kwargs):
+ r"""
+ Overview:
+ Return td_lambda shape for hpc
+ Returns:
+ shape: [T, B]
+ """
+ if len(args) <= 0:
+ tmp = kwargs['data'].reward.shape[0]
+ else:
+ tmp = args[0].reward.shape
+ return tmp
+
+
+@hpc_wrapper(
+ shape_fn=shape_fn_td_lambda,
+ namedtuple_data=True,
+ include_args=[0, 1, 2],
+ include_kwargs=['data', 'gamma', 'lambda_']
+)
+def td_lambda_error(data: namedtuple, gamma: float = 0.9, lambda_: float = 0.8) -> torch.Tensor:
+ """
+ Overview:
+ Computing TD(lambda) loss given constant gamma and lambda.
+ There is no special handling for terminal state value,
+ if some state has reached the terminal, just fill in zeros for values and rewards beyond terminal
+ (*including the terminal state*, values[terminal] should also be 0)
+ Arguments:
+ - data (:obj:`namedtuple`): td_lambda input data with fields ['value', 'reward', 'weight']
+ - gamma (:obj:`float`): Constant discount factor gamma, should be in [0, 1], defaults to 0.9
+ - lambda (:obj:`float`): Constant lambda, should be in [0, 1], defaults to 0.8
+ Returns:
+ - loss (:obj:`torch.Tensor`): Computed MSE loss, averaged over the batch
+ Shapes:
+ - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)`, where T is trajectory length and B is batch,\
+ which is the estimation of the state value at step 0 to T
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`, the returns from time step 0 to T-1
+ - weight (:obj:`torch.FloatTensor` or None): :math:`(B, )`, the training sample weight
+ - loss (:obj:`torch.FloatTensor`): :math:`()`, 0-dim tensor
+ Examples:
+ >>> T, B = 8, 4
+ >>> value = torch.randn(T + 1, B).requires_grad_(True)
+ >>> reward = torch.rand(T, B)
+ >>> loss = td_lambda_error(td_lambda_data(value, reward, None))
+ """
+ value, reward, weight = data
+ if weight is None:
+ weight = torch.ones_like(reward)
+ with torch.no_grad():
+ return_ = generalized_lambda_returns(value, reward, gamma, lambda_)
+ # discard the value at T as it should be considered in the next slice
+ loss = 0.5 * (F.mse_loss(return_, value[:-1], reduction='none') * weight).mean()
+ return loss
+
+
+def generalized_lambda_returns(
+ bootstrap_values: torch.Tensor,
+ rewards: torch.Tensor,
+ gammas: float,
+ lambda_: float,
+ done: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+ r"""
+ Overview:
+ Functional equivalent to trfl.value_ops.generalized_lambda_returns
+ https://github.com/deepmind/trfl/blob/2c07ac22512a16715cc759f0072be43a5d12ae45/trfl/value_ops.py#L74
+ Passing in a number instead of tensor to make the value constant for all samples in batch
+ Arguments:
+ - bootstrap_values (:obj:`torch.Tensor` or :obj:`float`):
+ estimation of the value at step 0 to *T*, of size [T_traj+1, batchsize]
+ - rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize]
+ - gammas (:obj:`torch.Tensor` or :obj:`float`):
+ Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize]
+ - lambda (:obj:`torch.Tensor` or :obj:`float`): Determining the mix of bootstrapping
+ vs further accumulation of multistep returns at each timestep, of size [T_traj, batchsize]
+ - done (:obj:`torch.Tensor` or :obj:`float`):
+ Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize]
+ Returns:
+ - return (:obj:`torch.Tensor`): Computed lambda return value
+ for each state from 0 to T-1, of size [T_traj, batchsize]
+ """
+ if not isinstance(gammas, torch.Tensor):
+ gammas = gammas * torch.ones_like(rewards)
+ if not isinstance(lambda_, torch.Tensor):
+ lambda_ = lambda_ * torch.ones_like(rewards)
+ bootstrap_values_tp1 = bootstrap_values[1:, :]
+ return multistep_forward_view(bootstrap_values_tp1, rewards, gammas, lambda_, done)
+
+
+def multistep_forward_view(
+ bootstrap_values: torch.Tensor,
+ rewards: torch.Tensor,
+ gammas: float,
+ lambda_: float,
+ done: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+ r"""
+ Overview:
+ Same as trfl.sequence_ops.multistep_forward_view
+ Implementing (12.18) in Sutton & Barto
+
+ ```
+ result[T-1] = rewards[T-1] + gammas[T-1] * bootstrap_values[T]
+ for t in 0...T-2 :
+ result[t] = rewards[t] + gammas[t]*(lambdas[t]*result[t+1] + (1-lambdas[t])*bootstrap_values[t+1])
+ ```
+
+ Assuming the first dim of input tensors correspond to the index in batch
+ Arguments:
+ - bootstrap_values (:obj:`torch.Tensor`): Estimation of the value at *step 1 to T*, of size [T_traj, batchsize]
+ - rewards (:obj:`torch.Tensor`): The returns from 0 to T-1, of size [T_traj, batchsize]
+ - gammas (:obj:`torch.Tensor`): Discount factor for each step (from 0 to T-1), of size [T_traj, batchsize]
+ - lambda (:obj:`torch.Tensor`): Determining the mix of bootstrapping vs further accumulation of \
+ multistep returns at each timestep of size [T_traj, batchsize], the element for T-1 is ignored \
+ and effectively set to 0, as there is no information about future rewards.
+ - done (:obj:`torch.Tensor` or :obj:`float`):
+ Whether the episode done at current step (from 0 to T-1), of size [T_traj, batchsize]
+ Returns:
+ - ret (:obj:`torch.Tensor`): Computed lambda return value \
+ for each state from 0 to T-1, of size [T_traj, batchsize]
+ """
+ result = torch.empty_like(rewards)
+ if done is None:
+ done = torch.zeros_like(rewards)
+ # Forced cutoff at the last one
+ result[-1, :] = rewards[-1, :] + (1 - done[-1, :]) * gammas[-1, :] * bootstrap_values[-1, :]
+ discounts = gammas * lambda_
+ for t in reversed(range(rewards.size()[0] - 1)):
+ result[t, :] = rewards[t, :] + (1 - done[t, :]) * \
+ (
+ discounts[t, :] * result[t + 1, :] +
+ (gammas[t, :] - discounts[t, :]) * bootstrap_values[t, :]
+ )
+
+ return result
diff --git a/DI-engine/ding/rl_utils/tests/test_a2c.py b/DI-engine/ding/rl_utils/tests/test_a2c.py
new file mode 100644
index 0000000000000000000000000000000000000000..e321db28c45b734a2e2060e358f046d6821981cd
--- /dev/null
+++ b/DI-engine/ding/rl_utils/tests/test_a2c.py
@@ -0,0 +1,53 @@
+import pytest
+from itertools import product
+import numpy as np
+import torch
+from ding.rl_utils import a2c_data, a2c_error, a2c_error_continuous
+
+random_weight = torch.rand(4) + 1
+weight_args = [None, random_weight]
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('weight, ', weight_args)
+def test_a2c(weight):
+ B, N = 4, 32
+ logit = torch.randn(B, N).requires_grad_(True)
+ action = torch.randint(0, N, size=(B, ))
+ value = torch.randn(B).requires_grad_(True)
+ adv = torch.rand(B)
+ return_ = torch.randn(B) * 2
+ data = a2c_data(logit, action, value, adv, return_, weight)
+ loss = a2c_error(data)
+ assert all([l.shape == tuple() for l in loss])
+ assert logit.grad is None
+ assert value.grad is None
+ total_loss = sum(loss)
+ total_loss.backward()
+ assert isinstance(logit.grad, torch.Tensor)
+ assert isinstance(value.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('weight, ', weight_args)
+def test_a2c_continuous(weight):
+ B, N = 4, 32
+ logit = {
+ "mu": torch.randn(B, N).requires_grad_(True),
+ "sigma": torch.exp(torch.randn(B, N)).requires_grad_(True),
+ }
+ action = torch.randn(B, N).requires_grad_(True)
+ value = torch.randn(B).requires_grad_(True)
+ adv = torch.rand(B)
+ return_ = torch.randn(B) * 2
+ data = a2c_data(logit, action, value, adv, return_, weight)
+ loss = a2c_error_continuous(data)
+ assert all([l.shape == tuple() for l in loss])
+ assert logit["mu"].grad is None
+ assert logit["sigma"].grad is None
+ assert value.grad is None
+ total_loss = sum(loss)
+ total_loss.backward()
+ assert isinstance(logit["mu"].grad, torch.Tensor)
+ assert isinstance(logit['sigma'].grad, torch.Tensor)
+ assert isinstance(value.grad, torch.Tensor)
diff --git a/DI-engine/ding/rl_utils/tests/test_adder.py b/DI-engine/ding/rl_utils/tests/test_adder.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bc84e785db24fb2ac1031fdb094af6fd68c95b5
--- /dev/null
+++ b/DI-engine/ding/rl_utils/tests/test_adder.py
@@ -0,0 +1,146 @@
+import pytest
+import copy
+from collections import deque
+import numpy as np
+import torch
+from ding.rl_utils import get_gae, get_gae_with_default_last_value, get_nstep_return_data, get_train_sample
+
+
+@pytest.mark.unittest
+class TestAdder:
+
+ def get_transition(self):
+ return {
+ 'value': torch.randn(1),
+ 'reward': torch.rand(1),
+ 'action': torch.rand(3),
+ 'other': np.random.randint(0, 10, size=(4, )),
+ 'obs': torch.randn(3),
+ 'done': False
+ }
+
+ def get_transition_multi_agent(self):
+ return {
+ 'value': torch.randn(1, 8),
+ 'reward': torch.rand(1, 1),
+ 'action': torch.rand(3),
+ 'other': np.random.randint(0, 10, size=(4, )),
+ 'obs': torch.randn(3),
+ 'done': False
+ }
+
+ def test_get_gae(self):
+ transitions = deque([self.get_transition() for _ in range(10)])
+ last_value = torch.randn(1)
+ output = get_gae(transitions, last_value, gamma=0.99, gae_lambda=0.97, cuda=False)
+ for i in range(len(output)):
+ o = output[i]
+ assert 'adv' in o.keys()
+ for k, v in o.items():
+ if k == 'adv':
+ assert isinstance(v, torch.Tensor)
+ assert v.shape == (1, )
+ else:
+ if k == 'done':
+ assert v == transitions[i][k]
+ else:
+ assert (v == transitions[i][k]).all()
+ output1 = get_gae_with_default_last_value(
+ copy.deepcopy(transitions), True, gamma=0.99, gae_lambda=0.97, cuda=False
+ )
+ for i in range(len(output)):
+ assert output[i]['adv'].ne(output1[i]['adv'])
+
+ data = copy.deepcopy(transitions)
+ data.append({'value': last_value})
+ output2 = get_gae_with_default_last_value(data, False, gamma=0.99, gae_lambda=0.97, cuda=False)
+ for i in range(len(output)):
+ assert output[i]['adv'].eq(output2[i]['adv'])
+
+ def test_get_gae_multi_agent(self):
+ transitions = deque([self.get_transition_multi_agent() for _ in range(10)])
+ last_value = torch.randn(1, 8)
+ output = get_gae(transitions, last_value, gamma=0.99, gae_lambda=0.97, cuda=False)
+ for i in range(len(output)):
+ o = output[i]
+ assert 'adv' in o.keys()
+ for k, v in o.items():
+ if k == 'adv':
+ assert isinstance(v, torch.Tensor)
+ assert v.shape == (
+ 1,
+ 8,
+ )
+ else:
+ if k == 'done':
+ assert v == transitions[i][k]
+ else:
+ assert (v == transitions[i][k]).all()
+ output1 = get_gae_with_default_last_value(
+ copy.deepcopy(transitions), True, gamma=0.99, gae_lambda=0.97, cuda=False
+ )
+ for i in range(len(output)):
+ for j in range(output[i]['adv'].shape[1]):
+ assert output[i]['adv'][0][j].ne(output1[i]['adv'][0][j])
+
+ data = copy.deepcopy(transitions)
+ data.append({'value': last_value})
+ output2 = get_gae_with_default_last_value(data, False, gamma=0.99, gae_lambda=0.97, cuda=False)
+ for i in range(len(output)):
+ for j in range(output[i]['adv'].shape[1]):
+ assert output[i]['adv'][0][j].eq(output2[i]['adv'][0][j])
+
+ def test_get_nstep_return_data(self):
+ nstep = 3
+ data = deque([self.get_transition() for _ in range(10)])
+ output_data = get_nstep_return_data(data, nstep=nstep)
+ assert len(output_data) == 10
+ for i, o in enumerate(output_data):
+ assert o['reward'].shape == (nstep, )
+ if i >= 10 - nstep + 1:
+ assert o['done'] is data[-1]['done']
+ assert o['reward'][-(i - 10 + nstep):].sum() == 0
+
+ data = deque([self.get_transition() for _ in range(12)])
+ output_data = get_nstep_return_data(data, nstep=nstep)
+ assert len(output_data) == 12
+
+ def test_get_train_sample(self):
+ data = [self.get_transition() for _ in range(10)]
+ output = get_train_sample(data, unroll_len=1, last_fn_type='drop')
+ assert len(output) == 10
+
+ output = get_train_sample(data, unroll_len=4, last_fn_type='drop')
+ assert len(output) == 2
+ for o in output:
+ for v in o.values():
+ assert len(v) == 4
+
+ output = get_train_sample(data, unroll_len=4, last_fn_type='null_padding')
+ assert len(output) == 3
+ for o in output:
+ for v in o.values():
+ assert len(v) == 4
+ assert output[-1]['done'] == [False, False, True, True]
+ for i in range(1, 10 % 4 + 1):
+ assert id(output[-1]['obs'][-i]) != id(output[-1]['obs'][0])
+
+ output = get_train_sample(data, unroll_len=4, last_fn_type='last')
+ assert len(output) == 3
+ for o in output:
+ for v in o.values():
+ assert len(v) == 4
+ miss_num = 4 - 10 % 4
+ for i in range(10 % 4):
+ assert id(output[-1]['obs'][i]) != id(output[-2]['obs'][miss_num + i])
+
+ output = get_train_sample(data, unroll_len=11, last_fn_type='last')
+ assert len(output) == 1
+ assert len(output[0]['obs']) == 11
+ assert output[-1]['done'][-1] is True
+ assert output[-1]['done'][0] is False
+ assert id(output[-1]['obs'][-1]) != id(output[-1]['obs'][0])
+
+
+test = TestAdder()
+test.test_get_gae_multi_agent()
diff --git a/DI-engine/ding/rl_utils/tests/test_coma.py b/DI-engine/ding/rl_utils/tests/test_coma.py
new file mode 100644
index 0000000000000000000000000000000000000000..51bb38c9c4fb6eaf590c3d842f4802c1d4f77ca8
--- /dev/null
+++ b/DI-engine/ding/rl_utils/tests/test_coma.py
@@ -0,0 +1,40 @@
+import pytest
+from itertools import product
+import numpy as np
+import torch
+from ding.rl_utils import coma_data, coma_error
+
+random_weight = torch.rand(128, 4, 8) + 1
+weight_args = [None, random_weight]
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('weight, ', weight_args)
+def test_coma(weight):
+ T, B, A, N = 128, 4, 8, 32
+ logit = torch.randn(
+ T,
+ B,
+ A,
+ N,
+ ).requires_grad_(True)
+ action = torch.randint(
+ 0, N, size=(
+ T,
+ B,
+ A,
+ )
+ )
+ reward = torch.rand(T, B)
+ q_value = torch.randn(T, B, A, N).requires_grad_(True)
+ target_q_value = torch.randn(T, B, A, N).requires_grad_(True)
+ mask = torch.randint(0, 2, (T, B, A))
+ data = coma_data(logit, action, q_value, target_q_value, reward, weight)
+ loss = coma_error(data, 0.99, 0.95)
+ assert all([l.shape == tuple() for l in loss])
+ assert logit.grad is None
+ assert q_value.grad is None
+ total_loss = sum(loss)
+ total_loss.backward()
+ assert isinstance(logit.grad, torch.Tensor)
+ assert isinstance(q_value.grad, torch.Tensor)
diff --git a/DI-engine/ding/rl_utils/tests/test_exploration.py b/DI-engine/ding/rl_utils/tests/test_exploration.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2b0894e3c97c270768376f64e65ed01778f9d44
--- /dev/null
+++ b/DI-engine/ding/rl_utils/tests/test_exploration.py
@@ -0,0 +1,39 @@
+import pytest
+import torch
+from ding.rl_utils import get_epsilon_greedy_fn, create_noise_generator
+
+
+@pytest.mark.unittest
+def test_eps_greedy():
+ exp_eps = get_epsilon_greedy_fn(start=0.9, end=0.1, decay=100)
+ assert exp_eps(0) == 0.9
+ assert exp_eps(10) > exp_eps(200)
+ lin_eps1 = get_epsilon_greedy_fn(start=1.0, end=0.1, decay=90, type_='linear')
+ assert lin_eps1(9) == 0.91
+ assert lin_eps1(100) == 0.1
+ lin_eps2 = get_epsilon_greedy_fn(start=0.9, end=0.3, decay=20, type_='linear')
+ assert pytest.approx(lin_eps2(9)) == 0.63
+ assert lin_eps2(100) == 0.3
+
+
+@pytest.mark.unittest
+def test_noise():
+ bs, dim = 4, 15
+ logits = torch.Tensor(bs, dim)
+ gauss = create_noise_generator(noise_type='gauss', noise_kwargs={'mu': 0.0, 'sigma': 1.5})
+ g_noise = gauss(logits.shape, logits.device)
+ assert g_noise.shape == logits.shape
+ assert g_noise.device == logits.device
+
+ x0 = torch.rand(bs, dim)
+ ou = create_noise_generator(noise_type='ou', noise_kwargs={'mu': 0.1, 'sigma': 1.0, 'theta': 2.0, 'x0': x0})
+ o_noise1 = ou((bs, dim), x0.device)
+ o_noise2 = ou((bs, dim), x0.device)
+ assert o_noise2.shape == x0.shape
+ assert o_noise2.device == x0.device
+ assert not torch.equal(ou.x0, ou._x) # OUNoise._x is not the same as _x0 after 2 calls
+ assert torch.abs(x0 - ou.x0).max() < 1e-6 # OUNoise._x0 does not change
+ x0 += 0.05
+ ou.x0 = x0
+ assert torch.abs(ou.x0 - x0).max() < 1e-6 and torch.abs(ou.x0 - ou._x).max() < 1e-6
+ o_noise3 = ou(x0.shape, x0.device)
diff --git a/DI-engine/ding/rl_utils/tests/test_gae.py b/DI-engine/ding/rl_utils/tests/test_gae.py
new file mode 100644
index 0000000000000000000000000000000000000000..a945e5686b5262fcc972821be5a2617de4844801
--- /dev/null
+++ b/DI-engine/ding/rl_utils/tests/test_gae.py
@@ -0,0 +1,36 @@
+import pytest
+import torch
+from ding.rl_utils import gae_data, gae
+
+
+@pytest.mark.unittest
+def test_gae():
+ # batch trajectory case
+ T, B = 32, 4
+ value = torch.randn(T, B)
+ next_value = torch.randn(T, B)
+ reward = torch.randn(T, B)
+ done = torch.zeros((T, B))
+ data = gae_data(value, next_value, reward, done, None)
+ adv = gae(data)
+ assert adv.shape == (T, B)
+ # single trajectory case/concat trajectory case
+ T = 24
+ value = torch.randn(T)
+ next_value = torch.randn(T)
+ reward = torch.randn(T)
+ done = torch.zeros((T))
+ data = gae_data(value, next_value, reward, done, None)
+ adv = gae(data)
+ assert adv.shape == (T, )
+
+
+def test_gae_multi_agent():
+ T, B, A = 32, 4, 8
+ value = torch.randn(T, B, A)
+ next_value = torch.randn(T, B, A)
+ reward = torch.randn(T, B)
+ done = torch.zeros(T, B)
+ data = gae_data(value, next_value, reward, done, None)
+ adv = gae(data)
+ assert adv.shape == (T, B, A)
diff --git a/DI-engine/ding/rl_utils/tests/test_happo.py b/DI-engine/ding/rl_utils/tests/test_happo.py
new file mode 100644
index 0000000000000000000000000000000000000000..d82e5a37bc6a445579552a2982ee27fca36cefd8
--- /dev/null
+++ b/DI-engine/ding/rl_utils/tests/test_happo.py
@@ -0,0 +1,71 @@
+import pytest
+from itertools import product
+import numpy as np
+import torch
+
+from ding.rl_utils import happo_data, happo_error, happo_error_continuous
+from ding.rl_utils.ppo import shape_fn_ppo
+
+use_value_clip_args = [True, False]
+dual_clip_args = [None, 5.0]
+random_weight = torch.rand(4) + 1
+weight_args = [None, random_weight]
+factor_args = [torch.rand(4, 1)]
+args = [item for item in product(*[use_value_clip_args, dual_clip_args, weight_args, factor_args])]
+
+
+@pytest.mark.unittest
+def test_shape_fn_ppo():
+ data = happo_data(torch.randn(3, 5, 8), None, None, None, None, None, None, None, None)
+ shape1 = shape_fn_ppo([data], {})
+ shape2 = shape_fn_ppo([], {'data': data})
+ assert shape1 == shape2 == (3, 5, 8)
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('use_value_clip, dual_clip, weight, factor', args)
+def test_happo(use_value_clip, dual_clip, weight, factor):
+ B, N = 4, 32
+ logit_new = torch.randn(B, N).requires_grad_(True)
+ logit_old = logit_new + torch.rand_like(logit_new) * 0.1
+ action = torch.randint(0, N, size=(B, ))
+ value_new = torch.randn(B).requires_grad_(True)
+ value_old = value_new + torch.rand_like(value_new) * 0.1
+ adv = torch.rand(B)
+ return_ = torch.randn(B) * 2
+ data = happo_data(logit_new, logit_old, action, value_new, value_old, adv, return_, weight, factor)
+ loss, info = happo_error(data, use_value_clip=use_value_clip, dual_clip=dual_clip)
+ assert all([l.shape == tuple() for l in loss])
+ assert all([np.isscalar(i) for i in info])
+ assert logit_new.grad is None
+ assert value_new.grad is None
+ total_loss = sum(loss)
+ total_loss.backward()
+ assert isinstance(logit_new.grad, torch.Tensor)
+ assert isinstance(value_new.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('use_value_clip, dual_clip, weight, factor', args)
+def test_happo_error_continous(use_value_clip, dual_clip, weight, factor):
+ B, N = 4, 6
+ mu_sigma_new = {'mu': torch.rand(B, N).requires_grad_(True), 'sigma': torch.rand(B, N).requires_grad_(True)}
+ mu_sigma_old = {
+ 'mu': mu_sigma_new['mu'] + torch.rand_like(mu_sigma_new['mu']) * 0.1,
+ 'sigma': mu_sigma_new['sigma'] + torch.rand_like(mu_sigma_new['sigma']) * 0.1
+ }
+ action = torch.rand(B, N)
+ value_new = torch.randn(B).requires_grad_(True)
+ value_old = value_new + torch.rand_like(value_new) * 0.1
+ adv = torch.rand(B)
+ return_ = torch.randn(B) * 2
+ data = happo_data(mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight, factor)
+ loss, info = happo_error_continuous(data, use_value_clip=use_value_clip, dual_clip=dual_clip)
+ assert all([l.shape == tuple() for l in loss])
+ assert all([np.isscalar(i) for i in info])
+ assert mu_sigma_new['mu'].grad is None
+ assert value_new.grad is None
+ total_loss = sum(loss)
+ total_loss.backward()
+ assert isinstance(mu_sigma_new['mu'].grad, torch.Tensor)
+ assert isinstance(value_new.grad, torch.Tensor)
diff --git a/DI-engine/ding/rl_utils/tests/test_ppg.py b/DI-engine/ding/rl_utils/tests/test_ppg.py
new file mode 100644
index 0000000000000000000000000000000000000000..ffa881662e5f13ff101948279df38a5a3bc9c506
--- /dev/null
+++ b/DI-engine/ding/rl_utils/tests/test_ppg.py
@@ -0,0 +1,45 @@
+import pytest
+import time
+from itertools import product
+import numpy as np
+import torch
+from ding.rl_utils import ppg_data, ppg_joint_error
+
+use_value_clip_args = [True, False]
+random_weight = torch.rand(4) + 1
+weight_args = [None, random_weight]
+args = [item for item in product(*[use_value_clip_args, weight_args])]
+
+
+# due to numeric stability of this unittest, we rerun it when sporadic error occurs
+@pytest.mark.parametrize('use_value_clip, weight', args)
+def test_ppg(use_value_clip, weight):
+ error_count = 0
+ while True:
+ torch.manual_seed(time.time())
+ B, N = 4, 32
+ logit_new = torch.randn(B, N).add_(0.1).clamp_(0.1, 0.99)
+ logit_old = logit_new.add_(torch.rand_like(logit_new) * 0.1).clamp_(0.1, 0.99)
+ logit_new.requires_grad_(True)
+ logit_old.requires_grad_(True)
+ action = torch.randint(0, N, size=(B, ))
+ value_new = torch.randn(B).requires_grad_(True)
+ value_old = value_new + torch.rand_like(value_new) * 0.1
+ return_ = torch.randn(B) * 2
+ data = ppg_data(logit_new, logit_old, action, value_new, value_old, return_, weight)
+ loss = ppg_joint_error(data, use_value_clip=use_value_clip)
+ assert all([l.shape == tuple() for l in loss])
+ assert logit_new.grad is None
+ assert value_new.grad is None
+ total_loss = sum(loss)
+ try:
+ total_loss.backward()
+ except RuntimeError as e:
+ print("[ERROR]: {}".format(e))
+ if error_count == 10:
+ break
+ error_count += 1
+ continue
+ assert isinstance(logit_new.grad, torch.Tensor)
+ assert isinstance(value_new.grad, torch.Tensor)
+ break
diff --git a/DI-engine/ding/rl_utils/tests/test_ppo.py b/DI-engine/ding/rl_utils/tests/test_ppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..a72d0e3b1674dc2df290ee4f502187acab9c4db3
--- /dev/null
+++ b/DI-engine/ding/rl_utils/tests/test_ppo.py
@@ -0,0 +1,92 @@
+import pytest
+from itertools import product
+import numpy as np
+import torch
+
+from ding.rl_utils import ppo_data, ppo_error, ppo_error_continuous
+from ding.rl_utils.ppo import shape_fn_ppo
+
+use_value_clip_args = [True, False]
+dual_clip_args = [None, 5.0]
+random_weight = torch.rand(4) + 1
+weight_args = [None, random_weight]
+args = [item for item in product(*[use_value_clip_args, dual_clip_args, weight_args])]
+
+
+@pytest.mark.unittest
+def test_shape_fn_ppo():
+ data = ppo_data(torch.randn(3, 5, 8), None, None, None, None, None, None, None)
+ shape1 = shape_fn_ppo([data], {})
+ shape2 = shape_fn_ppo([], {'data': data})
+ assert shape1 == shape2 == (3, 5, 8)
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('use_value_clip, dual_clip, weight', args)
+def test_ppo(use_value_clip, dual_clip, weight):
+ B, N = 4, 32
+ logit_new = torch.randn(B, N).requires_grad_(True)
+ logit_old = logit_new + torch.rand_like(logit_new) * 0.1
+ action = torch.randint(0, N, size=(B, ))
+ value_new = torch.randn(B).requires_grad_(True)
+ value_old = value_new + torch.rand_like(value_new) * 0.1
+ adv = torch.rand(B)
+ return_ = torch.randn(B) * 2
+ data = ppo_data(logit_new, logit_old, action, value_new, value_old, adv, return_, weight)
+ loss, info = ppo_error(data, use_value_clip=use_value_clip, dual_clip=dual_clip)
+ assert all([l.shape == tuple() for l in loss])
+ assert all([np.isscalar(i) for i in info])
+ assert logit_new.grad is None
+ assert value_new.grad is None
+ total_loss = sum(loss)
+ total_loss.backward()
+ assert isinstance(logit_new.grad, torch.Tensor)
+ assert isinstance(value_new.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+def test_mappo():
+ B, A, N = 4, 8, 32
+ logit_new = torch.randn(B, A, N).requires_grad_(True)
+ logit_old = logit_new + torch.rand_like(logit_new) * 0.1
+ action = torch.randint(0, N, size=(B, A))
+ value_new = torch.randn(B, A).requires_grad_(True)
+ value_old = value_new + torch.rand_like(value_new) * 0.1
+ adv = torch.rand(B, A)
+ return_ = torch.randn(B, A) * 2
+ data = ppo_data(logit_new, logit_old, action, value_new, value_old, adv, return_, None)
+ loss, info = ppo_error(data)
+ assert all([l.shape == tuple() for l in loss])
+ assert all([np.isscalar(i) for i in info])
+ assert logit_new.grad is None
+ assert value_new.grad is None
+ total_loss = sum(loss)
+ total_loss.backward()
+ assert isinstance(logit_new.grad, torch.Tensor)
+ assert isinstance(value_new.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('use_value_clip, dual_clip, weight', args)
+def test_ppo_error_continous(use_value_clip, dual_clip, weight):
+ B, N = 4, 6
+ mu_sigma_new = {'mu': torch.rand(B, N).requires_grad_(True), 'sigma': torch.rand(B, N).requires_grad_(True)}
+ mu_sigma_old = {
+ 'mu': mu_sigma_new['mu'] + torch.rand_like(mu_sigma_new['mu']) * 0.1,
+ 'sigma': mu_sigma_new['sigma'] + torch.rand_like(mu_sigma_new['sigma']) * 0.1
+ }
+ action = torch.rand(B, N)
+ value_new = torch.randn(B).requires_grad_(True)
+ value_old = value_new + torch.rand_like(value_new) * 0.1
+ adv = torch.rand(B)
+ return_ = torch.randn(B) * 2
+ data = ppo_data(mu_sigma_new, mu_sigma_old, action, value_new, value_old, adv, return_, weight)
+ loss, info = ppo_error_continuous(data, use_value_clip=use_value_clip, dual_clip=dual_clip)
+ assert all([l.shape == tuple() for l in loss])
+ assert all([np.isscalar(i) for i in info])
+ assert mu_sigma_new['mu'].grad is None
+ assert value_new.grad is None
+ total_loss = sum(loss)
+ total_loss.backward()
+ assert isinstance(mu_sigma_new['mu'].grad, torch.Tensor)
+ assert isinstance(value_new.grad, torch.Tensor)
diff --git a/DI-engine/ding/rl_utils/tests/test_retrace.py b/DI-engine/ding/rl_utils/tests/test_retrace.py
new file mode 100644
index 0000000000000000000000000000000000000000..267748c1f12ef7c86115c98ba9dae1918f993036
--- /dev/null
+++ b/DI-engine/ding/rl_utils/tests/test_retrace.py
@@ -0,0 +1,18 @@
+import pytest
+import torch
+from ding.rl_utils import compute_q_retraces
+
+
+@pytest.mark.unittest
+def test_compute_q_retraces():
+ T, B, N = 64, 32, 6
+ q_values = torch.randn(T + 1, B, N)
+ v_pred = torch.randn(T + 1, B, 1)
+ rewards = torch.randn(T, B)
+ ratio = torch.rand(T, B, N) * 0.4 + 0.8
+ assert ratio.max() <= 1.2 and ratio.min() >= 0.8
+ weights = torch.rand(T, B)
+ actions = torch.randint(0, N, size=(T, B))
+ with torch.no_grad():
+ q_retraces = compute_q_retraces(q_values, v_pred, rewards, actions, weights, ratio, gamma=0.99)
+ assert q_retraces.shape == (T + 1, B, 1)
diff --git a/DI-engine/ding/rl_utils/tests/test_td.py b/DI-engine/ding/rl_utils/tests/test_td.py
new file mode 100644
index 0000000000000000000000000000000000000000..c710695cb2939934cf3523db99d5549205f5cdc4
--- /dev/null
+++ b/DI-engine/ding/rl_utils/tests/test_td.py
@@ -0,0 +1,610 @@
+import pytest
+import torch
+from ding.rl_utils import q_nstep_td_data, q_nstep_td_error, q_1step_td_data, q_1step_td_error, td_lambda_data,\
+ td_lambda_error, q_nstep_td_error_with_rescale, dist_1step_td_data, dist_1step_td_error, dist_nstep_td_data,\
+ dqfd_nstep_td_data, dqfd_nstep_td_error, dist_nstep_td_error, v_1step_td_data, v_1step_td_error, v_nstep_td_data,\
+ v_nstep_td_error, q_nstep_sql_td_error, iqn_nstep_td_data, iqn_nstep_td_error,\
+ fqf_nstep_td_data, fqf_nstep_td_error, qrdqn_nstep_td_data, qrdqn_nstep_td_error, bdq_nstep_td_error,\
+ m_q_1step_td_data, m_q_1step_td_error
+from ding.rl_utils.td import shape_fn_dntd, shape_fn_qntd, shape_fn_td_lambda, shape_fn_qntd_rescale
+
+
+@pytest.mark.unittest
+def test_q_nstep_td():
+ batch_size = 4
+ action_dim = 3
+ next_q = torch.randn(batch_size, action_dim)
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ for nstep in range(1, 10):
+ q = torch.randn(batch_size, action_dim).requires_grad_(True)
+ reward = torch.rand(nstep, batch_size)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep)
+ assert td_error_per_sample.shape == (batch_size, )
+ assert loss.shape == ()
+ assert q.grad is None
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True)
+ value_gamma = torch.tensor(0.9)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ loss, td_error_per_sample = q_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True, value_gamma=value_gamma)
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+def test_bdq_nstep_td():
+ batch_size = 8
+ branch_num = 6
+ action_per_branch = 3
+ next_q = torch.randn(batch_size, branch_num, action_per_branch)
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_per_branch, size=(batch_size, branch_num))
+ next_action = torch.randint(0, action_per_branch, size=(batch_size, branch_num))
+ for nstep in range(1, 10):
+ q = torch.randn(batch_size, branch_num, action_per_branch).requires_grad_(True)
+ reward = torch.rand(nstep, batch_size)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep)
+ assert td_error_per_sample.shape == (batch_size, )
+ assert loss.shape == ()
+ assert q.grad is None
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ loss, td_error_per_sample = bdq_nstep_td_error(data, 0.95, nstep=nstep, cum_reward=True)
+ value_gamma = torch.tensor(0.9)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ loss, td_error_per_sample = bdq_nstep_td_error(
+ data, 0.95, nstep=nstep, cum_reward=True, value_gamma=value_gamma
+ )
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+def test_q_nstep_td_ngu():
+ batch_size = 4
+ action_dim = 3
+ next_q = torch.randn(batch_size, action_dim)
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ gamma = [torch.tensor(0.95) for i in range(batch_size)]
+
+ for nstep in range(1, 10):
+ q = torch.randn(batch_size, action_dim).requires_grad_(True)
+ reward = torch.rand(nstep, batch_size)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ loss, td_error_per_sample = q_nstep_td_error(data, gamma, nstep=nstep)
+ assert td_error_per_sample.shape == (batch_size, )
+ assert loss.shape == ()
+ assert q.grad is None
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+def test_dist_1step_td():
+ batch_size = 4
+ action_dim = 3
+ n_atom = 51
+ v_min = -10.0
+ v_max = 10.0
+ dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True)
+ next_dist = torch.randn(batch_size, action_dim, n_atom).abs()
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ reward = torch.randn(batch_size)
+ data = dist_1step_td_data(dist, next_dist, action, next_action, reward, done, None)
+ loss = dist_1step_td_error(data, 0.95, v_min, v_max, n_atom)
+ assert loss.shape == ()
+ assert dist.grad is None
+ loss.backward()
+ assert isinstance(dist.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+def test_q_1step_compatible():
+ batch_size = 4
+ action_dim = 3
+ next_q = torch.randn(batch_size, action_dim)
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ q = torch.randn(batch_size, action_dim).requires_grad_(True)
+ reward = torch.rand(batch_size)
+ nstep_data = q_nstep_td_data(q, next_q, action, next_action, reward.unsqueeze(0), done, None)
+ onestep_data = q_1step_td_data(q, next_q, action, next_action, reward, done, None)
+ nstep_loss, _ = q_nstep_td_error(nstep_data, 0.99, nstep=1)
+ onestep_loss = q_1step_td_error(onestep_data, 0.99)
+ assert pytest.approx(nstep_loss.item()) == onestep_loss.item()
+
+
+@pytest.mark.unittest
+def test_dist_nstep_td():
+ batch_size = 4
+ action_dim = 3
+ n_atom = 51
+ v_min = -10.0
+ v_max = 10.0
+ nstep = 5
+ dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True)
+ next_n_dist = torch.randn(batch_size, action_dim, n_atom).abs()
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ reward = torch.randn(nstep, batch_size)
+ data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None)
+ loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep)
+ assert loss.shape == ()
+ assert dist.grad is None
+ loss.backward()
+ assert isinstance(dist.grad, torch.Tensor)
+ weight = torch.tensor([0.9])
+ value_gamma = torch.tensor(0.9)
+ data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, weight)
+ loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep, value_gamma)
+ assert loss.shape == ()
+ loss.backward()
+ assert isinstance(dist.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+def test_dist_nstep_multi_agent_td():
+ batch_size = 4
+ action_dim = 3
+ agent_num = 2
+ n_atom = 51
+ v_min = -10.0
+ v_max = 10.0
+ nstep = 5
+ dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs().requires_grad_(True)
+ next_n_dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs()
+ done = torch.randint(0, 2, (batch_size, ))
+ action = torch.randint(
+ 0, action_dim, size=(
+ batch_size,
+ agent_num,
+ )
+ )
+ next_action = torch.randint(
+ 0, action_dim, size=(
+ batch_size,
+ agent_num,
+ )
+ )
+ reward = torch.randn(nstep, batch_size)
+ data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None)
+ loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep)
+ assert loss.shape == ()
+ assert dist.grad is None
+ loss.backward()
+ assert isinstance(dist.grad, torch.Tensor)
+ weight = 0.9
+ value_gamma = 0.9
+ data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, weight)
+ loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep, value_gamma)
+ assert loss.shape == ()
+ loss.backward()
+ assert isinstance(dist.grad, torch.Tensor)
+ agent_total_loss = 0
+ for i in range(agent_num):
+ data = dist_nstep_td_data(
+ dist[:, i, ], next_n_dist[:, i, ], action[:, i, ], next_action[:, i, ], reward, done, weight
+ )
+ agent_loss, _ = dist_nstep_td_error(data, 0.95, v_min, v_max, n_atom, nstep, value_gamma)
+ agent_total_loss = agent_total_loss + agent_loss
+ agent_average_loss = agent_total_loss / agent_num
+ assert abs(agent_average_loss.item() - loss.item()) < 1e-5
+
+
+@pytest.mark.unittest
+def test_q_nstep_td_with_rescale():
+ batch_size = 4
+ action_dim = 3
+ next_q = torch.randn(batch_size, action_dim)
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ for nstep in range(1, 10):
+ q = torch.randn(batch_size, action_dim).requires_grad_(True)
+ reward = torch.rand(nstep, batch_size)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ loss, _ = q_nstep_td_error_with_rescale(data, 0.95, nstep=nstep)
+ assert loss.shape == ()
+ assert q.grad is None
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+ print(loss)
+
+
+@pytest.mark.unittest
+def test_q_nstep_td_with_rescale_ngu():
+ batch_size = 4
+ action_dim = 3
+ next_q = torch.randn(batch_size, action_dim)
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ gamma = [torch.tensor(0.95) for i in range(batch_size)]
+ for nstep in range(1, 10):
+ q = torch.randn(batch_size, action_dim).requires_grad_(True)
+ reward = torch.rand(nstep, batch_size)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ loss, _ = q_nstep_td_error_with_rescale(data, gamma, nstep=nstep)
+ assert loss.shape == ()
+ assert q.grad is None
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+ print(loss)
+
+
+@pytest.mark.unittest
+def test_qrdqn_nstep_td():
+ batch_size = 4
+ action_dim = 3
+ tau = 3
+ next_q = torch.randn(batch_size, action_dim, tau)
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ for nstep in range(1, 10):
+ q = torch.randn(batch_size, action_dim, tau).requires_grad_(True)
+ reward = torch.rand(nstep, batch_size)
+ data = qrdqn_nstep_td_data(q, next_q, action, next_action, reward, done, tau, None)
+ loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep)
+ assert td_error_per_sample.shape == (batch_size, )
+ assert loss.shape == ()
+ assert q.grad is None
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+ loss, td_error_per_sample = qrdqn_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9))
+ assert td_error_per_sample.shape == (batch_size, )
+
+
+@pytest.mark.unittest
+def test_dist_1step_compatible():
+ batch_size = 4
+ action_dim = 3
+ n_atom = 51
+ v_min = -10.0
+ v_max = 10.0
+ dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True)
+ next_dist = torch.randn(batch_size, action_dim, n_atom).abs()
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ reward = torch.randn(batch_size)
+ onestep_data = dist_1step_td_data(dist, next_dist, action, next_action, reward, done, None)
+ nstep_data = dist_nstep_td_data(dist, next_dist, action, next_action, reward.unsqueeze(0), done, None)
+ onestep_loss = dist_1step_td_error(onestep_data, 0.95, v_min, v_max, n_atom)
+ nstep_loss, _ = dist_nstep_td_error(nstep_data, 0.95, v_min, v_max, n_atom, nstep=1)
+ assert pytest.approx(nstep_loss.item()) == onestep_loss.item()
+
+
+@pytest.mark.unittest
+def test_dist_1step_multi_agent_td():
+ batch_size = 4
+ action_dim = 3
+ agent_num = 2
+ n_atom = 51
+ v_min = -10.0
+ v_max = 10.0
+ dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs().requires_grad_(True)
+ next_dist = torch.randn(batch_size, agent_num, action_dim, n_atom).abs()
+ done = torch.randint(0, 2, (batch_size, ))
+ action = torch.randint(
+ 0, action_dim, size=(
+ batch_size,
+ agent_num,
+ )
+ )
+ next_action = torch.randint(
+ 0, action_dim, size=(
+ batch_size,
+ agent_num,
+ )
+ )
+ reward = torch.randn(batch_size)
+ data = dist_1step_td_data(dist, next_dist, action, next_action, reward, done, None)
+ loss = dist_1step_td_error(data, 0.95, v_min, v_max, n_atom)
+ assert loss.shape == ()
+ assert dist.grad is None
+ loss.backward()
+ assert isinstance(dist.grad, torch.Tensor)
+ agent_total_loss = 0
+ for i in range(agent_num):
+ data = dist_1step_td_data(
+ dist[:, i, ], next_dist[:, i, ], action[:, i, ], next_action[:, i, ], reward, done, None
+ )
+ agent_loss = dist_1step_td_error(data, 0.95, v_min, v_max, n_atom)
+ agent_total_loss = agent_total_loss + agent_loss
+ agent_average_loss = agent_total_loss / agent_num
+ assert abs(agent_average_loss.item() - loss.item()) < 1e-5
+
+
+@pytest.mark.unittest
+def test_td_lambda():
+ T, B = 8, 4
+ value = torch.randn(T + 1, B).requires_grad_(True)
+ reward = torch.rand(T, B)
+ loss = td_lambda_error(td_lambda_data(value, reward, None))
+ assert loss.shape == ()
+ assert value.grad is None
+ loss.backward()
+ assert isinstance(value.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+def test_v_1step_td():
+ batch_size = 5
+ v = torch.randn(batch_size).requires_grad_(True)
+ next_v = torch.randn(batch_size)
+ reward = torch.rand(batch_size)
+ done = torch.zeros(batch_size)
+ data = v_1step_td_data(v, next_v, reward, done, None)
+ loss, td_error_per_sample = v_1step_td_error(data, 0.99)
+ assert loss.shape == ()
+ assert v.grad is None
+ loss.backward()
+ assert isinstance(v.grad, torch.Tensor)
+ data = v_1step_td_data(v, next_v, reward, None, None)
+ loss, td_error_per_sample = v_1step_td_error(data, 0.99)
+ loss.backward()
+ assert isinstance(v.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+def test_v_1step_multi_agent_td():
+ batch_size = 5
+ agent_num = 2
+ v = torch.randn(batch_size, agent_num).requires_grad_(True)
+ next_v = torch.randn(batch_size, agent_num)
+ reward = torch.rand(batch_size)
+ done = torch.zeros(batch_size)
+ data = v_1step_td_data(v, next_v, reward, done, None)
+ loss, td_error_per_sample = v_1step_td_error(data, 0.99)
+ assert loss.shape == ()
+ assert v.grad is None
+ loss.backward()
+ assert isinstance(v.grad, torch.Tensor)
+ data = v_1step_td_data(v, next_v, reward, None, None)
+ loss, td_error_per_sample = v_1step_td_error(data, 0.99)
+ loss.backward()
+ assert isinstance(v.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+def test_v_nstep_td():
+ batch_size = 5
+ v = torch.randn(batch_size).requires_grad_(True)
+ next_v = torch.randn(batch_size)
+ reward = torch.rand(5, batch_size)
+ done = torch.zeros(batch_size)
+ data = v_nstep_td_data(v, next_v, reward, done, 0.9, 0.99)
+ loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5)
+ assert loss.shape == ()
+ assert v.grad is None
+ loss.backward()
+ assert isinstance(v.grad, torch.Tensor)
+ data = v_nstep_td_data(v, next_v, reward, done, None, 0.99)
+ loss, td_error_per_sample = v_nstep_td_error(data, 0.99, 5)
+ loss.backward()
+ assert isinstance(v.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+def test_dqfd_nstep_td():
+ batch_size = 4
+ action_dim = 3
+ next_q = torch.randn(batch_size, action_dim)
+ done = torch.randn(batch_size)
+ done_1 = torch.randn(batch_size)
+ next_q_one_step = torch.randn(batch_size, action_dim)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action_one_step = torch.randint(0, action_dim, size=(batch_size, ))
+ is_expert = torch.ones((batch_size))
+ for nstep in range(1, 10):
+ q = torch.randn(batch_size, action_dim).requires_grad_(True)
+ reward = torch.rand(nstep, batch_size)
+ data = dqfd_nstep_td_data(
+ q, next_q, action, next_action, reward, done, done_1, None, next_q_one_step, next_action_one_step, is_expert
+ )
+ loss, td_error_per_sample, loss_statistics = dqfd_nstep_td_error(
+ data, 0.95, lambda_n_step_td=1, lambda_supervised_loss=1, margin_function=0.8, nstep=nstep
+ )
+ assert td_error_per_sample.shape == (batch_size, )
+ assert loss.shape == ()
+ assert q.grad is None
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+ print(loss)
+
+
+@pytest.mark.unittest
+def test_q_nstep_sql_td():
+ batch_size = 4
+ action_dim = 3
+ next_q = torch.randn(batch_size, action_dim)
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ for nstep in range(1, 10):
+ q = torch.randn(batch_size, action_dim).requires_grad_(True)
+ reward = torch.rand(nstep, batch_size)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 1.0, nstep=nstep)
+ assert td_error_per_sample.shape == (batch_size, )
+ assert loss.shape == ()
+ assert q.grad is None
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(data, 0.95, 0.5, nstep=nstep, cum_reward=True)
+ value_gamma = torch.tensor(0.9)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ loss, td_error_per_sample, record_target_v = q_nstep_sql_td_error(
+ data, 0.95, 0.5, nstep=nstep, cum_reward=True, value_gamma=value_gamma
+ )
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+def test_iqn_nstep_td():
+ batch_size = 4
+ action_dim = 3
+ tau = 3
+ next_q = torch.randn(tau, batch_size, action_dim)
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ for nstep in range(1, 10):
+ q = torch.randn(tau, batch_size, action_dim).requires_grad_(True)
+ replay_quantile = torch.randn([tau, batch_size, 1])
+ reward = torch.rand(nstep, batch_size)
+ data = iqn_nstep_td_data(q, next_q, action, next_action, reward, done, replay_quantile, None)
+ loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep)
+ assert td_error_per_sample.shape == (batch_size, )
+ assert loss.shape == ()
+ assert q.grad is None
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+ loss, td_error_per_sample = iqn_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9))
+ assert td_error_per_sample.shape == (batch_size, )
+
+
+@pytest.mark.unittest
+def test_fqf_nstep_td():
+ batch_size = 4
+ action_dim = 3
+ tau = 3
+ next_q = torch.randn(batch_size, tau, action_dim)
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ for nstep in range(1, 10):
+ q = torch.randn(batch_size, tau, action_dim).requires_grad_(True)
+ quantiles_hats = torch.randn([batch_size, tau])
+ reward = torch.rand(nstep, batch_size)
+ data = fqf_nstep_td_data(q, next_q, action, next_action, reward, done, quantiles_hats, None)
+ loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep)
+ assert td_error_per_sample.shape == (batch_size, )
+ assert loss.shape == ()
+ assert q.grad is None
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+ loss, td_error_per_sample = fqf_nstep_td_error(data, 0.95, nstep=nstep, value_gamma=torch.tensor(0.9))
+ assert td_error_per_sample.shape == (batch_size, )
+
+
+@pytest.mark.unittest
+def test_shape_fn_qntd():
+ batch_size = 4
+ action_dim = 3
+ next_q = torch.randn(batch_size, action_dim)
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ for nstep in range(1, 10):
+ q = torch.randn(batch_size, action_dim).requires_grad_(True)
+ reward = torch.rand(nstep, batch_size)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ tmp = shape_fn_qntd([data, 0.95, 1], {})
+ assert tmp[0] == reward.shape[0]
+ assert tmp[1] == q.shape[0]
+ assert tmp[2] == q.shape[1]
+ tmp = shape_fn_qntd([], {'gamma': 0.95, 'nstep': 1, 'data': data})
+ assert tmp[0] == reward.shape[0]
+ assert tmp[1] == q.shape[0]
+ assert tmp[2] == q.shape[1]
+
+
+@pytest.mark.unittest
+def test_shape_fn_dntd():
+ batch_size = 4
+ action_dim = 3
+ n_atom = 51
+ v_min = -10.0
+ v_max = 10.0
+ nstep = 5
+ dist = torch.randn(batch_size, action_dim, n_atom).abs().requires_grad_(True)
+ next_n_dist = torch.randn(batch_size, action_dim, n_atom).abs()
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ reward = torch.randn(nstep, batch_size)
+ data = dist_nstep_td_data(dist, next_n_dist, action, next_action, reward, done, None)
+ tmp = shape_fn_dntd([data, 0.9, v_min, v_max, n_atom, nstep], {})
+ assert tmp[0] == reward.shape[0]
+ assert tmp[1] == dist.shape[0]
+ assert tmp[2] == dist.shape[1]
+ assert tmp[3] == n_atom
+ tmp = shape_fn_dntd([], {'data': data, 'gamma': 0.9, 'v_min': v_min, 'v_max': v_max, 'n_atom': n_atom, 'nstep': 5})
+ assert tmp[0] == reward.shape[0]
+ assert tmp[1] == dist.shape[0]
+ assert tmp[2] == dist.shape[1]
+ assert tmp[3] == n_atom
+
+
+@pytest.mark.unittest
+def test_shape_fn_qntd_rescale():
+ batch_size = 4
+ action_dim = 3
+ next_q = torch.randn(batch_size, action_dim)
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ next_action = torch.randint(0, action_dim, size=(batch_size, ))
+ for nstep in range(1, 10):
+ q = torch.randn(batch_size, action_dim).requires_grad_(True)
+ reward = torch.rand(nstep, batch_size)
+ data = q_nstep_td_data(q, next_q, action, next_action, reward, done, None)
+ tmp = shape_fn_qntd_rescale([data, 0.95, 1], {})
+ assert tmp[0] == reward.shape[0]
+ assert tmp[1] == q.shape[0]
+ assert tmp[2] == q.shape[1]
+ tmp = shape_fn_qntd_rescale([], {'gamma': 0.95, 'nstep': 1, 'data': data})
+ assert tmp[0] == reward.shape[0]
+ assert tmp[1] == q.shape[0]
+ assert tmp[2] == q.shape[1]
+
+
+@pytest.mark.unittest
+def test_fn_td_lambda():
+ T, B = 8, 4
+ value = torch.randn(T + 1, B).requires_grad_(True)
+ reward = torch.rand(T, B)
+ data = td_lambda_data(value, reward, None)
+ tmp = shape_fn_td_lambda([], {'data': data})
+ assert tmp == reward.shape[0]
+ tmp = shape_fn_td_lambda([data], {})
+ assert tmp == reward.shape
+
+
+@pytest.mark.unittest
+def test_fn_m_q_1step_td_error():
+ batch_size = 128
+ action_dim = 9
+ q = torch.randn(batch_size, action_dim).requires_grad_(True)
+ target_q_current = torch.randn(batch_size, action_dim).requires_grad_(False)
+ target_q_next = torch.randn(batch_size, action_dim).requires_grad_(False)
+ done = torch.randn(batch_size)
+ action = torch.randint(0, action_dim, size=(batch_size, ))
+ reward = torch.randn(batch_size)
+ data = m_q_1step_td_data(q, target_q_current, target_q_next, action, reward, done, None)
+ loss, td_error_per_sample, action_gap, clip_frac = m_q_1step_td_error(data, 0.99, 0.03, 0.6)
+
+ assert loss.shape == ()
+ assert q.grad is None
+ loss.backward()
+ assert isinstance(q.grad, torch.Tensor)
+ assert clip_frac.mean().item() <= 1
+ assert action_gap.item() > 0
+ assert td_error_per_sample.shape == (batch_size, )
diff --git a/DI-engine/ding/rl_utils/tests/test_upgo.py b/DI-engine/ding/rl_utils/tests/test_upgo.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bd96d9c7e5dd189bfb0b8a34d85280c144f299d
--- /dev/null
+++ b/DI-engine/ding/rl_utils/tests/test_upgo.py
@@ -0,0 +1,41 @@
+import pytest
+import torch
+from ding.rl_utils.upgo import upgo_loss, upgo_returns, tb_cross_entropy
+
+
+@pytest.mark.unittest
+def test_upgo():
+ T, B, N, N2 = 4, 8, 5, 7
+
+ # tb_cross_entropy: 3 tests
+ logit = torch.randn(T, B, N, N2).softmax(-1).requires_grad_(True)
+ action = logit.argmax(-1).detach()
+ ce = tb_cross_entropy(logit, action)
+ assert ce.shape == (T, B)
+
+ logit = torch.randn(T, B, N, N2, 2).softmax(-1).requires_grad_(True)
+ action = logit.argmax(-1).detach()
+ with pytest.raises(AssertionError):
+ ce = tb_cross_entropy(logit, action)
+
+ logit = torch.randn(T, B, N).softmax(-1).requires_grad_(True)
+ action = logit.argmax(-1).detach()
+ ce = tb_cross_entropy(logit, action)
+ assert ce.shape == (T, B)
+
+ # upgo_returns
+ rewards = torch.randn(T, B)
+ bootstrap_values = torch.randn(T + 1, B).requires_grad_(True)
+ returns = upgo_returns(rewards, bootstrap_values)
+ assert returns.shape == (T, B)
+
+ # upgo loss
+ rhos = torch.randn(T, B)
+ loss = upgo_loss(logit, rhos, action, rewards, bootstrap_values)
+ assert logit.requires_grad
+ assert bootstrap_values.requires_grad
+ for t in [logit, bootstrap_values]:
+ assert t.grad is None
+ loss.backward()
+ for t in [logit]:
+ assert isinstance(t.grad, torch.Tensor)
diff --git a/DI-engine/ding/rl_utils/tests/test_value_rescale.py b/DI-engine/ding/rl_utils/tests/test_value_rescale.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce5cb0c9b944a80a748807f21348e9ef1a488d76
--- /dev/null
+++ b/DI-engine/ding/rl_utils/tests/test_value_rescale.py
@@ -0,0 +1,49 @@
+import pytest
+import torch
+from ding.rl_utils.value_rescale import value_inv_transform, value_transform, symlog, inv_symlog
+
+
+@pytest.mark.unittest
+class TestValueRescale:
+
+ def test_value_transform(self):
+ for _ in range(10):
+ t = torch.rand((2, 3))
+ assert isinstance(value_transform(t), torch.Tensor)
+ assert value_transform(t).shape == t.shape
+
+ def test_value_inv_transform(self):
+ for _ in range(10):
+ t = torch.rand((2, 3))
+ assert isinstance(value_inv_transform(t), torch.Tensor)
+ assert value_inv_transform(t).shape == t.shape
+
+ def test_trans_inverse(self):
+ for _ in range(10):
+ t = torch.rand((4, 16))
+ diff = value_inv_transform(value_transform(t)) - t
+ assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
+ assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
+
+
+@pytest.mark.unittest
+class TestSymlog:
+
+ def test_symlog(self):
+ for _ in range(10):
+ t = torch.rand((3, 4))
+ assert isinstance(symlog(t), torch.Tensor)
+ assert symlog(t).shape == t.shape
+
+ def test_inv_symlog(self):
+ for _ in range(10):
+ t = torch.rand((3, 4))
+ assert isinstance(inv_symlog(t), torch.Tensor)
+ assert inv_symlog(t).shape == t.shape
+
+ def test_trans_inverse(self):
+ for _ in range(10):
+ t = torch.rand((4, 16))
+ diff = inv_symlog(symlog(t)) - t
+ assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
+ assert pytest.approx(diff.abs().max().item(), abs=2e-5) == 0
diff --git a/DI-engine/ding/rl_utils/tests/test_vtrace.py b/DI-engine/ding/rl_utils/tests/test_vtrace.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5fa94fa0b08802609441bbd5e04821e75c31eb0
--- /dev/null
+++ b/DI-engine/ding/rl_utils/tests/test_vtrace.py
@@ -0,0 +1,47 @@
+import pytest
+import torch
+from ding.rl_utils import vtrace_data, vtrace_error_discrete_action, vtrace_error_continuous_action
+
+
+@pytest.mark.unittest
+def test_vtrace_discrete_action():
+ T, B, N = 4, 8, 16
+ value = torch.randn(T + 1, B).requires_grad_(True)
+ reward = torch.rand(T, B)
+ target_output = torch.randn(T, B, N).requires_grad_(True)
+ behaviour_output = torch.randn(T, B, N)
+ action = torch.randint(0, N, size=(T, B))
+ data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
+ loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1)
+ assert all([l.shape == tuple() for l in loss])
+ assert target_output.grad is None
+ assert value.grad is None
+ loss = sum(loss)
+ loss.backward()
+ assert isinstance(target_output, torch.Tensor)
+ assert isinstance(value, torch.Tensor)
+
+
+@pytest.mark.unittest
+def test_vtrace_continuous_action():
+ T, B, N = 4, 8, 16
+ value = torch.randn(T + 1, B).requires_grad_(True)
+ reward = torch.rand(T, B)
+ target_output = {}
+ target_output['mu'] = torch.randn(T, B, N).requires_grad_(True)
+ target_output['sigma'] = torch.exp(torch.randn(T, B, N).requires_grad_(True))
+ behaviour_output = {}
+ behaviour_output['mu'] = torch.randn(T, B, N)
+ behaviour_output['sigma'] = torch.exp(torch.randn(T, B, N))
+ action = torch.randn((T, B, N))
+ data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
+ loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1)
+ assert all([l.shape == tuple() for l in loss])
+ assert target_output['mu'].grad is None
+ assert target_output['sigma'].grad is None
+ assert value.grad is None
+ loss = sum(loss)
+ loss.backward()
+ assert isinstance(target_output['mu'], torch.Tensor)
+ assert isinstance(target_output['sigma'], torch.Tensor)
+ assert isinstance(value, torch.Tensor)
diff --git a/DI-engine/ding/rl_utils/upgo.py b/DI-engine/ding/rl_utils/upgo.py
new file mode 100644
index 0000000000000000000000000000000000000000..1117f77e0238892977a87b1e3a2ed4f901a3b48f
--- /dev/null
+++ b/DI-engine/ding/rl_utils/upgo.py
@@ -0,0 +1,111 @@
+import torch
+import torch.nn.functional as F
+from ding.hpc_rl import hpc_wrapper
+from .td import generalized_lambda_returns
+
+
+def tb_cross_entropy(logit, label, mask=None):
+ """
+ Overview:
+ Compute the cross entropy loss for label and logit, with mask support
+ Arguments:
+ - logit (:obj:`torch.Tensor`): the logit tensor, of size [T, B, N] or [T, B, N, N2]
+ - label (:obj:`torch.Tensor`): the label tensor, of size [T, B] or [T, B, N2]
+ - mask (:obj:`torch.Tensor` or :obj:`None`): the mask tensor, of size [T, B] or [T, B, N2]
+ Returns:
+ - ce (:obj:`torch.Tensor`): the computed cross entropy, of size [T, B]
+ Examples:
+ >>> T, B, N, N2 = 4, 8, 5, 7
+ >>> logit = torch.randn(T, B, N, N2).softmax(-1).requires_grad_(True)
+ >>> action = logit.argmax(-1).detach()
+ >>> ce = tb_cross_entropy(logit, action)
+ """
+ assert (len(label.shape) >= 2)
+ T, B = label.shape[:2]
+ # Special 2D case
+ if len(label.shape) > 2:
+ assert len(label.shape) == 3
+ s, n = logit.shape[-2:]
+ logit = logit.reshape(-1, n)
+ label = label.reshape(-1)
+ ce = -F.cross_entropy(logit, label, reduction='none')
+ ce = ce.view(T * B, -1)
+ if mask is not None:
+ ce *= mask.reshape(-1, s)
+ ce = ce.sum(dim=1)
+ ce = ce.reshape(T, B)
+ else:
+ label = label.reshape(-1)
+ logit = logit.reshape(-1, logit.shape[-1])
+ ce = -F.cross_entropy(logit, label, reduction='none')
+ ce = ce.reshape(T, B, -1)
+ ce = ce.mean(dim=2)
+ return ce
+
+
+def upgo_returns(rewards: torch.Tensor, bootstrap_values: torch.Tensor) -> torch.Tensor:
+ r"""
+ Overview:
+ Computing UPGO return targets. Also notice there is no special handling for the terminal state.
+ Arguments:
+ - rewards (:obj:`torch.Tensor`): the returns from time step 0 to T-1, \
+ of size [T_traj, batchsize]
+ - bootstrap_values (:obj:`torch.Tensor`): estimation of the state value at step 0 to T, \
+ of size [T_traj+1, batchsize]
+ Returns:
+ - ret (:obj:`torch.Tensor`): Computed lambda return value for each state from 0 to T-1, \
+ of size [T_traj, batchsize]
+ Examples:
+ >>> T, B, N, N2 = 4, 8, 5, 7
+ >>> rewards = torch.randn(T, B)
+ >>> bootstrap_values = torch.randn(T + 1, B).requires_grad_(True)
+ >>> returns = upgo_returns(rewards, bootstrap_values)
+ """
+ # UPGO can be viewed as a lambda return! The trace continues for V_t (i.e. lambda = 1.0) if r_tp1 + V_tp2 > V_tp1.
+ # as the lambdas[-1, :] is ignored in generalized_lambda_returns, we don't care about bootstrap_values_tp2[-1]
+ lambdas = (rewards + bootstrap_values[1:]) >= bootstrap_values[:-1]
+ lambdas = torch.cat([lambdas[1:], torch.ones_like(lambdas[-1:])], dim=0)
+ return generalized_lambda_returns(bootstrap_values, rewards, 1.0, lambdas)
+
+
+@hpc_wrapper(
+ shape_fn=lambda args: args[0].shape,
+ namedtuple_data=True,
+ include_args=5,
+ include_kwargs=['target_output', 'rhos', 'action', 'rewards', 'bootstrap_values']
+)
+def upgo_loss(
+ target_output: torch.Tensor,
+ rhos: torch.Tensor,
+ action: torch.Tensor,
+ rewards: torch.Tensor,
+ bootstrap_values: torch.Tensor,
+ mask=None
+) -> torch.Tensor:
+ r"""
+ Overview:
+ Computing UPGO loss given constant gamma and lambda. There is no special handling for terminal state value,
+ if the last state in trajectory is the terminal, just pass a 0 as bootstrap_terminal_value.
+ Arguments:
+ - target_output (:obj:`torch.Tensor`): the output computed by the target policy network, \
+ of size [T_traj, batchsize, n_output]
+ - rhos (:obj:`torch.Tensor`): the importance sampling ratio, of size [T_traj, batchsize]
+ - action (:obj:`torch.Tensor`): the action taken, of size [T_traj, batchsize]
+ - rewards (:obj:`torch.Tensor`): the returns from time step 0 to T-1, of size [T_traj, batchsize]
+ - bootstrap_values (:obj:`torch.Tensor`): estimation of the state value at step 0 to T, \
+ of size [T_traj+1, batchsize]
+ Returns:
+ - loss (:obj:`torch.Tensor`): Computed importance sampled UPGO loss, averaged over the samples, of size []
+ Examples:
+ >>> T, B, N, N2 = 4, 8, 5, 7
+ >>> rhos = torch.randn(T, B)
+ >>> loss = upgo_loss(logit, rhos, action, rewards, bootstrap_values)
+ """
+ # discard the value at T as it should be considered in the next slice
+ with torch.no_grad():
+ returns = upgo_returns(rewards, bootstrap_values)
+ advantages = rhos * (returns - bootstrap_values[:-1])
+ metric = tb_cross_entropy(target_output, action, mask)
+ assert (metric.shape == action.shape[:2])
+ losses = advantages * metric
+ return -losses.mean()
diff --git a/DI-engine/ding/rl_utils/value_rescale.py b/DI-engine/ding/rl_utils/value_rescale.py
new file mode 100644
index 0000000000000000000000000000000000000000..bea95c7df2b0be5768036dbb758e889d067edd68
--- /dev/null
+++ b/DI-engine/ding/rl_utils/value_rescale.py
@@ -0,0 +1,65 @@
+import torch
+
+
+def value_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
+ r"""
+ Overview:
+ A function to reduce the scale of the action-value function.
+ :math: `h(x) = sign(x)(\sqrt{(abs(x)+1)} - 1) + \eps * x` .
+ Arguments:
+ - x: (:obj:`torch.Tensor`) The input tensor to be normalized.
+ - eps: (:obj:`float`) The coefficient of the additive regularization term \
+ to ensure h^{-1} is Lipschitz continuous
+ Returns:
+ - (:obj:`torch.Tensor`) Normalized tensor.
+
+ .. note::
+ Observe and Look Further: Achieving Consistent Performance on Atari
+ (https://arxiv.org/abs/1805.11593)
+ """
+ return torch.sign(x) * (torch.sqrt(torch.abs(x) + 1) - 1) + eps * x
+
+
+def value_inv_transform(x: torch.Tensor, eps: float = 1e-2) -> torch.Tensor:
+ r"""
+ Overview:
+ The inverse form of value rescale.
+ :math: `h^{-1}(x) = sign(x)({(\frac{\sqrt{1+4\eps(|x|+1+\eps)}-1}{2\eps})}^2-1)` .
+ Arguments:
+ - x: (:obj:`torch.Tensor`) The input tensor to be unnormalized.
+ - eps: (:obj:`float`) The coefficient of the additive regularization term \
+ to ensure h^{-1} is Lipschitz continuous
+ Returns:
+ - (:obj:`torch.Tensor`) Unnormalized tensor.
+ """
+ return torch.sign(x) * (((torch.sqrt(1 + 4 * eps * (torch.abs(x) + 1 + eps)) - 1) / (2 * eps)) ** 2 - 1)
+
+
+def symlog(x: torch.Tensor) -> torch.Tensor:
+ r"""
+ Overview:
+ A function to normalize the targets.
+ :math: `symlog(x) = sign(x)(\ln{|x|+1})` .
+ Arguments:
+ - x: (:obj:`torch.Tensor`) The input tensor to be normalized.
+ Returns:
+ - (:obj:`torch.Tensor`) Normalized tensor.
+
+ .. note::
+ Mastering Diverse Domains through World Models
+ (https://arxiv.org/abs/2301.04104)
+ """
+ return torch.sign(x) * (torch.log(torch.abs(x) + 1))
+
+
+def inv_symlog(x: torch.Tensor) -> torch.Tensor:
+ r"""
+ Overview:
+ The inverse form of symlog.
+ :math: `symexp(x) = sign(x)(\exp{|x|}-1)` .
+ Arguments:
+ - x: (:obj:`torch.Tensor`) The input tensor to be unnormalized.
+ Returns:
+ - (:obj:`torch.Tensor`) Unnormalized tensor.
+ """
+ return torch.sign(x) * (torch.exp(torch.abs(x)) - 1)
diff --git a/DI-engine/ding/rl_utils/vtrace.py b/DI-engine/ding/rl_utils/vtrace.py
new file mode 100644
index 0000000000000000000000000000000000000000..44728fd6cd92362c7a4c2e564f6538bb6844d2dc
--- /dev/null
+++ b/DI-engine/ding/rl_utils/vtrace.py
@@ -0,0 +1,212 @@
+import torch
+import torch.nn.functional as F
+from torch.distributions import Categorical, Independent, Normal
+from collections import namedtuple
+from .isw import compute_importance_weights
+from ding.hpc_rl import hpc_wrapper
+
+
+def vtrace_nstep_return(clipped_rhos, clipped_cs, reward, bootstrap_values, gamma=0.99, lambda_=0.95):
+ """
+ Overview:
+ Computation of vtrace return.
+ Returns:
+ - vtrace_return (:obj:`torch.FloatTensor`): the vtrace loss item, all of them are differentiable 0-dim tensor
+ Shapes:
+ - clipped_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size
+ - clipped_cs (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ - bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T+1, B)`
+ - vtrace_return (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ """
+ deltas = clipped_rhos * (reward + gamma * bootstrap_values[1:] - bootstrap_values[:-1])
+ factor = gamma * lambda_
+ result = bootstrap_values[:-1].clone()
+ vtrace_item = 0.
+ for t in reversed(range(reward.size()[0])):
+ vtrace_item = deltas[t] + factor * clipped_cs[t] * vtrace_item
+ result[t] += vtrace_item
+ return result
+
+
+def vtrace_advantage(clipped_pg_rhos, reward, return_, bootstrap_values, gamma):
+ """
+ Overview:
+ Computation of vtrace advantage.
+ Returns:
+ - vtrace_advantage (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor
+ Shapes:
+ - clipped_pg_rhos (:obj:`torch.FloatTensor`): :math:`(T, B)`, where T is timestep, B is batch size
+ - reward (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ - return (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ - bootstrap_values (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ - vtrace_advantage (:obj:`torch.FloatTensor`): :math:`(T, B)`
+ """
+ return clipped_pg_rhos * (reward + gamma * return_ - bootstrap_values)
+
+
+vtrace_data = namedtuple('vtrace_data', ['target_output', 'behaviour_output', 'action', 'value', 'reward', 'weight'])
+vtrace_loss = namedtuple('vtrace_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
+
+
+def shape_fn_vtrace_discrete_action(args, kwargs):
+ r"""
+ Overview:
+ Return shape of vtrace for hpc
+ Returns:
+ shape: [T, B, N]
+ """
+ if len(args) <= 0:
+ tmp = kwargs['data'].target_output.shape
+ else:
+ tmp = args[0].target_output.shape
+ return tmp
+
+
+@hpc_wrapper(
+ shape_fn=shape_fn_vtrace_discrete_action,
+ namedtuple_data=True,
+ include_args=[0, 1, 2, 3, 4, 5],
+ include_kwargs=['data', 'gamma', 'lambda_', 'rho_clip_ratio', 'c_clip_ratio', 'rho_pg_clip_ratio']
+)
+def vtrace_error_discrete_action(
+ data: namedtuple,
+ gamma: float = 0.99,
+ lambda_: float = 0.95,
+ rho_clip_ratio: float = 1.0,
+ c_clip_ratio: float = 1.0,
+ rho_pg_clip_ratio: float = 1.0
+):
+ """
+ Overview:
+ Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\
+ Architectures), (arXiv:1802.01561)
+ Arguments:
+ - data (:obj:`namedtuple`): input data with fields shown in ``vtrace_data``
+ - target_output (:obj:`torch.Tensor`): the output taking the action by the current policy network,\
+ usually this output is network output logit
+ - behaviour_output (:obj:`torch.Tensor`): the output taking the action by the behaviour policy network,\
+ usually this output is network output logit, which is used to produce the trajectory(collector)
+ - action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory,\
+ i.e.: behaviour_action
+ - gamma: (:obj:`float`): the future discount factor, defaults to 0.95
+ - lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0
+ - rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\
+ the baseline targets (vs)
+ - c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\
+ the baseline targets (vs)
+ - rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\
+ the policy gradient advantage
+ Returns:
+ - trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor
+ Shapes:
+ - target_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)`, where T is timestep, B is batch size and\
+ N is action dim
+ - behaviour_output (:obj:`torch.FloatTensor`): :math:`(T, B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(T, B)`
+ - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)`
+ - reward (:obj:`torch.LongTensor`): :math:`(T, B)`
+ - weight (:obj:`torch.LongTensor`): :math:`(T, B)`
+ Examples:
+ >>> T, B, N = 4, 8, 16
+ >>> value = torch.randn(T + 1, B).requires_grad_(True)
+ >>> reward = torch.rand(T, B)
+ >>> target_output = torch.randn(T, B, N).requires_grad_(True)
+ >>> behaviour_output = torch.randn(T, B, N)
+ >>> action = torch.randint(0, N, size=(T, B))
+ >>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
+ >>> loss = vtrace_error_discrete_action(data, rho_clip_ratio=1.1)
+ """
+ target_output, behaviour_output, action, value, reward, weight = data
+ with torch.no_grad():
+ IS = compute_importance_weights(target_output, behaviour_output, action, 'discrete')
+ rhos = torch.clamp(IS, max=rho_clip_ratio)
+ cs = torch.clamp(IS, max=c_clip_ratio)
+ return_ = vtrace_nstep_return(rhos, cs, reward, value, gamma, lambda_)
+ pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio)
+ return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0)
+ adv = vtrace_advantage(pg_rhos, reward, return_t_plus_1, value[:-1], gamma)
+
+ if weight is None:
+ weight = torch.ones_like(reward)
+ dist_target = Categorical(logits=target_output)
+ pg_loss = -(dist_target.log_prob(action) * adv * weight).mean()
+ value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean()
+ entropy_loss = (dist_target.entropy() * weight).mean()
+ return vtrace_loss(pg_loss, value_loss, entropy_loss)
+
+
+def vtrace_error_continuous_action(
+ data: namedtuple,
+ gamma: float = 0.99,
+ lambda_: float = 0.95,
+ rho_clip_ratio: float = 1.0,
+ c_clip_ratio: float = 1.0,
+ rho_pg_clip_ratio: float = 1.0
+):
+ """
+ Overview:
+ Implementation of vtrace(IMPALA: Scalable Distributed Deep-RL with Importance Weighted Actor-Learner\
+ Architectures), (arXiv:1802.01561)
+ Arguments:
+ - data (:obj:`namedtuple`): input data with fields shown in ``vtrace_data``
+ - target_output (:obj:`dict{key:torch.Tensor}`): the output taking the action \
+ by the current policy network, usually this output is network output, \
+ which represents the distribution by reparameterization trick.
+ - behaviour_output (:obj:`dict{key:torch.Tensor}`): the output taking the action \
+ by the behaviour policy network, usually this output is network output logit, \
+ which represents the distribution by reparameterization trick.
+ - action (:obj:`torch.Tensor`): the chosen action(index for the discrete action space) in trajectory, \
+ i.e.: behaviour_action
+ - gamma: (:obj:`float`): the future discount factor, defaults to 0.95
+ - lambda: (:obj:`float`): mix factor between 1-step (lambda_=0) and n-step, defaults to 1.0
+ - rho_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\
+ the baseline targets (vs)
+ - c_clip_ratio (:obj:`float`): the clipping threshold for importance weights (c) when calculating\
+ the baseline targets (vs)
+ - rho_pg_clip_ratio (:obj:`float`): the clipping threshold for importance weights (rho) when calculating\
+ the policy gradient advantage
+ Returns:
+ - trace_loss (:obj:`namedtuple`): the vtrace loss item, all of them are the differentiable 0-dim tensor
+ Shapes:
+ - target_output (:obj:`dict{key:torch.FloatTensor}`): :math:`(T, B, N)`, \
+ where T is timestep, B is batch size and \
+ N is action dim. The keys are usually parameters of reparameterization trick.
+ - behaviour_output (:obj:`dict{key:torch.FloatTensor}`): :math:`(T, B, N)`
+ - action (:obj:`torch.LongTensor`): :math:`(T, B)`
+ - value (:obj:`torch.FloatTensor`): :math:`(T+1, B)`
+ - reward (:obj:`torch.LongTensor`): :math:`(T, B)`
+ - weight (:obj:`torch.LongTensor`): :math:`(T, B)`
+ Examples:
+ >>> T, B, N = 4, 8, 16
+ >>> value = torch.randn(T + 1, B).requires_grad_(True)
+ >>> reward = torch.rand(T, B)
+ >>> target_output = dict(
+ >>> 'mu': torch.randn(T, B, N).requires_grad_(True),
+ >>> 'sigma': torch.exp(torch.randn(T, B, N).requires_grad_(True)),
+ >>> )
+ >>> behaviour_output = dict(
+ >>> 'mu': torch.randn(T, B, N),
+ >>> 'sigma': torch.exp(torch.randn(T, B, N)),
+ >>> )
+ >>> action = torch.randn((T, B, N))
+ >>> data = vtrace_data(target_output, behaviour_output, action, value, reward, None)
+ >>> loss = vtrace_error_continuous_action(data, rho_clip_ratio=1.1)
+ """
+ target_output, behaviour_output, action, value, reward, weight = data
+ with torch.no_grad():
+ IS = compute_importance_weights(target_output, behaviour_output, action, 'continuous')
+ rhos = torch.clamp(IS, max=rho_clip_ratio)
+ cs = torch.clamp(IS, max=c_clip_ratio)
+ return_ = vtrace_nstep_return(rhos, cs, reward, value, gamma, lambda_)
+ pg_rhos = torch.clamp(IS, max=rho_pg_clip_ratio)
+ return_t_plus_1 = torch.cat([return_[1:], value[-1:]], 0)
+ adv = vtrace_advantage(pg_rhos, reward, return_t_plus_1, value[:-1], gamma)
+
+ if weight is None:
+ weight = torch.ones_like(reward)
+ dist_target = Independent(Normal(loc=target_output['mu'], scale=target_output['sigma']), 1)
+ pg_loss = -(dist_target.log_prob(action) * adv * weight).mean()
+ value_loss = (F.mse_loss(value[:-1], return_, reduction='none') * weight).mean()
+ entropy_loss = (dist_target.entropy() * weight).mean()
+ return vtrace_loss(pg_loss, value_loss, entropy_loss)
diff --git a/DI-engine/ding/scripts/dijob-qbert.yaml b/DI-engine/ding/scripts/dijob-qbert.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..06905584768af488e98b1c615c1ef43f372c5ef0
--- /dev/null
+++ b/DI-engine/ding/scripts/dijob-qbert.yaml
@@ -0,0 +1,228 @@
+apiVersion: diengine.opendilab.org/v1alpha1
+kind: DIJob
+metadata:
+ name: qbert-dqn
+spec:
+ group: xxx
+ priorityClassName: ""
+ cleanPodPolicy: "Running"
+ volumes:
+ - name: cache-volume
+ emptyDir:
+ medium: Memory
+ sizeLimit: 128Mi
+ - name: work-dir
+ hostPath:
+ path: /data/nfs/ding/qbert
+ coordinator:
+ template:
+ spec:
+ containers:
+ - name: coordinator
+ image: diorchestrator/ding:v0.1.0-df39b81c
+ imagePullPolicy: Always
+ env:
+ - name: PYTHONUNBUFFERED
+ value: "1"
+ resources:
+ requests:
+ cpu: 3
+ memory: "10Gi"
+ limits:
+ cpu: 3
+ memory: "10Gi"
+ command: ["/bin/bash", "-c",]
+ args:
+ - |
+ cat < qbert_dqn_config_k8s.py
+ from easydict import EasyDict
+
+ qbert_dqn_config = dict(
+ env=dict(
+ collector_env_num=16,
+ collector_episode_num=2,
+ evaluator_env_num=8,
+ evaluator_episode_num=1,
+ stop_value=30000,
+ env_id='QbertNoFrameskip-v4',
+ frame_stack=4,
+ manager=dict(
+ shared_memory=False,
+ ),
+ ),
+ policy=dict(
+ cuda=False,
+ priority=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ batch_size=32,
+ learning_rate=0.0001,
+ learner=dict(
+ learner_num=1,
+ send_policy_freq=1,
+ ),
+ ),
+ collect=dict(
+ n_sample=16,
+ collector=dict(
+ collector_num=2,
+ update_policy_second=3,
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=400000,
+ enable_track_used_data=True,
+ ),
+ commander=dict(
+ collector_task_space=0,
+ learner_task_space=1,
+ eval_interval=30,
+ ),
+ ),
+ ),
+ )
+ qbert_dqn_config = EasyDict(qbert_dqn_config)
+ main_config = qbert_dqn_config
+
+ qbert_dqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn_command'),
+ learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
+ collector=dict(
+ type='zergling',
+ import_names=['ding.worker.collector.zergling_parallel_collector'],
+ ),
+ commander=dict(
+ type='solo',
+ import_names=['ding.worker.coordinator.solo_parallel_commander'],
+ ),
+ comm_learner=dict(
+ type='flask_fs',
+ import_names=['ding.worker.learner.comm.flask_fs_learner'],
+ ),
+ comm_collector=dict(
+ type='flask_fs',
+ import_names=['ding.worker.collector.comm.flask_fs_collector'],
+ ),
+ )
+ qbert_dqn_create_config = EasyDict(qbert_dqn_create_config)
+ create_config = qbert_dqn_create_config
+
+ qbert_dqn_system_config = dict(
+ coordinator=dict(
+ operator_server=dict(
+ system_addr='ding-server.ding-system:8080',
+ api_version='/v1alpha1',
+ init_replicas_request=dict(
+ collectors={
+ "replicas": 2,
+ },
+ learners={
+ "gpus": "0",
+ "replicas": 1,
+ },
+ ),
+ collector_target_num=2,
+ learner_target_num=1,
+ ),
+ ),
+ path_data='./data',
+ path_policy='./policy',
+ communication_mode='auto',
+ learner_gpu_num=1,
+ )
+ qbert_dqn_system_config = EasyDict(qbert_dqn_system_config)
+ system_config = qbert_dqn_system_config
+ EOF
+
+ # if code has been changed in the mount path, we have to reinstall ding cli
+ # pip install --no-cache-dir -e .;
+
+ ding -m dist --module config -P k8s -c qbert_dqn_config_k8s.py -s 0;
+ ding -m dist --module coordinator -c qbert_dqn_config_k8s.py.pkl -s 0 --disable-flask-log 0 -cdp $COORDINATOR_PORT
+ ports:
+ - name: coordinator
+ containerPort: 22273
+ volumeMounts:
+ - name: work-dir
+ mountPath: /ding
+ collector:
+ template:
+ spec:
+ containers:
+ - name: collector
+ image: diorchestrator/ding:v0.1.0-df39b81c
+ imagePullPolicy: Always
+ env:
+ - name: PYTHONUNBUFFERED
+ value: "1"
+ resources:
+ requests:
+ cpu: 6
+ memory: "10Gi"
+ limits:
+ cpu: 6
+ memory: "10Gi"
+ command: ["/bin/bash", "-c",]
+ args:
+ - |
+ # if code has been changed in the mount path, we have to reinstall ding cli
+ # pip install --no-cache-dir -e .;
+
+ ding -m dist --module collector -c qbert_dqn_config_k8s.py.pkl -s 0 -clp $COLLECTOR_PORT --disable-flask-log 0
+ ports:
+ - name: collector
+ containerPort: 22270
+ volumeMounts:
+ - name: work-dir
+ mountPath: /ding
+ learner:
+ template:
+ spec:
+ containers:
+ - name: learner
+ image: diorchestrator/ding:v0.1.0-df39b81c
+ imagePullPolicy: Always
+ env:
+ - name: PYTHONUNBUFFERED
+ value: "1"
+ resources:
+ requests:
+ cpu: 3
+ memory: "30Gi"
+ limits:
+ cpu: 3
+ memory: "30Gi"
+ command: ["/bin/bash", "-c",]
+ args:
+ - |
+ # if code has been changed in the mount path, we have to reinstall ding cli
+ # pip install --no-cache-dir -e .;
+
+ ding -m dist --module spawn_learner -c qbert_dqn_config_k8s.py.pkl -s 0 -lp $LEARNER_PORT --disable-flask-log 0
+ ports:
+ - name: learner
+ containerPort: 22271
+ volumeMounts:
+ - name: cache-volume
+ mountPath: /dev/shm
+ - name: work-dir
+ mountPath: /ding
diff --git a/DI-engine/ding/scripts/docker-test-entry.sh b/DI-engine/ding/scripts/docker-test-entry.sh
new file mode 100755
index 0000000000000000000000000000000000000000..3b3e5fca927f0ba4352e471d45628bb10cde5364
--- /dev/null
+++ b/DI-engine/ding/scripts/docker-test-entry.sh
@@ -0,0 +1,9 @@
+#!/usr/bin/env bash
+
+CONTAINER_ID=$(docker run --rm -d opendilab/ding:nightly tail -f /dev/null)
+
+trap "docker rm -f $CONTAINER_ID" EXIT
+
+docker exec $CONTAINER_ID rm -rf /ding &&
+ docker cp $(pwd) ${CONTAINER_ID}:/ding &&
+ docker exec -it $CONTAINER_ID /ding/ding/scripts/docker-test.sh
diff --git a/DI-engine/ding/scripts/docker-test.sh b/DI-engine/ding/scripts/docker-test.sh
new file mode 100755
index 0000000000000000000000000000000000000000..68ad7f63ccbefb0e27008118e2492d3da3ca87f8
--- /dev/null
+++ b/DI-engine/ding/scripts/docker-test.sh
@@ -0,0 +1,11 @@
+#!/usr/bin/env bash
+
+if [ ! -f /.dockerenv ]; then
+ echo "This script should be executed in docker container"
+ exit 1
+fi
+
+pip install --ignore-installed 'PyYAML<6.0'
+pip install -e .[test,k8s] &&
+ ./ding/scripts/install-k8s-tools.sh &&
+ make test
diff --git a/DI-engine/ding/scripts/install-k8s-tools.sh b/DI-engine/ding/scripts/install-k8s-tools.sh
new file mode 100755
index 0000000000000000000000000000000000000000..56800509469a517a11063ab3bec4efdf0e8c8cf0
--- /dev/null
+++ b/DI-engine/ding/scripts/install-k8s-tools.sh
@@ -0,0 +1,28 @@
+#!/usr/bin/env bash
+
+set -e
+
+ROOT_DIR="$(dirname "$0")"
+: ${USE_SUDO:="true"}
+
+# runs the given command as root (detects if we are root already)
+runAsRoot() {
+ local CMD="$*"
+
+ if [ $EUID -ne 0 -a $USE_SUDO = "true" ]; then
+ CMD="sudo $CMD"
+ fi
+
+ $CMD
+}
+
+# install k3d
+curl -s https://raw.githubusercontent.com/rancher/k3d/main/install.sh | TAG=v4.4.8 bash
+
+# install kubectl
+if [[ $(which kubectl) == "" ]]; then
+ echo "Installing kubectl..."
+ curl -LO https://dl.k8s.io/release/v1.21.3/bin/linux/amd64/kubectl
+ chmod +x kubectl
+ runAsRoot mv kubectl /usr/local/bin/kubectl
+fi
diff --git a/DI-engine/ding/scripts/kill.sh b/DI-engine/ding/scripts/kill.sh
new file mode 100755
index 0000000000000000000000000000000000000000..f89f5ed615a798bc4d4a3298686e91871c4ff202
--- /dev/null
+++ b/DI-engine/ding/scripts/kill.sh
@@ -0,0 +1 @@
+ps -ef | grep 'ding' | grep -v grep | awk '{print $2}'|xargs kill -9
diff --git a/DI-engine/ding/scripts/local_parallel.sh b/DI-engine/ding/scripts/local_parallel.sh
new file mode 100644
index 0000000000000000000000000000000000000000..93f45a860df253f81d13ecc9b8b3ae0229917baa
--- /dev/null
+++ b/DI-engine/ding/scripts/local_parallel.sh
@@ -0,0 +1 @@
+ding -m parallel -c $1 -s $2
diff --git a/DI-engine/ding/scripts/local_serial.sh b/DI-engine/ding/scripts/local_serial.sh
new file mode 100644
index 0000000000000000000000000000000000000000..6e434a078c758a06e5cb7d3fe2f099c23475d68f
--- /dev/null
+++ b/DI-engine/ding/scripts/local_serial.sh
@@ -0,0 +1 @@
+ding -m serial -c $1 -s $2
diff --git a/DI-engine/ding/scripts/main_league.sh b/DI-engine/ding/scripts/main_league.sh
new file mode 100755
index 0000000000000000000000000000000000000000..301e7a171a41090021b3781ed2995f30884aec51
--- /dev/null
+++ b/DI-engine/ding/scripts/main_league.sh
@@ -0,0 +1,52 @@
+#!/usr/bin/env bash
+
+BASEDIR=$(dirname "$0")/../entry
+
+kill_descendant_processes() {
+ local pid="$1"
+ local and_self="${2:-false}"
+ if children="$(pgrep -P "$pid")"; then
+ for child in $children; do
+ kill_descendant_processes "$child" true
+ done
+ fi
+ if [[ "$and_self" == true ]]; then
+ kill "$pid"
+ fi
+}
+
+trap "kill_descendant_processes $$" EXIT
+
+ditask --package $BASEDIR \
+ --main main_league.main \
+ --parallel-workers 1 \
+ --protocol tcp \
+ --address 127.0.0.1 \
+ --ports 50515 \
+ --node-ids 0 \
+ --topology alone \
+ --labels league,collect &
+
+# ditask --package $BASEDIR \
+# --main main_league.main \
+# --parallel-workers 3 \
+# --protocol tcp \
+# --address 127.0.0.1 \
+# --ports 50525 \
+# --node-ids 10 \
+# --topology alone \
+# --labels learn \
+# --attach-to tcp://127.0.0.1:50515 &
+
+# ditask --package $BASEDIR \
+# --main main_league.main \
+# --parallel-workers 1 \
+# --address 127.0.0.1 \
+# --protocol tcp \
+# --ports 50535 \
+# --node-ids 20 \
+# --topology alone \
+# --labels evaluate \
+# --attach-to tcp://127.0.0.1:50515,tcp://127.0.0.1:50525,tcp://127.0.0.1:50526,tcp://127.0.0.1:50527 &
+
+sleep 10000
diff --git a/DI-engine/ding/scripts/main_league_slurm.sh b/DI-engine/ding/scripts/main_league_slurm.sh
new file mode 100755
index 0000000000000000000000000000000000000000..a8a5e70359c8151a9b1f48f67795f58a67e0514c
--- /dev/null
+++ b/DI-engine/ding/scripts/main_league_slurm.sh
@@ -0,0 +1,7 @@
+#!/usr/bin/env bash
+
+export LC_ALL=en_US.utf-8
+export LANG=en_US.utf-8
+BASEDIR=$(dirname "$0")
+# srun -p partition_name --quotatype=reserved --mpi=pmi2 -n6 --ntasks-per-node=3 bash ding/scripts/main_league_slurm.sh
+ditask --package $BASEDIR/../entry --main main_league.main --platform slurm --platform-spec '{"tasks":[{"labels":"league,collect","node_ids":10},{"labels":"league,collect","node_ids":11},{"labels":"evaluate","node_ids":20,"attach_to":"$node.10,$node.11"},{"labels":"learn","node_ids":31,"attach_to":"$node.10,$node.11,$node.20"},{"labels":"learn","node_ids":32,"attach_to":"$node.10,$node.11,$node.20"},{"labels":"learn","node_ids":33,"attach_to":"$node.10,$node.11,$node.20"}]}'
diff --git a/DI-engine/ding/scripts/tests/test_parallel_socket.py b/DI-engine/ding/scripts/tests/test_parallel_socket.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f322e05d7a81c6c1629dd0c3cb53937b829f7cb
--- /dev/null
+++ b/DI-engine/ding/scripts/tests/test_parallel_socket.py
@@ -0,0 +1,142 @@
+import sys
+import os
+import time
+from ditk import logging
+import argparse
+import tempfile
+
+from random import random
+from string import ascii_lowercase
+from ding.framework import Parallel
+
+alphabet = [c.encode('ascii') for c in ascii_lowercase]
+
+
+class EasyCounter:
+
+ def __init__(self):
+ self._last = None
+ self._cnt = 0
+
+ def add(self, item):
+ self._last = item
+ self._cnt += 1
+
+ def cnt(self):
+ return self._cnt
+
+ def last(self):
+ return self._last
+
+
+class SockTest:
+
+ # In this class, we define three processes except the main process,
+ # which are receiver, testee, and sender.
+ # The testee receive messages from the sender, and sends its own greeting
+ # messages to the receiver periodically.
+ # During the test, we breakdown the network of testee, and then find out
+ # what happens to the testee.
+
+ @classmethod
+ def receiver(cls, epoch, interval):
+ router = Parallel()
+ greets = EasyCounter()
+ router.on("greeting_receiver", lambda msg: greets.add(msg))
+ start_t = time.time()
+ logging.info("receiver start ...")
+
+ for i in range(epoch):
+ while time.time() - start_t < i * interval:
+ time.sleep(0.01)
+
+ if greets.cnt() == 0 or i % 10 != 0:
+ continue
+ last_msg = greets.last()
+ msg_idx, msg_t = last_msg.split("_")[-2:]
+ logging.info(
+ "receiver passed {:.2f} s, received {} msgs. last msg: idx {}, time {} s".format(
+ time.time() - start_t, greets.cnt(), msg_idx, msg_t
+ )
+ )
+
+ logging.info("receiver done! total msg: {}".format(greets.cnt()))
+
+ @classmethod
+ def testee(cls, epoch, interval, data_size):
+ words = b''.join([alphabet[int(random() * 26)] for _ in range(1024 * 1024)]) * data_size
+ print("msg length: {:.4f} MB".format(sys.getsizeof(words) / 1024 / 1024))
+
+ router = Parallel()
+ greets = EasyCounter()
+ router.on("greeting_testee", lambda msg: greets.add(msg))
+ start_t = time.time()
+ logging.info("testee start ...")
+
+ with tempfile.NamedTemporaryFile(prefix="pytmp_", dir="./") as itf:
+ print("testee: write ip address to the tempfile:", itf.name)
+ with open(itf.name, 'w') as ifd:
+ ifd.write("{}\n".format(router.get_ip()))
+
+ for i in range(epoch):
+ while time.time() - start_t < i * interval:
+ time.sleep(0.01)
+
+ if router._retries == 0:
+ router.emit("greeting_receiver", "{}_{}_{:.2f}".format(words, i, time.time() - start_t))
+ elif router._retries == 1:
+ router.emit("greeting_receiver", "recovered_{}_{:.2f}".format(i, time.time() - start_t))
+ else:
+ raise Exception("Failed too many times")
+
+ if greets.cnt() == 0 or i % 10 != 0:
+ continue
+ last_msg = greets.last()
+ msg_idx, msg_t = last_msg.split("_")[-2:]
+ logging.info(
+ "testee passed {:.2f} s, received {} msgs. last msg: idx {}, time {} s".format(
+ time.time() - start_t, greets.cnt(), msg_idx, msg_t
+ )
+ )
+
+ logging.info("testee done! total msg: {} retries: {}".format(greets.cnt(), router._retries))
+
+ @classmethod
+ def sender(cls, epoch, interval, data_size):
+ words = b''.join([alphabet[int(random() * 26)] for _ in range(1024 * 1024)]) * data_size
+ print("msg length: {:.4f} MB".format(sys.getsizeof(words) / 1024 / 1024))
+
+ router = Parallel()
+ start_t = time.time()
+ logging.info("sender start ...")
+
+ for i in range(epoch):
+ while time.time() - start_t < i * interval:
+ time.sleep(0.01)
+
+ router.emit("greeting_testee", "{}_{}_{:.2f}".format(words, i, time.time() - start_t))
+
+ logging.info("sender done!")
+
+ @classmethod
+ def main(cls, epoch=1000, interval=1.0, data_size=1, file="tmp_p1"):
+ router = Parallel()
+ if router.node_id == 0:
+ cls.receiver(epoch, interval)
+ elif router.node_id == 1:
+ cls.testee(epoch, interval, data_size)
+ elif router.node_id == 2:
+ cls.sender(epoch, interval, data_size)
+ else:
+ raise Exception("Invalid node id")
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--epoch', '-t', type=int, default=1200)
+ parser.add_argument('--interval', '-i', type=float, default=0.1)
+ parser.add_argument('--data_size', '-s', type=int, default=1)
+ args = parser.parse_args()
+ Parallel.runner(
+ n_parallel_workers=3, protocol="tcp", topology="mesh", auto_recover=True, max_retries=1
+ )(SockTest.main, args.epoch, args.interval, args.data_size)
diff --git a/DI-engine/ding/scripts/tests/test_parallel_socket.sh b/DI-engine/ding/scripts/tests/test_parallel_socket.sh
new file mode 100644
index 0000000000000000000000000000000000000000..a49e1005012cf2c778b6291a0a563f41a9c003c5
--- /dev/null
+++ b/DI-engine/ding/scripts/tests/test_parallel_socket.sh
@@ -0,0 +1,36 @@
+total_epoch=1200 # the total num of msg
+interval=0.1 # msg send interval
+size=16 # data size (MB)
+test_start_time=20 # network fail time (s)
+test_duration=40 # network fail duration (s)
+output_file="my_test.log" # the python script will write its output into this file
+ip="0.0.0.0"
+
+rm -f pytmp_*
+
+nohup python test_parallel_socket.py -t $total_epoch -i $interval -s $size 1>$output_file 2>&1 &
+
+flag=true
+while $flag
+do
+ for file in `ls`
+ do
+ if [[ $file =~ "pytmp" ]]; then
+ ip=`cat $file`
+ flag=false
+ break
+ fi
+ done
+ sleep 0.1
+done
+echo "get ip: $ip"
+
+sleep $test_start_time
+echo "Network shutsown . . ."
+sudo iptables -A INPUT -p tcp -s $ip --dport 50516 -j DROP
+
+sleep $test_duration
+sudo iptables -D INPUT -p tcp -s $ip --dport 50516 -j DROP
+echo "Network recovered . . ."
+
+
diff --git a/DI-engine/ding/torch_utils/__init__.py b/DI-engine/ding/torch_utils/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..151b4da7e190e5f6d3a739895a9759ce8349012a
--- /dev/null
+++ b/DI-engine/ding/torch_utils/__init__.py
@@ -0,0 +1,14 @@
+from .checkpoint_helper import build_checkpoint_helper, CountVar, auto_checkpoint
+from .data_helper import to_device, to_tensor, to_ndarray, to_list, to_dtype, same_shape, tensor_to_list, \
+ build_log_buffer, CudaFetcher, get_tensor_data, unsqueeze, squeeze, get_null_data, get_shape0, to_item, \
+ zeros_like
+from .distribution import CategoricalPd, CategoricalPdPytorch
+from .metric import levenshtein_distance, hamming_distance
+from .network import *
+from .loss import *
+from .optimizer_helper import Adam, RMSprop, calculate_grad_norm, calculate_grad_norm_without_bias_two_norm
+from .nn_test_helper import is_differentiable
+from .math_helper import cov
+from .dataparallel import DataParallel
+from .reshape_helper import fold_batch, unfold_batch, unsqueeze_repeat
+from .parameter import NonegativeParameter, TanhParameter
diff --git a/DI-engine/ding/torch_utils/backend_helper.py b/DI-engine/ding/torch_utils/backend_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7346b2d792518ff715ebc71941c22314098ee2a
--- /dev/null
+++ b/DI-engine/ding/torch_utils/backend_helper.py
@@ -0,0 +1,12 @@
+import torch
+
+
+def enable_tf32() -> None:
+ """
+ Overview:
+ Enable tf32 on matmul and cudnn for faster computation. This only works on Ampere GPU devices. \
+ For detailed information, please refer to: \
+ https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices.
+ """
+ torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
+ torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
diff --git a/DI-engine/ding/torch_utils/checkpoint_helper.py b/DI-engine/ding/torch_utils/checkpoint_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d600b556ed168c2f382837dc04a782c08d55bae
--- /dev/null
+++ b/DI-engine/ding/torch_utils/checkpoint_helper.py
@@ -0,0 +1,346 @@
+from ditk import logging
+import signal
+import sys
+import traceback
+from typing import Callable
+import torch
+import torch.utils.data # torch1.1.0 compatibility
+from ding.utils import read_file, save_file
+
+logger = logging.getLogger('default_logger')
+
+
+def build_checkpoint_helper(cfg):
+ """
+ Overview:
+ Use config to build checkpoint helper.
+ Arguments:
+ - cfg (:obj:`dict`): ckpt_helper config
+ Returns:
+ - (:obj:`CheckpointHelper`): checkpoint_helper created by this function
+ """
+ return CheckpointHelper()
+
+
+class CheckpointHelper:
+ """
+ Overview:
+ Help to save or load checkpoint by give args.
+ Interfaces:
+ ``__init__``, ``save``, ``load``, ``_remove_prefix``, ``_add_prefix``, ``_load_matched_model_state_dict``
+ """
+
+ def __init__(self):
+ pass
+
+ def _remove_prefix(self, state_dict: dict, prefix: str = 'module.') -> dict:
+ """
+ Overview:
+ Remove prefix in state_dict
+ Arguments:
+ - state_dict (:obj:`dict`): model's state_dict
+ - prefix (:obj:`str`): this prefix will be removed in keys
+ Returns:
+ - new_state_dict (:obj:`dict`): new state_dict after removing prefix
+ """
+ new_state_dict = {}
+ for k, v in state_dict.items():
+ if k.startswith(prefix):
+ new_k = ''.join(k.split(prefix))
+ else:
+ new_k = k
+ new_state_dict[new_k] = v
+ return new_state_dict
+
+ def _add_prefix(self, state_dict: dict, prefix: str = 'module.') -> dict:
+ """
+ Overview:
+ Add prefix in state_dict
+ Arguments:
+ - state_dict (:obj:`dict`): model's state_dict
+ - prefix (:obj:`str`): this prefix will be added in keys
+ Returns:
+ - (:obj:`dict`): new state_dict after adding prefix
+ """
+ return {prefix + k: v for k, v in state_dict.items()}
+
+ def save(
+ self,
+ path: str,
+ model: torch.nn.Module,
+ optimizer: torch.optim.Optimizer = None,
+ last_iter: 'CountVar' = None, # noqa
+ last_epoch: 'CountVar' = None, # noqa
+ last_frame: 'CountVar' = None, # noqa
+ dataset: torch.utils.data.Dataset = None,
+ collector_info: torch.nn.Module = None,
+ prefix_op: str = None,
+ prefix: str = None,
+ ) -> None:
+ """
+ Overview:
+ Save checkpoint by given args
+ Arguments:
+ - path (:obj:`str`): the path of saving checkpoint
+ - model (:obj:`torch.nn.Module`): model to be saved
+ - optimizer (:obj:`torch.optim.Optimizer`): optimizer obj
+ - last_iter (:obj:`CountVar`): iter num, default None
+ - last_epoch (:obj:`CountVar`): epoch num, default None
+ - last_frame (:obj:`CountVar`): frame num, default None
+ - dataset (:obj:`torch.utils.data.Dataset`): dataset, should be replaydataset
+ - collector_info (:obj:`torch.nn.Module`): attr of checkpoint, save collector info
+ - prefix_op (:obj:`str`): should be ['remove', 'add'], process on state_dict
+ - prefix (:obj:`str`): prefix to be processed on state_dict
+ """
+ checkpoint = {}
+ model = model.state_dict()
+ if prefix_op is not None: # remove or add prefix to model.keys()
+ prefix_func = {'remove': self._remove_prefix, 'add': self._add_prefix}
+ if prefix_op not in prefix_func.keys():
+ raise KeyError('invalid prefix_op:{}'.format(prefix_op))
+ else:
+ model = prefix_func[prefix_op](model, prefix)
+ checkpoint['model'] = model
+
+ if optimizer is not None: # save optimizer
+ assert (last_iter is not None or last_epoch is not None)
+ checkpoint['last_iter'] = last_iter.val
+ if last_epoch is not None:
+ checkpoint['last_epoch'] = last_epoch.val
+ if last_frame is not None:
+ checkpoint['last_frame'] = last_frame.val
+ checkpoint['optimizer'] = optimizer.state_dict()
+
+ if dataset is not None:
+ checkpoint['dataset'] = dataset.state_dict()
+ if collector_info is not None:
+ checkpoint['collector_info'] = collector_info.state_dict()
+ save_file(path, checkpoint)
+ logger.info('save checkpoint in {}'.format(path))
+
+ def _load_matched_model_state_dict(self, model: torch.nn.Module, ckpt_state_dict: dict) -> None:
+ """
+ Overview:
+ Load matched model state_dict, and show mismatch keys between model's state_dict and checkpoint's state_dict
+ Arguments:
+ - model (:obj:`torch.nn.Module`): model
+ - ckpt_state_dict (:obj:`dict`): checkpoint's state_dict
+ """
+ assert isinstance(model, torch.nn.Module)
+ diff = {'miss_keys': [], 'redundant_keys': [], 'mismatch_shape_keys': []}
+ model_state_dict = model.state_dict()
+ model_keys = set(model_state_dict.keys())
+ ckpt_keys = set(ckpt_state_dict.keys())
+ diff['miss_keys'] = model_keys - ckpt_keys
+ diff['redundant_keys'] = ckpt_keys - model_keys
+
+ intersection_keys = model_keys.intersection(ckpt_keys)
+ valid_keys = []
+ for k in intersection_keys:
+ if model_state_dict[k].shape == ckpt_state_dict[k].shape:
+ valid_keys.append(k)
+ else:
+ diff['mismatch_shape_keys'].append(
+ '{}\tmodel_shape: {}\tckpt_shape: {}'.format(
+ k, model_state_dict[k].shape, ckpt_state_dict[k].shape
+ )
+ )
+ valid_ckpt_state_dict = {k: v for k, v in ckpt_state_dict.items() if k in valid_keys}
+ model.load_state_dict(valid_ckpt_state_dict, strict=False)
+
+ for n, keys in diff.items():
+ for k in keys:
+ logger.info('{}: {}'.format(n, k))
+
+ def load(
+ self,
+ load_path: str,
+ model: torch.nn.Module,
+ optimizer: torch.optim.Optimizer = None,
+ last_iter: 'CountVar' = None, # noqa
+ last_epoch: 'CountVar' = None, # noqa
+ last_frame: 'CountVar' = None, # noqa
+ lr_schduler: 'Scheduler' = None, # noqa
+ dataset: torch.utils.data.Dataset = None,
+ collector_info: torch.nn.Module = None,
+ prefix_op: str = None,
+ prefix: str = None,
+ strict: bool = True,
+ logger_prefix: str = '',
+ state_dict_mask: list = [],
+ ):
+ """
+ Overview:
+ Load checkpoint by given path
+ Arguments:
+ - load_path (:obj:`str`): checkpoint's path
+ - model (:obj:`torch.nn.Module`): model definition
+ - optimizer (:obj:`torch.optim.Optimizer`): optimizer obj
+ - last_iter (:obj:`CountVar`): iter num, default None
+ - last_epoch (:obj:`CountVar`): epoch num, default None
+ - last_frame (:obj:`CountVar`): frame num, default None
+ - lr_schduler (:obj:`Schduler`): lr_schduler obj
+ - dataset (:obj:`torch.utils.data.Dataset`): dataset, should be replaydataset
+ - collector_info (:obj:`torch.nn.Module`): attr of checkpoint, save collector info
+ - prefix_op (:obj:`str`): should be ['remove', 'add'], process on state_dict
+ - prefix (:obj:`str`): prefix to be processed on state_dict
+ - strict (:obj:`bool`): args of model.load_state_dict
+ - logger_prefix (:obj:`str`): prefix of logger
+ - state_dict_mask (:obj:`list`): A list containing state_dict keys, \
+ which shouldn't be loaded into model(after prefix op)
+
+ .. note::
+
+ The checkpoint loaded from load_path is a dict, whose format is like '{'state_dict': OrderedDict(), ...}'
+ """
+ # TODO save config
+ # Note: for reduce first GPU memory cost and compatible for cpu env
+ checkpoint = read_file(load_path)
+ state_dict = checkpoint['model']
+ if prefix_op is not None:
+ prefix_func = {'remove': self._remove_prefix, 'add': self._add_prefix}
+ if prefix_op not in prefix_func.keys():
+ raise KeyError('invalid prefix_op:{}'.format(prefix_op))
+ else:
+ state_dict = prefix_func[prefix_op](state_dict, prefix)
+ if len(state_dict_mask) > 0:
+ if strict:
+ logger.info(
+ logger_prefix +
+ '[Warning] non-empty state_dict_mask expects strict=False, but finds strict=True in input argument'
+ )
+ strict = False
+ for m in state_dict_mask:
+ state_dict_keys = list(state_dict.keys())
+ for k in state_dict_keys:
+ if k.startswith(m):
+ state_dict.pop(k) # ignore return value
+ if strict:
+ model.load_state_dict(state_dict, strict=True)
+ else:
+ self._load_matched_model_state_dict(model, state_dict)
+ logger.info(logger_prefix + 'load model state_dict in {}'.format(load_path))
+
+ if dataset is not None:
+ if 'dataset' in checkpoint.keys():
+ dataset.load_state_dict(checkpoint['dataset'])
+ logger.info(logger_prefix + 'load online data in {}'.format(load_path))
+ else:
+ logger.info(logger_prefix + "dataset not in checkpoint, ignore load procedure")
+
+ if optimizer is not None:
+ if 'optimizer' in checkpoint.keys():
+ optimizer.load_state_dict(checkpoint['optimizer'])
+ logger.info(logger_prefix + 'load optimizer in {}'.format(load_path))
+ else:
+ logger.info(logger_prefix + "optimizer not in checkpoint, ignore load procedure")
+
+ if last_iter is not None:
+ if 'last_iter' in checkpoint.keys():
+ last_iter.update(checkpoint['last_iter'])
+ logger.info(
+ logger_prefix + 'load last_iter in {}, current last_iter is {}'.format(load_path, last_iter.val)
+ )
+ else:
+ logger.info(logger_prefix + "last_iter not in checkpoint, ignore load procedure")
+
+ if collector_info is not None:
+ collector_info.load_state_dict(checkpoint['collector_info'])
+ logger.info(logger_prefix + 'load collector info in {}'.format(load_path))
+
+ if lr_schduler is not None:
+ assert (last_iter is not None)
+ raise NotImplementedError
+
+
+class CountVar(object):
+ """
+ Overview:
+ Number counter
+ Interfaces:
+ ``__init__``, ``update``, ``add``
+ Properties:
+ - val (:obj:`int`): the value of the counter
+ """
+
+ def __init__(self, init_val: int) -> None:
+ """
+ Overview:
+ Init the var counter
+ Arguments:
+ - init_val (:obj:`int`): the init value of the counter
+ """
+
+ self._val = init_val
+
+ @property
+ def val(self) -> int:
+ """
+ Overview:
+ Get the var counter
+ """
+
+ return self._val
+
+ def update(self, val: int) -> None:
+ """
+ Overview:
+ Update the var counter
+ Arguments:
+ - val (:obj:`int`): the update value of the counter
+ """
+ self._val = val
+
+ def add(self, add_num: int):
+ """
+ Overview:
+ Add the number to counter
+ Arguments:
+ - add_num (:obj:`int`): the number added to the counter
+ """
+ self._val += add_num
+
+
+def auto_checkpoint(func: Callable) -> Callable:
+ """
+ Overview:
+ Create a wrapper to wrap function, and the wrapper will call the save_checkpoint method
+ whenever an exception happens.
+ Arguments:
+ - func(:obj:`Callable`): the function to be wrapped
+ Returns:
+ - wrapper (:obj:`Callable`): the wrapped function
+ """
+ dead_signals = ['SIGILL', 'SIGINT', 'SIGKILL', 'SIGQUIT', 'SIGSEGV', 'SIGSTOP', 'SIGTERM', 'SIGBUS']
+ all_signals = dead_signals + ['SIGUSR1']
+
+ def register_signal_handler(handler):
+ valid_sig = []
+ invalid_sig = []
+ for sig in all_signals:
+ try:
+ sig = getattr(signal, sig)
+ signal.signal(sig, handler)
+ valid_sig.append(sig)
+ except Exception:
+ invalid_sig.append(sig)
+ logger.info('valid sig: ({})\ninvalid sig: ({})'.format(valid_sig, invalid_sig))
+
+ def wrapper(*args, **kwargs):
+ handle = args[0]
+ assert (hasattr(handle, 'save_checkpoint'))
+
+ def signal_handler(signal_num, frame):
+ sig = signal.Signals(signal_num)
+ logger.info("SIGNAL: {}({})".format(sig.name, sig.value))
+ handle.save_checkpoint('ckpt_interrupt.pth.tar')
+ sys.exit(1)
+
+ register_signal_handler(signal_handler)
+ try:
+ return func(*args, **kwargs)
+ except Exception as e:
+ handle.save_checkpoint('ckpt_exception.pth.tar')
+ traceback.print_exc()
+
+ return wrapper
diff --git a/DI-engine/ding/torch_utils/data_helper.py b/DI-engine/ding/torch_utils/data_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..b906279b1e9f972ae7294facad958e67434c6caf
--- /dev/null
+++ b/DI-engine/ding/torch_utils/data_helper.py
@@ -0,0 +1,756 @@
+from typing import Iterable, Any, Optional, List
+from collections.abc import Sequence
+import numbers
+import time
+import copy
+from threading import Thread
+from queue import Queue
+
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+
+from ding.utils.default_helper import get_shape0
+
+
+def to_device(item: Any, device: str, ignore_keys: list = []) -> Any:
+ """
+ Overview:
+ Transfer data to certain device.
+ Arguments:
+ - item (:obj:`Any`): The item to be transferred.
+ - device (:obj:`str`): The device wanted.
+ - ignore_keys (:obj:`list`): The keys to be ignored in transfer, default set to empty.
+ Returns:
+ - item (:obj:`Any`): The transferred item.
+ Examples:
+ >>> setup_data_dict['module'] = nn.Linear(3, 5)
+ >>> device = 'cuda'
+ >>> cuda_d = to_device(setup_data_dict, device, ignore_keys=['module'])
+ >>> assert cuda_d['module'].weight.device == torch.device('cpu')
+
+ Examples:
+ >>> setup_data_dict['module'] = nn.Linear(3, 5)
+ >>> device = 'cuda'
+ >>> cuda_d = to_device(setup_data_dict, device)
+ >>> assert cuda_d['module'].weight.device == torch.device('cuda:0')
+
+ .. note:
+
+ Now supports item type: :obj:`torch.nn.Module`, :obj:`torch.Tensor`, :obj:`Sequence`, \
+ :obj:`dict`, :obj:`numbers.Integral`, :obj:`numbers.Real`, :obj:`np.ndarray`, :obj:`str` and :obj:`None`.
+
+ """
+ if isinstance(item, torch.nn.Module):
+ return item.to(device)
+ elif isinstance(item, ttorch.Tensor):
+ if 'prev_state' in item:
+ prev_state = to_device(item.prev_state, device)
+ del item.prev_state
+ item = item.to(device)
+ item.prev_state = prev_state
+ return item
+ else:
+ return item.to(device)
+ elif isinstance(item, torch.Tensor):
+ return item.to(device)
+ elif isinstance(item, Sequence):
+ if isinstance(item, str):
+ return item
+ else:
+ return [to_device(t, device) for t in item]
+ elif isinstance(item, dict):
+ new_item = {}
+ for k in item.keys():
+ if k in ignore_keys:
+ new_item[k] = item[k]
+ else:
+ new_item[k] = to_device(item[k], device)
+ return new_item
+ elif isinstance(item, numbers.Integral) or isinstance(item, numbers.Real):
+ return item
+ elif isinstance(item, np.ndarray) or isinstance(item, np.bool_):
+ return item
+ elif item is None or isinstance(item, str):
+ return item
+ elif isinstance(item, torch.distributions.Distribution): # for compatibility
+ return item
+ else:
+ raise TypeError("not support item type: {}".format(type(item)))
+
+
+def to_dtype(item: Any, dtype: type) -> Any:
+ """
+ Overview:
+ Change data to certain dtype.
+ Arguments:
+ - item (:obj:`Any`): The item for changing the dtype.
+ - dtype (:obj:`type`): The type wanted.
+ Returns:
+ - item (:obj:`object`): The item with changed dtype.
+ Examples (tensor):
+ >>> t = torch.randint(0, 10, (3, 5))
+ >>> tfloat = to_dtype(t, torch.float)
+ >>> assert tfloat.dtype == torch.float
+
+ Examples (list):
+ >>> tlist = [torch.randint(0, 10, (3, 5))]
+ >>> tlfloat = to_dtype(tlist, torch.float)
+ >>> assert tlfloat[0].dtype == torch.float
+
+ Examples (dict):
+ >>> tdict = {'t': torch.randint(0, 10, (3, 5))}
+ >>> tdictf = to_dtype(tdict, torch.float)
+ >>> assert tdictf['t'].dtype == torch.float
+
+ .. note:
+
+ Now supports item type: :obj:`torch.Tensor`, :obj:`Sequence`, :obj:`dict`.
+ """
+ if isinstance(item, torch.Tensor):
+ return item.to(dtype=dtype)
+ elif isinstance(item, Sequence):
+ return [to_dtype(t, dtype) for t in item]
+ elif isinstance(item, dict):
+ return {k: to_dtype(item[k], dtype) for k in item.keys()}
+ else:
+ raise TypeError("not support item type: {}".format(type(item)))
+
+
+def to_tensor(
+ item: Any, dtype: Optional[torch.dtype] = None, ignore_keys: list = [], transform_scalar: bool = True
+) -> Any:
+ """
+ Overview:
+ Convert ``numpy.ndarray`` object to ``torch.Tensor``.
+ Arguments:
+ - item (:obj:`Any`): The ``numpy.ndarray`` objects to be converted. It can be exactly a ``numpy.ndarray`` \
+ object or a container (list, tuple or dict) that contains several ``numpy.ndarray`` objects.
+ - dtype (:obj:`torch.dtype`): The type of wanted tensor. If set to ``None``, its dtype will be unchanged.
+ - ignore_keys (:obj:`list`): If the ``item`` is a dict, values whose keys are in ``ignore_keys`` will not \
+ be converted.
+ - transform_scalar (:obj:`bool`): If set to ``True``, a scalar will be also converted to a tensor object.
+ Returns:
+ - item (:obj:`Any`): The converted tensors.
+
+ Examples (scalar):
+ >>> i = 10
+ >>> t = to_tensor(i)
+ >>> assert t.item() == i
+
+ Examples (dict):
+ >>> d = {'i': i}
+ >>> dt = to_tensor(d, torch.int)
+ >>> assert dt['i'].item() == i
+
+ Examples (named tuple):
+ >>> data_type = namedtuple('data_type', ['x', 'y'])
+ >>> inputs = data_type(np.random.random(3), 4)
+ >>> outputs = to_tensor(inputs, torch.float32)
+ >>> assert type(outputs) == data_type
+ >>> assert isinstance(outputs.x, torch.Tensor)
+ >>> assert isinstance(outputs.y, torch.Tensor)
+ >>> assert outputs.x.dtype == torch.float32
+ >>> assert outputs.y.dtype == torch.float32
+
+ .. note:
+
+ Now supports item type: :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`.
+ """
+
+ def transform(d):
+ if dtype is None:
+ return torch.as_tensor(d)
+ else:
+ return torch.tensor(d, dtype=dtype)
+
+ if isinstance(item, dict):
+ new_data = {}
+ for k, v in item.items():
+ if k in ignore_keys:
+ new_data[k] = v
+ else:
+ new_data[k] = to_tensor(v, dtype, ignore_keys, transform_scalar)
+ return new_data
+ elif isinstance(item, list) or isinstance(item, tuple):
+ if len(item) == 0:
+ return []
+ elif isinstance(item[0], numbers.Integral) or isinstance(item[0], numbers.Real):
+ return transform(item)
+ elif hasattr(item, '_fields'): # namedtuple
+ return type(item)(*[to_tensor(t, dtype) for t in item])
+ else:
+ new_data = []
+ for t in item:
+ new_data.append(to_tensor(t, dtype, ignore_keys, transform_scalar))
+ return new_data
+ elif isinstance(item, np.ndarray):
+ if dtype is None:
+ if item.dtype == np.float64:
+ return torch.FloatTensor(item)
+ else:
+ return torch.from_numpy(item)
+ else:
+ return torch.from_numpy(item).to(dtype)
+ elif isinstance(item, bool) or isinstance(item, str):
+ return item
+ elif np.isscalar(item):
+ if transform_scalar:
+ if dtype is None:
+ return torch.as_tensor(item)
+ else:
+ return torch.as_tensor(item).to(dtype)
+ else:
+ return item
+ elif item is None:
+ return None
+ elif isinstance(item, torch.Tensor):
+ if dtype is None:
+ return item
+ else:
+ return item.to(dtype)
+ else:
+ raise TypeError("not support item type: {}".format(type(item)))
+
+
+def to_ndarray(item: Any, dtype: np.dtype = None) -> Any:
+ """
+ Overview:
+ Convert ``torch.Tensor`` to ``numpy.ndarray``.
+ Arguments:
+ - item (:obj:`Any`): The ``torch.Tensor`` objects to be converted. It can be exactly a ``torch.Tensor`` \
+ object or a container (list, tuple or dict) that contains several ``torch.Tensor`` objects.
+ - dtype (:obj:`np.dtype`): The type of wanted array. If set to ``None``, its dtype will be unchanged.
+ Returns:
+ - item (:obj:`object`): The changed arrays.
+
+ Examples (ndarray):
+ >>> t = torch.randn(3, 5)
+ >>> tarray1 = to_ndarray(t)
+ >>> assert tarray1.shape == (3, 5)
+ >>> assert isinstance(tarray1, np.ndarray)
+
+ Examples (list):
+ >>> t = [torch.randn(5, ) for i in range(3)]
+ >>> tarray1 = to_ndarray(t, np.float32)
+ >>> assert isinstance(tarray1, list)
+ >>> assert tarray1[0].shape == (5, )
+ >>> assert isinstance(tarray1[0], np.ndarray)
+
+ .. note:
+
+ Now supports item type: :obj:`torch.Tensor`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`.
+ """
+
+ def transform(d):
+ if dtype is None:
+ return np.array(d)
+ else:
+ return np.array(d, dtype=dtype)
+
+ if isinstance(item, dict):
+ new_data = {}
+ for k, v in item.items():
+ new_data[k] = to_ndarray(v, dtype)
+ return new_data
+ elif isinstance(item, list) or isinstance(item, tuple):
+ if len(item) == 0:
+ return None
+ elif isinstance(item[0], numbers.Integral) or isinstance(item[0], numbers.Real):
+ return transform(item)
+ elif hasattr(item, '_fields'): # namedtuple
+ return type(item)(*[to_ndarray(t, dtype) for t in item])
+ else:
+ new_data = []
+ for t in item:
+ new_data.append(to_ndarray(t, dtype))
+ return new_data
+ elif isinstance(item, torch.Tensor):
+ if dtype is None:
+ return item.numpy()
+ else:
+ return item.numpy().astype(dtype)
+ elif isinstance(item, np.ndarray):
+ if dtype is None:
+ return item
+ else:
+ return item.astype(dtype)
+ elif isinstance(item, bool) or isinstance(item, str):
+ return item
+ elif np.isscalar(item):
+ if dtype is None:
+ return np.array(item)
+ else:
+ return np.array(item, dtype=dtype)
+ elif item is None:
+ return None
+ else:
+ raise TypeError("not support item type: {}".format(type(item)))
+
+
+def to_list(item: Any) -> Any:
+ """
+ Overview:
+ Convert ``torch.Tensor``, ``numpy.ndarray`` objects to ``list`` objects, and keep their dtypes unchanged.
+ Arguments:
+ - item (:obj:`Any`): The item to be converted.
+ Returns:
+ - item (:obj:`Any`): The list after conversion.
+
+ Examples:
+ >>> data = { \
+ 'tensor': torch.randn(4), \
+ 'list': [True, False, False], \
+ 'tuple': (4, 5, 6), \
+ 'bool': True, \
+ 'int': 10, \
+ 'float': 10., \
+ 'array': np.random.randn(4), \
+ 'str': "asdf", \
+ 'none': None, \
+ } \
+ >>> transformed_data = to_list(data)
+
+ .. note::
+
+ Now supports item type: :obj:`torch.Tensor`, :obj:`numpy.ndarray`, :obj:`dict`, :obj:`list`, \
+ :obj:`tuple` and :obj:`None`.
+ """
+ if item is None:
+ return item
+ elif isinstance(item, torch.Tensor):
+ return item.tolist()
+ elif isinstance(item, np.ndarray):
+ return item.tolist()
+ elif isinstance(item, list) or isinstance(item, tuple):
+ return [to_list(t) for t in item]
+ elif isinstance(item, dict):
+ return {k: to_list(v) for k, v in item.items()}
+ elif np.isscalar(item):
+ return item
+ else:
+ raise TypeError("not support item type: {}".format(type(item)))
+
+
+def tensor_to_list(item: Any) -> Any:
+ """
+ Overview:
+ Convert ``torch.Tensor`` objects to ``list``, and keep their dtypes unchanged.
+ Arguments:
+ - item (:obj:`Any`): The item to be converted.
+ Returns:
+ - item (:obj:`Any`): The lists after conversion.
+
+ Examples (2d-tensor):
+ >>> t = torch.randn(3, 5)
+ >>> tlist1 = tensor_to_list(t)
+ >>> assert len(tlist1) == 3
+ >>> assert len(tlist1[0]) == 5
+
+ Examples (1d-tensor):
+ >>> t = torch.randn(3, )
+ >>> tlist1 = tensor_to_list(t)
+ >>> assert len(tlist1) == 3
+
+ Examples (list)
+ >>> t = [torch.randn(5, ) for i in range(3)]
+ >>> tlist1 = tensor_to_list(t)
+ >>> assert len(tlist1) == 3
+ >>> assert len(tlist1[0]) == 5
+
+ Examples (dict):
+ >>> td = {'t': torch.randn(3, 5)}
+ >>> tdlist1 = tensor_to_list(td)
+ >>> assert len(tdlist1['t']) == 3
+ >>> assert len(tdlist1['t'][0]) == 5
+
+ .. note::
+
+ Now supports item type: :obj:`torch.Tensor`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`.
+ """
+ if item is None:
+ return item
+ elif isinstance(item, torch.Tensor):
+ return item.tolist()
+ elif isinstance(item, list) or isinstance(item, tuple):
+ return [tensor_to_list(t) for t in item]
+ elif isinstance(item, dict):
+ return {k: tensor_to_list(v) for k, v in item.items()}
+ elif np.isscalar(item):
+ return item
+ else:
+ raise TypeError("not support item type: {}".format(type(item)))
+
+
+def to_item(data: Any, ignore_error: bool = True) -> Any:
+ """
+ Overview:
+ Convert data to python native scalar (i.e. data item), and keep their dtypes unchanged.
+ Arguments:
+ - data (:obj:`Any`): The data that needs to be converted.
+ - ignore_error (:obj:`bool`): Whether to ignore the error when the data type is not supported. That is to \
+ say, only the data can be transformed into a python native scalar will be returned.
+ Returns:
+ - data (:obj:`Any`): Converted data.
+
+ Examples:
+ >>>> data = { \
+ 'tensor': torch.randn(1), \
+ 'list': [True, False, torch.randn(1)], \
+ 'tuple': (4, 5, 6), \
+ 'bool': True, \
+ 'int': 10, \
+ 'float': 10., \
+ 'array': np.random.randn(1), \
+ 'str': "asdf", \
+ 'none': None, \
+ }
+ >>>> new_data = to_item(data)
+ >>>> assert np.isscalar(new_data['tensor'])
+ >>>> assert np.isscalar(new_data['array'])
+ >>>> assert np.isscalar(new_data['list'][-1])
+
+ .. note::
+
+ Now supports item type: :obj:`torch.Tensor`, :obj:`torch.Tensor`, :obj:`ttorch.Tensor`, \
+ :obj:`bool`, :obj:`str`, :obj:`dict`, :obj:`list`, :obj:`tuple` and :obj:`None`.
+ """
+ if data is None:
+ return data
+ elif isinstance(data, bool) or isinstance(data, str):
+ return data
+ elif np.isscalar(data):
+ return data
+ elif isinstance(data, np.ndarray) or isinstance(data, torch.Tensor) or isinstance(data, ttorch.Tensor):
+ return data.item()
+ elif isinstance(data, list) or isinstance(data, tuple):
+ return [to_item(d) for d in data]
+ elif isinstance(data, dict):
+ new_data = {}
+ for k, v in data.items():
+ if ignore_error:
+ try:
+ new_data[k] = to_item(v)
+ except (ValueError, RuntimeError):
+ pass
+ else:
+ new_data[k] = to_item(v)
+ return new_data
+ else:
+ raise TypeError("not support data type: {}".format(data))
+
+
+def same_shape(data: list) -> bool:
+ """
+ Overview:
+ Judge whether all data elements in a list have the same shapes.
+ Arguments:
+ - data (:obj:`list`): The list of data.
+ Returns:
+ - same (:obj:`bool`): Whether the list of data all have the same shape.
+
+ Examples:
+ >>> tlist = [torch.randn(3, 5) for i in range(5)]
+ >>> assert same_shape(tlist)
+ >>> tlist = [torch.randn(3, 5), torch.randn(4, 5)]
+ >>> assert not same_shape(tlist)
+ """
+ assert (isinstance(data, list))
+ shapes = [t.shape for t in data]
+ return len(set(shapes)) == 1
+
+
+class LogDict(dict):
+ """
+ Overview:
+ Derived from ``dict``. Would convert ``torch.Tensor`` to ``list`` for convenient logging.
+ Interfaces:
+ ``_transform``, ``__setitem__``, ``update``.
+ """
+
+ def _transform(self, data: Any) -> None:
+ """
+ Overview:
+ Convert tensor objects to lists for better logging.
+ Arguments:
+ - data (:obj:`Any`): The input data to be converted.
+ """
+ if isinstance(data, torch.Tensor):
+ new_data = data.tolist()
+ else:
+ new_data = data
+ return new_data
+
+ def __setitem__(self, key: Any, value: Any) -> None:
+ """
+ Overview:
+ Override the ``__setitem__`` function of built-in dict.
+ Arguments:
+ - key (:obj:`Any`): The key of the data item.
+ - value (:obj:`Any`): The value of the data item.
+ """
+ new_value = self._transform(value)
+ super().__setitem__(key, new_value)
+
+ def update(self, data: dict) -> None:
+ """
+ Overview:
+ Override the ``update`` function of built-in dict.
+ Arguments:
+ - data (:obj:`dict`): The dict for updating current object.
+ """
+ for k, v in data.items():
+ self.__setitem__(k, v)
+
+
+def build_log_buffer() -> LogDict:
+ """
+ Overview:
+ Build log buffer, a subclass of dict, which can convert the input data into log format.
+ Returns:
+ - log_buffer (:obj:`LogDict`): Log buffer dict.
+ Examples:
+ >>> log_buffer = build_log_buffer()
+ >>> log_buffer['not_tensor'] = torch.randn(3)
+ >>> assert isinstance(log_buffer['not_tensor'], list)
+ >>> assert len(log_buffer['not_tensor']) == 3
+ >>> log_buffer.update({'not_tensor': 4, 'a': 5})
+ >>> assert log_buffer['not_tensor'] == 4
+ """
+ return LogDict()
+
+
+class CudaFetcher(object):
+ """
+ Overview:
+ Fetch data from source, and transfer it to a specified device.
+ Interfaces:
+ ``__init__``, ``__next__``, ``run``, ``close``.
+ """
+
+ def __init__(self, data_source: Iterable, device: str, queue_size: int = 4, sleep: float = 0.1) -> None:
+ """
+ Overview:
+ Initialize the CudaFetcher object using the given arguments.
+ Arguments:
+ - data_source (:obj:`Iterable`): The iterable data source.
+ - device (:obj:`str`): The device to put data to, such as "cuda:0".
+ - queue_size (:obj:`int`): The internal size of queue, such as 4.
+ - sleep (:obj:`float`): Sleeping time when the internal queue is full.
+ """
+ self._source = data_source
+ self._queue = Queue(maxsize=queue_size)
+ self._stream = torch.cuda.Stream()
+ self._producer_thread = Thread(target=self._producer, args=(), name='cuda_fetcher_producer')
+ self._sleep = sleep
+ self._device = device
+
+ def __next__(self) -> Any:
+ """
+ Overview:
+ Response to the request for data. Return one data item from the internal queue.
+ Returns:
+ - item (:obj:`Any`): The data item on the required device.
+ """
+ return self._queue.get()
+
+ def run(self) -> None:
+ """
+ Overview:
+ Start ``producer`` thread: Keep fetching data from source, change the device, and put into \
+ ``queue`` for request.
+ Examples:
+ >>> timer = EasyTimer()
+ >>> dataloader = iter([torch.randn(3, 3) for _ in range(10)])
+ >>> dataloader = CudaFetcher(dataloader, device='cuda', sleep=0.1)
+ >>> dataloader.run()
+ >>> data = next(dataloader)
+ """
+ self._end_flag = False
+ self._producer_thread.start()
+
+ def close(self) -> None:
+ """
+ Overview:
+ Stop ``producer`` thread by setting ``end_flag`` to ``True`` .
+ """
+ self._end_flag = True
+
+ def _producer(self) -> None:
+ """
+ Overview:
+ Keep fetching data from source, change the device, and put into ``queue`` for request.
+ """
+
+ with torch.cuda.stream(self._stream):
+ while not self._end_flag:
+ if self._queue.full():
+ time.sleep(self._sleep)
+ else:
+ data = next(self._source)
+ data = to_device(data, self._device)
+ self._queue.put(data)
+
+
+def get_tensor_data(data: Any) -> Any:
+ """
+ Overview:
+ Get pure tensor data from the given data (without disturbing grad computation graph).
+ Arguments:
+ - data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict).
+ Returns:
+ - output (:obj:`Any`): The output data.
+ Examples:
+ >>> a = { \
+ 'tensor': torch.tensor([1, 2, 3.], requires_grad=True), \
+ 'list': [torch.tensor([1, 2, 3.], requires_grad=True) for _ in range(2)], \
+ 'none': None \
+ }
+ >>> tensor_a = get_tensor_data(a)
+ >>> assert not tensor_a['tensor'].requires_grad
+ >>> for t in tensor_a['list']:
+ >>> assert not t.requires_grad
+ """
+ if isinstance(data, torch.Tensor):
+ return data.data.clone()
+ elif data is None:
+ return None
+ elif isinstance(data, Sequence):
+ return [get_tensor_data(d) for d in data]
+ elif isinstance(data, dict):
+ return {k: get_tensor_data(v) for k, v in data.items()}
+ else:
+ raise TypeError("not support type in get_tensor_data: {}".format(type(data)))
+
+
+def unsqueeze(data: Any, dim: int = 0) -> Any:
+ """
+ Overview:
+ Unsqueeze the tensor data.
+ Arguments:
+ - data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict).
+ - dim (:obj:`int`): The dimension to be unsqueezed.
+ Returns:
+ - output (:obj:`Any`): The output data.
+
+ Examples (tensor):
+ >>> t = torch.randn(3, 3)
+ >>> tt = unsqueeze(t, dim=0)
+ >>> assert tt.shape == torch.Shape([1, 3, 3])
+
+ Examples (list):
+ >>> t = [torch.randn(3, 3)]
+ >>> tt = unsqueeze(t, dim=0)
+ >>> assert tt[0].shape == torch.Shape([1, 3, 3])
+
+ Examples (dict):
+ >>> t = {"t": torch.randn(3, 3)}
+ >>> tt = unsqueeze(t, dim=0)
+ >>> assert tt["t"].shape == torch.Shape([1, 3, 3])
+ """
+ if isinstance(data, torch.Tensor):
+ return data.unsqueeze(dim)
+ elif isinstance(data, Sequence):
+ return [unsqueeze(d) for d in data]
+ elif isinstance(data, dict):
+ return {k: unsqueeze(v, 0) for k, v in data.items()}
+ else:
+ raise TypeError("not support type in unsqueeze: {}".format(type(data)))
+
+
+def squeeze(data: Any, dim: int = 0) -> Any:
+ """
+ Overview:
+ Squeeze the tensor data.
+ Arguments:
+ - data (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict).
+ - dim (:obj:`int`): The dimension to be Squeezed.
+ Returns:
+ - output (:obj:`Any`): The output data.
+
+ Examples (tensor):
+ >>> t = torch.randn(1, 3, 3)
+ >>> tt = squeeze(t, dim=0)
+ >>> assert tt.shape == torch.Shape([3, 3])
+
+ Examples (list):
+ >>> t = [torch.randn(1, 3, 3)]
+ >>> tt = squeeze(t, dim=0)
+ >>> assert tt[0].shape == torch.Shape([3, 3])
+
+ Examples (dict):
+ >>> t = {"t": torch.randn(1, 3, 3)}
+ >>> tt = squeeze(t, dim=0)
+ >>> assert tt["t"].shape == torch.Shape([3, 3])
+ """
+ if isinstance(data, torch.Tensor):
+ return data.squeeze(dim)
+ elif isinstance(data, Sequence):
+ return [squeeze(d) for d in data]
+ elif isinstance(data, dict):
+ return {k: squeeze(v, 0) for k, v in data.items()}
+ else:
+ raise TypeError("not support type in squeeze: {}".format(type(data)))
+
+
+def get_null_data(template: Any, num: int) -> List[Any]:
+ """
+ Overview:
+ Get null data given an input template.
+ Arguments:
+ - template (:obj:`Any`): The template data.
+ - num (:obj:`int`): The number of null data items to generate.
+ Returns:
+ - output (:obj:`List[Any]`): The generated null data.
+
+ Examples:
+ >>> temp = {'obs': [1, 2, 3], 'action': 1, 'done': False, 'reward': torch.tensor(1.)}
+ >>> null_data = get_null_data(temp, 2)
+ >>> assert len(null_data) ==2
+ >>> assert null_data[0]['null'] and null_data[0]['done']
+ """
+ ret = []
+ for _ in range(num):
+ data = copy.deepcopy(template)
+ data['null'] = True
+ data['done'] = True
+ data['reward'].zero_()
+ ret.append(data)
+ return ret
+
+
+def zeros_like(h: Any) -> Any:
+ """
+ Overview:
+ Generate zero-tensors like the input data.
+ Arguments:
+ - h (:obj:`Any`): The original data. It can be exactly a tensor or a container (Sequence or dict).
+ Returns:
+ - output (:obj:`Any`): The output zero-tensors.
+
+ Examples (tensor):
+ >>> t = torch.randn(3, 3)
+ >>> tt = zeros_like(t)
+ >>> assert tt.shape == torch.Shape([3, 3])
+ >>> assert torch.sum(torch.abs(tt)) < 1e-8
+
+ Examples (list):
+ >>> t = [torch.randn(3, 3)]
+ >>> tt = zeros_like(t)
+ >>> assert tt[0].shape == torch.Shape([3, 3])
+ >>> assert torch.sum(torch.abs(tt[0])) < 1e-8
+
+ Examples (dict):
+ >>> t = {"t": torch.randn(3, 3)}
+ >>> tt = zeros_like(t)
+ >>> assert tt["t"].shape == torch.Shape([3, 3])
+ >>> assert torch.sum(torch.abs(tt["t"])) < 1e-8
+ """
+ if isinstance(h, torch.Tensor):
+ return torch.zeros_like(h)
+ elif isinstance(h, (list, tuple)):
+ return [zeros_like(t) for t in h]
+ elif isinstance(h, dict):
+ return {k: zeros_like(v) for k, v in h.items()}
+ else:
+ raise TypeError("not support type: {}".format(h))
diff --git a/DI-engine/ding/torch_utils/dataparallel.py b/DI-engine/ding/torch_utils/dataparallel.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4ea14f767d8082cf109a0e5cc65e20d103ab44c
--- /dev/null
+++ b/DI-engine/ding/torch_utils/dataparallel.py
@@ -0,0 +1,35 @@
+import torch
+import torch.nn as nn
+
+
+class DataParallel(nn.DataParallel):
+ """
+ Overview:
+ A wrapper class for nn.DataParallel.
+ Interfaces:
+ ``__init__``, ``parameters``
+ """
+
+ def __init__(self, module, device_ids=None, output_device=None, dim=0):
+ """
+ Overview:
+ Initialize the DataParallel object.
+ Arguments:
+ - module (:obj:`nn.Module`): The module to be parallelized.
+ - device_ids (:obj:`list`): The list of GPU ids.
+ - output_device (:obj:`int`): The output GPU id.
+ - dim (:obj:`int`): The dimension to be parallelized.
+ """
+ super().__init__(module, device_ids=None, output_device=None, dim=0)
+ self.module = module
+
+ def parameters(self, recurse: bool = True):
+ """
+ Overview:
+ Return the parameters of the module.
+ Arguments:
+ - recurse (:obj:`bool`): Whether to return the parameters of the submodules.
+ Returns:
+ - params (:obj:`generator`): The generator of the parameters.
+ """
+ return self.module.parameters(recurse=True)
diff --git a/DI-engine/ding/torch_utils/distribution.py b/DI-engine/ding/torch_utils/distribution.py
new file mode 100644
index 0000000000000000000000000000000000000000..f68ef6fba015b82c500098ecf9d513feae14b542
--- /dev/null
+++ b/DI-engine/ding/torch_utils/distribution.py
@@ -0,0 +1,275 @@
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+from typing import Tuple, Dict
+
+import torch
+import numpy as np
+import torch.nn.functional as F
+
+
+class Pd(object):
+ """
+ Overview:
+ Abstract class for parameterizable probability distributions and sampling functions.
+ Interfaces:
+ ``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample``
+
+ .. tip::
+
+ In dereived classes, `logits` should be an attribute member stored in class.
+ """
+
+ def neglogp(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Calculate cross_entropy between input x and logits
+ Arguments:
+ - x (:obj:`torch.Tensor`): the input tensor
+ Return:
+ - cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss
+ """
+ raise NotImplementedError
+
+ def entropy(self) -> torch.Tensor:
+ """
+ Overview:
+ Calculate the softmax entropy of logits
+ Arguments:
+ - reduction (:obj:`str`): support [None, 'mean'], default set to 'mean'
+ Returns:
+ - entropy (:obj:`torch.Tensor`): the calculated entropy
+ """
+ raise NotImplementedError
+
+ def noise_mode(self):
+ """
+ Overview:
+ Add noise to logits. This method is designed for randomness
+ """
+ raise NotImplementedError
+
+ def mode(self):
+ """
+ Overview:
+ Return logits argmax result. This method is designed for deterministic.
+ """
+ raise NotImplementedError
+
+ def sample(self):
+ """
+ Overview:
+ Sample from logits's distribution by using softmax. This method is designed for multinomial.
+ """
+ raise NotImplementedError
+
+
+class CategoricalPd(Pd):
+ """
+ Overview:
+ Catagorical probility distribution sampler
+ Interfaces:
+ ``__init__``, ``neglogp``, ``entropy``, ``noise_mode``, ``mode``, ``sample``
+ """
+
+ def __init__(self, logits: torch.Tensor = None) -> None:
+ """
+ Overview:
+ Init the Pd with logits
+ Arguments:
+ - logits (:obj:torch.Tensor): logits to sample from
+ """
+ self.update_logits(logits)
+
+ def update_logits(self, logits: torch.Tensor) -> None:
+ """
+ Overview:
+ Updata logits
+ Arguments:
+ - logits (:obj:`torch.Tensor`): logits to update
+ """
+ self.logits = logits
+
+ def neglogp(self, x, reduction: str = 'mean') -> torch.Tensor:
+ """
+ Overview:
+ Calculate cross_entropy between input x and logits
+ Arguments:
+ - x (:obj:`torch.Tensor`): the input tensor
+ - reduction (:obj:`str`): support [None, 'mean'], default set to mean
+ Return:
+ - cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss
+ """
+ return F.cross_entropy(self.logits, x, reduction=reduction)
+
+ def entropy(self, reduction: str = 'mean') -> torch.Tensor:
+ """
+ Overview:
+ Calculate the softmax entropy of logits
+ Arguments:
+ - reduction (:obj:`str`): support [None, 'mean'], default set to mean
+ Returns:
+ - entropy (:obj:`torch.Tensor`): the calculated entropy
+ """
+ a = self.logits - self.logits.max(dim=-1, keepdim=True)[0]
+ ea = torch.exp(a)
+ z = ea.sum(dim=-1, keepdim=True)
+ p = ea / z
+ entropy = (p * (torch.log(z) - a)).sum(dim=-1)
+ assert (reduction in [None, 'mean'])
+ if reduction is None:
+ return entropy
+ elif reduction == 'mean':
+ return entropy.mean()
+
+ def noise_mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
+ """
+ Overview:
+ add noise to logits
+ Arguments:
+ - viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \
+ Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log)
+ Returns:
+ - result (:obj:`torch.Tensor`): noised logits
+ - viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization.
+ """
+ u = torch.rand_like(self.logits)
+ u = -torch.log(-torch.log(u))
+ noise_logits = self.logits + u
+ result = noise_logits.argmax(dim=-1)
+ if viz:
+ viz_feature = {}
+ viz_feature['logits'] = self.logits.data.cpu().numpy()
+ viz_feature['noise'] = u.data.cpu().numpy()
+ viz_feature['noise_logits'] = noise_logits.data.cpu().numpy()
+ return result, viz_feature
+ else:
+ return result
+
+ def mode(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
+ """
+ Overview:
+ return logits argmax result
+ Arguments:
+ - viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits;
+ Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log)
+ Returns:
+ - result (:obj:`torch.Tensor`): the logits argmax result
+ - viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization.
+ """
+ result = self.logits.argmax(dim=-1)
+ if viz:
+ viz_feature = {}
+ viz_feature['logits'] = self.logits.data.cpu().numpy()
+ return result, viz_feature
+ else:
+ return result
+
+ def sample(self, viz: bool = False) -> Tuple[torch.Tensor, Dict[str, np.ndarray]]:
+ """
+ Overview:
+ Sample from logits's distribution by using softmax
+ Arguments:
+ - viz (:obj:`bool`): Whether to return numpy from of logits, noise and noise_logits; \
+ Short for ``visualize`` . (Because tensor type cannot visualize in tb or text log)
+ Returns:
+ - result (:obj:`torch.Tensor`): the logits sampled result
+ - viz_feature (:obj:`Dict[str, np.ndarray]`): ndarray type data for visualization.
+ """
+ p = torch.softmax(self.logits, dim=1)
+ result = torch.multinomial(p, 1).squeeze(1)
+ if viz:
+ viz_feature = {}
+ viz_feature['logits'] = self.logits.data.cpu().numpy()
+ return result, viz_feature
+ else:
+ return result
+
+
+class CategoricalPdPytorch(torch.distributions.Categorical):
+ """
+ Overview:
+ Wrapped ``torch.distributions.Categorical``
+
+ Interfaces:
+ ``__init__``, ``update_logits``, ``update_probs``, ``sample``, ``neglogp``, ``mode``, ``entropy``
+ """
+
+ def __init__(self, probs: torch.Tensor = None) -> None:
+ """
+ Overview:
+ Initialize the CategoricalPdPytorch object.
+ Arguments:
+ - probs (:obj:`torch.Tensor`): The tensor of probabilities.
+ """
+ if probs is not None:
+ self.update_probs(probs)
+
+ def update_logits(self, logits: torch.Tensor) -> None:
+ """
+ Overview:
+ Updata logits
+ Arguments:
+ - logits (:obj:`torch.Tensor`): logits to update
+ """
+ super().__init__(logits=logits)
+
+ def update_probs(self, probs: torch.Tensor) -> None:
+ """
+ Overview:
+ Updata probs
+ Arguments:
+ - probs (:obj:`torch.Tensor`): probs to update
+ """
+ super().__init__(probs=probs)
+
+ def sample(self) -> torch.Tensor:
+ """
+ Overview:
+ Sample from logits's distribution by using softmax
+ Return:
+ - result (:obj:`torch.Tensor`): the logits sampled result
+ """
+ return super().sample()
+
+ def neglogp(self, actions: torch.Tensor, reduction: str = 'mean') -> torch.Tensor:
+ """
+ Overview:
+ Calculate cross_entropy between input x and logits
+ Arguments:
+ - actions (:obj:`torch.Tensor`): the input action tensor
+ - reduction (:obj:`str`): support [None, 'mean'], default set to mean
+ Return:
+ - cross_entropy (:obj:`torch.Tensor`): the returned cross_entropy loss
+ """
+ neglogp = super().log_prob(actions)
+ assert (reduction in ['none', 'mean'])
+ if reduction == 'none':
+ return neglogp
+ elif reduction == 'mean':
+ return neglogp.mean(dim=0)
+
+ def mode(self) -> torch.Tensor:
+ """
+ Overview:
+ Return logits argmax result
+ Return:
+ - result(:obj:`torch.Tensor`): the logits argmax result
+ """
+ return self.probs.argmax(dim=-1)
+
+ def entropy(self, reduction: str = None) -> torch.Tensor:
+ """
+ Overview:
+ Calculate the softmax entropy of logits
+ Arguments:
+ - reduction (:obj:`str`): support [None, 'mean'], default set to mean
+ Returns:
+ - entropy (:obj:`torch.Tensor`): the calculated entropy
+ """
+ entropy = super().entropy()
+ assert (reduction in [None, 'mean'])
+ if reduction is None:
+ return entropy
+ elif reduction == 'mean':
+ return entropy.mean()
diff --git a/DI-engine/ding/torch_utils/loss/__init__.py b/DI-engine/ding/torch_utils/loss/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8668e53b9855546a1e7fb8d97c79ef17ca3f1d3b
--- /dev/null
+++ b/DI-engine/ding/torch_utils/loss/__init__.py
@@ -0,0 +1,3 @@
+from .cross_entropy_loss import LabelSmoothCELoss, SoftFocalLoss, build_ce_criterion
+from .multi_logits_loss import MultiLogitsLoss
+from .contrastive_loss import ContrastiveLoss
diff --git a/DI-engine/ding/torch_utils/loss/contrastive_loss.py b/DI-engine/ding/torch_utils/loss/contrastive_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..94ef62f8de7bdcdfa196d330a5ad0ec599b92db4
--- /dev/null
+++ b/DI-engine/ding/torch_utils/loss/contrastive_loss.py
@@ -0,0 +1,139 @@
+from typing import Union
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ding.utils import SequenceType
+
+
+class ContrastiveLoss(nn.Module):
+ """
+ Overview:
+ The class for contrastive learning losses. Only InfoNCE loss is supported currently. \
+ Code Reference: https://github.com/rdevon/DIM. Paper Reference: https://arxiv.org/abs/1808.06670.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self,
+ x_size: Union[int, SequenceType],
+ y_size: Union[int, SequenceType],
+ heads: SequenceType = [1, 1],
+ encode_shape: int = 64,
+ loss_type: str = "infoNCE", # Only the InfoNCE loss is available now.
+ temperature: float = 1.0,
+ ) -> None:
+ """
+ Overview:
+ Initialize the ContrastiveLoss object using the given arguments.
+ Arguments:
+ - x_size (:obj:`Union[int, SequenceType]`): input shape for x, both the obs shape and the encoding shape \
+ are supported.
+ - y_size (:obj:`Union[int, SequenceType]`): Input shape for y, both the obs shape and the encoding shape \
+ are supported.
+ - heads (:obj:`SequenceType`): A list of 2 int elems, ``heads[0]`` for x and ``head[1]`` for y. \
+ Used in multi-head, global-local, local-local MI maximization process.
+ - encoder_shape (:obj:`Union[int, SequenceType]`): The dimension of encoder hidden state.
+ - loss_type: Only the InfoNCE loss is available now.
+ - temperature: The parameter to adjust the ``log_softmax``.
+ """
+ super(ContrastiveLoss, self).__init__()
+ assert len(heads) == 2, "Expected length of 2, but got: {}".format(len(heads))
+ assert loss_type.lower() in ["infonce"]
+
+ self._type = loss_type.lower()
+ self._encode_shape = encode_shape
+ self._heads = heads
+ self._x_encoder = self._create_encoder(x_size, heads[0])
+ self._y_encoder = self._create_encoder(y_size, heads[1])
+ self._temperature = temperature
+
+ def _create_encoder(self, obs_size: Union[int, SequenceType], heads: int) -> nn.Module:
+ """
+ Overview:
+ Create the encoder for the input obs.
+ Arguments:
+ - obs_size (:obj:`Union[int, SequenceType]`): input shape for x, both the obs shape and the encoding shape \
+ are supported. If the obs_size is an int, it means the obs is a 1D vector. If the obs_size is a list \
+ such as [1, 16, 16], it means the obs is a 3D image with shape [1, 16, 16].
+ - heads (:obj:`int`): The number of heads.
+ Returns:
+ - encoder (:obj:`nn.Module`): The encoder module.
+ Examples:
+ >>> obs_size = 16
+ or
+ >>> obs_size = [1, 16, 16]
+ >>> heads = 1
+ >>> encoder = self._create_encoder(obs_size, heads)
+ """
+ from ding.model import ConvEncoder, FCEncoder
+
+ if isinstance(obs_size, int):
+ obs_size = [obs_size]
+ assert len(obs_size) in [1, 3]
+
+ if len(obs_size) == 1:
+ hidden_size_list = [128, 128, self._encode_shape * heads]
+ encoder = FCEncoder(obs_size[0], hidden_size_list)
+ else:
+ hidden_size_list = [32, 64, 64, self._encode_shape * heads]
+ if obs_size[-1] >= 36:
+ encoder = ConvEncoder(obs_size, hidden_size_list)
+ else:
+ encoder = ConvEncoder(obs_size, hidden_size_list, kernel_size=[4, 3, 2], stride=[2, 1, 1])
+ return encoder
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Computes the noise contrastive estimation-based loss, a.k.a. infoNCE.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input x, both raw obs and encoding are supported.
+ - y (:obj:`torch.Tensor`): The input y, both raw obs and encoding are supported.
+ Returns:
+ loss (:obj:`torch.Tensor`): The calculated loss value.
+ Examples:
+ >>> x_dim = [3, 16]
+ >>> encode_shape = 16
+ >>> x = np.random.normal(0, 1, size=x_dim)
+ >>> y = x ** 2 + 0.01 * np.random.normal(0, 1, size=x_dim)
+ >>> estimator = ContrastiveLoss(dims, dims, encode_shape=encode_shape)
+ >>> loss = estimator.forward(x, y)
+ Examples:
+ >>> x_dim = [3, 1, 16, 16]
+ >>> encode_shape = 16
+ >>> x = np.random.normal(0, 1, size=x_dim)
+ >>> y = x ** 2 + 0.01 * np.random.normal(0, 1, size=x_dim)
+ >>> estimator = ContrastiveLoss(dims, dims, encode_shape=encode_shape)
+ >>> loss = estimator.forward(x, y)
+ """
+
+ N = x.size(0)
+ x_heads, y_heads = self._heads
+ x = self._x_encoder.forward(x).view(N, x_heads, self._encode_shape)
+ y = self._y_encoder.forward(y).view(N, y_heads, self._encode_shape)
+
+ x_n = x.view(-1, self._encode_shape)
+ y_n = y.view(-1, self._encode_shape)
+
+ # Use inner product to obtain positive samples.
+ # [N, x_heads, encode_dim] * [N, encode_dim, y_heads] -> [N, x_heads, y_heads]
+ u_pos = torch.matmul(x, y.permute(0, 2, 1)).unsqueeze(2)
+ # Use outer product to obtain all sample permutations.
+ # [N * x_heads, encode_dim] X [encode_dim, N * y_heads] -> [N * x_heads, N * y_heads]
+ u_all = torch.mm(y_n, x_n.t()).view(N, y_heads, N, x_heads).permute(0, 2, 3, 1)
+
+ # Mask the diagonal part to obtain the negative samples, with all diagonals setting to -10.
+ mask = torch.eye(N)[:, :, None, None].to(x.device)
+ n_mask = 1 - mask
+ u_neg = (n_mask * u_all) - (10. * (1 - n_mask))
+ u_neg = u_neg.view(N, N * x_heads, y_heads).unsqueeze(dim=1).expand(-1, x_heads, -1, -1)
+
+ # Concatenate positive and negative samples and apply log softmax.
+ pred_lgt = torch.cat([u_pos, u_neg], dim=2)
+ pred_log = F.log_softmax(pred_lgt * self._temperature, dim=2)
+
+ # The positive score is the first element of the log softmax.
+ loss = -pred_log[:, :, 0, :].mean()
+ return loss
diff --git a/DI-engine/ding/torch_utils/loss/cross_entropy_loss.py b/DI-engine/ding/torch_utils/loss/cross_entropy_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cdcb969d79fcf9c5a8880265693652d1a0c99af
--- /dev/null
+++ b/DI-engine/ding/torch_utils/loss/cross_entropy_loss.py
@@ -0,0 +1,106 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from typing import Any, Optional
+
+
+class LabelSmoothCELoss(nn.Module):
+ """
+ Overview:
+ Label smooth cross entropy loss.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(self, ratio: float) -> None:
+ """
+ Overview:
+ Initialize the LabelSmoothCELoss object using the given arguments.
+ Arguments:
+ - ratio (:obj:`float`): The ratio of label-smoothing (the value is in 0-1). If the ratio is larger, the \
+ extent of label smoothing is larger.
+ """
+ super().__init__()
+ self.ratio = ratio
+
+ def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
+ """
+ Overview:
+ Calculate label smooth cross entropy loss.
+ Arguments:
+ - logits (:obj:`torch.Tensor`): Predicted logits.
+ - labels (:obj:`torch.LongTensor`): Ground truth.
+ Returns:
+ - loss (:obj:`torch.Tensor`): Calculated loss.
+ """
+ B, N = logits.shape
+ val = float(self.ratio) / (N - 1)
+ one_hot = torch.full_like(logits, val)
+ one_hot.scatter_(1, labels.unsqueeze(1), 1 - val)
+ logits = F.log_softmax(logits, dim=1)
+ return -torch.sum(logits * (one_hot.detach())) / B
+
+
+class SoftFocalLoss(nn.Module):
+ """
+ Overview:
+ Soft focal loss.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(
+ self, gamma: int = 2, weight: Any = None, size_average: bool = True, reduce: Optional[bool] = None
+ ) -> None:
+ """
+ Overview:
+ Initialize the SoftFocalLoss object using the given arguments.
+ Arguments:
+ - gamma (:obj:`int`): The extent of focus on hard samples. A smaller ``gamma`` will lead to more focus on \
+ easy samples, while a larger ``gamma`` will lead to more focus on hard samples.
+ - weight (:obj:`Any`): The weight for loss of each class.
+ - size_average (:obj:`bool`): By default, the losses are averaged over each loss element in the batch. \
+ Note that for some losses, there are multiple elements per sample. If the field ``size_average`` is \
+ set to ``False``, the losses are instead summed for each minibatch. Ignored when ``reduce`` is \
+ ``False``.
+ - reduce (:obj:`Optional[bool]`): By default, the losses are averaged or summed over observations for \
+ each minibatch depending on size_average. When ``reduce`` is ``False``, returns a loss for each batch \
+ element instead and ignores ``size_average``.
+ """
+ super().__init__()
+ self.gamma = gamma
+ self.nll_loss = torch.nn.NLLLoss2d(weight, size_average, reduce=reduce)
+
+ def forward(self, inputs: torch.Tensor, targets: torch.LongTensor) -> torch.Tensor:
+ """
+ Overview:
+ Calculate soft focal loss.
+ Arguments:
+ - logits (:obj:`torch.Tensor`): Predicted logits.
+ - labels (:obj:`torch.LongTensor`): Ground truth.
+ Returns:
+ - loss (:obj:`torch.Tensor`): Calculated loss.
+ """
+ return self.nll_loss((1 - F.softmax(inputs, 1)) ** self.gamma * F.log_softmax(inputs, 1), targets)
+
+
+def build_ce_criterion(cfg: dict) -> nn.Module:
+ """
+ Overview:
+ Get a cross entropy loss instance according to given config.
+ Arguments:
+ - cfg (:obj:`dict`) : Config dict. It contains:
+ - type (:obj:`str`): Type of loss function, now supports ['cross_entropy', 'label_smooth_ce', \
+ 'soft_focal_loss'].
+ - kwargs (:obj:`dict`): Arguments for the corresponding loss function.
+ Returns:
+ - loss (:obj:`nn.Module`): loss function instance
+ """
+ if cfg.type == 'cross_entropy':
+ return nn.CrossEntropyLoss()
+ elif cfg.type == 'label_smooth_ce':
+ return LabelSmoothCELoss(cfg.kwargs.smooth_ratio)
+ elif cfg.type == 'soft_focal_loss':
+ return SoftFocalLoss()
+ else:
+ raise ValueError("invalid criterion type:{}".format(cfg.type))
diff --git a/DI-engine/ding/torch_utils/loss/multi_logits_loss.py b/DI-engine/ding/torch_utils/loss/multi_logits_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..86716cd632bf259976d95bf07e67471de0f12889
--- /dev/null
+++ b/DI-engine/ding/torch_utils/loss/multi_logits_loss.py
@@ -0,0 +1,164 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ding.torch_utils.network import one_hot
+
+
+class MultiLogitsLoss(nn.Module):
+ """
+ Overview:
+ Base class for supervised learning on linklink, including basic processes.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(self, criterion: str = None, smooth_ratio: float = 0.1) -> None:
+ """
+ Overview:
+ Initialization method, use cross_entropy as default criterion.
+ Arguments:
+ - criterion (:obj:`str`): Criterion type, supports ['cross_entropy', 'label_smooth_ce'].
+ - smooth_ratio (:obs:`float`): Smoothing ratio for label smoothing.
+ """
+ super(MultiLogitsLoss, self).__init__()
+ if criterion is None:
+ criterion = 'cross_entropy'
+ assert (criterion in ['cross_entropy', 'label_smooth_ce'])
+ self.criterion = criterion
+ if self.criterion == 'label_smooth_ce':
+ self.ratio = smooth_ratio
+
+ def _label_process(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.LongTensor:
+ """
+ Overview:
+ Process the label according to the criterion.
+ Arguments:
+ - logits (:obj:`torch.Tensor`): Predicted logits.
+ - labels (:obj:`torch.LongTensor`): Ground truth.
+ Returns:
+ - ret (:obj:`torch.LongTensor`): Processed label.
+ """
+ N = logits.shape[1]
+ if self.criterion == 'cross_entropy':
+ return one_hot(labels, num=N)
+ elif self.criterion == 'label_smooth_ce':
+ val = float(self.ratio) / (N - 1)
+ ret = torch.full_like(logits, val)
+ ret.scatter_(1, labels.unsqueeze(1), 1 - val)
+ return ret
+
+ def _nll_loss(self, nlls: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
+ """
+ Overview:
+ Calculate the negative log likelihood loss.
+ Arguments:
+ - nlls (:obj:`torch.Tensor`): Negative log likelihood loss.
+ - labels (:obj:`torch.LongTensor`): Ground truth.
+ Returns:
+ - ret (:obj:`torch.Tensor`): Calculated loss.
+ """
+ ret = (-nlls * (labels.detach()))
+ return ret.sum(dim=1)
+
+ def _get_metric_matrix(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
+ """
+ Overview:
+ Calculate the metric matrix.
+ Arguments:
+ - logits (:obj:`torch.Tensor`): Predicted logits.
+ - labels (:obj:`torch.LongTensor`): Ground truth.
+ Returns:
+ - metric (:obj:`torch.Tensor`): Calculated metric matrix.
+ """
+ M, N = logits.shape
+ labels = self._label_process(logits, labels)
+ logits = F.log_softmax(logits, dim=1)
+ metric = []
+ for i in range(M):
+ logit = logits[i]
+ logit = logit.repeat(M).reshape(M, N)
+ metric.append(self._nll_loss(logit, labels))
+ return torch.stack(metric, dim=0)
+
+ def _match(self, matrix: torch.Tensor):
+ """
+ Overview:
+ Match the metric matrix.
+ Arguments:
+ - matrix (:obj:`torch.Tensor`): Metric matrix.
+ Returns:
+ - index (:obj:`np.ndarray`): Matched index.
+ """
+ mat = matrix.clone().detach().to('cpu').numpy()
+ mat = -mat # maximize
+ M = mat.shape[0]
+ index = np.full(M, -1, dtype=np.int32) # -1 note not find link
+ lx = mat.max(axis=1)
+ ly = np.zeros(M, dtype=np.float32)
+ visx = np.zeros(M, dtype=np.bool_)
+ visy = np.zeros(M, dtype=np.bool_)
+
+ def has_augmented_path(t, binary_distance_matrix):
+ # What is changed? visx, visy, distance_matrix, index
+ # What is changed within this function? visx, visy, index
+ visx[t] = True
+ for i in range(M):
+ if not visy[i] and binary_distance_matrix[t, i]:
+ visy[i] = True
+ if index[i] == -1 or has_augmented_path(index[i], binary_distance_matrix):
+ index[i] = t
+ return True
+ return False
+
+ for i in range(M):
+ while True:
+ visx.fill(False)
+ visy.fill(False)
+ distance_matrix = self._get_distance_matrix(lx, ly, mat, M)
+ binary_distance_matrix = np.abs(distance_matrix) < 1e-4
+ if has_augmented_path(i, binary_distance_matrix):
+ break
+ masked_distance_matrix = distance_matrix[:, ~visy][visx]
+ if 0 in masked_distance_matrix.shape: # empty matrix
+ raise RuntimeError("match error, matrix: {}".format(matrix))
+ else:
+ d = masked_distance_matrix.min()
+ lx[visx] -= d
+ ly[visy] += d
+ return index
+
+ @staticmethod
+ def _get_distance_matrix(lx: np.ndarray, ly: np.ndarray, mat: np.ndarray, M: int) -> np.ndarray:
+ """
+ Overview:
+ Get distance matrix.
+ Arguments:
+ - lx (:obj:`np.ndarray`): lx.
+ - ly (:obj:`np.ndarray`): ly.
+ - mat (:obj:`np.ndarray`): mat.
+ - M (:obj:`int`): M.
+ """
+ nlx = np.broadcast_to(lx, [M, M]).T
+ nly = np.broadcast_to(ly, [M, M])
+ nret = nlx + nly - mat
+ return nret
+
+ def forward(self, logits: torch.Tensor, labels: torch.LongTensor) -> torch.Tensor:
+ """
+ Overview:
+ Calculate multiple logits loss.
+ Arguments:
+ - logits (:obj:`torch.Tensor`): Predicted logits, whose shape must be 2-dim, like (B, N).
+ - labels (:obj:`torch.LongTensor`): Ground truth.
+ Returns:
+ - loss (:obj:`torch.Tensor`): Calculated loss.
+ """
+ assert (len(logits.shape) == 2)
+ metric_matrix = self._get_metric_matrix(logits, labels)
+ index = self._match(metric_matrix)
+ loss = []
+ for i in range(metric_matrix.shape[0]):
+ loss.append(metric_matrix[index[i], i])
+ return sum(loss) / len(loss)
diff --git a/DI-engine/ding/torch_utils/loss/tests/test_contrastive_loss.py b/DI-engine/ding/torch_utils/loss/tests/test_contrastive_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..41506bc476059fd8d4e00211fab182bf33d01e49
--- /dev/null
+++ b/DI-engine/ding/torch_utils/loss/tests/test_contrastive_loss.py
@@ -0,0 +1,51 @@
+import pytest
+import numpy as np
+import torch
+from torch.utils.data import TensorDataset, DataLoader
+from ding.torch_utils.loss.contrastive_loss import ContrastiveLoss
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('noise', [0.1, 1.0])
+@pytest.mark.parametrize('dims', [16, [1, 16, 16], [1, 40, 40]])
+def test_infonce_loss(noise, dims):
+ print_loss = False
+ batch_size = 128
+ N_batch = 3
+ if isinstance(dims, int):
+ x_dim = [batch_size * N_batch, dims]
+ else:
+ x_dim = [batch_size * N_batch] + dims
+
+ encode_shape = 16
+ x = np.random.normal(0, 1, size=x_dim)
+ y = x ** 2 + noise * np.random.normal(0, 1, size=x_dim)
+
+ estimator = ContrastiveLoss(dims, dims, encode_shape=encode_shape)
+ dataset = TensorDataset(torch.Tensor(x), torch.Tensor(y))
+ dataloader = DataLoader(dataset, batch_size=batch_size)
+ optimizer = torch.optim.Adam(estimator.parameters(), lr=3e-4)
+
+ for epoch in range(3):
+ train_loss = 0.
+ test_loss = 0.
+ for inputs in dataloader:
+ x, y = inputs
+ optimizer.zero_grad()
+ loss = estimator.forward(x, y)
+ loss.backward()
+ optimizer.step()
+ train_loss += loss.item()
+
+ with torch.no_grad():
+ for inputs in dataloader:
+ x, y = inputs
+ outputs = estimator.forward(x, y)
+ test_loss += outputs.item()
+
+ if print_loss:
+ print(
+ "epoch {}: test_loss: {:.4f}, \t test_loss: {:.4f}".format(
+ epoch, train_loss / N_batch, test_loss / N_batch
+ )
+ )
diff --git a/DI-engine/ding/torch_utils/loss/tests/test_cross_entropy_loss.py b/DI-engine/ding/torch_utils/loss/tests/test_cross_entropy_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..e311ffdb9b33cf7aecca877f3642cebe7067c884
--- /dev/null
+++ b/DI-engine/ding/torch_utils/loss/tests/test_cross_entropy_loss.py
@@ -0,0 +1,28 @@
+import pytest
+import torch
+import torch.nn as nn
+
+from ding.torch_utils import LabelSmoothCELoss, SoftFocalLoss
+
+
+@pytest.mark.unittest
+class TestLabelSmoothCE:
+
+ def test_label_smooth_ce_loss(self):
+ logits = torch.randn(4, 6)
+ labels = torch.LongTensor([i for i in range(4)])
+ criterion1 = LabelSmoothCELoss(0)
+ criterion2 = nn.CrossEntropyLoss()
+ assert (torch.abs(criterion1(logits, labels) - criterion2(logits, labels)) < 1e-6)
+
+
+@pytest.mark.unittest
+class TestSoftFocalLoss:
+
+ def test_soft_focal_loss(self):
+ logits = torch.randn(4, 6)
+ labels = torch.LongTensor([i for i in range(4)])
+ criterion = SoftFocalLoss()
+ loss = criterion(logits, labels)
+ assert loss.shape == ()
+ loss_value = loss.item()
diff --git a/DI-engine/ding/torch_utils/loss/tests/test_multi_logits_loss.py b/DI-engine/ding/torch_utils/loss/tests/test_multi_logits_loss.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a6a983f6870063ef715de6be50b2bbc4ff35a2e
--- /dev/null
+++ b/DI-engine/ding/torch_utils/loss/tests/test_multi_logits_loss.py
@@ -0,0 +1,16 @@
+import pytest
+import torch
+from ding.torch_utils import MultiLogitsLoss
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('criterion_type', ['cross_entropy', 'label_smooth_ce'])
+def test_multi_logits_loss(criterion_type):
+ logits = torch.randn(4, 8).requires_grad_(True)
+ label = torch.LongTensor([0, 1, 3, 2])
+ criterion = MultiLogitsLoss(criterion=criterion_type)
+ loss = criterion(logits, label)
+ assert loss.shape == ()
+ assert logits.grad is None
+ loss.backward()
+ assert isinstance(logits, torch.Tensor)
diff --git a/DI-engine/ding/torch_utils/lr_scheduler.py b/DI-engine/ding/torch_utils/lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..09d341430a3542a46b38cf06a002c7dde53c5418
--- /dev/null
+++ b/DI-engine/ding/torch_utils/lr_scheduler.py
@@ -0,0 +1,60 @@
+from functools import partial
+import math
+
+import torch.optim
+from torch.optim.lr_scheduler import LambdaLR
+
+
+def get_lr_ratio(epoch: int, warmup_epochs: int, learning_rate: float, lr_decay_epochs: int, min_lr: float) -> float:
+ """
+ Overview:
+ Get learning rate ratio for each epoch.
+ Arguments:
+ - epoch (:obj:`int`): Current epoch.
+ - warmup_epochs (:obj:`int`): Warmup epochs.
+ - learning_rate (:obj:`float`): Learning rate.
+ - lr_decay_epochs (:obj:`int`): Learning rate decay epochs.
+ - min_lr (:obj:`float`): Minimum learning rate.
+ """
+
+ # 1) linear warmup for warmup_epochs.
+ if epoch < warmup_epochs:
+ return epoch / warmup_epochs
+ # 2) if epoch> lr_decay_epochs, return min learning rate
+ if epoch > lr_decay_epochs:
+ return min_lr / learning_rate
+ # 3) in between, use cosine decay down to min learning rate
+ decay_ratio = (epoch - warmup_epochs) / (lr_decay_epochs - warmup_epochs)
+ assert 0 <= decay_ratio <= 1
+ coefficient = 0.5 * (1.0 + math.cos(math.pi * decay_ratio))
+ return (min_lr + coefficient * (learning_rate - min_lr)) / learning_rate
+
+
+def cos_lr_scheduler(
+ optimizer: torch.optim.Optimizer,
+ learning_rate: float,
+ warmup_epochs: float = 5,
+ lr_decay_epochs: float = 100,
+ min_lr: float = 6e-5
+) -> torch.optim.lr_scheduler.LambdaLR:
+ """
+ Overview:
+ Cosine learning rate scheduler.
+ Arguments:
+ - optimizer (:obj:`torch.optim.Optimizer`): Optimizer.
+ - learning_rate (:obj:`float`): Learning rate.
+ - warmup_epochs (:obj:`float`): Warmup epochs.
+ - lr_decay_epochs (:obj:`float`): Learning rate decay epochs.
+ - min_lr (:obj:`float`): Minimum learning rate.
+ """
+
+ return LambdaLR(
+ optimizer,
+ partial(
+ get_lr_ratio,
+ warmup_epochs=warmup_epochs,
+ lr_decay_epochs=lr_decay_epochs,
+ min_lr=min_lr,
+ learning_rate=learning_rate
+ )
+ )
diff --git a/DI-engine/ding/torch_utils/math_helper.py b/DI-engine/ding/torch_utils/math_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..751c8075979ed11a037614ce9537b0830e59a307
--- /dev/null
+++ b/DI-engine/ding/torch_utils/math_helper.py
@@ -0,0 +1,76 @@
+from typing import Optional
+import torch
+
+
+def cov(
+ x: torch.Tensor,
+ rowvar: bool = False,
+ bias: bool = False,
+ ddof: Optional[int] = None,
+ aweights: Optional[torch.Tensor] = None
+) -> torch.Tensor:
+ """
+ Overview:
+ Estimates covariance matrix like ``numpy.cov``.
+ Arguments:
+ - x (:obj:`torch.Tensor`): A 1-D or 2-D tensor containing multiple variables and observations. Each row of \
+ ``x`` represents a variable, and each column a single observation of all those variables.
+ - rowvar (:obj:`bool`): If ``rowvar`` is True by default, and each column is a single observation of all those \
+ variables. Otherwise, each column represents a variable, while the rows contain observations.
+ - bias (:obj:`bool`): Default normalization (False) is by dividing ``N - 1``, where ``N`` is the number of \
+ observations given (unbiased estimate). If ``bias`` is ``True``, then normalization is by ``N``.
+ - ddof (:obj:`Optional[int]`): If ``ddof`` is not ``None``, it implies that the argument ``bias`` is \
+ overridden. Note that ``ddof=1`` will return the unbiased estimate (equals to ``bias=False``), and \
+ ``ddof=0`` will return the biased estimation (equals to ``bias=True``).
+ - aweights (:obj:`Optional[torch.Tensor]`): 1-D tensor of observation vector weights. These relative weights \
+ are typically large for observations considered “important” and smaller for observations considered less \
+ “important”. If ``ddof=0``, the tensor of weights can be used to assign weights to observation vectors.
+ Returns:
+ - cov_mat (:obj:`torch.Tensor`): Covariance matrix calculated.
+ """
+ if x.dim() == 1 and rowvar:
+ raise NotImplementedError
+ # ensure at least 2D
+ if x.dim() == 1:
+ x = x.view(-1, 1)
+
+ # treat each column as a data point, each row as a variable
+ if rowvar and x.shape[0] != 1:
+ x = x.t()
+
+ if ddof is None:
+ if bias == 0:
+ ddof = 1
+ else:
+ ddof = 0
+
+ w = aweights
+ if w is not None:
+ if not torch.is_tensor(w):
+ w = torch.tensor(w, dtype=torch.float)
+ w_sum = torch.sum(w)
+ avg = torch.sum(x * (w / w_sum)[:, None], 0)
+ else:
+ avg = torch.mean(x, 0)
+
+ # Determine the normalization
+ if w is None:
+ fact = x.shape[0] - ddof
+ elif ddof == 0:
+ fact = w_sum
+ # elif aweights is None:
+ # fact = w_sum - ddof
+ else:
+ fact = w_sum - ddof * torch.sum(w * w) / w_sum
+
+ xm = x.sub(avg.expand_as(x))
+
+ if w is None:
+ X_T = xm.t()
+ else:
+ X_T = torch.mm(torch.diag(w), xm).t()
+
+ c = torch.mm(X_T, xm)
+ c = c / fact
+
+ return c.squeeze()
diff --git a/DI-engine/ding/torch_utils/metric.py b/DI-engine/ding/torch_utils/metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..75554c211da72f79d4f00b6927951a88d6fc1eb9
--- /dev/null
+++ b/DI-engine/ding/torch_utils/metric.py
@@ -0,0 +1,80 @@
+import torch
+from typing import Optional, Callable
+
+
+def levenshtein_distance(
+ pred: torch.LongTensor,
+ target: torch.LongTensor,
+ pred_extra: Optional[torch.Tensor] = None,
+ target_extra: Optional[torch.Tensor] = None,
+ extra_fn: Optional[Callable] = None
+) -> torch.FloatTensor:
+ """
+ Overview:
+ Levenshtein Distance, i.e. Edit Distance.
+ Arguments:
+ - pred (:obj:`torch.LongTensor`): The first tensor to calculate the distance, shape: (N1, ) (N1 >= 0).
+ - target (:obj:`torch.LongTensor`): The second tensor to calculate the distance, shape: (N2, ) (N2 >= 0).
+ - pred_extra (:obj:`Optional[torch.Tensor]`): Extra tensor to calculate the distance, only works when \
+ ``extra_fn`` is not ``None``.
+ - target_extra (:obj:`Optional[torch.Tensor]`): Extra tensor to calculate the distance, only works when \
+ ``extra_fn`` is not ``None``.
+ - extra_fn (:obj:`Optional[Callable]`): The distance function for ``pred_extra`` and \
+ ``target_extra``. If set to ``None``, this distance will not be considered.
+ Returns:
+ - distance (:obj:`torch.FloatTensor`): distance(scalar), shape: (1, ).
+ """
+ assert (isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor))
+ assert (pred.dtype == torch.long and target.dtype == torch.long), '{}\t{}'.format(pred.dtype, target.dtype)
+ assert (pred.device == target.device)
+ assert (type(pred_extra) == type(target_extra))
+ if not extra_fn:
+ assert (not pred_extra)
+ N1, N2 = pred.shape[0], target.shape[0]
+ assert (N1 >= 0 and N2 >= 0)
+ if N1 == 0 or N2 == 0:
+ distance = max(N1, N2)
+ else:
+ dp_array = torch.zeros(N1, N2).float()
+ if extra_fn:
+ if pred[0] == target[0]:
+ extra = extra_fn(pred_extra[0], target_extra[0])
+ else:
+ extra = 1.
+ dp_array[0, :] = torch.arange(0, N2) + extra
+ dp_array[:, 0] = torch.arange(0, N1) + extra
+ else:
+ dp_array[0, :] = torch.arange(0, N2)
+ dp_array[:, 0] = torch.arange(0, N1)
+ for i in range(1, N1):
+ for j in range(1, N2):
+ if pred[i] == target[j]:
+ if extra_fn:
+ dp_array[i, j] = dp_array[i - 1, j - 1] + extra_fn(pred_extra[i], target_extra[j])
+ else:
+ dp_array[i, j] = dp_array[i - 1, j - 1]
+ else:
+ dp_array[i, j] = min(dp_array[i - 1, j] + 1, dp_array[i, j - 1] + 1, dp_array[i - 1, j - 1] + 1)
+ distance = dp_array[N1 - 1, N2 - 1]
+ return torch.FloatTensor([distance]).to(pred.device)
+
+
+def hamming_distance(pred: torch.LongTensor, target: torch.LongTensor, weight=1.) -> torch.LongTensor:
+ """
+ Overview:
+ Hamming Distance.
+ Arguments:
+ - pred (:obj:`torch.LongTensor`): Pred input, boolean vector(0 or 1).
+ - target (:obj:`torch.LongTensor`): Target input, boolean vector(0 or 1).
+ - weight (:obj:`torch.LongTensor`): Weight to multiply.
+ Returns:
+ - distance(:obj:`torch.LongTensor`): Distance (scalar), shape (1, ).
+ Shapes:
+ - pred & target (:obj:`torch.LongTensor`): shape :math:`(B, N)`, \
+ while B is the batch size, N is the dimension
+ """
+ assert (isinstance(pred, torch.Tensor) and isinstance(target, torch.Tensor))
+ assert (pred.dtype == torch.long and target.dtype == torch.long)
+ assert (pred.device == target.device)
+ assert (pred.shape == target.shape)
+ return pred.ne(target).sum(dim=1).float().mul_(weight)
diff --git a/DI-engine/ding/torch_utils/model_helper.py b/DI-engine/ding/torch_utils/model_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..05a83ab773c9eeaaf351d2cc6dd2ab25d9144775
--- /dev/null
+++ b/DI-engine/ding/torch_utils/model_helper.py
@@ -0,0 +1,17 @@
+import torch
+
+
+def get_num_params(model: torch.nn.Module) -> int:
+ """
+ Overview:
+ Return the number of parameters in the model.
+ Arguments:
+ - model (:obj:`torch.nn.Module`): The model object to calculate the parameter number.
+ Returns:
+ - n_params (:obj:`int`): The calculated number of parameters.
+ Examples:
+ >>> model = torch.nn.Linear(3, 5)
+ >>> num = get_num_params(model)
+ >>> assert num == 15
+ """
+ return sum(p.numel() for p in model.parameters())
diff --git a/DI-engine/ding/torch_utils/network/__init__.py b/DI-engine/ding/torch_utils/network/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..dda50c339e54f838e3c6aef195d73ec180e55f72
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/__init__.py
@@ -0,0 +1,15 @@
+from .activation import build_activation, Swish
+from .res_block import ResBlock, ResFCBlock
+from .nn_module import fc_block, conv2d_block, one_hot, deconv2d_block, BilinearUpsample, NearestUpsample, \
+ binary_encode, NoiseLinearLayer, noise_block, MLP, Flatten, normed_linear, normed_conv2d, conv1d_block
+from .normalization import build_normalization
+from .rnn import get_lstm, sequence_mask
+from .soft_argmax import SoftArgmax
+from .transformer import Transformer, ScaledDotProductAttention
+from .scatter_connection import ScatterConnection
+from .resnet import resnet18, ResNet
+from .gumbel_softmax import GumbelSoftmax
+from .gtrxl import GTrXL, GRUGatingUnit
+from .popart import PopArt
+#from .dreamer import Conv2dSame, DreamerLayerNorm, ActionHead, DenseHead
+from .merge import GatingType, SumMerge, VectorMerge
diff --git a/DI-engine/ding/torch_utils/network/activation.py b/DI-engine/ding/torch_utils/network/activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..b3c8fcda4c84fc42516897a53e24ece441efd128
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/activation.py
@@ -0,0 +1,168 @@
+import math
+from collections.abc import Callable
+
+import torch
+import torch.nn as nn
+
+
+class Lambda(nn.Module):
+ """
+ Overview:
+ A custom lambda module for constructing custom layers.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(self, f: Callable):
+ """
+ Overview:
+ Initialize the lambda module with a given function.
+ Arguments:
+ - f (:obj:`Callable`): a python function
+ """
+ super(Lambda, self).__init__()
+ self.f = f
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Compute the function of the input tensor.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ """
+ return self.f(x)
+
+
+class GLU(nn.Module):
+ """
+ Overview:
+ Gating Linear Unit (GLU), a specific type of activation function, which is first proposed in
+ [Language Modeling with Gated Convolutional Networks](https://arxiv.org/pdf/1612.08083.pdf).
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(self, input_dim: int, output_dim: int, context_dim: int, input_type: str = 'fc') -> None:
+ """
+ Overview:
+ Initialize the GLU module.
+ Arguments:
+ - input_dim (:obj:`int`): The dimension of the input tensor.
+ - output_dim (:obj:`int`): The dimension of the output tensor.
+ - context_dim (:obj:`int`): The dimension of the context tensor.
+ - input_type (:obj:`str`): The type of input, now supports ['fc', 'conv2d']
+ """
+ super(GLU, self).__init__()
+ assert (input_type in ['fc', 'conv2d'])
+ if input_type == 'fc':
+ self.layer1 = nn.Linear(context_dim, input_dim)
+ self.layer2 = nn.Linear(input_dim, output_dim)
+ elif input_type == 'conv2d':
+ self.layer1 = nn.Conv2d(context_dim, input_dim, 1, 1, 0)
+ self.layer2 = nn.Conv2d(input_dim, output_dim, 1, 1, 0)
+
+ def forward(self, x: torch.Tensor, context: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Compute the GLU transformation of the input tensor.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ - context (:obj:`torch.Tensor`): The context tensor.
+ Returns:
+ - x (:obj:`torch.Tensor`): The output tensor after GLU transformation.
+ """
+ gate = self.layer1(context)
+ gate = torch.sigmoid(gate)
+ x = gate * x
+ x = self.layer2(x)
+ return x
+
+
+class Swish(nn.Module):
+ """
+ Overview:
+ Swish activation function, which is a smooth, non-monotonic activation function. For more details, please refer
+ to [Searching for Activation Functions](https://arxiv.org/pdf/1710.05941.pdf).
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(self):
+ """
+ Overview:
+ Initialize the Swish module.
+ """
+ super(Swish, self).__init__()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Compute the Swish transformation of the input tensor.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - x (:obj:`torch.Tensor`): The output tensor after Swish transformation.
+ """
+ return x * torch.sigmoid(x)
+
+
+class GELU(nn.Module):
+ """
+ Overview:
+ Gaussian Error Linear Units (GELU) activation function, which is widely used in NLP models like GPT, BERT.
+ For more details, please refer to the original paper: https://arxiv.org/pdf/1606.08415.pdf.
+ Interfaces:
+ ``__init__``, ``forward``.
+ """
+
+ def __init__(self):
+ """
+ Overview:
+ Initialize the GELU module.
+ """
+ super(GELU, self).__init__()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Compute the GELU transformation of the input tensor.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - x (:obj:`torch.Tensor`): The output tensor after GELU transformation.
+ """
+ return 0.5 * x * (1.0 + torch.tanh(math.sqrt(2.0 / math.pi) * (x + 0.044715 * torch.pow(x, 3.0))))
+
+
+def build_activation(activation: str, inplace: bool = None) -> nn.Module:
+ """
+ Overview:
+ Build and return the activation module according to the given type.
+ Arguments:
+ - activation (:obj:`str`): The type of activation module, now supports \
+ ['relu', 'glu', 'prelu', 'swish', 'gelu', 'tanh', 'sigmoid', 'softplus', 'elu', 'square', 'identity'].
+ - inplace (Optional[:obj:`bool`): Execute the operation in-place in activation, defaults to None.
+ Returns:
+ - act_func (:obj:`nn.module`): The corresponding activation module.
+ """
+ if inplace is not None:
+ assert activation == 'relu', 'inplace argument is not compatible with {}'.format(activation)
+ else:
+ inplace = False
+ act_func = {
+ 'relu': nn.ReLU(inplace=inplace),
+ 'glu': GLU,
+ 'prelu': nn.PReLU(),
+ 'swish': Swish(),
+ 'gelu': GELU(),
+ "tanh": nn.Tanh(),
+ "sigmoid": nn.Sigmoid(),
+ "softplus": nn.Softplus(),
+ "elu": nn.ELU(),
+ "square": Lambda(lambda x: x ** 2),
+ "identity": Lambda(lambda x: x),
+ }
+ if activation.lower() in act_func.keys():
+ return act_func[activation]
+ else:
+ raise KeyError("invalid key for activation: {}".format(activation))
diff --git a/DI-engine/ding/torch_utils/network/diffusion.py b/DI-engine/ding/torch_utils/network/diffusion.py
new file mode 100755
index 0000000000000000000000000000000000000000..deb95c9022f0b603fa6968d5baa09c409748421b
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/diffusion.py
@@ -0,0 +1,661 @@
+from typing import Union, List, Dict
+from collections import namedtuple
+import numpy as np
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from ding.utils import list_split, MODEL_REGISTRY, squeeze, SequenceType
+
+
+def extract(a, t, x_shape):
+ """
+ Overview:
+ extract output from a through index t.
+ Arguments:
+ - a (:obj:`torch.Tensor`): input tensor
+ - t (:obj:`torch.Tensor`): index tensor
+ - x_shape (:obj:`torch.Tensor`): shape of x
+ """
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1, ) * (len(x_shape) - 1)))
+
+
+def cosine_beta_schedule(timesteps: int, s: float = 0.008, dtype=torch.float32):
+ """
+ Overview:
+ cosine schedule
+ as proposed in https://openreview.net/forum?id=-NEXDKk8gZ
+ Arguments:
+ - timesteps (:obj:`int`): timesteps of diffusion step
+ - s (:obj:`float`): s
+ - dtype (:obj:`torch.dtype`): dtype of beta
+ Return:
+ Tensor of beta [timesteps,], computing by cosine.
+ """
+ steps = timesteps + 1
+ x = np.linspace(0, steps, steps)
+ alphas_cumprod = np.cos(((x / steps) + s) / (1 + s) * np.pi * 0.5) ** 2
+ alphas_cumprod = alphas_cumprod / alphas_cumprod[0]
+ betas = 1 - (alphas_cumprod[1:] / alphas_cumprod[:-1])
+ betas_clipped = np.clip(betas, a_min=0, a_max=0.999)
+ return torch.tensor(betas_clipped, dtype=dtype)
+
+
+def apply_conditioning(x, conditions, action_dim):
+ """
+ Overview:
+ add condition into x
+ Arguments:
+ - x (:obj:`torch.Tensor`): input tensor
+ - conditions (:obj:`dict`): condition dict, key is timestep, value is condition
+ - action_dim (:obj:`int`): action dim
+ """
+ for t, val in conditions.items():
+ x[:, t, action_dim:] = val.clone()
+ return x
+
+
+class DiffusionConv1d(nn.Module):
+ """
+ Overview:
+ Conv1d with activation and normalization for diffusion models.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ padding: int,
+ activation: nn.Module = None,
+ n_groups: int = 8
+ ) -> None:
+ """
+ Overview:
+ Create a 1-dim convlution layer with activation and normalization. This Conv1d have GropuNorm.
+ And need add 1-dim when compute norm
+ Arguments:
+ - in_channels (:obj:`int`): Number of channels in the input tensor
+ - out_channels (:obj:`int`): Number of channels in the output tensor
+ - kernel_size (:obj:`int`): Size of the convolving kernel
+ - padding (:obj:`int`): Zero-padding added to both sides of the input
+ - activation (:obj:`nn.Module`): the optional activation function
+ """
+ super().__init__()
+ self.conv1 = nn.Conv1d(in_channels, out_channels, kernel_size=kernel_size, padding=padding)
+ self.norm = nn.GroupNorm(n_groups, out_channels)
+ self.act = activation
+
+ def forward(self, inputs) -> torch.Tensor:
+ """
+ Overview:
+ compute conv1d for inputs.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): input tensor
+ Return:
+ - out (:obj:`torch.Tensor`): output tensor
+ """
+ x = self.conv1(inputs)
+ # [batch, channels, horizon] -> [batch, channels, 1, horizon]
+ x = x.unsqueeze(-2)
+ x = self.norm(x)
+ # [batch, channels, 1, horizon] -> [batch, channels, horizon]
+ x = x.squeeze(-2)
+ out = self.act(x)
+ return out
+
+
+class SinusoidalPosEmb(nn.Module):
+ """
+ Overview:
+ class for computing sin position embeding
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, dim: int) -> None:
+ """
+ Overview:
+ Initialization of SinusoidalPosEmb class
+ Arguments:
+ - dim (:obj:`int`): dimension of embeding
+ """
+
+ super().__init__()
+ self.dim = dim
+
+ def forward(self, x) -> torch.Tensor:
+ """
+ Overview:
+ compute sin position embeding
+ Arguments:
+ - x (:obj:`torch.Tensor`): input tensor
+ Return:
+ - emb (:obj:`torch.Tensor`): output tensor
+ """
+
+ device = x.device
+ half_dim = self.dim // 2
+ emb = math.log(10000) / (half_dim - 1)
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
+ emb = x[:, None] * emb[None, :]
+ emb = torch.cat((emb.sin(), emb.cos()), dim=1)
+ return emb
+
+
+class Residual(nn.Module):
+ """
+ Overview:
+ Basic Residual block
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, fn):
+ """
+ Overview:
+ Initialization of Residual class
+ Arguments:
+ - fn (:obj:`nn.Module`): function of residual block
+ """
+
+ super().__init__()
+ self.fn = fn
+
+ def forward(self, x, *arg, **kwargs):
+ """
+ Overview:
+ compute residual block
+ Arguments:
+ - x (:obj:`torch.Tensor`): input tensor
+ """
+
+ return self.fn(x, *arg, **kwargs) + x
+
+
+class LayerNorm(nn.Module):
+ """
+ Overview:
+ LayerNorm, compute dim = 1, because Temporal input x [batch, dim, horizon]
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, dim, eps=1e-5) -> None:
+ """
+ Overview:
+ Initialization of LayerNorm class
+ Arguments:
+ - dim (:obj:`int`): dimension of input
+ - eps (:obj:`float`): eps of LayerNorm
+ """
+
+ super().__init__()
+ self.eps = eps
+ self.g = nn.Parameter(torch.ones(1, dim, 1))
+ self.b = nn.Parameter(torch.zeros(1, dim, 1))
+
+ def forward(self, x):
+ """
+ Overview:
+ compute LayerNorm
+ Arguments:
+ - x (:obj:`torch.Tensor`): input tensor
+ """
+
+ print('x.shape:', x.shape)
+ var = torch.var(x, dim=1, unbiased=False, keepdim=True)
+ mean = torch.mean(x, dim=1, keepdim=True)
+ return (x - mean) / (var + self.eps).sqrt() * self.g + self.b
+
+
+class PreNorm(nn.Module):
+ """
+ Overview:
+ PreNorm, compute dim = 1, because Temporal input x [batch, dim, horizon]
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, dim, fn) -> None:
+ """
+ Overview:
+ Initialization of PreNorm class
+ Arguments:
+ - dim (:obj:`int`): dimension of input
+ - fn (:obj:`nn.Module`): function of residual block
+ """
+
+ super().__init__()
+ self.fn = fn
+ self.norm = LayerNorm(dim)
+
+ def forward(self, x):
+ """
+ Overview:
+ compute PreNorm
+ Arguments:
+ - x (:obj:`torch.Tensor`): input tensor
+ """
+ x = self.norm(x)
+ return self.fn(x)
+
+
+class LinearAttention(nn.Module):
+ """
+ Overview:
+ Linear Attention head
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, dim, heads=4, dim_head=32) -> None:
+ """
+ Overview:
+ Initialization of LinearAttention class
+ Arguments:
+ - dim (:obj:`int`): dimension of input
+ - heads (:obj:`int`): heads of attention
+ - dim_head (:obj:`int`): dim of head
+ """
+ super().__init__()
+ self.scale = dim_head ** -0.5
+ self.heads = heads
+ hidden_dim = dim_head * heads
+ self.to_qkv = nn.Conv1d(dim, hidden_dim * 3, 1, bias=False)
+ self.to_out = nn.Conv1d(hidden_dim, dim, 1)
+
+ def forward(self, x):
+ """
+ Overview:
+ compute LinearAttention
+ Arguments:
+ - x (:obj:`torch.Tensor`): input tensor
+ """
+ qkv = self.to_qkv(x).chunk(3, dim=1)
+ q, k, v = map(lambda t: t.reshape(t.shape[0], self.heads, -1, t.shape[-1]), qkv)
+ q = q * self.scale
+ k = k.softmax(dim=-1)
+ context = torch.einsum('b h d n, b h e n -> b h d e', k, v)
+
+ out = torch.einsum('b h d e, b h d n -> b h e n', context, q)
+ out = out.reshape(out.shape[0], -1, out.shape[-1])
+ return self.to_out(out)
+
+
+class ResidualTemporalBlock(nn.Module):
+ """
+ Overview:
+ Residual block of temporal
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self, in_channels: int, out_channels: int, embed_dim: int, kernel_size: int = 5, mish: bool = True
+ ) -> None:
+ """
+ Overview:
+ Initialization of ResidualTemporalBlock class
+ Arguments:
+ - in_channels (:obj:'int'): dim of in_channels
+ - out_channels (:obj:'int'): dim of out_channels
+ - embed_dim (:obj:'int'): dim of embeding layer
+ - kernel_size (:obj:'int'): kernel_size of conv1d
+ - mish (:obj:'bool'): whether use mish as a activate function
+ """
+ super().__init__()
+ if mish:
+ act = nn.Mish()
+ else:
+ act = nn.SiLU()
+ self.blocks = nn.ModuleList(
+ [
+ DiffusionConv1d(in_channels, out_channels, kernel_size, kernel_size // 2, act),
+ DiffusionConv1d(out_channels, out_channels, kernel_size, kernel_size // 2, act),
+ ]
+ )
+ self.time_mlp = nn.Sequential(
+ act,
+ nn.Linear(embed_dim, out_channels),
+ )
+ self.residual_conv = nn.Conv1d(in_channels, out_channels, 1) \
+ if in_channels != out_channels else nn.Identity()
+
+ def forward(self, x, t):
+ """
+ Overview:
+ compute residual block
+ Arguments:
+ - x (:obj:'tensor'): input tensor
+ - t (:obj:'tensor'): time tensor
+ """
+ out = self.blocks[0](x) + self.time_mlp(t).unsqueeze(-1)
+ out = self.blocks[1](out)
+ return out + self.residual_conv(x)
+
+
+class DiffusionUNet1d(nn.Module):
+ """
+ Overview:
+ Diffusion unet for 1d vector data
+ Interfaces:
+ ``__init__``, ``forward``, ``get_pred``
+ """
+
+ def __init__(
+ self,
+ transition_dim: int,
+ dim: int = 32,
+ dim_mults: SequenceType = [1, 2, 4, 8],
+ returns_condition: bool = False,
+ condition_dropout: float = 0.1,
+ calc_energy: bool = False,
+ kernel_size: int = 5,
+ attention: bool = False,
+ ) -> None:
+ """
+ Overview:
+ Initialization of DiffusionUNet1d class
+ Arguments:
+ - transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim
+ - dim (:obj:'int'): dim of layer
+ - dim_mults (:obj:'SequenceType'): mults of dim
+ - returns_condition (:obj:'bool'): whether use return as a condition
+ - condition_dropout (:obj:'float'): dropout of returns condition
+ - calc_energy (:obj:'bool'): whether use calc_energy
+ - kernel_size (:obj:'int'): kernel_size of conv1d
+ - attention (:obj:'bool'): whether use attention
+ """
+ super().__init__()
+ dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
+ in_out = list(zip(dims[:-1], dims[1:]))
+
+ if calc_energy:
+ mish = False
+ act = nn.SiLU()
+ else:
+ mish = True
+ act = nn.Mish()
+
+ self.time_dim = dim
+ self.returns_dim = dim
+
+ self.time_mlp = nn.Sequential(
+ SinusoidalPosEmb(dim),
+ nn.Linear(dim, dim * 4),
+ act,
+ nn.Linear(dim * 4, dim),
+ )
+
+ self.returns_condition = returns_condition
+ self.condition_dropout = condition_dropout
+ self.cale_energy = calc_energy
+
+ if self.returns_condition:
+ self.returns_mlp = nn.Sequential(
+ nn.Linear(1, dim),
+ act,
+ nn.Linear(dim, dim * 4),
+ act,
+ nn.Linear(dim * 4, dim),
+ )
+ self.mask_dist = torch.distributions.Bernoulli(probs=1 - self.condition_dropout)
+ embed_dim = 2 * dim
+ else:
+ embed_dim = dim
+
+ self.downs = nn.ModuleList([])
+ self.ups = nn.ModuleList([])
+ num_resolution = len(in_out)
+
+ for ind, (dim_in, dim_out) in enumerate(in_out):
+ is_last = ind >= (num_resolution - 1)
+ self.downs.append(
+ nn.ModuleList(
+ [
+ ResidualTemporalBlock(dim_in, dim_out, embed_dim, kernel_size, mish=mish),
+ ResidualTemporalBlock(dim_out, dim_out, embed_dim, kernel_size, mish=mish),
+ Residual(PreNorm(dim_out, LinearAttention(dim_out))) if attention else nn.Identity(),
+ nn.Conv1d(dim_out, dim_out, 3, 2, 1) if not is_last else nn.Identity()
+ ]
+ )
+ )
+
+ mid_dim = dims[-1]
+ self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish)
+ self.mid_atten = Residual(PreNorm(mid_dim, LinearAttention(mid_dim))) if attention else nn.Identity()
+ self.mid_block2 = ResidualTemporalBlock(mid_dim, mid_dim, embed_dim, kernel_size, mish)
+
+ for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
+ is_last = ind >= (num_resolution - 1)
+ self.ups.append(
+ nn.ModuleList(
+ [
+ ResidualTemporalBlock(dim_out * 2, dim_in, embed_dim, kernel_size, mish=mish),
+ ResidualTemporalBlock(dim_in, dim_in, embed_dim, kernel_size, mish=mish),
+ Residual(PreNorm(dim_in, LinearAttention(dim_in))) if attention else nn.Identity(),
+ nn.ConvTranspose1d(dim_in, dim_in, 4, 2, 1) if not is_last else nn.Identity()
+ ]
+ )
+ )
+
+ self.final_conv = nn.Sequential(
+ DiffusionConv1d(dim, dim, kernel_size=kernel_size, padding=kernel_size // 2, activation=act),
+ nn.Conv1d(dim, transition_dim, 1),
+ )
+
+ def forward(self, x, cond, time, returns=None, use_dropout: bool = True, force_dropout: bool = False):
+ """
+ Overview:
+ compute diffusion unet forward
+ Arguments:
+ - x (:obj:'tensor'): noise trajectory
+ - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0
+ - time (:obj:'int'): timestep of diffusion step
+ - returns (:obj:'tensor'): condition returns of trajectory, returns is normal return
+ - use_dropout (:obj:'bool'): Whether use returns condition mask
+ - force_dropout (:obj:'bool'): Whether use returns condition
+ """
+ if self.cale_energy:
+ x_inp = x
+
+ # [batch, horizon, transition ] -> [batch, transition , horizon]
+ x = x.transpose(1, 2)
+ t = self.time_mlp(time)
+
+ if self.returns_condition:
+ assert returns is not None
+ returns_embed = self.returns_mlp(returns)
+ if use_dropout:
+ mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device)
+ returns_embed = mask * returns_embed
+ if force_dropout:
+ returns_embed = 0 * returns_embed
+ t = torch.cat([t, returns_embed], dim=-1)
+
+ h = []
+
+ for resnet, resnet2, atten, downsample in self.downs:
+ x = resnet(x, t)
+ x = resnet2(x, t)
+ x = atten(x)
+ h.append(x)
+ x = downsample(x)
+
+ x = self.mid_block1(x, t)
+ x = self.mid_atten(x)
+ x = self.mid_block2(x, t)
+
+ for resnet, resnet2, atten, upsample in self.ups:
+ x = torch.cat((x, h.pop()), dim=1)
+ x = resnet(x, t)
+ x = resnet2(x, t)
+ x = atten(x)
+ x = upsample(x)
+
+ x = self.final_conv(x)
+ # [batch, transition , horizon] -> [batch, horizon, transition ]
+ x = x.transpose(1, 2)
+
+ if self.cale_energy:
+ # Energy function
+ energy = ((x - x_inp) ** 2).mean()
+ grad = torch.autograd.grad(outputs=energy, inputs=x_inp, create_graph=True)
+ return grad[0]
+ else:
+ return x
+
+ def get_pred(self, x, cond, time, returns: bool = None, use_dropout: bool = True, force_dropout: bool = False):
+ """
+ Overview:
+ compute diffusion unet forward
+ Arguments:
+ - x (:obj:'tensor'): noise trajectory
+ - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0
+ - time (:obj:'int'): timestep of diffusion step
+ - returns (:obj:'tensor'): condition returns of trajectory, returns is normal return
+ - use_dropout (:obj:'bool'): Whether use returns condition mask
+ - force_dropout (:obj:'bool'): Whether use returns condition
+ """
+ # [batch, horizon, transition ] -> [batch, transition , horizon]
+ x = x.transpose(1, 2)
+ t = self.time_mlp(time)
+
+ if self.returns_condition:
+ assert returns is not None
+ returns_embed = self.returns_mlp(returns)
+ if use_dropout:
+ mask = self.mask_dist.sample(sample_shape=(returns_embed.size(0), 1)).to(returns_embed.device)
+ returns_embed = mask * returns_embed
+ if force_dropout:
+ returns_embed = 0 * returns_embed
+ t = torch.cat([t, returns_embed], dim=-1)
+
+ h = []
+
+ for resnet, resnet2, downsample in self.downs:
+ x = resnet(x, t)
+ x = resnet2(x, t)
+ h.append(x)
+ x = downsample(x)
+
+ x = self.mid_block1(x, t)
+ x = self.mid_block2(x, t)
+
+ for resnet, resnet2, upsample in self.ups:
+ x = torch.cat((x, h.pop()), dim=1)
+ x = resnet(x, t)
+ x = resnet2(x, t)
+ x = upsample(x)
+
+ x = self.final_conv(x)
+ # [batch, transition , horizon] -> [batch, horizon, transition ]
+ x = x.transpose(1, 2)
+ return x
+
+
+class TemporalValue(nn.Module):
+ """
+ Overview:
+ temporal net for value function
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ horizon: int,
+ transition_dim: int,
+ dim: int = 32,
+ time_dim: int = None,
+ out_dim: int = 1,
+ kernel_size: int = 5,
+ dim_mults: SequenceType = [1, 2, 4, 8],
+ ) -> None:
+ """
+ Overview:
+ Initialization of TemporalValue class
+ Arguments:
+ - horizon (:obj:'int'): horizon of trajectory
+ - transition_dim (:obj:'int'): dim of transition, it is obs_dim + action_dim
+ - dim (:obj:'int'): dim of layer
+ - time_dim (:obj:'int'): dim of time
+ - out_dim (:obj:'int'): dim of output
+ - kernel_size (:obj:'int'): kernel_size of conv1d
+ - dim_mults (:obj:'SequenceType'): mults of dim
+ """
+ super().__init__()
+ dims = [transition_dim, *map(lambda m: dim * m, dim_mults)]
+ in_out = list(zip(dims[:-1], dims[1:]))
+
+ time_dim = time_dim or dim
+ self.time_mlp = nn.Sequential(
+ SinusoidalPosEmb(dim),
+ nn.Linear(dim, dim * 4),
+ nn.Mish(),
+ nn.Linear(dim * 4, dim),
+ )
+ self.blocks = nn.ModuleList([])
+
+ for ind, (dim_in, dim_out) in enumerate(in_out):
+ self.blocks.append(
+ nn.ModuleList(
+ [
+ ResidualTemporalBlock(dim_in, dim_out, kernel_size=kernel_size, embed_dim=time_dim),
+ ResidualTemporalBlock(dim_out, dim_out, kernel_size=kernel_size, embed_dim=time_dim),
+ nn.Conv1d(dim_out, dim_out, 3, 2, 1)
+ ]
+ )
+ )
+
+ horizon = horizon // 2
+
+ mid_dim = dims[-1]
+ mid_dim_2 = mid_dim // 2
+ mid_dim_3 = mid_dim // 4
+
+ self.mid_block1 = ResidualTemporalBlock(mid_dim, mid_dim_2, kernel_size=kernel_size, embed_dim=time_dim)
+ self.mid_down1 = nn.Conv1d(mid_dim_2, mid_dim_2, 3, 2, 1)
+
+ horizon = horizon // 2
+ self.mid_block2 = ResidualTemporalBlock(mid_dim_2, mid_dim_3, kernel_size=kernel_size, embed_dim=time_dim)
+ self.mid_down2 = nn.Conv1d(mid_dim_3, mid_dim_3, 3, 2, 1)
+ horizon = horizon // 2
+
+ fc_dim = mid_dim_3 * max(horizon, 1)
+ self.final_block = nn.Sequential(
+ nn.Linear(fc_dim + time_dim, fc_dim // 2),
+ nn.Mish(),
+ nn.Linear(fc_dim // 2, out_dim),
+ )
+
+ def forward(self, x, cond, time, *args):
+ """
+ Overview:
+ compute temporal value forward
+ Arguments:
+ - x (:obj:'tensor'): noise trajectory
+ - cond (:obj:'tuple'): [ (time, state), ... ] state is init state of env, time = 0
+ - time (:obj:'int'): timestep of diffusion step
+ """
+ # [batch, horizon, transition ] -> [batch, transition , horizon]
+ x = x.transpose(1, 2)
+ t = self.time_mlp(time)
+ for resnet, resnet2, downsample in self.blocks:
+ x = resnet(x, t)
+ x = resnet2(x, t)
+ x = downsample(x)
+
+ x = self.mid_block1(x, t)
+ x = self.mid_down1(x)
+
+ x = self.mid_block2(x, t)
+ x = self.mid_down2(x)
+ x = x.view(len(x), -1)
+ out = self.final_block(torch.cat([x, t], dim=-1))
+ return out
diff --git a/DI-engine/ding/torch_utils/network/dreamer.py b/DI-engine/ding/torch_utils/network/dreamer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7c1597e54d11b45aa8fd107a8d0a4412201d43b
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/dreamer.py
@@ -0,0 +1,937 @@
+import math
+import numpy as np
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch import distributions as torchd
+from ding.torch_utils import MLP
+from ding.rl_utils import symlog, inv_symlog
+
+
+class Conv2dSame(torch.nn.Conv2d):
+ """
+ Overview:
+ Conv2dSame Network for dreamerv3.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def calc_same_pad(self, i, k, s, d):
+ """
+ Overview:
+ Calculate the same padding size.
+ Arguments:
+ - i (:obj:`int`): Input size.
+ - k (:obj:`int`): Kernel size.
+ - s (:obj:`int`): Stride size.
+ - d (:obj:`int`): Dilation size.
+ """
+ return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)
+
+ def forward(self, x):
+ """
+ Overview:
+ compute the forward of Conv2dSame.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input tensor.
+ """
+ ih, iw = x.size()[-2:]
+ pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
+ pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])
+
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2])
+
+ ret = F.conv2d(
+ x,
+ self.weight,
+ self.bias,
+ self.stride,
+ self.padding,
+ self.dilation,
+ self.groups,
+ )
+ return ret
+
+
+class DreamerLayerNorm(nn.Module):
+ """
+ Overview:
+ DreamerLayerNorm Network for dreamerv3.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, ch, eps=1e-03):
+ """
+ Overview:
+ Init the DreamerLayerNorm class.
+ Arguments:
+ - ch (:obj:`int`): Input channel.
+ - eps (:obj:`float`): Epsilon.
+ """
+
+ super(DreamerLayerNorm, self).__init__()
+ self.norm = torch.nn.LayerNorm(ch, eps=eps)
+
+ def forward(self, x):
+ """
+ Overview:
+ compute the forward of DreamerLayerNorm.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input tensor.
+ """
+
+ x = x.permute(0, 2, 3, 1)
+ x = self.norm(x)
+ x = x.permute(0, 3, 1, 2)
+ return x
+
+
+class DenseHead(nn.Module):
+ """
+ Overview:
+ DenseHead Network for value head, reward head, and discount head of dreamerv3.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ inp_dim,
+ shape, # (255,)
+ layer_num,
+ units, # 512
+ act='SiLU',
+ norm='LN',
+ dist='normal',
+ std=1.0,
+ outscale=1.0,
+ device='cpu',
+ ):
+ """
+ Overview:
+ Init the DenseHead class.
+ Arguments:
+ - inp_dim (:obj:`int`): Input dimension.
+ - shape (:obj:`tuple`): Output shape.
+ - layer_num (:obj:`int`): Number of layers.
+ - units (:obj:`int`): Number of units.
+ - act (:obj:`str`): Activation function.
+ - norm (:obj:`str`): Normalization function.
+ - dist (:obj:`str`): Distribution function.
+ - std (:obj:`float`): Standard deviation.
+ - outscale (:obj:`float`): Output scale.
+ - device (:obj:`str`): Device.
+ """
+
+ super(DenseHead, self).__init__()
+ self._shape = (shape, ) if isinstance(shape, int) else shape
+ if len(self._shape) == 0:
+ self._shape = (1, )
+ self._layer_num = layer_num
+ self._units = units
+ self._act = getattr(torch.nn, act)()
+ self._norm = norm
+ self._dist = dist
+ self._std = std
+ self._device = device
+
+ self.mlp = MLP(
+ inp_dim,
+ self._units,
+ self._units,
+ self._layer_num,
+ layer_fn=nn.Linear,
+ activation=self._act,
+ norm_type=self._norm
+ )
+ self.mlp.apply(weight_init)
+
+ self.mean_layer = nn.Linear(self._units, np.prod(self._shape))
+ self.mean_layer.apply(uniform_weight_init(outscale))
+
+ if self._std == "learned":
+ self.std_layer = nn.Linear(self._units, np.prod(self._shape))
+ self.std_layer.apply(uniform_weight_init(outscale))
+
+ def forward(self, features):
+ """
+ Overview:
+ compute the forward of DenseHead.
+ Arguments:
+ - features (:obj:`torch.Tensor`): Input tensor.
+ """
+
+ x = features
+ out = self.mlp(x) # (batch, time, _units=512)
+ mean = self.mean_layer(out) # (batch, time, 255)
+ if self._std == "learned":
+ std = self.std_layer(out)
+ else:
+ std = self._std
+ if self._dist == "normal":
+ return ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), len(self._shape)))
+ elif self._dist == "huber":
+ return ContDist(torchd.independent.Independent(UnnormalizedHuber(mean, std, 1.0), len(self._shape)))
+ elif self._dist == "binary":
+ return Bernoulli(torchd.independent.Independent(torchd.bernoulli.Bernoulli(logits=mean), len(self._shape)))
+ elif self._dist == "twohot_symlog":
+ return TwoHotDistSymlog(logits=mean, device=self._device)
+ raise NotImplementedError(self._dist)
+
+
+class ActionHead(nn.Module):
+ """
+ Overview:
+ ActionHead Network for action head of dreamerv3.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ inp_dim,
+ size,
+ layers,
+ units,
+ act=nn.ELU,
+ norm=nn.LayerNorm,
+ dist="trunc_normal",
+ init_std=0.0,
+ min_std=0.1,
+ max_std=1.0,
+ temp=0.1,
+ outscale=1.0,
+ unimix_ratio=0.01,
+ ):
+ """
+ Overview:
+ Initialize the ActionHead class.
+ Arguments:
+ - inp_dim (:obj:`int`): Input dimension.
+ - size (:obj:`int`): Output size.
+ - layers (:obj:`int`): Number of layers.
+ - units (:obj:`int`): Number of units.
+ - act (:obj:`str`): Activation function.
+ - norm (:obj:`str`): Normalization function.
+ - dist (:obj:`str`): Distribution function.
+ - init_std (:obj:`float`): Initial standard deviation.
+ - min_std (:obj:`float`): Minimum standard deviation.
+ - max_std (:obj:`float`): Maximum standard deviation.
+ - temp (:obj:`float`): Temperature.
+ - outscale (:obj:`float`): Output scale.
+ - unimix_ratio (:obj:`float`): Unimix ratio.
+ """
+ super(ActionHead, self).__init__()
+ self._size = size
+ self._layers = layers
+ self._units = units
+ self._dist = dist
+ self._act = getattr(torch.nn, act)
+ self._norm = getattr(torch.nn, norm)
+ self._min_std = min_std
+ self._max_std = max_std
+ self._init_std = init_std
+ self._unimix_ratio = unimix_ratio
+ self._temp = temp() if callable(temp) else temp
+
+ pre_layers = []
+ for index in range(self._layers):
+ pre_layers.append(nn.Linear(inp_dim, self._units, bias=False))
+ pre_layers.append(self._norm(self._units, eps=1e-03))
+ pre_layers.append(self._act())
+ if index == 0:
+ inp_dim = self._units
+ self._pre_layers = nn.Sequential(*pre_layers)
+ self._pre_layers.apply(weight_init)
+
+ if self._dist in ["tanh_normal", "tanh_normal_5", "normal", "trunc_normal"]:
+ self._dist_layer = nn.Linear(self._units, 2 * self._size)
+ self._dist_layer.apply(uniform_weight_init(outscale))
+
+ elif self._dist in ["normal_1", "onehot", "onehot_gumbel"]:
+ self._dist_layer = nn.Linear(self._units, self._size)
+ self._dist_layer.apply(uniform_weight_init(outscale))
+
+ def forward(self, features):
+ """
+ Overview:
+ compute the forward of ActionHead.
+ Arguments:
+ - features (:obj:`torch.Tensor`): Input tensor.
+ """
+
+ x = features
+ x = self._pre_layers(x)
+ if self._dist == "tanh_normal":
+ x = self._dist_layer(x)
+ mean, std = torch.split(x, 2, -1)
+ mean = torch.tanh(mean)
+ std = F.softplus(std + self._init_std) + self._min_std
+ dist = torchd.normal.Normal(mean, std)
+ dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector())
+ dist = torchd.independent.Independent(dist, 1)
+ dist = SampleDist(dist)
+ elif self._dist == "tanh_normal_5":
+ x = self._dist_layer(x)
+ mean, std = torch.split(x, 2, -1)
+ mean = 5 * torch.tanh(mean / 5)
+ std = F.softplus(std + 5) + 5
+ dist = torchd.normal.Normal(mean, std)
+ dist = torchd.transformed_distribution.TransformedDistribution(dist, TanhBijector())
+ dist = torchd.independent.Independent(dist, 1)
+ dist = SampleDist(dist)
+ elif self._dist == "normal":
+ x = self._dist_layer(x)
+ mean, std = torch.split(x, [self._size] * 2, -1)
+ std = (self._max_std - self._min_std) * torch.sigmoid(std + 2.0) + self._min_std
+ dist = torchd.normal.Normal(torch.tanh(mean), std)
+ dist = ContDist(torchd.independent.Independent(dist, 1))
+ elif self._dist == "normal_1":
+ x = self._dist_layer(x)
+ dist = torchd.normal.Normal(mean, 1)
+ dist = ContDist(torchd.independent.Independent(dist, 1))
+ elif self._dist == "trunc_normal":
+ x = self._dist_layer(x)
+ mean, std = torch.split(x, [self._size] * 2, -1)
+ mean = torch.tanh(mean)
+ std = 2 * torch.sigmoid(std / 2) + self._min_std
+ dist = SafeTruncatedNormal(mean, std, -1, 1)
+ dist = ContDist(torchd.independent.Independent(dist, 1))
+ elif self._dist == "onehot":
+ x = self._dist_layer(x)
+ dist = OneHotDist(x, unimix_ratio=self._unimix_ratio)
+ elif self._dist == "onehot_gumble":
+ x = self._dist_layer(x)
+ temp = self._temp
+ dist = ContDist(torchd.gumbel.Gumbel(x, 1 / temp))
+ else:
+ raise NotImplementedError(self._dist)
+ return dist
+
+
+class SampleDist:
+ """
+ Overview:
+ A kind of sample Dist for ActionHead of dreamerv3.
+ Interfaces:
+ ``__init__``, ``mean``, ``mode``, ``entropy``
+ """
+
+ def __init__(self, dist, samples=100):
+ """
+ Overview:
+ Initialize the SampleDist class.
+ Arguments:
+ - dist (:obj:`torch.Tensor`): Distribution.
+ - samples (:obj:`int`): Number of samples.
+ """
+
+ self._dist = dist
+ self._samples = samples
+
+ def mean(self):
+ """
+ Overview:
+ Calculate the mean of the distribution.
+ """
+
+ samples = self._dist.sample(self._samples)
+ return torch.mean(samples, 0)
+
+ def mode(self):
+ """
+ Overview:
+ Calculate the mode of the distribution.
+ """
+
+ sample = self._dist.sample(self._samples)
+ logprob = self._dist.log_prob(sample)
+ return sample[torch.argmax(logprob)][0]
+
+ def entropy(self):
+ """
+ Overview:
+ Calculate the entropy of the distribution.
+ """
+
+ sample = self._dist.sample(self._samples)
+ logprob = self.log_prob(sample)
+ return -torch.mean(logprob, 0)
+
+
+class OneHotDist(torchd.one_hot_categorical.OneHotCategorical):
+ """
+ Overview:
+ A kind of onehot Dist for dreamerv3.
+ Interfaces:
+ ``__init__``, ``mode``, ``sample``
+ """
+
+ def __init__(self, logits=None, probs=None, unimix_ratio=0.0):
+ """
+ Overview:
+ Initialize the OneHotDist class.
+ Arguments:
+ - logits (:obj:`torch.Tensor`): Logits.
+ - probs (:obj:`torch.Tensor`): Probabilities.
+ - unimix_ratio (:obj:`float`): Unimix ratio.
+ """
+
+ if logits is not None and unimix_ratio > 0.0:
+ probs = F.softmax(logits, dim=-1)
+ probs = probs * (1.0 - unimix_ratio) + unimix_ratio / probs.shape[-1]
+ logits = torch.log(probs)
+ super().__init__(logits=logits, probs=None)
+ else:
+ super().__init__(logits=logits, probs=probs)
+
+ def mode(self):
+ """
+ Overview:
+ Calculate the mode of the distribution.
+ """
+
+ _mode = F.one_hot(torch.argmax(super().logits, axis=-1), super().logits.shape[-1])
+ return _mode.detach() + super().logits - super().logits.detach()
+
+ def sample(self, sample_shape=(), seed=None):
+ """
+ Overview:
+ Sample from the distribution.
+ Arguments:
+ - sample_shape (:obj:`tuple`): Sample shape.
+ - seed (:obj:`int`): Seed.
+ """
+
+ if seed is not None:
+ raise ValueError('need to check')
+ sample = super().sample(sample_shape)
+ probs = super().probs
+ while len(probs.shape) < len(sample.shape):
+ probs = probs[None]
+ sample += probs - probs.detach()
+ return sample
+
+
+class TwoHotDistSymlog:
+ """
+ Overview:
+ A kind of twohotsymlog Dist for dreamerv3.
+ Interfaces:
+ ``__init__``, ``mode``, ``mean``, ``log_prob``, ``log_prob_target``
+ """
+
+ def __init__(self, logits=None, low=-20.0, high=20.0, device='cpu'):
+ """
+ Overview:
+ Initialize the TwoHotDistSymlog class.
+ Arguments:
+ - logits (:obj:`torch.Tensor`): Logits.
+ - low (:obj:`float`): Low.
+ - high (:obj:`float`): High.
+ - device (:obj:`str`): Device.
+ """
+
+ self.logits = logits
+ self.probs = torch.softmax(logits, -1)
+ self.buckets = torch.linspace(low, high, steps=255).to(device)
+ self.width = (self.buckets[-1] - self.buckets[0]) / 255
+
+ def mean(self):
+ """
+ Overview:
+ Calculate the mean of the distribution.
+ """
+
+ _mean = self.probs * self.buckets
+ return inv_symlog(torch.sum(_mean, dim=-1, keepdim=True))
+
+ def mode(self):
+ """
+ Overview:
+ Calculate the mode of the distribution.
+ """
+
+ _mode = self.probs * self.buckets
+ return inv_symlog(torch.sum(_mode, dim=-1, keepdim=True))
+
+ # Inside OneHotCategorical, log_prob is calculated using only max element in targets
+ def log_prob(self, x):
+ """
+ Overview:
+ Calculate the log probability of the distribution.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input tensor.
+ """
+
+ x = symlog(x)
+ # x(time, batch, 1)
+ below = torch.sum((self.buckets <= x[..., None]).to(torch.int32), dim=-1) - 1
+ above = len(self.buckets) - torch.sum((self.buckets > x[..., None]).to(torch.int32), dim=-1)
+ below = torch.clip(below, 0, len(self.buckets) - 1)
+ above = torch.clip(above, 0, len(self.buckets) - 1)
+ equal = (below == above)
+
+ dist_to_below = torch.where(equal, 1, torch.abs(self.buckets[below] - x))
+ dist_to_above = torch.where(equal, 1, torch.abs(self.buckets[above] - x))
+ total = dist_to_below + dist_to_above
+ weight_below = dist_to_above / total
+ weight_above = dist_to_below / total
+ target = (
+ F.one_hot(below, num_classes=len(self.buckets)) * weight_below[..., None] +
+ F.one_hot(above, num_classes=len(self.buckets)) * weight_above[..., None]
+ )
+ log_pred = self.logits - torch.logsumexp(self.logits, -1, keepdim=True)
+ target = target.squeeze(-2)
+
+ return (target * log_pred).sum(-1)
+
+ def log_prob_target(self, target):
+ """
+ Overview:
+ Calculate the log probability of the target.
+ Arguments:
+ - target (:obj:`torch.Tensor`): Target tensor.
+ """
+
+ log_pred = super().logits - torch.logsumexp(super().logits, -1, keepdim=True)
+ return (target * log_pred).sum(-1)
+
+
+class SymlogDist:
+ """
+ Overview:
+ A kind of Symlog Dist for dreamerv3.
+ Interfaces:
+ ``__init__``, ``entropy``, ``mode``, ``mean``, ``log_prob``
+ """
+
+ def __init__(self, mode, dist='mse', aggregation='sum', tol=1e-8, dim_to_reduce=[-1, -2, -3]):
+ """
+ Overview:
+ Initialize the SymlogDist class.
+ Arguments:
+ - mode (:obj:`torch.Tensor`): Mode.
+ - dist (:obj:`str`): Distribution function.
+ - aggregation (:obj:`str`): Aggregation function.
+ - tol (:obj:`float`): Tolerance.
+ - dim_to_reduce (:obj:`list`): Dimension to reduce.
+ """
+ self._mode = mode
+ self._dist = dist
+ self._aggregation = aggregation
+ self._tol = tol
+ self._dim_to_reduce = dim_to_reduce
+
+ def mode(self):
+ """
+ Overview:
+ Calculate the mode of the distribution.
+ """
+
+ return inv_symlog(self._mode)
+
+ def mean(self):
+ """
+ Overview:
+ Calculate the mean of the distribution.
+ """
+
+ return inv_symlog(self._mode)
+
+ def log_prob(self, value):
+ """
+ Overview:
+ Calculate the log probability of the distribution.
+ Arguments:
+ - value (:obj:`torch.Tensor`): Input tensor.
+ """
+
+ assert self._mode.shape == value.shape
+ if self._dist == 'mse':
+ distance = (self._mode - symlog(value)) ** 2.0
+ distance = torch.where(distance < self._tol, 0, distance)
+ elif self._dist == 'abs':
+ distance = torch.abs(self._mode - symlog(value))
+ distance = torch.where(distance < self._tol, 0, distance)
+ else:
+ raise NotImplementedError(self._dist)
+ if self._aggregation == 'mean':
+ loss = distance.mean(self._dim_to_reduce)
+ elif self._aggregation == 'sum':
+ loss = distance.sum(self._dim_to_reduce)
+ else:
+ raise NotImplementedError(self._aggregation)
+ return -loss
+
+
+class ContDist:
+ """
+ Overview:
+ A kind of ordinary Dist for dreamerv3.
+ Interfaces:
+ ``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob``
+ """
+
+ def __init__(self, dist=None):
+ """
+ Overview:
+ Initialize the ContDist class.
+ Arguments:
+ - dist (:obj:`torch.Tensor`): Distribution.
+ """
+
+ super().__init__()
+ self._dist = dist
+ self.mean = dist.mean
+
+ def __getattr__(self, name):
+ """
+ Overview:
+ Get attribute.
+ Arguments:
+ - name (:obj:`str`): Attribute name.
+ """
+
+ return getattr(self._dist, name)
+
+ def entropy(self):
+ """
+ Overview:
+ Calculate the entropy of the distribution.
+ """
+
+ return self._dist.entropy()
+
+ def mode(self):
+ """
+ Overview:
+ Calculate the mode of the distribution.
+ """
+
+ return self._dist.mean
+
+ def sample(self, sample_shape=()):
+ """
+ Overview:
+ Sample from the distribution.
+ Arguments:
+ - sample_shape (:obj:`tuple`): Sample shape.
+ """
+
+ return self._dist.rsample(sample_shape)
+
+ def log_prob(self, x):
+ return self._dist.log_prob(x)
+
+
+class Bernoulli:
+ """
+ Overview:
+ A kind of Bernoulli Dist for dreamerv3.
+ Interfaces:
+ ``__init__``, ``entropy``, ``mode``, ``sample``, ``log_prob``
+ """
+
+ def __init__(self, dist=None):
+ """
+ Overview:
+ Initialize the Bernoulli distribution.
+ Arguments:
+ - dist (:obj:`torch.Tensor`): Distribution.
+ """
+
+ super().__init__()
+ self._dist = dist
+ self.mean = dist.mean
+
+ def __getattr__(self, name):
+ """
+ Overview:
+ Get attribute.
+ Arguments:
+ - name (:obj:`str`): Attribute name.
+ """
+
+ return getattr(self._dist, name)
+
+ def entropy(self):
+ """
+ Overview:
+ Calculate the entropy of the distribution.
+ """
+ return self._dist.entropy()
+
+ def mode(self):
+ """
+ Overview:
+ Calculate the mode of the distribution.
+ """
+
+ _mode = torch.round(self._dist.mean)
+ return _mode.detach() + self._dist.mean - self._dist.mean.detach()
+
+ def sample(self, sample_shape=()):
+ """
+ Overview:
+ Sample from the distribution.
+ Arguments:
+ - sample_shape (:obj:`tuple`): Sample shape.
+ """
+
+ return self._dist.rsample(sample_shape)
+
+ def log_prob(self, x):
+ """
+ Overview:
+ Calculate the log probability of the distribution.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input tensor.
+ """
+
+ _logits = self._dist.base_dist.logits
+ log_probs0 = -F.softplus(_logits)
+ log_probs1 = -F.softplus(-_logits)
+
+ return log_probs0 * (1 - x) + log_probs1 * x
+
+
+class UnnormalizedHuber(torchd.normal.Normal):
+ """
+ Overview:
+ A kind of UnnormalizedHuber Dist for dreamerv3.
+ Interfaces:
+ ``__init__``, ``mode``, ``log_prob``
+ """
+
+ def __init__(self, loc, scale, threshold=1, **kwargs):
+ """
+ Overview:
+ Initialize the UnnormalizedHuber class.
+ Arguments:
+ - loc (:obj:`torch.Tensor`): Location.
+ - scale (:obj:`torch.Tensor`): Scale.
+ - threshold (:obj:`float`): Threshold.
+ """
+ super().__init__(loc, scale, **kwargs)
+ self._threshold = threshold
+
+ def log_prob(self, event):
+ """
+ Overview:
+ Calculate the log probability of the distribution.
+ Arguments:
+ - event (:obj:`torch.Tensor`): Event.
+ """
+
+ return -(torch.sqrt((event - self.mean) ** 2 + self._threshold ** 2) - self._threshold)
+
+ def mode(self):
+ """
+ Overview:
+ Calculate the mode of the distribution.
+ """
+
+ return self.mean
+
+
+class SafeTruncatedNormal(torchd.normal.Normal):
+ """
+ Overview:
+ A kind of SafeTruncatedNormal Dist for dreamerv3.
+ Interfaces:
+ ``__init__``, ``sample``
+ """
+
+ def __init__(self, loc, scale, low, high, clip=1e-6, mult=1):
+ """
+ Overview:
+ Initialize the SafeTruncatedNormal class.
+ Arguments:
+ - loc (:obj:`torch.Tensor`): Location.
+ - scale (:obj:`torch.Tensor`): Scale.
+ - low (:obj:`float`): Low.
+ - high (:obj:`float`): High.
+ - clip (:obj:`float`): Clip.
+ - mult (:obj:`float`): Mult.
+ """
+
+ super().__init__(loc, scale)
+ self._low = low
+ self._high = high
+ self._clip = clip
+ self._mult = mult
+
+ def sample(self, sample_shape):
+ """
+ Overview:
+ Sample from the distribution.
+ Arguments:
+ - sample_shape (:obj:`tuple`): Sample shape.
+ """
+
+ event = super().sample(sample_shape)
+ if self._clip:
+ clipped = torch.clip(event, self._low + self._clip, self._high - self._clip)
+ event = event - event.detach() + clipped.detach()
+ if self._mult:
+ event *= self._mult
+ return event
+
+
+class TanhBijector(torchd.Transform):
+ """
+ Overview:
+ A kind of TanhBijector Dist for dreamerv3.
+ Interfaces:
+ ``__init__``, ``_forward``, ``_inverse``, ``_forward_log_det_jacobian``
+ """
+
+ def __init__(self, validate_args=False, name='tanh'):
+ """
+ Overview:
+ Initialize the TanhBijector class.
+ Arguments:
+ - validate_args (:obj:`bool`): Validate arguments.
+ - name (:obj:`str`): Name.
+ """
+
+ super().__init__()
+
+ def _forward(self, x):
+ """
+ Overview:
+ Calculate the forward of the distribution.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input tensor.
+ """
+
+ return torch.tanh(x)
+
+ def _inverse(self, y):
+ """
+ Overview:
+ Calculate the inverse of the distribution.
+ Arguments:
+ - y (:obj:`torch.Tensor`): Input tensor.
+ """
+
+ y = torch.where((torch.abs(y) <= 1.), torch.clamp(y, -0.99999997, 0.99999997), y)
+ y = torch.atanh(y)
+ return y
+
+ def _forward_log_det_jacobian(self, x):
+ """
+ Overview:
+ Calculate the forward log det jacobian of the distribution.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input tensor.
+ """
+
+ log2 = torch.math.log(2.0)
+ return 2.0 * (log2 - x - torch.softplus(-2.0 * x))
+
+
+def static_scan(fn, inputs, start):
+ """
+ Overview:
+ Static scan function.
+ Arguments:
+ - fn (:obj:`function`): Function.
+ - inputs (:obj:`tuple`): Inputs.
+ - start (:obj:`torch.Tensor`): Start tensor.
+ """
+
+ last = start # {logit, stoch, deter:[batch_size, self._deter]}
+ indices = range(inputs[0].shape[0])
+ flag = True
+ for index in indices:
+ inp = lambda x: (_input[x] for _input in inputs) # inputs:(action:(time, batch, 6), embed:(time, batch, 4096))
+ last = fn(last, *inp(index)) # post, prior
+ if flag:
+ if isinstance(last, dict):
+ outputs = {key: value.clone().unsqueeze(0) for key, value in last.items()}
+ else:
+ outputs = []
+ for _last in last:
+ if isinstance(_last, dict):
+ outputs.append({key: value.clone().unsqueeze(0) for key, value in _last.items()})
+ else:
+ outputs.append(_last.clone().unsqueeze(0))
+ flag = False
+ else:
+ if isinstance(last, dict):
+ for key in last.keys():
+ outputs[key] = torch.cat([outputs[key], last[key].unsqueeze(0)], dim=0)
+ else:
+ for j in range(len(outputs)):
+ if isinstance(last[j], dict):
+ for key in last[j].keys():
+ outputs[j][key] = torch.cat([outputs[j][key], last[j][key].unsqueeze(0)], dim=0)
+ else:
+ outputs[j] = torch.cat([outputs[j], last[j].unsqueeze(0)], dim=0)
+ if isinstance(last, dict):
+ outputs = [outputs]
+ return outputs
+
+
+def weight_init(m):
+ """
+ Overview:
+ weight_init for Linear, Conv2d, ConvTranspose2d, and LayerNorm.
+ Arguments:
+ - m (:obj:`torch.nn`): Module.
+ """
+
+ if isinstance(m, nn.Linear):
+ in_num = m.in_features
+ out_num = m.out_features
+ denoms = (in_num + out_num) / 2.0
+ scale = 1.0 / denoms
+ std = np.sqrt(scale) / 0.87962566103423978
+ nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
+ if hasattr(m.bias, 'data'):
+ m.bias.data.fill_(0.0)
+ elif isinstance(m, nn.Conv2d) or isinstance(m, nn.ConvTranspose2d):
+ space = m.kernel_size[0] * m.kernel_size[1]
+ in_num = space * m.in_channels
+ out_num = space * m.out_channels
+ denoms = (in_num + out_num) / 2.0
+ scale = 1.0 / denoms
+ std = np.sqrt(scale) / 0.87962566103423978
+ nn.init.trunc_normal_(m.weight.data, mean=0.0, std=std, a=-2.0, b=2.0)
+ if hasattr(m.bias, 'data'):
+ m.bias.data.fill_(0.0)
+ elif isinstance(m, nn.LayerNorm):
+ m.weight.data.fill_(1.0)
+ if hasattr(m.bias, 'data'):
+ m.bias.data.fill_(0.0)
+
+
+def uniform_weight_init(given_scale):
+ """
+ Overview:
+ weight_init for Linear and LayerNorm.
+ Arguments:
+ - given_scale (:obj:`float`): Given scale.
+ """
+
+ def f(m):
+ if isinstance(m, nn.Linear):
+ in_num = m.in_features
+ out_num = m.out_features
+ denoms = (in_num + out_num) / 2.0
+ scale = given_scale / denoms
+ limit = np.sqrt(3 * scale)
+ nn.init.uniform_(m.weight.data, a=-limit, b=limit)
+ if hasattr(m.bias, 'data'):
+ m.bias.data.fill_(0.0)
+ elif isinstance(m, nn.LayerNorm):
+ m.weight.data.fill_(1.0)
+ if hasattr(m.bias, 'data'):
+ m.bias.data.fill_(0.0)
+
+ return f
diff --git a/DI-engine/ding/torch_utils/network/gtrxl.py b/DI-engine/ding/torch_utils/network/gtrxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..16ac7702c7cd9e8c9b0b122bf2fef48cf93c0487
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/gtrxl.py
@@ -0,0 +1,641 @@
+"""
+Overview:
+ This file implements the core modules of GTrXL Transformer as described in
+ "Stabilizing Transformer for Reinforcement Learning" (https://arxiv.org/abs/1910.06764).
+"""
+from typing import Optional, Dict, List
+import warnings
+import numpy as np
+import torch
+import torch.nn as nn
+from ding.torch_utils.network.nn_module import fc_block, build_normalization, F
+
+
+class PositionalEmbedding(nn.Module):
+ """
+ Overview:
+ The PositionalEmbedding module implements the positional embedding used in the vanilla Transformer model.
+ Interfaces:
+ ``__init__``, ``forward``
+
+ .. note::
+ This implementation is adapted from https://github.com/kimiyoung/transformer-xl/blob/ \
+ master/pytorch/mem_transformer.py
+ """
+
+ def __init__(self, embedding_dim: int):
+ """
+ Overview:
+ Initialize the PositionalEmbedding module.
+ Arguments:
+ - embedding_dim: (:obj:`int`): The dimensionality of the embeddings.
+ """
+
+ super(PositionalEmbedding, self).__init__()
+ self.embedding_dim = embedding_dim
+ inv_freq = 1 / (10000 ** (torch.arange(0.0, embedding_dim, 2.0) / embedding_dim)) # (embedding_dim / 2)
+ self.register_buffer('inv_freq', inv_freq)
+
+ def forward(self, pos_seq: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Compute positional embedding given a sequence of positions.
+ Arguments:
+ - pos_seq (:obj:`torch.Tensor`): The positional sequence, \
+ typically a 1D tensor of integers in the form of [seq_len-1, seq_len-2, ..., 1, 0],
+ Returns:
+ - pos_embedding (:obj:`torch.Tensor`): The computed positional embeddings. \
+ The shape of the tensor is (seq_len, 1, embedding_dim).
+ """
+
+ sinusoid_inp = torch.outer(pos_seq, self.inv_freq)
+ # For position embedding, the order of sin/cos is negligible.
+ # This is because tokens are consumed by the matrix multiplication which is permutation-invariant.
+ pos_embedding = torch.cat([sinusoid_inp.sin(), sinusoid_inp.cos()], dim=-1)
+ return pos_embedding.unsqueeze(1)
+
+
+class GRUGatingUnit(torch.nn.Module):
+ """
+ Overview:
+ The GRUGatingUnit module implements the GRU gating mechanism used in the GTrXL model.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, input_dim: int, bg: float = 2.):
+ """
+ Overview:
+ Initialize the GRUGatingUnit module.
+ Arguments:
+ - input_dim (:obj:`int`): The dimensionality of the input.
+ - bg (:obj:`bg`): The gate bias. By setting bg > 0 we can explicitly initialize the gating mechanism to \
+ be close to the identity map. This can greatly improve the learning speed and stability since it \
+ initializes the agent close to a Markovian policy (ignore attention at the beginning).
+ """
+
+ super(GRUGatingUnit, self).__init__()
+ self.Wr = torch.nn.Linear(input_dim, input_dim, bias=False)
+ self.Ur = torch.nn.Linear(input_dim, input_dim, bias=False)
+ self.Wz = torch.nn.Linear(input_dim, input_dim, bias=False)
+ self.Uz = torch.nn.Linear(input_dim, input_dim, bias=False)
+ self.Wg = torch.nn.Linear(input_dim, input_dim, bias=False)
+ self.Ug = torch.nn.Linear(input_dim, input_dim, bias=False)
+ self.bg = nn.Parameter(torch.full([input_dim], bg)) # bias
+ self.sigmoid = torch.nn.Sigmoid()
+ self.tanh = torch.nn.Tanh()
+
+ def forward(self, x: torch.Tensor, y: torch.Tensor):
+ """
+ Overview:
+ Compute the output value using the GRU gating mechanism.
+ Arguments:
+ - x: (:obj:`torch.Tensor`): The first input tensor.
+ - y: (:obj:`torch.Tensor`): The second input tensor. \
+ x and y should have the same shape and their last dimension should match the input_dim.
+ Returns:
+ - g: (:obj:`torch.Tensor`): The output of the GRU gating mechanism. \
+ The shape of g matches the shapes of x and y.
+ """
+
+ r = self.sigmoid(self.Wr(y) + self.Ur(x))
+ z = self.sigmoid(self.Wz(y) + self.Uz(x) - self.bg)
+ h = self.tanh(self.Wg(y) + self.Ug(torch.mul(r, x))) # element wise multiplication
+ g = torch.mul(1 - z, x) + torch.mul(z, h)
+ return g # x.shape == y.shape == g.shape
+
+
+class Memory:
+ """
+ Overview:
+ A class that stores the context used to add memory to Transformer.
+ Interfaces:
+ ``__init__``, ``init``, ``update``, ``get``, ``to``
+
+ .. note::
+ For details, refer to Transformer-XL: https://arxiv.org/abs/1901.02860
+ """
+
+ def __init__(
+ self,
+ memory_len: int = 20,
+ batch_size: int = 64,
+ embedding_dim: int = 256,
+ layer_num: int = 3,
+ memory: Optional[torch.Tensor] = None
+ ) -> None:
+ """
+ Overview:
+ Initialize the Memory module.
+ Arguments:
+ - memory_len (:obj:`int`): The dimension of memory, i.e., how many past observations to use as memory.
+ - batch_size (:obj:`int`): The dimension of each batch.
+ - embedding_dim (:obj:`int`): The dimension of embedding, which is the dimension of a single observation \
+ after embedding.
+ - layer_num (:obj:`int`): The number of transformer layers.
+ - memory (:obj:`Optional[torch.Tensor]`): The initial memory. Default is None.
+ """
+
+ super(Memory, self).__init__()
+ self.embedding_dim = embedding_dim
+ self.bs = batch_size
+ self.layer_num = layer_num
+ self.memory_len = memory_len
+ self.memory = None
+ self.init(memory)
+
+ def init(self, memory: Optional[torch.Tensor] = None):
+ """
+ Overview:
+ Initialize memory with an input list of tensors or create it automatically given its dimensions.
+ Arguments:
+ - memory (:obj:`Optional[torch.Tensor]`): Input memory tensor with shape \
+ (layer_num, memory_len, bs, embedding_dim). Its shape is (layer_num, memory_len, bs, embedding_dim), \
+ where memory_len is length of memory, bs is batch size and embedding_dim is the dimension of embedding.
+ """
+
+ if memory is not None:
+ self.memory = memory
+ layer_num_plus1, self.memory_len, self.bs, self.embedding_dim = memory.shape
+ self.layer_num = layer_num_plus1 - 1
+ else:
+ self.memory = torch.zeros(
+ self.layer_num + 1, self.memory_len, self.bs, self.embedding_dim, dtype=torch.float
+ )
+
+ def update(self, hidden_state: List[torch.Tensor]):
+ """
+ Overview:
+ Update the memory given a sequence of hidden states.
+ Example for single layer:
+ memory_len=3, hidden_size_len=2, bs=3
+
+ m00 m01 m02 h00 h01 h02 m20 m21 m22
+ m = m10 m11 m12 h = h10 h11 h12 => new_m = h00 h01 h02
+ m20 m21 m22 h10 h11 h12
+ Arguments:
+ - hidden_state: (:obj:`List[torch.Tensor]`): The hidden states to update the memory. \
+ Each tensor in the list has shape (cur_seq, bs, embedding_dim), where cur_seq \
+ is the length of the sequence.
+ Returns:
+ - memory: (:obj:`Optional[torch.Tensor]`): The updated memory, with shape \
+ (layer_num, memory_len, bs, embedding_dim).
+ """
+
+ if self.memory is None or hidden_state is None:
+ raise ValueError('Failed to update memory! Memory would be None') # TODO add support of no memory
+ sequence_len = hidden_state[0].shape[0]
+ with torch.no_grad():
+ new_memory = []
+ end = self.memory_len + sequence_len
+ beg = max(0, end - self.memory_len)
+ for i in range(self.layer_num + 1):
+ m = self.memory[i]
+ h = hidden_state[i]
+ cat = torch.cat([m, h], dim=0)
+ new_memory.append(cat[beg:end].detach())
+ new_memory = torch.stack(new_memory, dim=0)
+ self.memory = new_memory
+ return new_memory
+
+ def get(self):
+ """
+ Overview:
+ Get the current memory.
+ Returns:
+ - memory: (:obj:`Optional[torch.Tensor]`): The current memory, \
+ with shape (layer_num, memory_len, bs, embedding_dim).
+ """
+
+ return self.memory
+
+ def to(self, device: str = 'cpu'):
+ """
+ Overview:
+ Move the current memory to the specified device.
+ Arguments:
+ device (:obj:`str`): The device to move the memory to. Default is 'cpu'.
+ """
+
+ self.memory = self.memory.to(device)
+
+
+class AttentionXL(torch.nn.Module):
+ """
+ Overview:
+ An implementation of the Attention mechanism used in the TransformerXL model.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, input_dim: int, head_dim: int, head_num: int, dropout: nn.Module) -> None:
+ """
+ Overview:
+ Initialize the AttentionXL module.
+ Arguments:
+ - input_dim (:obj:`int`): The dimensionality of the input features.
+ - head_dim (:obj:`int`): The dimensionality of each attention head.
+ - head_num (:obj:`int`): The number of attention heads.
+ - dropout (:obj:`nn.Module`): The dropout layer to use
+ """
+
+ super(AttentionXL, self).__init__()
+ self.head_num = head_num
+ self.head_dim = head_dim
+ self.dropout = dropout
+ self.attention_kv = fc_block(input_dim, head_dim * head_num * 2) # key, value
+ self.attention_q = fc_block(input_dim, head_dim * head_num) # query (not computed with past hidden states)
+ self.project = fc_block(head_dim * head_num, input_dim) # project attention output back to input_dim
+ self.project_pos = fc_block(input_dim, head_dim * head_num) # project the positional embedding
+ self.scale = 1 / (head_dim ** 0.5) # for scaled dot product attention
+
+ def _rel_shift(self, x: torch.Tensor, zero_upper: bool = False) -> torch.Tensor:
+ """
+ Overview:
+ Perform a relative shift operation on the attention score matrix.
+ Example:
+ a00 a01 a02 0 a00 a01 a02 0 a00 a01 a02 0 a10 a02 0 0
+ a10 a11 a12 => 0 a10 a11 a12 => a02 0 a10 => a11 a12 0 => a11 a12 0
+ a20 a21 a22 0 a20 a21 a22 a11 a12 0 a20 a21 a22 a20 a21 a22
+ a20 a21 a22
+ 1) Append one "column" of zeros to the left
+ 2) Reshape the matrix from [3 x 4] into [4 x 3]
+ 3) Remove the first "row"
+ 4) Mask out the upper triangle (optional)
+
+ .. note::
+ See the following material for better understanding:
+ https://github.com/kimiyoung/transformer-xl/issues/8
+ https://arxiv.org/pdf/1901.02860.pdf (Appendix B)
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor with shape (cur_seq, full_seq, bs, head_num).
+ - zero_upper (:obj:`bool`): If True, the upper-right triangle of the matrix is set to zero.
+ Returns:
+ - x (:obj:`torch.Tensor`): The input tensor after the relative shift operation, \
+ with shape (cur_seq, full_seq, bs, head_num).
+ """
+
+ x_padded = F.pad(x, [1, 0]) # step 1
+ x_padded = x_padded.view(x.size(0), x.size(1), x.size(3) + 1, x.size(2)) # step 2
+ x = x_padded[:, :, 1:].view_as(x) # step 3
+ if zero_upper:
+ ones = torch.ones((x.size(2), x.size(3))).unsqueeze(0).unsqueeze(0)
+ x = x * torch.tril(ones.to(x.device), x.size(3) - x.size(2)) # step 4
+ return x
+
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ pos_embedding: torch.Tensor,
+ full_input: torch.Tensor,
+ u: torch.nn.Parameter,
+ v: torch.nn.Parameter,
+ mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Overview:
+ Compute the forward pass for the AttentionXL module.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): The attention input with shape (cur_seq, bs, input_dim).
+ - pos_embedding (:obj:`torch.Tensor`): The positional embedding with shape (full_seq, 1, full_seq).
+ - full_input (:obj:`torch.Tensor`): The concatenated memory and input tensor with shape \
+ (full_seq, bs, input_dim).
+ - u (:obj:`torch.nn.Parameter`): The content parameter with shape (head_num, head_dim).
+ - v (:obj:`torch.nn.Parameter`): The position parameter with shape (head_num, head_dim).
+ - mask (:obj:`Optional[torch.Tensor]`): The attention mask with shape (cur_seq, full_seq, 1). \
+ If None, no masking is applied.
+ Returns:
+ - output (:obj:`torch.Tensor`): The output of the attention mechanism with shape (cur_seq, bs, input_dim).
+ """
+
+ bs, cur_seq, full_seq = inputs.shape[1], inputs.shape[0], full_input.shape[0]
+ prev_seq = full_seq - cur_seq
+
+ kv = self.attention_kv(full_input)
+ key, value = torch.chunk(kv, 2, dim=-1) # full_seq x bs x num_head*dim_head
+ query = self.attention_q(inputs) # cur_seq x bs x num_head*dim_head
+ r = self.project_pos(pos_embedding) # full_seq x 1 x num_head*dim_head
+
+ key = key.view(full_seq, bs, self.head_num, self.head_dim)
+ query = query.view(cur_seq, bs, self.head_num, self.head_dim)
+ value = value.view(cur_seq + prev_seq, bs, self.head_num, self.head_dim)
+ r = r.view(full_seq, self.head_num, self.head_dim)
+
+ # (query + u) * key^T
+ q_u = query + u
+ content_attn = q_u.permute(1, 2, 0, 3) @ key.permute(1, 2, 3, 0) # bs x head_num x cur_seq x full_seq
+
+ # (query + v) * R^T
+ q_v = query + v
+ position_attn = q_v.permute(1, 2, 0, 3) @ r.permute(1, 2, 0) # bs x head_num x cur_seq x full_seq
+ position_attn = self._rel_shift(position_attn)
+
+ attn = content_attn + position_attn # bs x head_num x cur_seq x full_seq
+ attn.mul_(self.scale)
+
+ # fills float('-inf') where mask is True to let softmax ignore those positions.
+ if mask is not None and mask.any().item():
+ mask = mask.permute(2, 0, 1).unsqueeze(1) # 1 x 1 x cur_seq x full_seq
+ assert mask.shape[2:] == attn.shape[2:] # check shape of mask
+ attn = attn.masked_fill(mask, -float("inf")).type_as(attn)
+
+ attn = F.softmax(attn, dim=-1)
+ attn = self.dropout(attn)
+
+ # multiply softmax output by value
+ attn_vec = attn @ value.permute(1, 2, 0, 3)
+ attn_vec = attn_vec.permute(2, 0, 1, 3)
+
+ attn_vec = attn_vec.contiguous().view(cur_seq, bs, self.head_num * self.head_dim)
+ # cur_seq x bs x head_num * head_dim
+ output = self.dropout(self.project(attn_vec)) # cur_seq x bs x input_dim
+ return output
+
+
+class GatedTransformerXLLayer(torch.nn.Module):
+ """
+ Overview:
+ This class implements the attention layer of GTrXL (Gated Transformer-XL).
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ head_dim: int,
+ hidden_dim: int,
+ head_num: int,
+ mlp_num: int,
+ dropout: nn.Module,
+ activation: nn.Module,
+ gru_gating: bool = True,
+ gru_bias: float = 2.
+ ) -> None:
+ """
+ Overview:
+ Initialize GatedTransformerXLLayer.
+ Arguments:
+ - input_dim (:obj:`int`): The dimension of the input tensor.
+ - head_dim (:obj:`int`): The dimension of each head in the multi-head attention.
+ - hidden_dim (:obj:`int`): The dimension of the hidden layer in the MLP.
+ - head_num (:obj:`int`): The number of heads for the multi-head attention.
+ - mlp_num (:obj:`int`): The number of MLP layers in the attention layer.
+ - dropout (:obj:`nn.Module`): The dropout module used in the MLP and attention layers.
+ - activation (:obj:`nn.Module`): The activation function to be used in the MLP layers.
+ - gru_gating (:obj:`bool`, optional): Whether to use GRU gates. If False, replace GRU gates with \
+ residual connections. Default is True.
+ - gru_bias (:obj:`float`, optional): The bias of the GRU gate. Default is 2.
+ """
+
+ super(GatedTransformerXLLayer, self).__init__()
+ self.dropout = dropout
+ self.gating = gru_gating
+ if self.gating is True:
+ self.gate1 = GRUGatingUnit(input_dim, gru_bias)
+ self.gate2 = GRUGatingUnit(input_dim, gru_bias)
+ self.attention = AttentionXL(
+ input_dim,
+ head_dim,
+ head_num,
+ dropout,
+ )
+ layers = []
+ dims = [input_dim] + [hidden_dim] * (mlp_num - 1) + [input_dim]
+ for i in range(mlp_num):
+ layers.append(fc_block(dims[i], dims[i + 1], activation=activation))
+ if i != mlp_num - 1:
+ layers.append(self.dropout)
+ layers.append(self.dropout)
+ self.mlp = nn.Sequential(*layers)
+ self.layernorm1 = build_normalization('LN')(input_dim)
+ self.layernorm2 = build_normalization('LN')(input_dim)
+ self.activation = activation
+
+ def forward(
+ self,
+ inputs: torch.Tensor,
+ pos_embedding: torch.Tensor,
+ u: torch.nn.Parameter,
+ v: torch.nn.Parameter,
+ memory: torch.Tensor,
+ mask: Optional[torch.Tensor] = None,
+ ) -> torch.Tensor:
+ """
+ Overview:
+ Compute forward pass of GTrXL layer.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): The attention input tensor of shape (cur_seq, bs, input_dim).
+ - pos_embedding (:obj:`torch.Tensor`): The positional embedding tensor of shape (full_seq, 1, full_seq).
+ - u (:obj:`torch.nn.Parameter`): The content parameter tensor of shape (head_num, head_dim).
+ - v (:obj:`torch.nn.Parameter`): The position parameter tensor of shape (head_num, head_dim).
+ - memory (:obj:`torch.Tensor`): The memory tensor of shape (prev_seq, bs, input_dim).
+ - mask (:obj:`Optional[torch.Tensor]`): The attention mask tensor of shape (cur_seq, full_seq, 1).
+ Default is None.
+ Returns:
+ - output (:obj:`torch.Tensor`): layer output of shape (cur_seq, bs, input_dim)
+ """
+
+ # concat memory with input across sequence dimension
+ full_input = torch.cat([memory, inputs], dim=0) # full_seq x bs x input_dim
+ x1 = self.layernorm1(full_input)
+ a1 = self.dropout(self.attention(inputs, pos_embedding, x1, u, v, mask=mask))
+ a1 = self.activation(a1) # RELU after attention
+ o1 = self.gate1(inputs, a1) if self.gating else inputs + a1
+ x2 = self.layernorm2(o1)
+ m2 = self.dropout(self.mlp(x2))
+ o2 = self.gate2(o1, m2) if self.gating else o1 + m2
+ return o2
+
+
+class GTrXL(nn.Module):
+ """
+ Overview:
+ GTrXL Transformer implementation as described in "Stabilizing Transformer for Reinforcement Learning"
+ (https://arxiv.org/abs/1910.06764).
+ Interfaces:
+ ``__init__``, ``forward``, ``reset_memory``, ``get_memory``
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ head_dim: int = 128,
+ embedding_dim: int = 256,
+ head_num: int = 2,
+ mlp_num: int = 2,
+ layer_num: int = 3,
+ memory_len: int = 64,
+ dropout_ratio: float = 0.,
+ activation: nn.Module = nn.ReLU(),
+ gru_gating: bool = True,
+ gru_bias: float = 2.,
+ use_embedding_layer: bool = True,
+ ) -> None:
+ """Overview:
+ Init GTrXL Model.
+ Arguments:
+ - input_dim (:obj:`int`): The dimension of the input observation.
+ - head_dim (:obj:`int`, optional): The dimension of each head. Default is 128.
+ - embedding_dim (:obj:`int`, optional): The dimension of the embedding. Default is 256.
+ - head_num (:obj:`int`, optional): The number of heads for multi-head attention. Default is 2.
+ - mlp_num (:obj:`int`, optional): The number of MLP layers in the attention layer. Default is 2.
+ - layer_num (:obj:`int`, optional): The number of transformer layers. Default is 3.
+ - memory_len (:obj:`int`, optional): The length of memory. Default is 64.
+ - dropout_ratio (:obj:`float`, optional): The dropout ratio. Default is 0.
+ - activation (:obj:`nn.Module`, optional): The activation function. Default is nn.ReLU().
+ - gru_gating (:obj:`bool`, optional): If False, replace GRU gates with residual connections. \
+ Default is True.
+ - gru_bias (:obj:`float`, optional): The GRU gate bias. Default is 2.0.
+ - use_embedding_layer (:obj:`bool`, optional): If False, don't use input embedding layer. Default is True.
+ Raises:
+ - AssertionError: If `embedding_dim` is not an even number.
+ """
+
+ super(GTrXL, self).__init__()
+ assert embedding_dim % 2 == 0, 'embedding_dim={} should be even'.format(input_dim)
+ self.head_num = head_num
+ self.head_dim = head_dim
+ self.layer_num = layer_num
+ if isinstance(input_dim, list):
+ input_dim = np.prod(input_dim)
+ self.use_embedding_layer = use_embedding_layer
+ if use_embedding_layer:
+ self.embedding = fc_block(input_dim, embedding_dim, activation=activation)
+ self.activation = activation
+ self.pos_embedding = PositionalEmbedding(embedding_dim)
+ # memory to save hidden states of past segments
+ # it will be initialized in the forward method to get its size dynamically
+ self.memory = None
+ self.memory_len = memory_len
+ layers = []
+ dims = [embedding_dim] + [embedding_dim] * layer_num
+ self.dropout = nn.Dropout(dropout_ratio) if dropout_ratio > 0 else nn.Identity()
+ for i in range(layer_num):
+ layers.append(
+ GatedTransformerXLLayer(
+ dims[i], head_dim, embedding_dim, head_num, mlp_num, self.dropout, self.activation, gru_gating,
+ gru_bias
+ )
+ )
+ self.layers = nn.Sequential(*layers)
+ self.embedding_dim = embedding_dim
+ # u and v are the parameters to compute global content bias and global positional bias
+ self.u, self.v = (
+ torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)),
+ torch.nn.Parameter(torch.zeros(self.head_num, self.head_dim)),
+ )
+ self.att_mask = {} # create an attention mask for each different seq_len, in this way we don't need to create a
+ # new one each time we call the forward method
+ self.pos_embedding_dict = {} # create a pos embedding for each different seq_len
+
+ def reset_memory(self, batch_size: Optional[int] = None, state: Optional[torch.Tensor] = None):
+ """
+ Overview:
+ Clear or set the memory of GTrXL.
+ Arguments:
+ - batch_size (:obj:`Optional[int]`): The batch size. Default is None.
+ - state (:obj:`Optional[torch.Tensor]`): The input memory with shape \
+ (layer_num, memory_len, bs, embedding_dim). Default is None.
+ """
+
+ self.memory = Memory(memory_len=self.memory_len, layer_num=self.layer_num, embedding_dim=self.embedding_dim)
+ if batch_size is not None:
+ self.memory = Memory(self.memory_len, batch_size, self.embedding_dim, self.layer_num)
+ elif state is not None:
+ self.memory.init(state)
+
+ def get_memory(self):
+ """
+ Overview:
+ Returns the memory of GTrXL.
+ Returns:
+ - memory (:obj:`Optional[torch.Tensor]`): The output memory or None if memory has not been initialized. \
+ The shape is (layer_num, memory_len, bs, embedding_dim).
+ """
+
+ if self.memory is None:
+ return None
+ else:
+ return self.memory.get()
+
+ def forward(self, x: torch.Tensor, batch_first: bool = False, return_mem: bool = True) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Performs a forward pass on the GTrXL.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor with shape (seq_len, bs, input_size).
+ - batch_first (:obj:`bool`, optional): If the input data has shape (bs, seq_len, input_size), \
+ set this parameter to True to transpose along the first and second dimension and obtain shape \
+ (seq_len, bs, input_size). This does not affect the output memory. Default is False. \
+ - return_mem (:obj:`bool`, optional): If False, return only the output tensor without dict. Default is True.
+ Returns:
+ - x (:obj:`Dict[str, torch.Tensor]`): A dictionary containing the transformer output of shape \
+ (seq_len, bs, embedding_size) and memory of shape (layer_num, seq_len, bs, embedding_size).
+ """
+
+ if batch_first:
+ x = torch.transpose(x, 1, 0) # bs x cur_seq x input_dim -> cur_seq x bs x input_dim
+ cur_seq, bs = x.shape[:2]
+ memory = None if self.memory is None else self.memory.get()
+ if memory is None:
+ self.reset_memory(bs) # (layer_num+1) x memory_len x batch_size x embedding_dim
+ elif memory.shape[-2] != bs or memory.shape[-1] != self.embedding_dim:
+ warnings.warn(
+ "Memory {} and Input {} dimensions don't match,"
+ " this will cause the memory to be initialized to fit your input!".format(
+ list(memory.shape[-2:]), [x.shape[-2]] + [self.embedding_dim]
+ )
+ )
+ self.reset_memory(bs)
+ self.memory.to(x.device)
+ memory = self.memory.get()
+
+ if self.use_embedding_layer:
+ x = self.dropout(self.embedding(x))
+ prev_seq = self.memory_len
+ full_seq = cur_seq + prev_seq
+
+ if cur_seq in self.att_mask.keys():
+ attn_mask = self.att_mask[cur_seq]
+ else:
+ attn_mask = (
+ torch.triu(
+ torch.ones((cur_seq, full_seq)),
+ diagonal=1 + prev_seq, # fixed in train, eval, collect
+ ).bool().unsqueeze(-1).to(x.device)
+ ) # cur_seq x full_seq x 1
+ self.att_mask[cur_seq] = attn_mask
+
+ if cur_seq in self.pos_embedding_dict.keys():
+ pos_embedding = self.pos_embedding_dict[cur_seq]
+ else:
+ pos_ips = torch.arange(full_seq - 1, -1, -1.0, dtype=torch.float) # full_seq
+ pos_embedding = self.pos_embedding(pos_ips.to(x.device))
+ self.pos_embedding_dict[cur_seq] = pos_embedding
+ pos_embedding = self.dropout(pos_embedding) # full_seq x 1 x embedding_dim
+
+ hidden_state = [x]
+ out = x
+ for i in range(self.layer_num):
+ layer = self.layers[i]
+ out = layer(
+ out,
+ pos_embedding,
+ self.u,
+ self.v,
+ mask=attn_mask,
+ memory=memory[i], # (layer_num+1) x memory_len x batch_size x embedding_dim
+ ) # cur_seq x bs x embedding_dim
+ hidden_state.append(out.clone())
+
+ out = self.dropout(out)
+ self.memory.update(hidden_state) # (layer_num+1) x memory_len x batch_size x embedding_dim
+
+ if batch_first:
+ out = torch.transpose(out, 1, 0) # cur_seq x bs x embedding_dim -> bs x cur_seq x embedding_dim
+ if return_mem:
+ output = {"logit": out, "memory": memory} # return the content of the memory before the last update
+ else:
+ output = {"logit": out}
+ return output
diff --git a/DI-engine/ding/torch_utils/network/gumbel_softmax.py b/DI-engine/ding/torch_utils/network/gumbel_softmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..fea761210326d6b3067a5e69f89a193638a51d54
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/gumbel_softmax.py
@@ -0,0 +1,62 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class GumbelSoftmax(nn.Module):
+ """
+ Overview:
+ An `nn.Module` that computes GumbelSoftmax.
+ Interfaces:
+ ``__init__``, ``forward``, ``gumbel_softmax_sample``
+
+ .. note::
+ For more information on GumbelSoftmax, refer to the paper [Categorical Reparameterization \
+ with Gumbel-Softmax](https://arxiv.org/abs/1611.01144).
+ """
+
+ def __init__(self) -> None:
+ """
+ Overview:
+ Initialize the `GumbelSoftmax` module.
+ """
+ super(GumbelSoftmax, self).__init__()
+
+ def gumbel_softmax_sample(self, x: torch.Tensor, temperature: float, eps: float = 1e-8) -> torch.Tensor:
+ """
+ Overview:
+ Draw a sample from the Gumbel-Softmax distribution.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input tensor.
+ - temperature (:obj:`float`): Non-negative scalar controlling the sharpness of the distribution.
+ - eps (:obj:`float`): Small number to prevent division by zero, default is `1e-8`.
+ Returns:
+ - output (:obj:`torch.Tensor`): Sample from Gumbel-Softmax distribution.
+ """
+ U = torch.rand(x.shape)
+ U = U.to(x.device)
+ y = x - torch.log(-torch.log(U + eps) + eps)
+ return F.softmax(y / temperature, dim=1)
+
+ def forward(self, x: torch.Tensor, temperature: float = 1.0, hard: bool = False) -> torch.Tensor:
+ """
+ Overview:
+ Forward pass for the `GumbelSoftmax` module.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Unnormalized log-probabilities.
+ - temperature (:obj:`float`): Non-negative scalar controlling the sharpness of the distribution.
+ - hard (:obj:`bool`): If `True`, returns one-hot encoded labels. Default is `False`.
+ Returns:
+ - output (:obj:`torch.Tensor`): Sample from Gumbel-Softmax distribution.
+ Shapes:
+ - x: its shape is :math:`(B, N)`, where `B` is the batch size and `N` is the number of classes.
+ - y: its shape is :math:`(B, N)`, where `B` is the batch size and `N` is the number of classes.
+ """
+ y = self.gumbel_softmax_sample(x, temperature)
+ if hard:
+ y_hard = torch.zeros_like(x)
+ y_hard[torch.arange(0, x.shape[0]), y.max(1)[1]] = 1
+ # The detach function treat (y_hard - y) as constant,
+ # to make sure makes the gradient equal to y_soft gradient
+ y = (y_hard - y).detach() + y
+ return y
diff --git a/DI-engine/ding/torch_utils/network/merge.py b/DI-engine/ding/torch_utils/network/merge.py
new file mode 100644
index 0000000000000000000000000000000000000000..25d89885dddb5f344f4c8e4be1d70a4d0bbd9eb0
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/merge.py
@@ -0,0 +1,400 @@
+"""
+This file provides an implementation of several different neural network modules that are used for merging and
+transforming input data in various ways. The following components can be used when we are dealing with
+data from multiple modes, or when we need to merge multiple intermediate embedded representations in
+the forward process of a model.
+
+The main classes defined in this code are:
+
+ - BilinearGeneral: This class implements a bilinear transformation layer that applies a bilinear transformation to
+ incoming data, as described in the "Multiplicative Interactions and Where to Find Them", published at ICLR 2020,
+ https://openreview.net/forum?id=rylnK6VtDH. The transformation involves two input features and an output
+ feature, and also includes an optional bias term.
+
+ - TorchBilinearCustomized: This class implements a bilinear layer similar to the one provided by PyTorch
+ (torch.nn.Bilinear), but with additional customizations. This class can be used as an alternative to the
+ BilinearGeneral class.
+
+ - TorchBilinear: This class is a simple wrapper around the PyTorch's built-in nn.Bilinear module. It provides the
+ same functionality as PyTorch's nn.Bilinear but within the structure of the current module.
+
+ - FiLM: This class implements a Feature-wise Linear Modulation (FiLM) layer. FiLM layers apply an affine
+ transformation to the input data, conditioned on some additional context information.
+
+ - GatingType: This is an enumeration class that defines different types of gating mechanisms that can be used in
+ the modules.
+
+ - SumMerge: This class provides a simple summing mechanism to merge input streams.
+
+ - VectorMerge: This class implements a more complex merging mechanism for vector streams.
+ The streams are first transformed using layer normalization, a ReLU activation, and a linear layer.
+ Then they are merged either by simple summing or by using a gating mechanism.
+
+The implementation of these classes involves PyTorch and Numpy libraries, and the classes use PyTorch's nn.Module as
+the base class, making them compatible with PyTorch's neural network modules and functionalities.
+These modules can be useful building blocks in more complex deep learning architectures.
+"""
+
+import enum
+import math
+from collections import OrderedDict
+from typing import List, Dict, Tuple
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import Tensor
+
+
+class BilinearGeneral(nn.Module):
+ """
+ Overview:
+ Bilinear implementation as in: Multiplicative Interactions and Where to Find Them,
+ ICLR 2020, https://openreview.net/forum?id=rylnK6VtDH.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, in1_features: int, in2_features: int, out_features: int):
+ """
+ Overview:
+ Initialize the Bilinear layer.
+ Arguments:
+ - in1_features (:obj:`int`): The size of each first input sample.
+ - in2_features (:obj:`int`): The size of each second input sample.
+ - out_features (:obj:`int`): The size of each output sample.
+ """
+
+ super(BilinearGeneral, self).__init__()
+ # Initialize the weight matrices W and U, and the bias vectors V and b
+ self.W = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features))
+ self.U = nn.Parameter(torch.Tensor(out_features, in2_features))
+ self.V = nn.Parameter(torch.Tensor(out_features, in1_features))
+ self.b = nn.Parameter(torch.Tensor(out_features))
+ self.in1_features = in1_features
+ self.in2_features = in2_features
+ self.out_features = out_features
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ """
+ Overview:
+ Initialize the parameters of the Bilinear layer.
+ """
+
+ stdv = 1. / np.sqrt(self.in1_features)
+ self.W.data.uniform_(-stdv, stdv)
+ self.U.data.uniform_(-stdv, stdv)
+ self.V.data.uniform_(-stdv, stdv)
+ self.b.data.uniform_(-stdv, stdv)
+
+ def forward(self, x: torch.Tensor, z: torch.Tensor):
+ """
+ Overview:
+ compute the bilinear function.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The first input tensor.
+ - z (:obj:`torch.Tensor`): The second input tensor.
+ """
+
+ # Compute the bilinear function
+ # x^TWz
+ out_W = torch.einsum('bi,kij,bj->bk', x, self.W, z)
+ # x^TU
+ out_U = z.matmul(self.U.t())
+ # Vz
+ out_V = x.matmul(self.V.t())
+ # x^TWz + x^TU + Vz + b
+ out = out_W + out_U + out_V + self.b
+ return out
+
+
+class TorchBilinearCustomized(nn.Module):
+ """
+ Overview:
+ Customized Torch Bilinear implementation.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, in1_features: int, in2_features: int, out_features: int):
+ """
+ Overview:
+ Initialize the Bilinear layer.
+ Arguments:
+ - in1_features (:obj:`int`): The size of each first input sample.
+ - in2_features (:obj:`int`): The size of each second input sample.
+ - out_features (:obj:`int`): The size of each output sample.
+ """
+
+ super(TorchBilinearCustomized, self).__init__()
+ self.in1_features = in1_features
+ self.in2_features = in2_features
+ self.out_features = out_features
+ self.weight = nn.Parameter(torch.Tensor(out_features, in1_features, in2_features))
+ self.bias = nn.Parameter(torch.Tensor(out_features))
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ """
+ Overview:
+ Initialize the parameters of the Bilinear layer.
+ """
+
+ bound = 1 / math.sqrt(self.in1_features)
+ nn.init.uniform_(self.weight, -bound, bound)
+ nn.init.uniform_(self.bias, -bound, bound)
+
+ def forward(self, x, z):
+ """
+ Overview:
+ Compute the bilinear function.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The first input tensor.
+ - z (:obj:`torch.Tensor`): The second input tensor.
+ """
+
+ # Using torch.einsum for the bilinear operation
+ out = torch.einsum('bi,oij,bj->bo', x, self.weight, z) + self.bias
+ return out.squeeze(-1)
+
+
+"""
+Overview:
+ Implementation of the Bilinear layer as in PyTorch:
+ https://pytorch.org/docs/stable/generated/torch.nn.Bilinear.html#torch.nn.Bilinear
+Arguments:
+ - in1_features (:obj:`int`): The size of each first input sample.
+ - in2_features (:obj:`int`): The size of each second input sample.
+ - out_features (:obj:`int`): The size of each output sample.
+ - bias (:obj:`bool`): If set to False, the layer will not learn an additive bias. Default: ``True``.
+"""
+TorchBilinear = nn.Bilinear
+
+
+class FiLM(nn.Module):
+ """
+ Overview:
+ Feature-wise Linear Modulation (FiLM) Layer.
+ This layer applies feature-wise affine transformation based on context.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, feature_dim: int, context_dim: int):
+ """
+ Overview:
+ Initialize the FiLM layer.
+ Arguments:
+ - feature_dim (:obj:`int`). The dimension of the input feature vector.
+ - context_dim (:obj:`int`). The dimension of the input context vector.
+ """
+
+ super(FiLM, self).__init__()
+ # Define the fully connected layer for context
+ # The output dimension is twice the feature dimension for gamma and beta
+ self.context_layer = nn.Linear(context_dim, 2 * feature_dim)
+
+ def forward(self, feature: torch.Tensor, context: torch.Tensor):
+ """
+ Overview:
+ Forward propagation.
+ Arguments:
+ - feature (:obj:`torch.Tensor`). The input feature, shape (batch_size, feature_dim).
+ - context (:obj:`torch.Tensor`). The input context, shape (batch_size, context_dim).
+ Returns:
+ - conditioned_feature : torch.Tensor. The output feature after FiLM, shape (batch_size, feature_dim).
+ """
+
+ # Pass context through the fully connected layer
+ out = self.context_layer(context)
+ # Split the output into two parts: gamma and beta
+ # The dimension for splitting is 1 (feature dimension)
+ gamma, beta = torch.split(out, out.shape[1] // 2, dim=1)
+ # Apply feature-wise affine transformation
+ conditioned_feature = gamma * feature + beta
+ return conditioned_feature
+
+
+class GatingType(enum.Enum):
+ """
+ Overview:
+ Enum class defining different types of tensor gating and aggregation in modules.
+ """
+ NONE = 'none'
+ GLOBAL = 'global'
+ POINTWISE = 'pointwise'
+
+
+class SumMerge(nn.Module):
+ """
+ Overview:
+ A PyTorch module that merges a list of tensors by computing their sum. All input tensors must have the same
+ size. This module can work with any type of tensor (vector, units or visual).
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def forward(self, tensors: List[Tensor]) -> Tensor:
+ """
+ Overview:
+ Forward pass of the SumMerge module, which sums the input tensors.
+ Arguments:
+ - tensors (:obj:`List[Tensor]`): List of input tensors to be summed. All tensors must have the same size.
+ Returns:
+ - summed (:obj:`Tensor`): Tensor resulting from the sum of all input tensors.
+ """
+ # stack the tensors along the first dimension
+ stacked = torch.stack(tensors, dim=0)
+
+ # compute the sum along the first dimension
+ summed = torch.sum(stacked, dim=0)
+ # summed = sum(tensors)
+ return summed
+
+
+class VectorMerge(nn.Module):
+ """
+ Overview:
+ Merges multiple vector streams. Streams are first transformed through layer normalization, relu, and linear
+ layers, then summed. They don't need to have the same size. Gating can also be used before the sum.
+ Interfaces:
+ ``__init__``, ``encode``, ``_compute_gate``, ``forward``
+
+ .. note::
+ For more details about the gating types, please refer to the GatingType enum class.
+ """
+
+ def __init__(
+ self,
+ input_sizes: Dict[str, int],
+ output_size: int,
+ gating_type: GatingType = GatingType.NONE,
+ use_layer_norm: bool = True,
+ ):
+ """
+ Overview:
+ Initialize the `VectorMerge` module.
+ Arguments:
+ - input_sizes (:obj:`Dict[str, int]`): A dictionary mapping input names to their sizes. \
+ The size is a single integer for 1D inputs, or `None` for 0D inputs. \
+ If an input size is `None`, we assume it's `()`.
+ - output_size (:obj:`int`): The size of the output vector.
+ - gating_type (:obj:`GatingType`): The type of gating mechanism to use. Default is `GatingType.NONE`.
+ - use_layer_norm (:obj:`bool`): Whether to use layer normalization. Default is `True`.
+ """
+ super().__init__()
+ self._input_sizes = OrderedDict(input_sizes)
+ self._output_size = output_size
+ self._gating_type = gating_type
+ self._use_layer_norm = use_layer_norm
+
+ if self._use_layer_norm:
+ self._layer_norms = nn.ModuleDict()
+ else:
+ self._layer_norms = None
+
+ self._linears = nn.ModuleDict()
+ for name, size in self._input_sizes.items():
+ linear_input_size = size if size > 0 else 1
+ if self._use_layer_norm:
+ self._layer_norms[name] = nn.LayerNorm(linear_input_size)
+ self._linears[name] = nn.Linear(linear_input_size, self._output_size)
+
+ self._gating_linears = nn.ModuleDict()
+ if self._gating_type is GatingType.GLOBAL:
+ self.gate_size = 1
+ elif self._gating_type is GatingType.POINTWISE:
+ self.gate_size = self._output_size
+ elif self._gating_type is GatingType.NONE:
+ self._gating_linears = None
+ else:
+ raise ValueError(f'Gating type {self._gating_type} is not supported')
+
+ if self._gating_linears is not None:
+ if len(self._input_sizes) == 2:
+ # more efficient than the general version below
+ for name, size in self._input_sizes.items():
+ gate_input_size = size if size > 0 else 1
+ gating_layer = nn.Linear(gate_input_size, self.gate_size)
+ torch.nn.init.normal_(gating_layer.weight, std=0.005)
+ torch.nn.init.constant_(gating_layer.bias, 0.0)
+ self._gating_linears[name] = gating_layer
+ else:
+ for name, size in self._input_sizes.items():
+ gate_input_size = size if size > 0 else 1
+ gating_layer = nn.Linear(gate_input_size, len(self._input_sizes) * self.gate_size)
+ torch.nn.init.normal_(gating_layer.weight, std=0.005)
+ torch.nn.init.constant_(gating_layer.bias, 0.0)
+ self._gating_linears[name] = gating_layer
+
+ def encode(self, inputs: Dict[str, Tensor]) -> Tuple[List[Tensor], List[Tensor]]:
+ """
+ Overview:
+ Encode the input tensors using layer normalization, relu, and linear transformations.
+ Arguments:
+ - inputs (:obj:`Dict[str, Tensor]`): The input tensors.
+ Returns:
+ - gates (:obj:`List[Tensor]`): The gate tensors after transformations.
+ - outputs (:obj:`List[Tensor]`): The output tensors after transformations.
+ """
+ gates, outputs = [], []
+ for name, size in self._input_sizes.items():
+ feature = inputs[name]
+ if size <= 0 and feature.dim() == 1:
+ feature = feature.unsqueeze(-1)
+ feature = feature.to(torch.float32)
+ if self._use_layer_norm and name in self._layer_norms:
+ feature = self._layer_norms[name](feature)
+ feature = F.relu(feature)
+ gates.append(feature)
+ outputs.append(self._linears[name](feature))
+ return gates, outputs
+
+ def _compute_gate(
+ self,
+ init_gate: List[Tensor],
+ ) -> List[Tensor]:
+ """
+ Overview:
+ Compute the gate values based on the initial gate values.
+ Arguments:
+ - init_gate (:obj:`List[Tensor]`): The initial gate values.
+ Returns:
+ - gate (:obj:`List[Tensor]`): The computed gate values.
+ """
+ if len(self._input_sizes) == 2:
+ gate = [self._gating_linears[name](y) for name, y in zip(self._input_sizes.keys(), init_gate)]
+ gate = sum(gate)
+ sigmoid = torch.sigmoid(gate)
+ gate = [sigmoid, 1.0 - sigmoid]
+ else:
+ gate = [self._gating_linears[name](y) for name, y in zip(self._input_sizes.keys(), init_gate)]
+ gate = sum(gate)
+ gate = gate.reshape([-1, len(self._input_sizes), self.gate_size])
+ gate = F.softmax(gate, dim=1)
+ assert gate.shape[1] == len(self._input_sizes)
+ gate = [gate[:, i] for i in range(len(self._input_sizes))]
+ return gate
+
+ def forward(self, inputs: Dict[str, Tensor]) -> Tensor:
+ """
+ Overview:
+ Forward pass through the VectorMerge module.
+ Arguments:
+ - inputs (:obj:`Dict[str, Tensor]`): The input tensors.
+ Returns:
+ - output (:obj:`Tensor`): The output tensor after passing through the module.
+ """
+ gates, outputs = self.encode(inputs)
+ if len(outputs) == 1:
+ # Special case of 1-D inputs that do not need any gating.
+ output = outputs[0]
+ elif self._gating_type is GatingType.NONE:
+ output = sum(outputs)
+ else:
+ gate = self._compute_gate(gates)
+ data = [g * d for g, d in zip(gate, outputs)]
+ output = sum(data)
+ return output
diff --git a/DI-engine/ding/torch_utils/network/nn_module.py b/DI-engine/ding/torch_utils/network/nn_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..64a21edfe473b4ebe6aad881ac53a680f93ed26e
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/nn_module.py
@@ -0,0 +1,790 @@
+import math
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch.nn.init import xavier_normal_, kaiming_normal_, orthogonal_
+from typing import Union, Tuple, List, Callable
+from ding.compatibility import torch_ge_131
+
+from .normalization import build_normalization
+
+
+def weight_init_(weight: torch.Tensor, init_type: str = "xavier", activation: str = None) -> None:
+ """
+ Overview:
+ Initialize weight according to the specified type.
+ Arguments:
+ - weight (:obj:`torch.Tensor`): The weight that needs to be initialized.
+ - init_type (:obj:`str`, optional): The type of initialization to implement, \
+ supports ["xavier", "kaiming", "orthogonal"].
+ - activation (:obj:`str`, optional): The activation function name. Recommended to use only with \
+ ['relu', 'leaky_relu'].
+ """
+
+ def xavier_init(weight, *args):
+ xavier_normal_(weight)
+
+ def kaiming_init(weight, activation):
+ assert activation is not None
+ if hasattr(activation, "negative_slope"):
+ kaiming_normal_(weight, a=activation.negative_slope)
+ else:
+ kaiming_normal_(weight, a=0)
+
+ def orthogonal_init(weight, *args):
+ orthogonal_(weight)
+
+ if init_type is None:
+ return
+ init_type_dict = {"xavier": xavier_init, "kaiming": kaiming_init, "orthogonal": orthogonal_init}
+ if init_type in init_type_dict:
+ init_type_dict[init_type](weight, activation)
+ else:
+ raise KeyError("Invalid Value in init type: {}".format(init_type))
+
+
+def sequential_pack(layers: List[nn.Module]) -> nn.Sequential:
+ """
+ Overview:
+ Pack the layers in the input list to a `nn.Sequential` module.
+ If there is a convolutional layer in module, an extra attribute `out_channels` will be added
+ to the module and set to the out_channel of the conv layer.
+ Arguments:
+ - layers (:obj:`List[nn.Module]`): The input list of layers.
+ Returns:
+ - seq (:obj:`nn.Sequential`): Packed sequential container.
+ """
+ assert isinstance(layers, list)
+ seq = nn.Sequential(*layers)
+ for item in reversed(layers):
+ if isinstance(item, nn.Conv2d) or isinstance(item, nn.ConvTranspose2d):
+ seq.out_channels = item.out_channels
+ break
+ elif isinstance(item, nn.Conv1d):
+ seq.out_channels = item.out_channels
+ break
+ return seq
+
+
+def conv1d_block(
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ groups: int = 1,
+ activation: nn.Module = None,
+ norm_type: str = None
+) -> nn.Sequential:
+ """
+ Overview:
+ Create a 1-dimensional convolution layer with activation and normalization.
+ Arguments:
+ - in_channels (:obj:`int`): Number of channels in the input tensor.
+ - out_channels (:obj:`int`): Number of channels in the output tensor.
+ - kernel_size (:obj:`int`): Size of the convolving kernel.
+ - stride (:obj:`int`, optional): Stride of the convolution. Default is 1.
+ - padding (:obj:`int`, optional): Zero-padding added to both sides of the input. Default is 0.
+ - dilation (:obj:`int`, optional): Spacing between kernel elements. Default is 1.
+ - groups (:obj:`int`, optional): Number of blocked connections from input channels to output channels. \
+ Default is 1.
+ - activation (:obj:`nn.Module`, optional): The optional activation function.
+ - norm_type (:obj:`str`, optional): Type of the normalization.
+ Returns:
+ - block (:obj:`nn.Sequential`): A sequential list containing the torch layers of the 1-dimensional \
+ convolution layer.
+
+ .. note::
+ Conv1d (https://pytorch.org/docs/stable/generated/torch.nn.Conv1d.html#torch.nn.Conv1d)
+ """
+ block = []
+ block.append(nn.Conv1d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups))
+ if norm_type is not None:
+ block.append(build_normalization(norm_type, dim=1)(out_channels))
+ if activation is not None:
+ block.append(activation)
+ return sequential_pack(block)
+
+
+def conv2d_block(
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ dilation: int = 1,
+ groups: int = 1,
+ pad_type: str = 'zero',
+ activation: nn.Module = None,
+ norm_type: str = None,
+ num_groups_for_gn: int = 1,
+ bias: bool = True
+) -> nn.Sequential:
+ """
+ Overview:
+ Create a 2-dimensional convolution layer with activation and normalization.
+ Arguments:
+ - in_channels (:obj:`int`): Number of channels in the input tensor.
+ - out_channels (:obj:`int`): Number of channels in the output tensor.
+ - kernel_size (:obj:`int`): Size of the convolving kernel.
+ - stride (:obj:`int`, optional): Stride of the convolution. Default is 1.
+ - padding (:obj:`int`, optional): Zero-padding added to both sides of the input. Default is 0.
+ - dilation (:obj:`int`): Spacing between kernel elements.
+ - groups (:obj:`int`, optional): Number of blocked connections from input channels to output channels. \
+ Default is 1.
+ - pad_type (:obj:`str`, optional): The way to add padding, include ['zero', 'reflect', 'replicate']. \
+ Default is 'zero'.
+ - activation (:obj:`nn.Module`): the optional activation function.
+ - norm_type (:obj:`str`): The type of the normalization, now support ['BN', 'LN', 'IN', 'GN', 'SyncBN'], \
+ default set to None, which means no normalization.
+ - num_groups_for_gn (:obj:`int`): Number of groups for GroupNorm.
+ - bias (:obj:`bool`): whether to add a learnable bias to the nn.Conv2d. Default is True.
+ Returns:
+ - block (:obj:`nn.Sequential`): A sequential list containing the torch layers of the 2-dimensional \
+ convolution layer.
+
+ .. note::
+ Conv2d (https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html#torch.nn.Conv2d)
+ """
+ block = []
+ assert pad_type in ['zero', 'reflect', 'replication'], "invalid padding type: {}".format(pad_type)
+ if pad_type == 'zero':
+ pass
+ elif pad_type == 'reflect':
+ block.append(nn.ReflectionPad2d(padding))
+ padding = 0
+ elif pad_type == 'replication':
+ block.append(nn.ReplicationPad2d(padding))
+ padding = 0
+ block.append(
+ nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ bias=bias
+ )
+ )
+ if norm_type is not None:
+ if norm_type == 'LN':
+ # LN is implemented as GroupNorm with 1 group.
+ block.append(nn.GroupNorm(1, out_channels))
+ elif norm_type == 'GN':
+ block.append(nn.GroupNorm(num_groups_for_gn, out_channels))
+ elif norm_type in ['BN', 'IN', 'SyncBN']:
+ block.append(build_normalization(norm_type, dim=2)(out_channels))
+ else:
+ raise KeyError(
+ "Invalid value in norm_type: {}. The valid norm_type are "
+ "BN, LN, IN, GN and SyncBN.".format(norm_type)
+ )
+
+ if activation is not None:
+ block.append(activation)
+ return sequential_pack(block)
+
+
+def deconv2d_block(
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ padding: int = 0,
+ output_padding: int = 0,
+ groups: int = 1,
+ activation: int = None,
+ norm_type: int = None
+) -> nn.Sequential:
+ """
+ Overview:
+ Create a 2-dimensional transpose convolution layer with activation and normalization.
+ Arguments:
+ - in_channels (:obj:`int`): Number of channels in the input tensor.
+ - out_channels (:obj:`int`): Number of channels in the output tensor.
+ - kernel_size (:obj:`int`): Size of the convolving kernel.
+ - stride (:obj:`int`, optional): Stride of the convolution. Default is 1.
+ - padding (:obj:`int`, optional): Zero-padding added to both sides of the input. Default is 0.
+ - output_padding (:obj:`int`, optional): Additional size added to one side of the output shape. Default is 0.
+ - groups (:obj:`int`, optional): Number of blocked connections from input channels to output channels. \
+ Default is 1.
+ - activation (:obj:`int`, optional): The optional activation function.
+ - norm_type (:obj:`int`, optional): Type of the normalization.
+ Returns:
+ - block (:obj:`nn.Sequential`): A sequential list containing the torch layers of the 2-dimensional \
+ transpose convolution layer.
+
+ .. note::
+
+ ConvTranspose2d (https://pytorch.org/docs/master/generated/torch.nn.ConvTranspose2d.html)
+ """
+ block = [
+ nn.ConvTranspose2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ output_padding=output_padding,
+ groups=groups
+ )
+ ]
+ if norm_type is not None:
+ block.append(build_normalization(norm_type, dim=2)(out_channels))
+ if activation is not None:
+ block.append(activation)
+ return sequential_pack(block)
+
+
+def fc_block(
+ in_channels: int,
+ out_channels: int,
+ activation: nn.Module = None,
+ norm_type: str = None,
+ use_dropout: bool = False,
+ dropout_probability: float = 0.5
+) -> nn.Sequential:
+ """
+ Overview:
+ Create a fully-connected block with activation, normalization, and dropout.
+ Optional normalization can be done to the dim 1 (across the channels).
+ x -> fc -> norm -> act -> dropout -> out
+ Arguments:
+ - in_channels (:obj:`int`): Number of channels in the input tensor.
+ - out_channels (:obj:`int`): Number of channels in the output tensor.
+ - activation (:obj:`nn.Module`, optional): The optional activation function.
+ - norm_type (:obj:`str`, optional): Type of the normalization.
+ - use_dropout (:obj:`bool`, optional): Whether to use dropout in the fully-connected block. Default is False.
+ - dropout_probability (:obj:`float`, optional): Probability of an element to be zeroed in the dropout. \
+ Default is 0.5.
+ Returns:
+ - block (:obj:`nn.Sequential`): A sequential list containing the torch layers of the fully-connected block.
+
+ .. note::
+
+ You can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html).
+ """
+ block = []
+ block.append(nn.Linear(in_channels, out_channels))
+ if norm_type is not None:
+ block.append(build_normalization(norm_type, dim=1)(out_channels))
+ if activation is not None:
+ block.append(activation)
+ if use_dropout:
+ block.append(nn.Dropout(dropout_probability))
+ return sequential_pack(block)
+
+
+def normed_linear(
+ in_features: int,
+ out_features: int,
+ bias: bool = True,
+ device=None,
+ dtype=None,
+ scale: float = 1.0
+) -> nn.Linear:
+ """
+ Overview:
+ Create a nn.Linear module but with normalized fan-in init.
+ Arguments:
+ - in_features (:obj:`int`): Number of features in the input tensor.
+ - out_features (:obj:`int`): Number of features in the output tensor.
+ - bias (:obj:`bool`, optional): Whether to add a learnable bias to the nn.Linear. Default is True.
+ - device (:obj:`torch.device`, optional): The device to put the created module on. Default is None.
+ - dtype (:obj:`torch.dtype`, optional): The desired data type of created module. Default is None.
+ - scale (:obj:`float`, optional): The scale factor for initialization. Default is 1.0.
+ Returns:
+ - out (:obj:`nn.Linear`): A nn.Linear module with normalized fan-in init.
+ """
+
+ out = nn.Linear(in_features, out_features, bias)
+
+ out.weight.data *= scale / out.weight.norm(dim=1, p=2, keepdim=True)
+ if bias:
+ out.bias.data.zero_()
+ return out
+
+
+def normed_conv2d(
+ in_channels: int,
+ out_channels: int,
+ kernel_size: Union[int, Tuple[int, int]],
+ stride: Union[int, Tuple[int, int]] = 1,
+ padding: Union[int, Tuple[int, int]] = 0,
+ dilation: Union[int, Tuple[int, int]] = 1,
+ groups: int = 1,
+ bias: bool = True,
+ padding_mode: str = 'zeros',
+ device=None,
+ dtype=None,
+ scale: float = 1
+) -> nn.Conv2d:
+ """
+ Overview:
+ Create a nn.Conv2d module but with normalized fan-in init.
+ Arguments:
+ - in_channels (:obj:`int`): Number of channels in the input tensor.
+ - out_channels (:obj:`int`): Number of channels in the output tensor.
+ - kernel_size (:obj:`Union[int, Tuple[int, int]]`): Size of the convolving kernel.
+ - stride (:obj:`Union[int, Tuple[int, int]]`, optional): Stride of the convolution. Default is 1.
+ - padding (:obj:`Union[int, Tuple[int, int]]`, optional): Zero-padding added to both sides of the input. \
+ Default is 0.
+ - dilation (:`Union[int, Tuple[int, int]]`, optional): Spacing between kernel elements. Default is 1.
+ - groups (:obj:`int`, optional): Number of blocked connections from input channels to output channels. \
+ Default is 1.
+ - bias (:obj:`bool`, optional): Whether to add a learnable bias to the nn.Conv2d. Default is True.
+ - padding_mode (:obj:`str`, optional): The type of padding algorithm to use. Default is 'zeros'.
+ - device (:obj:`torch.device`, optional): The device to put the created module on. Default is None.
+ - dtype (:obj:`torch.dtype`, optional): The desired data type of created module. Default is None.
+ - scale (:obj:`float`, optional): The scale factor for initialization. Default is 1.
+ Returns:
+ - out (:obj:`nn.Conv2d`): A nn.Conv2d module with normalized fan-in init.
+ """
+
+ out = nn.Conv2d(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride,
+ padding,
+ dilation,
+ groups,
+ bias,
+ padding_mode,
+ )
+ out.weight.data *= scale / out.weight.norm(dim=(1, 2, 3), p=2, keepdim=True)
+ if bias:
+ out.bias.data.zero_()
+ return out
+
+
+def MLP(
+ in_channels: int,
+ hidden_channels: int,
+ out_channels: int,
+ layer_num: int,
+ layer_fn: Callable = None,
+ activation: nn.Module = None,
+ norm_type: str = None,
+ use_dropout: bool = False,
+ dropout_probability: float = 0.5,
+ output_activation: bool = True,
+ output_norm: bool = True,
+ last_linear_layer_init_zero: bool = False
+):
+ """
+ Overview:
+ Create a multi-layer perceptron using fully-connected blocks with activation, normalization, and dropout,
+ optional normalization can be done to the dim 1 (across the channels).
+ x -> fc -> norm -> act -> dropout -> out
+ Arguments:
+ - in_channels (:obj:`int`): Number of channels in the input tensor.
+ - hidden_channels (:obj:`int`): Number of channels in the hidden tensor.
+ - out_channels (:obj:`int`): Number of channels in the output tensor.
+ - layer_num (:obj:`int`): Number of layers.
+ - layer_fn (:obj:`Callable`, optional): Layer function.
+ - activation (:obj:`nn.Module`, optional): The optional activation function.
+ - norm_type (:obj:`str`, optional): The type of the normalization.
+ - use_dropout (:obj:`bool`, optional): Whether to use dropout in the fully-connected block. Default is False.
+ - dropout_probability (:obj:`float`, optional): Probability of an element to be zeroed in the dropout. \
+ Default is 0.5.
+ - output_activation (:obj:`bool`, optional): Whether to use activation in the output layer. If True, \
+ we use the same activation as front layers. Default is True.
+ - output_norm (:obj:`bool`, optional): Whether to use normalization in the output layer. If True, \
+ we use the same normalization as front layers. Default is True.
+ - last_linear_layer_init_zero (:obj:`bool`, optional): Whether to use zero initializations for the last \
+ linear layer (including w and b), which can provide stable zero outputs in the beginning, \
+ usually used in the policy network in RL settings.
+ Returns:
+ - block (:obj:`nn.Sequential`): A sequential list containing the torch layers of the multi-layer perceptron.
+
+ .. note::
+ you can refer to nn.linear (https://pytorch.org/docs/master/generated/torch.nn.Linear.html).
+ """
+ assert layer_num >= 0, layer_num
+ if layer_num == 0:
+ return sequential_pack([nn.Identity()])
+
+ channels = [in_channels] + [hidden_channels] * (layer_num - 1) + [out_channels]
+ if layer_fn is None:
+ layer_fn = nn.Linear
+ block = []
+ for i, (in_channels, out_channels) in enumerate(zip(channels[:-2], channels[1:-1])):
+ block.append(layer_fn(in_channels, out_channels))
+ if norm_type is not None:
+ block.append(build_normalization(norm_type, dim=1)(out_channels))
+ if activation is not None:
+ block.append(activation)
+ if use_dropout:
+ block.append(nn.Dropout(dropout_probability))
+
+ # The last layer
+ in_channels = channels[-2]
+ out_channels = channels[-1]
+ block.append(layer_fn(in_channels, out_channels))
+ """
+ In the final layer of a neural network, whether to use normalization and activation are typically determined
+ based on user specifications. These specifications depend on the problem at hand and the desired properties of
+ the model's output.
+ """
+ if output_norm is True:
+ # The last layer uses the same norm as front layers.
+ if norm_type is not None:
+ block.append(build_normalization(norm_type, dim=1)(out_channels))
+ if output_activation is True:
+ # The last layer uses the same activation as front layers.
+ if activation is not None:
+ block.append(activation)
+ if use_dropout:
+ block.append(nn.Dropout(dropout_probability))
+
+ if last_linear_layer_init_zero:
+ # Locate the last linear layer and initialize its weights and biases to 0.
+ for _, layer in enumerate(reversed(block)):
+ if isinstance(layer, nn.Linear):
+ nn.init.zeros_(layer.weight)
+ nn.init.zeros_(layer.bias)
+ break
+
+ return sequential_pack(block)
+
+
+class ChannelShuffle(nn.Module):
+ """
+ Overview:
+ Apply channel shuffle to the input tensor. For more details about the channel shuffle,
+ please refer to the 'ShuffleNet' paper: https://arxiv.org/abs/1707.01083
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, group_num: int) -> None:
+ """
+ Overview:
+ Initialize the ChannelShuffle class.
+ Arguments:
+ - group_num (:obj:`int`): The number of groups to exchange.
+ """
+ super().__init__()
+ self.group_num = group_num
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Forward pass through the ChannelShuffle module.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - x (:obj:`torch.Tensor`): The shuffled input tensor.
+ """
+ b, c, h, w = x.shape
+ g = self.group_num
+ assert (c % g == 0)
+ x = x.view(b, g, c // g, h, w).permute(0, 2, 1, 3, 4).contiguous().view(b, c, h, w)
+ return x
+
+
+def one_hot(val: torch.LongTensor, num: int, num_first: bool = False) -> torch.FloatTensor:
+ """
+ Overview:
+ Convert a torch.LongTensor to one-hot encoding. This implementation can be slightly faster than
+ ``torch.nn.functional.one_hot``.
+ Arguments:
+ - val (:obj:`torch.LongTensor`): Each element contains the state to be encoded, the range should be [0, num-1]
+ - num (:obj:`int`): Number of states of the one-hot encoding
+ - num_first (:obj:`bool`, optional): If False, the one-hot encoding is added as the last dimension; otherwise, \
+ it is added as the first dimension. Default is False.
+ Returns:
+ - one_hot (:obj:`torch.FloatTensor`): The one-hot encoded tensor.
+ Example:
+ >>> one_hot(2*torch.ones([2,2]).long(),3)
+ tensor([[[0., 0., 1.],
+ [0., 0., 1.]],
+ [[0., 0., 1.],
+ [0., 0., 1.]]])
+ >>> one_hot(2*torch.ones([2,2]).long(),3,num_first=True)
+ tensor([[[0., 0.], [1., 0.]],
+ [[0., 1.], [0., 0.]],
+ [[1., 0.], [0., 1.]]])
+ """
+ assert (isinstance(val, torch.Tensor)), type(val)
+ assert val.dtype == torch.long
+ assert (len(val.shape) >= 1)
+ old_shape = val.shape
+ val_reshape = val.reshape(-1, 1)
+ ret = torch.zeros(val_reshape.shape[0], num, device=val.device)
+ # To remember the location where the original value is -1 in val.
+ # If the value is -1, then it should be converted to all zeros encodings and
+ # the corresponding entry in index_neg_one is 1, which is used to transform
+ # the ret after the operation of ret.scatter_(1, val_reshape, 1) to their correct encodings bellowing
+ index_neg_one = torch.eq(val_reshape, -1).float()
+ if index_neg_one.sum() != 0: # if -1 exists in val
+ # convert the original value -1 to 0
+ val_reshape = torch.where(
+ val_reshape != -1, val_reshape,
+ torch.zeros(val_reshape.shape, device=val.device).long()
+ )
+ try:
+ ret.scatter_(1, val_reshape, 1)
+ if index_neg_one.sum() != 0: # if -1 exists in val
+ ret = ret * (1 - index_neg_one) # change -1's encoding from [1,0,...,0] to [0,0,...,0]
+ except RuntimeError:
+ raise RuntimeError('value: {}\nnum: {}\t:val_shape: {}\n'.format(val_reshape, num, val_reshape.shape))
+ if num_first:
+ return ret.permute(1, 0).reshape(num, *old_shape)
+ else:
+ return ret.reshape(*old_shape, num)
+
+
+class NearestUpsample(nn.Module):
+ """
+ Overview:
+ This module upsamples the input to the given scale_factor using the nearest mode.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, scale_factor: Union[float, List[float]]) -> None:
+ """
+ Overview:
+ Initialize the NearestUpsample class.
+ Arguments:
+ - scale_factor (:obj:`Union[float, List[float]]`): The multiplier for the spatial size.
+ """
+ super(NearestUpsample, self).__init__()
+ self.scale_factor = scale_factor
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Return the upsampled input tensor.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - upsample(:obj:`torch.Tensor`): The upsampled input tensor.
+ """
+ return F.interpolate(x, scale_factor=self.scale_factor, mode='nearest')
+
+
+class BilinearUpsample(nn.Module):
+ """
+ Overview:
+ This module upsamples the input to the given scale_factor using the bilinear mode.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, scale_factor: Union[float, List[float]]) -> None:
+ """
+ Overview:
+ Initialize the BilinearUpsample class.
+ Arguments:
+ - scale_factor (:obj:`Union[float, List[float]]`): The multiplier for the spatial size.
+ """
+ super(BilinearUpsample, self).__init__()
+ self.scale_factor = scale_factor
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Return the upsampled input tensor.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - upsample(:obj:`torch.Tensor`): The upsampled input tensor.
+ """
+ return F.interpolate(x, scale_factor=self.scale_factor, mode='bilinear', align_corners=False)
+
+
+def binary_encode(y: torch.Tensor, max_val: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Convert elements in a tensor to its binary representation.
+ Arguments:
+ - y (:obj:`torch.Tensor`): The tensor to be converted into its binary representation.
+ - max_val (:obj:`torch.Tensor`): The maximum value of the elements in the tensor.
+ Returns:
+ - binary (:obj:`torch.Tensor`): The input tensor in its binary representation.
+ Example:
+ >>> binary_encode(torch.tensor([3,2]),torch.tensor(8))
+ tensor([[0, 0, 1, 1],[0, 0, 1, 0]])
+ """
+ assert (max_val > 0)
+ x = y.clamp(0, max_val)
+ L = int(math.log(max_val, 2)) + 1
+ binary = []
+ one = torch.ones_like(x)
+ zero = torch.zeros_like(x)
+ for i in range(L):
+ num = 1 << (L - i - 1) # 2**(L-i-1)
+ bit = torch.where(x >= num, one, zero)
+ x -= bit * num
+ binary.append(bit)
+ return torch.stack(binary, dim=1)
+
+
+class NoiseLinearLayer(nn.Module):
+ """
+ Overview:
+ This is a linear layer with random noise.
+ Interfaces:
+ ``__init__``, ``reset_noise``, ``reset_parameters``, ``forward``
+ """
+
+ def __init__(self, in_channels: int, out_channels: int, sigma0: int = 0.4) -> None:
+ """
+ Overview:
+ Initialize the NoiseLinearLayer class.
+ Arguments:
+ - in_channels (:obj:`int`): Number of channels in the input tensor.
+ - out_channels (:obj:`int`): Number of channels in the output tensor.
+ - sigma0 (:obj:`int`, optional): Default noise volume when initializing NoiseLinearLayer. \
+ Default is 0.4.
+ """
+ super(NoiseLinearLayer, self).__init__()
+ self.in_channels = in_channels
+ self.out_channels = out_channels
+ self.weight_mu = nn.Parameter(torch.Tensor(out_channels, in_channels))
+ self.weight_sigma = nn.Parameter(torch.Tensor(out_channels, in_channels))
+ self.bias_mu = nn.Parameter(torch.Tensor(out_channels))
+ self.bias_sigma = nn.Parameter(torch.Tensor(out_channels))
+ self.register_buffer("weight_eps", torch.empty(out_channels, in_channels))
+ self.register_buffer("bias_eps", torch.empty(out_channels))
+ self.sigma0 = sigma0
+ self.reset_parameters()
+ self.reset_noise()
+
+ def _scale_noise(self, size: Union[int, Tuple]):
+ """
+ Overview:
+ Scale the noise.
+ Arguments:
+ - size (:obj:`Union[int, Tuple]`): The size of the noise.
+ """
+
+ x = torch.randn(size)
+ x = x.sign().mul(x.abs().sqrt())
+ return x
+
+ def reset_noise(self):
+ """
+ Overview:
+ Reset the noise settings in the layer.
+ """
+ is_cuda = self.weight_mu.is_cuda
+ in_noise = self._scale_noise(self.in_channels).to(torch.device("cuda" if is_cuda else "cpu"))
+ out_noise = self._scale_noise(self.out_channels).to(torch.device("cuda" if is_cuda else "cpu"))
+ self.weight_eps = out_noise.ger(in_noise)
+ self.bias_eps = out_noise
+
+ def reset_parameters(self):
+ """
+ Overview:
+ Reset the parameters in the layer.
+ """
+ stdv = 1. / math.sqrt(self.in_channels)
+ self.weight_mu.data.uniform_(-stdv, stdv)
+ self.bias_mu.data.uniform_(-stdv, stdv)
+
+ std_weight = self.sigma0 / math.sqrt(self.in_channels)
+ self.weight_sigma.data.fill_(std_weight)
+ std_bias = self.sigma0 / math.sqrt(self.out_channels)
+ self.bias_sigma.data.fill_(std_bias)
+
+ def forward(self, x: torch.Tensor):
+ """
+ Overview:
+ Perform the forward pass with noise.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - output (:obj:`torch.Tensor`): The output tensor with noise.
+ """
+ if self.training:
+ return F.linear(
+ x,
+ self.weight_mu + self.weight_sigma * self.weight_eps,
+ self.bias_mu + self.bias_sigma * self.bias_eps,
+ )
+ else:
+ return F.linear(x, self.weight_mu, self.bias_mu)
+
+
+def noise_block(
+ in_channels: int,
+ out_channels: int,
+ activation: str = None,
+ norm_type: str = None,
+ use_dropout: bool = False,
+ dropout_probability: float = 0.5,
+ sigma0: float = 0.4
+):
+ """
+ Overview:
+ Create a fully-connected noise layer with activation, normalization, and dropout.
+ Optional normalization can be done to the dim 1 (across the channels).
+ Arguments:
+ - in_channels (:obj:`int`): Number of channels in the input tensor.
+ - out_channels (:obj:`int`): Number of channels in the output tensor.
+ - activation (:obj:`str`, optional): The optional activation function. Default is None.
+ - norm_type (:obj:`str`, optional): Type of normalization. Default is None.
+ - use_dropout (:obj:`bool`, optional): Whether to use dropout in the fully-connected block.
+ - dropout_probability (:obj:`float`, optional): Probability of an element to be zeroed in the dropout. \
+ Default is 0.5.
+ - sigma0 (:obj:`float`, optional): The sigma0 is the default noise volume when initializing NoiseLinearLayer. \
+ Default is 0.4.
+ Returns:
+ - block (:obj:`nn.Sequential`): A sequential list containing the torch layers of the fully-connected block.
+ """
+ block = [NoiseLinearLayer(in_channels, out_channels, sigma0=sigma0)]
+ if norm_type is not None:
+ block.append(build_normalization(norm_type, dim=1)(out_channels))
+ if activation is not None:
+ block.append(activation)
+ if use_dropout:
+ block.append(nn.Dropout(dropout_probability))
+ return sequential_pack(block)
+
+
+class NaiveFlatten(nn.Module):
+ """
+ Overview:
+ This module is a naive implementation of the flatten operation.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, start_dim: int = 1, end_dim: int = -1) -> None:
+ """
+ Overview:
+ Initialize the NaiveFlatten class.
+ Arguments:
+ - start_dim (:obj:`int`, optional): The first dimension to flatten. Default is 1.
+ - end_dim (:obj:`int`, optional): The last dimension to flatten. Default is -1.
+ """
+ super(NaiveFlatten, self).__init__()
+ self.start_dim = start_dim
+ self.end_dim = end_dim
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Perform the flatten operation on the input tensor.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - output (:obj:`torch.Tensor`): The flattened output tensor.
+ """
+ if self.end_dim != -1:
+ return x.view(*x.shape[:self.start_dim], -1, *x.shape[self.end_dim + 1:])
+ else:
+ return x.view(*x.shape[:self.start_dim], -1)
+
+
+if torch_ge_131():
+ Flatten = nn.Flatten
+else:
+ Flatten = NaiveFlatten
diff --git a/DI-engine/ding/torch_utils/network/normalization.py b/DI-engine/ding/torch_utils/network/normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6d4c8df3215ed3dd061184faaf64346383dfb3c
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/normalization.py
@@ -0,0 +1,36 @@
+from typing import Optional
+import torch.nn as nn
+
+
+def build_normalization(norm_type: str, dim: Optional[int] = None) -> nn.Module:
+ """
+ Overview:
+ Construct the corresponding normalization module. For beginners,
+ refer to [this article](https://zhuanlan.zhihu.com/p/34879333) to learn more about batch normalization.
+ Arguments:
+ - norm_type (:obj:`str`): Type of the normalization. Currently supports ['BN', 'LN', 'IN', 'SyncBN'].
+ - dim (:obj:`Optional[int]`): Dimension of the normalization, applicable when norm_type is in ['BN', 'IN'].
+ Returns:
+ - norm_func (:obj:`nn.Module`): The corresponding batch normalization function.
+ """
+ if dim is None:
+ key = norm_type
+ else:
+ if norm_type in ['BN', 'IN']:
+ key = norm_type + str(dim)
+ elif norm_type in ['LN', 'SyncBN']:
+ key = norm_type
+ else:
+ raise NotImplementedError("not support indicated dim when creates {}".format(norm_type))
+ norm_func = {
+ 'BN1': nn.BatchNorm1d,
+ 'BN2': nn.BatchNorm2d,
+ 'LN': nn.LayerNorm,
+ 'IN1': nn.InstanceNorm1d,
+ 'IN2': nn.InstanceNorm2d,
+ 'SyncBN': nn.SyncBatchNorm,
+ }
+ if key in norm_func.keys():
+ return norm_func[key]
+ else:
+ raise KeyError("invalid norm type: {}".format(key))
diff --git a/DI-engine/ding/torch_utils/network/popart.py b/DI-engine/ding/torch_utils/network/popart.py
new file mode 100644
index 0000000000000000000000000000000000000000..e01406a57ac96f8c19220c06e698901c95f61c9e
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/popart.py
@@ -0,0 +1,125 @@
+"""
+Implementation of ``POPART`` algorithm for reward rescale.
+
+
+POPART is an adaptive normalization algorithm to normalize the targets used in the learning updates.
+The two main components in POPART are:
+**ART**: to update scale and shift such that the return is appropriately normalized,
+**POP**: to preserve the outputs of the unnormalized function when we change the scale and shift.
+
+"""
+from typing import Optional, Union, Dict
+import math
+import torch
+import torch.nn as nn
+
+
+class PopArt(nn.Module):
+ """
+ Overview:
+ A linear layer with popart normalization. This class implements a linear transformation followed by
+ PopArt normalization, which is a method to automatically adapt the contribution of each task to the agent's
+ updates in multi-task learning, as described in the paper .
+
+ Interfaces:
+ ``__init__``, ``reset_parameters``, ``forward``, ``update_parameters``
+ """
+
+ def __init__(
+ self,
+ input_features: Union[int, None] = None,
+ output_features: Union[int, None] = None,
+ beta: float = 0.5
+ ) -> None:
+ """
+ Overview:
+ Initialize the class with input features, output features, and the beta parameter.
+ Arguments:
+ - input_features (:obj:`Union[int, None]`): The size of each input sample.
+ - output_features (:obj:`Union[int, None]`): The size of each output sample.
+ - beta (:obj:`float`): The parameter for moving average.
+ """
+ super(PopArt, self).__init__()
+
+ self.beta = beta
+ self.input_features = input_features
+ self.output_features = output_features
+ # Initialize the linear layer parameters, weight and bias.
+ self.weight = nn.Parameter(torch.Tensor(output_features, input_features))
+ self.bias = nn.Parameter(torch.Tensor(output_features))
+ # Register a buffer for normalization parameters which can not be considered as model parameters.
+ # The normalization parameters will be used later to save the target value's scale and shift.
+ self.register_buffer('mu', torch.zeros(output_features, requires_grad=False))
+ self.register_buffer('sigma', torch.ones(output_features, requires_grad=False))
+ self.register_buffer('v', torch.ones(output_features, requires_grad=False))
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ """
+ Overview:
+ Reset the parameters including weights and bias using kaiming_uniform_ and uniform_ initialization.
+ """
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
+ if self.bias is not None:
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
+ bound = 1 / math.sqrt(fan_in)
+ nn.init.uniform_(self.bias, -bound, bound)
+
+ def forward(self, x: torch.Tensor) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Implement the forward computation of the linear layer and return both the output and the
+ normalized output of the layer.
+ Arguments:
+ - x (:obj:`torch.Tensor`): Input tensor which is to be normalized.
+ Returns:
+ - output (:obj:`Dict[str, torch.Tensor]`): A dictionary contains 'pred' and 'unnormalized_pred'.
+ """
+ normalized_output = x.mm(self.weight.t())
+ normalized_output += self.bias.unsqueeze(0).expand_as(normalized_output)
+ # The unnormalization of output
+ with torch.no_grad():
+ output = normalized_output * self.sigma + self.mu
+
+ return {'pred': normalized_output.squeeze(1), 'unnormalized_pred': output.squeeze(1)}
+
+ def update_parameters(self, value: torch.Tensor) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Update the normalization parameters based on the given value and return the new mean and
+ standard deviation after the update.
+ Arguments:
+ - value (:obj:`torch.Tensor`): The tensor to be used for updating parameters.
+ Returns:
+ - update_results (:obj:`Dict[str, torch.Tensor]`): A dictionary contains 'new_mean' and 'new_std'.
+ """
+ # Tensor device conversion of the normalization parameters.
+ self.mu = self.mu.to(value.device)
+ self.sigma = self.sigma.to(value.device)
+ self.v = self.v.to(value.device)
+
+ old_mu = self.mu
+ old_std = self.sigma
+
+ # Calculate the first and second moments (mean and variance) of the target value:
+ batch_mean = torch.mean(value, 0)
+ batch_v = torch.mean(torch.pow(value, 2), 0)
+ batch_mean[torch.isnan(batch_mean)] = self.mu[torch.isnan(batch_mean)]
+ batch_v[torch.isnan(batch_v)] = self.v[torch.isnan(batch_v)]
+ batch_mean = (1 - self.beta) * self.mu + self.beta * batch_mean
+ batch_v = (1 - self.beta) * self.v + self.beta * batch_v
+ batch_std = torch.sqrt(batch_v - (batch_mean ** 2))
+ # Clip the standard deviation to reject the outlier data.
+ batch_std = torch.clamp(batch_std, min=1e-4, max=1e+6)
+ # Replace the nan value with old value.
+ batch_std[torch.isnan(batch_std)] = self.sigma[torch.isnan(batch_std)]
+
+ self.mu = batch_mean
+ self.v = batch_v
+ self.sigma = batch_std
+ # Update weight and bias with mean and standard deviation to preserve unnormalised outputs
+ self.weight.data = (self.weight.data.t() * old_std / self.sigma).t()
+ self.bias.data = (old_std * self.bias.data + old_mu - self.mu) / self.sigma
+
+ return {'new_mean': batch_mean, 'new_std': batch_std}
diff --git a/DI-engine/ding/torch_utils/network/res_block.py b/DI-engine/ding/torch_utils/network/res_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..14223f940c0efa71208e607b9ca41207d4a39fbe
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/res_block.py
@@ -0,0 +1,152 @@
+from typing import Union
+
+import torch
+import torch.nn as nn
+
+from .nn_module import conv2d_block, fc_block
+
+
+class ResBlock(nn.Module):
+ """
+ Overview:
+ Residual Block with 2D convolution layers, including 3 types:
+ basic block:
+ input channel: C
+ x -> 3*3*C -> norm -> act -> 3*3*C -> norm -> act -> out
+ \__________________________________________/+
+ bottleneck block:
+ x -> 1*1*(1/4*C) -> norm -> act -> 3*3*(1/4*C) -> norm -> act -> 1*1*C -> norm -> act -> out
+ \_____________________________________________________________________________/+
+ downsample block: used in EfficientZero
+ input channel: C
+ x -> 3*3*C -> norm -> act -> 3*3*C -> norm -> act -> out
+ \__________________ 3*3*C ____________________/+
+ For more details, please refer to `Deep Residual Learning for Image Recognition
+ `_.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ in_channels: int,
+ activation: nn.Module = nn.ReLU(),
+ norm_type: str = 'BN',
+ res_type: str = 'basic',
+ bias: bool = True,
+ out_channels: Union[int, None] = None,
+ ) -> None:
+ """
+ Overview:
+ Init the 2D convolution residual block.
+ Arguments:
+ - in_channels (:obj:`int`): Number of channels in the input tensor.
+ - activation (:obj:`nn.Module`): The optional activation function.
+ - norm_type (:obj:`str`): Type of the normalization, default set to 'BN'(Batch Normalization), \
+ supports ['BN', 'LN', 'IN', 'GN', 'SyncBN', None].
+ - res_type (:obj:`str`): Type of residual block, supports ['basic', 'bottleneck', 'downsample']
+ - bias (:obj:`bool`): Whether to add a learnable bias to the conv2d_block. default set to True.
+ - out_channels (:obj:`int`): Number of channels in the output tensor, default set to None, \
+ which means out_channels = in_channels.
+ """
+ super(ResBlock, self).__init__()
+ self.act = activation
+ assert res_type in ['basic', 'bottleneck',
+ 'downsample'], 'residual type only support basic and bottleneck, not:{}'.format(res_type)
+ self.res_type = res_type
+ if out_channels is None:
+ out_channels = in_channels
+ if self.res_type == 'basic':
+ self.conv1 = conv2d_block(
+ in_channels, out_channels, 3, 1, 1, activation=self.act, norm_type=norm_type, bias=bias
+ )
+ self.conv2 = conv2d_block(
+ out_channels, out_channels, 3, 1, 1, activation=None, norm_type=norm_type, bias=bias
+ )
+ elif self.res_type == 'bottleneck':
+ self.conv1 = conv2d_block(
+ in_channels, out_channels, 1, 1, 0, activation=self.act, norm_type=norm_type, bias=bias
+ )
+ self.conv2 = conv2d_block(
+ out_channels, out_channels, 3, 1, 1, activation=self.act, norm_type=norm_type, bias=bias
+ )
+ self.conv3 = conv2d_block(
+ out_channels, out_channels, 1, 1, 0, activation=None, norm_type=norm_type, bias=bias
+ )
+ elif self.res_type == 'downsample':
+ self.conv1 = conv2d_block(
+ in_channels, out_channels, 3, 2, 1, activation=self.act, norm_type=norm_type, bias=bias
+ )
+ self.conv2 = conv2d_block(
+ out_channels, out_channels, 3, 1, 1, activation=None, norm_type=norm_type, bias=bias
+ )
+ self.conv3 = conv2d_block(in_channels, out_channels, 3, 2, 1, activation=None, norm_type=None, bias=bias)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Return the redisual block output.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - x (:obj:`torch.Tensor`): The resblock output tensor.
+ """
+ identity = x
+ x = self.conv1(x)
+ x = self.conv2(x)
+ if self.res_type == 'bottleneck':
+ x = self.conv3(x)
+ elif self.res_type == 'downsample':
+ identity = self.conv3(identity)
+ x = self.act(x + identity)
+ return x
+
+
+class ResFCBlock(nn.Module):
+ """
+ Overview:
+ Residual Block with 2 fully connected layers.
+ x -> fc1 -> norm -> act -> fc2 -> norm -> act -> out
+ \_____________________________________/+
+
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self, in_channels: int, activation: nn.Module = nn.ReLU(), norm_type: str = 'BN', dropout: float = None
+ ):
+ """
+ Overview:
+ Init the fully connected layer residual block.
+ Arguments:
+ - in_channels (:obj:`int`): The number of channels in the input tensor.
+ - activation (:obj:`nn.Module`): The optional activation function.
+ - norm_type (:obj:`str`): The type of the normalization, default set to 'BN'.
+ - dropout (:obj:`float`): The dropout rate, default set to None.
+ """
+ super(ResFCBlock, self).__init__()
+ self.act = activation
+ if dropout is not None:
+ self.dropout = nn.Dropout(dropout)
+ else:
+ self.dropout = None
+ self.fc1 = fc_block(in_channels, in_channels, activation=self.act, norm_type=norm_type)
+ self.fc2 = fc_block(in_channels, in_channels, activation=None, norm_type=norm_type)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Return the output of the redisual block.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - x (:obj:`torch.Tensor`): The resblock output tensor.
+ """
+ identity = x
+ x = self.fc1(x)
+ x = self.fc2(x)
+ x = self.act(x + identity)
+ if self.dropout is not None:
+ x = self.dropout(x)
+ return x
diff --git a/DI-engine/ding/torch_utils/network/resnet.py b/DI-engine/ding/torch_utils/network/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..643f3533550a2c0276ddc9ccecd28c12a53bae00
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/resnet.py
@@ -0,0 +1,956 @@
+"""
+This implementation of ResNet is a bit modification version of `https://github.com/rwightman/pytorch-image-models.git`
+"""
+from typing import List, Callable, Optional, Tuple, Type, Dict, Union
+import math
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .nn_module import Flatten
+
+
+def to_2tuple(item: int) -> tuple:
+ """
+ Overview:
+ Convert a scalar to a 2-tuple or return the item if it's not a scalar.
+ Arguments:
+ - item (:obj:`int`): An item to be converted to a 2-tuple.
+ Returns:
+ - (:obj:`tuple`): A 2-tuple of the item.
+ """
+ if np.isscalar(item):
+ return (item, item)
+ else:
+ return item
+
+
+# Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution
+def get_same_padding(x: int, k: int, s: int, d: int) -> int:
+ """
+ Overview:
+ Calculate asymmetric TensorFlow-like 'SAME' padding for a convolution.
+ Arguments:
+ - x (:obj:`int`): The size of the input.
+ - k (:obj:`int`): The size of the kernel.
+ - s (:obj:`int`): The stride of the convolution.
+ - d (:obj:`int`): The dilation of the convolution.
+ Returns:
+ - (:obj:`int`): The size of the padding.
+ """
+ return max((math.ceil(x / s) - 1) * s + (k - 1) * d + 1 - x, 0)
+
+
+# Dynamically pad input x with 'SAME' padding for conv with specified args
+def pad_same(x, k: List[int], s: List[int], d: List[int] = (1, 1), value: float = 0):
+ """
+ Overview:
+ Dynamically pad input x with 'SAME' padding for conv with specified args.
+ Arguments:
+ - x (:obj:`Tensor`): The input tensor.
+ - k (:obj:`List[int]`): The size of the kernel.
+ - s (:obj:`List[int]`): The stride of the convolution.
+ - d (:obj:`List[int]`): The dilation of the convolution.
+ - value (:obj:`float`): Value to fill the padding.
+ Returns:
+ - (:obj:`Tensor`): The padded tensor.
+ """
+ ih, iw = x.size()[-2:]
+ pad_h, pad_w = get_same_padding(ih, k[0], s[0], d[0]), get_same_padding(iw, k[1], s[1], d[1])
+ if pad_h > 0 or pad_w > 0:
+ x = F.pad(x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2], value=value)
+ return x
+
+
+def avg_pool2d_same(
+ x,
+ kernel_size: List[int],
+ stride: List[int],
+ padding: List[int] = (0, 0),
+ ceil_mode: bool = False,
+ count_include_pad: bool = True
+):
+ """
+ Overview:
+ Apply average pooling with 'SAME' padding on the input tensor.
+ Arguments:
+ - x (:obj:`Tensor`): The input tensor.
+ - kernel_size (:obj:`List[int]`): The size of the kernel.
+ - stride (:obj:`List[int]`): The stride of the convolution.
+ - padding (:obj:`List[int]`): The size of the padding.
+ - ceil_mode (:obj:`bool`): When True, will use ceil instead of floor to compute the output shape.
+ - count_include_pad (:obj:`bool`): When True, will include the zero-padding in the averaging calculation.
+ Returns:
+ - (:obj:`Tensor`): The tensor after average pooling.
+ """
+ # FIXME how to deal with count_include_pad vs not for external padding?
+ x = pad_same(x, kernel_size, stride)
+ return F.avg_pool2d(x, kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
+
+
+class AvgPool2dSame(nn.AvgPool2d):
+ """
+ Overview:
+ Tensorflow-like 'SAME' wrapper for 2D average pooling.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ kernel_size: int,
+ stride: Optional[Tuple[int, int]] = None,
+ padding: int = 0,
+ ceil_mode: bool = False,
+ count_include_pad: bool = True
+ ) -> None:
+ """
+ Overview:
+ Initialize the AvgPool2dSame with given arguments.
+ Arguments:
+ - kernel_size (:obj:`int`): The size of the window to take an average over.
+ - stride (:obj:`Optional[Tuple[int, int]]`): The stride of the window. If None, default to kernel_size.
+ - padding (:obj:`int`): Implicit zero padding to be added on both sides.
+ - ceil_mode (:obj:`bool`): When True, will use `ceil` instead of `floor` to compute the output shape.
+ - count_include_pad (:obj:`bool`): When True, will include the zero-padding in the averaging calculation.
+ """
+ kernel_size = to_2tuple(kernel_size)
+ stride = to_2tuple(stride)
+ super(AvgPool2dSame, self).__init__(kernel_size, stride, (0, 0), ceil_mode, count_include_pad)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Forward pass of the AvgPool2dSame.
+ Argument:
+ - x (:obj:`torch.Tensor`): Input tensor.
+ Returns:
+ - (:obj:`torch.Tensor`): Output tensor after average pooling.
+ """
+ x = pad_same(x, self.kernel_size, self.stride)
+ return F.avg_pool2d(x, self.kernel_size, self.stride, self.padding, self.ceil_mode, self.count_include_pad)
+
+
+def _create_pool(num_features: int,
+ num_classes: int,
+ pool_type: str = 'avg',
+ use_conv: bool = False) -> Tuple[nn.Module, int]:
+ """
+ Overview:
+ Create a global pooling layer based on the given arguments.
+ Arguments:
+ - num_features (:obj:`int`): Number of input features.
+ - num_classes (:obj:`int`): Number of output classes.
+ - pool_type (:obj:`str`): Type of the pooling operation. Defaults to 'avg'.
+ - use_conv (:obj:`bool`): Whether to use convolutional layer after pooling. Defaults to False.
+ Returns:
+ - (:obj:`Tuple[nn.Module, int]`): The created global pooling layer and the number of pooled features.
+ """
+ flatten_in_pool = not use_conv # flatten when we use a Linear layer after pooling
+ if not pool_type:
+ assert num_classes == 0 or use_conv, \
+ 'Pooling can only be disabled if classifier is also removed or conv classifier is used'
+ flatten_in_pool = False # disable flattening if pooling is pass-through (no pooling)
+ assert flatten_in_pool
+ global_pool = nn.AdaptiveAvgPool2d(1)
+ num_pooled_features = num_features * 1
+ return global_pool, num_pooled_features
+
+
+def _create_fc(num_features: int, num_classes: int, use_conv: bool = False) -> nn.Module:
+ """
+ Overview:
+ Create a fully connected layer based on the given arguments.
+ Arguments:
+ - num_features (:obj:`int`): Number of input features.
+ - num_classes (:obj:`int`): Number of output classes.
+ - use_conv (:obj:`bool`): Whether to use convolutional layer. Defaults to False.
+ Returns:
+ - (:obj:`nn.Module`): The created fully connected layer.
+ """
+ if num_classes <= 0:
+ fc = nn.Identity() # pass-through (no classifier)
+ elif use_conv:
+ fc = nn.Conv2d(num_features, num_classes, 1, bias=True)
+ else:
+ # use nn.Linear for simplification
+ fc = nn.Linear(num_features, num_classes, bias=True)
+ return fc
+
+
+def create_classifier(num_features: int,
+ num_classes: int,
+ pool_type: str = 'avg',
+ use_conv: bool = False) -> Tuple[nn.Module, nn.Module]:
+ """
+ Overview:
+ Create a classifier with global pooling layer and fully connected layer.
+ Arguments:
+ - num_features (:obj:`int`): The number of features.
+ - num_classes (:obj:`int`): The number of classes for the final classification.
+ - pool_type (:obj:`str`): The type of pooling to use; 'avg' for Average Pooling.
+ - use_conv (:obj:`bool`): Whether to use convolution or not.
+ Returns:
+ - global_pool (:obj:`nn.Module`): The created global pooling layer.
+ - fc (:obj:`nn.Module`): The created fully connected layer.
+ """
+ assert pool_type == 'avg'
+ global_pool, num_pooled_features = _create_pool(num_features, num_classes, pool_type, use_conv=use_conv)
+ fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
+ return global_pool, fc
+
+
+class ClassifierHead(nn.Module):
+ """
+ Overview:
+ Classifier head with configurable global pooling and dropout.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ in_chs: int,
+ num_classes: int,
+ pool_type: str = 'avg',
+ drop_rate: float = 0.,
+ use_conv: bool = False
+ ) -> None:
+ """
+ Overview:
+ Initialize the ClassifierHead with given arguments.
+ Arguments:
+ - in_chs (:obj:`int`): Number of input channels.
+ - num_classes (:obj:`int`): Number of classes for the final classification.
+ - pool_type (:obj:`str`): The type of pooling to use; 'avg' for Average Pooling.
+ - drop_rate (:obj:`float`): The dropout rate.
+ - use_conv (:obj:`bool`): Whether to use convolution or not.
+ """
+ super(ClassifierHead, self).__init__()
+ self.drop_rate = drop_rate
+ self.global_pool, num_pooled_features = _create_pool(in_chs, num_classes, pool_type, use_conv=use_conv)
+ self.fc = _create_fc(num_pooled_features, num_classes, use_conv=use_conv)
+ self.flatten = Flatten(1) if use_conv and pool_type else nn.Identity()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Forward pass of the ClassifierHead.
+ Argument:
+ - x (:obj:`torch.Tensor`): Input tensor.
+ Returns:
+ - (:obj:`torch.Tensor`): Output tensor after classification.
+ """
+ x = self.global_pool(x)
+ if self.drop_rate:
+ x = F.dropout(x, p=float(self.drop_rate), training=self.training)
+ x = self.fc(x)
+ x = self.flatten(x)
+ return x
+
+
+def create_attn(layer: nn.Module, plane: int) -> None:
+ """
+ Overview:
+ Create an attention mechanism.
+ Arguments:
+ - layer (:obj:`nn.Module`): The layer where the attention is to be applied.
+ - plane (:obj:`int`): The plane on which the attention is to be applied.
+ Returns:
+ - None
+ """
+ return None
+
+
+def get_padding(kernel_size: int, stride: int, dilation: int = 1) -> int:
+ """
+ Overview:
+ Compute the padding based on the kernel size, stride and dilation.
+ Arguments:
+ - kernel_size (:obj:`int`): The size of the kernel.
+ - stride (:obj:`int`): The stride of the convolution.
+ - dilation (:obj:`int`): The dilation factor.
+ Returns:
+ - padding (:obj:`int`): The computed padding.
+ """
+ padding = ((stride - 1) + dilation * (kernel_size - 1)) // 2
+ return padding
+
+
+class BasicBlock(nn.Module):
+ """
+ Overview:
+ The basic building block for models like ResNet. This class extends pytorch's Module class.
+ It represents a standard block of layers including two convolutions, batch normalization,
+ an optional attention mechanism, and activation functions.
+ Interfaces:
+ ``__init__``, ``forward``, ``zero_init_last_bn``
+ Properties:
+ - expansion (:obj:int): Specifies the expansion factor for the planes of the conv layers.
+ """
+ expansion = 1
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Callable = None,
+ cardinality: int = 1,
+ base_width: int = 64,
+ reduce_first: int = 1,
+ dilation: int = 1,
+ first_dilation: int = None,
+ act_layer: Callable = nn.ReLU,
+ norm_layer: Callable = nn.BatchNorm2d,
+ attn_layer: Callable = None,
+ aa_layer: Callable = None,
+ drop_block: Callable = None,
+ drop_path: Callable = None
+ ) -> None:
+ """
+ Overview:
+ Initialize the BasicBlock with given parameters.
+ Arguments:
+ - inplanes (:obj:`int`): Number of input channels.
+ - planes (:obj:`int`): Number of output channels.
+ - stride (:obj:`int`): The stride of the convolutional layer.
+ - downsample (:obj:`Callable`): Function for downsampling the inputs.
+ - cardinality (:obj:`int`): Group size for grouped convolution.
+ - base_width (:obj:`int`): Base width of the convolutions.
+ - reduce_first (:obj:`int`): Reduction factor for first convolution of each block.
+ - dilation (:obj:`int`): Spacing between kernel points.
+ - first_dilation (:obj:`int`): First dilation value.
+ - act_layer (:obj:`Callable`): Function for activation layer.
+ - norm_layer (:obj:`Callable`): Function for normalization layer.
+ - attn_layer (:obj:`Callable`): Function for attention layer.
+ - aa_layer (:obj:`Callable`): Function for anti-aliasing layer.
+ - drop_block (:obj:`Callable`): Method for dropping block.
+ - drop_path (:obj:`Callable`): Method for dropping path.
+ """
+ super(BasicBlock, self).__init__()
+
+ assert cardinality == 1, 'BasicBlock only supports cardinality of 1'
+ assert base_width == 64, 'BasicBlock does not support changing base width'
+ first_planes = planes // reduce_first
+ outplanes = planes * self.expansion
+ first_dilation = first_dilation or dilation
+ use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
+
+ self.conv1 = nn.Conv2d(
+ inplanes,
+ first_planes,
+ kernel_size=3,
+ stride=1 if use_aa else stride,
+ padding=first_dilation,
+ dilation=first_dilation,
+ bias=False
+ )
+ self.bn1 = norm_layer(first_planes)
+ self.act1 = act_layer(inplace=True)
+ self.aa = aa_layer(channels=first_planes, stride=stride) if use_aa else None
+
+ self.conv2 = nn.Conv2d(first_planes, outplanes, kernel_size=3, padding=dilation, dilation=dilation, bias=False)
+ self.bn2 = norm_layer(outplanes)
+
+ self.se = create_attn(attn_layer, outplanes)
+
+ self.act2 = act_layer(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.drop_block = drop_block
+ self.drop_path = drop_path
+
+ def zero_init_last_bn(self) -> None:
+ """
+ Overview:
+ Initialize the batch normalization layer with zeros.
+ """
+ nn.init.zeros_(self.bn2.weight)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Defines the computation performed at every call.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - output (:obj:`torch.Tensor`): The output tensor after passing through the BasicBlock.
+ """
+ shortcut = x
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ if self.drop_block is not None:
+ x = self.drop_block(x)
+ x = self.act1(x)
+ if self.aa is not None:
+ x = self.aa(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ if self.drop_block is not None:
+ x = self.drop_block(x)
+
+ if self.se is not None:
+ x = self.se(x)
+
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+
+ if self.downsample is not None:
+ shortcut = self.downsample(shortcut)
+ x += shortcut
+ x = self.act2(x)
+
+ return x
+
+
+class Bottleneck(nn.Module):
+ """
+ Overview:
+ The Bottleneck class is a basic block used to build ResNet networks. It is a part of the PyTorch's
+ implementation of ResNet. This block is designed with several layers including a convolutional layer,
+ normalization layer, activation layer, attention layer, anti-aliasing layer, and a dropout layer.
+ Interfaces:
+ ``__init__``, ``forward``, ``zero_init_last_bn``
+ Properties:
+ expansion, inplanes, planes, stride, downsample, cardinality, base_width, reduce_first, dilation, \
+ first_dilation, act_layer, norm_layer, attn_layer, aa_layer, drop_block, drop_path
+
+ """
+ expansion = 4
+
+ def __init__(
+ self,
+ inplanes: int,
+ planes: int,
+ stride: int = 1,
+ downsample: Optional[nn.Module] = None,
+ cardinality: int = 1,
+ base_width: int = 64,
+ reduce_first: int = 1,
+ dilation: int = 1,
+ first_dilation: Optional[int] = None,
+ act_layer: Type[nn.Module] = nn.ReLU,
+ norm_layer: Type[nn.Module] = nn.BatchNorm2d,
+ attn_layer: Optional[Type[nn.Module]] = None,
+ aa_layer: Optional[Type[nn.Module]] = None,
+ drop_block: Callable = None,
+ drop_path: Callable = None
+ ) -> None:
+ """
+ Overview:
+ Initialize the Bottleneck class with various parameters.
+
+ Arguments:
+ - inplanes (:obj:`int`): The number of input planes.
+ - planes (:obj:`int`): The number of output planes.
+ - stride (:obj:`int`, optional): The stride size, defaults to 1.
+ - downsample (:obj:`nn.Module`, optional): The downsample method, defaults to None.
+ - cardinality (:obj:`int`, optional): The size of the group convolutions, defaults to 1.
+ - base_width (:obj:`int`, optional): The base width, defaults to 64.
+ - reduce_first (:obj:`int`, optional): The first reduction factor, defaults to 1.
+ - dilation (:obj:`int`, optional): The dilation factor, defaults to 1.
+ - first_dilation (:obj:`int`, optional): The first dilation factor, defaults to None.
+ - act_layer (:obj:`Type[nn.Module]`, optional): The activation layer type, defaults to nn.ReLU.
+ - norm_layer (:obj:`Type[nn.Module]`, optional): The normalization layer type, defaults to nn.BatchNorm2d.
+ - attn_layer (:obj:`Type[nn.Module]`, optional): The attention layer type, defaults to None.
+ - aa_layer (:obj:`Type[nn.Module]`, optional): The anti-aliasing layer type, defaults to None.
+ - drop_block (:obj:`Callable`): The dropout block, defaults to None.
+ - drop_path (:obj:`Callable`): The drop path, defaults to None.
+ """
+ super(Bottleneck, self).__init__()
+
+ width = int(math.floor(planes * (base_width / 64)) * cardinality)
+ first_planes = width // reduce_first
+ outplanes = planes * self.expansion
+ first_dilation = first_dilation or dilation
+ use_aa = aa_layer is not None and (stride == 2 or first_dilation != dilation)
+
+ self.conv1 = nn.Conv2d(inplanes, first_planes, kernel_size=1, bias=False)
+ self.bn1 = norm_layer(first_planes)
+ self.act1 = act_layer(inplace=True)
+
+ self.conv2 = nn.Conv2d(
+ first_planes,
+ width,
+ kernel_size=3,
+ stride=1 if use_aa else stride,
+ padding=first_dilation,
+ dilation=first_dilation,
+ groups=cardinality,
+ bias=False
+ )
+ self.bn2 = norm_layer(width)
+ self.act2 = act_layer(inplace=True)
+ self.aa = aa_layer(channels=width, stride=stride) if use_aa else None
+
+ self.conv3 = nn.Conv2d(width, outplanes, kernel_size=1, bias=False)
+ self.bn3 = norm_layer(outplanes)
+
+ self.se = create_attn(attn_layer, outplanes)
+
+ self.act3 = act_layer(inplace=True)
+ self.downsample = downsample
+ self.stride = stride
+ self.dilation = dilation
+ self.drop_block = drop_block
+ self.drop_path = drop_path
+
+ def zero_init_last_bn(self) -> None:
+ """
+ Overview:
+ Initialize the last batch normalization layer with zero.
+ """
+ nn.init.zeros_(self.bn3.weight)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Defines the computation performed at every call.
+ Arguments:
+ - x (:obj:`Tensor`): The input tensor.
+ Returns:
+ - x (:obj:`Tensor`): The output tensor resulting from the computation.
+ """
+ shortcut = x
+
+ x = self.conv1(x)
+ x = self.bn1(x)
+ if self.drop_block is not None:
+ x = self.drop_block(x)
+ x = self.act1(x)
+
+ x = self.conv2(x)
+ x = self.bn2(x)
+ if self.drop_block is not None:
+ x = self.drop_block(x)
+ x = self.act2(x)
+ if self.aa is not None:
+ x = self.aa(x)
+
+ x = self.conv3(x)
+ x = self.bn3(x)
+ if self.drop_block is not None:
+ x = self.drop_block(x)
+
+ if self.se is not None:
+ x = self.se(x)
+
+ if self.drop_path is not None:
+ x = self.drop_path(x)
+
+ if self.downsample is not None:
+ shortcut = self.downsample(shortcut)
+ x += shortcut
+ x = self.act3(x)
+
+ return x
+
+
+def downsample_conv(
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ dilation: int = 1,
+ first_dilation: int = None,
+ norm_layer: Type[nn.Module] = None
+) -> nn.Sequential:
+ """
+ Overview:
+ Create a sequential module for downsampling that includes a convolution layer and a normalization layer.
+ Arguments:
+ - in_channels (:obj:`int`): The number of input channels.
+ - out_channels (:obj:`int`): The number of output channels.
+ - kernel_size (:obj:`int`): The size of the kernel.
+ - stride (:obj:`int`, optional): The stride size, defaults to 1.
+ - dilation (:obj:`int`, optional): The dilation factor, defaults to 1.
+ - first_dilation (:obj:`int`, optional): The first dilation factor, defaults to None.
+ - norm_layer (:obj:`Type[nn.Module]`, optional): The normalization layer type, defaults to nn.BatchNorm2d.
+ Returns:
+ - nn.Sequential: A sequence of layers performing downsampling through convolution.
+ """
+ norm_layer = norm_layer or nn.BatchNorm2d
+ kernel_size = 1 if stride == 1 and dilation == 1 else kernel_size
+ first_dilation = (first_dilation or dilation) if kernel_size > 1 else 1
+ p = get_padding(kernel_size, stride, first_dilation)
+
+ return nn.Sequential(
+ *[
+ nn.Conv2d(
+ in_channels, out_channels, kernel_size, stride=stride, padding=p, dilation=first_dilation, bias=False
+ ),
+ norm_layer(out_channels)
+ ]
+ )
+
+
+def downsample_avg(
+ in_channels: int,
+ out_channels: int,
+ kernel_size: int,
+ stride: int = 1,
+ dilation: int = 1,
+ first_dilation: int = None,
+ norm_layer: Type[nn.Module] = None
+) -> nn.Sequential:
+ """
+ Overview:
+ Create a sequential module for downsampling that includes an average pooling layer, a convolution layer,
+ and a normalization layer.
+ Arguments:
+ - in_channels (:obj:`int`): The number of input channels.
+ - out_channels (:obj:`int`): The number of output channels.
+ - kernel_size (:obj:`int`): The size of the kernel.
+ - stride (:obj:`int`, optional): The stride size, defaults to 1.
+ - dilation (:obj:`int`, optional): The dilation factor, defaults to 1.
+ - first_dilation (:obj:`int`, optional): The first dilation factor, defaults to None.
+ - norm_layer (:obj:`Type[nn.Module]`, optional): The normalization layer type, defaults to nn.BatchNorm2d.
+ Returns:
+ - nn.Sequential: A sequence of layers performing downsampling through average pooling.
+ """
+ norm_layer = norm_layer or nn.BatchNorm2d
+ avg_stride = stride if dilation == 1 else 1
+ if stride == 1 and dilation == 1:
+ pool = nn.Identity()
+ else:
+ avg_pool_fn = AvgPool2dSame if avg_stride == 1 and dilation > 1 else nn.AvgPool2d
+ pool = avg_pool_fn(2, avg_stride, ceil_mode=True, count_include_pad=False)
+
+ return nn.Sequential(
+ *[pool,
+ nn.Conv2d(in_channels, out_channels, 1, stride=1, padding=0, bias=False),
+ norm_layer(out_channels)]
+ )
+
+
+def drop_blocks(drop_block_rate: float = 0.) -> List[None]:
+ """
+ Overview:
+ Generate a list of None values based on the drop block rate.
+ Arguments:
+ - drop_block_rate (:obj:`float`, optional): The drop block rate, defaults to 0.
+ Returns:
+ - List[None]: A list of None values.
+ """
+ assert drop_block_rate == 0., drop_block_rate
+ return [None for _ in range(4)]
+
+
+def make_blocks(
+ block_fn: Type[nn.Module],
+ channels: List[int],
+ block_repeats: List[int],
+ inplanes: int,
+ reduce_first: int = 1,
+ output_stride: int = 32,
+ down_kernel_size: int = 1,
+ avg_down: bool = False,
+ drop_block_rate: float = 0.,
+ drop_path_rate: float = 0.,
+ **kwargs
+) -> Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Union[int, str]]]]:
+ """
+ Overview:
+ Create a list of blocks for the network, with each block having a given number of repeats. Also, create a
+ feature info list that contains information about the output of each block.
+ Arguments:
+ - block_fn (:obj:`Type[nn.Module]`): The type of block to use.
+ - channels (:obj:`List[int]`): The list of output channels for each block.
+ - block_repeats (:obj:`List[int]`): The list of number of repeats for each block.
+ - inplanes (:obj:`int`): The number of input planes.
+ - reduce_first (:obj:`int`, optional): The first reduction factor, defaults to 1.
+ - output_stride (:obj:`int`, optional): The total stride of the network, defaults to 32.
+ - down_kernel_size (:obj:`int`, optional): The size of the downsample kernel, defaults to 1.
+ - avg_down (:obj:`bool`, optional): Whether to use average pooling for downsampling, defaults to False.
+ - drop_block_rate (:obj:`float`, optional): The drop block rate, defaults to 0.
+ - drop_path_rate (:obj:`float`, optional): The drop path rate, defaults to 0.
+ - **kwargs: Additional keyword arguments.
+ Returns:
+ - Tuple[List[Tuple[str, nn.Module]], List[Dict[str, Union[int, str]]]]: \
+ A tuple that includes a list of blocks for the network and a feature info list.
+ """
+ stages = []
+ feature_info = []
+ net_num_blocks = sum(block_repeats)
+ net_block_idx = 0
+ net_stride = 4
+ dilation = prev_dilation = 1
+ for stage_idx, (planes, num_blocks, db) in enumerate(zip(channels, block_repeats, drop_blocks(drop_block_rate))):
+ stage_name = f'layer{stage_idx + 1}' # never liked this name, but weight compat requires it
+ stride = 1 if stage_idx == 0 else 2
+ if net_stride >= output_stride:
+ dilation *= stride
+ stride = 1
+ else:
+ net_stride *= stride
+
+ downsample = None
+ if stride != 1 or inplanes != planes * block_fn.expansion:
+ down_kwargs = dict(
+ in_channels=inplanes,
+ out_channels=planes * block_fn.expansion,
+ kernel_size=down_kernel_size,
+ stride=stride,
+ dilation=dilation,
+ first_dilation=prev_dilation,
+ norm_layer=kwargs.get('norm_layer')
+ )
+ downsample = downsample_avg(**down_kwargs) if avg_down else downsample_conv(**down_kwargs)
+
+ block_kwargs = dict(reduce_first=reduce_first, dilation=dilation, drop_block=db, **kwargs)
+ blocks = []
+ for block_idx in range(num_blocks):
+ downsample = downsample if block_idx == 0 else None
+ stride = stride if block_idx == 0 else 1
+ block_dpr = drop_path_rate * net_block_idx / (net_num_blocks - 1) # stochastic depth linear decay rule
+ blocks.append(
+ block_fn(
+ inplanes, planes, stride, downsample, first_dilation=prev_dilation, drop_path=None, **block_kwargs
+ )
+ )
+ prev_dilation = dilation
+ inplanes = planes * block_fn.expansion
+ net_block_idx += 1
+
+ stages.append((stage_name, nn.Sequential(*blocks)))
+ feature_info.append(dict(num_chs=inplanes, reduction=net_stride, module=stage_name))
+
+ return stages, feature_info
+
+
+class ResNet(nn.Module):
+ """
+ Overview:
+ Implements ResNet, ResNeXt, SE-ResNeXt, and SENet models. This implementation supports various modifications
+ based on the v1c, v1d, v1e, and v1s variants included in the MXNet Gluon ResNetV1b model. For more details
+ about the variants and options, please refer to the 'Bag of Tricks' paper: https://arxiv.org/pdf/1812.01187.
+ Interfaces:
+ ``__init__``, ``forward``, ``zero_init_last_bn``, ``get_classifier``
+ """
+
+ def __init__(
+ self,
+ block: nn.Module,
+ layers: List[int],
+ num_classes: int = 1000,
+ in_chans: int = 3,
+ cardinality: int = 1,
+ base_width: int = 64,
+ stem_width: int = 64,
+ stem_type: str = '',
+ replace_stem_pool: bool = False,
+ output_stride: int = 32,
+ block_reduce_first: int = 1,
+ down_kernel_size: int = 1,
+ avg_down: bool = False,
+ act_layer: nn.Module = nn.ReLU,
+ norm_layer: nn.Module = nn.BatchNorm2d,
+ aa_layer: Optional[nn.Module] = None,
+ drop_rate: float = 0.0,
+ drop_path_rate: float = 0.0,
+ drop_block_rate: float = 0.0,
+ global_pool: str = 'avg',
+ zero_init_last_bn: bool = True,
+ block_args: Optional[dict] = None
+ ) -> None:
+ """
+ Overview:
+ Initialize the ResNet model with given block, layers and other configuration options.
+ Arguments:
+ - block (:obj:`nn.Module`): Class for the residual block.
+ - layers (:obj:`List[int]`): Numbers of layers in each block.
+ - num_classes (:obj:`int`, optional): Number of classification classes. Default is 1000.
+ - in_chans (:obj:`int`, optional): Number of input (color) channels. Default is 3.
+ - cardinality (:obj:`int`, optional): Number of convolution groups for 3x3 conv in Bottleneck. Default is 1.
+ - base_width (:obj:`int`, optional): Factor determining bottleneck channels. Default is 64.
+ - stem_width (:obj:`int`, optional): Number of channels in stem convolutions. Default is 64.
+ - stem_type (:obj:`str`, optional): The type of stem. Default is ''.
+ - replace_stem_pool (:obj:`bool`, optional): Whether to replace stem pooling. Default is False.
+ - output_stride (:obj:`int`, optional): Output stride of the network. Default is 32.
+ - block_reduce_first (:obj:`int`, optional): Reduction factor for first convolution output width of \
+ residual blocks. Default is 1.
+ - down_kernel_size (:obj:`int`, optional): Kernel size of residual block downsampling path. Default is 1.
+ - avg_down (:obj:`bool`, optional): Whether to use average pooling for projection skip connection between
+ stages/downsample. Default is False.
+ - act_layer (:obj:`nn.Module`, optional): Activation layer. Default is nn.ReLU.
+ - norm_layer (:obj:`nn.Module`, optional): Normalization layer. Default is nn.BatchNorm2d.
+ - aa_layer (:obj:`Optional[nn.Module]`, optional): Anti-aliasing layer. Default is None.
+ - drop_rate (:obj:`float`, optional): Dropout probability before classifier, for training. Default is 0.0.
+ - drop_path_rate (:obj:`float`, optional): Drop path rate. Default is 0.0.
+ - drop_block_rate (:obj:`float`, optional): Drop block rate. Default is 0.0.
+ - global_pool (:obj:`str`, optional): Global pooling type. Default is 'avg'.
+ - zero_init_last_bn (:obj:`bool`, optional): Whether to initialize last batch normalization with zero. \
+ Default is True.
+ - block_args (:obj:`Optional[dict]`, optional): Additional arguments for block. Default is None.
+ """
+ block_args = block_args or dict()
+ assert output_stride in (8, 16, 32)
+ self.num_classes = num_classes
+ self.drop_rate = drop_rate
+ super(ResNet, self).__init__()
+
+ # Stem
+ deep_stem = 'deep' in stem_type
+ inplanes = stem_width * 2 if deep_stem else 64
+ if deep_stem:
+ stem_chs = (stem_width, stem_width)
+ if 'tiered' in stem_type:
+ stem_chs = (3 * (stem_width // 4), stem_width)
+ self.conv1 = nn.Sequential(
+ *[
+ nn.Conv2d(in_chans, stem_chs[0], 3, stride=2, padding=1, bias=False),
+ norm_layer(stem_chs[0]),
+ act_layer(inplace=True),
+ nn.Conv2d(stem_chs[0], stem_chs[1], 3, stride=1, padding=1, bias=False),
+ norm_layer(stem_chs[1]),
+ act_layer(inplace=True),
+ nn.Conv2d(stem_chs[1], inplanes, 3, stride=1, padding=1, bias=False)
+ ]
+ )
+ else:
+ self.conv1 = nn.Conv2d(in_chans, inplanes, kernel_size=7, stride=2, padding=3, bias=False)
+ self.bn1 = norm_layer(inplanes)
+ self.act1 = act_layer(inplace=True)
+ self.feature_info = [dict(num_chs=inplanes, reduction=2, module='act1')]
+
+ # Stem Pooling
+ if replace_stem_pool:
+ self.maxpool = nn.Sequential(
+ *filter(
+ None, [
+ nn.Conv2d(inplanes, inplanes, 3, stride=1 if aa_layer else 2, padding=1, bias=False),
+ aa_layer(channels=inplanes, stride=2) if aa_layer else None,
+ norm_layer(inplanes),
+ act_layer(inplace=True)
+ ]
+ )
+ )
+ else:
+ if aa_layer is not None:
+ self.maxpool = nn.Sequential(
+ *[nn.MaxPool2d(kernel_size=3, stride=1, padding=1),
+ aa_layer(channels=inplanes, stride=2)]
+ )
+ else:
+ self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1)
+
+ # Feature Blocks
+ channels = [64, 128, 256, 512]
+ stage_modules, stage_feature_info = make_blocks(
+ block,
+ channels,
+ layers,
+ inplanes,
+ cardinality=cardinality,
+ base_width=base_width,
+ output_stride=output_stride,
+ reduce_first=block_reduce_first,
+ avg_down=avg_down,
+ down_kernel_size=down_kernel_size,
+ act_layer=act_layer,
+ norm_layer=norm_layer,
+ aa_layer=aa_layer,
+ drop_block_rate=drop_block_rate,
+ drop_path_rate=drop_path_rate,
+ **block_args
+ )
+ for stage in stage_modules:
+ self.add_module(*stage) # layer1, layer2, etc
+ self.feature_info.extend(stage_feature_info)
+
+ # Head (Pooling and Classifier)
+ self.num_features = 512 * block.expansion
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ self.init_weights(zero_init_last_bn=zero_init_last_bn)
+
+ def init_weights(self, zero_init_last_bn: bool = True) -> None:
+ """
+ Overview:
+ Initialize the weights in the model.
+ Arguments:
+ - zero_init_last_bn (:obj:`bool`, optional): Whether to initialize last batch normalization with zero.
+ Default is True.
+ """
+ for n, m in self.named_modules():
+ if isinstance(m, nn.Conv2d):
+ nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
+ elif isinstance(m, nn.BatchNorm2d):
+ nn.init.ones_(m.weight)
+ nn.init.zeros_(m.bias)
+ if zero_init_last_bn:
+ for m in self.modules():
+ if hasattr(m, 'zero_init_last_bn'):
+ m.zero_init_last_bn()
+
+ def get_classifier(self) -> nn.Module:
+ """
+ Overview:
+ Get the classifier module from the model.
+ Returns:
+ - classifier (:obj:`nn.Module`): The classifier module in the model.
+ """
+ return self.fc
+
+ def reset_classifier(self, num_classes: int, global_pool: str = 'avg') -> None:
+ """
+ Overview:
+ Reset the classifier with a new number of classes and pooling type.
+ Arguments:
+ - num_classes (:obj:`int`): New number of classification classes.
+ - global_pool (:obj:`str`, optional): New global pooling type. Default is 'avg'.
+ """
+ self.num_classes = num_classes
+ self.global_pool, self.fc = create_classifier(self.num_features, self.num_classes, pool_type=global_pool)
+
+ def forward_features(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Forward pass through the feature layers of the model.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - x (:obj:`torch.Tensor`): The output tensor after passing through feature layers.
+ """
+ x = self.conv1(x)
+ x = self.bn1(x)
+ x = self.act1(x)
+ x = self.maxpool(x)
+
+ x = self.layer1(x)
+ x = self.layer2(x)
+ x = self.layer3(x)
+ x = self.layer4(x)
+ return x
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Full forward pass through the model.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor.
+ Returns:
+ - x (:obj:`torch.Tensor`): The output tensor after passing through the model.
+ """
+ x = self.forward_features(x)
+ x = self.global_pool(x)
+ x = x.view(x.shape[0], -1)
+ if self.drop_rate:
+ x = F.dropout(x, p=float(self.drop_rate), training=self.training)
+ x = self.fc(x)
+ return x
+
+
+def resnet18() -> nn.Module:
+ """
+ Overview:
+ Creates a ResNet18 model.
+ Returns:
+ - model (:obj:`nn.Module`): ResNet18 model.
+ """
+ return ResNet(block=BasicBlock, layers=[2, 2, 2, 2])
diff --git a/DI-engine/ding/torch_utils/network/rnn.py b/DI-engine/ding/torch_utils/network/rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..e24bd7e468e409bbdd7ed4de36caf4dd483c2163
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/rnn.py
@@ -0,0 +1,362 @@
+from typing import Optional, Union, List, Tuple, Dict
+import math
+import torch
+import torch.nn as nn
+import treetensor.torch as ttorch
+
+import ding
+from ding.torch_utils.network.normalization import build_normalization
+if ding.enable_hpc_rl:
+ from hpc_rll.torch_utils.network.rnn import LSTM as HPCLSTM
+else:
+ HPCLSTM = None
+
+
+def is_sequence(data):
+ """
+ Overview:
+ Determines if the input data is of type list or tuple.
+ Arguments:
+ - data: The input data to be checked.
+ Returns:
+ - boolean: True if the input is a list or a tuple, False otherwise.
+ """
+ return isinstance(data, list) or isinstance(data, tuple)
+
+
+def sequence_mask(lengths: torch.Tensor, max_len: Optional[int] = None) -> torch.BoolTensor:
+ """
+ Overview:
+ Generates a boolean mask for a batch of sequences with differing lengths.
+ Arguments:
+ - lengths (:obj:`torch.Tensor`): A tensor with the lengths of each sequence. Shape could be (n, 1) or (n).
+ - max_len (:obj:`int`, optional): The padding size. If max_len is None, the padding size is the max length of \
+ sequences.
+ Returns:
+ - masks (:obj:`torch.BoolTensor`): A boolean mask tensor. The mask has the same device as lengths.
+ """
+ if len(lengths.shape) == 1:
+ lengths = lengths.unsqueeze(dim=1)
+ bz = lengths.numel()
+ if max_len is None:
+ max_len = lengths.max()
+ else:
+ max_len = min(max_len, lengths.max())
+ return torch.arange(0, max_len).type_as(lengths).repeat(bz, 1).lt(lengths).to(lengths.device)
+
+
+class LSTMForwardWrapper(object):
+ """
+ Overview:
+ Class providing methods to use before and after the LSTM `forward` method.
+ Wraps the LSTM `forward` method.
+ Interfaces:
+ ``_before_forward``, ``_after_forward``
+ """
+
+ def _before_forward(self, inputs: torch.Tensor, prev_state: Union[None, List[Dict]]) -> torch.Tensor:
+ """
+ Overview:
+ Preprocesses the inputs and previous states before the LSTM `forward` method.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): Input vector of the LSTM cell. Shape: [seq_len, batch_size, input_size]
+ - prev_state (:obj:`Union[None, List[Dict]]`): Previous state tensor. Shape: [num_directions*num_layers, \
+ batch_size, hidden_size]. If None, prv_state will be initialized to all zeros.
+ Returns:
+ - prev_state (:obj:`torch.Tensor`): Preprocessed previous state for the LSTM batch.
+ """
+ assert hasattr(self, 'num_layers')
+ assert hasattr(self, 'hidden_size')
+ seq_len, batch_size = inputs.shape[:2]
+ if prev_state is None:
+ num_directions = 1
+ zeros = torch.zeros(
+ num_directions * self.num_layers,
+ batch_size,
+ self.hidden_size,
+ dtype=inputs.dtype,
+ device=inputs.device
+ )
+ prev_state = (zeros, zeros)
+ elif is_sequence(prev_state):
+ if len(prev_state) != batch_size:
+ raise RuntimeError(
+ "prev_state number is not equal to batch_size: {}/{}".format(len(prev_state), batch_size)
+ )
+ num_directions = 1
+ zeros = torch.zeros(
+ num_directions * self.num_layers, 1, self.hidden_size, dtype=inputs.dtype, device=inputs.device
+ )
+ state = []
+ for prev in prev_state:
+ if prev is None:
+ state.append([zeros, zeros])
+ else:
+ if isinstance(prev, (Dict, ttorch.Tensor)):
+ state.append([v for v in prev.values()])
+ else:
+ state.append(prev)
+ state = list(zip(*state))
+ prev_state = [torch.cat(t, dim=1) for t in state]
+ elif isinstance(prev_state, dict):
+ prev_state = list(prev_state.values())
+ else:
+ raise TypeError("not support prev_state type: {}".format(type(prev_state)))
+ return prev_state
+
+ def _after_forward(self,
+ next_state: Tuple[torch.Tensor],
+ list_next_state: bool = False) -> Union[List[Dict], Dict[str, torch.Tensor]]:
+ """
+ Overview:
+ Post-processes the next_state after the LSTM `forward` method.
+ Arguments:
+ - next_state (:obj:`Tuple[torch.Tensor]`): Tuple containing the next state (h, c).
+ - list_next_state (:obj:`bool`, optional): Determines the format of the returned next_state. \
+ If True, returns next_state in list format. Default is False.
+ Returns:
+ - next_state(:obj:`Union[List[Dict], Dict[str, torch.Tensor]]`): The post-processed next_state.
+ """
+ if list_next_state:
+ h, c = next_state
+ batch_size = h.shape[1]
+ next_state = [torch.chunk(h, batch_size, dim=1), torch.chunk(c, batch_size, dim=1)]
+ next_state = list(zip(*next_state))
+ next_state = [{k: v for k, v in zip(['h', 'c'], item)} for item in next_state]
+ else:
+ next_state = {k: v for k, v in zip(['h', 'c'], next_state)}
+ return next_state
+
+
+class LSTM(nn.Module, LSTMForwardWrapper):
+ """
+ Overview:
+ Implementation of an LSTM cell with Layer Normalization (LN).
+ Interfaces:
+ ``__init__``, ``forward``
+
+ .. note::
+
+ For a primer on LSTM, refer to https://zhuanlan.zhihu.com/p/32085405.
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ hidden_size: int,
+ num_layers: int,
+ norm_type: Optional[str] = None,
+ dropout: float = 0.
+ ) -> None:
+ """
+ Overview:
+ Initialize LSTM cell parameters.
+ Arguments:
+ - input_size (:obj:`int`): Size of the input vector.
+ - hidden_size (:obj:`int`): Size of the hidden state vector.
+ - num_layers (:obj:`int`): Number of LSTM layers.
+ - norm_type (:obj:`Optional[str]`): Normalization type, default is None.
+ - dropout (:obj:`float`): Dropout rate, default is 0.
+ """
+ super(LSTM, self).__init__()
+ self.input_size = input_size
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+
+ norm_func = build_normalization(norm_type)
+ self.norm = nn.ModuleList([norm_func(hidden_size * 4) for _ in range(2 * num_layers)])
+ self.wx = nn.ParameterList()
+ self.wh = nn.ParameterList()
+ dims = [input_size] + [hidden_size] * num_layers
+ for l in range(num_layers):
+ self.wx.append(nn.Parameter(torch.zeros(dims[l], dims[l + 1] * 4)))
+ self.wh.append(nn.Parameter(torch.zeros(hidden_size, hidden_size * 4)))
+ self.bias = nn.Parameter(torch.zeros(num_layers, hidden_size * 4))
+ self.use_dropout = dropout > 0.
+ if self.use_dropout:
+ self.dropout = nn.Dropout(dropout)
+ self._init()
+
+ def _init(self):
+ """
+ Overview:
+ Initialize the parameters of the LSTM cell.
+ """
+
+ gain = math.sqrt(1. / self.hidden_size)
+ for l in range(self.num_layers):
+ torch.nn.init.uniform_(self.wx[l], -gain, gain)
+ torch.nn.init.uniform_(self.wh[l], -gain, gain)
+ if self.bias is not None:
+ torch.nn.init.uniform_(self.bias[l], -gain, gain)
+
+ def forward(self,
+ inputs: torch.Tensor,
+ prev_state: torch.Tensor,
+ list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]:
+ """
+ Overview:
+ Compute output and next state given previous state and input.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): Input vector of cell, size [seq_len, batch_size, input_size].
+ - prev_state (:obj:`torch.Tensor`): Previous state, \
+ size [num_directions*num_layers, batch_size, hidden_size].
+ - list_next_state (:obj:`bool`): Whether to return next_state in list format, default is True.
+ Returns:
+ - x (:obj:`torch.Tensor`): Output from LSTM.
+ - next_state (:obj:`Union[torch.Tensor, list]`): Hidden state from LSTM.
+ """
+ seq_len, batch_size = inputs.shape[:2]
+ prev_state = self._before_forward(inputs, prev_state)
+
+ H, C = prev_state
+ x = inputs
+ next_state = []
+ for l in range(self.num_layers):
+ h, c = H[l], C[l]
+ new_x = []
+ for s in range(seq_len):
+ gate = self.norm[l * 2](torch.matmul(x[s], self.wx[l])
+ ) + self.norm[l * 2 + 1](torch.matmul(h, self.wh[l]))
+ if self.bias is not None:
+ gate += self.bias[l]
+ gate = list(torch.chunk(gate, 4, dim=1))
+ i, f, o, u = gate
+ i = torch.sigmoid(i)
+ f = torch.sigmoid(f)
+ o = torch.sigmoid(o)
+ u = torch.tanh(u)
+ c = f * c + i * u
+ h = o * torch.tanh(c)
+ new_x.append(h)
+ next_state.append((h, c))
+ x = torch.stack(new_x, dim=0)
+ if self.use_dropout and l != self.num_layers - 1:
+ x = self.dropout(x)
+ next_state = [torch.stack(t, dim=0) for t in zip(*next_state)]
+
+ next_state = self._after_forward(next_state, list_next_state)
+ return x, next_state
+
+
+class PytorchLSTM(nn.LSTM, LSTMForwardWrapper):
+ """
+ Overview:
+ Wrapper class for PyTorch's nn.LSTM, formats the input and output. For more details on nn.LSTM,
+ refer to https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html#torch.nn.LSTM
+ Interfaces:
+ ``forward``
+ """
+
+ def forward(self,
+ inputs: torch.Tensor,
+ prev_state: torch.Tensor,
+ list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, list]]:
+ """
+ Overview:
+ Executes nn.LSTM.forward with preprocessed input.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): Input vector of cell, size [seq_len, batch_size, input_size].
+ - prev_state (:obj:`torch.Tensor`): Previous state, size [num_directions*num_layers, batch_size, \
+ hidden_size].
+ - list_next_state (:obj:`bool`): Whether to return next_state in list format, default is True.
+ Returns:
+ - output (:obj:`torch.Tensor`): Output from LSTM.
+ - next_state (:obj:`Union[torch.Tensor, list]`): Hidden state from LSTM.
+ """
+ prev_state = self._before_forward(inputs, prev_state)
+ output, next_state = nn.LSTM.forward(self, inputs, prev_state)
+ next_state = self._after_forward(next_state, list_next_state)
+ return output, next_state
+
+
+class GRU(nn.GRUCell, LSTMForwardWrapper):
+ """
+ Overview:
+ This class extends the `torch.nn.GRUCell` and `LSTMForwardWrapper` classes, and formats inputs and outputs
+ accordingly.
+ Interfaces:
+ ``__init__``, ``forward``
+ Properties:
+ hidden_size, num_layers
+
+ .. note::
+ For further details, refer to the official PyTorch documentation:
+
+ """
+
+ def __init__(self, input_size: int, hidden_size: int, num_layers: int) -> None:
+ """
+ Overview:
+ Initialize the GRU class with input size, hidden size, and number of layers.
+ Arguments:
+ - input_size (:obj:`int`): The size of the input vector.
+ - hidden_size (:obj:`int`): The size of the hidden state vector.
+ - num_layers (:obj:`int`): The number of GRU layers.
+ """
+ super(GRU, self).__init__(input_size, hidden_size)
+ self.hidden_size = hidden_size
+ self.num_layers = num_layers
+
+ def forward(self,
+ inputs: torch.Tensor,
+ prev_state: Optional[torch.Tensor] = None,
+ list_next_state: bool = True) -> Tuple[torch.Tensor, Union[torch.Tensor, List]]:
+ """
+ Overview:
+ Wrap the `nn.GRU.forward` method.
+ Arguments:
+ - inputs (:obj:`torch.Tensor`): Input vector of cell, tensor of size [seq_len, batch_size, input_size].
+ - prev_state (:obj:`Optional[torch.Tensor]`): None or tensor of \
+ size [num_directions*num_layers, batch_size, hidden_size].
+ - list_next_state (:obj:`bool`): Whether to return next_state in list format (default is True).
+ Returns:
+ - output (:obj:`torch.Tensor`): Output from GRU.
+ - next_state (:obj:`torch.Tensor` or :obj:`list`): Hidden state from GRU.
+ """
+ # for compatibility
+ prev_state, _ = self._before_forward(inputs, prev_state)
+ inputs, prev_state = inputs.squeeze(0), prev_state.squeeze(0)
+ next_state = nn.GRUCell.forward(self, inputs, prev_state)
+ next_state = next_state.unsqueeze(0)
+ x = next_state
+ # for compatibility
+ next_state = self._after_forward([next_state, next_state.clone()], list_next_state)
+ return x, next_state
+
+
+def get_lstm(
+ lstm_type: str,
+ input_size: int,
+ hidden_size: int,
+ num_layers: int = 1,
+ norm_type: str = 'LN',
+ dropout: float = 0.,
+ seq_len: Optional[int] = None,
+ batch_size: Optional[int] = None
+) -> Union[LSTM, PytorchLSTM]:
+ """
+ Overview:
+ Build and return the corresponding LSTM cell based on the provided parameters.
+ Arguments:
+ - lstm_type (:obj:`str`): Version of RNN cell. Supported options are ['normal', 'pytorch', 'hpc', 'gru'].
+ - input_size (:obj:`int`): Size of the input vector.
+ - hidden_size (:obj:`int`): Size of the hidden state vector.
+ - num_layers (:obj:`int`): Number of LSTM layers (default is 1).
+ - norm_type (:obj:`str`): Type of normalization (default is 'LN').
+ - dropout (:obj:`float`): Dropout rate (default is 0.0).
+ - seq_len (:obj:`Optional[int]`): Sequence length (default is None).
+ - batch_size (:obj:`Optional[int]`): Batch size (default is None).
+ Returns:
+ - lstm (:obj:`Union[LSTM, PytorchLSTM]`): The corresponding LSTM cell.
+ """
+ assert lstm_type in ['normal', 'pytorch', 'hpc', 'gru']
+ if lstm_type == 'normal':
+ return LSTM(input_size, hidden_size, num_layers, norm_type, dropout=dropout)
+ elif lstm_type == 'pytorch':
+ return PytorchLSTM(input_size, hidden_size, num_layers, dropout=dropout)
+ elif lstm_type == 'hpc':
+ return HPCLSTM(seq_len, batch_size, input_size, hidden_size, num_layers, norm_type, dropout).cuda()
+ elif lstm_type == 'gru':
+ assert num_layers == 1
+ return GRU(input_size, hidden_size, num_layers)
diff --git a/DI-engine/ding/torch_utils/network/scatter_connection.py b/DI-engine/ding/torch_utils/network/scatter_connection.py
new file mode 100644
index 0000000000000000000000000000000000000000..d596f3aa1c74375b2774631959961f6857b201ff
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/scatter_connection.py
@@ -0,0 +1,121 @@
+import torch
+import torch.nn as nn
+from typing import Tuple, List
+from ding.hpc_rl import hpc_wrapper
+
+
+def shape_fn_scatter_connection(args, kwargs) -> List[int]:
+ """
+ Overview:
+ Return the shape of scatter_connection for HPC.
+ Arguments:
+ - args (:obj:`Tuple`): The arguments passed to the scatter_connection function.
+ - kwargs (:obj:`Dict`): The keyword arguments passed to the scatter_connection function.
+ Returns:
+ - shape (:obj:`List[int]`): A list representing the shape of scatter_connection, \
+ in the form of [B, M, N, H, W, scatter_type].
+ """
+ if len(args) <= 1:
+ tmp = list(kwargs['x'].shape)
+ else:
+ tmp = list(args[1].shape) # args[0] is __main__.ScatterConnection object
+ if len(args) <= 2:
+ tmp.extend(kwargs['spatial_size'])
+ else:
+ tmp.extend(args[2])
+ tmp.append(args[0].scatter_type)
+ return tmp
+
+
+class ScatterConnection(nn.Module):
+ """
+ Overview:
+ Scatter feature to its corresponding location. In AlphaStar, each entity is embedded into a tensor,
+ and these tensors are scattered into a feature map with map size.
+ Interfaces:
+ ``__init__``, ``forward``, ``xy_forward``
+ """
+
+ def __init__(self, scatter_type: str) -> None:
+ """
+ Overview:
+ Initialize the ScatterConnection object.
+ Arguments:
+ - scatter_type (:obj:`str`): The scatter type, which decides the behavior when two entities have the \
+ same location. It can be either 'add' or 'cover'. If 'add', the first one will be added to the \
+ second one. If 'cover', the first one will be covered by the second one.
+ """
+ super(ScatterConnection, self).__init__()
+ self.scatter_type = scatter_type
+ assert self.scatter_type in ['cover', 'add']
+
+ @hpc_wrapper(
+ shape_fn=shape_fn_scatter_connection,
+ namedtuple_data=False,
+ include_args=[0, 2],
+ include_kwargs=['x', 'location'],
+ is_cls_method=True
+ )
+ def forward(self, x: torch.Tensor, spatial_size: Tuple[int, int], location: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Scatter input tensor 'x' into a spatial feature map.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \
+ is the number of entities, and `N` is the dimension of entity attributes.
+ - spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \
+ will be scattered, where `H` is the height and `W` is the width.
+ - location (:obj:`torch.Tensor`): The tensor of locations of shape `(B, M, 2)`. \
+ Each location should be (y, x).
+ Returns:
+ - output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`.
+ Note:
+ When there are some overlapping in locations, 'cover' mode will result in the loss of information.
+ 'add' mode is used as a temporary substitute.
+ """
+ device = x.device
+ B, M, N = x.shape
+ x = x.permute(0, 2, 1)
+ H, W = spatial_size
+ index = location[:, :, 1] + location[:, :, 0] * W
+ index = index.unsqueeze(dim=1).repeat(1, N, 1)
+ output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W)
+ if self.scatter_type == 'cover':
+ output.scatter_(dim=2, index=index, src=x)
+ elif self.scatter_type == 'add':
+ output.scatter_add_(dim=2, index=index, src=x)
+ output = output.view(B, N, H, W)
+ return output
+
+ def xy_forward(
+ self, x: torch.Tensor, spatial_size: Tuple[int, int], coord_x: torch.Tensor, coord_y
+ ) -> torch.Tensor:
+ """
+ Overview:
+ Scatter input tensor 'x' into a spatial feature map using separate x and y coordinates.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor of shape `(B, M, N)`, where `B` is the batch size, `M` \
+ is the number of entities, and `N` is the dimension of entity attributes.
+ - spatial_size (:obj:`Tuple[int, int]`): The size `(H, W)` of the spatial feature map into which 'x' \
+ will be scattered, where `H` is the height and `W` is the width.
+ - coord_x (:obj:`torch.Tensor`): The x-coordinates tensor of shape `(B, M)`.
+ - coord_y (:obj:`torch.Tensor`): The y-coordinates tensor of shape `(B, M)`.
+ Returns:
+ - output (:obj:`torch.Tensor`): The scattered feature map of shape `(B, N, H, W)`.
+ Note:
+ When there are some overlapping in locations, 'cover' mode will result in the loss of information.
+ 'add' mode is used as a temporary substitute.
+ """
+ device = x.device
+ B, M, N = x.shape
+ x = x.permute(0, 2, 1)
+ H, W = spatial_size
+ index = (coord_x * W + coord_y).long()
+ index = index.unsqueeze(dim=1).repeat(1, N, 1)
+ output = torch.zeros(size=(B, N, H, W), device=device).view(B, N, H * W)
+ if self.scatter_type == 'cover':
+ output.scatter_(dim=2, index=index, src=x)
+ elif self.scatter_type == 'add':
+ output.scatter_add_(dim=2, index=index, src=x)
+ output = output.view(B, N, H, W)
+ return output
diff --git a/DI-engine/ding/torch_utils/network/soft_argmax.py b/DI-engine/ding/torch_utils/network/soft_argmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..166d0bb8f6fe17b16bcbdd7f4bd139cf5c692999
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/soft_argmax.py
@@ -0,0 +1,59 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+
+class SoftArgmax(nn.Module):
+ """
+ Overview:
+ A neural network module that computes the SoftArgmax operation (essentially a 2-dimensional spatial softmax),
+ which is often used for location regression tasks. It converts a feature map (such as a heatmap) into precise
+ coordinate locations.
+ Interfaces:
+ ``__init__``, ``forward``
+
+ .. note::
+ For more information on SoftArgmax, you can refer to
+ and the paper .
+ """
+
+ def __init__(self):
+ """
+ Overview:
+ Initialize the SoftArgmax module.
+ """
+ super(SoftArgmax, self).__init__()
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Perform the forward pass of the SoftArgmax operation.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor, typically a heatmap representing predicted locations.
+ Returns:
+ - location (:obj:`torch.Tensor`): The predicted coordinates as a result of the SoftArgmax operation.
+ Shapes:
+ - x: :math:`(B, C, H, W)`, where `B` is the batch size, `C` is the number of channels, \
+ and `H` and `W` represent height and width respectively.
+ - location: :math:`(B, 2)`, where `B` is the batch size and 2 represents the coordinates (height, width).
+ """
+ # Unpack the dimensions of the input tensor
+ B, C, H, W = x.shape
+ device, dtype = x.device, x.dtype
+ # Ensure the input tensor has a single channel
+ assert C == 1, "Input tensor should have only one channel"
+ # Create a meshgrid for the height (h_kernel) and width (w_kernel)
+ h_kernel = torch.arange(0, H, device=device).to(dtype)
+ h_kernel = h_kernel.view(1, 1, H, 1).repeat(1, 1, 1, W)
+
+ w_kernel = torch.arange(0, W, device=device).to(dtype)
+ w_kernel = w_kernel.view(1, 1, 1, W).repeat(1, 1, H, 1)
+
+ # Apply the softmax function across the spatial dimensions (height and width)
+ x = F.softmax(x.view(B, C, -1), dim=-1).view(B, C, H, W)
+ # Compute the expected values for height and width by multiplying the probability map by the meshgrids
+ h = (x * h_kernel).sum(dim=[1, 2, 3]) # Sum over the channel, height, and width dimensions
+ w = (x * w_kernel).sum(dim=[1, 2, 3]) # Sum over the channel, height, and width dimensions
+
+ # Stack the height and width coordinates along a new dimension to form the final output tensor
+ return torch.stack([h, w], dim=1)
diff --git a/DI-engine/ding/torch_utils/network/tests/test_activation.py b/DI-engine/ding/torch_utils/network/tests/test_activation.py
new file mode 100644
index 0000000000000000000000000000000000000000..5071d766f39b0947e5da7ae751702d0a70d8471f
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_activation.py
@@ -0,0 +1,46 @@
+import pytest
+import torch
+from ding.torch_utils import build_activation
+
+
+@pytest.mark.unittest
+class TestActivation:
+
+ def test(self):
+ act_type = 'relu'
+ act = build_activation(act_type, inplace=True)
+ act_type = 'prelu'
+ act = build_activation(act_type)
+ with pytest.raises(AssertionError):
+ act = build_activation(act_type, inplace=True)
+ with pytest.raises(KeyError):
+ act = build_activation('xxxlu')
+ act_type = 'glu'
+ input_dim = 50
+ output_dim = 150
+ context_dim = 200
+ act = build_activation(act_type
+ )(input_dim=input_dim, output_dim=output_dim, context_dim=context_dim, input_type='fc')
+ batch_size = 10
+ inputs = torch.rand(batch_size, input_dim).requires_grad_(True)
+ context = torch.rand(batch_size, context_dim).requires_grad_(True)
+ output = act(inputs, context)
+ assert output.shape == (batch_size, output_dim)
+ assert act.layer1.weight.grad is None
+ loss = output.mean()
+ loss.backward()
+ assert isinstance(inputs.grad, torch.Tensor)
+ assert isinstance(act.layer1.weight.grad, torch.Tensor)
+
+ act = build_activation(act_type)(
+ input_dim=input_dim, output_dim=output_dim, context_dim=context_dim, input_type='conv2d'
+ )
+ size = 16
+ inputs = torch.rand(batch_size, input_dim, size, size)
+ context = torch.rand(batch_size, context_dim, size, size)
+ output = act(inputs, context)
+ assert output.shape == (batch_size, output_dim, size, size)
+ assert act.layer1.weight.grad is None
+ loss = output.mean()
+ loss.backward()
+ assert isinstance(act.layer1.weight.grad, torch.Tensor)
diff --git a/DI-engine/ding/torch_utils/network/tests/test_diffusion.py b/DI-engine/ding/torch_utils/network/tests/test_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9794a726b0d6f6d16d4deb8e7883dc1193764ee
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_diffusion.py
@@ -0,0 +1,29 @@
+import pytest
+import torch
+from ding.torch_utils.network.diffusion import DiffusionUNet1d, TemporalValue
+
+batch_size = 2
+transition_dim = 10
+dim = 8
+dim_mults = [1, 2, 4]
+horizon = 4
+
+
+@pytest.mark.unittest
+class TestDiffusionNet:
+
+ def test_DiffusionNet1d(self):
+ diffusion = DiffusionUNet1d(transition_dim, dim, dim_mults)
+ input = torch.rand(batch_size, horizon, transition_dim)
+ t = torch.randint(0, horizon, (batch_size, )).long()
+ cond = {t: torch.randn(batch_size, 2) for t in range(horizon)}
+ output = diffusion(input, cond, time=t)
+ assert output.shape == (batch_size, horizon, transition_dim)
+
+ def test_TemporalValue(self):
+ value = TemporalValue(horizon, transition_dim, dim, dim_mults=dim_mults)
+ input = torch.rand(batch_size, horizon, transition_dim)
+ t = torch.randint(0, horizon, (batch_size, )).long()
+ cond = {t: torch.randn(batch_size, 2) for t in range(horizon)}
+ output = value(input, cond, time=t)
+ assert output.shape == (batch_size, 1)
diff --git a/DI-engine/ding/torch_utils/network/tests/test_dreamer.py b/DI-engine/ding/torch_utils/network/tests/test_dreamer.py
new file mode 100644
index 0000000000000000000000000000000000000000..accfb1d8c5d6c1810a7afe5ea530921c1f2b8a90
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_dreamer.py
@@ -0,0 +1,73 @@
+import pytest
+from easydict import EasyDict
+import torch
+from torch import distributions as torchd
+from itertools import product
+from ding.torch_utils.network.dreamer import DenseHead, SampleDist, OneHotDist, TwoHotDistSymlog, \
+ SymlogDist, ContDist, Bernoulli, UnnormalizedHuber, weight_init, uniform_weight_init
+
+# arguments
+shape = [255, (255, ), ()]
+# to do
+# dist = ['normal', 'huber', 'binary', 'twohot_symlog']
+dist = ['twohot_symlog']
+args = list(product(*[shape, dist]))
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('shape, dist', args)
+def test_DenseHead(shape, dist):
+ in_dim, layer_num, units, B, time = 1536, 2, 512, 16, 64
+ head = DenseHead(in_dim, shape, layer_num, units, dist=dist)
+ x = torch.randn(B, time, in_dim)
+ a = torch.randn(B, time, 1)
+ y = head(x)
+ assert y.mode().shape == (B, time, 1)
+ assert y.log_prob(a).shape == (B, time)
+
+
+B, time = 16, 64
+mean = torch.randn(B, time, 255)
+std = 1.0
+a = torch.randn(B, time, 1) # or torch.randn(B, time, 255)
+sample_shape = torch.Size([])
+
+
+@pytest.mark.unittest
+def test_ContDist():
+ dist_origin = torchd.normal.Normal(mean, std)
+ dist = torchd.independent.Independent(dist_origin, 1)
+ dist_new = ContDist(dist)
+ assert dist_new.mode().shape == (B, time, 255)
+ assert dist_new.log_prob(a).shape == (B, time)
+ assert dist_origin.log_prob(a).shape == (B, time, 255)
+ assert dist_new.sample().shape == (B, time, 255)
+
+
+@pytest.mark.unittest
+def test_UnnormalizedHuber():
+ dist_origin = UnnormalizedHuber(mean, std)
+ dist = torchd.independent.Independent(dist_origin, 1)
+ dist_new = ContDist(dist)
+ assert dist_new.mode().shape == (B, time, 255)
+ assert dist_new.log_prob(a).shape == (B, time)
+ assert dist_origin.log_prob(a).shape == (B, time, 255)
+ assert dist_new.sample().shape == (B, time, 255)
+
+
+@pytest.mark.unittest
+def test_Bernoulli():
+ dist_origin = torchd.bernoulli.Bernoulli(logits=mean)
+ dist = torchd.independent.Independent(dist_origin, 1)
+ dist_new = Bernoulli(dist)
+ assert dist_new.mode().shape == (B, time, 255)
+ assert dist_new.log_prob(a).shape == (B, time, 255)
+ # to do
+ # assert dist_new.sample().shape == (B, time, 255)
+
+
+@pytest.mark.unittest
+def test_TwoHotDistSymlog():
+ dist = TwoHotDistSymlog(logits=mean)
+ assert dist.mode().shape == (B, time, 1)
+ assert dist.log_prob(a).shape == (B, time)
diff --git a/DI-engine/ding/torch_utils/network/tests/test_gtrxl.py b/DI-engine/ding/torch_utils/network/tests/test_gtrxl.py
new file mode 100644
index 0000000000000000000000000000000000000000..4219da7417f75bb83052b8409449831e2b833b85
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_gtrxl.py
@@ -0,0 +1,107 @@
+import pytest
+import torch
+
+from ding.torch_utils import GTrXL, GRUGatingUnit
+
+
+@pytest.mark.unittest
+class TestGTrXL:
+
+ def test_GTrXl(self):
+ dim_size = 128
+ seq_len = 64
+ bs = 32
+ embedding_dim = 256
+ layer_num = 5
+ mem_len = 40
+ # input shape: cur_seq x bs x input_dim
+ memory = [None, torch.rand(layer_num + 1, mem_len, bs, embedding_dim)]
+ batch_first = [False, True]
+ for i in range(2):
+ m = memory[i]
+ bf = batch_first[i]
+ model = GTrXL(
+ input_dim=dim_size,
+ head_dim=2,
+ embedding_dim=embedding_dim,
+ memory_len=mem_len,
+ head_num=2,
+ mlp_num=2,
+ layer_num=layer_num,
+ )
+
+ input = torch.rand(seq_len, bs, dim_size)
+ if bf:
+ input = torch.transpose(input, 1, 0)
+ input.requires_grad_(True)
+ if m is None:
+ model.reset_memory(batch_size=bs)
+ else:
+ model.reset_memory(state=m)
+ output = model(input, batch_first=bf)
+ target = torch.randn(output['logit'].shape)
+ mse_loss = torch.nn.MSELoss()
+ target = torch.randn(output['logit'].shape)
+ loss = mse_loss(output['logit'], target)
+ assert input.grad is None
+ loss.backward()
+ assert isinstance(input.grad, torch.Tensor)
+ if bf is False:
+ assert output['logit'].shape == (seq_len, bs, embedding_dim)
+ else:
+ assert output['logit'].shape == (bs, seq_len, embedding_dim)
+ assert output['memory'].shape == (layer_num + 1, mem_len, bs, embedding_dim)
+ memory_out = output['memory']
+ if m is not None:
+ assert torch.all(torch.eq(memory_out, m))
+
+ def test_memory(self):
+ dim_size = 128
+ seq_len = 4
+ bs = 16
+ embedding_dim = 128
+ layer_num = 3
+ mem_len = 8
+ model = GTrXL(
+ input_dim=dim_size,
+ head_dim=2,
+ embedding_dim=embedding_dim,
+ memory_len=mem_len,
+ head_num=2,
+ mlp_num=2,
+ layer_num=layer_num,
+ )
+ memories = []
+ outs = []
+ for i in range(4):
+ input = torch.rand(seq_len, bs, dim_size)
+ output = model(input)
+ memories.append(output['memory'])
+ outs.append(output['logit'])
+ # first returned memory should be a zero matrix
+ assert sum(memories[0].flatten()) == 0
+ # last layer of second memory is equal to the output of the first input in its last 4 positions
+ assert torch.all(torch.eq(memories[1][-1][4:], outs[0]))
+ assert sum(memories[1][-1][:4].flatten()) == 0
+ # last layer of third memory is equal to the output of the second input in its last 4 positions
+ assert torch.all(torch.eq(memories[2][-1][4:], outs[1]))
+ # last layer of third memory is equal to the output of the first input in its first 4 positions
+ assert torch.all(torch.eq(memories[2][-1][:4], outs[0]))
+ # last layer of 4th memory is equal to the output of the second input in its first 4 positions
+ # and the third input in its last 4 positions
+ assert torch.all(torch.eq(memories[3][-1][4:], outs[2]))
+ assert torch.all(torch.eq(memories[3][-1][:4], outs[1]))
+
+ def test_gru(self):
+ input_dim = 32
+ gru = GRUGatingUnit(input_dim, 1.)
+ x = torch.rand((4, 12, 32))
+ y = torch.rand((4, 12, 32))
+ out = gru(x, y)
+ assert out.shape == x.shape
+ gru = GRUGatingUnit(input_dim, 100000.) # set high bias to check 'out' is similar to the first input 'x'
+ # In GTrXL the bias is initialized with a value high enough such that information coming from the second input
+ # 'y' are partially ignored so to produce a behavior more similar to a MDP, thus giving less importance to
+ # past information
+ out = gru(x, y)
+ torch.testing.assert_close(out, x)
diff --git a/DI-engine/ding/torch_utils/network/tests/test_gumbel_softmax.py b/DI-engine/ding/torch_utils/network/tests/test_gumbel_softmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..9168f06a594d73305f6602c5ddf01985d5493377
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_gumbel_softmax.py
@@ -0,0 +1,26 @@
+import numpy as np
+import pytest
+import torch
+
+from ding.torch_utils.network import GumbelSoftmax, gumbel_softmax
+
+
+@pytest.mark.unittest
+class TestGumbelSoftmax:
+
+ def test(self):
+ B = 4
+ N = 10
+ model = GumbelSoftmax()
+ # data case 1
+ for _ in range(N):
+ data = torch.rand((4, 10))
+ data = torch.log(data)
+ gumbelsoftmax = model(data, hard=False)
+ assert gumbelsoftmax.shape == (B, N)
+ # data case 2
+ for _ in range(N):
+ data = torch.rand((4, 10))
+ data = torch.log(data)
+ gumbelsoftmax = model(data, hard=True)
+ assert gumbelsoftmax.shape == (B, N)
diff --git a/DI-engine/ding/torch_utils/network/tests/test_merge.py b/DI-engine/ding/torch_utils/network/tests/test_merge.py
new file mode 100644
index 0000000000000000000000000000000000000000..58f2ed34e412eaf0f15731b79da2ad9b902755f6
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_merge.py
@@ -0,0 +1,46 @@
+import pytest
+import torch
+from ding.torch_utils.network import GatingType, SumMerge, VectorMerge
+
+
+@pytest.mark.unittest
+def test_SumMerge():
+ input_shape = (3, 5)
+ input = [torch.rand(input_shape).requires_grad_(True) for i in range(4)]
+ sum_merge = SumMerge()
+
+ output = sum_merge(input)
+ assert output.shape == (3, 5)
+ loss = output.mean()
+ loss.backward()
+ assert isinstance(input[0].grad, torch.Tensor)
+
+
+@pytest.mark.unittest
+def test_VectorMerge():
+ input_sizes = {'in1': 3, 'in2': 16, 'in3': 27}
+ output_size = 512
+ input_dict = {}
+ for k, v in input_sizes.items():
+ input_dict[k] = torch.rand((64, v)).requires_grad_(True)
+
+ vector_merge = VectorMerge(input_sizes, output_size, GatingType.NONE)
+ output = vector_merge(input_dict)
+ assert output.shape == (64, output_size)
+ loss = output.mean()
+ loss.backward()
+ assert isinstance(input_dict['in1'].grad, torch.Tensor)
+
+ vector_merge = VectorMerge(input_sizes, output_size, GatingType.GLOBAL)
+ output = vector_merge(input_dict)
+ assert output.shape == (64, output_size)
+ loss = output.mean()
+ loss.backward()
+ assert isinstance(input_dict['in1'].grad, torch.Tensor)
+
+ vector_merge = VectorMerge(input_sizes, output_size, GatingType.POINTWISE)
+ output = vector_merge(input_dict)
+ assert output.shape == (64, output_size)
+ loss = output.mean()
+ loss.backward()
+ assert isinstance(input_dict['in1'].grad, torch.Tensor)
diff --git a/DI-engine/ding/torch_utils/network/tests/test_nn_module.py b/DI-engine/ding/torch_utils/network/tests/test_nn_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..8fdc7845ee77ab1cc4e5867da8a9d13c77386ca5
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_nn_module.py
@@ -0,0 +1,240 @@
+import pytest
+import torch
+from torch.testing import assert_allclose
+
+from ding.torch_utils import build_activation
+from ding.torch_utils.network.nn_module import MLP, conv1d_block, conv2d_block, fc_block, deconv2d_block, \
+ ChannelShuffle, one_hot, NearestUpsample, BilinearUpsample, binary_encode, weight_init_, NaiveFlatten, \
+ normed_linear, normed_conv2d
+
+batch_size = 2
+in_channels = 2
+hidden_channels = 3
+out_channels = 3
+H = 2
+W = 3
+kernel_size = 2
+stride = 1
+padding = 0
+dilation = 1
+groups = 1
+init_type = ['xavier', 'kaiming', 'orthogonal']
+act = build_activation('relu')
+norm_type = 'BN'
+
+
+@pytest.mark.unittest
+class TestNnModule:
+
+ def run_model(self, input, model):
+ output = model(input)
+ loss = output.mean()
+ loss.backward()
+ assert isinstance(
+ input.grad,
+ torch.Tensor,
+ )
+ return output
+
+ def test_weight_init(self):
+ weight = torch.zeros(2, 3)
+ for init_type in ['xavier', 'orthogonal']:
+ weight_init_(weight, init_type)
+ for act in [torch.nn.LeakyReLU(), torch.nn.ReLU()]:
+ weight_init_(weight, 'kaiming', act)
+ with pytest.raises(KeyError):
+ weight_init_(weight, 'xxx')
+
+ def test_mlp(self):
+ layer_num = 3
+ input_tensor = torch.rand(batch_size, in_channels).requires_grad_(True)
+
+ for output_activation in [True, False]:
+ for output_norm in [True, False]:
+ for activation in [torch.nn.ReLU(), torch.nn.LeakyReLU(), torch.nn.Tanh(), None]:
+ for norm_type in ["LN", "BN", None]:
+ # Test case 1: MLP without last linear layer initialized to 0.
+ model = MLP(
+ in_channels,
+ hidden_channels,
+ out_channels,
+ layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ output_activation=output_activation,
+ output_norm=output_norm
+ )
+ output_tensor = self.run_model(input_tensor, model)
+ assert output_tensor.shape == (batch_size, out_channels)
+
+ # Test case 2: MLP with last linear layer initialized to 0.
+ model = MLP(
+ in_channels,
+ hidden_channels,
+ out_channels,
+ layer_num,
+ activation=activation,
+ norm_type=norm_type,
+ output_activation=output_activation,
+ output_norm=output_norm,
+ last_linear_layer_init_zero=True
+ )
+ output_tensor = self.run_model(input_tensor, model)
+ assert output_tensor.shape == (batch_size, out_channels)
+ last_linear_layer = None
+ for layer in reversed(model):
+ if isinstance(layer, torch.nn.Linear):
+ last_linear_layer = layer
+ break
+ assert_allclose(last_linear_layer.weight, torch.zeros_like(last_linear_layer.weight))
+ assert_allclose(last_linear_layer.bias, torch.zeros_like(last_linear_layer.bias))
+
+ def test_conv1d_block(self):
+ length = 2
+ input = torch.rand(batch_size, in_channels, length).requires_grad_(True)
+ block = conv1d_block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ activation=act,
+ norm_type=norm_type
+ )
+ output = self.run_model(input, block)
+ output_length = (length - kernel_size + 2 * padding // stride) + 1
+ assert output.shape == (batch_size, out_channels, output_length)
+
+ def test_conv2d_block(self):
+ input = torch.rand(batch_size, in_channels, H, W).requires_grad_(True)
+ for pad_type in ['zero', 'reflect', 'replication']:
+ block = conv2d_block(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ pad_type=pad_type,
+ activation=act,
+ norm_type=norm_type
+ )
+ output = self.run_model(input, block)
+ output_H = (H - kernel_size + 2 * padding // stride) + 1
+ output_W = (W - kernel_size + 2 * padding // stride) + 1
+ assert output.shape == (batch_size, out_channels, output_H, output_W)
+
+ def test_deconv2d_block(self):
+ input = torch.rand(batch_size, in_channels, H, W).requires_grad_(True)
+ output_padding = 0
+ block = deconv2d_block(
+ in_channels,
+ out_channels,
+ kernel_size,
+ stride=1,
+ padding=0,
+ output_padding=output_padding,
+ groups=1,
+ activation=act,
+ norm_type=norm_type
+ )
+ output = self.run_model(input, block)
+ output_H = (H - 1) * stride + output_padding - 2 * padding + kernel_size
+ output_W = (W - 1) * stride + output_padding - 2 * padding + kernel_size
+ assert output.shape == (batch_size, out_channels, output_H, output_W)
+
+ def test_fc_block(self):
+ input = torch.rand(batch_size, in_channels).requires_grad_(True)
+ for use_dropout in [True, False]:
+ block = fc_block(
+ in_channels,
+ out_channels,
+ activation=act,
+ norm_type=norm_type,
+ use_dropout=use_dropout,
+ dropout_probability=0.5
+ )
+ output = self.run_model(input, block)
+ assert output.shape == (batch_size, out_channels)
+
+ def test_normed_linear(self):
+ input = torch.rand(batch_size, in_channels).requires_grad_(True)
+ block = normed_linear(in_channels, out_channels, scale=1)
+ r = block.weight.norm(dim=None, keepdim=False) * block.weight.norm(dim=None, keepdim=False)
+ assert r.item() < out_channels + 0.01
+ assert r.item() > out_channels - 0.01
+ output = self.run_model(input, block)
+ assert output.shape == (batch_size, out_channels)
+
+ def test_normed_conv2d(self):
+ input = torch.rand(batch_size, in_channels, H, W).requires_grad_(True)
+ block = normed_conv2d(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding,
+ dilation=dilation,
+ groups=groups,
+ scale=1
+ )
+ r = block.weight.norm(dim=(1, 2, 3), p=2, keepdim=True)[0, 0, 0, 0]
+ assert r.item() < 1.01
+ assert r.item() > 0.99
+ output = self.run_model(input, block)
+ output_H = (H - kernel_size + 2 * padding // stride) + 1
+ output_W = (W - kernel_size + 2 * padding // stride) + 1
+ assert output.shape == (batch_size, out_channels, output_H, output_W)
+
+ def test_channel_shuffle(self):
+ group_num = 2
+ input = torch.rand(batch_size, in_channels, H, W).requires_grad_(True)
+ channel_shuffle = ChannelShuffle(group_num)
+ output = self.run_model(input, channel_shuffle)
+ assert output.shape == input.shape
+
+ def test_one_hot(self):
+ M = 2
+ N = 2
+ max_num = 3
+ input = torch.ones(M, N).long()
+ output = one_hot(input, max_num, num_first=False)
+ assert output.sum() == input.numel()
+ assert output.shape == (M, N, max_num)
+ output = one_hot(input, max_num, num_first=True)
+ assert output.sum() == input.numel()
+ assert output.shape == (max_num, M, N)
+ with pytest.raises(RuntimeError):
+ _ = one_hot(torch.arange(0, max_num), max_num - 1)
+
+ def test_upsample(self):
+ scale_factor = 2
+ input = torch.rand(batch_size, in_channels, H, W).requires_grad_(True)
+ model = NearestUpsample(scale_factor)
+ output = self.run_model(input, model)
+ assert output.shape == (batch_size, in_channels, 2 * H, 2 * W)
+ model = BilinearUpsample(scale_factor)
+ output = self.run_model(input, model)
+ assert output.shape == (batch_size, in_channels, 2 * H, 2 * W)
+
+ def test_binary_encode(self):
+ input = torch.tensor([4])
+ max_val = torch.tensor(8)
+ output = binary_encode(input, max_val)
+ assert torch.equal(output, torch.tensor([[0, 1, 0, 0]]))
+
+ @pytest.mark.tmp
+ def test_flatten(self):
+ inputs = torch.randn(4, 3, 8, 8)
+ model1 = NaiveFlatten()
+ output1 = model1(inputs)
+ assert output1.shape == (4, 3 * 8 * 8)
+ model2 = NaiveFlatten(1, 2)
+ output2 = model2(inputs)
+ assert output2.shape == (4, 3 * 8, 8)
+ model3 = NaiveFlatten(1, 3)
+ output3 = model2(inputs)
+ assert output1.shape == (4, 3 * 8 * 8)
diff --git a/DI-engine/ding/torch_utils/network/tests/test_normalization.py b/DI-engine/ding/torch_utils/network/tests/test_normalization.py
new file mode 100644
index 0000000000000000000000000000000000000000..655d863fa57c78f02bc70208405d8e2627a188a7
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_normalization.py
@@ -0,0 +1,42 @@
+import pytest
+import torch
+from ding.torch_utils import build_normalization
+
+num_features = 2
+batch_size = 2
+H, W = 2, 3
+
+
+@pytest.mark.unittest
+class TestNormalization:
+
+ def validate(self, input, norm):
+ output = norm(input)
+ loss = output.mean()
+ loss.backward()
+ assert output.shape == input.shape
+ assert isinstance(input.grad, torch.Tensor)
+
+ def test(self):
+ with pytest.raises(KeyError):
+ norm = build_normalization('XXXN')
+ input1d = torch.rand(batch_size, num_features).requires_grad_(True)
+ input2d = torch.rand(batch_size, num_features, H, W).requires_grad_(True)
+
+ norm_type = 'BN'
+ norm = build_normalization(norm_type, dim=1)(num_features)
+ self.validate(input1d, norm)
+
+ norm = build_normalization(norm_type, dim=2)(num_features)
+ self.validate(input2d, norm)
+
+ norm_type = 'LN'
+ norm = build_normalization(norm_type)(input1d.shape[1:])
+ self.validate(input1d, norm)
+
+ norm = build_normalization(norm_type)(input2d.shape[2:])
+ self.validate(input2d, norm)
+
+ norm_type = 'IN'
+ norm = build_normalization(norm_type, dim=2)(num_features)
+ self.validate(input2d, norm)
diff --git a/DI-engine/ding/torch_utils/network/tests/test_popart.py b/DI-engine/ding/torch_utils/network/tests/test_popart.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ea694f49d5ce9e45efd6ae7e16a42f3c64099f1
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_popart.py
@@ -0,0 +1,36 @@
+import pytest
+import torch
+from ding.torch_utils import PopArt
+
+batch_size = 4
+input_features = 16
+output_features = 4
+
+
+@pytest.mark.unittest
+class TestPopArt:
+
+ def test_popart(self):
+ input = torch.rand((batch_size, input_features)).requires_grad_(True)
+ model = PopArt(input_features, output_features)
+ output = model(input)
+ loss = output['pred'].mean()
+ loss.backward()
+ assert isinstance(input.grad, torch.Tensor)
+
+ # validate the shape of parameters and outputs
+ assert output['pred'].shape == (batch_size, output_features)
+ assert output['unnormalized_pred'].shape == (batch_size, output_features)
+ assert model.mu.shape == torch.Size([output_features])
+ assert model.sigma.shape == torch.Size([output_features])
+ assert model.v.shape == torch.Size([output_features])
+
+ # validate the normalization
+ assert torch.all(torch.abs(output['pred']) <= 1)
+
+ model.update_parameters(torch.rand(batch_size, output_features))
+
+ # validate the non-empty of parameters
+ assert not torch.all(torch.isnan(model.mu))
+ assert not torch.all(torch.isnan(model.sigma))
+ assert not torch.all(torch.isnan(model.v))
diff --git a/DI-engine/ding/torch_utils/network/tests/test_res_block.py b/DI-engine/ding/torch_utils/network/tests/test_res_block.py
new file mode 100644
index 0000000000000000000000000000000000000000..26b0b946b3908ab7e69f424d32ef91978677ec24
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_res_block.py
@@ -0,0 +1,36 @@
+import torch
+import pytest
+from ding.torch_utils.network import ResBlock, ResFCBlock
+
+batch_size = 2
+in_channels = 2
+H, W = 2, 3
+activation = torch.nn.ReLU()
+norm_type = 'BN'
+res_type = ['basic', 'bottleneck', 'downsample']
+res_type_classic = ['basic', 'bottleneck']
+
+
+@pytest.mark.unittest
+class TestResBlock:
+
+ def test_res_blcok(self):
+ input = torch.rand(batch_size, in_channels, 2, 3).requires_grad_(True)
+ for r in res_type:
+ for norm_type in ['BN', 'LN', 'IN', 'GN', None]:
+ model = ResBlock(in_channels, activation, norm_type, r)
+ output = model(input)
+ loss = output.mean()
+ loss.backward()
+ if r in res_type_classic:
+ assert output.shape == input.shape
+ assert isinstance(input.grad, torch.Tensor)
+
+ def test_res_fc_block(self):
+ input = torch.rand(batch_size, in_channels).requires_grad_(True)
+ model = ResFCBlock(in_channels, activation, norm_type)
+ output = model(input)
+ loss = output.mean()
+ loss.backward()
+ assert output.shape == input.shape
+ assert isinstance(input.grad, torch.Tensor)
diff --git a/DI-engine/ding/torch_utils/network/tests/test_resnet.py b/DI-engine/ding/torch_utils/network/tests/test_resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..75e613bf064a99f85a47bc1591edf46ae2d28cae
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_resnet.py
@@ -0,0 +1,105 @@
+import pytest
+import torch
+from ding.torch_utils.network import resnet18
+from ding.torch_utils.network.resnet \
+ import ResNet, BasicBlock, Bottleneck, AvgPool2dSame, avg_pool2d_same, ClassifierHead
+from itertools import product
+
+
+@pytest.mark.unittest
+def test_resnet18():
+ model = resnet18()
+ print(model)
+ inputs = torch.randn(4, 3, 224, 224)
+ outputs = model(inputs)
+ assert outputs.shape == (4, 1000)
+
+
+stem_type = ['', 'deep', 'deep,tiered']
+replace_stem_pool = [True, False]
+avg_down = [True, False]
+block = [BasicBlock]
+layers = [[2, 2, 2, 2]]
+zero_init_last_bn = [True, False]
+output_stride = [8, 32]
+num_classes = [0, 1000]
+args = [
+ item for item in
+ product(*[stem_type, replace_stem_pool, avg_down, block, layers, zero_init_last_bn, output_stride, num_classes])
+]
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize(
+ 'stem_type, replace_stem_pool, avg_down, block, layers, zero_init_last_bn, output_stride, num_classes', args
+)
+def test_resnet(stem_type, replace_stem_pool, avg_down, block, layers, zero_init_last_bn, output_stride, num_classes):
+ model = ResNet(
+ stem_type=stem_type,
+ replace_stem_pool=replace_stem_pool,
+ avg_down=avg_down,
+ block=block,
+ layers=layers,
+ output_stride=output_stride,
+ num_classes=num_classes,
+ drop_rate=0. if stem_type == 'deep' else 0.05
+ )
+ model.init_weights(zero_init_last_bn=zero_init_last_bn)
+ inputs = torch.randn(4, 3, 224, 224).requires_grad_(True)
+ outputs = model(inputs)
+ assert outputs.shape == (4, num_classes if num_classes > 0 else 512)
+ mse_loss = torch.nn.MSELoss()
+ target = torch.randn(outputs.shape)
+ loss = mse_loss(outputs, target)
+ assert inputs.grad is None
+ loss.backward()
+ assert isinstance(inputs.grad, torch.Tensor)
+
+ model.reset_classifier(num_classes=183)
+ inputs = torch.randn(4, 3, 224, 224).requires_grad_(True)
+ outputs = model(inputs)
+ assert outputs.shape == (4, 183)
+ target = torch.randn(outputs.shape)
+ loss = mse_loss(outputs, target)
+ assert inputs.grad is None
+ loss.backward()
+ assert isinstance(inputs.grad, torch.Tensor)
+
+ clf = model.get_classifier()
+ outputs = model.forward_features(x=inputs)
+
+
+@pytest.mark.unittest
+def test_avg_pool2d_same():
+ x = torch.randn(4, 4, 4, 4).requires_grad_(True)
+ avg_pool2d_same(x=x, kernel_size=(2, 2), stride=(2, 2))
+
+
+inplanes = [4]
+planes = [1]
+args_btl = [item for item in product(*[inplanes, planes])]
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('inplanes, planes', args_btl)
+def test_Bottleneck(inplanes, planes):
+ model = Bottleneck(inplanes=inplanes, planes=planes)
+ x = torch.randn(4, 4, 4, 4).requires_grad_(True)
+ outputs = model(x)
+ assert outputs.shape == (4, 4, 4, 4)
+ model.zero_init_last_bn()
+
+
+in_chs = [1]
+num_classes = [0, 1]
+drop_rate = [0, 0.05]
+args_cls = [item for item in product(*[in_chs, num_classes, drop_rate])]
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('in_chs, num_classes, drop_rate', args_cls)
+def test_ClassifierHead(in_chs, num_classes, drop_rate):
+ model = ClassifierHead(in_chs=in_chs, num_classes=num_classes, drop_rate=drop_rate)
+ inputs = torch.randn(1, 1, 1, 1).requires_grad_(True)
+ outputs = model(inputs)
+ assert outputs.shape == (1, 1, 1, 1)
diff --git a/DI-engine/ding/torch_utils/network/tests/test_rnn.py b/DI-engine/ding/torch_utils/network/tests/test_rnn.py
new file mode 100644
index 0000000000000000000000000000000000000000..a77a4db5fe26142c9512cad5fea93631fa4675af
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_rnn.py
@@ -0,0 +1,69 @@
+import pytest
+import torch
+from ding.torch_utils import get_lstm, sequence_mask
+
+
+@pytest.mark.unittest
+class TestLstm:
+
+ def test(self):
+ seq_len = 2
+ batch_size = 3
+ input_size = 2
+ hidden_size = 3
+ num_layers = 2
+ norm_type = 'LN'
+ dropout = 0.1
+ input = torch.rand(seq_len, batch_size, input_size).requires_grad_(True)
+ # abnormal case
+ lstm = get_lstm('normal', input_size, hidden_size, num_layers, norm_type, dropout)
+ prev_state = torch.randn(4)
+ with pytest.raises(TypeError):
+ _, _ = lstm(input, prev_state, list_next_state=True)
+ with pytest.raises(RuntimeError):
+ _, _ = lstm(input, [[] for _ in range(batch_size + 1)], list_next_state=True)
+ # normal case
+ lstm_type = ['normal', 'pytorch']
+ for l in lstm_type:
+ lstm = get_lstm(l, input_size, hidden_size, num_layers, norm_type, dropout)
+ prev_state = None
+ output, prev_state = lstm(input, prev_state, list_next_state=True)
+ loss = output.mean()
+ loss.backward()
+ assert output.shape == (seq_len, batch_size, hidden_size)
+ assert len(prev_state) == batch_size
+ assert prev_state[0]['h'].shape == (num_layers, 1, hidden_size)
+ assert isinstance(input.grad, torch.Tensor)
+
+ prev_state = None
+ for s in range(seq_len):
+ input_step = input[s:s + 1]
+ output, prev_state = lstm(input_step, prev_state, list_next_state=True)
+ assert output.shape == (1, batch_size, hidden_size)
+ assert len(prev_state) == batch_size
+ assert prev_state[0]['h'].shape == (num_layers, 1, hidden_size)
+ assert isinstance(input.grad, torch.Tensor)
+
+ prev_state = None
+ for s in range(seq_len):
+ input_step = input[s:s + 1]
+ output, prev_state = lstm(input_step, prev_state, list_next_state=False)
+ assert output.shape == (1, batch_size, hidden_size)
+ assert len(prev_state) == 2
+ assert prev_state['h'].shape == (num_layers, batch_size, hidden_size)
+ assert isinstance(input.grad, torch.Tensor)
+
+ randns = torch.randn(num_layers, 1, hidden_size)
+ prev_state = [None for _ in range(batch_size)]
+ prev_state[0] = {'h': randns, 'c': randns}
+ output, prev_state = lstm(input, prev_state, list_next_state=True)
+
+
+@pytest.mark.unittest
+def test_sequence_mask():
+ lengths = torch.LongTensor([0, 4, 3, 1, 2])
+ masks = sequence_mask(lengths)
+ assert masks.shape == (5, 4)
+ assert masks.dtype == torch.bool
+ masks = sequence_mask(lengths, max_len=3)
+ assert masks.shape == (5, 3)
diff --git a/DI-engine/ding/torch_utils/network/tests/test_scatter.py b/DI-engine/ding/torch_utils/network/tests/test_scatter.py
new file mode 100644
index 0000000000000000000000000000000000000000..a79e10737684d0b3ef971f4dead2eb141219454b
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_scatter.py
@@ -0,0 +1,45 @@
+import pytest
+import torch
+from ding.torch_utils import ScatterConnection
+
+
+@pytest.mark.unittest
+class TestScatterConnection:
+
+ def test_naive(self):
+ for scatter_type in ['add', 'cover']:
+ model = ScatterConnection(scatter_type)
+ B, M, N = 2, 24, 32
+ H, W = 2, 3
+ input = torch.rand(B, M, N).requires_grad_(True)
+ h = torch.randint(
+ low=0, high=H, size=(
+ B,
+ M,
+ )
+ ).unsqueeze(dim=2)
+ w = torch.randint(
+ low=0, high=W, size=(
+ B,
+ M,
+ )
+ ).unsqueeze(dim=2)
+ location = torch.cat([h, w], dim=2)
+ output = model(x=input, spatial_size=(H, W), location=location)
+ loss = output.mean()
+ loss.backward()
+ assert isinstance(input.grad, torch.Tensor)
+
+ def test_xy_forward(self):
+ for scatter_type in ['add', 'cover']:
+ model = ScatterConnection(scatter_type)
+ B, M, N = 10, 20, 3
+ spatial_size = (13, 17)
+ input = torch.randn(size=(B, M, N)).requires_grad_(True)
+ coord_x = torch.randint(low=0, high=13, size=(B, M))
+ coord_y = torch.randint(low=0, high=17, size=(B, M))
+ output = model.xy_forward(input, spatial_size, coord_x, coord_y)
+ loss = output.mean()
+ loss.backward()
+ assert isinstance(input.grad, torch.Tensor)
+ assert output.shape == (B, N, *spatial_size)
diff --git a/DI-engine/ding/torch_utils/network/tests/test_soft_argmax.py b/DI-engine/ding/torch_utils/network/tests/test_soft_argmax.py
new file mode 100644
index 0000000000000000000000000000000000000000..d80a964a53be693bc33e0bb9a5d2a4621eaa921d
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_soft_argmax.py
@@ -0,0 +1,42 @@
+import numpy as np
+import pytest
+import torch
+
+from ding.torch_utils.network import SoftArgmax
+
+
+@pytest.mark.unittest
+class TestSoftArgmax:
+
+ def test(self):
+ H, W = (48, 64)
+ B = 4
+ N = 10
+ model = SoftArgmax()
+ # data case 1
+ for _ in range(N):
+ test_h = np.random.randint(0, H, size=(B, ))
+ test_w = np.random.randint(0, W, size=(B, ))
+ test_location = torch.LongTensor([test_h, test_w]).permute(1, 0).contiguous()
+ assert test_location.shape == (B, 2)
+ data = torch.full((B, 1, H, W), -1e8)
+ for idx, (h, w) in enumerate(test_location):
+ data[idx, 0, h, w] = 1
+
+ pred_location = model(data)
+ assert pred_location.shape == (B, 2)
+ assert torch.abs(pred_location - test_location).sum() < 1e-6
+ # data case 2
+ pseudo_gauss_kernel = torch.FloatTensor([1, 3, 1, 3, 5, 3, 1, 3, 1]).reshape(3, 3)
+ for _ in range(N):
+ test_h = np.random.randint(1, H - 1, size=(B, ))
+ test_w = np.random.randint(1, W - 1, size=(B, ))
+ test_location = torch.LongTensor([test_h, test_w]).permute(1, 0).contiguous()
+ assert test_location.shape == (B, 2)
+ data = torch.full((B, 1, H, W), -1e8)
+ for idx, (h, w) in enumerate(test_location):
+ data[idx, 0, h - 1:h + 2, w - 1:w + 2] = pseudo_gauss_kernel
+
+ pred_location = model(data)
+ assert pred_location.shape == (B, 2)
+ assert torch.abs(pred_location - test_location).sum() < 1e-4
diff --git a/DI-engine/ding/torch_utils/network/tests/test_transformer.py b/DI-engine/ding/torch_utils/network/tests/test_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..813af0d025d266599c37a47a02062929ff5e3ed8
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/tests/test_transformer.py
@@ -0,0 +1,31 @@
+import pytest
+import torch
+
+from ding.torch_utils import Transformer
+
+
+@pytest.mark.unittest
+class TestTransformer:
+
+ def test(self):
+ batch_size = 2
+ num_entries = 2
+ C = 2
+ masks = [None, torch.rand(batch_size, num_entries).round().bool()]
+ for mask in masks:
+ output_dim = 4
+ model = Transformer(
+ input_dim=C,
+ head_dim=2,
+ hidden_dim=3,
+ output_dim=output_dim,
+ head_num=2,
+ mlp_num=2,
+ layer_num=2,
+ )
+ input = torch.rand(batch_size, num_entries, C).requires_grad_(True)
+ output = model(input, mask)
+ loss = output.mean()
+ loss.backward()
+ assert isinstance(input.grad, torch.Tensor)
+ assert output.shape == (batch_size, num_entries, output_dim)
diff --git a/DI-engine/ding/torch_utils/network/transformer.py b/DI-engine/ding/torch_utils/network/transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a508b3909b14a3ebca95abe497891f62e9e1ed7
--- /dev/null
+++ b/DI-engine/ding/torch_utils/network/transformer.py
@@ -0,0 +1,263 @@
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import math
+from typing import List, Optional, Tuple
+
+from .nn_module import fc_block, build_normalization
+
+
+class Attention(nn.Module):
+ """
+ Overview:
+ For each entry embedding, compute individual attention across all entries, add them up to get output attention.
+ Interfaces:
+ ``__init__``, ``split``, ``forward``
+ """
+
+ def __init__(self, input_dim: int, head_dim: int, output_dim: int, head_num: int, dropout: nn.Module) -> None:
+ """
+ Overview:
+ Initialize the Attention module with the provided dimensions and dropout layer.
+ Arguments:
+ - input_dim (:obj:`int`): The dimension of the input.
+ - head_dim (:obj:`int`): The dimension of each head in the multi-head attention mechanism.
+ - output_dim (:obj:`int`): The dimension of the output.
+ - head_num (:obj:`int`): The number of heads in the multi-head attention mechanism.
+ - dropout (:obj:`nn.Module`): The dropout layer used in the attention mechanism.
+ """
+ super(Attention, self).__init__()
+ self.head_num = head_num
+ self.head_dim = head_dim
+ self.dropout = dropout
+ self.attention_pre = fc_block(input_dim, head_dim * head_num * 3) # query, key, value
+ self.project = fc_block(head_dim * head_num, output_dim)
+
+ def split(self, x: torch.Tensor, T: bool = False) -> List[torch.Tensor]:
+ """
+ Overview:
+ Split the input to get multi-head queries, keys, and values.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The tensor to be split, which could be a query, key, or value.
+ - T (:obj:`bool`, optional): If True, transpose the output tensors. Defaults to False.
+ Returns:
+ - x (:obj:`List[torch.Tensor]`): A list of output tensors for each head.
+ """
+ B, N = x.shape[:2]
+ x = x.view(B, N, self.head_num, self.head_dim)
+ x = x.permute(0, 2, 1, 3).contiguous() # B, head_num, N, head_dim
+ if T:
+ x = x.permute(0, 1, 3, 2).contiguous()
+ return x
+
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Overview:
+ Compute the attention from the input tensor.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor for the forward computation.
+ - mask (:obj:`Optional[torch.Tensor]`, optional): Optional mask to exclude invalid entries.
+ Defaults to None.
+ Returns:
+ - attention (:obj:`torch.Tensor`): The computed attention tensor.
+ """
+ assert (len(x.shape) == 3)
+ B, N = x.shape[:2]
+ x = self.attention_pre(x)
+ query, key, value = torch.chunk(x, 3, dim=2)
+ query, key, value = self.split(query), self.split(key, T=True), self.split(value)
+
+ score = torch.matmul(query, key) # B, head_num, N, N
+ score /= math.sqrt(self.head_dim)
+ if mask is not None:
+ # inplace modification for reasonable softmax
+ score.masked_fill_(~mask, value=-1e9)
+
+ score = F.softmax(score, dim=-1)
+ score = self.dropout(score)
+ attention = torch.matmul(score, value) # B, head_num, N, head_dim
+
+ attention = attention.permute(0, 2, 1, 3).contiguous() # B, N, head_num, head_dim
+ attention = self.project(attention.view(B, N, -1)) # B, N, output_dim
+ return attention
+
+
+class TransformerLayer(nn.Module):
+ """
+ Overview:
+ In transformer layer, first computes entries's attention and applies a feedforward layer.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self, input_dim: int, head_dim: int, hidden_dim: int, output_dim: int, head_num: int, mlp_num: int,
+ dropout: nn.Module, activation: nn.Module
+ ) -> None:
+ """
+ Overview:
+ Initialize the TransformerLayer with the provided dimensions, dropout layer, and activation function.
+ Arguments:
+ - input_dim (:obj:`int`): The dimension of the input.
+ - head_dim (:obj:`int`): The dimension of each head in the multi-head attention mechanism.
+ - hidden_dim (:obj:`int`): The dimension of the hidden layer in the MLP (Multi-Layer Perceptron).
+ - output_dim (:obj:`int`): The dimension of the output.
+ - head_num (:obj:`int`): The number of heads in the multi-head attention mechanism.
+ - mlp_num (:obj:`int`): The number of layers in the MLP.
+ - dropout (:obj:`nn.Module`): The dropout layer used in the attention mechanism.
+ - activation (:obj:`nn.Module`): The activation function used in the MLP.
+ """
+ super(TransformerLayer, self).__init__()
+ self.attention = Attention(input_dim, head_dim, output_dim, head_num, dropout)
+ self.layernorm1 = build_normalization('LN')(output_dim)
+ self.dropout = dropout
+ layers = []
+ dims = [output_dim] + [hidden_dim] * (mlp_num - 1) + [output_dim]
+ for i in range(mlp_num):
+ layers.append(fc_block(dims[i], dims[i + 1], activation=activation))
+ if i != mlp_num - 1:
+ layers.append(self.dropout)
+ layers.append(self.dropout)
+ self.mlp = nn.Sequential(*layers)
+ self.layernorm2 = build_normalization('LN')(output_dim)
+
+ def forward(self, inputs: Tuple[torch.Tensor, torch.Tensor]) -> Tuple[torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Compute the forward pass through the Transformer layer.
+ Arguments:
+ - inputs (:obj:`Tuple[torch.Tensor, torch.Tensor]`): A tuple containing the input tensor `x` and
+ the mask tensor.
+ Returns:
+ - output (:obj:`Tuple[torch.Tensor, torch.Tensor]`): A tuple containing the predicted value tensor and
+ the mask tensor.
+ """
+ x, mask = inputs
+ a = self.dropout(self.attention(x, mask))
+ x = self.layernorm1(x + a)
+ m = self.dropout(self.mlp(x))
+ x = self.layernorm2(x + m)
+ return x, mask
+
+
+class Transformer(nn.Module):
+ """
+ Overview:
+ Implementation of the Transformer model.
+
+ .. note::
+ For more details, refer to "Attention is All You Need": http://arxiv.org/abs/1706.03762.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(
+ self,
+ input_dim: int,
+ head_dim: int = 128,
+ hidden_dim: int = 1024,
+ output_dim: int = 256,
+ head_num: int = 2,
+ mlp_num: int = 2,
+ layer_num: int = 3,
+ dropout_ratio: float = 0.,
+ activation: nn.Module = nn.ReLU(),
+ ):
+ """
+ Overview:
+ Initialize the Transformer with the provided dimensions, dropout layer, activation function,
+ and layer numbers.
+ Arguments:
+ - input_dim (:obj:`int`): The dimension of the input.
+ - head_dim (:obj:`int`): The dimension of each head in the multi-head attention mechanism.
+ - hidden_dim (:obj:`int`): The dimension of the hidden layer in the MLP (Multi-Layer Perceptron).
+ - output_dim (:obj:`int`): The dimension of the output.
+ - head_num (:obj:`int`): The number of heads in the multi-head attention mechanism.
+ - mlp_num (:obj:`int`): The number of layers in the MLP.
+ - layer_num (:obj:`int`): The number of Transformer layers.
+ - dropout_ratio (:obj:`float`): The dropout ratio for the dropout layer.
+ - activation (:obj:`nn.Module`): The activation function used in the MLP.
+ """
+ super(Transformer, self).__init__()
+ self.embedding = fc_block(input_dim, output_dim, activation=activation)
+ self.act = activation
+ layers = []
+ dims = [output_dim] + [output_dim] * layer_num
+ self.dropout = nn.Dropout(dropout_ratio)
+ for i in range(layer_num):
+ layers.append(
+ TransformerLayer(dims[i], head_dim, hidden_dim, dims[i + 1], head_num, mlp_num, self.dropout, self.act)
+ )
+ self.main = nn.Sequential(*layers)
+
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
+ """
+ Overview:
+ Perform the forward pass through the Transformer.
+ Arguments:
+ - x (:obj:`torch.Tensor`): The input tensor, with shape `(B, N, C)`, where `B` is batch size, \
+ `N` is the number of entries, and `C` is the feature dimension.
+ - mask (:obj:`Optional[torch.Tensor]`, optional): The mask tensor (bool), used to mask out invalid \
+ entries in attention. It has shape `(B, N)`, where `B` is batch size and `N` is number of \
+ entries. Defaults to None.
+ Returns:
+ - x (:obj:`torch.Tensor`): The output tensor from the Transformer.
+ """
+ if mask is not None:
+ mask = mask.unsqueeze(dim=1).repeat(1, mask.shape[1], 1).unsqueeze(dim=1)
+ x = self.embedding(x)
+ x = self.dropout(x)
+ x, mask = self.main((x, mask))
+ return x
+
+
+class ScaledDotProductAttention(nn.Module):
+ """
+ Overview:
+ Implementation of Scaled Dot Product Attention, a key component of Transformer models.
+ This class performs the dot product of the query, key and value tensors, scales it with the square root of the
+ dimension of the key vector (d_k) and applies dropout for regularization.
+ Interfaces:
+ ``__init__``, ``forward``
+ """
+
+ def __init__(self, d_k: int, dropout: float = 0.0) -> None:
+ """
+ Overview:
+ Initialize the ScaledDotProductAttention module with the dimension of the key vector and the dropout rate.
+ Arguments:
+ - d_k (:obj:`int`): The dimension of the key vector. This will be used to scale the dot product of the \
+ query and key.
+ - dropout (:obj:`float`, optional): The dropout rate to be applied after the softmax operation. \
+ Defaults to 0.0.
+ """
+ super(ScaledDotProductAttention, self).__init__()
+ self.d_k = d_k
+ self.dropout = nn.Dropout(dropout)
+
+ def forward(
+ self,
+ q: torch.Tensor,
+ k: torch.Tensor,
+ v: torch.Tensor,
+ mask: Optional[torch.Tensor] = None
+ ) -> torch.Tensor:
+ """
+ Overview:
+ Perform the Scaled Dot Product Attention operation on the query, key and value tensors.
+ Arguments:
+ - q (:obj:`torch.Tensor`): The query tensor.
+ - k (:obj:`torch.Tensor`): The key tensor.
+ - v (:obj:`torch.Tensor`): The value tensor.
+ - mask (:obj:`Optional[torch.Tensor]`): An optional mask tensor to be applied on the attention scores.
+ Defaults to None.
+ Returns:
+ - output (:obj:`torch.Tensor`): The output tensor after the attention operation.
+ """
+ attn = torch.matmul(q / (self.d_k ** 0.5), k.transpose(2, 3))
+ if mask is not None:
+ # inplace modification for reasonable softmax
+ attn.masked_fill_(~mask, -1e9)
+ attn = self.dropout(F.softmax(attn, dim=-1))
+ output = torch.matmul(attn, v)
+ return output
diff --git a/DI-engine/ding/torch_utils/nn_test_helper.py b/DI-engine/ding/torch_utils/nn_test_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..d62ebc6ccfeb80c6ab106145abac3b4eacc8e341
--- /dev/null
+++ b/DI-engine/ding/torch_utils/nn_test_helper.py
@@ -0,0 +1,46 @@
+from typing import Union, List
+import torch
+
+
+def is_differentiable(
+ loss: torch.Tensor, model: Union[torch.nn.Module, List[torch.nn.Module]], print_instead: bool = False
+) -> None:
+ """
+ Overview:
+ Judge whether the model/models are differentiable. First check whether module's grad is None,
+ then do loss's back propagation, finally check whether module's grad are torch.Tensor.
+ Arguments:
+ - loss (:obj:`torch.Tensor`): loss tensor of the model
+ - model (:obj:`Union[torch.nn.Module, List[torch.nn.Module]]`): model or models to be checked
+ - print_instead (:obj:`bool`): Whether to print module's final grad result, \
+ instead of asserting. Default set to ``False``.
+ """
+ assert isinstance(loss, torch.Tensor)
+ if isinstance(model, list):
+ for m in model:
+ assert isinstance(m, torch.nn.Module)
+ for k, p in m.named_parameters():
+ assert p.grad is None, k
+ elif isinstance(model, torch.nn.Module):
+ for k, p in model.named_parameters():
+ assert p.grad is None, k
+ else:
+ raise TypeError('model must be list or nn.Module')
+
+ loss.backward()
+
+ if isinstance(model, list):
+ for m in model:
+ for k, p in m.named_parameters():
+ if print_instead:
+ if not isinstance(p.grad, torch.Tensor):
+ print(k, "grad is:", p.grad)
+ else:
+ assert isinstance(p.grad, torch.Tensor), k
+ elif isinstance(model, torch.nn.Module):
+ for k, p in model.named_parameters():
+ if print_instead:
+ if not isinstance(p.grad, torch.Tensor):
+ print(k, "grad is:", p.grad)
+ else:
+ assert isinstance(p.grad, torch.Tensor), k
diff --git a/DI-engine/ding/torch_utils/optimizer_helper.py b/DI-engine/ding/torch_utils/optimizer_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4d351cca251f9a0eb80250a47f360472080bb23
--- /dev/null
+++ b/DI-engine/ding/torch_utils/optimizer_helper.py
@@ -0,0 +1,878 @@
+import torch
+import math
+from torch.nn.utils import clip_grad_norm_, clip_grad_value_
+from typing import Union, Iterable, Tuple, Callable, List
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import pdb
+import numpy as np
+import copy
+import random
+
+inf = math.inf
+
+
+def calculate_grad_norm(model: torch.nn.Module, norm_type=2) -> float:
+ """
+ Overview:
+ calculate grad norm of the parameters whose grad norms are not None in the model.
+ Arguments:
+ - model: torch.nn.Module
+ - norm_type (:obj:`int` or `inf`)
+ """
+ parameters = list(filter(lambda p: p.grad is not None, model.parameters()))
+ if parameters == []:
+ parameters = 0
+ return 0
+ if norm_type == 'inf':
+ total_norm = max(p.grad.data.abs().max() for p in parameters)
+ return float(total_norm)
+ else:
+ total_norm = 0
+ for p in parameters:
+ param_norm = p.grad.data.norm(norm_type)
+ total_norm += param_norm.item() ** norm_type
+ total_norm = total_norm ** (1. / norm_type)
+ return float(total_norm)
+
+
+def calculate_grad_norm_without_bias_two_norm(model: torch.nn.Module) -> float:
+ """
+ Overview:
+ calculate grad norm of the parameters whose grad norms are not None in the model.
+ Arguments:
+ - model: torch.nn.Module
+ """
+ _list = []
+ for name, param in model.named_parameters():
+ if 'bias' not in name and param.requires_grad:
+ if param.grad is None:
+ return 0
+ _list.append(param.grad.data.norm(2).item() ** 2)
+ return float(sum(_list) ** (1. / 2))
+
+
+def grad_ignore_norm(parameters, max_norm, norm_type=2):
+ """
+ Overview:
+ Clip the gradient norm of an iterable of parameters.
+ Arguments:
+ - parameters (:obj:`Iterable`): an iterable of torch.Tensor
+ - max_norm (:obj:`float`): the max norm of the gradients
+ - norm_type (:obj:`float`): 2.0 means use norm2 to clip
+ """
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
+ max_norm = float(max_norm)
+ norm_type = float(norm_type)
+ if norm_type == inf:
+ total_norm = max(p.grad.data.abs().max() for p in parameters)
+ else:
+ total_norm = 0
+ for p in parameters:
+ param_norm = p.grad.data.norm(norm_type)
+ total_norm += param_norm.item() ** norm_type
+ total_norm = total_norm ** (1. / norm_type)
+ clip_coef = max_norm / (total_norm + 1e-6)
+ if clip_coef < 1:
+ for p in parameters:
+ p.grad.zero_()
+ return total_norm
+
+
+def grad_ignore_value(parameters, clip_value):
+ """
+ Overview:
+ Clip the gradient value of an iterable of parameters.
+ Arguments:
+ - parameters (:obj:`Iterable`): an iterable of torch.Tensor
+ - clip_value (:obj:`float`): the value to start clipping
+ """
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ clip_value = float(clip_value)
+ flag = False
+ for p in filter(lambda p: p.grad is not None, parameters):
+ val = p.grad.data.abs().max()
+ if val >= clip_value:
+ flag = True
+ break
+ if flag:
+ for p in filter(lambda p: p.grad is not None, parameters):
+ p.grad.data.zero_()
+
+
+class Adam(torch.optim.Adam):
+ """
+ Overview:
+ Rewrited Adam optimizer to support more features.
+ Interfaces:
+ ``__init__``, ``step``, ``_state_init``, ``get_grad``
+ """
+
+ def __init__(
+ self,
+ params: Iterable,
+ lr: float = 1e-3,
+ betas: Tuple[float, float] = (0.9, 0.999),
+ eps: float = 1e-8,
+ weight_decay: float = 0,
+ amsgrad: bool = False,
+ optim_type: str = 'adam',
+ grad_clip_type: str = None,
+ clip_value: Union[float, None] = None,
+ clip_coef: float = 5,
+ clip_norm_type: float = 2.0,
+ clip_momentum_timestep: int = 100,
+ grad_norm_type: str = None,
+ grad_ignore_type: str = None,
+ ignore_value: Union[float, None] = None,
+ ignore_coef: float = 5,
+ ignore_norm_type: float = 2.0,
+ ignore_momentum_timestep: int = 100,
+ ):
+ """
+ Overview:
+ init method of refactored Adam class
+ Arguments:
+ - params (:obj:`iterable`): – an iterable of torch.Tensor s or dict s. \
+ Specifies what Tensors should be optimized
+ - lr (:obj:`float`): learning rate, default set to 1e-3
+ - betas (:obj:`Tuple[float, float]`): coefficients used for computing running averages of gradient and its\
+ square, default set to (0.9, 0.999))
+ - eps (:obj:`float`): term added to the denominator to improve numerical stability, default set to 1e-8
+ - weight_decay (:obj:`float`): weight decay coefficient, deault set to 0
+ - amsgrad (:obj:`bool`): whether to use the AMSGrad variant of this algorithm from the paper\
+ On the Convergence of Adam and Beyond
+ - optim_type (:obj:str): support ["adam", "adamw"]
+ - grad_clip_type (:obj:`str`): support [None, 'clip_momentum', 'clip_value', 'clip_norm', \
+ 'clip_momentum_norm']
+ - clip_value (:obj:`float`): the value to start clipping
+ - clip_coef (:obj:`float`): the cliping coefficient
+ - clip_norm_type (:obj:`float`): 2.0 means use norm2 to clip
+ - clip_momentum_timestep (:obj:`int`): after how many step should we start the momentum clipping
+ - grad_ignore_type (:obj:`str`): support [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', \
+ 'ignore_momentum_norm']
+ - ignore_value (:obj:`float`): the value to start ignoring
+ - ignore_coef (:obj:`float`): the ignoreing coefficient
+ - ignore_norm_type (:obj:`float`): 2.0 means use norm2 to ignore
+ - ignore_momentum_timestep (:obj:`int`): after how many step should we start the momentum ignoring
+
+ """
+
+ self._support_type = {
+ 'optim': ['adam', 'adamw'],
+ 'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'],
+ 'grad_norm': [None],
+ 'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'],
+ }
+
+ assert optim_type in self._support_type['optim']
+ assert grad_clip_type in self._support_type['grad_clip']
+ assert grad_norm_type in self._support_type['grad_norm']
+ assert grad_ignore_type in self._support_type['grad_ignore']
+ if grad_clip_type:
+ assert clip_value is not None
+ if grad_ignore_type:
+ assert ignore_value is not None
+
+ self._optim_type = optim_type
+ self._grad_clip_type = grad_clip_type
+ self._grad_norm_type = grad_norm_type
+ self._grad_ignore_type = grad_ignore_type
+ self._clip_value = clip_value
+ self._clip_norm_type = clip_norm_type
+ self._clip_coef = clip_coef
+ self._ignore_value = ignore_value
+ self._ignore_norm_type = ignore_norm_type
+ self._ignore_coef = ignore_coef
+ self._clip_momentum_timestep = clip_momentum_timestep
+ self._ignore_momentum_timestep = ignore_momentum_timestep
+
+ if self._optim_type == 'adamw':
+ self._weight_decay = weight_decay
+ super(Adam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=0, amsgrad=amsgrad)
+ elif self._optim_type == 'adam':
+ super(Adam, self).__init__(params, lr=lr, betas=betas, eps=eps, weight_decay=weight_decay, amsgrad=amsgrad)
+ else:
+ raise NotImplementedError(
+ "optimizer type {} is not implemented, support type is {}".format(
+ self._optim_type, self._support_type['optim']
+ )
+ )
+
+ def _state_init(self, p, amsgrad):
+ """
+ Overview:
+ Initialize the state of the optimizer
+ Arguments:
+ - p (:obj:`torch.Tensor`): the parameter to be optimized
+ - amsgrad (:obj:`bool`): whether to use the AMSGrad variant of this algorithm from the paper\
+ On the Convergence of Adam and Beyond
+ """
+ state = self.state[p]
+ state['thre_exp_avg_sq'] = torch.zeros_like(p.data, device=p.data.device)
+ # others
+ if torch.__version__ < "1.12.0":
+ state['step'] = 0
+ # TODO
+ # wait torch upgrad to 1.4, 1.3.1 didn't support memory format state['step'] = 0
+ else:
+ state['step'] = torch.zeros((1,), dtype=torch.float, device=p.device) \
+ if self.defaults['capturable'] else torch.tensor(0.)
+
+ state['exp_avg'] = torch.zeros_like(p.data)
+ # Exponential moving average of squared gradient values
+ state['exp_avg_sq'] = torch.zeros_like(p.data)
+ if amsgrad:
+ # Maintains max of all exp. moving avg. of sq. grad. values
+ state['max_exp_avg_sq'] = torch.zeros_like(p.data)
+
+ def step(self, closure: Union[Callable, None] = None):
+ """
+ Overview:
+ Performs a single optimization step
+ Arguments:
+ - closure (:obj:`callable`): A closure that reevaluates the model and returns the loss, default set to None
+ """
+ # clipping
+ new_params = [
+ t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None
+ ]
+ if self._grad_clip_type == 'clip_value':
+ clip_grad_value_(new_params, self._clip_value)
+ elif self._grad_clip_type == 'clip_norm':
+ clip_grad_norm_(new_params, self._clip_value, self._clip_norm_type)
+ elif self._grad_clip_type == 'clip_momentum':
+ '''
+ This is the implimentation mimic the clip used in OPENAI, quote:
+ 'Gradients are additionally clipped per parameter to be within between ±5√v
+ where v is the running estimate of the second moment of the (unclipped) gradient'
+ '''
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ state = self.state[p]
+ if len(state) == 0:
+ self._state_init(p, group['amsgrad'])
+ grad = p.grad.data
+ # should we use same beta group?
+ beta1, beta2 = group['betas']
+ bias_correction2 = 1 - beta2 ** state['step']
+ state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
+ if state['step'] >= self._clip_momentum_timestep: # initial value is inaccurate
+ flag = grad.abs(
+ ) > (state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) * self._clip_coef
+ grad.mul_(~flag).add_(
+ ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) *
+ self._clip_coef).mul_(flag)
+ )
+ elif self._grad_clip_type == 'clip_momentum_norm':
+ # might have multi param_group, we should calculate each group differently.
+ for group in self.param_groups:
+ total_norm = 0
+ total_momentum_norm = 0
+ step = inf
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ state = self.state[p]
+ if len(state) == 0:
+ self._state_init(p, group['amsgrad'])
+ grad = p.grad.data
+ # should we use same beta group?
+ beta1, beta2 = group['betas']
+ bias_correction2 = 1 - beta2 ** state['step']
+ state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
+ # sum total_norm
+ param_norm = grad.norm(self._clip_norm_type)
+ total_norm += param_norm.item() ** self._clip_norm_type
+
+ # sum momentum_norm
+ momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) *
+ self._clip_coef).norm(self._clip_norm_type)
+ total_momentum_norm += momentum.item() ** self._clip_norm_type
+ step = min(step, state['step'])
+ if step > self._clip_momentum_timestep:
+ total_norm = total_norm ** (1. / self._clip_norm_type)
+ total_momentum_norm = total_momentum_norm ** (1. / self._clip_norm_type)
+ clip_coef = total_momentum_norm / (total_norm + 1e-6)
+ if clip_coef < 1:
+ for p in group['params']:
+ p.grad.data.mul_(clip_coef)
+
+ if self._grad_ignore_type == 'ignore_value':
+ grad_ignore_value(new_params, self._ignore_value)
+ elif self._grad_ignore_type == 'ignore_norm':
+ grad_ignore_norm(new_params, self._ignore_value, self._ignore_norm_type)
+ elif self._grad_ignore_type == 'ignore_momentum':
+ flag = False
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ state = self.state[p]
+ if len(state) == 0:
+ self._state_init(p, group['amsgrad'])
+ grad = p.grad.data
+ # should we use same beta group?
+ beta1, beta2 = group['betas']
+ bias_correction2 = 1 - beta2 ** state['step']
+ state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
+ if state['step'] >= self._ignore_momentum_timestep: # initial value is inaccurate
+ if grad.abs() > (state['thre_exp_avg_sq'].sqrt() /
+ math.sqrt(bias_correction2)) * self._ignore_coef:
+ flag = True
+ break
+ else:
+ continue
+ break
+
+ if flag:
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ p.grad.zero_()
+ elif self._grad_ignore_type == 'ignore_momentum_norm':
+ # might have multi param_group, we should calculate each group differently.
+ step = inf
+ for group in self.param_groups:
+ total_norm = 0
+ total_momentum_norm = 0
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ state = self.state[p]
+ if len(state) == 0:
+ self._state_init(p, group['amsgrad'])
+ grad = p.grad.data
+ # should we use same beta group?
+ beta1, beta2 = group['betas']
+ bias_correction2 = 1 - beta2 ** state['step']
+ state['thre_exp_avg_sq'].mul_(beta2).addcmul_(1 - beta2, grad, grad)
+ # sum total_norm
+ param_norm = grad.norm(self._ignore_norm_type)
+ total_norm += param_norm.item() ** self._ignore_norm_type
+
+ # sum momentum_norm
+ momentum = ((state['thre_exp_avg_sq'].sqrt() / math.sqrt(bias_correction2)) *
+ self._ignore_coef).norm(self._ignore_norm_type)
+ total_momentum_norm += momentum.item() ** self._ignore_norm_type
+ step = min(step, state['step'])
+
+ if step > self._ignore_momentum_timestep:
+ total_norm = total_norm ** (1. / self._ignore_norm_type)
+ total_momentum_norm = total_momentum_norm ** (1. / self._ignore_norm_type)
+ ignore_coef = total_momentum_norm / (total_norm + 1e-6)
+ if ignore_coef < 1:
+ for p in group['params']:
+ p.grad.zero_()
+
+ # Adam optim type
+ if self._optim_type == 'adamw':
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ p.data = p.data.add(-self._weight_decay * group['lr'], p.data)
+ return super().step(closure=closure)
+ elif self._optim_type == 'adam':
+ return super().step(closure=closure)
+
+ def get_grad(self) -> float:
+ total_norm = 0.
+ params = [t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None]
+ for p in params:
+ param_norm = p.grad.data.norm(self._clip_norm_type)
+ total_norm += param_norm.item() ** self._clip_norm_type
+ return total_norm
+
+
+class RMSprop(torch.optim.RMSprop):
+ r"""
+ Overview:
+ Rewrited RMSprop optimizer to support more features.
+ Interfaces:
+ ``__init__``, ``step``, ``_state_init``, ``get_grad``
+ """
+
+ def __init__(
+ self,
+ params: Iterable,
+ lr: float = 1e-2,
+ alpha: float = 0.99,
+ eps: float = 1e-8,
+ weight_decay: float = 0,
+ momentum: float = 0,
+ centered: bool = False,
+ grad_clip_type: str = None,
+ clip_value: Union[float, None] = None,
+ clip_coef: float = 5,
+ clip_norm_type: float = 2.0,
+ clip_momentum_timestep: int = 100,
+ grad_norm_type: str = None,
+ grad_ignore_type: str = None,
+ ignore_value: Union[float, None] = None,
+ ignore_coef: float = 5,
+ ignore_norm_type: float = 2.0,
+ ignore_momentum_timestep: int = 100,
+ ):
+ """
+ Overview:
+ init method of refactored Adam class
+ Arguments:
+ - params (:obj:`iterable`): – an iterable of torch.Tensor s or dict s. \
+ Specifies what Tensors should be optimized
+ - lr (:obj:`float`): learning rate, default set to 1e-3
+ - alpha (:obj:`float`): smoothing constant, default set to 0.99
+ - eps (:obj:`float`): term added to the denominator to improve numerical stability, default set to 1e-8
+ - weight_decay (:obj:`float`): weight decay coefficient, deault set to 0
+ - centred (:obj:`bool`): if True, compute the centered RMSprop, \
+ the gradient is normalized by an estimation of its variance
+ - grad_clip_type (:obj:`str`): support [None, 'clip_momentum', 'clip_value', 'clip_norm', \
+ 'clip_momentum_norm']
+ - clip_value (:obj:`float`): the value to start clipping
+ - clip_coef (:obj:`float`): the cliping coefficient
+ - clip_norm_type (:obj:`float`): 2.0 means use norm2 to clip
+ - clip_momentum_timestep (:obj:`int`): after how many step should we start the momentum clipping
+ - grad_ignore_type (:obj:`str`): support [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', \
+ 'ignore_momentum_norm']
+ - ignore_value (:obj:`float`): the value to start ignoring
+ - ignore_coef (:obj:`float`): the ignoreing coefficient
+ - ignore_norm_type (:obj:`float`): 2.0 means use norm2 to ignore
+ - ignore_momentum_timestep (:obj:`int`): after how many step should we start the momentum ignoring
+ """
+
+ self._support_type = {
+ 'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'],
+ 'grad_norm': [None],
+ 'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'],
+ }
+
+ assert grad_clip_type in self._support_type['grad_clip']
+ assert grad_norm_type in self._support_type['grad_norm']
+ assert grad_ignore_type in self._support_type['grad_ignore']
+ if grad_clip_type:
+ assert clip_value is not None
+ if grad_ignore_type:
+ assert ignore_value is not None
+
+ self._grad_clip_type = grad_clip_type
+ self._grad_norm_type = grad_norm_type
+ self._grad_ignore_type = grad_ignore_type
+ self._clip_value = clip_value
+ self._clip_norm_type = clip_norm_type
+ self._clip_coef = clip_coef
+ self._ignore_value = ignore_value
+ self._ignore_norm_type = ignore_norm_type
+ self._ignore_coef = ignore_coef
+ self._clip_momentum_timestep = clip_momentum_timestep
+ self._ignore_momentum_timestep = ignore_momentum_timestep
+
+ super(RMSprop, self).__init__(
+ params, lr=lr, alpha=alpha, eps=eps, weight_decay=weight_decay, momentum=momentum, centered=centered
+ )
+
+ def _state_init(self, p, momentum, centered):
+ """
+ Overview:
+ Initialize the state of the optimizer
+ Arguments:
+ - p (:obj:`torch.Tensor`): the parameter to be optimized
+ - momentum (:obj:`float`): the momentum coefficient
+ - centered (:obj:`bool`): if True, compute the centered RMSprop, \
+ the gradient is normalized by an estimation of its variance
+ """
+
+ state = self.state[p]
+ state['step'] = 0
+ state['thre_square_avg'] = torch.zeros_like(p.data, device=p.data.device)
+ state['square_avg'] = torch.zeros_like(p.data, device=p.data.device)
+ if momentum:
+ state['momentum_buffer'] = torch.zeros_like(p.data, device=p.data.device)
+ if centered:
+ state['grad_avg'] = torch.zeros_like(p.data, device=p.data.device)
+
+ def step(self, closure: Union[Callable, None] = None):
+ """
+ Overview:
+ Performs a single optimization step
+ Arguments:
+ - closure (:obj:`callable`): A closure that reevaluates the model and returns the loss, default set to None
+ """
+ # clipping
+ new_params = [
+ t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None
+ ]
+ if self._grad_clip_type == 'clip_value':
+ clip_grad_value_(new_params, self._clip_value)
+ elif self._grad_clip_type == 'clip_norm':
+ clip_grad_norm_(new_params, self._clip_value, self._clip_norm_type)
+ elif self._grad_clip_type == 'clip_momentum':
+ '''
+ This implementation mimics the clip used in OPENAI, quote:
+ 'Gradients are additionally clipped per parameter to be within between ±5√v
+ where v is the running estimate of the second moment of the (unclipped) gradient'
+ '''
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ state = self.state[p]
+ if len(state) == 0:
+ self._state_init(p, group['momentum'], group['centered'])
+ grad = p.grad.data
+ # beta1, beta2 = group['betas']
+ alpha = group['alpha']
+ state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad)
+ if state['step'] >= self._clip_momentum_timestep: # initial value is inaccurate
+ flag = grad.abs() > state['thre_square_avg'].sqrt() * self._clip_coef
+ grad.mul_(~flag).add_((state['thre_square_avg'].sqrt() * self._clip_coef).mul_(flag))
+ elif self._grad_clip_type == 'clip_momentum_norm':
+ # might have multi param_group, we should calculate each group differently.
+ for group in self.param_groups:
+ total_norm = 0
+ total_momentum_norm = 0
+ step = inf
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ state = self.state[p]
+ if len(state) == 0:
+ self._state_init(p, group['momentum'], group['centered'])
+ grad = p.grad.data
+ alpha = group['alpha']
+ state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad)
+ # sum total_norm
+ param_norm = grad.norm(self._clip_norm_type)
+ total_norm += param_norm.item() ** self._clip_norm_type
+
+ # sum momentum_norm
+ momentum = (state['thre_square_avg'].sqrt() * self._clip_coef).norm(self._clip_norm_type)
+ total_momentum_norm += momentum.item() ** self._clip_norm_type
+ step = min(step, state['step'])
+ if step > self._clip_momentum_timestep:
+ total_norm = total_norm ** (1. / self._clip_norm_type)
+ total_momentum_norm = total_momentum_norm ** (1. / self._clip_norm_type)
+ clip_coef = total_momentum_norm / (total_norm + 1e-6)
+ if clip_coef < 1:
+ for p in group['params']:
+ p.grad.data.mul_(clip_coef)
+
+ if self._grad_ignore_type == 'ignore_value':
+ grad_ignore_value(new_params, self._ignore_value)
+ elif self._grad_ignore_type == 'ignore_norm':
+ grad_ignore_norm(new_params, self._ignore_value, self._ignore_norm_type)
+ elif self._grad_ignore_type == 'ignore_momentum':
+ flag = False
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ state = self.state[p]
+ if len(state) == 0:
+ self._state_init(p, group['momentum'], group['centered'])
+ grad = p.grad.data
+ alpha = group['alpha']
+ state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad)
+ if state['step'] >= self._ignore_momentum_timestep: # initial value is inaccurate
+ if grad.abs() > state['thre_square_avg'].sqrt() * self._ignore_coef:
+ flag = True
+ break
+ else:
+ continue
+ break
+
+ if flag:
+ for group in self.param_groups:
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ p.grad.zero_()
+ elif self._grad_ignore_type == 'ignore_momentum_norm':
+ # might have multi param_group, we should calculate each group differently.
+ step = inf
+ for group in self.param_groups:
+ total_norm = 0
+ total_momentum_norm = 0
+ for p in group['params']:
+ if p.grad is None:
+ continue
+ state = self.state[p]
+ if len(state) == 0:
+ self._state_init(p, group['momentum'], group['centered'])
+ grad = p.grad.data
+ alpha = group['alpha']
+ state['thre_square_avg'].mul_(alpha).addcmul_(1 - alpha, grad, grad)
+ # sum total_norm
+ param_norm = grad.norm(self._ignore_norm_type)
+ total_norm += param_norm.item() ** self._ignore_norm_type
+
+ # sum momentum_norm
+ momentum = (state['thre_square_avg'].sqrt() * self._ignore_coef).norm(self._ignore_norm_type)
+ total_momentum_norm += momentum.item() ** self._ignore_norm_type
+ step = min(step, state['step'])
+
+ if step > self._ignore_momentum_timestep:
+ total_norm = total_norm ** (1. / self._ignore_norm_type)
+ total_momentum_norm = total_momentum_norm ** (1. / self._ignore_norm_type)
+ ignore_coef = total_momentum_norm / (total_norm + 1e-6)
+ if ignore_coef < 1:
+ for p in group['params']:
+ p.grad.zero_()
+
+ return super().step(closure=closure)
+
+ def get_grad(self) -> float:
+ """
+ Overview:
+ calculate grad norm of the parameters whose grad norms are not None in the model.
+ """
+
+ total_norm = 0.
+ params = [t for group in self.param_groups for t in group['params'] if t.requires_grad and t.grad is not None]
+ for p in params:
+ param_norm = p.grad.data.norm(self._clip_norm_type)
+ total_norm += param_norm.item() ** self._clip_norm_type
+ return total_norm
+
+
+class PCGrad():
+ """
+ Overview:
+ PCGrad optimizer to support multi-task.
+ you can view the paper in the following link https://arxiv.org/pdf/2001.06782.pdf
+ Interfaces:
+ ``__init__``, ``zero_grad``, ``step``, ``pc_backward``
+ Properties:
+ - optimizer (:obj:`torch.optim`): the optimizer to be used
+ """
+
+ def __init__(self, optimizer, reduction='mean'):
+ """
+ Overview:
+ Initialization of PCGrad optimizer
+ Arguments:
+ - optimizer (:obj:`torch.optim`): the optimizer to be used
+ - reduction (:obj:`str`): the reduction method, support ['mean', 'sum']
+ """
+
+ self._optim, self._reduction = optimizer, reduction
+
+ @property
+ def optimizer(self):
+ """
+ Overview:
+ get the optimizer
+ """
+
+ return self._optim
+
+ def zero_grad(self):
+ """
+ Overview:
+ clear the gradient of the parameters
+ """
+
+ return self._optim.zero_grad(set_to_none=True)
+
+ def step(self):
+ """
+ Overview:
+ update the parameters with the gradient
+ """
+
+ return self._optim.step()
+
+ def pc_backward(self, objectives):
+ """
+ Overview:
+ calculate the gradient of the parameters
+ Arguments:
+ - objectives: a list of objectives
+ """
+
+ grads, shapes, has_grads = self._pack_grad(objectives)
+ pc_grad = self._project_conflicting(grads, has_grads)
+ pc_grad = self._unflatten_grad(pc_grad, shapes[0])
+ self._set_grad(pc_grad)
+ return
+
+ def _project_conflicting(self, grads, has_grads, shapes=None):
+ """
+ Overview:
+ project the conflicting gradient to the orthogonal space
+ Arguments:
+ - grads (:obj:`list`): a list of the gradient of the parameters
+ - has_grads (:obj:`list`): a list of mask represent whether the parameter has gradient
+ - shapes (:obj:`list`): a list of the shape of the parameters
+ """
+
+ shared = torch.stack(has_grads).prod(0).bool()
+ pc_grad, num_task = copy.deepcopy(grads), len(grads)
+ for g_i in pc_grad:
+ random.shuffle(grads)
+ for g_j in grads:
+ g_i_g_j = torch.dot(g_i, g_j)
+ if g_i_g_j < 0:
+ g_i -= (g_i_g_j) * g_j / (g_j.norm() ** 2)
+ merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)
+ if self._reduction:
+ merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).mean(dim=0)
+ elif self._reduction == 'sum':
+ merged_grad[shared] = torch.stack([g[shared] for g in pc_grad]).sum(dim=0)
+ else:
+ raise KeyError("invalid reduction method")
+
+ merged_grad[~shared] = torch.stack([g[~shared] for g in pc_grad]).sum(dim=0)
+ return merged_grad
+
+ def _set_grad(self, grads):
+ """
+ Overview:
+ set the modified gradients to the network
+ Arguments:
+ - grads (:obj:`list`): a list of the gradient of the parameters
+ """
+
+ idx = 0
+ for group in self._optim.param_groups:
+ for p in group['params']:
+ # if p.grad is None: continue
+ p.grad = grads[idx]
+ idx += 1
+ return
+
+ def _pack_grad(self, objectives):
+ """
+ Overview:
+ pack the gradient of the parameters of the network for each objective
+ Arguments:
+ - objectives: a list of objectives
+ Returns:
+ - grad: a list of the gradient of the parameters
+ - shape: a list of the shape of the parameters
+ - has_grad: a list of mask represent whether the parameter has gradient
+ """
+
+ grads, shapes, has_grads = [], [], []
+ for obj in objectives:
+ self._optim.zero_grad(set_to_none=True)
+ obj.backward(retain_graph=True)
+ grad, shape, has_grad = self._retrieve_grad()
+ grads.append(self._flatten_grad(grad, shape))
+ has_grads.append(self._flatten_grad(has_grad, shape))
+ shapes.append(shape)
+ return grads, shapes, has_grads
+
+ def _unflatten_grad(self, grads, shapes):
+ """
+ Overview:
+ unflatten the gradient of the parameters of the network
+ Arguments:
+ - grads (:obj:`list`): a list of the gradient of the parameters
+ - shapes (:obj:`list`): a list of the shape of the parameters
+ """
+
+ unflatten_grad, idx = [], 0
+ for shape in shapes:
+ length = np.prod(shape)
+ unflatten_grad.append(grads[idx:idx + length].view(shape).clone())
+ idx += length
+ return unflatten_grad
+
+ def _flatten_grad(self, grads, shapes):
+ """
+ Overview:
+ flatten the gradient of the parameters of the network
+ Arguments:
+ - grads (:obj:`list`): a list of the gradient of the parameters
+ - shapes (:obj:`list`): a list of the shape of the parameters
+ """
+
+ flatten_grad = torch.cat([g.flatten() for g in grads])
+ return flatten_grad
+
+ def _retrieve_grad(self):
+ """
+ Overview:
+ get the gradient of the parameters of the network with specific objective
+ Returns:
+ - grad: a list of the gradient of the parameters
+ - shape: a list of the shape of the parameters
+ - has_grad: a list of mask represent whether the parameter has gradient
+ """
+
+ grad, shape, has_grad = [], [], []
+ for group in self._optim.param_groups:
+ for p in group['params']:
+ # if p.grad is None: continue
+ # tackle the multi-head scenario
+ if p.grad is None:
+ shape.append(p.shape)
+ grad.append(torch.zeros_like(p).to(p.device))
+ has_grad.append(torch.zeros_like(p).to(p.device))
+ continue
+ shape.append(p.grad.shape)
+ grad.append(p.grad.clone())
+ has_grad.append(torch.ones_like(p).to(p.device))
+ return grad, shape, has_grad
+
+
+def configure_weight_decay(model: nn.Module, weight_decay: float) -> List:
+ """
+ Overview:
+ Separating out all parameters of the model into two buckets: those that will experience
+ weight decay for regularization and those that won't (biases, and layer-norm or embedding weights).
+ Arguments:
+ - model (:obj:`nn.Module`): the given PyTorch model.
+ - weight_decay (:obj:`float`): weight decay value for optimizer.
+ Returns:
+ - optim groups (:obj:`List`): the parameter groups to be set in the latter optimizer.
+ """
+ decay = set()
+ no_decay = set()
+ whitelist_weight_modules = (torch.nn.Linear, )
+ blacklist_weight_modules = (torch.nn.LayerNorm, torch.nn.Embedding)
+ for mn, m in model.named_modules():
+ for pn, p in m.named_parameters():
+ fpn = '%s.%s' % (mn, pn) if mn else pn # full param name
+ # Because named_modules and named_parameters are recursive
+ # we will see the same tensors p many times. But doing it this way
+ # allows us to know which parent module any tensor p belongs to.
+ if pn.endswith('bias'):
+ # all biases will not be decayed
+ no_decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, whitelist_weight_modules):
+ # weights of whitelist modules will be weight decayed
+ decay.add(fpn)
+ elif pn.endswith('weight') and isinstance(m, blacklist_weight_modules):
+ # weights of blacklist modules will NOT be weight decayed
+ no_decay.add(fpn)
+ else:
+ decay.add(fpn)
+
+ decay = decay - no_decay
+ # validate that we considered every parameter
+ param_dict = {pn: p for pn, p in model.named_parameters()}
+ union_params = decay | no_decay
+ assert len(
+ param_dict.keys() - union_params) == 0, "parameters %s were not separated into either decay/no_decay set!" \
+ % (str(param_dict.keys() - union_params),)
+
+ optim_groups = [
+ {
+ "params": [param_dict[pn] for pn in sorted(list(decay))],
+ "weight_decay": weight_decay
+ },
+ {
+ "params": [param_dict[pn] for pn in sorted(list(no_decay))],
+ "weight_decay": 0.0
+ },
+ ]
+
+ return optim_groups
diff --git a/DI-engine/ding/torch_utils/parameter.py b/DI-engine/ding/torch_utils/parameter.py
new file mode 100644
index 0000000000000000000000000000000000000000..08da7feb766b5cee908326a6018524f56ec2177f
--- /dev/null
+++ b/DI-engine/ding/torch_utils/parameter.py
@@ -0,0 +1,89 @@
+from typing import Optional
+import torch
+from torch import nn
+from torch.distributions.transforms import TanhTransform
+
+
+class NonegativeParameter(nn.Module):
+ """
+ Overview:
+ This module will output a non-negative parameter during the forward process.
+ Interfaces:
+ ``__init__``, ``forward``, ``set_data``.
+ """
+
+ def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True, delta: float = 1e-8):
+ """
+ Overview:
+ Initialize the NonegativeParameter object using the given arguments.
+ Arguments:
+ - data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \
+ default value is 0.
+ - requires_grad (:obj:`bool`): Whether this parameter requires grad.
+ - delta (:obj:`Any`): The delta of log function.
+ """
+ super().__init__()
+ if data is None:
+ data = torch.zeros(1)
+ self.log_data = nn.Parameter(torch.log(data + delta), requires_grad=requires_grad)
+
+ def forward(self) -> torch.Tensor:
+ """
+ Overview:
+ Output the non-negative parameter during the forward process.
+ Returns:
+ parameter (:obj:`torch.Tensor`): The generated parameter.
+ """
+ return torch.exp(self.log_data)
+
+ def set_data(self, data: torch.Tensor) -> None:
+ """
+ Overview:
+ Set the value of the non-negative parameter.
+ Arguments:
+ data (:obj:`torch.Tensor`): The new value of the non-negative parameter.
+ """
+ self.log_data = nn.Parameter(torch.log(data + 1e-8), requires_grad=self.log_data.requires_grad)
+
+
+class TanhParameter(nn.Module):
+ """
+ Overview:
+ This module will output a tanh parameter during the forward process.
+ Interfaces:
+ ``__init__``, ``forward``, ``set_data``.
+ """
+
+ def __init__(self, data: Optional[torch.Tensor] = None, requires_grad: bool = True):
+ """
+ Overview:
+ Initialize the TanhParameter object using the given arguments.
+ Arguments:
+ - data (:obj:`Optional[torch.Tensor]`): The initial value of generated parameter. If set to ``None``, the \
+ default value is 1.
+ - requires_grad (:obj:`bool`): Whether this parameter requires grad.
+ """
+ super().__init__()
+ if data is None:
+ data = torch.zeros(1)
+ self.transform = TanhTransform(cache_size=1)
+
+ self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=requires_grad)
+
+ def forward(self) -> torch.Tensor:
+ """
+ Overview:
+ Output the tanh parameter during the forward process.
+ Returns:
+ parameter (:obj:`torch.Tensor`): The generated parameter.
+ """
+ return self.transform(self.data_inv)
+
+ def set_data(self, data: torch.Tensor) -> None:
+ """
+ Overview:
+ Set the value of the tanh parameter.
+ Arguments:
+ data (:obj:`torch.Tensor`): The new value of the tanh parameter.
+ """
+ self.data_inv = nn.Parameter(self.transform.inv(data), requires_grad=self.data_inv.requires_grad)
diff --git a/DI-engine/ding/torch_utils/reshape_helper.py b/DI-engine/ding/torch_utils/reshape_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..49a3e3b8d21ff7f7f07aff121695c08d8e3157cd
--- /dev/null
+++ b/DI-engine/ding/torch_utils/reshape_helper.py
@@ -0,0 +1,91 @@
+from typing import Tuple, Union
+
+from torch import Tensor, Size
+
+
+def fold_batch(x: Tensor, nonbatch_ndims: int = 1) -> Tuple[Tensor, Size]:
+ """
+ Overview:
+ :math:`(T, B, X) \leftarrow (T*B, X)`\
+ Fold the first (ndim - nonbatch_ndims) dimensions of a tensor as batch dimension.\
+ This operation is similar to `torch.flatten` but provides an inverse function
+ `unfold_batch` to restore the folded dimensions.
+
+ Arguments:
+ - x (:obj:`torch.Tensor`): the tensor to fold
+ - nonbatch_ndims (:obj:`int`): the number of dimensions that is not folded as
+ batch dimension.
+
+ Returns:
+ - x (:obj:`torch.Tensor`): the folded tensor
+ - batch_dims: the folded dimensions of the original tensor, which can be used to
+ reverse the operation
+
+ Examples:
+ >>> x = torch.ones(10, 20, 5, 4, 8)
+ >>> x, batch_dim = fold_batch(x, 2)
+ >>> x.shape == (1000, 4, 8)
+ >>> batch_dim == (10, 20, 5)
+
+ """
+ if nonbatch_ndims > 0:
+ batch_dims = x.shape[:-nonbatch_ndims]
+ x = x.view(-1, *(x.shape[-nonbatch_ndims:]))
+ return x, batch_dims
+ else:
+ batch_dims = x.shape
+ x = x.view(-1)
+ return x, batch_dims
+
+
+def unfold_batch(x: Tensor, batch_dims: Union[Size, Tuple]) -> Tensor:
+ """
+ Overview:
+ Unfold the batch dimension of a tensor.
+
+ Arguments:
+ - x (:obj:`torch.Tensor`): the tensor to unfold
+ - batch_dims (:obj:`torch.Size`): the dimensions that are folded
+
+ Returns:
+ - x (:obj:`torch.Tensor`): the original unfolded tensor
+
+ Examples:
+ >>> x = torch.ones(10, 20, 5, 4, 8)
+ >>> x, batch_dim = fold_batch(x, 2)
+ >>> x.shape == (1000, 4, 8)
+ >>> batch_dim == (10, 20, 5)
+ >>> x = unfold_batch(x, batch_dim)
+ >>> x.shape == (10, 20, 5, 4, 8)
+ """
+ return x.view(*batch_dims, *x.shape[1:])
+
+
+def unsqueeze_repeat(x: Tensor, repeat_times: int, unsqueeze_dim: int = 0) -> Tensor:
+ """
+ Overview:
+ Squeeze the tensor on `unsqueeze_dim` and then repeat in this dimension for `repeat_times` times.\
+ This is useful for preproprocessing the input to an model ensemble.
+
+ Arguments:
+ - x (:obj:`torch.Tensor`): the tensor to squeeze and repeat
+ - repeat_times (:obj:`int`): the times that the tensor is repeatd
+ - unsqueeze_dim (:obj:`int`): the unsqueezed dimension
+
+ Returns:
+ - x (:obj:`torch.Tensor`): the unsqueezed and repeated tensor
+
+ Examples:
+ >>> x = torch.ones(64, 6)
+ >>> x = unsqueeze_repeat(x, 4)
+ >>> x.shape == (4, 64, 6)
+
+ >>> x = torch.ones(64, 6)
+ >>> x = unsqueeze_repeat(x, 4, -1)
+ >>> x.shape == (64, 6, 4)
+ """
+ assert -1 <= unsqueeze_dim <= len(x.shape), f'unsqueeze_dim should be from {-1} to {len(x.shape)}'
+ x = x.unsqueeze(unsqueeze_dim)
+ repeats = [1] * len(x.shape)
+ repeats[unsqueeze_dim] *= repeat_times
+ return x.repeat(*repeats)
diff --git a/DI-engine/ding/torch_utils/tests/test_backend_helper.py b/DI-engine/ding/torch_utils/tests/test_backend_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5988846936e8a41588d9b39394d64e7ffb4baf6
--- /dev/null
+++ b/DI-engine/ding/torch_utils/tests/test_backend_helper.py
@@ -0,0 +1,21 @@
+import pytest
+import torch
+
+from ding.torch_utils.backend_helper import enable_tf32
+
+
+@pytest.mark.cudatest
+class TestBackendHelper:
+
+ def test_tf32(self):
+ r"""
+ Overview:
+ Test the tf32.
+ """
+ enable_tf32()
+ net = torch.nn.Linear(3, 4)
+ x = torch.randn(1, 3)
+ y = torch.sum(net(x))
+ net.zero_grad()
+ y.backward()
+ assert net.weight.grad is not None
diff --git a/DI-engine/ding/torch_utils/tests/test_ckpt_helper.py b/DI-engine/ding/torch_utils/tests/test_ckpt_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..f397a4a66128e1f89f9267f268fd2f90df383d25
--- /dev/null
+++ b/DI-engine/ding/torch_utils/tests/test_ckpt_helper.py
@@ -0,0 +1,184 @@
+import os
+import time
+
+import pytest
+import torch
+import torch.nn as nn
+import uuid
+
+from ding.torch_utils.checkpoint_helper import auto_checkpoint, build_checkpoint_helper, CountVar
+from ding.utils import read_file, save_file
+
+
+class DstModel(nn.Module):
+
+ def __init__(self):
+ super(DstModel, self).__init__()
+ self.fc1 = nn.Linear(3, 3)
+ self.fc2 = nn.Linear(3, 8)
+ self.fc_dst = nn.Linear(3, 6)
+
+
+class SrcModel(nn.Module):
+
+ def __init__(self):
+ super(SrcModel, self).__init__()
+ self.fc1 = nn.Linear(3, 3)
+ self.fc2 = nn.Linear(3, 8)
+ self.fc_src = nn.Linear(3, 7)
+
+
+class HasStateDict(object):
+
+ def __init__(self, name):
+ self._name = name
+ self._state_dict = name + str(uuid.uuid4())
+
+ def state_dict(self):
+ old = self._state_dict
+ self._state_dict = self._name + str(uuid.uuid4())
+ return old
+
+ def load_state_dict(self, state_dict):
+ self._state_dict = state_dict
+
+
+@pytest.mark.unittest
+class TestCkptHelper:
+
+ def test_load_model(self):
+ path = 'model.pt'
+ os.popen('rm -rf ' + path)
+ time.sleep(1)
+
+ dst_model = DstModel()
+ src_model = SrcModel()
+ ckpt_state_dict = {'model': src_model.state_dict()}
+ torch.save(ckpt_state_dict, path)
+
+ ckpt_helper = build_checkpoint_helper({})
+ with pytest.raises(RuntimeError):
+ ckpt_helper.load(path, dst_model, strict=True)
+
+ ckpt_helper.load(path, dst_model, strict=False)
+ assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() < 1e-6
+ assert torch.abs(dst_model.fc1.bias - src_model.fc1.bias).max() < 1e-6
+
+ dst_model = DstModel()
+ src_model = SrcModel()
+ assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() > 1e-6
+ src_optimizer = HasStateDict('src_optimizer')
+ dst_optimizer = HasStateDict('dst_optimizer')
+ src_last_epoch = CountVar(11)
+ dst_last_epoch = CountVar(5)
+ src_last_iter = CountVar(110)
+ dst_last_iter = CountVar(50)
+ src_dataset = HasStateDict('src_dataset')
+ dst_dataset = HasStateDict('dst_dataset')
+ src_collector_info = HasStateDict('src_collect_info')
+ dst_collector_info = HasStateDict('dst_collect_info')
+ ckpt_helper.save(
+ path,
+ src_model,
+ optimizer=src_optimizer,
+ dataset=src_dataset,
+ collector_info=src_collector_info,
+ last_iter=src_last_iter,
+ last_epoch=src_last_epoch,
+ prefix_op='remove',
+ prefix="f"
+ )
+ ckpt_helper.load(
+ path,
+ dst_model,
+ dataset=dst_dataset,
+ optimizer=dst_optimizer,
+ last_iter=dst_last_iter,
+ last_epoch=dst_last_epoch,
+ collector_info=dst_collector_info,
+ strict=False,
+ state_dict_mask=['fc1'],
+ prefix_op='add',
+ prefix="f"
+ )
+ assert dst_dataset.state_dict().startswith('src')
+ assert dst_optimizer.state_dict().startswith('src')
+ assert dst_collector_info.state_dict().startswith('src')
+ assert dst_last_iter.val == 110
+ for k, v in dst_model.named_parameters():
+ assert k.startswith('fc')
+ print('==dst', dst_model.fc2.weight)
+ print('==src', src_model.fc2.weight)
+ assert torch.abs(dst_model.fc2.weight - src_model.fc2.weight).max() < 1e-6
+ assert torch.abs(dst_model.fc1.weight - src_model.fc1.weight).max() > 1e-6
+
+ checkpoint = read_file(path)
+ checkpoint.pop('dataset')
+ checkpoint.pop('optimizer')
+ checkpoint.pop('last_iter')
+ save_file(path, checkpoint)
+ ckpt_helper.load(
+ path,
+ dst_model,
+ dataset=dst_dataset,
+ optimizer=dst_optimizer,
+ last_iter=dst_last_iter,
+ last_epoch=dst_last_epoch,
+ collector_info=dst_collector_info,
+ strict=True,
+ state_dict_mask=['fc1'],
+ prefix_op='add',
+ prefix="f"
+ )
+ with pytest.raises(NotImplementedError):
+ ckpt_helper.load(
+ path,
+ dst_model,
+ strict=False,
+ lr_schduler='lr_scheduler',
+ last_iter=dst_last_iter,
+ )
+
+ with pytest.raises(KeyError):
+ ckpt_helper.save(path, src_model, prefix_op='key_error', prefix="f")
+ ckpt_helper.load(path, dst_model, strict=False, prefix_op='key_error', prefix="f")
+
+ os.popen('rm -rf ' + path + '*')
+
+
+@pytest.mark.unittest
+def test_count_var():
+ var = CountVar(0)
+ var.add(5)
+ assert var.val == 5
+ var.update(3)
+ assert var.val == 3
+
+
+@pytest.mark.unittest
+def test_auto_checkpoint():
+
+ class AutoCkptCls:
+
+ def __init__(self):
+ pass
+
+ @auto_checkpoint
+ def start(self):
+ for i in range(10):
+ if i < 5:
+ time.sleep(0.2)
+ else:
+ raise Exception("There is an exception")
+ break
+
+ def save_checkpoint(self, ckpt_path):
+ print('Checkpoint is saved successfully in {}!'.format(ckpt_path))
+
+ auto_ckpt = AutoCkptCls()
+ auto_ckpt.start()
+
+
+if __name__ == '__main__':
+ test = TestCkptHelper()
+ test.test_load_model()
diff --git a/DI-engine/ding/torch_utils/tests/test_data_helper.py b/DI-engine/ding/torch_utils/tests/test_data_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b61d96dd9f39b66889bc04de36b9e4a0fbcce4e
--- /dev/null
+++ b/DI-engine/ding/torch_utils/tests/test_data_helper.py
@@ -0,0 +1,250 @@
+import pytest
+from collections import namedtuple
+import numpy as np
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader
+import treetensor.torch as ttorch
+
+from ding.torch_utils import CudaFetcher, to_device, to_dtype, to_tensor, to_ndarray, to_list, \
+ tensor_to_list, same_shape, build_log_buffer, get_tensor_data, to_item
+from ding.utils import EasyTimer
+
+
+@pytest.fixture(scope='function')
+def setup_data_dict():
+ return {
+ 'tensor': torch.randn(4),
+ 'list': [True, False, False],
+ 'tuple': (4, 5, 6),
+ 'bool': True,
+ 'int': 10,
+ 'float': 10.,
+ 'array': np.random.randn(4),
+ 'str': "asdf",
+ 'none': None,
+ }
+
+
+@pytest.mark.unittest
+class TestDataFunction:
+
+ def test_to_dtype(self):
+ t = torch.randint(0, 10, (3, 5))
+ tfloat = to_dtype(t, torch.float)
+ assert tfloat.dtype == torch.float
+ tlist = [t]
+ tlfloat = to_dtype(tlist, torch.float)
+ assert tlfloat[0].dtype == torch.float
+ tdict = {'t': t}
+ tdictf = to_dtype(tdict, torch.float)
+ assert tdictf['t'].dtype == torch.float
+ with pytest.raises(TypeError):
+ to_dtype(EasyTimer(), torch.float)
+
+ def test_to_tensor(self, setup_data_dict):
+ i = 10
+ t = to_tensor(i)
+ assert t.item() == i
+ d = {'i': i}
+ dt = to_tensor(d, torch.int)
+ assert dt['i'].item() == i
+ with pytest.raises(TypeError):
+ _ = to_tensor({1, 2}, torch.int)
+
+ data_type = namedtuple('data_type', ['x', 'y'])
+ inputs = data_type(np.random.random(3), 4)
+ outputs = to_tensor(inputs, torch.float32)
+ assert type(outputs) == data_type
+ assert isinstance(outputs.x, torch.Tensor)
+ assert isinstance(outputs.y, torch.Tensor)
+ assert outputs.x.dtype == torch.float32
+ assert outputs.y.dtype == torch.float32
+
+ transformed_tensor = to_tensor(setup_data_dict)
+ with pytest.raises(TypeError):
+ to_tensor(EasyTimer(), torch.float)
+
+ def test_to_ndarray(self, setup_data_dict):
+ t = torch.randn(3, 5)
+ tarray1 = to_ndarray(t)
+ assert tarray1.shape == (3, 5)
+ assert isinstance(tarray1, np.ndarray)
+
+ t = [torch.randn(5, ) for i in range(3)]
+ tarray1 = to_ndarray(t, np.float32)
+ assert isinstance(tarray1, list)
+ assert tarray1[0].shape == (5, )
+ assert isinstance(tarray1[0], np.ndarray)
+
+ transformed_array = to_ndarray(setup_data_dict)
+ with pytest.raises(TypeError):
+ to_ndarray(EasyTimer(), np.float32)
+
+ def test_to_list(self, setup_data_dict):
+ # tensor_to_list
+ t = torch.randn(3, 5)
+ tlist1 = tensor_to_list(t)
+ assert len(tlist1) == 3
+ assert len(tlist1[0]) == 5
+
+ t = torch.randn(3, )
+ tlist1 = tensor_to_list(t)
+ assert len(tlist1) == 3
+
+ t = [torch.randn(5, ) for i in range(3)]
+ tlist1 = tensor_to_list(t)
+ assert len(tlist1) == 3
+ assert len(tlist1[0]) == 5
+
+ td = {'t': t}
+ tdlist1 = tensor_to_list(td)
+ assert len(tdlist1['t']) == 3
+ assert len(tdlist1['t'][0]) == 5
+
+ tback = to_tensor(tlist1, torch.float)
+ for i in range(3):
+ assert (tback[i] == t[i]).all()
+
+ with pytest.raises(TypeError):
+ tensor_to_list(EasyTimer())
+
+ # to_list
+ transformed_list = to_list(setup_data_dict)
+ with pytest.raises(TypeError):
+ to_ndarray(EasyTimer())
+
+ def test_to_item(self):
+ data = {
+ 'tensor': torch.randn(1),
+ 'list': [True, False, torch.randn(1)],
+ 'tuple': (4, 5, 6),
+ 'bool': True,
+ 'int': 10,
+ 'float': 10.,
+ 'array': np.random.randn(1),
+ 'str': "asdf",
+ 'none': None,
+ }
+ assert not np.isscalar(data['tensor'])
+ assert not np.isscalar(data['array'])
+ assert not np.isscalar(data['list'][-1])
+ new_data = to_item(data)
+ assert np.isscalar(new_data['tensor'])
+ assert np.isscalar(new_data['array'])
+ assert np.isscalar(new_data['list'][-1])
+
+ data = ttorch.randn({'a': 1})
+ new_data = to_item(data)
+ assert np.isscalar(new_data.a)
+
+ with pytest.raises((ValueError, RuntimeError)):
+ to_item({'a': torch.randn(4), 'b': torch.rand(1)}, ignore_error=False)
+ output = to_item({'a': torch.randn(4), 'b': torch.rand(1)}, ignore_error=True)
+ assert 'a' not in output
+ assert 'b' in output
+
+ def test_same_shape(self):
+ tlist = [torch.randn(3, 5) for i in range(5)]
+ assert same_shape(tlist)
+ tlist = [torch.randn(3, 5), torch.randn(4, 5)]
+ assert not same_shape(tlist)
+
+ def test_get_tensor_data(self):
+ a = {
+ 'tensor': torch.tensor([1, 2, 3.], requires_grad=True),
+ 'list': [torch.tensor([1, 2, 3.], requires_grad=True) for _ in range(2)],
+ 'none': None
+ }
+ tensor_a = get_tensor_data(a)
+ assert not tensor_a['tensor'].requires_grad
+ for t in tensor_a['list']:
+ assert not t.requires_grad
+ with pytest.raises(TypeError):
+ get_tensor_data(EasyTimer())
+
+
+@pytest.mark.unittest
+def test_log_dict():
+ log_buffer = build_log_buffer()
+ log_buffer['not_tensor'] = torch.randn(3)
+ assert isinstance(log_buffer['not_tensor'], list)
+ assert len(log_buffer['not_tensor']) == 3
+ log_buffer.update({'not_tensor': 4, 'a': 5})
+ assert log_buffer['not_tensor'] == 4
+
+
+@pytest.mark.cudatest
+class TestCudaFetcher:
+
+ def get_dataloader(self):
+
+ class Dataset(object):
+
+ def __init__(self):
+ self.data = torch.randn(2560, 2560)
+
+ def __len__(self):
+ return 100
+
+ def __getitem__(self, idx):
+ return self.data
+
+ return DataLoader(Dataset(), batch_size=3)
+
+ def get_model(self):
+
+ class Model(nn.Module):
+
+ def __init__(self):
+ super(Model, self).__init__()
+ self.main = [nn.Linear(2560, 2560) for _ in range(100)]
+ self.main = nn.Sequential(*self.main)
+
+ def forward(self, x):
+ x = self.main(x)
+ return x
+
+ return Model()
+
+ def test_naive(self):
+ model = self.get_model()
+ model.cuda()
+ timer = EasyTimer()
+ dataloader = iter(self.get_dataloader())
+ dataloader = CudaFetcher(dataloader, device='cuda', sleep=0.1)
+ dataloader.run()
+
+ count = 0
+ while True:
+ with timer:
+ data = next(dataloader)
+ model(data)
+ print('count {}, run_time: {}'.format(count, timer.value))
+ count += 1
+ if count == 10:
+ break
+
+ dataloader.close()
+
+
+@pytest.mark.cudatest
+def test_to_device_cuda(setup_data_dict):
+ setup_data_dict['module'] = nn.Linear(3, 5)
+ device = 'cuda'
+ cuda_d = to_device(setup_data_dict, device, ignore_keys=['module'])
+ assert cuda_d['module'].weight.device == torch.device('cpu')
+ other = EasyTimer()
+ with pytest.raises(TypeError):
+ to_device(other)
+
+
+@pytest.mark.unittest
+def test_to_device_cpu(setup_data_dict):
+ setup_data_dict['module'] = nn.Linear(3, 5)
+ device = 'cpu'
+ cuda_d = to_device(setup_data_dict, device, ignore_keys=['module'])
+ assert cuda_d['module'].weight.device == torch.device('cpu')
+ other = EasyTimer()
+ with pytest.raises(TypeError):
+ to_device(other)
diff --git a/DI-engine/ding/torch_utils/tests/test_distribution.py b/DI-engine/ding/torch_utils/tests/test_distribution.py
new file mode 100644
index 0000000000000000000000000000000000000000..a080a3100e646671db214bf90303b9cb2547efba
--- /dev/null
+++ b/DI-engine/ding/torch_utils/tests/test_distribution.py
@@ -0,0 +1,66 @@
+import pytest
+import torch
+
+from ding.torch_utils.distribution import Pd, CategoricalPd, CategoricalPdPytorch
+
+
+@pytest.mark.unittest
+class TestProbDistribution:
+
+ def test_Pd(self):
+ pd = Pd()
+ with pytest.raises(NotImplementedError):
+ pd.neglogp(torch.randn(5, ))
+ with pytest.raises(NotImplementedError):
+ pd.noise_mode()
+ with pytest.raises(NotImplementedError):
+ pd.mode()
+ with pytest.raises(NotImplementedError):
+ pd.sample()
+
+ def test_CatePD(self):
+ pd = CategoricalPd()
+ logit1 = torch.randn(3, 5, requires_grad=True)
+ logit2 = torch.randint(5, (3, ), dtype=torch.int64)
+
+ pd.update_logits(logit1)
+ entropy = pd.neglogp(logit2)
+ assert entropy.requires_grad
+ assert entropy.shape == torch.Size([])
+
+ entropy = pd.entropy()
+ assert entropy.requires_grad
+ assert entropy.shape == torch.Size([])
+ entropy = pd.entropy(reduction=None)
+ assert entropy.requires_grad
+ assert entropy.shape == torch.Size([3])
+
+ ret = pd.sample()
+ assert ret.shape == torch.Size([3])
+ ret = pd.sample(viz=True)
+ assert ret[0].shape == torch.Size([3])
+
+ ret = pd.mode()
+ assert ret.shape == torch.Size([3])
+ ret = pd.mode(viz=True)
+ assert ret[0].shape == torch.Size([3])
+
+ ret = pd.noise_mode()
+ assert ret.shape == torch.Size([3])
+ ret = pd.noise_mode(viz=True)
+ assert ret[0].shape == torch.Size([3])
+
+ pd = CategoricalPdPytorch()
+ pd.update_logits(logit1)
+
+ ret = pd.sample()
+ assert ret.shape == torch.Size([3])
+ ret = pd.mode()
+ assert ret.shape == torch.Size([3])
+
+ entropy = pd.entropy(reduction='mean')
+ assert entropy.requires_grad
+ assert entropy.shape == torch.Size([])
+ entropy = pd.entropy(reduction=None)
+ assert entropy.requires_grad
+ assert entropy.shape == torch.Size([3])
diff --git a/DI-engine/ding/torch_utils/tests/test_feature_merge.py b/DI-engine/ding/torch_utils/tests/test_feature_merge.py
new file mode 100644
index 0000000000000000000000000000000000000000..41cfc57f5c901ac8efff4c96aa6c6f696ba2fae2
--- /dev/null
+++ b/DI-engine/ding/torch_utils/tests/test_feature_merge.py
@@ -0,0 +1,131 @@
+import pytest
+import torch
+from ding.torch_utils.network.merge import TorchBilinearCustomized, TorchBilinear, BilinearGeneral, FiLM
+
+
+@pytest.mark.unittest
+def test_torch_bilinear_customized():
+ batch_size = 10
+ in1_features = 20
+ in2_features = 30
+ out_features = 40
+ bilinear_customized = TorchBilinearCustomized(in1_features, in2_features, out_features)
+ x = torch.randn(batch_size, in1_features)
+ z = torch.randn(batch_size, in2_features)
+ out = bilinear_customized(x, z)
+ assert out.shape == (batch_size, out_features), "Output shape does not match expected shape."
+
+
+@pytest.mark.unittest
+def test_torch_bilinear():
+ batch_size = 10
+ in1_features = 20
+ in2_features = 30
+ out_features = 40
+ torch_bilinear = TorchBilinear(in1_features, in2_features, out_features)
+ x = torch.randn(batch_size, in1_features)
+ z = torch.randn(batch_size, in2_features)
+ out = torch_bilinear(x, z)
+ assert out.shape == (batch_size, out_features), "Output shape does not match expected shape."
+
+
+@pytest.mark.unittest
+def test_bilinear_consistency():
+ batch_size = 10
+ in1_features = 20
+ in2_features = 30
+ out_features = 40
+
+ # Initialize weights and biases with set values
+ weight = torch.randn(out_features, in1_features, in2_features)
+ bias = torch.randn(out_features)
+
+ # Create and initialize TorchBilinearCustomized and TorchBilinear models
+ bilinear_customized = TorchBilinearCustomized(in1_features, in2_features, out_features)
+ bilinear_customized.weight.data = weight.clone()
+ bilinear_customized.bias.data = bias.clone()
+
+ torch_bilinear = TorchBilinear(in1_features, in2_features, out_features)
+ torch_bilinear.weight.data = weight.clone()
+ torch_bilinear.bias.data = bias.clone()
+
+ # Provide same input to both models
+ x = torch.randn(batch_size, in1_features)
+ z = torch.randn(batch_size, in2_features)
+
+ # Compute outputs
+ out_bilinear_customized = bilinear_customized(x, z)
+ out_torch_bilinear = torch_bilinear(x, z)
+
+ # Compute the mean squared error between outputs
+ mse = torch.mean((out_bilinear_customized - out_torch_bilinear) ** 2)
+
+ print(f"Mean Squared Error between outputs: {mse.item()}")
+
+ # Check if outputs are the same
+ # assert torch.allclose(out_bilinear_customized, out_torch_bilinear),
+ # "Outputs of TorchBilinearCustomized and TorchBilinear are not the same."
+
+
+def test_bilinear_general():
+ """
+ Overview:
+ Test for the `BilinearGeneral` class.
+ """
+ # Define the input dimensions and batch size
+ in1_features = 20
+ in2_features = 30
+ out_features = 40
+ batch_size = 10
+
+ # Create a BilinearGeneral instance
+ bilinear_general = BilinearGeneral(in1_features, in2_features, out_features)
+
+ # Create random inputs
+ input1 = torch.randn(batch_size, in1_features)
+ input2 = torch.randn(batch_size, in2_features)
+
+ # Perform forward pass
+ output = bilinear_general(input1, input2)
+
+ # Check output shape
+ assert output.shape == (batch_size, out_features), "Output shape does not match expected shape."
+
+ # Check parameter shapes
+ assert bilinear_general.W.shape == (
+ out_features, in1_features, in2_features
+ ), "Weight W shape does not match expected shape."
+ assert bilinear_general.U.shape == (out_features, in2_features), "Weight U shape does not match expected shape."
+ assert bilinear_general.V.shape == (out_features, in1_features), "Weight V shape does not match expected shape."
+ assert bilinear_general.b.shape == (out_features, ), "Bias shape does not match expected shape."
+
+ # Check parameter types
+ assert isinstance(bilinear_general.W, torch.nn.Parameter), "Weight W is not an instance of torch.nn.Parameter."
+ assert isinstance(bilinear_general.U, torch.nn.Parameter), "Weight U is not an instance of torch.nn.Parameter."
+ assert isinstance(bilinear_general.V, torch.nn.Parameter), "Weight V is not an instance of torch.nn.Parameter."
+ assert isinstance(bilinear_general.b, torch.nn.Parameter), "Bias is not an instance of torch.nn.Parameter."
+
+
+@pytest.mark.unittest
+def test_film_forward():
+ # Set the feature and context dimensions
+ feature_dim = 128
+ context_dim = 256
+
+ # Initialize the FiLM layer
+ film_layer = FiLM(feature_dim, context_dim)
+
+ # Create random feature and context vectors
+ feature = torch.randn((32, feature_dim)) # batch size is 32
+ context = torch.randn((32, context_dim)) # batch size is 32
+
+ # Forward propagation
+ conditioned_feature = film_layer(feature, context)
+
+ # Check the output shape
+ assert conditioned_feature.shape == feature.shape, \
+ f'Expected output shape {feature.shape}, but got {conditioned_feature.shape}'
+
+ # Check that the output is different from the input
+ assert not torch.all(torch.eq(feature, conditioned_feature)), \
+ 'The output feature is the same as the input feature'
diff --git a/DI-engine/ding/torch_utils/tests/test_lr_scheduler.py b/DI-engine/ding/torch_utils/tests/test_lr_scheduler.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ba52d9e1f376721afcb3411d3f77c5bde8a7a03
--- /dev/null
+++ b/DI-engine/ding/torch_utils/tests/test_lr_scheduler.py
@@ -0,0 +1,20 @@
+import pytest
+import torch
+from torch.optim import Adam
+
+from ding.torch_utils.lr_scheduler import cos_lr_scheduler
+
+
+@pytest.mark.unittest
+class TestLRSchedulerHelper:
+
+ def test_cos_lr_scheduler(self):
+ r"""
+ Overview:
+ Test the cos lr scheduler.
+ """
+ net = torch.nn.Linear(3, 4)
+ opt = Adam(net.parameters(), lr=1e-2)
+ scheduler = cos_lr_scheduler(opt, learning_rate=1e-2, min_lr=6e-5)
+ scheduler.step(101)
+ assert opt.param_groups[0]['lr'] == 6e-5
diff --git a/DI-engine/ding/torch_utils/tests/test_math_helper.py b/DI-engine/ding/torch_utils/tests/test_math_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e0c3236076f81c8e94dc63c4ed1e090e693a2bf
--- /dev/null
+++ b/DI-engine/ding/torch_utils/tests/test_math_helper.py
@@ -0,0 +1,46 @@
+import numpy as np
+import pytest
+import torch
+
+from ding.torch_utils.math_helper import cov
+
+
+@pytest.mark.unittest
+class TestMathHelper:
+
+ def test_cov(self):
+ r'''
+ Overview:
+ Test the conv
+ '''
+ # test 1D
+ # test dtype and rowvar
+ x1 = np.array([1, 2, 3])
+ cov1 = np.cov(x1, rowvar=False)
+ x1_tensor = torch.FloatTensor(x1)
+ cov1_tensor = cov(x1_tensor, rowvar=False).detach().numpy()
+ assert (np.abs(cov1 - cov1_tensor) < 1e-6).any()
+
+ # test 2D
+ x2 = np.array([[0., 2.], [1., 1.], [2., 0.]]).T
+ cov2 = np.cov(x2, rowvar=True)
+ x2_tensor = torch.FloatTensor(x2)
+ cov2_tensor = cov(x2_tensor, rowvar=True).detach().numpy()
+ assert (np.abs(cov2 - cov2_tensor) < 1e-6).any()
+
+ # test bias
+ cov3 = np.cov(x2, rowvar=True, bias=True)
+ cov3_tensor = cov(x2_tensor, rowvar=True, bias=True).detach().numpy()
+ assert (np.abs(cov3 - cov3_tensor) < 1e-6).any()
+
+ # test ddof
+ aweights = np.array([1., 2., 3.])
+ cov4 = np.cov(x2, rowvar=True, ddof=0, aweights=aweights)
+ cov4_tensor = cov(x2_tensor, rowvar=True, ddof=0, aweights=aweights).detach().numpy()
+ assert (np.abs(cov4 - cov4_tensor) < 1e-6).any()
+
+ # test aweights
+ cov5 = np.cov(x2, rowvar=True, aweights=aweights)
+ aweights_tensor = torch.FloatTensor(aweights)
+ cov5_tensor = cov(x2_tensor, rowvar=True, aweights=aweights_tensor).detach().numpy()
+ assert (np.abs(cov5 - cov5_tensor) < 1e-6).any()
diff --git a/DI-engine/ding/torch_utils/tests/test_metric.py b/DI-engine/ding/torch_utils/tests/test_metric.py
new file mode 100644
index 0000000000000000000000000000000000000000..f1a80ddec684589605c3427bfbd6ab5442792acd
--- /dev/null
+++ b/DI-engine/ding/torch_utils/tests/test_metric.py
@@ -0,0 +1,54 @@
+import random
+
+import pytest
+import torch
+
+from ding.torch_utils.metric import levenshtein_distance, hamming_distance
+
+
+@pytest.mark.unittest
+class TestMetric():
+
+ def test_levenshtein_distance(self):
+ r'''
+ Overview:
+ Test the Levenshtein Distance
+ '''
+ pred = torch.LongTensor([1, 4, 6, 4, 1])
+ target1 = torch.LongTensor([1, 6, 4, 4, 1])
+ distance = levenshtein_distance(pred, target1)
+ assert (distance.item() == 2)
+
+ target2 = torch.LongTensor([])
+ distance = levenshtein_distance(pred, target2)
+ assert (distance.item() == 5)
+
+ target3 = torch.LongTensor([6, 4, 1])
+ distance = levenshtein_distance(pred, target3)
+ assert (distance.item() == 2)
+ target3 = torch.LongTensor([6, 4, 1])
+ distance = levenshtein_distance(pred, target3, pred, target3, extra_fn=lambda x, y: x + y)
+ assert distance.item() == 13
+ target4 = torch.LongTensor([1, 4, 1])
+ distance = levenshtein_distance(pred, target4, pred, target4, extra_fn=lambda x, y: x + y)
+ assert distance.item() == 14
+
+ def test_hamming_distance(self):
+ r'''
+ Overview:
+ Test the Hamming Distance
+ '''
+ base = torch.zeros(8).long()
+ index = [i for i in range(8)]
+ for i in range(2):
+ pred_idx = random.sample(index, 4)
+ target_idx = random.sample(index, 4)
+ pred = base.clone()
+ pred[pred_idx] = 1
+ target = base.clone()
+ target[target_idx] = 1
+ pred = pred.unsqueeze(0)
+ target = target.unsqueeze(0)
+ distance = hamming_distance(pred, target)
+ diff = len(set(pred_idx).union(set(target_idx)) - set(pred_idx).intersection(set(target_idx)))
+ assert (distance.item() == diff)
diff --git a/DI-engine/ding/torch_utils/tests/test_model_helper.py b/DI-engine/ding/torch_utils/tests/test_model_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2dd72e54e65b408d510060cfa307b7fa73f08ad
--- /dev/null
+++ b/DI-engine/ding/torch_utils/tests/test_model_helper.py
@@ -0,0 +1,19 @@
+import pytest
+import torch
+
+from ding.torch_utils.model_helper import get_num_params
+
+
+@pytest.mark.unittest
+class TestModelHelper:
+
+ def test_model_helper(self):
+ r"""
+ Overview:
+ Test the model helper.
+ """
+ net = torch.nn.Linear(3, 4, bias=False)
+ assert get_num_params(net) == 12
+
+ net = torch.nn.Conv2d(3, 3, kernel_size=3, bias=False)
+ assert get_num_params(net) == 81
diff --git a/DI-engine/ding/torch_utils/tests/test_nn_test_helper.py b/DI-engine/ding/torch_utils/tests/test_nn_test_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc249a330190c9366d3532b5bf828bb71a5d95c9
--- /dev/null
+++ b/DI-engine/ding/torch_utils/tests/test_nn_test_helper.py
@@ -0,0 +1,37 @@
+import pytest
+import torch
+import torch.nn as nn
+
+from ding.torch_utils.nn_test_helper import is_differentiable
+
+
+@pytest.mark.unittest
+def test_is_differentibale():
+
+ class LinearNet(nn.Module):
+
+ def __init__(self, features_in=1, features_out=1):
+ super().__init__()
+ self.linear = nn.Linear(features_in, features_out)
+ self._init_weight()
+
+ def forward(self, x):
+ return self.linear(x)
+
+ def _init_weight(self):
+ nn.init.constant_(self.linear.weight, val=1)
+ nn.init.constant_(self.linear.bias, val=0)
+
+ net = LinearNet()
+ mse_fn = nn.L1Loss()
+ net._init_weight()
+ x = torch.FloatTensor([120])
+ target_value = torch.FloatTensor([2])
+ target_value.requires_grad = True
+ loss = mse_fn(net(x), target_value)
+ assert is_differentiable(loss, net) is None
+ with pytest.raises(AssertionError):
+ value = net(x).detach()
+ target_value = torch.FloatTensor([2])
+ target_value.requires_grad = False
+ is_differentiable(loss, net)
diff --git a/DI-engine/ding/torch_utils/tests/test_optimizer.py b/DI-engine/ding/torch_utils/tests/test_optimizer.py
new file mode 100644
index 0000000000000000000000000000000000000000..389346fe545a175698330a0ff0b8390919e89902
--- /dev/null
+++ b/DI-engine/ding/torch_utils/tests/test_optimizer.py
@@ -0,0 +1,197 @@
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from ding.torch_utils.optimizer_helper import Adam, RMSprop, calculate_grad_norm, \
+ calculate_grad_norm_without_bias_two_norm, PCGrad, configure_weight_decay
+import pytest
+import time
+
+
+class LinearNet(nn.Module):
+
+ def __init__(self, features_in=1, features_out=1):
+ super().__init__()
+ self.linear = nn.Linear(features_in, features_out)
+ self._init_weight()
+
+ def forward(self, x):
+ return self.linear(x)
+
+ def _init_weight(self):
+ nn.init.constant_(self.linear.weight, val=1)
+ nn.init.constant_(self.linear.bias, val=0)
+
+
+def try_optim_with(tname, t, optim_t):
+ net = LinearNet()
+ mse_fn = nn.L1Loss()
+ if tname == 'grad_clip':
+ if optim_t == 'rmsprop':
+ optimizer = RMSprop(
+ net.parameters(),
+ grad_clip_type=t,
+ clip_value=0.000001,
+ clip_norm_type=1.2,
+ lr=0.1,
+ clip_momentum_timestep=2,
+ ignore_momentum_timestep=2,
+ clip_coef=0.5
+ )
+ else:
+ optimizer = Adam(
+ net.parameters(),
+ grad_clip_type=t,
+ clip_value=0.000001,
+ clip_norm_type=1.2,
+ lr=0.1,
+ optim_type=optim_t,
+ clip_momentum_timestep=2,
+ ignore_momentum_timestep=2,
+ clip_coef=0.5
+ )
+ if tname == 'grad_ignore':
+ if optim_t == 'rmsprop':
+ optimizer = RMSprop(
+ net.parameters(),
+ grad_ignore_type=t,
+ clip_value=0.000001,
+ ignore_value=0.000001,
+ ignore_norm_type=1.2,
+ lr=0.1,
+ clip_momentum_timestep=2,
+ ignore_momentum_timestep=2,
+ )
+ else:
+ optimizer = Adam(
+ net.parameters(),
+ grad_ignore_type=t,
+ clip_value=0.000001,
+ ignore_value=0.000001,
+ ignore_norm_type=1.2,
+ lr=0.1,
+ optim_type=optim_t,
+ clip_momentum_timestep=2,
+ ignore_momentum_timestep=2,
+ ignore_coef=0.01
+ )
+ # 网络输入和标签
+ x = torch.FloatTensor([120])
+ x.requires_grad = True
+ target_value = torch.FloatTensor([2])
+ target_value.requires_grad = True
+ # loss计算
+ for _ in range(10):
+ predict = net(x)
+ loss = mse_fn(predict, target_value)
+ loss.backward()
+ optimizer.step()
+ if t is not None and 'ignore' not in t:
+ assert optimizer.get_grad() != 0.
+ for _ in range(10):
+ target_value = torch.FloatTensor([_ ** 2])
+ target_value.requires_grad = True
+ predict = net(x)
+ loss = mse_fn(predict, target_value)
+ loss.backward()
+ optimizer.step()
+
+ if t is None:
+ print("weight without optimizer clip:" + str(net.linear.weight))
+ else:
+ print("weight with optimizer {} of type: {} is ".format(tname, t) + str(net.linear.weight))
+
+ weight = net.linear.weight
+ return weight
+
+
+@pytest.mark.unittest
+class TestAdam:
+
+ def test_naive(self):
+ support_type = {
+ 'optim': ['adam', 'adamw'],
+ 'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'],
+ 'grad_norm': [None],
+ 'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'],
+ }
+
+ for optim_t in support_type['optim']:
+ for tname in ['grad_clip', 'grad_ignore']:
+ for t in support_type[tname]:
+ try_optim_with(tname=tname, t=t, optim_t=optim_t)
+
+
+@pytest.mark.unittest
+class TestRMSprop:
+
+ def test_naive(self):
+ support_type = {
+ 'grad_clip': [None, 'clip_momentum', 'clip_value', 'clip_norm', 'clip_momentum_norm'],
+ 'grad_norm': [None],
+ 'grad_ignore': [None, 'ignore_momentum', 'ignore_value', 'ignore_norm', 'ignore_momentum_norm'],
+ }
+
+ for tname in ['grad_clip', 'grad_ignore']:
+ for t in support_type[tname]:
+ try_optim_with(tname=tname, t=t, optim_t='rmsprop')
+
+
+@pytest.mark.unittest
+class Test_calculate_grad_norm_with_without_bias:
+
+ def test_two_functions(self):
+ net = LinearNet()
+ mse_fn = nn.L1Loss()
+ optimizer = Adam(net.parameters(), )
+ x = torch.FloatTensor([120])
+ x.requires_grad = True
+ target_value = torch.FloatTensor([2])
+ target_value.requires_grad = True
+ for _ in range(10):
+ predict = net(x)
+ loss = mse_fn(predict, target_value)
+ loss.backward()
+ optimizer.step()
+ inf_norm = calculate_grad_norm(model=net, norm_type='inf')
+ two_norm = calculate_grad_norm(model=net)
+ two_norm_nobias = float(calculate_grad_norm_without_bias_two_norm(model=net))
+ one_norm = calculate_grad_norm(model=net, norm_type=1)
+ assert isinstance(two_norm, float)
+ assert isinstance(inf_norm, float)
+ assert isinstance(one_norm, float)
+ assert isinstance(two_norm_nobias, float)
+
+
+@pytest.mark.unittest
+class TestPCGrad:
+
+ def naive_test(self):
+ x, y = torch.randn(2, 3), torch.randn(2, 4)
+ net = LinearNet(3, 4)
+ y_pred = net(x)
+ pc_adam = PCGrad(optim.Adam(net.parameters()))
+ pc_adam.zero_grad()
+ loss1_fn, loss2_fn = nn.L1Loss(), nn.MSELoss()
+ loss1, loss2 = loss1_fn(y_pred, y), loss2_fn(y_pred, y)
+
+ pc_adam.pc_backward([loss1, loss2])
+ for p in net.parameters():
+ assert isinstance(p, torch.Tensor)
+
+
+@pytest.mark.unittest
+class TestWeightDecay:
+
+ def test_wd(self):
+ net = nn.Sequential(nn.Linear(3, 4), nn.LayerNorm(4))
+ x = torch.randn(1, 3)
+ group_params = configure_weight_decay(model=net, weight_decay=1e-4)
+ assert group_params[0]['weight_decay'] == 1e-4
+ assert group_params[1]['weight_decay'] == 0
+ assert len(group_params[0]['params']) == 1
+ assert len(group_params[1]['params']) == 3
+ opt = Adam(group_params, lr=1e-2)
+ opt.zero_grad()
+ y = torch.sum(net(x))
+ y.backward()
+ opt.step()
diff --git a/DI-engine/ding/torch_utils/tests/test_parameter.py b/DI-engine/ding/torch_utils/tests/test_parameter.py
new file mode 100644
index 0000000000000000000000000000000000000000..3467c6582980e422a70f69461dc140fc6230215a
--- /dev/null
+++ b/DI-engine/ding/torch_utils/tests/test_parameter.py
@@ -0,0 +1,25 @@
+import unittest
+import pytest
+import torch
+from ding.torch_utils.parameter import NonegativeParameter, TanhParameter
+
+
+@pytest.mark.unittest
+def test_nonegative_parameter():
+ nonegative_parameter = NonegativeParameter(torch.tensor([2.0, 3.0]))
+ assert torch.sum(torch.abs(nonegative_parameter() - torch.tensor([2.0, 3.0]))) == 0
+ nonegative_parameter.set_data(torch.tensor(1))
+ assert nonegative_parameter() == 1
+
+
+@pytest.mark.unittest
+def test_tanh_parameter():
+ tanh_parameter = TanhParameter(torch.tensor([0.5, -0.2]))
+ assert torch.isclose(tanh_parameter() - torch.tensor([0.5, -0.2]), torch.zeros(2), atol=1e-6).all()
+ tanh_parameter.set_data(torch.tensor(0.3))
+ assert tanh_parameter() == 0.3
+
+
+if __name__ == "__main__":
+ test_nonegative_parameter()
+ test_tanh_parameter()
diff --git a/DI-engine/ding/torch_utils/tests/test_reshape_helper.py b/DI-engine/ding/torch_utils/tests/test_reshape_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ba539ac947d48e73575183c91871cd36df2516b
--- /dev/null
+++ b/DI-engine/ding/torch_utils/tests/test_reshape_helper.py
@@ -0,0 +1,35 @@
+import pytest
+import torch
+from ding.torch_utils.reshape_helper import fold_batch, unfold_batch, unsqueeze_repeat
+
+
+@pytest.mark.unittest
+def test_fold_unfold_batch():
+ T, B, C, H, W = 10, 20, 3, 255, 255
+ data = torch.randn(T, B, C, H, W)
+ data, batch_dim = fold_batch(data, nonbatch_ndims=3)
+ assert data.shape == (T * B, C, H, W) and batch_dim == (T, B)
+ data = unfold_batch(data, batch_dim)
+ assert data.shape == (T, B, C, H, W)
+
+ T, B, N = 10, 20, 100
+ data = torch.randn(T, B, N)
+ data, batch_dim = fold_batch(data, nonbatch_ndims=1)
+ assert data.shape == (T * B, N) and batch_dim == (T, B)
+ data = unfold_batch(data, batch_dim)
+ assert data.shape == (T, B, N)
+
+
+@pytest.mark.unittest
+def test_unsqueeze_repeat():
+ T, B, C, H, W = 10, 20, 3, 255, 255
+ repeat_times = 4
+ data = torch.randn(T, B, C, H, W)
+ ensembled_data = unsqueeze_repeat(data, repeat_times)
+ assert ensembled_data.shape == (repeat_times, T, B, C, H, W)
+
+ ensembled_data = unsqueeze_repeat(data, repeat_times, -1)
+ assert ensembled_data.shape == (T, B, C, H, W, repeat_times)
+
+ ensembled_data = unsqueeze_repeat(data, repeat_times, 2)
+ assert ensembled_data.shape == (T, B, repeat_times, C, H, W)
diff --git a/DI-engine/ding/utils/__init__.py b/DI-engine/ding/utils/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c2cb1c326321fa6a63397f30b828b8e23934d85
--- /dev/null
+++ b/DI-engine/ding/utils/__init__.py
@@ -0,0 +1,41 @@
+import ding
+from .collection_helper import iter_mapping
+from .compression_helper import get_data_compressor, get_data_decompressor, CloudPickleWrapper
+from .default_helper import override, dicts_to_lists, lists_to_dicts, squeeze, default_get, error_wrapper, list_split, \
+ LimitedSpaceContainer, deep_merge_dicts, set_pkg_seed, flatten_dict, one_time_warning, split_data_generator, \
+ RunningMeanStd, make_key_as_identifier, remove_illegal_item
+from .design_helper import SingletonMetaclass
+from .file_helper import read_file, save_file, remove_file
+from .import_helper import try_import_ceph, try_import_mc, try_import_link, import_module, try_import_redis, \
+ try_import_rediscluster
+from .k8s_helper import get_operator_server_kwargs, exist_operator_server, DEFAULT_K8S_COLLECTOR_PORT, \
+ DEFAULT_K8S_LEARNER_PORT, DEFAULT_K8S_AGGREGATOR_SLAVE_PORT, DEFAULT_K8S_COORDINATOR_PORT, pod_exec_command, \
+ K8sLauncher
+from .lock_helper import LockContext, LockContextType, get_file_lock, get_rw_file_lock
+from .log_helper import build_logger, pretty_print, LoggerFactory
+from .log_writer_helper import DistributedWriter
+from .orchestrator_launcher import OrchestratorLauncher
+from .profiler_helper import Profiler, register_profiler
+from .registry_factory import registries, POLICY_REGISTRY, ENV_REGISTRY, LEARNER_REGISTRY, COMM_LEARNER_REGISTRY, \
+ SERIAL_COLLECTOR_REGISTRY, PARALLEL_COLLECTOR_REGISTRY, COMM_COLLECTOR_REGISTRY, \
+ COMMANDER_REGISTRY, LEAGUE_REGISTRY, PLAYER_REGISTRY, MODEL_REGISTRY, ENV_MANAGER_REGISTRY, ENV_WRAPPER_REGISTRY, \
+ REWARD_MODEL_REGISTRY, BUFFER_REGISTRY, DATASET_REGISTRY, SERIAL_EVALUATOR_REGISTRY, MQ_REGISTRY, \
+ WORLD_MODEL_REGISTRY, STOCHASTIC_OPTIMIZER_REGISTRY
+from .scheduler_helper import Scheduler
+from .segment_tree import SumSegmentTree, MinSegmentTree, SegmentTree
+from .slurm_helper import find_free_port_slurm, node_to_host, node_to_partition
+from .system_helper import get_ip, get_pid, get_task_uid, PropagatingThread, find_free_port
+from .time_helper import build_time_helper, EasyTimer, WatchDog
+from .type_helper import SequenceType
+from .render_helper import render, fps, get_env_fps, render_env
+from .fast_copy import fastcopy
+from .bfs_helper import get_vi_sequence
+from .normalizer_helper import DatasetNormalizer
+
+if ding.enable_linklink: # False as default
+ from .linklink_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \
+ allreduce, broadcast, DistContext, allreduce_async, synchronize
+else:
+ from .pytorch_ddp_dist_helper import get_rank, get_world_size, dist_mode, dist_init, dist_finalize, \
+ allreduce, broadcast, DDPContext, allreduce_async, synchronize, reduce_data, broadcast_object_list, \
+ to_ddp_config, allreduce_data
diff --git a/DI-engine/ding/utils/autolog/__init__.py b/DI-engine/ding/utils/autolog/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8945ebaf3c8dadfaf6795782429ac34929f62784
--- /dev/null
+++ b/DI-engine/ding/utils/autolog/__init__.py
@@ -0,0 +1,8 @@
+from .base import TimeMode
+from .data import RangedData, TimeRangedData
+from .model import LoggedModel
+from .time_ctl import BaseTime, NaturalTime, TickTime, TimeProxy
+from .value import LoggedValue
+
+if __name__ == "__main__":
+ pass
diff --git a/DI-engine/ding/utils/autolog/base.py b/DI-engine/ding/utils/autolog/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..88a7fbef99413be7dc360da43f5f502b453a485a
--- /dev/null
+++ b/DI-engine/ding/utils/autolog/base.py
@@ -0,0 +1,24 @@
+from enum import unique, IntEnum
+from typing import TypeVar, Union
+
+_LOGGED_VALUE__PROPERTY_NAME = '__property_name__'
+_LOGGED_MODEL__PROPERTIES = '__properties__'
+_LOGGED_MODEL__PROPERTY_ATTR_PREFIX = '_property_'
+
+_TimeType = TypeVar('_TimeType', bound=Union[float, int])
+_ValueType = TypeVar('_ValueType')
+
+
+@unique
+class TimeMode(IntEnum):
+ """
+ Overview:
+ Mode that used to decide the format of range_values function
+
+ ABSOLUTE: use absolute time
+ RELATIVE_LIFECYCLE: use relative time based on property's lifecycle
+ RELATIVE_CURRENT_TIME: use relative time based on current time
+ """
+ ABSOLUTE = 0
+ RELATIVE_LIFECYCLE = 1
+ RELATIVE_CURRENT_TIME = 2
diff --git a/DI-engine/ding/utils/autolog/data.py b/DI-engine/ding/utils/autolog/data.py
new file mode 100644
index 0000000000000000000000000000000000000000..e611b97f432a018f0874e0f871ea639ff7fda65d
--- /dev/null
+++ b/DI-engine/ding/utils/autolog/data.py
@@ -0,0 +1,318 @@
+import pickle
+from abc import abstractmethod, ABCMeta
+from collections import deque
+from threading import Lock
+from typing import TypeVar, Iterable, List, Tuple, Union
+
+from .time_ctl import BaseTime
+
+_Tp = TypeVar('_Tp')
+
+
+class RangedData(metaclass=ABCMeta):
+ """
+ Overview:
+ A data structure that can store data for a period of time.
+ Interfaces:
+ ``__init__``, ``append``, ``extend``, ``current``, ``history``, ``expire``, ``__bool__``, ``_get_time``.
+ Properties:
+ - expire (:obj:`float`): The expire time.
+ """
+
+ def __init__(self, expire: float, use_pickle: bool = False):
+ """
+ Overview:
+ Initialize the RangedData object.
+ Arguments:
+ - expire (:obj:`float`): The expire time of the data.
+ - use_pickle (:obj:`bool`): Whether to use pickle to serialize the data.
+ """
+
+ self.__expire = expire
+ self.__use_pickle = use_pickle
+ self.__check_expire()
+
+ self.__data_max_id = 0
+ self.__data_items = {}
+ self.__data_lock = Lock()
+
+ self.__last_item = None
+ self.__queue = deque()
+ self.__lock = Lock()
+
+ def __check_expire(self):
+ """
+ Overview:
+ Check the expire time.
+ """
+
+ if isinstance(self.__expire, (int, float)):
+ if self.__expire <= 0:
+ raise ValueError(
+ "Expire should be greater than 0, but {actual} found.".format(actual=repr(self.__expire))
+ )
+ else:
+ raise TypeError(
+ 'Expire should be int or float, but {actual} found.'.format(actual=type(self.__expire).__name__)
+ )
+
+ def __registry_data_item(self, data: _Tp) -> int:
+ """
+ Overview:
+ Registry the data item.
+ Arguments:
+ - data (:obj:`_Tp`): The data item.
+ """
+
+ with self.__data_lock:
+ self.__data_max_id += 1
+ if self.__use_pickle:
+ self.__data_items[self.__data_max_id] = pickle.dumps(data)
+ else:
+ self.__data_items[self.__data_max_id] = data
+
+ return self.__data_max_id
+
+ def __get_data_item(self, data_id: int) -> _Tp:
+ """
+ Overview:
+ Get the data item.
+ Arguments:
+ - data_id (:obj:`int`): The data id.
+ """
+
+ with self.__data_lock:
+ if self.__use_pickle:
+ return pickle.loads(self.__data_items[data_id])
+ else:
+ return self.__data_items[data_id]
+
+ def __remove_data_item(self, data_id: int):
+ """
+ Overview:
+ Remove the data item.
+ Arguments:
+ - data_id (:obj:`int`): The data id.
+ """
+
+ with self.__data_lock:
+ del self.__data_items[data_id]
+
+ def __check_time(self, time_: float):
+ """
+ Overview:
+ Check the time.
+ Arguments:
+ - time_ (:obj:`float`): The time.
+ """
+
+ if self.__queue:
+ _time, _ = self.__queue[-1]
+ if time_ < _time:
+ raise ValueError(
+ "Time {time} invalid for descending from last time {last_time}".format(
+ time=repr(time_), last_time=repr(_time)
+ )
+ )
+
+ def __append_item(self, time_: float, data: _Tp):
+ """
+ Overview:
+ Append the data item.
+ Arguments:
+ - time_ (:obj:`float`): The time.
+ - data (:obj:`_Tp`): The data item.
+ """
+
+ self.__queue.append((time_, self.__registry_data_item(data)))
+
+ def __flush_history(self):
+ """
+ Overview:
+ Flush the history data.
+ """
+
+ _time = self._get_time()
+ _limit_time = _time - self.__expire
+ while self.__queue:
+ _head_time, _head_id = self.__queue.popleft()
+ if _head_time >= _limit_time:
+ self.__queue.appendleft((_head_time, _head_id))
+ break
+ else:
+ if self.__last_item:
+ _last_time, _last_id = self.__last_item
+ self.__remove_data_item(_last_id)
+
+ self.__last_item = (_head_time, _head_id)
+
+ def __append(self, time_: float, data: _Tp):
+ """
+ Overview:
+ Append the data.
+ """
+
+ self.__check_time(time_)
+ self.__append_item(time_, data)
+ self.__flush_history()
+
+ def __current(self):
+ """
+ Overview:
+ Get the current data.
+ """
+
+ if self.__queue:
+ _tail_time, _tail_id = self.__queue.pop()
+ self.__queue.append((_tail_time, _tail_id))
+ return self.__get_data_item(_tail_id)
+ elif self.__last_item:
+ _last_time, _last_id = self.__last_item
+ return self.__get_data_item(_last_id)
+ else:
+ raise ValueError("This range is empty.")
+
+ def __history_yield(self):
+ """
+ Overview:
+ Yield the history data.
+ """
+
+ _time = self._get_time()
+ _limit_time = _time - self.__expire
+ _latest_time, _latest_id = None, None
+
+ if self.__last_item:
+ _latest_time, _latest_id = _last_time, _last_id = self.__last_item
+ yield max(_last_time, _limit_time), self.__get_data_item(_last_id)
+
+ for _item_time, _item_id in self.__queue:
+ _latest_time, _latest_id = _item_time, _item_id
+ yield _item_time, self.__get_data_item(_item_id)
+
+ if _latest_time is not None and _latest_time < _time:
+ yield _time, self.__get_data_item(_latest_id)
+
+ def __history(self):
+ """
+ Overview:
+ Get the history data.
+ """
+
+ return list(self.__history_yield())
+
+ def append(self, data: _Tp):
+ """
+ Overview:
+ Append the data.
+ """
+
+ with self.__lock:
+ self.__flush_history()
+ _time = self._get_time()
+ self.__append(_time, data)
+ return self
+
+ def extend(self, iter_: Iterable[_Tp]):
+ """
+ Overview:
+ Extend the data.
+ """
+
+ with self.__lock:
+ self.__flush_history()
+ _time = self._get_time()
+ for item in iter_:
+ self.__append(_time, item)
+ return self
+
+ def current(self) -> _Tp:
+ """
+ Overview:
+ Get the current data.
+ """
+
+ with self.__lock:
+ self.__flush_history()
+ return self.__current()
+
+ def history(self) -> List[Tuple[Union[int, float], _Tp]]:
+ """
+ Overview:
+ Get the history data.
+ """
+
+ with self.__lock:
+ self.__flush_history()
+ return self.__history()
+
+ @property
+ def expire(self) -> float:
+ """
+ Overview:
+ Get the expire time.
+ """
+
+ with self.__lock:
+ self.__flush_history()
+ return self.__expire
+
+ def __bool__(self):
+ """
+ Overview:
+ Check whether the range is empty.
+ """
+
+ with self.__lock:
+ self.__flush_history()
+ return not not (self.__queue or self.__last_item)
+
+ @abstractmethod
+ def _get_time(self) -> float:
+ """
+ Overview:
+ Get the current time.
+ """
+
+ raise NotImplementedError
+
+
+class TimeRangedData(RangedData):
+ """
+ Overview:
+ A data structure that can store data for a period of time.
+ Interfaces:
+ ``__init__``, ``_get_time``, ``append``, ``extend``, ``current``, ``history``, ``expire``, ``__bool__``.
+ Properties:
+ - time (:obj:`BaseTime`): The time.
+ - expire (:obj:`float`): The expire time.
+ """
+
+ def __init__(self, time_: BaseTime, expire: float):
+ """
+ Overview:
+ Initialize the TimeRangedData object.
+ Arguments:
+ - time_ (:obj:`BaseTime`): The time.
+ - expire (:obj:`float`): The expire time.
+ """
+
+ RangedData.__init__(self, expire)
+ self.__time = time_
+
+ def _get_time(self) -> float:
+ """
+ Overview:
+ Get the current time.
+ """
+
+ return self.__time.time()
+
+ @property
+ def time(self):
+ """
+ Overview:
+ Get the time.
+ """
+
+ return self.__time
diff --git a/DI-engine/ding/utils/autolog/model.py b/DI-engine/ding/utils/autolog/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c58bb6544a5fe6c467b44aadc3646cbf020eddd
--- /dev/null
+++ b/DI-engine/ding/utils/autolog/model.py
@@ -0,0 +1,301 @@
+from abc import ABCMeta
+from typing import TypeVar, Union, List, Any
+
+from .base import _LOGGED_MODEL__PROPERTIES, _LOGGED_MODEL__PROPERTY_ATTR_PREFIX, _TimeType, TimeMode, \
+ _LOGGED_VALUE__PROPERTY_NAME
+from .data import TimeRangedData
+from .time_ctl import BaseTime, TimeProxy
+from .value import LoggedValue
+
+_TimeObjectType = TypeVar('_TimeObjectType', bound=BaseTime)
+
+
+class _LoggedModelMeta(ABCMeta):
+ """
+ Overview:
+ Metaclass of LoggedModel, used to find all LoggedValue properties and register them.
+ Interfaces:
+ ``__init__``
+ """
+
+ def __init__(cls, name: str, bases: tuple, namespace: dict):
+
+ super().__init__(name, bases, namespace)
+
+ _properties = []
+ for k, v in namespace.items():
+ if isinstance(v, LoggedValue):
+ setattr(v, _LOGGED_VALUE__PROPERTY_NAME, k)
+ _properties.append(k)
+
+ setattr(cls, _LOGGED_MODEL__PROPERTIES, _properties)
+
+
+class LoggedModel(metaclass=_LoggedModelMeta):
+ """
+ Overview:
+ A model with timeline (integered time, such as 1st, 2nd, 3rd, can also be modeled as a kind
+ of self-defined discrete time, such as the implement of TickTime). Serveral values have association
+ with each other can be maintained together by using LoggedModel.
+
+ Example:
+ Define AvgList model like this
+
+ >>> from ding.utils.autolog import LoggedValue, LoggedModel
+ >>> class AvgList(LoggedModel):
+ >>> value = LoggedValue(float)
+ >>> __property_names = ['value']
+ >>>
+ >>> def __init__(self, time_: BaseTime, expire: Union[int, float]):
+ >>> LoggedModel.__init__(self, time_, expire)
+ >>> # attention, original value must be set in __init__ function, or it will not
+ >>> # be activated, the timeline of this value will also be unexpectedly affected.
+ >>> self.value = 0.0
+ >>> self.__register()
+ >>>
+ >>> def __register(self):
+ >>> def __avg_func(prop_name: str) -> float: # function to calculate average value of properties
+ >>> records = self.range_values[prop_name]()
+ >>> (_start_time, _), _ = records[0]
+ >>> (_, _end_time), _ = records[-1]
+ >>>
+ >>> _duration = _end_time - _start_time
+ >>> _sum = sum([_value * (_end_time - _begin_time) for (_begin_time, _end_time), _value in records])
+ >>>
+ >>> return _sum / _duration
+ >>>
+ >>> for _prop_name in self.__property_names:
+ >>> self.register_attribute_value('avg', _prop_name, partial(__avg_func, prop_name=_prop_name))
+
+ Use it like this
+
+ >>> from ding.utils.autolog import NaturalTime, TimeMode
+ >>>
+ >>> if __name__ == "__main__":
+ >>> _time = NaturalTime()
+ >>> ll = AvgList(_time, expire=10)
+ >>>
+ >>> # just do something here ...
+ >>>
+ >>> print(ll.range_values['value']()) # original range_values function in LoggedModel of last 10 secs
+ >>> print(ll.range_values['value'](TimeMode.ABSOLUTE)) # use absolute time
+ >>> print(ll.avg['value']()) # average value of last 10 secs
+
+ Interfaces:
+ ``__init__``, ``time``, ``expire``, ``fixed_time``, ``current_time``, ``freeze``, ``unfreeze``, \
+ ``register_attribute_value``, ``__getattr__``, ``get_property_attribute``
+
+ Property:
+ - time (:obj:`BaseTime`): The time.
+ - expire (:obj:`float`): The expire time.
+ """
+
+ def __init__(self, time_: _TimeObjectType, expire: _TimeType):
+ """
+ Overview:
+ Initialize the LoggedModel object using the given arguments.
+ Arguments:
+ - time_ (:obj:`BaseTime`): The time.
+ - expire (:obj:`float`): The expire time.
+ """
+
+ self.__time = time_
+ self.__time_proxy = TimeProxy(self.__time, frozen=False)
+ self.__init_time = self.__time_proxy.time()
+ self.__expire = expire
+
+ self.__methods = {}
+ self.__prop2attr = {} # used to find registerd attributes list according to property name
+
+ self.__init_properties()
+ self.__register_default_funcs()
+
+ @property
+ def __properties(self) -> List[str]:
+ """
+ Overview:
+ Get all property names.
+ """
+
+ return getattr(self, _LOGGED_MODEL__PROPERTIES)
+
+ def __get_property_ranged_data(self, name: str) -> TimeRangedData:
+ """
+ Overview:
+ Get ranged data of one property.
+ Arguments:
+ - name (:obj:`str`): The property name.
+ """
+
+ return getattr(self, _LOGGED_MODEL__PROPERTY_ATTR_PREFIX + name)
+
+ def __init_properties(self):
+ """
+ Overview:
+ Initialize all properties.
+ """
+
+ for name in self.__properties:
+ setattr(
+ self, _LOGGED_MODEL__PROPERTY_ATTR_PREFIX + name,
+ TimeRangedData(self.__time_proxy, expire=self.__expire)
+ )
+
+ def __get_range_values_func(self, name: str):
+ """
+ Overview:
+ Get range_values function of one property.
+ Arguments:
+ - name (:obj:`str`): The property name.
+ """
+
+ def _func(mode: TimeMode = TimeMode.RELATIVE_LIFECYCLE):
+ _current_time = self.__time_proxy.time()
+ _result = self.__get_property_ranged_data(name).history()
+
+ if mode == TimeMode.RELATIVE_LIFECYCLE:
+ _result = [(_time - self.__init_time, _data) for _time, _data in _result]
+ elif mode == TimeMode.RELATIVE_CURRENT_TIME:
+ _result = [(_time - _current_time, _data) for _time, _data in _result]
+
+ _ranges = []
+ for i in range(0, len(_result) - 1):
+ _this_time, _this_data = _result[i]
+ _next_time, _next_data = _result[i + 1]
+ _ranges.append(((_this_time, _next_time), _this_data))
+
+ return _ranges
+
+ return _func
+
+ def __register_default_funcs(self):
+ """
+ Overview:
+ Register default functions.
+ """
+
+ for name in self.__properties:
+ self.register_attribute_value('range_values', name, self.__get_range_values_func(name))
+
+ @property
+ def time(self) -> _TimeObjectType:
+ """
+ Overview:
+ Get original time object passed in, can execute method (such as step()) by this property.
+
+ Returns:
+ BaseTime: time object used by this model
+ """
+ return self.__time
+
+ @property
+ def expire(self) -> _TimeType:
+ """
+ Overview:
+ Get expire time
+
+ Returns:
+ int or float: time that old value records expired
+ """
+ return self.__expire
+
+ def fixed_time(self) -> Union[float, int]:
+ """
+ Overview:
+ Get fixed time (will be frozen time if time proxy is frozen)
+ This feature can be useful when adding value replay feature (in the future)
+
+ Returns:
+ int or float: fixed time
+ """
+ return self.__time_proxy.time()
+
+ def current_time(self) -> Union[float, int]:
+ """
+ Overview:
+ Get current time (real time that regardless of time proxy's frozen statement)
+
+ Returns:
+ int or float: current time
+ """
+ return self.__time_proxy.current_time()
+
+ def freeze(self):
+ """
+ Overview:
+ Freeze time proxy object.
+ This feature can be useful when adding value replay feature (in the future)
+ """
+ self.__time_proxy.freeze()
+
+ def unfreeze(self):
+ """
+ Overview:
+ Unfreeze time proxy object.
+ This feature can be useful when adding value replay feature (in the future)
+ """
+ self.__time_proxy.unfreeze()
+
+ def register_attribute_value(self, attribute_name: str, property_name: str, value: Any):
+ """
+ Overview:
+ Register a new attribute for one of the values. Example can be found in overview of class.
+ Arguments:
+ - attribute_name (:obj:`str`): name of attribute
+ - property_name (:obj:`str`): name of property
+ - value (:obj:`Any`): value of attribute
+ """
+ self.__methods[attribute_name] = self.__methods.get(attribute_name, {})
+ self.__methods[attribute_name][property_name] = value
+ if attribute_name == "range_values":
+ # "range_values" is not added to ``self.__prop2attr``
+ return
+ self.__prop2attr[property_name] = self.__prop2attr.get(property_name, [])
+ self.__prop2attr[property_name].append(attribute_name)
+
+ def __getattr__(self, attribute_name: str) -> Any:
+ """
+ Overview:
+ Support all methods registered.
+
+ Arguments:
+ attribute_name (str): name of attribute
+
+ Return:
+ A indelible object that can return attribute value.
+
+ Example:
+ >>> ll = AvgList(NaturalTime(), expire=10)
+ >>> ll.range_value['value'] # get 'range_value' attribute of 'value' property, it should be a function
+ """
+ if attribute_name in self.__methods.keys():
+ _attributes = self.__methods[attribute_name]
+
+ class _Cls:
+
+ def __getitem__(self, property_name: str):
+ if property_name in _attributes.keys():
+ return _attributes[property_name]
+ else:
+ raise KeyError(
+ "Attribute {attr_name} for property {prop_name} not found.".format(
+ attr_name=repr(attribute_name),
+ prop_name=repr(property_name),
+ )
+ )
+
+ return _Cls()
+ else:
+ raise KeyError("Attribute {name} not found.".format(name=repr(attribute_name)))
+
+ def get_property_attribute(self, property_name: str) -> List[str]:
+ """
+ Overview:
+ Find all registered attributes (except common "range_values" attribute, since "range_values" is not
+ added to ``self.__prop2attr``) of one given property.
+ Arguments:
+ - property_name (:obj:`str`): name of property to query attributes
+ Returns:
+ - attr_list (:obj:`List[str]`): the registered attributes list of the input property
+ """
+ return self.__prop2attr[property_name]
diff --git a/DI-engine/ding/utils/autolog/tests/__init__.py b/DI-engine/ding/utils/autolog/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/ding/utils/autolog/tests/test_data.py b/DI-engine/ding/utils/autolog/tests/test_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf51dd058ff3765b30bf2a6e782eec01a8879b68
--- /dev/null
+++ b/DI-engine/ding/utils/autolog/tests/test_data.py
@@ -0,0 +1,96 @@
+import pytest
+
+from ding.utils.autolog import TimeRangedData, NaturalTime, TickTime
+
+
+@pytest.mark.unittest
+class TestAutologRangedData:
+
+ def test_expire(self):
+ data = TimeRangedData(NaturalTime(), expire=5)
+ assert data.expire == 5
+
+ with pytest.raises(ValueError):
+ TimeRangedData(NaturalTime(), expire=-1)
+
+ with pytest.raises(TypeError):
+ TimeRangedData(NaturalTime(), expire='5')
+
+ def test_bool(self):
+ data = TimeRangedData(TickTime(), expire=5)
+ assert not data
+
+ data.append(233)
+ assert data
+
+ data.time.step()
+ data.extend([2, 3, 5, 7])
+ assert data
+
+ data.time.step(4)
+ assert data
+
+ data.time.step(1)
+ assert data
+
+ data.time.step(1)
+ assert data
+
+ data.time.step(1)
+ assert data
+
+ data.time.step(10)
+ assert data
+
+ def test_current(self):
+ data = TimeRangedData(TickTime(), expire=5)
+ with pytest.raises(ValueError):
+ _ = data.current()
+
+ data.append(233)
+ assert data.current() == 233
+
+ data.time.step()
+ data.extend([2, 3, 5, 7])
+ assert data.current() == 7
+
+ data.time.step(4)
+ assert data.current() == 7
+
+ data.time.step(1)
+ assert data.current() == 7
+
+ data.time.step(1)
+ assert data.current() == 7
+
+ data.time.step(1)
+ assert data.current() == 7
+
+ data.time.step(10)
+ assert data.current() == 7
+
+ def test_history(self):
+ data = TimeRangedData(TickTime(), expire=5)
+ assert data.history() == []
+
+ data.append(233)
+ assert data.history() == [(0, 233)]
+
+ data.time.step()
+ data.extend([2, 3, 5, 7])
+ assert data.history() == [(0, 233), (1, 2), (1, 3), (1, 5), (1, 7)]
+
+ data.time.step(4)
+ assert data.history() == [(0, 233), (1, 2), (1, 3), (1, 5), (1, 7), (5, 7)]
+
+ data.time.step(1)
+ assert data.history() == [(1, 233), (1, 2), (1, 3), (1, 5), (1, 7), (6, 7)]
+
+ data.time.step(1)
+ assert data.history() == [(2, 7), (7, 7)]
+
+ data.time.step(1)
+ assert data.history() == [(3, 7), (8, 7)]
+
+ data.time.step(10)
+ assert data.history() == [(13, 7), (18, 7)]
diff --git a/DI-engine/ding/utils/autolog/tests/test_model.py b/DI-engine/ding/utils/autolog/tests/test_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..834ce9095e2a762247166c46e0e29f89a1712053
--- /dev/null
+++ b/DI-engine/ding/utils/autolog/tests/test_model.py
@@ -0,0 +1,432 @@
+import time
+from functools import partial
+from typing import Union
+
+import pytest
+
+from ding.utils.autolog import LoggedModel, LoggedValue, TickTime, NaturalTime, TimeMode
+
+
+# noinspection DuplicatedCode
+@pytest.mark.unittest
+class TestAutologModel:
+
+ def __get_demo_class(self):
+ # noinspection DuplicatedCode
+ class _TickModel(LoggedModel):
+ in_time = LoggedValue(float)
+ out_time = LoggedValue(float)
+ __thruput_property_names = ['in_time', 'out_time']
+
+ def __init__(self, time_: 'BaseTime', expire: Union[int, float]): # noqa
+ LoggedModel.__init__(self, time_, expire)
+ self.__register()
+
+ def __register(self):
+
+ def __avg_func(prop_name: str) -> float:
+ records = self.range_values[prop_name]()
+ _sum = sum([_value for (_begin_time, _end_time), _value in records])
+ return _sum / self.expire
+
+ for _prop_name in self.__thruput_property_names:
+ self.register_attribute_value('thruput', _prop_name, partial(__avg_func, _prop_name))
+ self.register_attribute_value(
+ 'reversed_name', _prop_name, partial(lambda name: name[::-1], _prop_name)
+ )
+
+ return _TickModel
+
+ def test_getter_and_setter(self):
+ _class = self.__get_demo_class()
+
+ _time = TickTime()
+ _tick_monitor = _class(_time, expire=5)
+
+ with pytest.raises(ValueError):
+ _ = _tick_monitor.in_time
+ with pytest.raises(ValueError):
+ _ = _tick_monitor.out_time
+
+ _tick_monitor.in_time = 2.0
+ assert _tick_monitor.in_time == 2.0
+
+ with pytest.raises(TypeError):
+ _tick_monitor.in_time = None
+ assert _tick_monitor.in_time == 2.0
+
+ def test_property_getter(self):
+ _class = self.__get_demo_class()
+
+ _time = TickTime()
+ _tick_monitor = _class(_time, expire=5)
+
+ assert _tick_monitor.reversed_name['in_time']() == 'emit_ni'
+ assert _tick_monitor.reversed_name['out_time']() == 'emit_tuo'
+
+ with pytest.raises(KeyError):
+ _tick_monitor.reversed_name['property_not_exist']()
+ with pytest.raises(KeyError):
+ _tick_monitor.reversed_nam['in_time']()
+
+ def test_time(self):
+ _class = self.__get_demo_class()
+
+ _time = TickTime()
+ _tick_monitor = _class(_time, expire=5)
+
+ assert id(_tick_monitor.time) == id(_time)
+ assert _tick_monitor.fixed_time() == 0
+ assert _tick_monitor.current_time() == 0
+
+ _tick_monitor.freeze()
+ _time.step()
+ assert _tick_monitor.fixed_time() == 0
+ assert _tick_monitor.current_time() == 1
+
+ _tick_monitor.unfreeze()
+ assert _tick_monitor.fixed_time() == 1
+ assert _tick_monitor.current_time() == 1
+
+ def test_expire(self):
+ _class = self.__get_demo_class()
+
+ _time = TickTime()
+ _tick_monitor = _class(_time, expire=5)
+
+ assert _tick_monitor.expire == 5
+
+ def test_with_tick_time(self):
+ _class = self.__get_demo_class()
+
+ _time = TickTime()
+ _tick_monitor = _class(_time, expire=5)
+
+ _assert_results = [
+ (0.0, 0.0),
+ (0.2, 0.4),
+ (0.6, 1.2),
+ (1.2, 2.4),
+ (2.0, 4.0),
+ (3.0, 6.0),
+ (4.2, 8.4),
+ (5.4, 10.8),
+ (6.6, 13.2),
+ (7.8, 15.6),
+ ]
+
+ for i in range(0, 10):
+ _tick_monitor.in_time = 1.0 * i
+ _tick_monitor.out_time = 2.0 * i
+ _time.step()
+
+ _thin, _thout = _tick_monitor.thruput['in_time'](), _tick_monitor.thruput['out_time']()
+ _exp_thin, _exp_thout = _assert_results[i]
+
+ assert _thin == _exp_thin
+ assert _thout == _exp_thout
+
+ def test_with_natural_time(self):
+ _class = self.__get_demo_class()
+
+ _time = NaturalTime()
+ _tick_monitor = _class(_time, expire=5)
+
+ _assert_results = [
+ (0.0, 0.0),
+ (0.2, 0.4),
+ (0.6, 1.2),
+ (1.2, 2.4),
+ (2.0, 4.0),
+ (3.0, 6.0),
+ (4.0, 8.0),
+ (5.0, 10.0),
+ (6.0, 12.0),
+ (7.0, 14.0),
+ ]
+
+ for i in range(0, 10):
+ _tick_monitor.in_time = 1.0 * i
+ _tick_monitor.out_time = 2.0 * i
+ time.sleep(1.0)
+
+ _thin, _thout = _tick_monitor.thruput['in_time'](), _tick_monitor.thruput['out_time']()
+ _exp_thin, _exp_thout = _assert_results[i]
+
+ assert abs(_thin - _exp_thin) < 0.1
+ assert abs(_thout - _exp_thout) < 0.1
+
+ def test_double_model(self):
+ _class = self.__get_demo_class()
+
+ _time = TickTime()
+ _tick_monitor_1 = _class(_time, expire=5)
+ _tick_monitor_2 = _class(_time, expire=5)
+
+ _assert_results_1 = [
+ (0.0, 0.0),
+ (0.2, 0.4),
+ (0.6, 1.2),
+ (1.2, 2.4),
+ (2.0, 4.0),
+ (3.0, 6.0),
+ (4.2, 8.4),
+ (5.4, 10.8),
+ (6.6, 13.2),
+ (7.8, 15.6),
+ ]
+ _assert_results_2 = [
+ (0.0, 0.0), (0.4, 0.8), (1.2, 2.4), (2.4, 4.8), (4.0, 8.0), (6.0, 12.0), (8.4, 16.8), (10.8, 21.6),
+ (13.2, 26.4), (15.6, 31.2)
+ ]
+
+ for i in range(0, 10):
+ _tick_monitor_1.in_time = 1.0 * i
+ _tick_monitor_1.out_time = 2.0 * i
+ _tick_monitor_2.in_time = 2.0 * i
+ _tick_monitor_2.out_time = 4.0 * i
+
+ _time.step()
+
+ _thin_1, _thout_1 = _tick_monitor_1.thruput['in_time'](), _tick_monitor_1.thruput['out_time']()
+ _exp_thin_1, _exp_thout_1 = _assert_results_1[i]
+
+ _thin_2, _thout_2 = _tick_monitor_2.thruput['in_time'](), _tick_monitor_2.thruput['out_time']()
+ _exp_thin_2, _exp_thout_2 = _assert_results_2[i]
+
+ assert (_thin_1, _thout_1) == (_exp_thin_1, _exp_thout_1)
+ assert (_thin_2, _thout_2) == (_exp_thin_2, _exp_thout_2)
+
+ def test_range_values_default(self):
+ _class = self.__get_demo_class()
+
+ _time = TickTime()
+ _tick_monitor = _class(_time, expire=5)
+
+ _assert_results = [
+ ([((0, 1), 0.0)], [((0, 1), 0.0)]),
+ ([((0, 1), 0.0), ((1, 2), 1.0)], [((0, 1), 0.0), ((1, 2), 2.0)]),
+ ([((0, 1), 0.0), ((1, 2), 1.0), ((2, 3), 2.0)], [((0, 1), 0.0), ((1, 2), 2.0), ((2, 3), 4.0)]),
+ (
+ [((0, 1), 0.0), ((1, 2), 1.0), ((2, 3), 2.0),
+ ((3, 4), 3.0)], [((0, 1), 0.0), ((1, 2), 2.0), ((2, 3), 4.0), ((3, 4), 6.0)]
+ ),
+ (
+ [((0, 1), 0.0), ((1, 2), 1.0), ((2, 3), 2.0), ((3, 4), 3.0),
+ ((4, 5), 4.0)], [((0, 1), 0.0), ((1, 2), 2.0), ((2, 3), 4.0), ((3, 4), 6.0), ((4, 5), 8.0)]
+ ),
+ (
+ [((1, 1), 0.0), ((1, 2), 1.0), ((2, 3), 2.0), ((3, 4), 3.0), ((4, 5), 4.0), ((5, 6), 5.0)], [
+ ((1, 1), 0.0), ((1, 2), 2.0), ((2, 3), 4.0), ((3, 4), 6.0), ((4, 5), 8.0), ((5, 6), 10.0)
+ ]
+ ),
+ (
+ [((2, 2), 1.0), ((2, 3), 2.0), ((3, 4), 3.0), ((4, 5), 4.0), ((5, 6), 5.0), ((6, 7), 6.0)], [
+ ((2, 2), 2.0), ((2, 3), 4.0), ((3, 4), 6.0), ((4, 5), 8.0), ((5, 6), 10.0), ((6, 7), 12.0)
+ ]
+ ),
+ (
+ [((3, 3), 2.0), ((3, 4), 3.0), ((4, 5), 4.0), ((5, 6), 5.0), ((6, 7), 6.0), ((7, 8), 7.0)], [
+ ((3, 3), 4.0), ((3, 4), 6.0), ((4, 5), 8.0), ((5, 6), 10.0), ((6, 7), 12.0), ((7, 8), 14.0)
+ ]
+ ),
+ (
+ [((4, 4), 3.0), ((4, 5), 4.0), ((5, 6), 5.0), ((6, 7), 6.0), ((7, 8), 7.0), ((8, 9), 8.0)], [
+ ((4, 4), 6.0), ((4, 5), 8.0), ((5, 6), 10.0), ((6, 7), 12.0), ((7, 8), 14.0), ((8, 9), 16.0)
+ ]
+ ),
+ (
+ [((5, 5), 4.0), ((5, 6), 5.0), ((6, 7), 6.0), ((7, 8), 7.0), ((8, 9), 8.0), ((9, 10), 9.0)], [
+ ((5, 5), 8.0), ((5, 6), 10.0), ((6, 7), 12.0), ((7, 8), 14.0), ((8, 9), 16.0), ((9, 10), 18.0)
+ ]
+ ),
+ ]
+
+ for i in range(0, 10):
+ _tick_monitor.in_time = 1.0 * i
+ _tick_monitor.out_time = 2.0 * i
+ _time.step()
+
+ _thin, _thout = _tick_monitor.range_values['in_time'](), _tick_monitor.range_values['out_time']()
+ _exp_thin, _exp_thout = _assert_results[i]
+
+ assert (_thin, _thout) == (_exp_thin, _exp_thout)
+
+ def test_range_values_absolute(self):
+ _class = self.__get_demo_class()
+
+ _time = TickTime(1)
+ _tick_monitor = _class(_time, expire=5)
+
+ _assert_results = [
+ ([((1, 2), 0.0)], [((1, 2), 0.0)]),
+ ([((1, 2), 0.0), ((2, 3), 1.0)], [((1, 2), 0.0), ((2, 3), 2.0)]),
+ ([((1, 2), 0.0), ((2, 3), 1.0), ((3, 4), 2.0)], [((1, 2), 0.0), ((2, 3), 2.0), ((3, 4), 4.0)]),
+ (
+ [((1, 2), 0.0), ((2, 3), 1.0), ((3, 4), 2.0),
+ ((4, 5), 3.0)], [((1, 2), 0.0), ((2, 3), 2.0), ((3, 4), 4.0), ((4, 5), 6.0)]
+ ),
+ (
+ [((1, 2), 0.0), ((2, 3), 1.0), ((3, 4), 2.0), ((4, 5), 3.0),
+ ((5, 6), 4.0)], [((1, 2), 0.0), ((2, 3), 2.0), ((3, 4), 4.0), ((4, 5), 6.0), ((5, 6), 8.0)]
+ ),
+ (
+ [((2, 2), 0.0), ((2, 3), 1.0), ((3, 4), 2.0), ((4, 5), 3.0), ((5, 6), 4.0), ((6, 7), 5.0)], [
+ ((2, 2), 0.0), ((2, 3), 2.0), ((3, 4), 4.0), ((4, 5), 6.0), ((5, 6), 8.0), ((6, 7), 10.0)
+ ]
+ ),
+ (
+ [((3, 3), 1.0), ((3, 4), 2.0), ((4, 5), 3.0), ((5, 6), 4.0), ((6, 7), 5.0), ((7, 8), 6.0)], [
+ ((3, 3), 2.0), ((3, 4), 4.0), ((4, 5), 6.0), ((5, 6), 8.0), ((6, 7), 10.0), ((7, 8), 12.0)
+ ]
+ ),
+ (
+ [((4, 4), 2.0), ((4, 5), 3.0), ((5, 6), 4.0), ((6, 7), 5.0), ((7, 8), 6.0), ((8, 9), 7.0)], [
+ ((4, 4), 4.0), ((4, 5), 6.0), ((5, 6), 8.0), ((6, 7), 10.0), ((7, 8), 12.0), ((8, 9), 14.0)
+ ]
+ ),
+ (
+ [((5, 5), 3.0), ((5, 6), 4.0), ((6, 7), 5.0), ((7, 8), 6.0), ((8, 9), 7.0), ((9, 10), 8.0)], [
+ ((5, 5), 6.0), ((5, 6), 8.0), ((6, 7), 10.0), ((7, 8), 12.0), ((8, 9), 14.0), ((9, 10), 16.0)
+ ]
+ ),
+ (
+ [((6, 6), 4.0), ((6, 7), 5.0), ((7, 8), 6.0), ((8, 9), 7.0), ((9, 10), 8.0), ((10, 11), 9.0)], [
+ ((6, 6), 8.0), ((6, 7), 10.0), ((7, 8), 12.0), ((8, 9), 14.0), ((9, 10), 16.0), ((10, 11), 18.0)
+ ]
+ ),
+ ]
+
+ for i in range(0, 10):
+ _tick_monitor.in_time = 1.0 * i
+ _tick_monitor.out_time = 2.0 * i
+ _time.step()
+
+ _thin = _tick_monitor.range_values['in_time'](TimeMode.ABSOLUTE)
+ _thout = _tick_monitor.range_values['out_time'](TimeMode.ABSOLUTE)
+ _exp_thin, _exp_thout = _assert_results[i]
+
+ assert (_thin, _thout) == (_exp_thin, _exp_thout)
+
+ def test_range_values_lifecycle(self):
+ _class = self.__get_demo_class()
+
+ _time = TickTime(1)
+ _tick_monitor = _class(_time, expire=5)
+
+ _assert_results = [
+ ([((0, 1), 0.0)], [((0, 1), 0.0)]),
+ ([((0, 1), 0.0), ((1, 2), 1.0)], [((0, 1), 0.0), ((1, 2), 2.0)]),
+ ([((0, 1), 0.0), ((1, 2), 1.0), ((2, 3), 2.0)], [((0, 1), 0.0), ((1, 2), 2.0), ((2, 3), 4.0)]),
+ (
+ [((0, 1), 0.0), ((1, 2), 1.0), ((2, 3), 2.0),
+ ((3, 4), 3.0)], [((0, 1), 0.0), ((1, 2), 2.0), ((2, 3), 4.0), ((3, 4), 6.0)]
+ ),
+ (
+ [((0, 1), 0.0), ((1, 2), 1.0), ((2, 3), 2.0), ((3, 4), 3.0),
+ ((4, 5), 4.0)], [((0, 1), 0.0), ((1, 2), 2.0), ((2, 3), 4.0), ((3, 4), 6.0), ((4, 5), 8.0)]
+ ),
+ (
+ [((1, 1), 0.0), ((1, 2), 1.0), ((2, 3), 2.0), ((3, 4), 3.0), ((4, 5), 4.0), ((5, 6), 5.0)], [
+ ((1, 1), 0.0), ((1, 2), 2.0), ((2, 3), 4.0), ((3, 4), 6.0), ((4, 5), 8.0), ((5, 6), 10.0)
+ ]
+ ),
+ (
+ [((2, 2), 1.0), ((2, 3), 2.0), ((3, 4), 3.0), ((4, 5), 4.0), ((5, 6), 5.0), ((6, 7), 6.0)], [
+ ((2, 2), 2.0), ((2, 3), 4.0), ((3, 4), 6.0), ((4, 5), 8.0), ((5, 6), 10.0), ((6, 7), 12.0)
+ ]
+ ),
+ (
+ [((3, 3), 2.0), ((3, 4), 3.0), ((4, 5), 4.0), ((5, 6), 5.0), ((6, 7), 6.0), ((7, 8), 7.0)], [
+ ((3, 3), 4.0), ((3, 4), 6.0), ((4, 5), 8.0), ((5, 6), 10.0), ((6, 7), 12.0), ((7, 8), 14.0)
+ ]
+ ),
+ (
+ [((4, 4), 3.0), ((4, 5), 4.0), ((5, 6), 5.0), ((6, 7), 6.0), ((7, 8), 7.0), ((8, 9), 8.0)], [
+ ((4, 4), 6.0), ((4, 5), 8.0), ((5, 6), 10.0), ((6, 7), 12.0), ((7, 8), 14.0), ((8, 9), 16.0)
+ ]
+ ),
+ (
+ [((5, 5), 4.0), ((5, 6), 5.0), ((6, 7), 6.0), ((7, 8), 7.0), ((8, 9), 8.0), ((9, 10), 9.0)], [
+ ((5, 5), 8.0), ((5, 6), 10.0), ((6, 7), 12.0), ((7, 8), 14.0), ((8, 9), 16.0), ((9, 10), 18.0)
+ ]
+ ),
+ ]
+
+ for i in range(0, 10):
+ _tick_monitor.in_time = 1.0 * i
+ _tick_monitor.out_time = 2.0 * i
+ _time.step()
+
+ # print('(', _tick_monitor.range_values['in_time'](TimeMode.RELATIVE_LIFECYCLE), ',',
+ # _tick_monitor.range_values['out_time'](TimeMode.RELATIVE_LIFECYCLE), '),')
+
+ _thin = _tick_monitor.range_values['in_time'](TimeMode.RELATIVE_LIFECYCLE)
+ _thout = _tick_monitor.range_values['out_time'](TimeMode.RELATIVE_LIFECYCLE)
+ _exp_thin, _exp_thout = _assert_results[i]
+
+ assert (_thin, _thout) == (_exp_thin, _exp_thout)
+
+ def test_range_values_current(self):
+ _class = self.__get_demo_class()
+
+ _time = TickTime(1)
+ _tick_monitor = _class(_time, expire=5)
+
+ _assert_results = [
+ ([((-1, 0), 0.0)], [((-1, 0), 0.0)]),
+ ([((-2, -1), 0.0), ((-1, 0), 1.0)], [((-2, -1), 0.0), ((-1, 0), 2.0)]),
+ ([((-3, -2), 0.0), ((-2, -1), 1.0), ((-1, 0), 2.0)], [((-3, -2), 0.0), ((-2, -1), 2.0), ((-1, 0), 4.0)]),
+ (
+ [((-4, -3), 0.0), ((-3, -2), 1.0), ((-2, -1), 2.0),
+ ((-1, 0), 3.0)], [((-4, -3), 0.0), ((-3, -2), 2.0), ((-2, -1), 4.0), ((-1, 0), 6.0)]
+ ),
+ (
+ [((-5, -4), 0.0), ((-4, -3), 1.0), ((-3, -2), 2.0), ((-2, -1), 3.0),
+ ((-1, 0), 4.0)], [((-5, -4), 0.0), ((-4, -3), 2.0), ((-3, -2), 4.0), ((-2, -1), 6.0), ((-1, 0), 8.0)]
+ ),
+ (
+ [((-5, -5), 0.0), ((-5, -4), 1.0), ((-4, -3), 2.0), ((-3, -2), 3.0), ((-2, -1), 4.0), ((-1, 0), 5.0)], [
+ ((-5, -5), 0.0), ((-5, -4), 2.0), ((-4, -3), 4.0), ((-3, -2), 6.0), ((-2, -1), 8.0),
+ ((-1, 0), 10.0)
+ ]
+ ),
+ (
+ [((-5, -5), 1.0), ((-5, -4), 2.0), ((-4, -3), 3.0), ((-3, -2), 4.0), ((-2, -1), 5.0), ((-1, 0), 6.0)], [
+ ((-5, -5), 2.0), ((-5, -4), 4.0), ((-4, -3), 6.0), ((-3, -2), 8.0), ((-2, -1), 10.0),
+ ((-1, 0), 12.0)
+ ]
+ ),
+ (
+ [((-5, -5), 2.0), ((-5, -4), 3.0), ((-4, -3), 4.0), ((-3, -2), 5.0), ((-2, -1), 6.0), ((-1, 0), 7.0)], [
+ ((-5, -5), 4.0), ((-5, -4), 6.0), ((-4, -3), 8.0), ((-3, -2), 10.0), ((-2, -1), 12.0),
+ ((-1, 0), 14.0)
+ ]
+ ),
+ (
+ [((-5, -5), 3.0), ((-5, -4), 4.0), ((-4, -3), 5.0), ((-3, -2), 6.0), ((-2, -1), 7.0), ((-1, 0), 8.0)], [
+ ((-5, -5), 6.0), ((-5, -4), 8.0), ((-4, -3), 10.0), ((-3, -2), 12.0), ((-2, -1), 14.0),
+ ((-1, 0), 16.0)
+ ]
+ ),
+ (
+ [((-5, -5), 4.0), ((-5, -4), 5.0), ((-4, -3), 6.0), ((-3, -2), 7.0), ((-2, -1), 8.0), ((-1, 0), 9.0)], [
+ ((-5, -5), 8.0), ((-5, -4), 10.0), ((-4, -3), 12.0), ((-3, -2), 14.0), ((-2, -1), 16.0),
+ ((-1, 0), 18.0)
+ ]
+ ),
+ ]
+
+ for i in range(0, 10):
+ _tick_monitor.in_time = 1.0 * i
+ _tick_monitor.out_time = 2.0 * i
+ _time.step()
+
+ # print('(', _tick_monitor.range_values['in_time'](TimeMode.RELATIVE_CURRENT_TIME), ',',
+ # _tick_monitor.range_values['out_time'](TimeMode.RELATIVE_CURRENT_TIME), '),')
+
+ _thin = _tick_monitor.range_values['in_time'](TimeMode.RELATIVE_CURRENT_TIME)
+ _thout = _tick_monitor.range_values['out_time'](TimeMode.RELATIVE_CURRENT_TIME)
+ _exp_thin, _exp_thout = _assert_results[i]
+
+ assert (_thin, _thout) == (_exp_thin, _exp_thout)
diff --git a/DI-engine/ding/utils/autolog/tests/test_time.py b/DI-engine/ding/utils/autolog/tests/test_time.py
new file mode 100644
index 0000000000000000000000000000000000000000..506e40689d674c7b2863c7c8a7a95147be89ffd2
--- /dev/null
+++ b/DI-engine/ding/utils/autolog/tests/test_time.py
@@ -0,0 +1,131 @@
+import time
+from unittest.mock import Mock
+
+import pytest
+
+from ding.utils.autolog import TickTime, NaturalTime, TimeProxy
+
+
+class TestNaturalTime:
+
+ @pytest.mark.unittest
+ def test_natural_time(self):
+ _time = NaturalTime()
+ assert abs(_time.time() - time.time()) < 0.2
+
+ @pytest.mark.benchmark
+ def test_natural_time_for_100k_times(self):
+ for i in range(0, 100000):
+ _time = NaturalTime()
+ assert abs(_time.time() - time.time()) < 0.2
+
+ @pytest.mark.unittest
+ def test_natural_time_with_mad_system(self):
+ _time_func, time.time = time.time, Mock(side_effect=[1.5, 1.8, 2.0, 2.0, 1.75, 1.9, 2.2])
+
+ try:
+ _time = NaturalTime()
+ assert _time.time() == 1.5
+ assert _time.time() == 1.8
+ assert _time.time() == 2.0
+ assert _time.time() == 2.0
+ assert _time.time() == 2.0
+ assert _time.time() == 2.0
+ assert _time.time() == 2.2
+ finally:
+ time.time = _time_func
+
+
+class TestTickTime:
+
+ @pytest.mark.unittest
+ def test_tick_bare(self):
+ _time = TickTime()
+ assert _time.time() == 0
+ assert _time.step() == 1
+ assert _time.time() == 1
+ assert _time.step(2) == 3
+ assert _time.time() == 3
+
+ with pytest.raises(TypeError):
+ _time.step(0.9)
+
+ with pytest.raises(ValueError):
+ _time.step(0)
+
+ @pytest.mark.unittest
+ def test_tick_init(self):
+ _time = TickTime(3)
+ assert _time.time() == 3
+ assert _time.step() == 4
+ assert _time.time() == 4
+ assert _time.step(2) == 6
+ assert _time.time() == 6
+
+ with pytest.raises(TypeError):
+ _time.step(0.9)
+
+ with pytest.raises(ValueError):
+ _time.step(0)
+
+
+class TestTimeProxy:
+
+ @pytest.mark.unittest
+ def test_time_proxy_for_tick_time(self):
+ _time = TickTime()
+ _proxy = TimeProxy(_time)
+
+ assert _proxy.time() == 0
+ assert _proxy.current_time() == 0
+ assert not _proxy.is_frozen
+
+ _time.step()
+ assert _proxy.time() == 1
+ assert _proxy.current_time() == 1
+ assert not _proxy.is_frozen
+
+ _proxy.freeze()
+ _time.step(2)
+ assert _proxy.time() == 1
+ assert _proxy.current_time() == 3
+ assert _proxy.is_frozen
+
+ _time.step()
+ assert _proxy.time() == 1
+ assert _proxy.current_time() == 4
+ assert _proxy.is_frozen
+
+ _proxy.unfreeze()
+ assert _proxy.time() == 4
+ assert _proxy.current_time() == 4
+ assert not _proxy.is_frozen
+
+ @pytest.mark.unittest
+ def test_time_proxy_frozen_for_tick_time(self):
+ _time = TickTime()
+ _proxy = TimeProxy(_time, frozen=True)
+
+ assert _proxy.time() == 0
+ assert _proxy.current_time() == 0
+ assert _proxy.is_frozen
+
+ _time.step()
+ assert _proxy.time() == 0
+ assert _proxy.current_time() == 1
+ assert _proxy.is_frozen
+
+ _time.step(2)
+ assert _proxy.time() == 0
+ assert _proxy.current_time() == 3
+ assert _proxy.is_frozen
+
+ _time.step()
+ assert _proxy.time() == 0
+ assert _proxy.current_time() == 4
+ assert _proxy.is_frozen
+
+ _proxy.unfreeze()
+ assert _proxy.time() == 4
+ assert _proxy.current_time() == 4
+ assert not _proxy.is_frozen
diff --git a/DI-engine/ding/utils/autolog/time_ctl.py b/DI-engine/ding/utils/autolog/time_ctl.py
new file mode 100644
index 0000000000000000000000000000000000000000..110753e4cff1f47b81ad3b0992a5eab0a379bc75
--- /dev/null
+++ b/DI-engine/ding/utils/autolog/time_ctl.py
@@ -0,0 +1,225 @@
+import time
+from abc import ABCMeta, abstractmethod
+from typing import Union
+
+from ..lock_helper import LockContext, LockContextType
+
+
+class BaseTime(metaclass=ABCMeta):
+ """
+ Overview:
+ Abstract time interface
+ Interfaces:
+ ``time``
+ """
+
+ @abstractmethod
+ def time(self) -> Union[int, float]:
+ """
+ Overview:
+ Get time information
+
+ Returns:
+ - time(:obj:`float, int`): time information
+ """
+ raise NotImplementedError
+
+
+class NaturalTime(BaseTime):
+ """
+ Overview:
+ Natural time object
+ Interfaces:
+ ``__init__``, ``time``
+ Example:
+ >>> from ding.utils.autolog.time_ctl import NaturalTime
+ >>> time_ = NaturalTime()
+ """
+
+ def __init__(self):
+ self.__last_time = None
+
+ def time(self) -> float:
+ """
+ Overview:
+ Get current natural time (float format, unix timestamp)
+
+ Returns:
+ - time(:obj:`float`): unix timestamp
+
+ Example:
+ >>> from ding.utils.autolog.time_ctl import NaturalTime
+ >>> time_ = NaturalTime()
+ >>> time_.time()
+ 1603896383.8811457
+ """
+ _current_time = time.time()
+ if self.__last_time is not None:
+ _current_time = max(_current_time, self.__last_time)
+
+ self.__last_time = _current_time
+ return _current_time
+
+
+class TickTime(BaseTime):
+ """
+ Overview:
+ Tick time object
+ Interfaces:
+ ``__init__``, ``step``, ``time``
+ Example:
+ >>> from ding.utils.autolog.time_ctl import TickTime
+ >>> time_ = TickTime()
+ """
+
+ def __init__(self, init: int = 0):
+ """
+ Overview:
+ Constructor of TickTime
+
+ Arguments:
+ - init (:obj:`int`): initial time, default is 0
+ """
+ self.__tick_time = init
+
+ def step(self, delta: int = 1) -> int:
+ """
+ Overview
+ Step the time forward for this TickTime
+
+ Arguments:
+ - delta (:obj:`int`): steps to step forward, default is 1
+
+ Returns:
+ - time (:obj:`int`): new time after stepping
+
+ Example:
+ >>> from ding.utils.autolog.time_ctl import TickTime
+ >>> time_ = TickTime(0)
+ >>> time_.step()
+ 1
+ >>> time_.step(2)
+ 3
+ """
+ if not isinstance(delta, int):
+ raise TypeError("Delta should be positive int, but {actual} found.".format(actual=type(delta).__name__))
+ elif delta < 1:
+ raise ValueError("Delta should be no less than 1, but {actual} found.".format(actual=repr(delta)))
+ else:
+ self.__tick_time += delta
+ return self.__tick_time
+
+ def time(self) -> int:
+ """
+ Overview
+ Get current tick time
+
+ Returns:
+ int: current tick time
+
+ Example:
+ >>> from ding.utils.autolog.time_ctl import TickTime
+ >>> time_ = TickTime(0)
+ >>> time_.step()
+ >>> time_.time()
+ 1
+ """
+ return self.__tick_time
+
+
+class TimeProxy(BaseTime):
+ """
+ Overview:
+ Proxy of time object, it can freeze time, sometimes useful when reproducing.
+ This object is thread-safe, and also freeze and unfreeze operation is strictly ordered.
+ Interfaces:
+ ``__init__``, ``freeze``, ``unfreeze``, ``time``, ``current_time``
+ Example:
+ >>> from ding.utils.autolog.time_ctl import TickTime, TimeProxy
+ >>> tick_time_ = TickTime()
+ >>> time_ = TimeProxy(tick_time_)
+ >>> tick_time_.step()
+ >>> print(tick_time_.time(), time_.time(), time_.current_time())
+ 1 1 1
+ >>> time_.freeze()
+ >>> tick_time_.step()
+ >>> print(tick_time_.time(), time_.time(), time_.current_time())
+ 2 1 2
+ >>> time_.unfreeze()
+ >>> print(tick_time_.time(), time_.time(), time_.current_time())
+ 2 2 2
+ """
+
+ def __init__(self, time_: BaseTime, frozen: bool = False, lock_type: LockContextType = LockContextType.THREAD_LOCK):
+ """
+ Overview:
+ Constructor for Time proxy
+
+ Arguments:
+ - time_ (:obj:`BaseTime`): another time object it based on
+ - frozen (:obj:`bool`): this object will be frozen immediately if true, otherwise not, default is False
+ - lock_type (:obj:`LockContextType`): type of the lock, default is THREAD_LOCK
+ """
+ self.__time = time_
+ self.__current_time = self.__time.time()
+
+ self.__frozen = frozen
+ self.__lock = LockContext(lock_type)
+ self.__frozen_lock = LockContext(lock_type)
+ if self.__frozen:
+ self.__frozen_lock.acquire()
+
+ @property
+ def is_frozen(self) -> bool:
+ """
+ Overview:
+ Get if this time proxy object is frozen
+
+ Returns:
+ bool: true if it is frozen, otherwise false
+ """
+ with self.__lock:
+ return self.__frozen
+
+ def freeze(self):
+ """
+ Overview:
+ Freeze this time proxy
+ """
+ with self.__lock:
+ self.__frozen_lock.acquire()
+ self.__frozen = True
+ self.__current_time = self.__time.time()
+
+ def unfreeze(self):
+ """
+ Overview:
+ Unfreeze this time proxy
+ """
+ with self.__lock:
+ self.__frozen = False
+ self.__frozen_lock.release()
+
+ def time(self) -> Union[int, float]:
+ """
+ Overview:
+ Get time (may be frozen time)
+
+ Returns:
+ int or float: the time
+ """
+ with self.__lock:
+ if self.__frozen:
+ return self.__current_time
+ else:
+ return self.__time.time()
+
+ def current_time(self) -> Union[int, float]:
+ """
+ Overview:
+ Get current time (will not be frozen time)
+
+ Returns:
+ int or float: current time
+ """
+ return self.__time.time()
diff --git a/DI-engine/ding/utils/autolog/value.py b/DI-engine/ding/utils/autolog/value.py
new file mode 100644
index 0000000000000000000000000000000000000000..98510a036ad75b2ed171c71013a4803a93615fce
--- /dev/null
+++ b/DI-engine/ding/utils/autolog/value.py
@@ -0,0 +1,77 @@
+from typing import Type
+
+from .base import _LOGGED_VALUE__PROPERTY_NAME, _LOGGED_MODEL__PROPERTY_ATTR_PREFIX, _ValueType
+from .data import TimeRangedData
+
+
+class LoggedValue:
+ """
+ Overview:
+ LoggedValue can be used as property in LoggedModel, for it has __get__ and __set__ method.
+ This class's instances will be associated with their owner LoggedModel instance, all the LoggedValue
+ of one LoggedModel will shared the only one time object (defined in time_ctl), so that timeline can
+ be managed properly.
+ Interfaces:
+ ``__init__``, ``__get__``, ``__set__``
+ Properties:
+ - __property_name (:obj:`str`): The name of the property.
+ """
+
+ def __init__(self, type_: Type[_ValueType] = object):
+ """
+ Overview:
+ Initialize the LoggedValue object.
+ Interfaces:
+ ``__init__``
+ """
+
+ self.__type = type_
+
+ @property
+ def __property_name(self):
+ """
+ Overview:
+ Get the name of the property.
+ """
+
+ return getattr(self, _LOGGED_VALUE__PROPERTY_NAME)
+
+ def __get_ranged_data(self, instance) -> TimeRangedData:
+ """
+ Overview:
+ Get the ranged data.
+ Interfaces:
+ ``__get_ranged_data``
+ """
+
+ return getattr(instance, _LOGGED_MODEL__PROPERTY_ATTR_PREFIX + self.__property_name)
+
+ def __get__(self, instance, owner):
+ """
+ Overview:
+ Get the value.
+ Arguments:
+ - instance (:obj:`LoggedModel`): The owner LoggedModel instance.
+ - owner (:obj:`type`): The owner class.
+ """
+
+ return self.__get_ranged_data(instance).current()
+
+ def __set__(self, instance, value: _ValueType):
+ """
+ Overview:
+ Set the value.
+ Arguments:
+ - instance (:obj:`LoggedModel`): The owner LoggedModel instance.
+ - value (:obj:`_ValueType`): The value to set.
+ """
+
+ if isinstance(value, self.__type):
+ return self.__get_ranged_data(instance).append(value)
+ else:
+ raise TypeError(
+ 'New value should be {expect}, but {actual} found.'.format(
+ expect=self.__type.__name__,
+ actual=type(value).__name__,
+ )
+ )
diff --git a/DI-engine/ding/utils/bfs_helper.py b/DI-engine/ding/utils/bfs_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..948bba5cf9a435560123a6e0174bf915a8f96ce9
--- /dev/null
+++ b/DI-engine/ding/utils/bfs_helper.py
@@ -0,0 +1,70 @@
+import numpy as np
+import torch
+from gym import Env
+from typing import Tuple, List
+
+
+def get_vi_sequence(env: Env, observation: np.ndarray) -> Tuple[np.ndarray, List]:
+ """
+ Overview:
+ Given an instance of the maze environment and the current observation, using Broad-First-Search (BFS) \
+ algorithm to plan an optimal path and record the result.
+ Arguments:
+ - env (:obj:`Env`): The instance of the maze environment.
+ - observation (:obj:`np.ndarray`): The current observation.
+ Returns:
+ - output (:obj:`Tuple[np.ndarray, List]`): The BFS result. ``output[0]`` contains the BFS map after each \
+ iteration and ``output[1]`` contains the optimal actions before reaching the finishing point.
+ """
+ xy = np.where(observation[Ellipsis, -1] == 1)
+ start_x, start_y = xy[0][0], xy[1][0]
+ target_location = env.target_location
+ nav_map = env.nav_map
+ current_points = [target_location]
+ chosen_actions = {target_location: 0}
+ visited_points = {target_location: True}
+ vi_sequence = []
+
+ vi_map = np.full((env.size, env.size), fill_value=env.n_action, dtype=np.int32)
+
+ found_start = False
+ while current_points and not found_start:
+ next_points = []
+ for point_x, point_y in current_points:
+ for (action, (next_point_x, next_point_y)) in [(0, (point_x - 1, point_y)), (1, (point_x, point_y - 1)),
+ (2, (point_x + 1, point_y)), (3, (point_x, point_y + 1))]:
+
+ if (next_point_x, next_point_y) in visited_points:
+ continue
+
+ if not (0 <= next_point_x < len(nav_map) and 0 <= next_point_y < len(nav_map[next_point_x])):
+ continue
+
+ if nav_map[next_point_x][next_point_y] == 'x':
+ continue
+
+ next_points.append((next_point_x, next_point_y))
+ visited_points[(next_point_x, next_point_y)] = True
+ chosen_actions[(next_point_x, next_point_y)] = action
+ vi_map[next_point_x, next_point_y] = action
+
+ if next_point_x == start_x and next_point_y == start_y:
+ found_start = True
+ vi_sequence.append(vi_map.copy())
+ current_points = next_points
+ track_back = []
+ if found_start:
+ cur_x, cur_y = start_x, start_y
+ while cur_x != target_location[0] or cur_y != target_location[1]:
+ act = vi_sequence[-1][cur_x, cur_y]
+ track_back.append((torch.FloatTensor(env.process_states([cur_x, cur_y], env.get_maze_map())), act))
+ if act == 0:
+ cur_x += 1
+ elif act == 1:
+ cur_y += 1
+ elif act == 2:
+ cur_x -= 1
+ elif act == 3:
+ cur_y -= 1
+
+ return np.array(vi_sequence), track_back
diff --git a/DI-engine/ding/utils/collection_helper.py b/DI-engine/ding/utils/collection_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c8caed6b4d146ba6f9c640c7d9c2c3db5ea9a60
--- /dev/null
+++ b/DI-engine/ding/utils/collection_helper.py
@@ -0,0 +1,23 @@
+from typing import Iterable, TypeVar, Callable
+
+_IterType = TypeVar('_IterType')
+_IterTargetType = TypeVar('_IterTargetType')
+
+
+def iter_mapping(iter_: Iterable[_IterType], mapping: Callable[[_IterType], _IterTargetType]):
+ """
+ Overview:
+ Map a list of iterable elements to input iteration callable
+ Arguments:
+ - iter_(:obj:`_IterType list`): The list for iteration
+ - mapping (:obj:`Callable [[_IterType], _IterTargetType]`): A callable that maps iterable elements function.
+ Return:
+ - (:obj:`iter_mapping object`): Iteration results
+ Example:
+ >>> iterable_list = [1, 2, 3, 4, 5]
+ >>> _iter = iter_mapping(iterable_list, lambda x: x ** 2)
+ >>> print(list(_iter))
+ [1, 4, 9, 16, 25]
+ """
+ for item in iter_:
+ yield mapping(item)
diff --git a/DI-engine/ding/utils/compression_helper.py b/DI-engine/ding/utils/compression_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..71eeef25b07ffcadf679ce5d5e61935df0bf8e83
--- /dev/null
+++ b/DI-engine/ding/utils/compression_helper.py
@@ -0,0 +1,240 @@
+from typing import Any, ByteString, Callable
+import pickle
+import cloudpickle
+import zlib
+import numpy as np
+
+
+class CloudPickleWrapper:
+ """
+ Overview:
+ CloudPickleWrapper can be able to pickle more python object(e.g: an object with lambda expression).
+ Interfaces:
+ ``__init__``, ``__getstate__``, ``__setstate__``.
+ """
+
+ def __init__(self, data: Any) -> None:
+ """
+ Overview:
+ Initialize the CloudPickleWrapper using the given arguments.
+ Arguments:
+ - data (:obj:`Any`): The object to be dumped.
+ """
+ self.data = data
+
+ def __getstate__(self) -> bytes:
+ """
+ Overview:
+ Get the state of the CloudPickleWrapper.
+ Returns:
+ - data (:obj:`bytes`): The dumped byte-like result.
+ """
+
+ return cloudpickle.dumps(self.data)
+
+ def __setstate__(self, data: bytes) -> None:
+ """
+ Overview:
+ Set the state of the CloudPickleWrapper.
+ Arguments:
+ - data (:obj:`bytes`): The dumped byte-like result.
+ """
+
+ if isinstance(data, (tuple, list, np.ndarray)): # pickle is faster
+ self.data = pickle.loads(data)
+ else:
+ self.data = cloudpickle.loads(data)
+
+
+def dummy_compressor(data: Any) -> Any:
+ """
+ Overview:
+ Return the raw input data.
+ Arguments:
+ - data (:obj:`Any`): The input data of the compressor.
+ Returns:
+ - output (:obj:`Any`): This compressor will exactly return the input data.
+ """
+ return data
+
+
+def zlib_data_compressor(data: Any) -> bytes:
+ """
+ Overview:
+ Takes the input compressed data and return the compressed original data (zlib compressor) in binary format.
+ Arguments:
+ - data (:obj:`Any`): The input data of the compressor.
+ Returns:
+ - output (:obj:`bytes`): The compressed byte-like result.
+ Examples:
+ >>> zlib_data_compressor("Hello")
+ b'x\x9ck`\x99\xca\xc9\x00\x01=\xac\x1e\xa999\xf9S\xf4\x00%L\x04j'
+ """
+ return zlib.compress(pickle.dumps(data))
+
+
+def lz4_data_compressor(data: Any) -> bytes:
+ """
+ Overview:
+ Return the compressed original data (lz4 compressor).The compressor outputs in binary format.
+ Arguments:
+ - data (:obj:`Any`): The input data of the compressor.
+ Returns:
+ - output (:obj:`bytes`): The compressed byte-like result.
+ Examples:
+ >>> lz4.block.compress(pickle.dumps("Hello"))
+ b'\x14\x00\x00\x00R\x80\x04\x95\t\x00\x01\x00\x90\x8c\x05Hello\x94.'
+ """
+ try:
+ import lz4.block
+ except ImportError:
+ from ditk import logging
+ import sys
+ logging.warning("Please install lz4 first, such as `pip3 install lz4`")
+ sys.exit(1)
+ return lz4.block.compress(pickle.dumps(data))
+
+
+def jpeg_data_compressor(data: np.ndarray) -> bytes:
+ """
+ Overview:
+ To reduce memory usage, we can choose to store the jpeg strings of image instead of the numpy array in \
+ the buffer. This function encodes the observation numpy arr to the jpeg strings.
+ Arguments:
+ - data (:obj:`np.array`): the observation numpy arr.
+ Returns:
+ - img_str (:obj:`bytes`): The compressed byte-like result.
+ """
+ try:
+ import cv2
+ except ImportError:
+ from ditk import logging
+ import sys
+ logging.warning("Please install opencv-python first.")
+ sys.exit(1)
+ img_str = cv2.imencode('.jpg', data)[1].tobytes()
+
+ return img_str
+
+
+_COMPRESSORS_MAP = {
+ 'lz4': lz4_data_compressor,
+ 'zlib': zlib_data_compressor,
+ 'jpeg': jpeg_data_compressor,
+ 'none': dummy_compressor,
+}
+
+
+def get_data_compressor(name: str):
+ """
+ Overview:
+ Get the data compressor according to the input name.
+ Arguments:
+ - name(:obj:`str`): Name of the compressor, support ``['lz4', 'zlib', 'jpeg', 'none']``
+ Return:
+ - compressor (:obj:`Callable`): Corresponding data_compressor, taking input data returning compressed data.
+ Example:
+ >>> compress_fn = get_data_compressor('lz4')
+ >>> compressed_data = compressed(input_data)
+ """
+ return _COMPRESSORS_MAP[name]
+
+
+def dummy_decompressor(data: Any) -> Any:
+ """
+ Overview:
+ Return the input data.
+ Arguments:
+ - data (:obj:`Any`): The input data of the decompressor.
+ Returns:
+ - output (:obj:`bytes`): The decompressed result, which is exactly the input.
+ """
+ return data
+
+
+def lz4_data_decompressor(compressed_data: bytes) -> Any:
+ """
+ Overview:
+ Return the decompressed original data (lz4 compressor).
+ Arguments:
+ - data (:obj:`bytes`): The input data of the decompressor.
+ Returns:
+ - output (:obj:`Any`): The decompressed object.
+ """
+ try:
+ import lz4.block
+ except ImportError:
+ from ditk import logging
+ import sys
+ logging.warning("Please install lz4 first, such as `pip3 install lz4`")
+ sys.exit(1)
+ return pickle.loads(lz4.block.decompress(compressed_data))
+
+
+def zlib_data_decompressor(compressed_data: bytes) -> Any:
+ """
+ Overview:
+ Return the decompressed original data (zlib compressor).
+ Arguments:
+ - data (:obj:`bytes`): The input data of the decompressor.
+ Returns:
+ - output (:obj:`Any`): The decompressed object.
+ """
+ return pickle.loads(zlib.decompress(compressed_data))
+
+
+def jpeg_data_decompressor(compressed_data: bytes, gray_scale=False) -> np.ndarray:
+ """
+ Overview:
+ To reduce memory usage, we can choose to store the jpeg strings of image instead of the numpy array in the \
+ buffer. This function decodes the observation numpy arr from the jpeg strings.
+ Arguments:
+ - compressed_data (:obj:`bytes`): The jpeg strings.
+ - gray_scale (:obj:`bool`): If the observation is gray, ``gray_scale=True``,
+ if the observation is RGB, ``gray_scale=False``.
+ Returns:
+ - arr (:obj:`np.ndarray`): The decompressed numpy array.
+ """
+ try:
+ import cv2
+ except ImportError:
+ from ditk import logging
+ import sys
+ logging.warning("Please install opencv-python first.")
+ sys.exit(1)
+ nparr = np.frombuffer(compressed_data, np.uint8)
+ if gray_scale:
+ arr = cv2.imdecode(nparr, cv2.IMREAD_GRAYSCALE)
+ arr = np.expand_dims(arr, -1)
+ else:
+ arr = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
+
+ return arr
+
+
+_DECOMPRESSORS_MAP = {
+ 'lz4': lz4_data_decompressor,
+ 'zlib': zlib_data_decompressor,
+ 'jpeg': jpeg_data_decompressor,
+ 'none': dummy_decompressor,
+}
+
+
+def get_data_decompressor(name: str) -> Callable:
+ """
+ Overview:
+ Get the data decompressor according to the input name.
+ Arguments:
+ - name(:obj:`str`): Name of the decompressor, support ``['lz4', 'zlib', 'none']``
+
+ .. note::
+
+ For all the decompressors, the input of a bytes-like object is required.
+
+ Returns:
+ - decompressor (:obj:`Callable`): Corresponding data decompressor.
+ Examples:
+ >>> decompress_fn = get_data_decompressor('lz4')
+ >>> origin_data = compressed(compressed_data)
+ """
+ return _DECOMPRESSORS_MAP[name]
diff --git a/DI-engine/ding/utils/data/__init__.py b/DI-engine/ding/utils/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6544dd8f67bbea4b748126ed3301f5497fa96c96
--- /dev/null
+++ b/DI-engine/ding/utils/data/__init__.py
@@ -0,0 +1,4 @@
+from .collate_fn import diff_shape_collate, default_collate, default_decollate, timestep_collate, ttorch_collate
+from .dataloader import AsyncDataLoader
+from .dataset import NaiveRLDataset, D4RLDataset, HDF5Dataset, BCODataset, \
+ create_dataset, hdf5_save, offline_data_save_type
diff --git a/DI-engine/ding/utils/data/base_dataloader.py b/DI-engine/ding/utils/data/base_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..d19bd9fcde4bab6b285be350e05776c9e272fd3d
--- /dev/null
+++ b/DI-engine/ding/utils/data/base_dataloader.py
@@ -0,0 +1,66 @@
+from typing import Optional, Callable, List, Any, Iterable
+import torch
+
+
+def example_get_data_fn() -> Any:
+ """
+ Overview:
+ Get data from file or other middleware
+ .. note::
+ staticmethod or static function, all the operation is on CPU
+ """
+ # 1. read data from file or other middleware
+ # 2. data post-processing(e.g.: normalization, to tensor)
+ # 3. return data
+ pass
+
+
+class IDataLoader:
+ """
+ Overview:
+ Base class of data loader
+ Interfaces:
+ ``__init__``, ``__next__``, ``__iter__``, ``_get_data``, ``close``
+ """
+
+ def __next__(self, batch_size: Optional[int] = None) -> torch.Tensor:
+ """
+ Overview:
+ Get one batch data
+ Arguments:
+ - batch_size (:obj:`Optional[int]`): sometimes, batch_size is specified by each iteration, \
+ if batch_size is None, use default batch_size value
+ """
+ # get one batch train data
+ if batch_size is None:
+ batch_size = self._batch_size
+ data = self._get_data(batch_size)
+ return self._collate_fn(data)
+
+ def __iter__(self) -> Iterable:
+ """
+ Overview:
+ Get data iterator
+ """
+
+ return self
+
+ def _get_data(self, batch_size: Optional[int] = None) -> List[torch.Tensor]:
+ """
+ Overview:
+ Get one batch data
+ Arguments:
+ - batch_size (:obj:`Optional[int]`): sometimes, batch_size is specified by each iteration, \
+ if batch_size is None, use default batch_size value
+ """
+
+ raise NotImplementedError
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close data loader
+ """
+
+ # release resource
+ pass
diff --git a/DI-engine/ding/utils/data/collate_fn.py b/DI-engine/ding/utils/data/collate_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..5397a9c4508d0af043bd06f3917fca6d3229de3f
--- /dev/null
+++ b/DI-engine/ding/utils/data/collate_fn.py
@@ -0,0 +1,344 @@
+from collections.abc import Sequence, Mapping
+from typing import List, Dict, Union, Any
+
+import torch
+import treetensor.torch as ttorch
+import re
+import collections.abc as container_abcs
+from ding.compatibility import torch_ge_131
+
+int_classes = int
+string_classes = (str, bytes)
+np_str_obj_array_pattern = re.compile(r'[SaUO]')
+
+default_collate_err_msg_format = (
+ "default_collate: batch must contain tensors, numpy arrays, numbers, "
+ "dicts or lists; found {}"
+)
+
+
+def ttorch_collate(x, json: bool = False, cat_1dim: bool = True):
+ """
+ Overview:
+ Collates a list of tensors or nested dictionaries of tensors into a single tensor or nested \
+ dictionary of tensors.
+
+ Arguments:
+ - x : The input list of tensors or nested dictionaries of tensors.
+ - json (:obj:`bool`): If True, converts the output to JSON format. Defaults to False.
+ - cat_1dim (:obj:`bool`): If True, concatenates tensors with shape (B, 1) along the last dimension. \
+ Defaults to True.
+
+ Returns:
+ The collated output tensor or nested dictionary of tensors.
+
+ Examples:
+ >>> # case 1: Collate a list of tensors
+ >>> tensors = [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6]), torch.tensor([7, 8, 9])]
+ >>> collated = ttorch_collate(tensors)
+ collated = torch.tensor([[1, 2, 3], [4, 5, 6], [7, 8, 9]])
+ >>> # case 2: Collate a nested dictionary of tensors
+ >>> nested_dict = {
+ 'a': torch.tensor([1, 2, 3]),
+ 'b': torch.tensor([4, 5, 6]),
+ 'c': torch.tensor([7, 8, 9])
+ }
+ >>> collated = ttorch_collate(nested_dict)
+ collated = {
+ 'a': torch.tensor([1, 2, 3]),
+ 'b': torch.tensor([4, 5, 6]),
+ 'c': torch.tensor([7, 8, 9])
+ }
+ >>> # case 3: Collate a list of nested dictionaries of tensors
+ >>> nested_dicts = [
+ {'a': torch.tensor([1, 2, 3]), 'b': torch.tensor([4, 5, 6])},
+ {'a': torch.tensor([7, 8, 9]), 'b': torch.tensor([10, 11, 12])}
+ ]
+ >>> collated = ttorch_collate(nested_dicts)
+ collated = {
+ 'a': torch.tensor([[1, 2, 3], [7, 8, 9]]),
+ 'b': torch.tensor([[4, 5, 6], [10, 11, 12]])
+ }
+ """
+
+ def inplace_fn(t):
+ for k in t.keys():
+ if isinstance(t[k], torch.Tensor):
+ if len(t[k].shape) == 2 and t[k].shape[1] == 1: # reshape (B, 1) -> (B)
+ t[k] = t[k].squeeze(-1)
+ else:
+ inplace_fn(t[k])
+
+ x = ttorch.stack(x)
+ if cat_1dim:
+ inplace_fn(x)
+ if json:
+ x = x.json()
+ return x
+
+
+def default_collate(batch: Sequence,
+ cat_1dim: bool = True,
+ ignore_prefix: list = ['collate_ignore']) -> Union[torch.Tensor, Mapping, Sequence]:
+ """
+ Overview:
+ Put each data field into a tensor with outer dimension batch size.
+
+ Arguments:
+ - batch (:obj:`Sequence`): A data sequence, whose length is batch size, whose element is one piece of data.
+ - cat_1dim (:obj:`bool`): Whether to concatenate tensors with shape (B, 1) to (B), defaults to True.
+ - ignore_prefix (:obj:`list`): A list of prefixes to ignore when collating dictionaries, \
+ defaults to ['collate_ignore'].
+
+ Returns:
+ - ret (:obj:`Union[torch.Tensor, Mapping, Sequence]`): the collated data, with batch size into each data \
+ field. The return dtype depends on the original element dtype, can be [torch.Tensor, Mapping, Sequence].
+
+ Example:
+ >>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n)
+ >>> a = [torch.zeros(2,3) for _ in range(4)]
+ >>> default_collate(a).shape
+ torch.Size([4, 2, 3])
+ >>>
+ >>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, )
+ >>> a = [[0 for __ in range(3)] for _ in range(4)]
+ >>> default_collate(a)
+ [tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])]
+ >>>
+ >>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->>
+ >>> # a dict whose values are tensors with shape :math:`(B, m, n)`
+ >>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)]
+ >>> print(a[0][2].shape, a[0][3].shape)
+ torch.Size([2, 3]) torch.Size([3, 4])
+ >>> b = default_collate(a)
+ >>> print(b[2].shape, b[3].shape)
+ torch.Size([4, 2, 3]) torch.Size([4, 3, 4])
+ """
+
+ if isinstance(batch, ttorch.Tensor):
+ return batch.json()
+
+ elem = batch[0]
+ elem_type = type(elem)
+ if isinstance(elem, torch.Tensor):
+ out = None
+ if torch_ge_131() and torch.utils.data.get_worker_info() is not None:
+ # If we're in a background process, directly concatenate into a
+ # shared memory tensor to avoid an extra copy
+ numel = sum([x.numel() for x in batch])
+ storage = elem.storage()._new_shared(numel)
+ out = elem.new(storage)
+ if elem.shape == (1, ) and cat_1dim:
+ # reshape (B, 1) -> (B)
+ return torch.cat(batch, 0, out=out)
+ # return torch.stack(batch, 0, out=out)
+ else:
+ return torch.stack(batch, 0, out=out)
+ elif isinstance(elem, ttorch.Tensor):
+ return ttorch_collate(batch, json=True, cat_1dim=cat_1dim)
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+ and elem_type.__name__ != 'string_':
+ if elem_type.__name__ == 'ndarray':
+ # array of string classes and object
+ if np_str_obj_array_pattern.search(elem.dtype.str) is not None:
+ raise TypeError(default_collate_err_msg_format.format(elem.dtype))
+ return default_collate([torch.as_tensor(b) for b in batch], cat_1dim=cat_1dim)
+ elif elem.shape == (): # scalars
+ return torch.as_tensor(batch)
+ elif isinstance(elem, float):
+ return torch.tensor(batch, dtype=torch.float32)
+ elif isinstance(elem, int_classes):
+ dtype = torch.bool if isinstance(elem, bool) else torch.int64
+ return torch.tensor(batch, dtype=dtype)
+ elif isinstance(elem, string_classes):
+ return batch
+ elif isinstance(elem, container_abcs.Mapping):
+ ret = {}
+ for key in elem:
+ if any([key.startswith(t) for t in ignore_prefix]):
+ ret[key] = [d[key] for d in batch]
+ else:
+ ret[key] = default_collate([d[key] for d in batch], cat_1dim=cat_1dim)
+ return ret
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
+ return elem_type(*(default_collate(samples, cat_1dim=cat_1dim) for samples in zip(*batch)))
+ elif isinstance(elem, container_abcs.Sequence):
+ transposed = zip(*batch)
+ return [default_collate(samples, cat_1dim=cat_1dim) for samples in transposed]
+
+ raise TypeError(default_collate_err_msg_format.format(elem_type))
+
+
+def timestep_collate(batch: List[Dict[str, Any]]) -> Dict[str, Union[torch.Tensor, list]]:
+ """
+ Overview:
+ Collates a batch of timestepped data fields into tensors with the outer dimension being the batch size. \
+ Each timestepped data field is represented as a tensor with shape [T, B, any_dims], where T is the length \
+ of the sequence, B is the batch size, and any_dims represents the shape of the tensor at each timestep.
+
+ Arguments:
+ - batch(:obj:`List[Dict[str, Any]]`): A list of dictionaries with length B, where each dictionary represents \
+ a timestepped data field. Each dictionary contains a key-value pair, where the key is the name of the \
+ data field and the value is a sequence of torch.Tensor objects with any shape.
+
+ Returns:
+ - ret(:obj:`Dict[str, Union[torch.Tensor, list]]`): The collated data, with the timestep and batch size \
+ incorporated into each data field. The shape of each data field is [T, B, dim1, dim2, ...].
+
+ Examples:
+ >>> batch = [
+ {'data0': [torch.tensor([1, 2, 3]), torch.tensor([4, 5, 6])]},
+ {'data1': [torch.tensor([7, 8, 9]), torch.tensor([10, 11, 12])]}
+ ]
+ >>> collated_data = timestep_collate(batch)
+ >>> print(collated_data['data'].shape)
+ torch.Size([2, 2, 3])
+ """
+
+ def stack(data):
+ if isinstance(data, container_abcs.Mapping):
+ return {k: stack(data[k]) for k in data}
+ elif isinstance(data, container_abcs.Sequence) and isinstance(data[0], torch.Tensor):
+ return torch.stack(data)
+ else:
+ return data
+
+ elem = batch[0]
+ assert isinstance(elem, (container_abcs.Mapping, list)), type(elem)
+ if isinstance(batch[0], list): # new pipeline + treetensor
+ prev_state = [[b[i].get('prev_state') for b in batch] for i in range(len(batch[0]))]
+ batch_data = ttorch.stack([ttorch_collate(b) for b in batch]) # (B, T, *)
+ del batch_data.prev_state
+ batch_data = batch_data.transpose(1, 0)
+ batch_data.prev_state = prev_state
+ else:
+ prev_state = [b.pop('prev_state') for b in batch]
+ batch_data = default_collate(batch) # -> {some_key: T lists}, each list is [B, some_dim]
+ batch_data = stack(batch_data) # -> {some_key: [T, B, some_dim]}
+ transformed_prev_state = list(zip(*prev_state))
+ batch_data['prev_state'] = transformed_prev_state
+ # append back prev_state, avoiding multi batch share the same data bug
+ for i in range(len(batch)):
+ batch[i]['prev_state'] = prev_state[i]
+ return batch_data
+
+
+def diff_shape_collate(batch: Sequence) -> Union[torch.Tensor, Mapping, Sequence]:
+ """
+ Overview:
+ Collates a batch of data with different shapes.
+ This function is similar to `default_collate`, but it allows tensors in the batch to have `None` values, \
+ which is common in StarCraft observations.
+
+ Arguments:
+ - batch (:obj:`Sequence`): A sequence of data, where each element is a piece of data.
+
+ Returns:
+ - ret (:obj:`Union[torch.Tensor, Mapping, Sequence]`): The collated data, with the batch size applied \
+ to each data field. The return type depends on the original element type and can be a torch.Tensor, \
+ Mapping, or Sequence.
+
+ Examples:
+ >>> # a list with B tensors shaped (m, n) -->> a tensor shaped (B, m, n)
+ >>> a = [torch.zeros(2,3) for _ in range(4)]
+ >>> diff_shape_collate(a).shape
+ torch.Size([4, 2, 3])
+ >>>
+ >>> # a list with B lists, each list contains m elements -->> a list of m tensors, each with shape (B, )
+ >>> a = [[0 for __ in range(3)] for _ in range(4)]
+ >>> diff_shape_collate(a)
+ [tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0]), tensor([0, 0, 0, 0])]
+ >>>
+ >>> # a list with B dicts, whose values are tensors shaped :math:`(m, n)` -->>
+ >>> # a dict whose values are tensors with shape :math:`(B, m, n)`
+ >>> a = [{i: torch.zeros(i,i+1) for i in range(2, 4)} for _ in range(4)]
+ >>> print(a[0][2].shape, a[0][3].shape)
+ torch.Size([2, 3]) torch.Size([3, 4])
+ >>> b = diff_shape_collate(a)
+ >>> print(b[2].shape, b[3].shape)
+ torch.Size([4, 2, 3]) torch.Size([4, 3, 4])
+ """
+ elem = batch[0]
+ elem_type = type(elem)
+ if any([isinstance(elem, type(None)) for elem in batch]):
+ return batch
+ elif isinstance(elem, torch.Tensor):
+ shapes = [e.shape for e in batch]
+ if len(set(shapes)) != 1:
+ return batch
+ else:
+ return torch.stack(batch, 0)
+ elif elem_type.__module__ == 'numpy' and elem_type.__name__ != 'str_' \
+ and elem_type.__name__ != 'string_':
+ if elem_type.__name__ == 'ndarray':
+ return diff_shape_collate([torch.as_tensor(b) for b in batch]) # todo
+ elif elem.shape == (): # scalars
+ return torch.as_tensor(batch)
+ elif isinstance(elem, float):
+ return torch.tensor(batch, dtype=torch.float32)
+ elif isinstance(elem, int_classes):
+ dtype = torch.bool if isinstance(elem, bool) else torch.int64
+ return torch.tensor(batch, dtype=dtype)
+ elif isinstance(elem, Mapping):
+ return {key: diff_shape_collate([d[key] for d in batch]) for key in elem}
+ elif isinstance(elem, tuple) and hasattr(elem, '_fields'): # namedtuple
+ return elem_type(*(diff_shape_collate(samples) for samples in zip(*batch)))
+ elif isinstance(elem, Sequence):
+ transposed = zip(*batch)
+ return [diff_shape_collate(samples) for samples in transposed]
+
+ raise TypeError('not support element type: {}'.format(elem_type))
+
+
+def default_decollate(
+ batch: Union[torch.Tensor, Sequence, Mapping],
+ ignore: List[str] = ['prev_state', 'prev_actor_state', 'prev_critic_state']
+) -> List[Any]:
+ """
+ Overview:
+ Drag out batch_size collated data's batch size to decollate it, which is the reverse operation of \
+ ``default_collate``.
+
+ Arguments:
+ - batch (:obj:`Union[torch.Tensor, Sequence, Mapping]`): The collated data batch. It can be a tensor, \
+ sequence, or mapping.
+ - ignore(:obj:`List[str]`): A list of names to be ignored. Only applicable if the input ``batch`` is a \
+ dictionary. If a key is in this list, its value will remain the same without decollation. Defaults to \
+ ['prev_state', 'prev_actor_state', 'prev_critic_state'].
+
+ Returns:
+ - ret (:obj:`List[Any]`): A list with B elements, where B is the batch size.
+
+ Examples:
+ >>> batch = {
+ 'a': [
+ [1, 2, 3],
+ [4, 5, 6]
+ ],
+ 'b': [
+ [7, 8, 9],
+ [10, 11, 12]
+ ]}
+ >>> default_decollate(batch)
+ {
+ 0: {'a': [1, 2, 3], 'b': [7, 8, 9]},
+ 1: {'a': [4, 5, 6], 'b': [10, 11, 12]},
+ }
+ """
+ if isinstance(batch, torch.Tensor):
+ batch = torch.split(batch, 1, dim=0)
+ # Squeeze if the original batch's shape is like (B, dim1, dim2, ...);
+ # otherwise, directly return the list.
+ if len(batch[0].shape) > 1:
+ batch = [elem.squeeze(0) for elem in batch]
+ return list(batch)
+ elif isinstance(batch, Sequence):
+ return list(zip(*[default_decollate(e) for e in batch]))
+ elif isinstance(batch, Mapping):
+ tmp = {k: v if k in ignore else default_decollate(v) for k, v in batch.items()}
+ B = len(list(tmp.values())[0])
+ return [{k: tmp[k][i] for k in tmp.keys()} for i in range(B)]
+ elif isinstance(batch, torch.distributions.Distribution): # For compatibility
+ return [None for _ in range(batch.batch_shape[0])]
+
+ raise TypeError("Not supported batch type: {}".format(type(batch)))
diff --git a/DI-engine/ding/utils/data/dataloader.py b/DI-engine/ding/utils/data/dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b6ffeec8384801f5a9496de37e615d394ce6a5c
--- /dev/null
+++ b/DI-engine/ding/utils/data/dataloader.py
@@ -0,0 +1,363 @@
+from typing import Iterable, Callable, Optional, Any, Union
+import time
+import platform
+import threading
+import queue
+
+import torch
+import torch.multiprocessing as tm
+from ding.torch_utils import to_device
+from ding.utils import LockContext, LockContextType
+from .base_dataloader import IDataLoader
+from .collate_fn import default_collate
+
+
+class AsyncDataLoader(IDataLoader):
+ """
+ Overview:
+ An asynchronous dataloader.
+ Interfaces:
+ ``__init__``, ``__iter__``, ``__next__``, ``_get_data``, ``_async_loop``, ``_worker_loop``, ``_cuda_loop``, \
+ ``_get_data``, ``close``
+ """
+
+ def __init__(
+ self,
+ data_source: Union[Callable, dict],
+ batch_size: int,
+ device: str,
+ chunk_size: Optional[int] = None,
+ collate_fn: Optional[Callable] = None,
+ num_workers: int = 0
+ ) -> None:
+ """
+ Overview:
+ Init dataloader with input parameters.
+ If ``data_source`` is ``dict``, data will only be processed in ``get_data_thread`` and put into
+ ``async_train_queue``.
+ If ``data_source`` is ``Callable``, data will be processed by implementing functions, and can be sorted
+ in two types:
+
+ - ``num_workers`` == 0 or 1: Only main worker will process it and put into ``async_train_queue``.
+ - ``num_workers`` > 1: Main worker will divide a job into several pieces, push every job into \
+ ``job_queue``; Then slave workers get jobs and implement; Finally they will push procesed data \
+ into ``async_train_queue``.
+
+ At the last step, if ``device`` contains "cuda", data in ``async_train_queue`` will be transferred to
+ ``cuda_queue`` for uer to access.
+ Arguments:
+ - data_source (:obj:`Union[Callable, dict]`): The data source, e.g. function to be implemented(Callable), \
+ replay buffer's real data(dict), etc.
+ - batch_size (:obj:`int`): Batch size.
+ - device (:obj:`str`): Device.
+ - chunk_size (:obj:`int`): The size of a chunked piece in a batch, should exactly divide ``batch_size``, \
+ only function when there are more than 1 worker.
+ - collate_fn (:obj:`Callable`): The function which is used to collate batch size into each data field.
+ - num_workers (:obj:`int`): Number of extra workers. \
+ 0 or 1 means only 1 main worker and no extra ones, i.e. Multiprocessing is disabled. \
+ More than 1 means multiple workers implemented by multiprocessing are to processs data respectively.
+ """
+ self.data_source = data_source
+ self.batch_size = batch_size
+ self.device = device
+ self.use_cuda = 'cuda' in self.device
+ if self.use_cuda:
+ self.stream = torch.cuda.Stream()
+ if chunk_size is None:
+ self.chunk_size = 1
+ else:
+ self.chunk_size = chunk_size
+ assert self.batch_size >= self.chunk_size and self.batch_size % self.chunk_size == 0, '{}/{}'.format(
+ self.batch_size, self.chunk_size
+ )
+ if collate_fn is None:
+ self.collate_fn = default_collate
+ else:
+ self.collate_fn = collate_fn
+ self.num_workers = num_workers
+ if self.num_workers < 0:
+ raise ValueError(
+ '"num_workers" should be non-negative; '
+ 'Use num_workers = 0 or 1 to disable multiprocessing.'
+ )
+ # Up to "2 * num_workers" pieces of data will be stored in dataloader, waiting for learner to get.
+ # Up to "2 * num_workers" jobs will be stored in dataloader, waiting for slave process to get and accomplish.
+ queue_maxsize = max(1, self.num_workers) * 2
+ self.queue_maxsize = queue_maxsize
+
+ # For multiprocessing: Use ``spawn`` on Windows, ``fork`` on other platforms.
+ context_str = 'spawn' if platform.system().lower() == 'windows' else 'fork'
+ self.mp_context = tm.get_context(context_str)
+ self.manager = self.mp_context.Manager()
+ # ``async_train_queue`` is the queue to store processed data.
+ # User can directly access data if don't use cuda; Otherwise, user will access data from ``cuda_queue``.
+ self.async_train_queue = self.mp_context.Queue(maxsize=queue_maxsize)
+ self.end_flag = False
+
+ # Multiprocessing workers: If num_workers > 1, more than 1 worker are to process data.
+ if self.num_workers > 1:
+ self.batch_id = self.mp_context.Value('i', 0)
+ self.cur_batch = self.mp_context.Value('i', 0)
+ if self.batch_size != self.chunk_size:
+ # job_result {batch_id: result_list} is used to store processed result in temporal.
+ self.job_result = self.manager.dict()
+ self.job_result_lock = LockContext(type_=LockContextType.PROCESS_LOCK)
+ self.job_queue = self.mp_context.Queue(maxsize=queue_maxsize)
+ self.worker = [
+ self.mp_context.Process(
+ target=self._worker_loop, args=(), name='dataloader_worker{}_{}'.format(i, time.time())
+ ) for i in range(self.num_workers)
+ ]
+ for w in self.worker:
+ w.daemon = True
+ w.start()
+ print('Using {} workers to load data'.format(self.num_workers))
+
+ # Parent and child pipes. Used by ``async_process`` and ``get_data_thread`` to coordinate.
+ p, c = self.mp_context.Pipe()
+
+ # Async process (Main worker): Process data if num_workers <= 1; Assign job to other workers if num_workers > 1.
+ self.async_process = self.mp_context.Process(target=self._async_loop, args=(p, c))
+ self.async_process.daemon = True
+ self.async_process.start()
+
+ # Get data thread: Get data from ``data_source`` and send it to ``async_process``.`
+ self.get_data_thread = threading.Thread(target=self._get_data, args=(p, c))
+ self.get_data_thread.daemon = True
+ self.get_data_thread.start()
+
+ # Cuda thread: If use cuda, data in ``async_train_queue`` will be transferred to ``cuda_queue``;
+ # Then user will access data from ``cuda_queue``.
+ if self.use_cuda:
+ self.cuda_queue = queue.Queue(maxsize=queue_maxsize)
+ self.cuda_thread = threading.Thread(target=self._cuda_loop, args=(), name='dataloader_cuda')
+ self.cuda_thread.daemon = True
+ self.cuda_thread.start()
+
+ def __iter__(self) -> Iterable:
+ """
+ Overview:
+ Return the iterable self as an iterator.
+ Returns:
+ - self (:obj:`Iterable`): Self as an iterator.
+ """
+ return self
+
+ def _get_data(self, p: tm.multiprocessing.connection, c: tm.multiprocessing.connection) -> None:
+ """
+ Overview:
+ Init dataloader with input parameters. Will run as a thread through ``self.get_data_thread``.
+ Arguments:
+ - p (:obj:`tm.multiprocessing.connection`): Parent connection.
+ - c (:obj:`tm.multiprocessing.connection`): Child connection.
+ """
+ c.close() # Close unused c, only use p
+ while not self.end_flag:
+ if not p.poll(timeout=0.2):
+ time.sleep(0.01)
+ continue
+ try:
+ cmd = p.recv()
+ except EOFError:
+ break
+ if cmd == 'get_data':
+ # Main worker asks for data.
+ data = self.data_source(self.batch_size)
+ # ``data`` can be callable, e.g. a function to read data from file, therefore we can divide
+ # this job to pieces, assign to every slave worker and accomplish jobs asynchronously.
+ # But if we get a list of dicts, which means the data has already been processed and
+ # can be used directly, we can put it directly in async_train_queue and wait it
+ # to be accessed by a user, e.g. learner.
+ if isinstance(data[0], dict):
+ data = self.collate_fn(data)
+ self.async_train_queue.put(data)
+ p.send('pass')
+ else:
+ p.send(data)
+ p.close()
+
+ def _async_loop(self, p: tm.multiprocessing.connection, c: tm.multiprocessing.connection) -> None:
+ """
+ Overview:
+ Main worker process. Run through ``self.async_process``.
+ Firstly, get data from ``self.get_data_thread``.
+ If multiple workers, put data in ``self.job_queue`` for further multiprocessing operation;
+ If only one worker, process data and put directly into ``self.async_train_queue``.
+ Arguments:
+ - p (:obj:`tm.multiprocessing.connection`): Parent connection.
+ - c (:obj:`tm.multiprocessing.connection`): Child connection.
+ """
+ torch.set_num_threads(1)
+ p.close() # Close unused p, only use c
+ while not self.end_flag:
+ if self.num_workers > 1:
+ # Multiple workers: Put jobs (chunked data) into job_queue
+ if self.job_queue.full():
+ time.sleep(0.001)
+ else:
+ # Get data from ``_get_data`` thread.
+ c.send('get_data')
+ data = c.recv()
+ if isinstance(data, str) and data == 'pass':
+ continue
+ # Get data to be processed, chunk it into pieces and put them into job_queue.
+ chunk_num = self.batch_size // self.chunk_size
+ with self.batch_id.get_lock():
+ for i in range(chunk_num):
+ start, end = i * self.chunk_size, (i + 1) * self.chunk_size
+ self.job_queue.put({'batch_id': self.batch_id.value, 'job': data[start:end]})
+ self.batch_id.value = (self.batch_id.value + 1) % self.queue_maxsize # Increment batch_id
+ time.sleep(0.001)
+ else:
+ # Only one worker: Process data and directly put it into async_train_queue
+ if self.async_train_queue.full():
+ time.sleep(0.001)
+ else:
+ c.send('get_data')
+ data = c.recv()
+ if isinstance(data, str) and data == 'pass':
+ continue
+ data = [fn() for fn in data] # Implement functions in list ``data``.
+ data = self.collate_fn(data)
+ self.async_train_queue.put(data)
+ c.close()
+
+ def _worker_loop(self) -> None:
+ """
+ Overview:
+ Worker process. Run through each element in list ``self.worker``.
+ Get data job from ``self.job_queue``, process it and then put into ``self.async_train_queue``.
+ Only function when ``self.num_workers`` > 1, which means using multiprocessing.
+ """
+ while not self.end_flag:
+ if self.job_queue.empty() or self.async_train_queue.full():
+ # No left job to be done, or finished job have no space to store.
+ time.sleep(0.01)
+ continue
+ else:
+ try:
+ element = self.job_queue.get()
+ except (ConnectionResetError, ConnectionRefusedError) as e:
+ break
+ batch_id, job = element['batch_id'], element['job']
+ # Process the assigned data.
+ data = [fn() for fn in job] # Only function-type job will arrive here, dict-type will not
+ if len(data) == self.batch_size == self.chunk_size:
+ # Data not chunked: Finish the assigned one means finishing a whole batch.
+ data = self.collate_fn(data)
+ while batch_id != self.cur_batch.value:
+ time.sleep(0.01)
+ self.async_train_queue.put(data)
+ # Directly update cur_batch, since a whole batch is finished
+ with self.cur_batch.get_lock():
+ self.cur_batch.value = (self.cur_batch.value + 1) % self.queue_maxsize
+ else:
+ # Data chunked: Must wait for all chunked pieces in a batch to be accomplished.
+ finish_flag = False # indicate whether a whole batch is accomplished
+ with self.job_result_lock:
+ if batch_id not in self.job_result:
+ # The first one in a batch
+ self.job_result[batch_id] = data
+ elif len(self.job_result[batch_id]) + len(data) == self.batch_size:
+ # The last one in a batch
+ data += self.job_result.pop(batch_id)
+ assert batch_id not in self.job_result
+ finish_flag = True
+ else:
+ # Middle pieces in a batch
+ self.job_result[batch_id] += data
+ if finish_flag:
+ data = self.collate_fn(data)
+ while batch_id != self.cur_batch.value:
+ time.sleep(0.01)
+ self.async_train_queue.put(data)
+ with self.cur_batch.get_lock():
+ self.cur_batch.value = (self.cur_batch.value + 1) % self.queue_maxsize
+ # If ``self.end_flag`` is True, clear and close job_queue, because _worker_loop gets jobs from job_queue.
+ while not self.job_queue.empty():
+ try:
+ _ = self.job_queue.get()
+ except Exception as e:
+ break
+ self.job_queue.close()
+ self.job_queue.join_thread()
+
+ def _cuda_loop(self) -> None:
+ """
+ Overview:
+ Only when using cuda, would this be run as a thread through ``self.cuda_thread``.
+ Get data from ``self.async_train_queue``, change its device and put it into ``self.cuda_queue``
+ """
+ with torch.cuda.stream(self.stream):
+ while not self.end_flag:
+ if self.async_train_queue.empty() or self.cuda_queue.full():
+ time.sleep(0.01)
+ else:
+ data = self.async_train_queue.get()
+ data = to_device(data, self.device)
+ self.cuda_queue.put(data)
+ # If ``self.end_flag``` is True, clear and close async_train_queue,
+ # because _cuda_loop gets data from async_train_queue.
+ while not self.async_train_queue.empty():
+ _ = self.async_train_queue.get()
+ self.async_train_queue.close()
+ self.async_train_queue.join_thread()
+
+ def __next__(self) -> Any:
+ """
+ Overview:
+ Return next data in the iterator. If use cuda, get from ``self.cuda_queue``;
+ Otherwise, get from ``self.async_train_queue``.
+ Returns:
+ - data (:obj:`torch.Tensor`): Next data in the dataloader iterator.
+ """
+ while not self.end_flag:
+ if self.use_cuda:
+ if self.cuda_queue.empty():
+ time.sleep(0.01)
+ else:
+ data = self.cuda_queue.get(timeout=60)
+ self.cuda_queue.task_done()
+ return data
+ else:
+ if self.async_train_queue.empty():
+ time.sleep(0.01)
+ else:
+ return self.async_train_queue.get()
+ # If ``self.end_flag``` is True, clear and close either 1) or 2):
+ # 1) cuda_queue. Because user get data from cuda_queue, and async_train_queue is closed by cuda_loop.
+ # 2) async_train_queue. Because user get data from async_train_queue.
+ if self.use_cuda:
+ while not self.cuda_queue.empty():
+ _ = self.cuda_queue.get()
+ self.cuda_queue.task_done()
+ self.cuda_queue.join()
+ else:
+ while not self.async_train_queue.empty():
+ _ = self.async_train_queue.get()
+ self.async_train_queue.close()
+ self.async_train_queue.join_thread()
+
+ def __del__(self) -> None:
+ """
+ Overview:
+ Delete this dataloader.
+ """
+ self.close()
+
+ def close(self) -> None:
+ """
+ Overview:
+ Delete this dataloader. First set ``end_flag`` to True, which means different processes/threads
+ will clear and close all data queues; Then all processes will be terminated and joined.
+ """
+ if self.end_flag:
+ return
+ self.end_flag = True
+ self.async_process.terminate()
+ self.async_process.join()
+ if self.num_workers > 1:
+ for w in self.worker:
+ w.terminate()
+ w.join()
+ print('Del AsyncDataLoader')
diff --git a/DI-engine/ding/utils/data/dataset.py b/DI-engine/ding/utils/data/dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..40d7831001bebe09616a0181917185bbcb7af2b7
--- /dev/null
+++ b/DI-engine/ding/utils/data/dataset.py
@@ -0,0 +1,1510 @@
+from typing import List, Dict, Tuple
+from ditk import logging
+from copy import deepcopy
+from easydict import EasyDict
+from torch.utils.data import Dataset
+from dataclasses import dataclass
+
+import pickle
+import easydict
+import torch
+import numpy as np
+
+from ding.utils.bfs_helper import get_vi_sequence
+from ding.utils import DATASET_REGISTRY, import_module, DatasetNormalizer
+from ding.rl_utils import discount_cumsum
+
+
+@dataclass
+class DatasetStatistics:
+ """
+ Overview:
+ Dataset statistics.
+ """
+ mean: np.ndarray # obs
+ std: np.ndarray # obs
+ action_bounds: np.ndarray
+
+
+@DATASET_REGISTRY.register('naive')
+class NaiveRLDataset(Dataset):
+ """
+ Overview:
+ Naive RL dataset, which is used for offline RL algorithms.
+ Interfaces:
+ ``__init__``, ``__len__``, ``__getitem__``
+ """
+
+ def __init__(self, cfg) -> None:
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`dict`): Config dict.
+ """
+
+ assert type(cfg) in [str, EasyDict], "invalid cfg type: {}".format(type(cfg))
+ if isinstance(cfg, EasyDict):
+ self._data_path = cfg.policy.collect.data_path
+ elif isinstance(cfg, str):
+ self._data_path = cfg
+ with open(self._data_path, 'rb') as f:
+ self._data: List[Dict[str, torch.Tensor]] = pickle.load(f)
+
+ def __len__(self) -> int:
+ """
+ Overview:
+ Get the length of the dataset.
+ """
+
+ return len(self._data)
+
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Get the item of the dataset.
+ """
+
+ return self._data[idx]
+
+
+@DATASET_REGISTRY.register('d4rl')
+class D4RLDataset(Dataset):
+ """
+ Overview:
+ D4RL dataset, which is used for offline RL algorithms.
+ Interfaces:
+ ``__init__``, ``__len__``, ``__getitem__``
+ Properties:
+ - mean (:obj:`np.ndarray`): Mean of the dataset.
+ - std (:obj:`np.ndarray`): Std of the dataset.
+ - action_bounds (:obj:`np.ndarray`): Action bounds of the dataset.
+ - statistics (:obj:`dict`): Statistics of the dataset.
+ """
+
+ def __init__(self, cfg: dict) -> None:
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`dict`): Config dict.
+ """
+
+ import gym
+ try:
+ import d4rl # register d4rl enviroments with open ai gym
+ except ImportError:
+ import sys
+ logging.warning("not found d4rl env, please install it, refer to https://github.com/rail-berkeley/d4rl")
+ sys.exit(1)
+
+ # Init parameters
+ data_path = cfg.policy.collect.get('data_path', None)
+ env_id = cfg.env.env_id
+
+ # Create the environment
+ if data_path:
+ d4rl.set_dataset_path(data_path)
+ env = gym.make(env_id)
+ dataset = d4rl.qlearning_dataset(env)
+ self._cal_statistics(dataset, env)
+ try:
+ if cfg.env.norm_obs.use_norm and cfg.env.norm_obs.offline_stats.use_offline_stats:
+ dataset = self._normalize_states(dataset)
+ except (KeyError, AttributeError):
+ # do not normalize
+ pass
+ self._data = []
+ self._load_d4rl(dataset)
+
+ @property
+ def data(self) -> List:
+ return self._data
+
+ def __len__(self) -> int:
+ """
+ Overview:
+ Get the length of the dataset.
+ """
+
+ return len(self._data)
+
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Get the item of the dataset.
+ """
+
+ return self._data[idx]
+
+ def _load_d4rl(self, dataset: Dict[str, np.ndarray]) -> None:
+ """
+ Overview:
+ Load the d4rl dataset.
+ Arguments:
+ - dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset.
+ """
+
+ for i in range(len(dataset['observations'])):
+ trans_data = {}
+ trans_data['obs'] = torch.from_numpy(dataset['observations'][i])
+ trans_data['next_obs'] = torch.from_numpy(dataset['next_observations'][i])
+ trans_data['action'] = torch.from_numpy(dataset['actions'][i])
+ trans_data['reward'] = torch.tensor(dataset['rewards'][i])
+ trans_data['done'] = dataset['terminals'][i]
+ self._data.append(trans_data)
+
+ def _cal_statistics(self, dataset, env, eps=1e-3, add_action_buffer=True):
+ """
+ Overview:
+ Calculate the statistics of the dataset.
+ Arguments:
+ - dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset.
+ - env (:obj:`gym.Env`): The environment.
+ - eps (:obj:`float`): Epsilon.
+ """
+
+ self._mean = dataset['observations'].mean(0)
+ self._std = dataset['observations'].std(0) + eps
+ action_max = dataset['actions'].max(0)
+ action_min = dataset['actions'].min(0)
+ if add_action_buffer:
+ action_buffer = 0.05 * (action_max - action_min)
+ action_max = (action_max + action_buffer).clip(max=env.action_space.high)
+ action_min = (action_min - action_buffer).clip(min=env.action_space.low)
+ self._action_bounds = np.stack([action_min, action_max], axis=0)
+
+ def _normalize_states(self, dataset):
+ """
+ Overview:
+ Normalize the states.
+ Arguments:
+ - dataset (:obj:`Dict[str, np.ndarray]`): The d4rl dataset.
+ """
+
+ dataset['observations'] = (dataset['observations'] - self._mean) / self._std
+ dataset['next_observations'] = (dataset['next_observations'] - self._mean) / self._std
+ return dataset
+
+ @property
+ def mean(self):
+ """
+ Overview:
+ Get the mean of the dataset.
+ """
+
+ return self._mean
+
+ @property
+ def std(self):
+ """
+ Overview:
+ Get the std of the dataset.
+ """
+
+ return self._std
+
+ @property
+ def action_bounds(self) -> np.ndarray:
+ """
+ Overview:
+ Get the action bounds of the dataset.
+ """
+
+ return self._action_bounds
+
+ @property
+ def statistics(self) -> dict:
+ """
+ Overview:
+ Get the statistics of the dataset.
+ """
+
+ return DatasetStatistics(mean=self.mean, std=self.std, action_bounds=self.action_bounds)
+
+
+@DATASET_REGISTRY.register('hdf5')
+class HDF5Dataset(Dataset):
+ """
+ Overview:
+ HDF5 dataset is saved in hdf5 format, which is used for offline RL algorithms.
+ The hdf5 format is a common format for storing large numerical arrays in Python.
+ For more details, please refer to https://support.hdfgroup.org/HDF5/.
+ Interfaces:
+ ``__init__``, ``__len__``, ``__getitem__``
+ Properties:
+ - mean (:obj:`np.ndarray`): Mean of the dataset.
+ - std (:obj:`np.ndarray`): Std of the dataset.
+ - action_bounds (:obj:`np.ndarray`): Action bounds of the dataset.
+ - statistics (:obj:`dict`): Statistics of the dataset.
+ """
+
+ def __init__(self, cfg: dict) -> None:
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`dict`): Config dict.
+ """
+
+ try:
+ import h5py
+ except ImportError:
+ import sys
+ logging.warning("not found h5py package, please install it trough `pip install h5py ")
+ sys.exit(1)
+ data_path = cfg.policy.collect.get('data_path', None)
+ if 'dataset' in cfg:
+ self.context_len = cfg.dataset.context_len
+ else:
+ self.context_len = 0
+ data = h5py.File(data_path, 'r')
+ self._load_data(data)
+ self._cal_statistics()
+ try:
+ if cfg.env.norm_obs.use_norm and cfg.env.norm_obs.offline_stats.use_offline_stats:
+ self._normalize_states()
+ except (KeyError, AttributeError):
+ # do not normalize
+ pass
+
+ def __len__(self) -> int:
+ """
+ Overview:
+ Get the length of the dataset.
+ """
+
+ return len(self._data['obs']) - self.context_len
+
+ def __getitem__(self, idx: int) -> Dict[str, torch.Tensor]:
+ """
+ Overview:
+ Get the item of the dataset.
+ Arguments:
+ - idx (:obj:`int`): The index of the dataset.
+ """
+
+ if self.context_len == 0: # for other offline RL algorithms
+ return {k: self._data[k][idx] for k in self._data.keys()}
+ else: # for decision transformer
+ block_size = self.context_len
+ done_idx = idx + block_size
+ idx = done_idx - block_size
+ states = torch.as_tensor(
+ np.array(self._data['obs'][idx:done_idx]), dtype=torch.float32
+ ).view(block_size, -1)
+ actions = torch.as_tensor(self._data['action'][idx:done_idx], dtype=torch.long)
+ rtgs = torch.as_tensor(self._data['reward'][idx:done_idx, 0], dtype=torch.float32)
+ timesteps = torch.as_tensor(range(idx, done_idx), dtype=torch.int64)
+ traj_mask = torch.ones(self.context_len, dtype=torch.long)
+ return timesteps, states, actions, rtgs, traj_mask
+
+ def _load_data(self, dataset: Dict[str, np.ndarray]) -> None:
+ """
+ Overview:
+ Load the dataset.
+ Arguments:
+ - dataset (:obj:`Dict[str, np.ndarray]`): The dataset.
+ """
+
+ self._data = {}
+ for k in dataset.keys():
+ logging.info(f'Load {k} data.')
+ self._data[k] = dataset[k][:]
+
+ def _cal_statistics(self, eps: float = 1e-3):
+ """
+ Overview:
+ Calculate the statistics of the dataset.
+ Arguments:
+ - eps (:obj:`float`): Epsilon.
+ """
+
+ self._mean = self._data['obs'].mean(0)
+ self._std = self._data['obs'].std(0) + eps
+ action_max = self._data['action'].max(0)
+ action_min = self._data['action'].min(0)
+ buffer = 0.05 * (action_max - action_min)
+ action_max = action_max.astype(float) + buffer
+ action_min = action_max.astype(float) - buffer
+ self._action_bounds = np.stack([action_min, action_max], axis=0)
+
+ def _normalize_states(self):
+ """
+ Overview:
+ Normalize the states.
+ """
+
+ self._data['obs'] = (self._data['obs'] - self._mean) / self._std
+ self._data['next_obs'] = (self._data['next_obs'] - self._mean) / self._std
+
+ @property
+ def mean(self):
+ """
+ Overview:
+ Get the mean of the dataset.
+ """
+
+ return self._mean
+
+ @property
+ def std(self):
+ """
+ Overview:
+ Get the std of the dataset.
+ """
+
+ return self._std
+
+ @property
+ def action_bounds(self) -> np.ndarray:
+ """
+ Overview:
+ Get the action bounds of the dataset.
+ """
+
+ return self._action_bounds
+
+ @property
+ def statistics(self) -> dict:
+ """
+ Overview:
+ Get the statistics of the dataset.
+ """
+
+ return DatasetStatistics(mean=self.mean, std=self.std, action_bounds=self.action_bounds)
+
+
+@DATASET_REGISTRY.register('d4rl_trajectory')
+class D4RLTrajectoryDataset(Dataset):
+ """
+ Overview:
+ D4RL trajectory dataset, which is used for offline RL algorithms.
+ Interfaces:
+ ``__init__``, ``__len__``, ``__getitem__``
+ """
+
+ # from infos.py from official d4rl github repo
+ REF_MIN_SCORE = {
+ 'halfcheetah': -280.178953,
+ 'walker2d': 1.629008,
+ 'hopper': -20.272305,
+ }
+
+ REF_MAX_SCORE = {
+ 'halfcheetah': 12135.0,
+ 'walker2d': 4592.3,
+ 'hopper': 3234.3,
+ }
+
+ # calculated from d4rl datasets
+ D4RL_DATASET_STATS = {
+ 'halfcheetah-medium-v2': {
+ 'state_mean': [
+ -0.06845773756504059, 0.016414547339081764, -0.18354906141757965, -0.2762460708618164,
+ -0.34061527252197266, -0.09339715540409088, -0.21321271359920502, -0.0877423882484436,
+ 5.173007488250732, -0.04275195300579071, -0.036108363419771194, 0.14053793251514435,
+ 0.060498327016830444, 0.09550975263118744, 0.06739100068807602, 0.005627387668937445,
+ 0.013382787816226482
+ ],
+ 'state_std': [
+ 0.07472999393939972, 0.3023499846458435, 0.30207309126853943, 0.34417077898979187, 0.17619241774082184,
+ 0.507205605506897, 0.2567007839679718, 0.3294812738895416, 1.2574149370193481, 0.7600541710853577,
+ 1.9800915718078613, 6.565362453460693, 7.466367721557617, 4.472222805023193, 10.566964149475098,
+ 5.671932697296143, 7.4982590675354
+ ]
+ },
+ 'halfcheetah-medium-replay-v2': {
+ 'state_mean': [
+ -0.12880703806877136, 0.3738119602203369, -0.14995987713336945, -0.23479078710079193,
+ -0.2841278612613678, -0.13096535205841064, -0.20157982409000397, -0.06517726927995682,
+ 3.4768247604370117, -0.02785065770149231, -0.015035249292850494, 0.07697279006242752,
+ 0.01266712136566639, 0.027325302362442017, 0.02316424623131752, 0.010438721626996994,
+ -0.015839405357837677
+ ],
+ 'state_std': [
+ 0.17019015550613403, 1.284424901008606, 0.33442774415016174, 0.3672759234905243, 0.26092398166656494,
+ 0.4784106910228729, 0.3181420564651489, 0.33552637696266174, 2.0931615829467773, 0.8037433624267578,
+ 1.9044333696365356, 6.573209762573242, 7.572863578796387, 5.069749355316162, 9.10555362701416,
+ 6.085654258728027, 7.25300407409668
+ ]
+ },
+ 'halfcheetah-medium-expert-v2': {
+ 'state_mean': [
+ -0.05667462572455406, 0.024369969964027405, -0.061670560389757156, -0.22351515293121338,
+ -0.2675151228904724, -0.07545716315507889, -0.05809682980179787, -0.027675075456500053,
+ 8.110626220703125, -0.06136331334710121, -0.17986927926540375, 0.25175222754478455, 0.24186332523822784,
+ 0.2519369423389435, 0.5879552960395813, -0.24090635776519775, -0.030184272676706314
+ ],
+ 'state_std': [
+ 0.06103534251451492, 0.36054104566574097, 0.45544400811195374, 0.38476887345314026, 0.2218363732099533,
+ 0.5667523741722107, 0.3196682929992676, 0.2852923572063446, 3.443821907043457, 0.6728139519691467,
+ 1.8616976737976074, 9.575807571411133, 10.029894828796387, 5.903450012207031, 12.128185272216797,
+ 6.4811787605285645, 6.378620147705078
+ ]
+ },
+ 'walker2d-medium-v2': {
+ 'state_mean': [
+ 1.218966007232666, 0.14163373410701752, -0.03704913705587387, -0.13814310729503632, 0.5138224363327026,
+ -0.04719110205769539, -0.47288352251052856, 0.042254164814949036, 2.3948874473571777,
+ -0.03143199160695076, 0.04466355964541435, -0.023907244205474854, -0.1013401448726654,
+ 0.09090937674045563, -0.004192637279629707, -0.12120571732521057, -0.5497063994407654
+ ],
+ 'state_std': [
+ 0.12311358004808426, 0.3241879940032959, 0.11456084251403809, 0.2623065710067749, 0.5640279054641724,
+ 0.2271878570318222, 0.3837319612503052, 0.7373676896095276, 1.2387926578521729, 0.798020601272583,
+ 1.5664079189300537, 1.8092705011367798, 3.025604248046875, 4.062486171722412, 1.4586567878723145,
+ 3.7445690631866455, 5.5851287841796875
+ ]
+ },
+ 'walker2d-medium-replay-v2': {
+ 'state_mean': [
+ 1.209364652633667, 0.13264022767543793, -0.14371201395988464, -0.2046516090631485, 0.5577612519264221,
+ -0.03231537342071533, -0.2784661054611206, 0.19130706787109375, 1.4701707363128662,
+ -0.12504704296588898, 0.0564953051507473, -0.09991033375263214, -0.340340256690979, 0.03546293452382088,
+ -0.08934258669614792, -0.2992438077926636, -0.5984178185462952
+ ],
+ 'state_std': [
+ 0.11929835379123688, 0.3562574088573456, 0.25852200388908386, 0.42075422406196594, 0.5202291011810303,
+ 0.15685082972049713, 0.36770978569984436, 0.7161387801170349, 1.3763766288757324, 0.8632221817970276,
+ 2.6364643573760986, 3.0134117603302, 3.720684051513672, 4.867283821105957, 2.6681625843048096,
+ 3.845186948776245, 5.4768385887146
+ ]
+ },
+ 'walker2d-medium-expert-v2': {
+ 'state_mean': [
+ 1.2294334173202515, 0.16869689524173737, -0.07089081406593323, -0.16197483241558075,
+ 0.37101927399635315, -0.012209027074277401, -0.42461398243904114, 0.18986578285694122,
+ 3.162475109100342, -0.018092676997184753, 0.03496946766972542, -0.013921679928898811,
+ -0.05937029421329498, -0.19549426436424255, -0.0019200450042262673, -0.062483321875333786,
+ -0.27366524934768677
+ ],
+ 'state_std': [
+ 0.09932824969291687, 0.25981399416923523, 0.15062759816646576, 0.24249176681041718, 0.6758718490600586,
+ 0.1650741547346115, 0.38140663504600525, 0.6962361335754395, 1.3501490354537964, 0.7641991376876831,
+ 1.534574270248413, 2.1785972118377686, 3.276582717895508, 4.766193866729736, 1.1716983318328857,
+ 4.039782524108887, 5.891613960266113
+ ]
+ },
+ 'hopper-medium-v2': {
+ 'state_mean': [
+ 1.311279058456421, -0.08469521254301071, -0.5382719039916992, -0.07201576232910156, 0.04932365566492081,
+ 2.1066856384277344, -0.15017354488372803, 0.008783451281487942, -0.2848185896873474,
+ -0.18540096282958984, -0.28461286425590515
+ ],
+ 'state_std': [
+ 0.17790751159191132, 0.05444620922207832, 0.21297138929367065, 0.14530418813228607, 0.6124444007873535,
+ 0.8517446517944336, 1.4515252113342285, 0.6751695871353149, 1.5362390279769897, 1.616074562072754,
+ 5.607253551483154
+ ]
+ },
+ 'hopper-medium-replay-v2': {
+ 'state_mean': [
+ 1.2305138111114502, -0.04371410980820656, -0.44542956352233887, -0.09370097517967224,
+ 0.09094487875699997, 1.3694725036621094, -0.19992674887180328, -0.022861352190375328,
+ -0.5287045240402222, -0.14465883374214172, -0.19652697443962097
+ ],
+ 'state_std': [
+ 0.1756512075662613, 0.0636928603053093, 0.3438323438167572, 0.19566889107227325, 0.5547984838485718,
+ 1.051029920578003, 1.158307671546936, 0.7963128685951233, 1.4802359342575073, 1.6540331840515137,
+ 5.108601093292236
+ ]
+ },
+ 'hopper-medium-expert-v2': {
+ 'state_mean': [
+ 1.3293815851211548, -0.09836531430482864, -0.5444297790527344, -0.10201650857925415,
+ 0.02277466468513012, 2.3577215671539307, -0.06349576264619827, -0.00374026270583272,
+ -0.1766270101070404, -0.11862941086292267, -0.12097819894552231
+ ],
+ 'state_std': [
+ 0.17012375593185425, 0.05159067362546921, 0.18141433596611023, 0.16430604457855225, 0.6023368239402771,
+ 0.7737284898757935, 1.4986555576324463, 0.7483318448066711, 1.7953159809112549, 2.0530025959014893,
+ 5.725032806396484
+ ]
+ },
+ }
+
+ def __init__(self, cfg: dict) -> None:
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`dict`): Config dict.
+ """
+
+ dataset_path = cfg.dataset.data_dir_prefix
+ rtg_scale = cfg.dataset.rtg_scale
+ self.context_len = cfg.dataset.context_len
+ self.env_type = cfg.dataset.env_type
+
+ if 'hdf5' in dataset_path: # for mujoco env
+ try:
+ import h5py
+ import collections
+ except ImportError:
+ import sys
+ logging.warning("not found h5py package, please install it trough `pip install h5py ")
+ sys.exit(1)
+ dataset = h5py.File(dataset_path, 'r')
+
+ N = dataset['rewards'].shape[0]
+ data_ = collections.defaultdict(list)
+
+ use_timeouts = False
+ if 'timeouts' in dataset:
+ use_timeouts = True
+
+ episode_step = 0
+ paths = []
+ for i in range(N):
+ done_bool = bool(dataset['terminals'][i])
+ if use_timeouts:
+ final_timestep = dataset['timeouts'][i]
+ else:
+ final_timestep = (episode_step == 1000 - 1)
+ for k in ['observations', 'actions', 'rewards', 'terminals']:
+ data_[k].append(dataset[k][i])
+ if done_bool or final_timestep:
+ episode_step = 0
+ episode_data = {}
+ for k in data_:
+ episode_data[k] = np.array(data_[k])
+ paths.append(episode_data)
+ data_ = collections.defaultdict(list)
+ episode_step += 1
+
+ self.trajectories = paths
+
+ # calculate state mean and variance and returns_to_go for all traj
+ states = []
+ for traj in self.trajectories:
+ traj_len = traj['observations'].shape[0]
+ states.append(traj['observations'])
+ # calculate returns to go and rescale them
+ traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale
+
+ # used for input normalization
+ states = np.concatenate(states, axis=0)
+ self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
+
+ # normalize states
+ for traj in self.trajectories:
+ traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std
+
+ elif 'pkl' in dataset_path:
+ if 'dqn' in dataset_path:
+ # load dataset
+ with open(dataset_path, 'rb') as f:
+ self.trajectories = pickle.load(f)
+
+ if isinstance(self.trajectories[0], list):
+ # for our collected dataset, e.g. cartpole/lunarlander case
+ trajectories_tmp = []
+
+ original_keys = ['obs', 'next_obs', 'action', 'reward']
+ keys = ['observations', 'next_observations', 'actions', 'rewards']
+ trajectories_tmp = [
+ {
+ key: np.stack(
+ [
+ self.trajectories[eps_index][transition_index][o_key]
+ for transition_index in range(len(self.trajectories[eps_index]))
+ ],
+ axis=0
+ )
+ for key, o_key in zip(keys, original_keys)
+ } for eps_index in range(len(self.trajectories))
+ ]
+ self.trajectories = trajectories_tmp
+
+ states = []
+ for traj in self.trajectories:
+ # traj_len = traj['observations'].shape[0]
+ states.append(traj['observations'])
+ # calculate returns to go and rescale them
+ traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale
+
+ # used for input normalization
+ states = np.concatenate(states, axis=0)
+ self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
+
+ # normalize states
+ for traj in self.trajectories:
+ traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std
+ else:
+ # load dataset
+ with open(dataset_path, 'rb') as f:
+ self.trajectories = pickle.load(f)
+
+ states = []
+ for traj in self.trajectories:
+ states.append(traj['observations'])
+ # calculate returns to go and rescale them
+ traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale
+
+ # used for input normalization
+ states = np.concatenate(states, axis=0)
+ self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
+
+ # normalize states
+ for traj in self.trajectories:
+ traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std
+ else:
+ # -- load data from memory (make more efficient)
+ obss = []
+ actions = []
+ returns = [0]
+ done_idxs = []
+ stepwise_returns = []
+
+ transitions_per_buffer = np.zeros(50, dtype=int)
+ num_trajectories = 0
+ while len(obss) < cfg.dataset.num_steps:
+ buffer_num = np.random.choice(np.arange(50 - cfg.dataset.num_buffers, 50), 1)[0]
+ i = transitions_per_buffer[buffer_num]
+ frb = FixedReplayBuffer(
+ data_dir=cfg.dataset.data_dir_prefix + '/1/replay_logs',
+ replay_suffix=buffer_num,
+ observation_shape=(84, 84),
+ stack_size=4,
+ update_horizon=1,
+ gamma=0.99,
+ observation_dtype=np.uint8,
+ batch_size=32,
+ replay_capacity=100000
+ )
+ if frb._loaded_buffers:
+ done = False
+ curr_num_transitions = len(obss)
+ trajectories_to_load = cfg.dataset.trajectories_per_buffer
+ while not done:
+ states, ac, ret, next_states, next_action, next_reward, terminal, indices = \
+ frb.sample_transition_batch(batch_size=1, indices=[i])
+ states = states.transpose((0, 3, 1, 2))[0] # (1, 84, 84, 4) --> (4, 84, 84)
+ obss.append(states)
+ actions.append(ac[0])
+ stepwise_returns.append(ret[0])
+ if terminal[0]:
+ done_idxs.append(len(obss))
+ returns.append(0)
+ if trajectories_to_load == 0:
+ done = True
+ else:
+ trajectories_to_load -= 1
+ returns[-1] += ret[0]
+ i += 1
+ if i >= 100000:
+ obss = obss[:curr_num_transitions]
+ actions = actions[:curr_num_transitions]
+ stepwise_returns = stepwise_returns[:curr_num_transitions]
+ returns[-1] = 0
+ i = transitions_per_buffer[buffer_num]
+ done = True
+ num_trajectories += (cfg.dataset.trajectories_per_buffer - trajectories_to_load)
+ transitions_per_buffer[buffer_num] = i
+
+ actions = np.array(actions)
+ returns = np.array(returns)
+ stepwise_returns = np.array(stepwise_returns)
+ done_idxs = np.array(done_idxs)
+
+ # -- create reward-to-go dataset
+ start_index = 0
+ rtg = np.zeros_like(stepwise_returns)
+ for i in done_idxs:
+ i = int(i)
+ curr_traj_returns = stepwise_returns[start_index:i]
+ for j in range(i - 1, start_index - 1, -1): # start from i-1
+ rtg_j = curr_traj_returns[j - start_index:i - start_index]
+ rtg[j] = sum(rtg_j)
+ start_index = i
+
+ # -- create timestep dataset
+ start_index = 0
+ timesteps = np.zeros(len(actions) + 1, dtype=int)
+ for i in done_idxs:
+ i = int(i)
+ timesteps[start_index:i + 1] = np.arange(i + 1 - start_index)
+ start_index = i + 1
+
+ self.obss = obss
+ self.actions = actions
+ self.done_idxs = done_idxs
+ self.rtgs = rtg
+ self.timesteps = timesteps
+ # return obss, actions, returns, done_idxs, rtg, timesteps
+
+ def get_max_timestep(self) -> int:
+ """
+ Overview:
+ Get the max timestep of the dataset.
+ """
+
+ return max(self.timesteps)
+
+ def get_state_stats(self) -> Tuple[np.ndarray, np.ndarray]:
+ """
+ Overview:
+ Get the state mean and std of the dataset.
+ """
+
+ return deepcopy(self.state_mean), deepcopy(self.state_std)
+
+ def get_d4rl_dataset_stats(self, env_d4rl_name: str) -> Dict[str, list]:
+ """
+ Overview:
+ Get the d4rl dataset stats.
+ Arguments:
+ - env_d4rl_name (:obj:`str`): The d4rl env name.
+ """
+
+ return self.D4RL_DATASET_STATS[env_d4rl_name]
+
+ def __len__(self) -> int:
+ """
+ Overview:
+ Get the length of the dataset.
+ """
+
+ if self.env_type != 'atari':
+ return len(self.trajectories)
+ else:
+ return len(self.obss) - self.context_len
+
+ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
+ """
+ Overview:
+ Get the item of the dataset.
+ Arguments:
+ - idx (:obj:`int`): The index of the dataset.
+ """
+
+ if self.env_type != 'atari':
+ traj = self.trajectories[idx]
+ traj_len = traj['observations'].shape[0]
+
+ if traj_len > self.context_len:
+ # sample random index to slice trajectory
+ si = np.random.randint(0, traj_len - self.context_len)
+
+ states = torch.from_numpy(traj['observations'][si:si + self.context_len])
+ actions = torch.from_numpy(traj['actions'][si:si + self.context_len])
+ returns_to_go = torch.from_numpy(traj['returns_to_go'][si:si + self.context_len])
+ timesteps = torch.arange(start=si, end=si + self.context_len, step=1)
+
+ # all ones since no padding
+ traj_mask = torch.ones(self.context_len, dtype=torch.long)
+
+ else:
+ padding_len = self.context_len - traj_len
+
+ # padding with zeros
+ states = torch.from_numpy(traj['observations'])
+ states = torch.cat(
+ [states, torch.zeros(([padding_len] + list(states.shape[1:])), dtype=states.dtype)], dim=0
+ )
+
+ actions = torch.from_numpy(traj['actions'])
+ actions = torch.cat(
+ [actions, torch.zeros(([padding_len] + list(actions.shape[1:])), dtype=actions.dtype)], dim=0
+ )
+
+ returns_to_go = torch.from_numpy(traj['returns_to_go'])
+ returns_to_go = torch.cat(
+ [
+ returns_to_go,
+ torch.zeros(([padding_len] + list(returns_to_go.shape[1:])), dtype=returns_to_go.dtype)
+ ],
+ dim=0
+ )
+
+ timesteps = torch.arange(start=0, end=self.context_len, step=1)
+
+ traj_mask = torch.cat(
+ [torch.ones(traj_len, dtype=torch.long),
+ torch.zeros(padding_len, dtype=torch.long)], dim=0
+ )
+ return timesteps, states, actions, returns_to_go, traj_mask
+ else: # mean cost less than 0.001s
+ block_size = self.context_len
+ done_idx = idx + block_size
+ for i in self.done_idxs:
+ if i > idx: # first done_idx greater than idx
+ done_idx = min(int(i), done_idx)
+ break
+ idx = done_idx - block_size
+ states = torch.as_tensor(
+ np.array(self.obss[idx:done_idx]), dtype=torch.float32
+ ).view(block_size, -1) # (block_size, 4*84*84)
+ states = states / 255.
+ actions = torch.as_tensor(self.actions[idx:done_idx], dtype=torch.long).unsqueeze(1) # (block_size, 1)
+ rtgs = torch.as_tensor(self.rtgs[idx:done_idx], dtype=torch.float32).unsqueeze(1)
+ timesteps = torch.as_tensor(self.timesteps[idx:idx + 1], dtype=torch.int64).unsqueeze(1)
+ traj_mask = torch.ones(self.context_len, dtype=torch.long)
+ return timesteps, states, actions, rtgs, traj_mask
+
+
+@DATASET_REGISTRY.register('d4rl_diffuser')
+class D4RLDiffuserDataset(Dataset):
+ """
+ Overview:
+ D4RL diffuser dataset, which is used for offline RL algorithms.
+ Interfaces:
+ ``__init__``, ``__len__``, ``__getitem__``
+ """
+
+ def __init__(self, dataset_path: str, context_len: int, rtg_scale: float) -> None:
+ """
+ Overview:
+ Initialization method of D4RLDiffuserDataset.
+ Arguments:
+ - dataset_path (:obj:`str`): The dataset path.
+ - context_len (:obj:`int`): The length of the context.
+ - rtg_scale (:obj:`float`): The scale of the returns to go.
+ """
+
+ self.context_len = context_len
+
+ # load dataset
+ with open(dataset_path, 'rb') as f:
+ self.trajectories = pickle.load(f)
+
+ if isinstance(self.trajectories[0], list):
+ # for our collected dataset, e.g. cartpole/lunarlander case
+ trajectories_tmp = []
+
+ original_keys = ['obs', 'next_obs', 'action', 'reward']
+ keys = ['observations', 'next_observations', 'actions', 'rewards']
+ for key, o_key in zip(keys, original_keys):
+ trajectories_tmp = [
+ {
+ key: np.stack(
+ [
+ self.trajectories[eps_index][transition_index][o_key]
+ for transition_index in range(len(self.trajectories[eps_index]))
+ ],
+ axis=0
+ )
+ } for eps_index in range(len(self.trajectories))
+ ]
+ self.trajectories = trajectories_tmp
+
+ states = []
+ for traj in self.trajectories:
+ traj_len = traj['observations'].shape[0]
+ states.append(traj['observations'])
+ # calculate returns to go and rescale them
+ traj['returns_to_go'] = discount_cumsum(traj['rewards'], 1.0) / rtg_scale
+
+ # used for input normalization
+ states = np.concatenate(states, axis=0)
+ self.state_mean, self.state_std = np.mean(states, axis=0), np.std(states, axis=0) + 1e-6
+
+ # normalize states
+ for traj in self.trajectories:
+ traj['observations'] = (traj['observations'] - self.state_mean) / self.state_std
+
+
+class FixedReplayBuffer(object):
+ """
+ Overview:
+ Object composed of a list of OutofGraphReplayBuffers.
+ Interfaces:
+ ``__init__``, ``get_transition_elements``, ``sample_transition_batch``
+ """
+
+ def __init__(self, data_dir, replay_suffix, *args, **kwargs): # pylint: disable=keyword-arg-before-vararg
+ """
+ Overview:
+ Initialize the FixedReplayBuffer class.
+ Arguments:
+ - data_dir (:obj:`str`): log Directory from which to load the replay buffer.
+ - replay_suffix (:obj:`int`): If not None, then only load the replay buffer \
+ corresponding to the specific suffix in data directory.
+ - args (:obj:`list`): Arbitrary extra arguments.
+ - kwargs (:obj:`dict`): Arbitrary keyword arguments.
+
+ """
+
+ self._args = args
+ self._kwargs = kwargs
+ self._data_dir = data_dir
+ self._loaded_buffers = False
+ self.add_count = np.array(0)
+ self._replay_suffix = replay_suffix
+ if not self._loaded_buffers:
+ if replay_suffix is not None:
+ assert replay_suffix >= 0, 'Please pass a non-negative replay suffix'
+ self.load_single_buffer(replay_suffix)
+ else:
+ pass
+ # self._load_replay_buffers(num_buffers=50)
+
+ def load_single_buffer(self, suffix):
+ """
+ Overview:
+ Load a single replay buffer.
+ Arguments:
+ - suffix (:obj:`int`): The suffix of the replay buffer.
+ """
+
+ replay_buffer = self._load_buffer(suffix)
+ if replay_buffer is not None:
+ self._replay_buffers = [replay_buffer]
+ self.add_count = replay_buffer.add_count
+ self._num_replay_buffers = 1
+ self._loaded_buffers = True
+
+ def _load_buffer(self, suffix):
+ """
+ Overview:
+ Loads a OutOfGraphReplayBuffer replay buffer.
+ Arguments:
+ - suffix (:obj:`int`): The suffix of the replay buffer.
+ """
+
+ try:
+ from dopamine.replay_memory import circular_replay_buffer
+ STORE_FILENAME_PREFIX = circular_replay_buffer.STORE_FILENAME_PREFIX
+ # pytype: disable=attribute-error
+ replay_buffer = circular_replay_buffer.OutOfGraphReplayBuffer(*self._args, **self._kwargs)
+ replay_buffer.load(self._data_dir, suffix)
+ # pytype: enable=attribute-error
+ return replay_buffer
+ # except tf.errors.NotFoundError:
+ except:
+ raise ('can not load')
+
+ def get_transition_elements(self):
+ """
+ Overview:
+ Returns the transition elements.
+ """
+
+ return self._replay_buffers[0].get_transition_elements()
+
+ def sample_transition_batch(self, batch_size=None, indices=None):
+ """
+ Overview:
+ Returns a batch of transitions (including any extra contents).
+ Arguments:
+ - batch_size (:obj:`int`): The batch size.
+ - indices (:obj:`list`): The indices of the batch.
+ """
+
+ buffer_index = np.random.randint(self._num_replay_buffers)
+ return self._replay_buffers[buffer_index].sample_transition_batch(batch_size=batch_size, indices=indices)
+
+
+class PCDataset(Dataset):
+ """
+ Overview:
+ Dataset for Procedure Cloning.
+ Interfaces:
+ ``__init__``, ``__len__``, ``__getitem__``
+ """
+
+ def __init__(self, all_data):
+ """
+ Overview:
+ Initialization method of PCDataset.
+ Arguments:
+ - all_data (:obj:`tuple`): The tuple of all data.
+ """
+
+ self._data = all_data
+
+ def __getitem__(self, item):
+ """
+ Overview:
+ Get the item of the dataset.
+ Arguments:
+ - item (:obj:`int`): The index of the dataset.
+ """
+
+ return {'obs': self._data[0][item], 'bfs_in': self._data[1][item], 'bfs_out': self._data[2][item]}
+
+ def __len__(self):
+ """
+ Overview:
+ Get the length of the dataset.
+ """
+
+ return self._data[0].shape[0]
+
+
+def load_bfs_datasets(train_seeds=1, test_seeds=5):
+ """
+ Overview:
+ Load BFS datasets.
+ Arguments:
+ - train_seeds (:obj:`int`): The number of train seeds.
+ - test_seeds (:obj:`int`): The number of test seeds.
+ """
+
+ from dizoo.maze.envs import Maze
+
+ def load_env(seed):
+ ccc = easydict.EasyDict({'size': 16})
+ e = Maze(ccc)
+ e.seed(seed)
+ e.reset()
+ return e
+
+ envs = [load_env(i) for i in range(train_seeds + test_seeds)]
+
+ observations_train = []
+ observations_test = []
+ bfs_input_maps_train = []
+ bfs_input_maps_test = []
+ bfs_output_maps_train = []
+ bfs_output_maps_test = []
+ for idx, env in enumerate(envs):
+ if idx < train_seeds:
+ observations = observations_train
+ bfs_input_maps = bfs_input_maps_train
+ bfs_output_maps = bfs_output_maps_train
+ else:
+ observations = observations_test
+ bfs_input_maps = bfs_input_maps_test
+ bfs_output_maps = bfs_output_maps_test
+
+ start_obs = env.process_states(env._get_obs(), env.get_maze_map())
+ _, track_back = get_vi_sequence(env, start_obs)
+ env_observations = torch.stack([track_back[i][0] for i in range(len(track_back))], dim=0)
+
+ for i in range(env_observations.shape[0]):
+ bfs_sequence, _ = get_vi_sequence(env, env_observations[i].numpy().astype(np.int32)) # [L, W, W]
+ bfs_input_map = env.n_action * np.ones([env.size, env.size], dtype=np.long)
+
+ for j in range(bfs_sequence.shape[0]):
+ bfs_input_maps.append(torch.from_numpy(bfs_input_map))
+ bfs_output_maps.append(torch.from_numpy(bfs_sequence[j]))
+ observations.append(env_observations[i])
+ bfs_input_map = bfs_sequence[j]
+
+ train_data = PCDataset(
+ (
+ torch.stack(observations_train, dim=0),
+ torch.stack(bfs_input_maps_train, dim=0),
+ torch.stack(bfs_output_maps_train, dim=0),
+ )
+ )
+ test_data = PCDataset(
+ (
+ torch.stack(observations_test, dim=0),
+ torch.stack(bfs_input_maps_test, dim=0),
+ torch.stack(bfs_output_maps_test, dim=0),
+ )
+ )
+
+ return train_data, test_data
+
+
+@DATASET_REGISTRY.register('bco')
+class BCODataset(Dataset):
+ """
+ Overview:
+ Dataset for Behavioral Cloning from Observation.
+ Interfaces:
+ ``__init__``, ``__len__``, ``__getitem__``
+ Properties:
+ - obs (:obj:`np.ndarray`): The observation array.
+ - action (:obj:`np.ndarray`): The action array.
+ """
+
+ def __init__(self, data=None):
+ """
+ Overview:
+ Initialization method of BCODataset.
+ Arguments:
+ - data (:obj:`dict`): The data dict.
+ """
+
+ if data is None:
+ raise ValueError('Dataset can not be empty!')
+ else:
+ self._data = data
+
+ def __len__(self):
+ """
+ Overview:
+ Get the length of the dataset.
+ """
+
+ return len(self._data['obs'])
+
+ def __getitem__(self, idx):
+ """
+ Overview:
+ Get the item of the dataset.
+ Arguments:
+ - idx (:obj:`int`): The index of the dataset.
+ """
+
+ return {k: self._data[k][idx] for k in self._data.keys()}
+
+ @property
+ def obs(self):
+ """
+ Overview:
+ Get the observation array.
+ """
+
+ return self._data['obs']
+
+ @property
+ def action(self):
+ """
+ Overview:
+ Get the action array.
+ """
+
+ return self._data['action']
+
+
+@DATASET_REGISTRY.register('diffuser_traj')
+class SequenceDataset(torch.utils.data.Dataset):
+ """
+ Overview:
+ Dataset for diffuser.
+ Interfaces:
+ ``__init__``, ``__len__``, ``__getitem__``
+ """
+
+ def __init__(self, cfg):
+ """
+ Overview:
+ Initialization method of SequenceDataset.
+ Arguments:
+ - cfg (:obj:`dict`): The config dict.
+ """
+
+ import gym
+
+ env_id = cfg.env.env_id
+ data_path = cfg.policy.collect.get('data_path', None)
+ env = gym.make(env_id)
+
+ dataset = env.get_dataset()
+
+ self.returns_scale = cfg.env.returns_scale
+ self.horizon = cfg.env.horizon
+ self.max_path_length = cfg.env.max_path_length
+ self.discount = cfg.policy.learn.discount_factor
+ self.discounts = self.discount ** np.arange(self.max_path_length)[:, None]
+ self.use_padding = cfg.env.use_padding
+ self.include_returns = cfg.env.include_returns
+ self.env_id = cfg.env.env_id
+ itr = self.sequence_dataset(env, dataset)
+ self.n_episodes = 0
+
+ fields = {}
+ for k in dataset.keys():
+ if 'metadata' in k:
+ continue
+ fields[k] = []
+ fields['path_lengths'] = []
+
+ for i, episode in enumerate(itr):
+ path_length = len(episode['observations'])
+ assert path_length <= self.max_path_length
+ fields['path_lengths'].append(path_length)
+ for key, val in episode.items():
+ if key not in fields:
+ fields[key] = []
+ if val.ndim < 2:
+ val = np.expand_dims(val, axis=-1)
+ shape = (self.max_path_length, val.shape[-1])
+ arr = np.zeros(shape, dtype=np.float32)
+ arr[:path_length] = val
+ fields[key].append(arr)
+ if episode['terminals'].any() and cfg.env.termination_penalty and 'timeouts' in episode:
+ assert not episode['timeouts'].any(), 'Penalized a timeout episode for early termination'
+ fields['rewards'][-1][path_length - 1] += cfg.env.termination_penalty
+ self.n_episodes += 1
+
+ for k in fields.keys():
+ fields[k] = np.array(fields[k])
+
+ self.normalizer = DatasetNormalizer(fields, cfg.policy.normalizer, path_lengths=fields['path_lengths'])
+ self.indices = self.make_indices(fields['path_lengths'], self.horizon)
+
+ self.observation_dim = cfg.env.obs_dim
+ self.action_dim = cfg.env.action_dim
+ self.fields = fields
+ self.normalize()
+ self.normed = False
+ if cfg.env.normed:
+ self.vmin, self.vmax = self._get_bounds()
+ self.normed = True
+
+ # shapes = {key: val.shape for key, val in self.fields.items()}
+ # print(f'[ datasets/mujoco ] Dataset fields: {shapes}')
+
+ def sequence_dataset(self, env, dataset=None):
+ """
+ Overview:
+ Sequence the dataset.
+ Arguments:
+ - env (:obj:`gym.Env`): The gym env.
+ """
+
+ import collections
+ N = dataset['rewards'].shape[0]
+ if 'maze2d' in env.spec.id:
+ dataset = self.maze2d_set_terminals(env, dataset)
+ data_ = collections.defaultdict(list)
+
+ # The newer version of the dataset adds an explicit
+ # timeouts field. Keep old method for backwards compatability.
+ use_timeouts = 'timeouts' in dataset
+
+ episode_step = 0
+ for i in range(N):
+ done_bool = bool(dataset['terminals'][i])
+ if use_timeouts:
+ final_timestep = dataset['timeouts'][i]
+ else:
+ final_timestep = (episode_step == env._max_episode_steps - 1)
+
+ for k in dataset:
+ if 'metadata' in k:
+ continue
+ data_[k].append(dataset[k][i])
+
+ if done_bool or final_timestep:
+ episode_step = 0
+ episode_data = {}
+ for k in data_:
+ episode_data[k] = np.array(data_[k])
+ if 'maze2d' in env.spec.id:
+ episode_data = self.process_maze2d_episode(episode_data)
+ yield episode_data
+ data_ = collections.defaultdict(list)
+
+ episode_step += 1
+
+ def maze2d_set_terminals(self, env, dataset):
+ """
+ Overview:
+ Set the terminals for maze2d.
+ Arguments:
+ - env (:obj:`gym.Env`): The gym env.
+ - dataset (:obj:`dict`): The dataset dict.
+ """
+
+ goal = env.get_target()
+ threshold = 0.5
+
+ xy = dataset['observations'][:, :2]
+ distances = np.linalg.norm(xy - goal, axis=-1)
+ at_goal = distances < threshold
+ timeouts = np.zeros_like(dataset['timeouts'])
+
+ # timeout at time t iff
+ # at goal at time t and
+ # not at goal at time t + 1
+ timeouts[:-1] = at_goal[:-1] * ~at_goal[1:]
+
+ timeout_steps = np.where(timeouts)[0]
+ path_lengths = timeout_steps[1:] - timeout_steps[:-1]
+
+ print(
+ f'[ utils/preprocessing ] Segmented {env.spec.id} | {len(path_lengths)} paths | '
+ f'min length: {path_lengths.min()} | max length: {path_lengths.max()}'
+ )
+
+ dataset['timeouts'] = timeouts
+ return dataset
+
+ def process_maze2d_episode(self, episode):
+ """
+ Overview:
+ Process the maze2d episode, adds in `next_observations` field to episode.
+ Arguments:
+ - episode (:obj:`dict`): The episode dict.
+ """
+
+ assert 'next_observations' not in episode
+ length = len(episode['observations'])
+ next_observations = episode['observations'][1:].copy()
+ for key, val in episode.items():
+ episode[key] = val[:-1]
+ episode['next_observations'] = next_observations
+ return episode
+
+ def normalize(self, keys=['observations', 'actions']):
+ """
+ Overview:
+ Normalize the dataset, normalize fields that will be predicted by the diffusion model
+ Arguments:
+ - keys (:obj:`list`): The list of keys.
+ """
+
+ for key in keys:
+ array = self.fields[key].reshape(self.n_episodes * self.max_path_length, -1)
+ normed = self.normalizer.normalize(array, key)
+ self.fields[f'normed_{key}'] = normed.reshape(self.n_episodes, self.max_path_length, -1)
+
+ def make_indices(self, path_lengths, horizon):
+ """
+ Overview:
+ Make indices for sampling from dataset. Each index maps to a datapoint.
+ Arguments:
+ - path_lengths (:obj:`np.ndarray`): The path length array.
+ - horizon (:obj:`int`): The horizon.
+ """
+
+ indices = []
+ for i, path_length in enumerate(path_lengths):
+ max_start = min(path_length - 1, self.max_path_length - horizon)
+ if not self.use_padding:
+ max_start = min(max_start, path_length - horizon)
+ for start in range(max_start):
+ end = start + horizon
+ indices.append((i, start, end))
+ indices = np.array(indices)
+ return indices
+
+ def get_conditions(self, observations):
+ """
+ Overview:
+ Get the conditions on current observation for planning.
+ Arguments:
+ - observations (:obj:`np.ndarray`): The observation array.
+ """
+
+ if 'maze2d' in self.env_id:
+ return {'condition_id': [0, self.horizon - 1], 'condition_val': [observations[0], observations[-1]]}
+ else:
+ return {'condition_id': [0], 'condition_val': [observations[0]]}
+
+ def __len__(self):
+ """
+ Overview:
+ Get the length of the dataset.
+ """
+
+ return len(self.indices)
+
+ def _get_bounds(self):
+ """
+ Overview:
+ Get the bounds of the dataset.
+ """
+
+ print('[ datasets/sequence ] Getting value dataset bounds...', end=' ', flush=True)
+ vmin = np.inf
+ vmax = -np.inf
+ for i in range(len(self.indices)):
+ value = self.__getitem__(i)['returns'].item()
+ vmin = min(value, vmin)
+ vmax = max(value, vmax)
+ print('✓')
+ return vmin, vmax
+
+ def normalize_value(self, value):
+ """
+ Overview:
+ Normalize the value.
+ Arguments:
+ - value (:obj:`np.ndarray`): The value array.
+ """
+
+ # [0, 1]
+ normed = (value - self.vmin) / (self.vmax - self.vmin)
+ # [-1, 1]
+ normed = normed * 2 - 1
+ return normed
+
+ def __getitem__(self, idx, eps=1e-4):
+ """
+ Overview:
+ Get the item of the dataset.
+ Arguments:
+ - idx (:obj:`int`): The index of the dataset.
+ - eps (:obj:`float`): The epsilon.
+ """
+
+ path_ind, start, end = self.indices[idx]
+
+ observations = self.fields['normed_observations'][path_ind, start:end]
+ actions = self.fields['normed_actions'][path_ind, start:end]
+ done = self.fields['terminals'][path_ind, start:end]
+
+ # conditions = self.get_conditions(observations)
+ trajectories = np.concatenate([actions, observations], axis=-1)
+
+ if self.include_returns:
+ rewards = self.fields['rewards'][path_ind, start:]
+ discounts = self.discounts[:len(rewards)]
+ returns = (discounts * rewards).sum()
+ if self.normed:
+ returns = self.normalize_value(returns)
+ returns = np.array([returns / self.returns_scale], dtype=np.float32)
+ batch = {
+ 'trajectories': trajectories,
+ 'returns': returns,
+ 'done': done,
+ 'action': actions,
+ }
+ else:
+ batch = {
+ 'trajectories': trajectories,
+ 'done': done,
+ 'action': actions,
+ }
+
+ batch.update(self.get_conditions(observations))
+ return batch
+
+
+def hdf5_save(exp_data, expert_data_path):
+ """
+ Overview:
+ Save the data to hdf5.
+ """
+
+ try:
+ import h5py
+ except ImportError:
+ import sys
+ logging.warning("not found h5py package, please install it trough 'pip install h5py' ")
+ sys.exit(1)
+ dataset = dataset = h5py.File('%s_demos.hdf5' % expert_data_path.replace('.pkl', ''), 'w')
+ dataset.create_dataset('obs', data=np.array([d['obs'].numpy() for d in exp_data]), compression='gzip')
+ dataset.create_dataset('action', data=np.array([d['action'].numpy() for d in exp_data]), compression='gzip')
+ dataset.create_dataset('reward', data=np.array([d['reward'].numpy() for d in exp_data]), compression='gzip')
+ dataset.create_dataset('done', data=np.array([d['done'] for d in exp_data]), compression='gzip')
+ dataset.create_dataset('next_obs', data=np.array([d['next_obs'].numpy() for d in exp_data]), compression='gzip')
+
+
+def naive_save(exp_data, expert_data_path):
+ """
+ Overview:
+ Save the data to pickle.
+ """
+
+ with open(expert_data_path, 'wb') as f:
+ pickle.dump(exp_data, f)
+
+
+def offline_data_save_type(exp_data, expert_data_path, data_type='naive'):
+ """
+ Overview:
+ Save the offline data.
+ """
+
+ globals()[data_type + '_save'](exp_data, expert_data_path)
+
+
+def create_dataset(cfg, **kwargs) -> Dataset:
+ """
+ Overview:
+ Create dataset.
+ """
+
+ cfg = EasyDict(cfg)
+ import_module(cfg.get('import_names', []))
+ return DATASET_REGISTRY.build(cfg.policy.collect.data_type, cfg=cfg, **kwargs)
diff --git a/DI-engine/ding/utils/data/structure/__init__.py b/DI-engine/ding/utils/data/structure/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cc58828a61489e4b269786d3e40dfd149f7e9d6
--- /dev/null
+++ b/DI-engine/ding/utils/data/structure/__init__.py
@@ -0,0 +1,2 @@
+from .cache import Cache
+from .lifo_deque import LifoDeque
diff --git a/DI-engine/ding/utils/data/structure/cache.py b/DI-engine/ding/utils/data/structure/cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..836261e6159c156cc4b1dfe6a9b352c86cef4c6e
--- /dev/null
+++ b/DI-engine/ding/utils/data/structure/cache.py
@@ -0,0 +1,142 @@
+import time
+from queue import Queue
+from threading import Thread
+from typing import Any
+
+from ding.utils import LockContext, LockContextType
+
+
+class Cache:
+ """
+ Overview:
+ Data cache for reducing concurrent pressure, with timeout and full queue eject mechanism
+ Interfaces:
+ ``__init__``, ``push_data``, ``get_cached_data_iter``, ``run``, ``close``
+ Property:
+ remain_data_count
+ """
+
+ def __init__(self, maxlen: int, timeout: float, monitor_interval: float = 1.0, _debug: bool = False) -> None:
+ """
+ Overview:
+ Initialize the cache object.
+ Arguments:
+ - maxlen (:obj:`int`): Maximum length of the cache queue.
+ - timeout (:obj:`float`): Maximum second of the data can remain in the cache.
+ - monitor_interval (:obj:`float`): Interval of the timeout monitor thread checks the time.
+ - _debug (:obj:`bool`): Whether to use debug mode or not, which enables debug print info.
+ """
+ assert maxlen > 0
+ self.maxlen = maxlen
+ self.timeout = timeout
+ self.monitor_interval = monitor_interval
+ self.debug = _debug
+ # two separate receive and send queue for reducing interaction frequency and interference
+ self.receive_queue = Queue(maxlen)
+ self.send_queue = Queue(maxlen)
+ self.receive_lock = LockContext(type_=LockContextType.THREAD_LOCK)
+ self._timeout_thread = Thread(target=self._timeout_monitor)
+ # the bool flag for gracefully shutting down the timeout monitor thread
+ self._timeout_thread_flag = True
+
+ def push_data(self, data: Any) -> None:
+ """
+ Overview:
+ Push data into receive queue, if the receive queue is full(after push), then push all the data
+ in receive queue into send queue.
+ Arguments:
+ - data (:obj:`Any`): The data which needs to be added into receive queue
+
+ .. tip::
+ thread-safe
+ """
+ with self.receive_lock:
+ # Push the data item and current time together into queue
+ self.receive_queue.put([data, time.time()])
+ if self.receive_queue.full():
+ self.dprint('send total receive_queue, current len:{}'.format(self.receive_queue.qsize()))
+ while not self.receive_queue.empty():
+ # Only send raw data to send queue
+ self.send_queue.put(self.receive_queue.get()[0])
+
+ def get_cached_data_iter(self) -> 'callable_iterator': # noqa
+ """
+ Overview:
+ Get the iterator of the send queue. Once a data is pushed into send queue, it can be accessed by
+ this iterator. 'STOP' is the end flag of this iterator.
+ Returns:
+ - iterator (:obj:`callable_iterator`) The send queue iterator.
+ """
+ return iter(self.send_queue.get, 'STOP')
+
+ def _timeout_monitor(self) -> None:
+ """
+ Overview:
+ The workflow of the timeout monitor thread.
+ """
+ # Loop until the flag is set to False
+ while self._timeout_thread_flag:
+ # A fixed check interval
+ time.sleep(self.monitor_interval)
+ with self.receive_lock:
+ # For non-empty receive_queue, check the time from head to tail(only access no pop) until finding
+ # the first data which is not timeout
+ while not self.receive_queue.empty():
+ # Check the time of the data remains in the receive_queue, if excesses the timeout then returns True
+ is_timeout = self._warn_if_timeout()
+ if not is_timeout:
+ break
+
+ def _warn_if_timeout(self) -> bool:
+ """
+ Overview:
+ Return whether is timeout.
+ Returns
+ - result: (:obj:`bool`) Whether is timeout.
+ """
+ wait_time = time.time() - self.receive_queue.queue[0][1]
+ if wait_time >= self.timeout:
+ self.dprint(
+ 'excess the maximum wait time, eject from the cache.(wait_time/timeout: {}/{}'.format(
+ wait_time, self.timeout
+ )
+ )
+ self.send_queue.put(self.receive_queue.get()[0])
+ return True
+ else:
+ return False
+
+ def run(self) -> None:
+ """
+ Overview:
+ Launch the cache internal thread, e.g. timeout monitor thread.
+ """
+ self._timeout_thread.start()
+
+ def close(self) -> None:
+ """
+ Overview:
+ Shut down the cache internal thread and send the end flag to send queue's iterator.
+ """
+ self._timeout_thread_flag = False
+ self.send_queue.put('STOP')
+
+ def dprint(self, s: str) -> None:
+ """
+ Overview:
+ In debug mode, print debug str.
+ Arguments:
+ - s (:obj:`str`): Debug info to be printed.
+ """
+ if self.debug:
+ print('[CACHE] ' + s)
+
+ @property
+ def remain_data_count(self) -> int:
+ """
+ Overview:
+ Return receive queue's remain data count
+ Returns:
+ - count (:obj:`int`): The size of the receive queue.
+ """
+ return self.receive_queue.qsize()
diff --git a/DI-engine/ding/utils/data/structure/lifo_deque.py b/DI-engine/ding/utils/data/structure/lifo_deque.py
new file mode 100644
index 0000000000000000000000000000000000000000..b18c4a0608de7e7b887a9cc77ef23537ccbc6603
--- /dev/null
+++ b/DI-engine/ding/utils/data/structure/lifo_deque.py
@@ -0,0 +1,15 @@
+from queue import LifoQueue
+from collections import deque
+
+
+class LifoDeque(LifoQueue):
+ """
+ Overview:
+ Like LifoQueue, but automatically replaces the oldest data when the queue is full.
+ Interfaces:
+ ``_init``, ``_put``, ``_get``
+ """
+
+ def _init(self, maxsize):
+ self.maxsize = maxsize + 1
+ self.queue = deque(maxlen=maxsize)
diff --git a/DI-engine/ding/utils/data/tests/dataloader_speed/experiment_dataloader_speed.py b/DI-engine/ding/utils/data/tests/dataloader_speed/experiment_dataloader_speed.py
new file mode 100644
index 0000000000000000000000000000000000000000..b94b1ce1acd94d926fb4c99b5a5cb0ae09a26ac0
--- /dev/null
+++ b/DI-engine/ding/utils/data/tests/dataloader_speed/experiment_dataloader_speed.py
@@ -0,0 +1,219 @@
+import time
+import torch
+import torch.nn as nn
+from torch.utils.data import DataLoader, Dataset
+from functools import partial
+from itertools import product
+import os.path as osp
+import os
+import random
+
+from ding.utils import EasyTimer, read_file
+from ding.utils.data import AsyncDataLoader
+
+exp_times = 10
+max_iter = 50
+num_workers = 8
+use_cuda = True
+
+# read_file_time, process_time, batch_size, chunk_size, env_name
+env_args = [
+ (0.0008, 0.005, 128, 32, "small"),
+ (0.0008, 0.05, 64, 16, "middle"),
+ (0.6, 0.2, 4, 1, "big16"),
+ (2, 0.25, 4, 1, "big64"),
+]
+data_infer_ratio_args = [1, 2, 4]
+
+args = [item for item in product(*[env_args, data_infer_ratio_args])]
+
+out_str_list = []
+
+
+class MyDataset(Dataset):
+
+ def __init__(self, file_time, process_time, batch_size, name):
+ self.data = torch.randn(256, 256)
+ self.file_time = file_time
+ self.process_time = process_time
+ self.batch_size = batch_size
+ self.path = osp.join(osp.dirname(__file__), "../traj_files/{}/data".format(name))
+ self.file_list = os.listdir(self.path)
+ self.file_sequence = random.sample(range(0, len(self.file_list)), len(self.file_list))
+ self.i = 0
+
+ def __len__(self):
+ return self.batch_size * max_iter * 2
+
+ def __getitem__(self, idx):
+ try:
+ s = read_file(osp.join(self.path, self.file_list[self.file_sequence[self.i]]))
+ except:
+ print("file read meets an error")
+ time.sleep(self.file_time)
+ self.i = (self.i + 1) % len(self.file_list)
+ time.sleep(self.process_time)
+ return [self.data, idx]
+
+
+class MyModel(nn.Module):
+
+ def __init__(self, infer_time):
+ super().__init__()
+ self.main = [nn.Linear(256, 256) for _ in range(10)]
+ self.main = nn.Sequential(*self.main)
+ self.infer_time = infer_time
+
+ def forward(self, x):
+ idx = x[1]
+ # No real infer here.
+ time.sleep(self.infer_time)
+ return [x, idx]
+
+
+def get_data_source(dataset):
+
+ def data_source_fn(batch_size):
+ return [partial(dataset.__getitem__, idx=i) for i in range(batch_size)]
+
+ return data_source_fn
+
+
+def entry(env, read_infer_ratio, use_cuda):
+ file_time, process_time, batch_size, chunk_size, data_name = env[0], env[1], env[2], env[3], env[4]
+ data_time = file_time + process_time
+ infer_time = data_time * (batch_size / num_workers) * 1.05 / read_infer_ratio
+ out_str = '\n===== each_data: {:.4f}({}), infer: {:.4f}, read/infer: {:.4f}, \
+ batch_size: {}, chunk_size: {} ====='.format(
+ data_time, data_name, infer_time, read_infer_ratio, batch_size, chunk_size
+ )
+ out_str_list.append(out_str)
+ print(out_str)
+
+ model = MyModel(infer_time)
+ if use_cuda:
+ model.cuda()
+ timer = EasyTimer()
+
+ # ### Our DataLoader ####
+ total_sum_time_list = []
+ total_data_time_list = []
+ total_infer_time_list = []
+ for _ in range(exp_times):
+ print('\t----- Our DataLoader -----')
+ dataset = MyDataset(file_time, process_time, batch_size, data_name)
+ data_source = get_data_source(dataset)
+ device = 'cuda' if use_cuda else 'cpu'
+ our_dataloader = AsyncDataLoader(
+ data_source, batch_size, device, num_workers=num_workers, chunk_size=chunk_size
+ )
+ iter = 0
+ total_data_time = 0.
+ total_infer_time = 0.
+ total_sum_time = 0.
+ while True:
+ with timer:
+ data = next(our_dataloader)
+ data_time = timer.value
+ with timer:
+ with torch.no_grad():
+ _, idx = model(data)
+ infer_time = timer.value
+ sum_time = data_time + infer_time
+ if iter > 5: # ignore start-5-iter time
+ total_data_time += data_time
+ total_infer_time += infer_time
+ print(
+ '\t\titer {:0>2d}, sum_time: {:.4f}, data_time: {:.4f}, infer_time: {:.4f}'.format(
+ iter, sum_time, data_time, infer_time
+ )
+ )
+ iter += 1
+ if iter == max_iter:
+ break
+ total_sum_time = total_data_time + total_infer_time
+ out_str = '\ttotal_sum_time: {:.4f}, total_data_time: {:.4f}, \
+ total_infer_time: {:.4f}, data/sum: {:.4f}'.format(
+ total_sum_time, total_data_time, total_infer_time, total_data_time / total_sum_time
+ )
+ # out_str_list.append(out_str)
+ print(out_str)
+ our_dataloader.__del__()
+ torch.cuda.empty_cache()
+
+ total_sum_time_list.append(total_sum_time)
+ total_data_time_list.append(total_data_time)
+ total_infer_time_list.append(total_infer_time)
+ total_sum_time = sum(total_sum_time_list) / len(total_sum_time_list)
+ total_data_time = sum(total_data_time_list) / len(total_data_time_list)
+ total_infer_time = sum(total_infer_time_list) / len(total_infer_time_list)
+ out_str = '\t(Our DataLoader {} average) total_sum_time: {:.4f}, \
+ total_data_time: {:.4f}, total_infer_time: {:.4f}, data/sum: {:.4f}'.format(
+ exp_times, total_sum_time, total_data_time, total_infer_time, total_data_time / total_sum_time
+ )
+ out_str_list.append(out_str)
+ print(out_str)
+
+ # ### PyTorch DataLoader ####
+ for real_num_workers in [0, 8]:
+ total_sum_time_list = []
+ total_data_time_list = []
+ total_infer_time_list = []
+ for _ in range(exp_times):
+ print('\t----- PyTorch DataLoader (num_workers = {}) -----'.format(real_num_workers))
+ dataset = MyDataset(file_time, process_time, batch_size, data_name)
+ torch_dataloader = DataLoader(dataset, batch_size, num_workers=real_num_workers)
+ torch_dataloader_iter = torch_dataloader.__iter__()
+ iter = 0
+ total_data_time = 0.
+ total_infer_time = 0.
+ total_sum_time = 0.
+ while True:
+ with timer:
+ data = next(torch_dataloader_iter)[0]
+ if use_cuda:
+ data = data.cuda()
+ data_time = timer.value
+ with timer:
+ with torch.no_grad():
+ _, idx = model(data)
+ infer_time = timer.value
+ sum_time = data_time + infer_time
+ if iter > 5: # ignore start-5-iter time
+ total_data_time += data_time
+ total_infer_time += infer_time
+ print(
+ '\t\titer {:0>2d}, sum_time: {:.4f}, data_time: {:.4f}, infer_time: {:.4f}'.format(
+ iter, sum_time, data_time, infer_time
+ )
+ )
+ iter += 1
+ if iter == max_iter:
+ break
+ total_sum_time = total_data_time + total_infer_time
+ out_str = '\ttotal_sum_time: {:.4f}, total_data_time: {:.4f}, \
+ total_infer_time: {:.4f}, data/sum: {:.4f}'.format(
+ total_sum_time, total_data_time, total_infer_time, total_data_time / total_sum_time
+ )
+ # out_str_list.append(out_str)
+ print(out_str)
+ torch.cuda.empty_cache()
+
+ total_sum_time_list.append(total_sum_time)
+ total_data_time_list.append(total_data_time)
+ total_infer_time_list.append(total_infer_time)
+ total_sum_time = sum(total_sum_time_list) / len(total_sum_time_list)
+ total_data_time = sum(total_data_time_list) / len(total_data_time_list)
+ total_infer_time = sum(total_infer_time_list) / len(total_infer_time_list)
+ out_str = '\t(PyTorch DataLoader baseline {} average) total_sum_time: {:.4f}, \
+ total_data_time: {:.4f}, total_infer_time: {:.4f}, data/sum: {:.4f}'.format(
+ exp_times, total_sum_time, total_data_time, total_infer_time, total_data_time / total_sum_time
+ )
+ out_str_list.append(out_str)
+ print(out_str)
+
+
+if __name__ == "__main__":
+ for env, read_infer_ratio in args:
+ entry(env, read_infer_ratio, use_cuda=use_cuda)
+ print("\n".join(out_str_list))
diff --git a/DI-engine/ding/utils/data/tests/test_cache.py b/DI-engine/ding/utils/data/tests/test_cache.py
new file mode 100644
index 0000000000000000000000000000000000000000..05f3e6f471127c2e6a9f1b9e2fa443442a9e42a6
--- /dev/null
+++ b/DI-engine/ding/utils/data/tests/test_cache.py
@@ -0,0 +1,58 @@
+import threading
+import time
+from threading import Thread
+
+import numpy as np
+import pytest
+
+from ding.utils.data.structure import Cache
+
+
+@pytest.mark.unittest
+class TestCache:
+ cache = Cache(16, 4, monitor_interval=1.0, _debug=True)
+ send_count = 0
+ produce_count = 0
+
+ def producer(self, id):
+ time.sleep(1)
+ begin_time = time.time()
+ count = 0
+ while time.time() - begin_time < 20:
+ t = np.random.randint(1, 6)
+ time.sleep(t)
+ print('[PRODUCER] thread {} use {} second to produce a data'.format(id, t))
+ self.cache.push_data({'data': []})
+ count += 1
+ print('[PRODUCER] thread {} finish job, total produce {} data'.format(id, count))
+ self.produce_count += count
+
+ def consumer(self):
+ for data in self.cache.get_cached_data_iter():
+ self.send_count += 1
+ print('[CONSUMER] cache send {}'.format(self.send_count))
+
+ def test(self):
+ producer_num = 8
+
+ self.cache.run()
+ threadings = [Thread(target=self.producer, args=(i, )) for i in range(producer_num)]
+ for t in threadings:
+ t.start()
+
+ consumer_thread = Thread(target=self.consumer)
+ consumer_thread.start()
+
+ for t in threadings:
+ t.join()
+
+ # wait timeout mechanism to clear the cache
+ time.sleep(4 + 1 + 0.1)
+
+ assert (self.cache.remain_data_count == 0)
+ assert (self.send_count == self.produce_count)
+
+ self.cache.close()
+ # wait the cache internal thread close and the consumer_thread get 'STOP' signal
+ time.sleep(1 + 0.5)
+ assert (not consumer_thread.is_alive())
diff --git a/DI-engine/ding/utils/data/tests/test_collate_fn.py b/DI-engine/ding/utils/data/tests/test_collate_fn.py
new file mode 100644
index 0000000000000000000000000000000000000000..83611377c12db856d594219d1da87152422ce3fa
--- /dev/null
+++ b/DI-engine/ding/utils/data/tests/test_collate_fn.py
@@ -0,0 +1,174 @@
+import pytest
+from collections import namedtuple
+import random
+import numpy as np
+import torch
+from ding.utils.data import timestep_collate, default_collate, default_decollate, diff_shape_collate
+
+B, T = 4, 3
+
+
+@pytest.mark.unittest
+class TestTimestepCollate:
+
+ def get_data(self):
+ data = {
+ 'obs': [torch.randn(4) for _ in range(T)],
+ 'reward': [torch.FloatTensor([0]) for _ in range(T)],
+ 'done': [False for _ in range(T)],
+ 'prev_state': [(torch.randn(3), torch.randn(3)) for _ in range(T)],
+ 'action': [[torch.randn(3), torch.randn(5)] for _ in range(T)],
+ }
+ return data
+
+ def get_multi_shape_state_data(self):
+ data = {
+ 'obs': [torch.randn(4) for _ in range(T)],
+ 'reward': [torch.FloatTensor([0]) for _ in range(T)],
+ 'done': [False for _ in range(T)],
+ 'prev_state': [
+ [(torch.randn(3), torch.randn(5)), (torch.randn(4), ), (torch.randn(5), torch.randn(6))]
+ for _ in range(T)
+ ],
+ 'action': [[torch.randn(3), torch.randn(5)] for _ in range(T)],
+ }
+ return data
+
+ def test(self):
+ batch = timestep_collate([self.get_data() for _ in range(B)])
+ assert isinstance(batch, dict)
+ assert set(batch.keys()) == set(['obs', 'reward', 'done', 'prev_state', 'action'])
+ assert batch['obs'].shape == (T, B, 4)
+ assert batch['reward'].shape == (T, B)
+ assert batch['done'].shape == (T, B) and batch['done'].dtype == torch.bool
+ assert isinstance(batch['prev_state'], list)
+ assert len(batch['prev_state']) == T and len(batch['prev_state'][0]) == B
+ assert isinstance(batch['action'], list) and len(batch['action']) == T
+ assert batch['action'][0][0].shape == (B, 3)
+ assert batch['action'][0][1].shape == (B, 5)
+
+ # hidden_state might contain multi prev_states with different shapes
+ batch = timestep_collate([self.get_multi_shape_state_data() for _ in range(B)])
+ assert isinstance(batch, dict)
+ assert set(batch.keys()) == set(['obs', 'reward', 'done', 'prev_state', 'action'])
+ assert batch['obs'].shape == (T, B, 4)
+ assert batch['reward'].shape == (T, B)
+ assert batch['done'].shape == (T, B) and batch['done'].dtype == torch.bool
+ assert isinstance(batch['prev_state'], list)
+ print(batch['prev_state'][0][0])
+ assert len(batch['prev_state']) == T and len(batch['prev_state'][0]
+ ) == B and len(batch['prev_state'][0][0]) == 3
+ assert isinstance(batch['action'], list) and len(batch['action']) == T
+ assert batch['action'][0][0].shape == (B, 3)
+ assert batch['action'][0][1].shape == (B, 5)
+
+
+@pytest.mark.unittest
+class TestDefaultCollate:
+
+ def test_numpy(self):
+ data = [np.random.randn(4, 3).astype(np.float64) for _ in range(5)]
+ data = default_collate(data)
+ assert data.shape == (5, 4, 3)
+ assert data.dtype == torch.float64
+ data = [float(np.random.randn(1)[0]) for _ in range(6)]
+ data = default_collate(data)
+ assert data.shape == (6, )
+ assert data.dtype == torch.float32
+ with pytest.raises(TypeError):
+ default_collate([np.array(['str']) for _ in range(3)])
+
+ def test_basic(self):
+ data = [random.random() for _ in range(3)]
+ data = default_collate(data)
+ assert data.shape == (3, )
+ assert data.dtype == torch.float32
+ data = [random.randint(0, 10) for _ in range(3)]
+ data = default_collate(data)
+ assert data.shape == (3, )
+ assert data.dtype == torch.int64
+ data = ['str' for _ in range(4)]
+ data = default_collate(data)
+ assert len(data) == 4
+ assert all([s == 'str' for s in data])
+ T = namedtuple('T', ['x', 'y'])
+ data = [T(1, 2) for _ in range(4)]
+ data = default_collate(data)
+ assert isinstance(data, T)
+ assert data.x.shape == (4, ) and data.x.eq(1).sum() == 4
+ assert data.y.shape == (4, ) and data.y.eq(2).sum() == 4
+ with pytest.raises(TypeError):
+ default_collate([object() for _ in range(4)])
+
+ data = [{'collate_ignore_data': random.random()} for _ in range(4)]
+ data = default_collate(data)
+ assert isinstance(data, dict)
+ assert len(data['collate_ignore_data']) == 4
+
+
+@pytest.mark.unittest
+class TestDefaultDecollate:
+
+ def test(self):
+ with pytest.raises(TypeError):
+ default_decollate([object() for _ in range(4)])
+ data = torch.randn(4, 3, 5)
+ data = default_decollate(data)
+ print([d.shape for d in data])
+ assert len(data) == 4 and all([d.shape == (3, 5) for d in data])
+ data = [torch.randn(8, 2, 4), torch.randn(8, 5)]
+ data = default_decollate(data)
+ assert len(data) == 8 and all([d[0].shape == (2, 4) and d[1].shape == (5, ) for d in data])
+ data = {
+ 'logit': torch.randn(4, 13),
+ 'action': torch.randint(0, 13, size=(4, )),
+ 'prev_state': [(torch.zeros(3, 1, 12), torch.zeros(3, 1, 12)) for _ in range(4)],
+ }
+ data = default_decollate(data)
+ assert len(data) == 4 and isinstance(data, list)
+ assert all([d['logit'].shape == (13, ) for d in data])
+ assert all([d['action'].shape == (1, ) for d in data])
+ assert all([len(d['prev_state']) == 2 and d['prev_state'][0].shape == (3, 1, 12) for d in data])
+
+
+@pytest.mark.unittest
+class TestDiffShapeCollate:
+
+ def test(self):
+ with pytest.raises(TypeError):
+ diff_shape_collate([object() for _ in range(4)])
+ data = [
+ {
+ 'item1': torch.randn(4),
+ 'item2': None,
+ 'item3': torch.randn(3),
+ 'item4': np.random.randn(5, 6)
+ },
+ {
+ 'item1': torch.randn(5),
+ 'item2': torch.randn(6),
+ 'item3': torch.randn(3),
+ 'item4': np.random.randn(5, 6)
+ },
+ ]
+ data = diff_shape_collate(data)
+ assert isinstance(data['item1'], list) and len(data['item1']) == 2
+ assert isinstance(data['item2'], list) and len(data['item2']) == 2 and data['item2'][0] is None
+ assert data['item3'].shape == (2, 3)
+ assert data['item4'].shape == (2, 5, 6)
+ data = [
+ {
+ 'item1': 1,
+ 'item2': 3,
+ 'item3': 2.0
+ },
+ {
+ 'item1': None,
+ 'item2': 4,
+ 'item3': 2.0
+ },
+ ]
+ data = diff_shape_collate(data)
+ assert isinstance(data['item1'], list) and len(data['item1']) == 2 and data['item1'][1] is None
+ assert data['item2'].shape == (2, ) and data['item2'].dtype == torch.int64
+ assert data['item3'].shape == (2, ) and data['item3'].dtype == torch.float32
diff --git a/DI-engine/ding/utils/data/tests/test_dataloader.py b/DI-engine/ding/utils/data/tests/test_dataloader.py
new file mode 100644
index 0000000000000000000000000000000000000000..9fc78113dfd5557f93a482ade63202a071c7754e
--- /dev/null
+++ b/DI-engine/ding/utils/data/tests/test_dataloader.py
@@ -0,0 +1,104 @@
+import pytest
+import threading
+import time
+import torch
+import torch.nn as nn
+from functools import partial
+from itertools import product
+
+from ding.utils import EasyTimer
+from ding.utils.data import AsyncDataLoader
+
+batch_size_args = [3, 6]
+num_workers_args = [0, 4]
+chunk_size_args = [1, 3]
+args = [item for item in product(*[batch_size_args, num_workers_args, chunk_size_args])]
+unittest_args = [item for item in product(*[[3], [2], [1]])]
+
+
+class Dataset(object):
+
+ def __init__(self):
+ self.data = torch.randn(256, 256)
+
+ def __len__(self):
+ return 100
+
+ def __getitem__(self, idx):
+ time.sleep(0.5)
+ return [self.data, idx]
+
+
+class TestAsyncDataLoader:
+
+ def get_data_source(self):
+ dataset = Dataset()
+
+ def data_source_fn(batch_size):
+ return [partial(dataset.__getitem__, idx=i) for i in range(batch_size)]
+
+ return data_source_fn
+
+ def get_model(self):
+
+ class Model(nn.Module):
+
+ def __init__(self):
+ super(Model, self).__init__()
+ self.main = [nn.Linear(256, 256) for _ in range(10)]
+ self.main = nn.Sequential(*self.main)
+
+ def forward(self, x):
+ idx = x[1]
+ x = self.main(x[0])
+ time.sleep(1)
+ return [x, idx]
+
+ return Model()
+
+ # @pytest.mark.unittest
+ @pytest.mark.parametrize('batch_size, num_workers, chunk_size', unittest_args)
+ def test_cpu(self, batch_size, num_workers, chunk_size):
+ self.entry(batch_size, num_workers, chunk_size, use_cuda=False)
+
+ @pytest.mark.cudatest
+ @pytest.mark.parametrize('batch_size, num_workers, chunk_size', args)
+ def test_gpu(self, batch_size, num_workers, chunk_size):
+ self.entry(batch_size, num_workers, chunk_size, use_cuda=True)
+ torch.cuda.empty_cache()
+
+ def entry(self, batch_size, num_workers, chunk_size, use_cuda):
+ model = self.get_model()
+ if use_cuda:
+ model.cuda()
+ timer = EasyTimer()
+ data_source = self.get_data_source()
+ device = 'cuda' if use_cuda else 'cpu'
+ dataloader = AsyncDataLoader(data_source, batch_size, device, num_workers=num_workers, chunk_size=chunk_size)
+ count = 0
+ total_data_time = 0.
+ while True:
+ with timer:
+ data = next(dataloader)
+ data_time = timer.value
+ if count > 2: # ignore start-3 time
+ total_data_time += data_time
+ with timer:
+ with torch.no_grad():
+ _, idx = model(data)
+ if use_cuda:
+ idx = idx.cpu()
+ sorted_idx = torch.sort(idx)[0]
+ assert sorted_idx.eq(torch.arange(batch_size)).sum() == batch_size, idx
+ model_time = timer.value
+ print('count {}, data_time: {}, model_time: {}'.format(count, data_time, model_time))
+ count += 1
+ if count == 10:
+ break
+ if num_workers < 1:
+ assert total_data_time <= 7 * batch_size * 0.5 + 7 * 0.01 - 7 * 1
+ else:
+ assert total_data_time <= 7 * 0.008
+ dataloader.__del__()
+ time.sleep(0.5)
+ assert len(threading.enumerate()) <= 2, threading.enumerate()
diff --git a/DI-engine/ding/utils/data/tests/test_dataset.py b/DI-engine/ding/utils/data/tests/test_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..423d5899a78e6a86f51a2e82e590d31744dc143d
--- /dev/null
+++ b/DI-engine/ding/utils/data/tests/test_dataset.py
@@ -0,0 +1,85 @@
+import pytest
+import torch
+from easydict import EasyDict
+import os
+from ding.utils.data import offline_data_save_type, create_dataset, NaiveRLDataset, D4RLDataset, HDF5Dataset
+
+cfg1 = dict(policy=dict(collect=dict(
+ data_type='naive',
+ data_path='./expert.pkl',
+), ))
+
+cfg2 = dict(
+ env=dict(norm_obs=dict(use_norm=True, offline_stats=dict(use_offline_stats=True))),
+ policy=dict(collect=dict(data_type='hdf5', data_path='./expert_demos.hdf5')),
+)
+
+cfg3 = dict(env=dict(env_id='hopper-expert-v0'), policy=dict(collect=dict(data_type='d4rl', ), ))
+
+cfgs = [cfg1, cfg2] # cfg3
+unittest_args = ['naive', 'hdf5']
+
+# fake transition & data
+transition = {}
+transition['obs'] = torch.zeros((3, 1))
+transition['next_obs'] = torch.zeros((3, 1))
+transition['action'] = torch.zeros((1, 1))
+transition['reward'] = torch.tensor((1, ))
+transition['done'] = False
+transition['collect_iter'] = 0
+
+fake_data = [transition for i in range(32)]
+expert_data_path = './expert.pkl'
+
+
+@pytest.mark.parametrize('data_type', unittest_args)
+@pytest.mark.unittest
+def test_offline_data_save_type(data_type):
+ offline_data_save_type(exp_data=fake_data, expert_data_path=expert_data_path, data_type=data_type)
+
+
+@pytest.mark.parametrize('cfg', cfgs)
+@pytest.mark.unittest
+def test_dataset(cfg):
+ cfg = EasyDict(cfg)
+ create_dataset(cfg)
+
+
+@pytest.mark.parametrize('cfg', [cfg1])
+@pytest.mark.unittest
+def test_NaiveRLDataset(cfg):
+ cfg = EasyDict(cfg)
+ NaiveRLDataset(cfg)
+ dataset = NaiveRLDataset(expert_data_path)
+ assert type(len(dataset)) == int
+ assert dataset[0] is not None
+
+
+# @pytest.mark.parametrize('cfg', [cfg3])
+# @pytest.mark.unittest
+# def test_D4RLDataset(cfg):
+# cfg = EasyDict(cfg)
+# dataset = D4RLDataset(cfg)
+
+
+@pytest.mark.parametrize('cfg', [cfg2])
+@pytest.mark.unittest
+def test_HDF5Dataset(cfg):
+ cfg = EasyDict(cfg)
+ dataset = HDF5Dataset(cfg)
+ assert dataset.mean is not None and dataset.std[0] is not None
+ assert dataset._data['obs'].mean(0)[0] == 0
+ assert type(len(dataset)) == int
+ assert dataset[0] is not None
+
+
+@pytest.fixture(scope="session", autouse=True)
+def cleanup(request):
+
+ def remove_test_dir():
+ if os.path.exists('./expert.pkl'):
+ os.remove('./expert.pkl')
+ if os.path.exists('./expert_demos.hdf5'):
+ os.remove('./expert_demos.hdf5')
+
+ request.addfinalizer(remove_test_dir)
diff --git a/DI-engine/ding/utils/default_helper.py b/DI-engine/ding/utils/default_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..1881ca6cc0837a265ec5a054990387b213de4272
--- /dev/null
+++ b/DI-engine/ding/utils/default_helper.py
@@ -0,0 +1,654 @@
+from typing import Union, Mapping, List, NamedTuple, Tuple, Callable, Optional, Any, Dict
+import copy
+from ditk import logging
+import random
+from functools import lru_cache # in python3.9, we can change to cache
+import numpy as np
+import torch
+import treetensor.torch as ttorch
+
+
+def get_shape0(data: Union[List, Dict, torch.Tensor, ttorch.Tensor]) -> int:
+ """
+ Overview:
+ Get shape[0] of data's torch tensor or treetensor
+ Arguments:
+ - data (:obj:`Union[List,Dict,torch.Tensor,ttorch.Tensor]`): data to be analysed
+ Returns:
+ - shape[0] (:obj:`int`): first dimension length of data, usually the batchsize.
+ """
+ if isinstance(data, list) or isinstance(data, tuple):
+ return get_shape0(data[0])
+ elif isinstance(data, dict):
+ for k, v in data.items():
+ return get_shape0(v)
+ elif isinstance(data, torch.Tensor):
+ return data.shape[0]
+ elif isinstance(data, ttorch.Tensor):
+
+ def fn(t):
+ item = list(t.values())[0]
+ if np.isscalar(item[0]):
+ return item[0]
+ else:
+ return fn(item)
+
+ return fn(data.shape)
+ else:
+ raise TypeError("Error in getting shape0, not support type: {}".format(data))
+
+
+def lists_to_dicts(
+ data: Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]],
+ recursive: bool = False,
+) -> Union[Mapping[object, object], NamedTuple]:
+ """
+ Overview:
+ Transform a list of dicts to a dict of lists.
+ Arguments:
+ - data (:obj:`Union[List[Union[dict, NamedTuple]], Tuple[Union[dict, NamedTuple]]]`):
+ A dict of lists need to be transformed
+ - recursive (:obj:`bool`): whether recursively deals with dict element
+ Returns:
+ - newdata (:obj:`Union[Mapping[object, object], NamedTuple]`): A list of dicts as a result
+ Example:
+ >>> from ding.utils import *
+ >>> lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}])
+ {1: [1, 2], 10: [3, 4]}
+ """
+ if len(data) == 0:
+ raise ValueError("empty data")
+ if isinstance(data[0], dict):
+ if recursive:
+ new_data = {}
+ for k in data[0].keys():
+ if isinstance(data[0][k], dict) and k != 'prev_state':
+ tmp = [data[b][k] for b in range(len(data))]
+ new_data[k] = lists_to_dicts(tmp)
+ else:
+ new_data[k] = [data[b][k] for b in range(len(data))]
+ else:
+ new_data = {k: [data[b][k] for b in range(len(data))] for k in data[0].keys()}
+ elif isinstance(data[0], tuple) and hasattr(data[0], '_fields'): # namedtuple
+ new_data = type(data[0])(*list(zip(*data)))
+ else:
+ raise TypeError("not support element type: {}".format(type(data[0])))
+ return new_data
+
+
+def dicts_to_lists(data: Mapping[object, List[object]]) -> List[Mapping[object, object]]:
+ """
+ Overview:
+ Transform a dict of lists to a list of dicts.
+
+ Arguments:
+ - data (:obj:`Mapping[object, list]`): A list of dicts need to be transformed
+
+ Returns:
+ - newdata (:obj:`List[Mapping[object, object]]`): A dict of lists as a result
+
+ Example:
+ >>> from ding.utils import *
+ >>> dicts_to_lists({1: [1, 2], 10: [3, 4]})
+ [{1: 1, 10: 3}, {1: 2, 10: 4}]
+ """
+ new_data = [v for v in data.values()]
+ new_data = [{k: v for k, v in zip(data.keys(), t)} for t in list(zip(*new_data))]
+ return new_data
+
+
+def override(cls: type) -> Callable[[
+ Callable,
+], Callable]:
+ """
+ Overview:
+ Annotation for documenting method overrides.
+
+ Arguments:
+ - cls (:obj:`type`): The superclass that provides the overridden method. If this
+ cls does not actually have the method, an error is raised.
+ """
+
+ def check_override(method: Callable) -> Callable:
+ if method.__name__ not in dir(cls):
+ raise NameError("{} does not override any method of {}".format(method, cls))
+ return method
+
+ return check_override
+
+
+def squeeze(data: object) -> object:
+ """
+ Overview:
+ Squeeze data from tuple, list or dict to single object
+ Arguments:
+ - data (:obj:`object`): data to be squeezed
+ Example:
+ >>> a = (4, )
+ >>> a = squeeze(a)
+ >>> print(a)
+ >>> 4
+ """
+ if isinstance(data, tuple) or isinstance(data, list):
+ if len(data) == 1:
+ return data[0]
+ else:
+ return tuple(data)
+ elif isinstance(data, dict):
+ if len(data) == 1:
+ return list(data.values())[0]
+ return data
+
+
+default_get_set = set()
+
+
+def default_get(
+ data: dict,
+ name: str,
+ default_value: Optional[Any] = None,
+ default_fn: Optional[Callable] = None,
+ judge_fn: Optional[Callable] = None
+) -> Any:
+ """
+ Overview:
+ Getting the value by input, checks generically on the inputs with \
+ at least ``data`` and ``name``. If ``name`` exists in ``data``, \
+ get the value at ``name``; else, add ``name`` to ``default_get_set``\
+ with value generated by \
+ ``default_fn`` (or directly as ``default_value``) that \
+ is checked by `` judge_fn`` to be legal.
+ Arguments:
+ - data(:obj:`dict`): Data input dictionary
+ - name(:obj:`str`): Key name
+ - default_value(:obj:`Optional[Any]`) = None,
+ - default_fn(:obj:`Optional[Callable]`) = Value
+ - judge_fn(:obj:`Optional[Callable]`) = None
+ Returns:
+ - ret(:obj:`list`): Splitted data
+ - residual(:obj:`list`): Residule list
+ """
+ if name in data:
+ return data[name]
+ else:
+ assert default_value is not None or default_fn is not None
+ value = default_fn() if default_fn is not None else default_value
+ if judge_fn:
+ assert judge_fn(value), "defalut value({}) is not accepted by judge_fn".format(type(value))
+ if name not in default_get_set:
+ logging.warning("{} use default value {}".format(name, value))
+ default_get_set.add(name)
+ return value
+
+
+def list_split(data: list, step: int) -> List[list]:
+ """
+ Overview:
+ Split list of data by step.
+ Arguments:
+ - data(:obj:`list`): List of data for spliting
+ - step(:obj:`int`): Number of step for spliting
+ Returns:
+ - ret(:obj:`list`): List of splitted data.
+ - residual(:obj:`list`): Residule list. This value is ``None`` when ``data`` divides ``steps``.
+ Example:
+ >>> list_split([1,2,3,4],2)
+ ([[1, 2], [3, 4]], None)
+ >>> list_split([1,2,3,4],3)
+ ([[1, 2, 3]], [4])
+ """
+ if len(data) < step:
+ return [], data
+ ret = []
+ divide_num = len(data) // step
+ for i in range(divide_num):
+ start, end = i * step, (i + 1) * step
+ ret.append(data[start:end])
+ if divide_num * step < len(data):
+ residual = data[divide_num * step:]
+ else:
+ residual = None
+ return ret, residual
+
+
+def error_wrapper(fn, default_ret, warning_msg=""):
+ """
+ Overview:
+ wrap the function, so that any Exception in the function will be catched and return the default_ret
+ Arguments:
+ - fn (:obj:`Callable`): the function to be wraped
+ - default_ret (:obj:`obj`): the default return when an Exception occurred in the function
+ Returns:
+ - wrapper (:obj:`Callable`): the wrapped function
+ Examples:
+ >>> # Used to checkfor Fakelink (Refer to utils.linklink_dist_helper.py)
+ >>> def get_rank(): # Get the rank of linklink model, return 0 if use FakeLink.
+ >>> if is_fake_link:
+ >>> return 0
+ >>> return error_wrapper(link.get_rank, 0)()
+ """
+
+ def wrapper(*args, **kwargs):
+ try:
+ ret = fn(*args, **kwargs)
+ except Exception as e:
+ ret = default_ret
+ if warning_msg != "":
+ one_time_warning(warning_msg, "\ndefault_ret = {}\terror = {}".format(default_ret, e))
+ return ret
+
+ return wrapper
+
+
+class LimitedSpaceContainer:
+ """
+ Overview:
+ A space simulator.
+ Interfaces:
+ ``__init__``, ``get_residual_space``, ``release_space``
+ """
+
+ def __init__(self, min_val: int, max_val: int) -> None:
+ """
+ Overview:
+ Set ``min_val`` and ``max_val`` of the container, also set ``cur`` to ``min_val`` for initialization.
+ Arguments:
+ - min_val (:obj:`int`): Min volume of the container, usually 0.
+ - max_val (:obj:`int`): Max volume of the container.
+ """
+ self.min_val = min_val
+ self.max_val = max_val
+ assert (max_val >= min_val)
+ self.cur = self.min_val
+
+ def get_residual_space(self) -> int:
+ """
+ Overview:
+ Get all residual pieces of space. Set ``cur`` to ``max_val``
+ Arguments:
+ - ret (:obj:`int`): Residual space, calculated by ``max_val`` - ``cur``.
+ """
+ ret = self.max_val - self.cur
+ self.cur = self.max_val
+ return ret
+
+ def acquire_space(self) -> bool:
+ """
+ Overview:
+ Try to get one pice of space. If there is one, return True; Otherwise return False.
+ Returns:
+ - flag (:obj:`bool`): Whether there is any piece of residual space.
+ """
+ if self.cur < self.max_val:
+ self.cur += 1
+ return True
+ else:
+ return False
+
+ def release_space(self) -> None:
+ """
+ Overview:
+ Release only one piece of space. Decrement ``cur``, but ensure it won't be negative.
+ """
+ self.cur = max(self.min_val, self.cur - 1)
+
+ def increase_space(self) -> None:
+ """
+ Overview:
+ Increase one piece in space. Increment ``max_val``.
+ """
+ self.max_val += 1
+
+ def decrease_space(self) -> None:
+ """
+ Overview:
+ Decrease one piece in space. Decrement ``max_val``.
+ """
+ self.max_val -= 1
+
+
+def deep_merge_dicts(original: dict, new_dict: dict) -> dict:
+ """
+ Overview:
+ Merge two dicts by calling ``deep_update``
+ Arguments:
+ - original (:obj:`dict`): Dict 1.
+ - new_dict (:obj:`dict`): Dict 2.
+ Returns:
+ - merged_dict (:obj:`dict`): A new dict that is d1 and d2 deeply merged.
+ """
+ original = original or {}
+ new_dict = new_dict or {}
+ merged = copy.deepcopy(original)
+ if new_dict: # if new_dict is neither empty dict nor None
+ deep_update(merged, new_dict, True, [])
+ return merged
+
+
+def deep_update(
+ original: dict,
+ new_dict: dict,
+ new_keys_allowed: bool = False,
+ whitelist: Optional[List[str]] = None,
+ override_all_if_type_changes: Optional[List[str]] = None
+):
+ """
+ Overview:
+ Update original dict with values from new_dict recursively.
+ Arguments:
+ - original (:obj:`dict`): Dictionary with default values.
+ - new_dict (:obj:`dict`): Dictionary with values to be updated
+ - new_keys_allowed (:obj:`bool`): Whether new keys are allowed.
+ - whitelist (:obj:`Optional[List[str]]`):
+ List of keys that correspond to dict
+ values where new subkeys can be introduced. This is only at the top
+ level.
+ - override_all_if_type_changes(:obj:`Optional[List[str]]`):
+ List of top level
+ keys with value=dict, for which we always simply override the
+ entire value (:obj:`dict`), if the "type" key in that value dict changes.
+
+ .. note::
+
+ If new key is introduced in new_dict, then if new_keys_allowed is not
+ True, an error will be thrown. Further, for sub-dicts, if the key is
+ in the whitelist, then new subkeys can be introduced.
+ """
+ whitelist = whitelist or []
+ override_all_if_type_changes = override_all_if_type_changes or []
+
+ for k, value in new_dict.items():
+ if k not in original and not new_keys_allowed:
+ raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys()))
+
+ # Both original value and new one are dicts.
+ if isinstance(original.get(k), dict) and isinstance(value, dict):
+ # Check old type vs old one. If different, override entire value.
+ if k in override_all_if_type_changes and \
+ "type" in value and "type" in original[k] and \
+ value["type"] != original[k]["type"]:
+ original[k] = value
+ # Whitelisted key -> ok to add new subkeys.
+ elif k in whitelist:
+ deep_update(original[k], value, True)
+ # Non-whitelisted key.
+ else:
+ deep_update(original[k], value, new_keys_allowed)
+ # Original value not a dict OR new value not a dict:
+ # Override entire value.
+ else:
+ original[k] = value
+ return original
+
+
+def flatten_dict(data: dict, delimiter: str = "/") -> dict:
+ """
+ Overview:
+ Flatten the dict, see example
+ Arguments:
+ - data (:obj:`dict`): Original nested dict
+ - delimiter (str): Delimiter of the keys of the new dict
+ Returns:
+ - data (:obj:`dict`): Flattened nested dict
+ Example:
+ >>> a
+ {'a': {'b': 100}}
+ >>> flatten_dict(a)
+ {'a/b': 100}
+ """
+ data = copy.deepcopy(data)
+ while any(isinstance(v, dict) for v in data.values()):
+ remove = []
+ add = {}
+ for key, value in data.items():
+ if isinstance(value, dict):
+ for subkey, v in value.items():
+ add[delimiter.join([key, subkey])] = v
+ remove.append(key)
+ data.update(add)
+ for k in remove:
+ del data[k]
+ return data
+
+
+def set_pkg_seed(seed: int, use_cuda: bool = True) -> None:
+ """
+ Overview:
+ Side effect function to set seed for ``random``, ``numpy random``, and ``torch's manual seed``.\
+ This is usaually used in entry scipt in the section of setting random seed for all package and instance
+ Argument:
+ - seed(:obj:`int`): Set seed
+ - use_cuda(:obj:`bool`) Whether use cude
+ Examples:
+ >>> # ../entry/xxxenv_xxxpolicy_main.py
+ >>> ...
+ # Set random seed for all package and instance
+ >>> collector_env.seed(seed)
+ >>> evaluator_env.seed(seed, dynamic_seed=False)
+ >>> set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+ >>> ...
+ # Set up RL Policy, etc.
+ >>> ...
+
+ """
+ random.seed(seed)
+ np.random.seed(seed)
+ torch.manual_seed(seed)
+ if use_cuda and torch.cuda.is_available():
+ torch.cuda.manual_seed(seed)
+
+
+@lru_cache()
+def one_time_warning(warning_msg: str) -> None:
+ """
+ Overview:
+ Print warning message only once.
+ Arguments:
+ - warning_msg (:obj:`str`): Warning message.
+ """
+
+ logging.warning(warning_msg)
+
+
+def split_fn(data, indices, start, end):
+ """
+ Overview:
+ Split data by indices
+ Arguments:
+ - data (:obj:`Union[List, Dict, torch.Tensor, ttorch.Tensor]`): data to be analysed
+ - indices (:obj:`np.ndarray`): indices to split
+ - start (:obj:`int`): start index
+ - end (:obj:`int`): end index
+ """
+
+ if data is None:
+ return None
+ elif isinstance(data, list):
+ return [split_fn(d, indices, start, end) for d in data]
+ elif isinstance(data, dict):
+ return {k1: split_fn(v1, indices, start, end) for k1, v1 in data.items()}
+ elif isinstance(data, str):
+ return data
+ else:
+ return data[indices[start:end]]
+
+
+def split_data_generator(data: dict, split_size: int, shuffle: bool = True) -> dict:
+ """
+ Overview:
+ Split data into batches
+ Arguments:
+ - data (:obj:`dict`): data to be analysed
+ - split_size (:obj:`int`): split size
+ - shuffle (:obj:`bool`): whether shuffle
+ """
+
+ assert isinstance(data, dict), type(data)
+ length = []
+ for k, v in data.items():
+ if v is None:
+ continue
+ elif k in ['prev_state', 'prev_actor_state', 'prev_critic_state']:
+ length.append(len(v))
+ elif isinstance(v, list) or isinstance(v, tuple):
+ if isinstance(v[0], str):
+ # some buffer data contains useless string infos, such as 'buffer_id',
+ # which should not be split, so we just skip it
+ continue
+ else:
+ length.append(get_shape0(v[0]))
+ elif isinstance(v, dict):
+ length.append(len(v[list(v.keys())[0]]))
+ else:
+ length.append(len(v))
+ assert len(length) > 0
+ # assert len(set(length)) == 1, "data values must have the same length: {}".format(length)
+ # if continuous action, data['logit'] is list of length 2
+ length = length[0]
+ assert split_size >= 1
+ if shuffle:
+ indices = np.random.permutation(length)
+ else:
+ indices = np.arange(length)
+ for i in range(0, length, split_size):
+ if i + split_size > length:
+ i = length - split_size
+ batch = split_fn(data, indices, i, i + split_size)
+ yield batch
+
+
+class RunningMeanStd(object):
+ """
+ Overview:
+ Wrapper to update new variable, new mean, and new count
+ Interfaces:
+ ``__init__``, ``update``, ``reset``, ``new_shape``
+ Properties:
+ - ``mean``, ``std``, ``_epsilon``, ``_shape``, ``_mean``, ``_var``, ``_count``
+ """
+
+ def __init__(self, epsilon=1e-4, shape=(), device=torch.device('cpu')):
+ """
+ Overview:
+ Initialize ``self.`` See ``help(type(self))`` for accurate \
+ signature; setup the properties.
+ Arguments:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ - epsilon (:obj:`Float`): the epsilon used for self for the std output
+ - shape (:obj: `np.array`): the np array shape used for the expression \
+ of this wrapper on attibutes of mean and variance
+ """
+ self._epsilon = epsilon
+ self._shape = shape
+ self._device = device
+ self.reset()
+
+ def update(self, x):
+ """
+ Overview:
+ Update mean, variable, and count
+ Arguments:
+ - ``x``: the batch
+ """
+ batch_mean = np.mean(x, axis=0)
+ batch_var = np.var(x, axis=0)
+ batch_count = x.shape[0]
+
+ new_count = batch_count + self._count
+ mean_delta = batch_mean - self._mean
+ new_mean = self._mean + mean_delta * batch_count / new_count
+ # this method for calculating new variable might be numerically unstable
+ m_a = self._var * self._count
+ m_b = batch_var * batch_count
+ m2 = m_a + m_b + np.square(mean_delta) * self._count * batch_count / new_count
+ new_var = m2 / new_count
+ self._mean = new_mean
+ self._var = new_var
+ self._count = new_count
+
+ def reset(self):
+ """
+ Overview:
+ Resets the state of the environment and reset properties: ``_mean``, ``_var``, ``_count``
+ """
+ if len(self._shape) > 0:
+ self._mean = np.zeros(self._shape, 'float32')
+ self._var = np.ones(self._shape, 'float32')
+ else:
+ self._mean, self._var = 0., 1.
+ self._count = self._epsilon
+
+ @property
+ def mean(self) -> np.ndarray:
+ """
+ Overview:
+ Property ``mean`` gotten from ``self._mean``
+ """
+ if np.isscalar(self._mean):
+ return self._mean
+ else:
+ return torch.FloatTensor(self._mean).to(self._device)
+
+ @property
+ def std(self) -> np.ndarray:
+ """
+ Overview:
+ Property ``std`` calculated from ``self._var`` and the epsilon value of ``self._epsilon``
+ """
+ std = np.sqrt(self._var + 1e-8)
+ if np.isscalar(std):
+ return std
+ else:
+ return torch.FloatTensor(std).to(self._device)
+
+ @staticmethod
+ def new_shape(obs_shape, act_shape, rew_shape):
+ """
+ Overview:
+ Get new shape of observation, acton, and reward; in this case unchanged.
+ Arguments:
+ obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`)
+ Returns:
+ obs_shape (:obj:`Any`), act_shape (:obj:`Any`), rew_shape (:obj:`Any`)
+ """
+ return obs_shape, act_shape, rew_shape
+
+
+def make_key_as_identifier(data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Overview:
+ Make the key of dict into legal python identifier string so that it is
+ compatible with some python magic method such as ``__getattr``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): The original dict data.
+ Return:
+ - new_data (:obj:`Dict[str, Any]`): The new dict data with legal identifier keys.
+ """
+
+ def legalization(s: str) -> str:
+ if s[0].isdigit():
+ s = '_' + s
+ return s.replace('.', '_')
+
+ new_data = {}
+ for k in data:
+ new_k = legalization(k)
+ new_data[new_k] = data[k]
+ return new_data
+
+
+def remove_illegal_item(data: Dict[str, Any]) -> Dict[str, Any]:
+ """
+ Overview:
+ Remove illegal item in dict info, like str, which is not compatible with Tensor.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): The original dict data.
+ Return:
+ - new_data (:obj:`Dict[str, Any]`): The new dict data without legal items.
+ """
+ new_data = {}
+ for k, v in data.items():
+ if isinstance(v, str):
+ continue
+ new_data[k] = data[k]
+ return new_data
diff --git a/DI-engine/ding/utils/design_helper.py b/DI-engine/ding/utils/design_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..24805218ff6475555efe87cdd81e5c16b4e628f2
--- /dev/null
+++ b/DI-engine/ding/utils/design_helper.py
@@ -0,0 +1,24 @@
+from abc import ABCMeta
+
+
+# ABCMeta is a subclass of type, extending ABCMeta makes this metaclass is compatible with some classes
+# which extends ABC
+class SingletonMetaclass(ABCMeta):
+ """
+ Overview:
+ Returns the given type instance in input class
+ Interfaces:
+ ``__call__``
+ """
+ instances = {}
+
+ def __call__(cls: type, *args, **kwargs) -> object:
+ """
+ Overview:
+ Returns the given type instance in input class
+ """
+
+ if cls not in SingletonMetaclass.instances:
+ SingletonMetaclass.instances[cls] = super(SingletonMetaclass, cls).__call__(*args, **kwargs)
+ cls.instance = SingletonMetaclass.instances[cls]
+ return SingletonMetaclass.instances[cls]
diff --git a/DI-engine/ding/utils/fake_linklink.py b/DI-engine/ding/utils/fake_linklink.py
new file mode 100644
index 0000000000000000000000000000000000000000..5998030b3689cbb16706cda74d7cfa892dec9d67
--- /dev/null
+++ b/DI-engine/ding/utils/fake_linklink.py
@@ -0,0 +1,34 @@
+from collections import namedtuple
+
+
+class FakeClass:
+ """
+ Overview:
+ Fake class.
+ """
+
+ def __init__(self, *args, **kwargs):
+ pass
+
+
+class FakeNN:
+ """
+ Overview:
+ Fake nn class.
+ """
+
+ SyncBatchNorm2d = FakeClass
+
+
+class FakeLink:
+ """
+ Overview:
+ Fake link class.
+ """
+
+ nn = FakeNN()
+ syncbnVarMode_t = namedtuple("syncbnVarMode_t", "L2")(L2=None)
+ allreduceOp_t = namedtuple("allreduceOp_t", ['Sum', 'Max'])
+
+
+link = FakeLink()
diff --git a/DI-engine/ding/utils/fast_copy.py b/DI-engine/ding/utils/fast_copy.py
new file mode 100644
index 0000000000000000000000000000000000000000..cf4185ecbd33fa06bb983d813fd74f9968cfae7c
--- /dev/null
+++ b/DI-engine/ding/utils/fast_copy.py
@@ -0,0 +1,96 @@
+import torch
+import numpy as np
+from typing import Any, List
+
+
+class _FastCopy:
+ """
+ Overview:
+ The idea of this class comes from this article \
+ https://newbedev.com/what-is-a-fast-pythonic-way-to-deepcopy-just-data-from-a-python-dict-or-list.
+ We use recursive calls to copy each object that needs to be copied, which will be 5x faster \
+ than copy.deepcopy.
+ Interfaces:
+ ``__init__``, ``_copy_list``, ``_copy_dict``, ``_copy_tensor``, ``_copy_ndarray``, ``copy``.
+ """
+
+ def __init__(self):
+ """
+ Overview:
+ Initialize the _FastCopy object.
+ """
+
+ dispatch = {}
+ dispatch[list] = self._copy_list
+ dispatch[dict] = self._copy_dict
+ dispatch[torch.Tensor] = self._copy_tensor
+ dispatch[np.ndarray] = self._copy_ndarray
+ self.dispatch = dispatch
+
+ def _copy_list(self, l: List) -> dict:
+ """
+ Overview:
+ Copy the list.
+ Arguments:
+ - l (:obj:`List`): The list to be copied.
+ """
+
+ ret = l.copy()
+ for idx, item in enumerate(ret):
+ cp = self.dispatch.get(type(item))
+ if cp is not None:
+ ret[idx] = cp(item)
+ return ret
+
+ def _copy_dict(self, d: dict) -> dict:
+ """
+ Overview:
+ Copy the dict.
+ Arguments:
+ - d (:obj:`dict`): The dict to be copied.
+ """
+
+ ret = d.copy()
+ for key, value in ret.items():
+ cp = self.dispatch.get(type(value))
+ if cp is not None:
+ ret[key] = cp(value)
+
+ return ret
+
+ def _copy_tensor(self, t: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ Copy the tensor.
+ Arguments:
+ - t (:obj:`torch.Tensor`): The tensor to be copied.
+ """
+
+ return t.clone()
+
+ def _copy_ndarray(self, a: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Copy the ndarray.
+ Arguments:
+ - a (:obj:`np.ndarray`): The ndarray to be copied.
+ """
+
+ return np.copy(a)
+
+ def copy(self, sth: Any) -> Any:
+ """
+ Overview:
+ Copy the object.
+ Arguments:
+ - sth (:obj:`Any`): The object to be copied.
+ """
+
+ cp = self.dispatch.get(type(sth))
+ if cp is None:
+ return sth
+ else:
+ return cp(sth)
+
+
+fastcopy = _FastCopy()
diff --git a/DI-engine/ding/utils/file_helper.py b/DI-engine/ding/utils/file_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..b14c42de79518ce5eabcb0fba4f031c462521d39
--- /dev/null
+++ b/DI-engine/ding/utils/file_helper.py
@@ -0,0 +1,343 @@
+import io
+from ditk import logging
+import os
+import pickle
+import time
+from functools import lru_cache
+from typing import Union
+
+import torch
+
+from .import_helper import try_import_ceph, try_import_redis, try_import_rediscluster, try_import_mc
+from .lock_helper import get_file_lock
+
+_memcached = None
+_redis_cluster = None
+
+if os.environ.get('DI_STORE', 'off').lower() == 'on':
+ print('Enable DI-store')
+ from di_store import Client
+
+ di_store_config_path = os.environ.get("DI_STORE_CONFIG_PATH", './di_store.yaml')
+ di_store_client = Client(di_store_config_path)
+
+ def save_to_di_store(data):
+ return di_store_client.put(data)
+
+ def read_from_di_store(object_ref):
+ data = di_store_client.get(object_ref)
+ di_store_client.delete(object_ref)
+ return data
+else:
+ save_to_di_store = read_from_di_store = None
+
+
+@lru_cache()
+def get_ceph_package():
+ return try_import_ceph()
+
+
+@lru_cache()
+def get_redis_package():
+ return try_import_redis()
+
+
+@lru_cache()
+def get_rediscluster_package():
+ return try_import_rediscluster()
+
+
+@lru_cache()
+def get_mc_package():
+ return try_import_mc()
+
+
+def read_from_ceph(path: str) -> object:
+ """
+ Overview:
+ Read file from ceph
+ Arguments:
+ - path (:obj:`str`): File path in ceph, start with ``"s3://"``
+ Returns:
+ - (:obj:`data`): Deserialized data
+ """
+ value = get_ceph_package().Get(path)
+ if not value:
+ raise FileNotFoundError("File({}) doesn't exist in ceph".format(path))
+
+ return pickle.loads(value)
+
+
+@lru_cache()
+def _get_redis(host='localhost', port=6379):
+ """
+ Overview:
+ Ensures redis usage
+ Arguments:
+ - host (:obj:`str`): Host string
+ - port (:obj:`int`): Port number
+ Returns:
+ - (:obj:`Redis(object)`): Redis object with given ``host``, ``port``, and ``db=0``
+ """
+ return get_redis_package().StrictRedis(host=host, port=port, db=0)
+
+
+def read_from_redis(path: str) -> object:
+ """
+ Overview:
+ Read file from redis
+ Arguments:
+ - path (:obj:`str`): Dile path in redis, could be a string key
+ Returns:
+ - (:obj:`data`): Deserialized data
+ """
+ return pickle.loads(_get_redis().get(path))
+
+
+def _ensure_rediscluster(startup_nodes=[{"host": "127.0.0.1", "port": "7000"}]):
+ """
+ Overview:
+ Ensures redis usage
+ Arguments:
+ - List of startup nodes (:obj:`dict`) of
+ - host (:obj:`str`): Host string
+ - port (:obj:`int`): Port number
+ Returns:
+ - (:obj:`RedisCluster(object)`): RedisCluster object with given ``host``, ``port``, \
+ and ``False`` for ``decode_responses`` in default.
+ """
+ global _redis_cluster
+ if _redis_cluster is None:
+ _redis_cluster = get_rediscluster_package().RedisCluster(startup_nodes=startup_nodes, decode_responses=False)
+ return
+
+
+def read_from_rediscluster(path: str) -> object:
+ """
+ Overview:
+ Read file from rediscluster
+ Arguments:
+ - path (:obj:`str`): Dile path in rediscluster, could be a string key
+ Returns:
+ - (:obj:`data`): Deserialized data
+ """
+ _ensure_rediscluster()
+ value_bytes = _redis_cluster.get(path)
+ value = pickle.loads(value_bytes)
+ return value
+
+
+def read_from_file(path: str) -> object:
+ """
+ Overview:
+ Read file from local file system
+ Arguments:
+ - path (:obj:`str`): File path in local file system
+ Returns:
+ - (:obj:`data`): Deserialized data
+ """
+ with open(path, "rb") as f:
+ value = pickle.load(f)
+
+ return value
+
+
+def _ensure_memcached():
+ """
+ Overview:
+ Ensures memcache usage
+ Returns:
+ - (:obj:`MemcachedClient instance`): MemcachedClient's class instance built with current \
+ memcached_client's ``server_list.conf`` and ``client.conf`` files
+ """
+ global _memcached
+ if _memcached is None:
+ server_list_config_file = "/mnt/lustre/share/memcached_client/server_list.conf"
+ client_config_file = "/mnt/lustre/share/memcached_client/client.conf"
+ _memcached = get_mc_package().MemcachedClient.GetInstance(server_list_config_file, client_config_file)
+ return
+
+
+def read_from_mc(path: str, flush=False) -> object:
+ """
+ Overview:
+ Read file from memcache, file must be saved by `torch.save()`
+ Arguments:
+ - path (:obj:`str`): File path in local system
+ Returns:
+ - (:obj:`data`): Deserialized data
+ """
+ _ensure_memcached()
+ while True:
+ try:
+ value = get_mc_package().pyvector()
+ if flush:
+ _memcached.Get(path, value, get_mc_package().MC_READ_THROUGH)
+ return
+ else:
+ _memcached.Get(path, value)
+ value_buf = get_mc_package().ConvertBuffer(value)
+ value_str = io.BytesIO(value_buf)
+ value_str = torch.load(value_str, map_location='cpu')
+ return value_str
+ except Exception:
+ print('read mc failed, retry...')
+ time.sleep(0.01)
+
+
+def read_from_path(path: str):
+ """
+ Overview:
+ Read file from ceph
+ Arguments:
+ - path (:obj:`str`): File path in ceph, start with ``"s3://"``, or use local file system
+ Returns:
+ - (:obj:`data`): Deserialized data
+ """
+ if get_ceph_package() is None:
+ logging.info(
+ "You do not have ceph installed! Loading local file!"
+ " If you are not testing locally, something is wrong!"
+ )
+ return read_from_file(path)
+ else:
+ return read_from_ceph(path)
+
+
+def save_file_ceph(path, data):
+ """
+ Overview:
+ Save pickle dumped data file to ceph
+ Arguments:
+ - path (:obj:`str`): File path in ceph, start with ``"s3://"``, use file system when not
+ - data (:obj:`Any`): Could be dict, list or tensor etc.
+ """
+ data = pickle.dumps(data)
+ save_path = os.path.dirname(path)
+ file_name = os.path.basename(path)
+ ceph = get_ceph_package()
+ if ceph is not None:
+ if hasattr(ceph, 'save_from_string'):
+ ceph.save_from_string(save_path, file_name, data)
+ elif hasattr(ceph, 'put'):
+ ceph.put(os.path.join(save_path, file_name), data)
+ else:
+ raise RuntimeError('ceph can not save file, check your ceph installation')
+ else:
+ size = len(data)
+ if save_path == 'do_not_save':
+ logging.info(
+ "You do not have ceph installed! ignored file {} of size {}!".format(file_name, size) +
+ " If you are not testing locally, something is wrong!"
+ )
+ return
+ p = os.path.join(save_path, file_name)
+ with open(p, 'wb') as f:
+ logging.info(
+ "You do not have ceph installed! Saving as local file at {} of size {}!".format(p, size) +
+ " If you are not testing locally, something is wrong!"
+ )
+ f.write(data)
+
+
+def save_file_redis(path, data):
+ """
+ Overview:
+ Save pickle dumped data file to redis
+ Arguments:
+ - path (:obj:`str`): File path (could be a string key) in redis
+ - data (:obj:`Any`): Could be dict, list or tensor etc.
+ """
+ _get_redis().set(path, pickle.dumps(data))
+
+
+def save_file_rediscluster(path, data):
+ """
+ Overview:
+ Save pickle dumped data file to rediscluster
+ Arguments:
+ - path (:obj:`str`): File path (could be a string key) in redis
+ - data (:obj:`Any`): Could be dict, list or tensor etc.
+ """
+ _ensure_rediscluster()
+ data = pickle.dumps(data)
+ _redis_cluster.set(path, data)
+ return
+
+
+def read_file(path: str, fs_type: Union[None, str] = None, use_lock: bool = False) -> object:
+ """
+ Overview:
+ Read file from path
+ Arguments:
+ - path (:obj:`str`): The path of file to read
+ - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}``
+ - use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system
+ """
+ if fs_type is None:
+ if path.lower().startswith('s3'):
+ fs_type = 'ceph'
+ elif get_mc_package() is not None:
+ fs_type = 'mc'
+ else:
+ fs_type = 'normal'
+ assert fs_type in ['normal', 'ceph', 'mc']
+ if fs_type == 'ceph':
+ data = read_from_path(path)
+ elif fs_type == 'normal':
+ if use_lock:
+ with get_file_lock(path, 'read'):
+ data = torch.load(path, map_location='cpu')
+ else:
+ data = torch.load(path, map_location='cpu')
+ elif fs_type == 'mc':
+ data = read_from_mc(path)
+ return data
+
+
+def save_file(path: str, data: object, fs_type: Union[None, str] = None, use_lock: bool = False) -> None:
+ """
+ Overview:
+ Save data to file of path
+ Arguments:
+ - path (:obj:`str`): The path of file to save to
+ - data (:obj:`object`): The data to save
+ - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}``
+ - use_lock (:obj:`bool`): Whether ``use_lock`` is in local normal file system
+ """
+ if fs_type is None:
+ if path.lower().startswith('s3'):
+ fs_type = 'ceph'
+ elif get_mc_package() is not None:
+ fs_type = 'mc'
+ else:
+ fs_type = 'normal'
+ assert fs_type in ['normal', 'ceph', 'mc']
+ if fs_type == 'ceph':
+ save_file_ceph(path, data)
+ elif fs_type == 'normal':
+ if use_lock:
+ with get_file_lock(path, 'write'):
+ torch.save(data, path)
+ else:
+ torch.save(data, path)
+ elif fs_type == 'mc':
+ torch.save(data, path)
+ read_from_mc(path, flush=True)
+
+
+def remove_file(path: str, fs_type: Union[None, str] = None) -> None:
+ """
+ Overview:
+ Remove file
+ Arguments:
+ - path (:obj:`str`): The path of file you want to remove
+ - fs_type (:obj:`str` or :obj:`None`): The file system type, support ``{'normal', 'ceph'}``
+ """
+ if fs_type is None:
+ fs_type = 'ceph' if path.lower().startswith('s3') else 'normal'
+ assert fs_type in ['normal', 'ceph']
+ if fs_type == 'ceph':
+ os.popen("aws s3 rm --recursive {}".format(path))
+ elif fs_type == 'normal':
+ os.popen("rm -rf {}".format(path))
diff --git a/DI-engine/ding/utils/import_helper.py b/DI-engine/ding/utils/import_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbb757dae2660afb535e73b77437e0f8e9f240b0
--- /dev/null
+++ b/DI-engine/ding/utils/import_helper.py
@@ -0,0 +1,107 @@
+import importlib
+from typing import List
+
+import ding
+from .default_helper import one_time_warning
+
+
+def try_import_ceph():
+ """
+ Overview:
+ Try import ceph module, if failed, return ``None``
+
+ Returns:
+ - (:obj:`Module`): Imported module, or ``None`` when ceph not found
+ """
+ try:
+ import ceph
+ client = ceph.S3Client()
+ return client
+ except ModuleNotFoundError as e:
+ try:
+ from petrel_client.client import Client
+ client = Client(conf_path='~/petreloss.conf')
+ return client
+ except ModuleNotFoundError as e:
+ one_time_warning("You have not installed ceph package! DI-engine has changed to some alternatives.")
+ ceph = None
+ return ceph
+
+
+def try_import_mc():
+ """
+ Overview:
+ Try import mc module, if failed, return ``None``
+
+ Returns:
+ - (:obj:`Module`): Imported module, or ``None`` when mc not found
+ """
+ try:
+ import mc
+ except ModuleNotFoundError as e:
+ # one_time_warning("You have not installed memcache package! DI-engine has changed to some alternatives.")
+ mc = None
+ return mc
+
+
+def try_import_redis():
+ """
+ Overview:
+ Try import redis module, if failed, return ``None``
+
+ Returns:
+ - (:obj:`Module`): Imported module, or ``None`` when redis not found
+ """
+ try:
+ import redis
+ except ModuleNotFoundError as e:
+ one_time_warning("You have not installed redis package! DI-engine has changed to some alternatives.")
+ redis = None
+ return redis
+
+
+def try_import_rediscluster():
+ """
+ Overview:
+ Try import rediscluster module, if failed, return ``None``
+
+ Returns:
+ - (:obj:`Module`): Imported module, or ``None`` when rediscluster not found
+ """
+ try:
+ import rediscluster
+ except ModuleNotFoundError as e:
+ one_time_warning("You have not installed rediscluster package! DI-engine has changed to some alternatives.")
+ rediscluster = None
+ return rediscluster
+
+
+def try_import_link():
+ """
+ Overview:
+ Try import linklink module, if failed, import ding.tests.fake_linklink instead
+
+ Returns:
+ - (:obj:`Module`): Imported module (may be ``fake_linklink``)
+ """
+ if ding.enable_linklink:
+ try:
+ import linklink as link
+ except ModuleNotFoundError as e:
+ one_time_warning("You have not installed linklink package! DI-engine has changed to some alternatives.")
+ from .fake_linklink import link
+ else:
+ from .fake_linklink import link
+
+ return link
+
+
+def import_module(modules: List[str]) -> None:
+ """
+ Overview:
+ Import several module as a list
+ Arguments:
+ - (:obj:`str list`): List of module names
+ """
+ for name in modules:
+ importlib.import_module(name)
diff --git a/DI-engine/ding/utils/k8s_helper.py b/DI-engine/ding/utils/k8s_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..e30bba497df0f9c9b31408370c0d6ce0ed127dfd
--- /dev/null
+++ b/DI-engine/ding/utils/k8s_helper.py
@@ -0,0 +1,244 @@
+import os
+import json
+from typing import Tuple
+from easydict import EasyDict
+import yaml
+import subprocess
+from enum import Enum, unique
+from ding.interaction.base import split_http_address
+from .default_helper import one_time_warning
+
+DEFAULT_NAMESPACE = 'default'
+DEFAULT_POD_NAME = 'dijob-example-coordinator'
+DEFAULT_API_VERSION = '/v1alpha1'
+
+DEFAULT_K8S_COLLECTOR_PORT = 22270
+DEFAULT_K8S_LEARNER_PORT = 22271
+DEFAULT_K8S_AGGREGATOR_SLAVE_PORT = 22272
+DEFAULT_K8S_COORDINATOR_PORT = 22273
+DEFAULT_K8S_AGGREGATOR_MASTER_PORT = 22273
+
+
+def get_operator_server_kwargs(cfg: EasyDict) -> dict:
+ """
+ Overview:
+ Get kwarg dict from config file
+ Arguments:
+ - cfg (:obj:`EasyDict`) System config
+ Returns:
+ - result (:obj:`dict`) Containing ``api_version``, ``namespace``, ``name``, ``port``, ``host``.
+ """
+
+ namespace = os.environ.get('KUBERNETES_POD_NAMESPACE', DEFAULT_NAMESPACE)
+ name = os.environ.get('KUBERNETES_POD_NAME', DEFAULT_POD_NAME)
+ url = cfg.get('system_addr', None) or os.environ.get('KUBERNETES_SERVER_URL', None)
+ assert url, 'please set environment variable KUBERNETES_SERVER_URL in Kubenetes platform.'
+ api_version = cfg.get('api_version', None) \
+ or os.environ.get('KUBERNETES_SERVER_API_VERSION', DEFAULT_API_VERSION)
+ try:
+ host, port = url.split(":")[0], int(url.split(":")[1])
+ except Exception as e:
+ host, port, _, _ = split_http_address(url)
+
+ return {
+ 'api_version': api_version,
+ 'namespace': namespace,
+ 'name': name,
+ 'host': host,
+ 'port': port,
+ }
+
+
+def exist_operator_server() -> bool:
+ """
+ Overview:
+ Check if the 'KUBERNETES_SERVER_URL' environment variable exists.
+ """
+
+ return 'KUBERNETES_SERVER_URL' in os.environ
+
+
+def pod_exec_command(kubeconfig: str, name: str, namespace: str, cmd: str) -> Tuple[int, str]:
+ """
+ Overview:
+ Execute command in pod
+ Arguments:
+ - kubeconfig (:obj:`str`) The path of kubeconfig file
+ - name (:obj:`str`) The name of pod
+ - namespace (:obj:`str`) The namespace of pod
+ """
+
+ try:
+ from kubernetes import config
+ from kubernetes.client import CoreV1Api
+ from kubernetes.client.rest import ApiException
+ from kubernetes.stream import stream
+ except ModuleNotFoundError as e:
+ one_time_warning("You have not installed kubernetes package! Please try 'pip install DI-engine[k8s]'.")
+ exit(-1)
+
+ config.load_kube_config(config_file=kubeconfig)
+ core_v1 = CoreV1Api()
+ resp = None
+ try:
+ resp = core_v1.read_namespaced_pod(name=name, namespace=namespace)
+ except ApiException as e:
+ if e.status != 404:
+ return -1, "Unknown error: %s" % e
+ if not resp:
+ return -1, f"Pod {name} does not exist."
+ if resp.status.phase != 'Running':
+ return -1, f"Pod {name} is not in Running."
+ exec_command = ['/bin/sh', '-c', cmd]
+ resp = stream(
+ core_v1.connect_get_namespaced_pod_exec,
+ name,
+ namespace,
+ command=exec_command,
+ stderr=False,
+ stdin=False,
+ stdout=True,
+ tty=False
+ )
+ resp = resp.replace("\'", "\"") \
+ .replace('None', 'null') \
+ .replace(': False', ': 0') \
+ .replace(': True', ': 1') \
+ .replace('"^([+-]?[0-9.]+)([eEinumkKMGTP]*[-+]?[0-9]*)$"', '\\"^([+-]?[0-9.]+)([eEinumkKMGTP]*[-+]?[0-9]*)$\\"')
+ resp = json.loads(resp)
+ return resp['code'], resp['message']
+
+
+@unique
+class K8sType(Enum):
+ Local = 1
+ K3s = 2
+
+
+class K8sLauncher(object):
+ """
+ Overview:
+ object to manage the K8s cluster
+ Interfaces:
+ ``__init__``, ``_load``, ``create_cluster``, ``_check_k3d_tools``, ``delete_cluster``, ``preload_images``
+ """
+
+ def __init__(self, config_path: str) -> None:
+ """
+ Overview:
+ Initialize the K8sLauncher object.
+ Arguments:
+ - config_path (:obj:`str`): The path of the config file.
+ """
+
+ self.name = None
+ self.servers = 1
+ self.agents = 0
+ self.type = K8sType.Local
+ self._images = []
+
+ self._load(config_path)
+ self._check_k3d_tools()
+
+ def _load(self, config_path: str) -> None:
+ """
+ Overview:
+ Load the config file.
+ Arguments:
+ - config_path (:obj:`str`): The path of the config file.
+ """
+
+ with open(config_path, 'r') as f:
+ data = yaml.safe_load(f)
+ self.name = data.get('name') if data.get('name') else self.name
+ if data.get('servers'):
+ if type(data.get('servers')) is not int:
+ raise TypeError(f"servers' type is expected int, actual {type(data.get('servers'))}")
+ self.servers = data.get('servers')
+ if data.get('agents'):
+ if type(data.get('agents')) is not int:
+ raise TypeError(f"agents' type is expected int, actual {type(data.get('agents'))}")
+ self.agents = data.get('agents')
+ if data.get('type'):
+ if data.get('type') == 'k3s':
+ self.type = K8sType.K3s
+ elif data.get('type') == 'local':
+ self.type = K8sType.Local
+ else:
+ raise ValueError(f"no type found for {data.get('type')}")
+ if data.get('preload_images'):
+ if type(data.get('preload_images')) is not list:
+ raise TypeError(f"preload_images' type is expected list, actual {type(data.get('preload_images'))}")
+ self._images = data.get('preload_images')
+
+ def _check_k3d_tools(self) -> None:
+ """
+ Overview:
+ Check if the k3d tools exist.
+ """
+
+ if self.type != K8sType.K3s:
+ return
+ args = ['which', 'k3d']
+ proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ out, _ = proc.communicate()
+ if out.decode('utf-8') == '':
+ raise FileNotFoundError(
+ "No k3d tools found, please install by executing ./ding/scripts/install-k8s-tools.sh"
+ )
+
+ def create_cluster(self) -> None:
+ """
+ Overview:
+ Create the k8s cluster.
+ """
+
+ print('Creating k8s cluster...')
+ if self.type != K8sType.K3s:
+ return
+ args = ['k3d', 'cluster', 'create', f'{self.name}', f'--servers={self.servers}', f'--agents={self.agents}']
+ proc = subprocess.Popen(args, stderr=subprocess.PIPE)
+ _, err = proc.communicate()
+ err_str = err.decode('utf-8').strip()
+ if err_str != '' and 'WARN' not in err_str:
+ if 'already exists' in err_str:
+ print('K8s cluster already exists')
+ else:
+ raise RuntimeError(f'Failed to create cluster {self.name}: {err_str}')
+
+ # preload images
+ self.preload_images(self._images)
+
+ def delete_cluster(self) -> None:
+ """
+ Overview:
+ Delete the k8s cluster.
+ """
+
+ print('Deleting k8s cluster...')
+ if self.type != K8sType.K3s:
+ return
+ args = ['k3d', 'cluster', 'delete', f'{self.name}']
+ proc = subprocess.Popen(args, stderr=subprocess.PIPE)
+ _, err = proc.communicate()
+ err_str = err.decode('utf-8').strip()
+ if err_str != '' and 'WARN' not in err_str and \
+ 'NotFound' not in err_str:
+ raise RuntimeError(f'Failed to delete cluster {self.name}: {err_str}')
+
+ def preload_images(self, images: list) -> None:
+ """
+ Overview:
+ Preload images.
+ """
+
+ if self.type != K8sType.K3s or len(images) == 0:
+ return
+ args = ['k3d', 'image', 'import', f'--cluster={self.name}']
+ args += images
+
+ proc = subprocess.Popen(args, stderr=subprocess.PIPE)
+ _, err = proc.communicate()
+ err_str = err.decode('utf-8').strip()
+ if err_str != '' and 'WARN' not in err_str:
+ raise RuntimeError(f'Failed to preload images: {err_str}')
diff --git a/DI-engine/ding/utils/linklink_dist_helper.py b/DI-engine/ding/utils/linklink_dist_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..36fffa19a0dd339408fbcc1ce6db571c9184c2b4
--- /dev/null
+++ b/DI-engine/ding/utils/linklink_dist_helper.py
@@ -0,0 +1,227 @@
+from functools import lru_cache
+from typing import Callable, Tuple, List, Any
+
+import numpy as np
+import torch
+
+from .default_helper import error_wrapper
+from .fake_linklink import FakeLink
+from .import_helper import try_import_link
+
+
+@lru_cache()
+def get_link():
+ return try_import_link()
+
+
+@lru_cache()
+def is_fake_link():
+ return isinstance(get_link(), FakeLink)
+
+
+def get_rank() -> int:
+ """
+ Overview:
+ Get the rank of ``linklink`` model, return 0 if use ``FakeLink``.
+
+ .. note::
+ Reference ``import_helper.try_import_link`` and ``linklink.get_rank``.
+ """
+ if is_fake_link():
+ return 0
+ return error_wrapper(get_link().get_rank, 0, "[WARNING]: call linklink error, return default_ret.")()
+
+
+def get_world_size() -> int:
+ """
+ Overview:
+ Get the ``world_size`` of ``linklink model``, return 0 if use ``FakeLink``.
+
+ .. note::
+ Reference ``import_helper.try_import_link`` and ``linklink.get_world_size``.
+ """
+ if is_fake_link():
+ return 1
+ return error_wrapper(get_link().get_world_size, 1, "[WARNING]: call linklink error, return default_ret.")()
+
+
+def broadcast(value: torch.Tensor, rank: int) -> None:
+ """
+ Overview:
+ Use ``linklink.broadcast`` and raise error when using ``FakeLink``
+ Arguments:
+ - value (:obj:`obj`): the value to board cast
+ - rank (:obj:`int`): the rank to broadcast on
+ """
+ if is_fake_link():
+ raise NotImplementedError
+ get_link().broadcast(value, rank)
+
+
+def allreduce(data: torch.Tensor, op: str = 'sum') -> None:
+ """
+ Overview:
+ Call ``linklink.allreduce`` on the data
+ Arguments:
+ - data (:obj:`obj`): the data to reduce
+ - op (:obj:`str`): the operation to perform on data, support ``['sum', 'max']``
+ """
+ link_op_map = {'sum': get_link().allreduceOp_t.Sum, 'max': get_link().allreduceOp_t.Max}
+ if op not in link_op_map.keys():
+ raise KeyError("not support allreduce op type: {}".format(op))
+ else:
+ link_op = link_op_map[op]
+ if is_fake_link():
+ return data
+ get_link().allreduce(data, reduce_op=link_op)
+ if op == 'sum':
+ data.div_(get_world_size())
+
+
+def allreduce_async(data: torch.Tensor, op: str = 'sum') -> None:
+ """
+ Overview:
+ Call ``linklink.allreduce_async`` on the data
+ Arguments:
+ - data (:obj:`obj`): the data to reduce
+ - op (:obj:`str`): the operation to perform on data, support ``['sum', 'max']``
+ """
+ link_op_map = {'sum': get_link().allreduceOp_t.Sum, 'max': get_link().allreduceOp_t.Max}
+ if op not in link_op_map.keys():
+ raise KeyError("not support allreduce op type: {}".format(op))
+ else:
+ link_op = link_op_map[op]
+ if is_fake_link():
+ return data
+ if op == 'sum':
+ data.div_(get_world_size())
+ get_link().allreduce_async(data, reduce_op=link_op)
+
+
+def get_group(group_size: int) -> List:
+ """
+ Overview:
+ Get the group segmentation of ``group_size`` each group
+ Arguments:
+ - group_size (:obj:`int`) the ``group_size``
+ """
+ rank = get_rank()
+ world_size = get_world_size()
+ if group_size is None:
+ group_size = world_size
+ assert (world_size % group_size == 0)
+ return simple_group_split(world_size, rank, world_size // group_size)
+
+
+def dist_mode(func: Callable) -> Callable:
+ """
+ Overview:
+ Wrap the function so that in can init and finalize automatically before each call
+ Arguments:
+ - func (:obj:`Callable`): the function to wrap
+ """
+
+ def wrapper(*args, **kwargs):
+ dist_init()
+ func(*args, **kwargs)
+ dist_finalize()
+
+ return wrapper
+
+
+def dist_init(method: str = 'slurm', device_id: int = 0) -> Tuple[int, int]:
+ """
+ Overview:
+ Init the distribution
+ Arguments:
+ - method (:obj:`str`): Support ``['slurm', 'single_node`]``
+ - device_id (:obj:`int`): Default device when using ``single_node`` method
+ """
+ get_link().initialize()
+ world_size = get_link().get_world_size()
+ rank = get_link().get_rank()
+
+ if method == 'slurm':
+ # proc_id = int(os.environ['SLURM_PROCID'])
+ # ntasks = int(os.environ['SLURM_NTASKS'])
+ # node_list = os.environ['SLURM_NODELIST']
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ elif method == 'single_node':
+ torch.cuda.set_device(device_id)
+
+ return rank, world_size
+
+
+def dist_finalize() -> None:
+ """
+ Overview:
+ Finalize ``linklink``, see ``linklink.finalize()``
+ """
+ get_link().finalize()
+
+
+class DistContext:
+ """
+ Overview:
+ A context manager for ``linklink`` distribution
+ Interfaces:
+ ``__init__``, ``__enter__``, ``__exit__``
+ """
+
+ def __init__(self) -> None:
+ """
+ Overview:
+ Initialize the ``DistContext``
+ """
+
+ pass
+
+ def __enter__(self) -> None:
+ """
+ Overview:
+ Initialize ``linklink`` distribution
+ """
+
+ dist_init()
+
+ def __exit__(self, *args, **kwargs) -> Any:
+ """
+ Overview:
+ Finalize ``linklink`` distribution
+ Arugments:
+ - args (:obj:`Tuple`): The arguments passed to the ``__exit__`` function.
+ - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__exit__`` function.
+ """
+
+ dist_finalize()
+
+
+def simple_group_split(world_size: int, rank: int, num_groups: int) -> List:
+ """
+ Overview:
+ Split the group according to ``worldsize``, ``rank`` and ``num_groups``
+ Arguments:
+ - world_size (:obj:`int`): The world size
+ - rank (:obj:`int`): The rank
+ - num_groups (:obj:`int`): The number of groups
+ .. note::
+ With faulty input, raise ``array split does not result in an equal division``
+ """
+
+ groups = []
+ rank_list = np.split(np.arange(world_size), num_groups)
+ rank_list = [list(map(int, x)) for x in rank_list]
+ for i in range(num_groups):
+ groups.append(get_link().new_group(rank_list[i]))
+ group_size = world_size // num_groups
+ return groups[rank // group_size]
+
+
+def synchronize():
+ """
+ Overview:
+ Synchronize the process
+ """
+
+ get_link().synchronize()
diff --git a/DI-engine/ding/utils/loader/__init__.py b/DI-engine/ding/utils/loader/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..08324700276433f15f7d090c8ee189cf3036a7e3
--- /dev/null
+++ b/DI-engine/ding/utils/loader/__init__.py
@@ -0,0 +1,11 @@
+from .base import Loader
+from .collection import collection, CollectionError, length, length_is, contains, tuple_, cofilter, tpselector
+from .dict import DictError, dict_
+from .exception import CompositeStructureError
+from .mapping import mapping, MappingError, mpfilter, mpkeys, mpvalues, mpitems, item, item_or
+from .norm import norm, normfunc, lnot, land, lor, lin, lis, lisnot, lsum, lcmp
+from .number import interval, numeric, negative, positive, plus, minus, minus_with, multi, divide, divide_with, power, \
+ power_with, msum, mmulti, mcmp, is_negative, is_positive, non_negative, non_positive
+from .string import enum, rematch, regrep
+from .types import is_type, to_type, is_callable, prop, method, fcall, fpartial
+from .utils import keep, optional, check_only, raw, check
diff --git a/DI-engine/ding/utils/loader/base.py b/DI-engine/ding/utils/loader/base.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd55bc8b621ad7e3f695f7d6d7469d91f0faf060
--- /dev/null
+++ b/DI-engine/ding/utils/loader/base.py
@@ -0,0 +1,257 @@
+from abc import abstractmethod
+from typing import TypeVar, Callable, Any
+
+CAPTURE_EXCEPTIONS = (Exception, )
+_ValueType = TypeVar('_ValueType')
+
+
+def _to_exception(exception) -> Callable[[Any], Exception]:
+ """
+ Overview:
+ Convert exception to callable exception.
+ Arguments:
+ - exception (:obj:`Exception`): The exception to be converted.
+ """
+
+ if hasattr(exception, '__call__'):
+ return exception
+ elif isinstance(exception, Exception):
+ return lambda v: exception
+ elif isinstance(exception, str):
+ return lambda v: ValueError(exception)
+ else:
+ raise TypeError(
+ 'Unknown type of exception, func, exception or str expected but {actual} found.'.format(
+ actual=repr(type(exception).__name__)
+ )
+ )
+
+
+def _to_loader(value) -> 'ILoaderClass':
+ """
+ Overview:
+ Convert value to loader.
+ Arguments:
+ - value (:obj:`Any`): The value to be converted.
+ """
+
+ if isinstance(value, ILoaderClass):
+ return value
+ elif isinstance(value, tuple):
+ if len(value) == 2:
+ _predict, _exception = value
+ _load = None
+ elif len(value) == 3:
+ _predict, _load, _exception = value
+ else:
+ raise ValueError('Tuple\'s length should be 2 or 3, but {actual} found.'.format(actual=repr(len(value))))
+
+ _exception = _to_exception(_exception)
+
+ def _load_tuple(value_):
+ if not _predict(value_):
+ raise _exception(value_)
+
+ return (_load or (lambda v: v))(value_)
+
+ return _to_loader(_load_tuple)
+ elif isinstance(value, type):
+
+ def _load_type(value_):
+ if not isinstance(value_, value):
+ raise TypeError(
+ 'type not match, {expect} expected but {actual} found'.format(
+ expect=repr(value.__name__), actual=repr(type(value_).__name__)
+ )
+ )
+ return value_
+
+ return _to_loader(_load_type)
+ elif hasattr(value, '__call__'):
+
+ class _Loader(ILoaderClass):
+
+ def _load(self, value_):
+ return value(value_)
+
+ return _Loader()
+ elif isinstance(value, bool):
+ return _to_loader((lambda v: value, ValueError('assertion false')))
+ elif value is None:
+ return _to_loader(
+ (
+ lambda v: v is None, lambda v:
+ TypeError('type not match, none expected but {actual} found'.format(actual=repr(type(v).__name__)))
+ )
+ )
+ else:
+ return _to_loader(lambda v: value)
+
+
+Loader = _to_loader
+
+
+def _reset_exception(loader, eg: Callable[[Any, Exception], Exception]):
+ """
+ Overview:
+ Reset exception of loader.
+ """
+
+ loader = Loader(loader)
+
+ def _load(value):
+ try:
+ return loader(value)
+ except CAPTURE_EXCEPTIONS as err:
+ raise eg(value, err)
+
+ return Loader(_load)
+
+
+class ILoaderClass:
+ """
+ Overview:
+ Base class of loader.
+ Interfaces:
+ ``__init__``, ``_load``, ``load``, ``check``, ``__call__``, ``__and__``, ``__or__``, ``__rshift__``
+ """
+
+ @abstractmethod
+ def _load(self, value: _ValueType) -> _ValueType:
+ """
+ Overview:
+ Load the value.
+ Arguments:
+ - value (:obj:`_ValueType`): The value to be loaded.
+ """
+
+ raise NotImplementedError
+
+ def __load(self, value: _ValueType) -> _ValueType:
+ """
+ Overview:
+ Load the value.
+ Arguments:
+ - value (:obj:`_ValueType`): The value to be loaded.
+ """
+
+ return self._load(value)
+
+ def __check(self, value: _ValueType) -> bool:
+ """
+ Overview:
+ Check whether the value is valid.
+ Arguments:
+ - value (:obj:`_ValueType`): The value to be checked.
+ """
+
+ try:
+ self._load(value)
+ except CAPTURE_EXCEPTIONS:
+ return False
+ else:
+ return True
+
+ def load(self, value: _ValueType) -> _ValueType:
+ """
+ Overview:
+ Load the value.
+ Arguments:
+ - value (:obj:`_ValueType`): The value to be loaded.
+ """
+
+ return self.__load(value)
+
+ def check(self, value: _ValueType) -> bool:
+ """
+ Overview:
+ Check whether the value is valid.
+ Arguments:
+ - value (:obj:`_ValueType`): The value to be checked.
+ """
+
+ return self.__check(value)
+
+ def __call__(self, value: _ValueType) -> _ValueType:
+ """
+ Overview:
+ Load the value.
+ Arguments:
+ - value (:obj:`_ValueType`): The value to be loaded.
+ """
+
+ return self.__load(value)
+
+ def __and__(self, other) -> 'ILoaderClass':
+ """
+ Overview:
+ Combine two loaders.
+ Arguments:
+ - other (:obj:`ILoaderClass`): The other loader.
+ """
+
+ def _load(value: _ValueType) -> _ValueType:
+ self.load(value)
+ return Loader(other).load(value)
+
+ return Loader(_load)
+
+ def __rand__(self, other) -> 'ILoaderClass':
+ """
+ Overview:
+ Combine two loaders.
+ Arguments:
+ - other (:obj:`ILoaderClass`): The other loader.
+ """
+
+ return Loader(other) & self
+
+ def __or__(self, other) -> 'ILoaderClass':
+ """
+ Overview:
+ Combine two loaders.
+ Arguments:
+ - other (:obj:`ILoaderClass`): The other loader.
+ """
+
+ def _load(value: _ValueType) -> _ValueType:
+ try:
+ return self.load(value)
+ except CAPTURE_EXCEPTIONS:
+ return Loader(other).load(value)
+
+ return Loader(_load)
+
+ def __ror__(self, other) -> 'ILoaderClass':
+ """
+ Overview:
+ Combine two loaders.
+ Arguments:
+ - other (:obj:`ILoaderClass`): The other loader.
+ """
+
+ return Loader(other) | self
+
+ def __rshift__(self, other) -> 'ILoaderClass':
+ """
+ Overview:
+ Combine two loaders.
+ Arguments:
+ - other (:obj:`ILoaderClass`): The other loader.
+ """
+
+ def _load(value: _ValueType) -> _ValueType:
+ _return_value = self.load(value)
+ return _to_loader(other).load(_return_value)
+
+ return Loader(_load)
+
+ def __rrshift__(self, other) -> 'ILoaderClass':
+ """
+ Overview:
+ Combine two loaders.
+ Arguments:
+ - other (:obj:`ILoaderClass`): The other loader.
+ """
+
+ return Loader(other) >> self
diff --git a/DI-engine/ding/utils/loader/collection.py b/DI-engine/ding/utils/loader/collection.py
new file mode 100644
index 0000000000000000000000000000000000000000..770e6c6c64829dc8541e57ae1da3862395e58ef8
--- /dev/null
+++ b/DI-engine/ding/utils/loader/collection.py
@@ -0,0 +1,175 @@
+from typing import Optional, List, Tuple, Callable, Any
+
+from .base import ILoaderClass, Loader, CAPTURE_EXCEPTIONS
+from .exception import CompositeStructureError
+from .types import method
+
+COLLECTION_ERROR_ITEM = Tuple[int, Exception]
+COLLECTION_ERRORS = List[COLLECTION_ERROR_ITEM]
+
+
+class CollectionError(CompositeStructureError):
+ """
+ Overview:
+ Collection error.
+ Interfaces:
+ ``__init__``, ``errors``
+ Properties:
+ ``errors``
+ """
+
+ def __init__(self, errors: COLLECTION_ERRORS):
+ """
+ Overview:
+ Initialize the CollectionError.
+ Arguments:
+ - errors (:obj:`COLLECTION_ERRORS`): The errors.
+ """
+
+ self.__errors = list(errors or [])
+ CompositeStructureError.__init__(
+ self, '{count} error(s) found in collection.'.format(count=repr(list(self.__errors)))
+ )
+
+ @property
+ def errors(self) -> COLLECTION_ERRORS:
+ """
+ Overview:
+ Get the errors.
+ """
+
+ return self.__errors
+
+
+def collection(loader, type_back: bool = True) -> ILoaderClass:
+ """
+ Overview:
+ Create a collection loader.
+ Arguments:
+ - loader (:obj:`ILoaderClass`): The loader.
+ - type_back (:obj:`bool`): Whether to convert the type back.
+ """
+
+ loader = Loader(loader)
+
+ def _load(value):
+ _result = []
+ _errors = []
+
+ for index, item in enumerate(value):
+ try:
+ _return = loader.load(item)
+ except CAPTURE_EXCEPTIONS as err:
+ _errors.append((index, err))
+ else:
+ _result.append(_return)
+
+ if _errors:
+ raise CollectionError(_errors)
+
+ if type_back:
+ _result = type(value)(_result)
+ return _result
+
+ return method('__iter__') & Loader(_load)
+
+
+def tuple_(*loaders) -> ILoaderClass:
+ """
+ Overview:
+ Create a tuple loader.
+ Arguments:
+ - loaders (:obj:`tuple`): The loaders.
+ """
+
+ loaders = [Loader(loader) for loader in loaders]
+
+ def _load(value: tuple):
+ return tuple([loader(item) for loader, item in zip(loaders, value)])
+
+ return tuple & length_is(len(loaders)) & Loader(_load)
+
+
+def length(min_length: Optional[int] = None, max_length: Optional[int] = None) -> ILoaderClass:
+ """
+ Overview:
+ Create a length loader.
+ Arguments:
+ - min_length (:obj:`int`): The minimum length.
+ - max_length (:obj:`int`): The maximum length.
+ """
+
+ def _load(value):
+ _length = len(value)
+ if min_length is not None and _length < min_length:
+ raise ValueError(
+ 'minimum length is {expect}, but {actual} found'.format(expect=repr(min_length), actual=repr(_length))
+ )
+ if max_length is not None and _length > max_length:
+ raise ValueError(
+ 'maximum length is {expect}, but {actual} found'.format(expect=repr(max_length), actual=repr(_length))
+ )
+
+ return value
+
+ return method('__len__') & Loader(_load)
+
+
+def length_is(length_: int) -> ILoaderClass:
+ """
+ Overview:
+ Create a length loader.
+ Arguments:
+ - length_ (:obj:`int`): The length.
+ """
+
+ return length(min_length=length_, max_length=length_)
+
+
+def contains(content) -> ILoaderClass:
+ """
+ Overview:
+ Create a contains loader.
+ Arguments:
+ - content (:obj:`Any`): The content.
+ """
+
+ def _load(value):
+ if content not in value:
+ raise ValueError('{content} not found in value'.format(content=repr(content)))
+
+ return value
+
+ return method('__contains__') & Loader(_load)
+
+
+def cofilter(checker: Callable[[Any], bool], type_back: bool = True) -> ILoaderClass:
+ """
+ Overview:
+ Create a cofilter loader.
+ Arguments:
+ - checker (:obj:`Callable[[Any], bool]`): The checker.
+ - type_back (:obj:`bool`): Whether to convert the type back.
+ """
+
+ def _load(value):
+ _result = [item for item in value if checker(item)]
+ if type_back:
+ _result = type(value)(_result)
+ return _result
+
+ return method('__iter__') & Loader(_load)
+
+
+def tpselector(*indices) -> ILoaderClass:
+ """
+ Overview:
+ Create a tuple selector loader.
+ Arguments:
+ - indices (:obj:`tuple`): The indices.
+ """
+
+ def _load(value: tuple):
+ return tuple([value[index] for index in indices])
+
+ return tuple & Loader(_load)
diff --git a/DI-engine/ding/utils/loader/dict.py b/DI-engine/ding/utils/loader/dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..a14d3ff9f876c37f3a473c82abf7babcc03ee5cf
--- /dev/null
+++ b/DI-engine/ding/utils/loader/dict.py
@@ -0,0 +1,66 @@
+from typing import Mapping
+
+from .base import Loader, CAPTURE_EXCEPTIONS, ILoaderClass
+from .exception import CompositeStructureError
+
+DICT_ERRORS = Mapping[str, Exception]
+
+
+class DictError(CompositeStructureError):
+ """
+ Overview:
+ Dict error.
+ Interfaces:
+ ``__init__``, ``errors``
+ Properties:
+ ``errors``
+ """
+
+ def __init__(self, errors: DICT_ERRORS):
+ """
+ Overview:
+ Initialize the DictError.
+ Arguments:
+ - errors (:obj:`DICT_ERRORS`): The errors.
+ """
+
+ self.__error = errors
+
+ @property
+ def errors(self) -> DICT_ERRORS:
+ """
+ Overview:
+ Get the errors.
+ """
+
+ return self.__error
+
+
+def dict_(**kwargs) -> ILoaderClass:
+ """
+ Overview:
+ Create a dict loader.
+ Arguments:
+ - kwargs (:obj:`Mapping[str, ILoaderClass]`): The loaders.
+ """
+
+ kwargs = [(k, Loader(v)) for k, v in kwargs.items()]
+
+ def _load(value):
+ _errors = {}
+ _results = {}
+
+ for k, vl in kwargs:
+ try:
+ v = vl(value)
+ except CAPTURE_EXCEPTIONS as err:
+ _errors[k] = err
+ else:
+ _results[k] = v
+
+ if not _errors:
+ return _results
+ else:
+ raise DictError(_errors)
+
+ return Loader(_load)
diff --git a/DI-engine/ding/utils/loader/exception.py b/DI-engine/ding/utils/loader/exception.py
new file mode 100644
index 0000000000000000000000000000000000000000..9358f1c85e11370e31db7c47b4c6be5c7c3c4b5b
--- /dev/null
+++ b/DI-engine/ding/utils/loader/exception.py
@@ -0,0 +1,27 @@
+from abc import ABCMeta, abstractmethod
+from typing import List, Union, Tuple
+
+INDEX_TYPING = Union[int, str]
+ERROR_ITEM_TYPING = Tuple[INDEX_TYPING, Exception]
+ERROR_ITEMS = List[ERROR_ITEM_TYPING]
+
+
+class CompositeStructureError(ValueError, metaclass=ABCMeta):
+ """
+ Overview:
+ Composite structure error.
+ Interfaces:
+ ``__init__``, ``errors``
+ Properties:
+ ``errors``
+ """
+
+ @property
+ @abstractmethod
+ def errors(self) -> ERROR_ITEMS:
+ """
+ Overview:
+ Get the errors.
+ """
+
+ raise NotImplementedError
diff --git a/DI-engine/ding/utils/loader/mapping.py b/DI-engine/ding/utils/loader/mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3993c2366af2e10e1416fbc47f6c5385561a8d5
--- /dev/null
+++ b/DI-engine/ding/utils/loader/mapping.py
@@ -0,0 +1,178 @@
+from typing import List, Tuple, Callable, Any
+
+from .base import ILoaderClass, Loader, CAPTURE_EXCEPTIONS
+from .exception import CompositeStructureError
+from .types import method
+from .utils import raw
+
+MAPPING_ERROR_ITEM = Tuple[str, Exception]
+MAPPING_ERRORS = List[MAPPING_ERROR_ITEM]
+
+
+class MappingError(CompositeStructureError):
+ """
+ Overview:
+ Mapping error.
+ Interfaces:
+ ``__init__``, ``errors``
+ """
+
+ def __init__(self, key_errors: MAPPING_ERRORS, value_errors: MAPPING_ERRORS):
+ """
+ Overview:
+ Initialize the MappingError.
+ Arguments:
+ - key_errors (:obj:`MAPPING_ERRORS`): The key errors.
+ - value_errors (:obj:`MAPPING_ERRORS`): The value errors.
+ """
+
+ self.__key_errors = list(key_errors or [])
+ self.__value_errors = list(value_errors or [])
+ self.__errors = self.__key_errors + self.__value_errors
+
+ def key_errors(self) -> MAPPING_ERRORS:
+ """
+ Overview:
+ Get the key errors.
+ """
+
+ return self.__key_errors
+
+ def value_errors(self) -> MAPPING_ERRORS:
+ """
+ Overview:
+ Get the value errors.
+ """
+
+ return self.__value_errors
+
+ def errors(self) -> MAPPING_ERRORS:
+ """
+ Overview:
+ Get the errors.
+ """
+
+ return self.__errors
+
+
+def mapping(key_loader, value_loader, type_back: bool = True) -> ILoaderClass:
+ """
+ Overview:
+ Create a mapping loader.
+ Arguments:
+ - key_loader (:obj:`ILoaderClass`): The key loader.
+ - value_loader (:obj:`ILoaderClass`): The value loader.
+ - type_back (:obj:`bool`): Whether to convert the type back.
+ """
+
+ key_loader = Loader(key_loader)
+ value_loader = Loader(value_loader)
+
+ def _load(value):
+ _key_errors = []
+ _value_errors = []
+ _result = {}
+ for key_, value_ in value.items():
+ key_error, value_error = None, None
+ key_result, value_result = None, None
+
+ try:
+ key_result = key_loader(key_)
+ except CAPTURE_EXCEPTIONS as err:
+ key_error = err
+
+ try:
+ value_result = value_loader(value_)
+ except CAPTURE_EXCEPTIONS as err:
+ value_error = err
+
+ if not key_error and not value_error:
+ _result[key_result] = value_result
+ else:
+ if key_error:
+ _key_errors.append((key_, key_error))
+ if value_error:
+ _value_errors.append((key_, value_error))
+
+ if not _key_errors and not _value_errors:
+ if type_back:
+ _result = type(value)(_result)
+ return _result
+ else:
+ raise MappingError(_key_errors, _value_errors)
+
+ return method('items') & Loader(_load)
+
+
+def mpfilter(check: Callable[[Any, Any], bool], type_back: bool = True) -> ILoaderClass:
+ """
+ Overview:
+ Create a mapping filter loader.
+ Arguments:
+ - check (:obj:`Callable[[Any, Any], bool]`): The check function.
+ - type_back (:obj:`bool`): Whether to convert the type back.
+ """
+
+ def _load(value):
+ _result = {key_: value_ for key_, value_ in value.items() if check(key_, value_)}
+
+ if type_back:
+ _result = type(value)(_result)
+ return _result
+
+ return method('items') & Loader(_load)
+
+
+def mpkeys() -> ILoaderClass:
+ """
+ Overview:
+ Create a mapping keys loader.
+ """
+
+ return method('items') & method('keys') & Loader(lambda v: set(v.keys()))
+
+
+def mpvalues() -> ILoaderClass:
+ """
+ Overview:
+ Create a mapping values loader.
+ """
+
+ return method('items') & method('values') & Loader(lambda v: set(v.values()))
+
+
+def mpitems() -> ILoaderClass:
+ """
+ Overview:
+ Create a mapping items loader.
+ """
+
+ return method('items') & Loader(lambda v: set([(key, value) for key, value in v.items()]))
+
+
+_INDEX_PRECHECK = method('__getitem__')
+
+
+def item(key) -> ILoaderClass:
+ """
+ Overview:
+ Create a item loader.
+ Arguments:
+ - key (:obj:`Any`): The key.
+ """
+
+ return _INDEX_PRECHECK & Loader(
+ (lambda v: key in v.keys(), lambda v: v[key], KeyError('key {key} not found'.format(key=repr(key))))
+ )
+
+
+def item_or(key, default) -> ILoaderClass:
+ """
+ Overview:
+ Create a item or loader.
+ Arguments:
+ - key (:obj:`Any`): The key.
+ - default (:obj:`Any`): The default value.
+ """
+
+ return _INDEX_PRECHECK & (item(key) | raw(default))
diff --git a/DI-engine/ding/utils/loader/norm.py b/DI-engine/ding/utils/loader/norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..af142ed4e609c38850ef93d30dbf4b95c6f66906
--- /dev/null
+++ b/DI-engine/ding/utils/loader/norm.py
@@ -0,0 +1,535 @@
+import operator
+from abc import abstractmethod
+from functools import wraps
+from typing import Callable, Any
+
+from .base import ILoaderClass
+
+
+def _callable_to_norm(func: Callable[[Any], Any]) -> 'INormClass':
+ """
+ Overview:
+ Convert callable to norm.
+ Arguments:
+ - func (:obj:`Callable[[Any], Any]`): The callable to be converted.
+ """
+
+ class _Norm(INormClass):
+
+ def _call(self, value):
+ return func(value)
+
+ return _Norm()
+
+
+def norm(value) -> 'INormClass':
+ """
+ Overview:
+ Convert value to norm.
+ Arguments:
+ - value (:obj:`Any`): The value to be converted.
+ """
+
+ if isinstance(value, INormClass):
+ return value
+ elif isinstance(value, ILoaderClass):
+ return _callable_to_norm(value)
+ else:
+ return _callable_to_norm(lambda v: value)
+
+
+def normfunc(func):
+ """
+ Overview:
+ Convert function to norm function.
+ Arguments:
+ - func (:obj:`Callable[[Any], Any]`): The function to be converted.
+ """
+
+ @wraps(func)
+ def _new_func(*args_norm, **kwargs_norm):
+ args_norm = [norm(item) for item in args_norm]
+ kwargs_norm = {key: norm(value) for key, value in kwargs_norm.items()}
+
+ def _callable(v):
+ args = [item(v) for item in args_norm]
+ kwargs = {key: value(v) for key, value in kwargs_norm.items()}
+ return func(*args, **kwargs)
+
+ return _callable_to_norm(_callable)
+
+ return _new_func
+
+
+UNARY_FUNC = Callable[[Any], Any]
+BINARY_FUNC = Callable[[Any, Any], Any]
+
+
+def _unary(a: 'INormClass', func: UNARY_FUNC) -> 'INormClass':
+ """
+ Overview:
+ Create a unary norm.
+ Arguments:
+ - a (:obj:`INormClass`): The norm.
+ - func (:obj:`UNARY_FUNC`): The function.
+ """
+
+ return _callable_to_norm(lambda v: func(a(v)))
+
+
+def _binary(a: 'INormClass', b: 'INormClass', func: BINARY_FUNC) -> 'INormClass':
+ """
+ Overview:
+ Create a binary norm.
+ Arguments:
+ - a (:obj:`INormClass`): The first norm.
+ - b (:obj:`INormClass`): The second norm.
+ - func (:obj:`BINARY_FUNC`): The function.
+ """
+ return _callable_to_norm(lambda v: func(a(v), b(v)))
+
+
+def _binary_reducing(func: BINARY_FUNC, zero):
+ """
+ Overview:
+ Create a binary reducing norm.
+ Arguments:
+ - func (:obj:`BINARY_FUNC`): The function.
+ - zero (:obj:`Any`): The zero value.
+ """
+
+ @wraps(func)
+ def _new_func(*args) -> 'INormClass':
+ _sum = norm(zero)
+ for item in args:
+ _sum = _binary(_sum, norm(item), func)
+ return _sum
+
+ return _new_func
+
+
+class INormClass:
+ """
+ Overview:
+ The norm class.
+ Interfaces:
+ ``__call__``, ``__add__``, ``__radd__``, ``__sub__``, ``__rsub__``, ``__mul__``, ``__rmul__``, ``__matmul__``,
+ ``__rmatmul__``, ``__truediv__``, ``__rtruediv__``, ``__floordiv__``, ``__rfloordiv__``, ``__mod__``,
+ ``__rmod__``, ``__pow__``, ``__rpow__``, ``__lshift__``, ``__rlshift__``, ``__rshift__``, ``__rrshift__``,
+ ``__and__``, ``__rand__``, ``__or__``, ``__ror__``, ``__xor__``, ``__rxor__``, ``__invert__``, ``__pos__``,
+ ``__neg__``, ``__eq__``, ``__ne__``, ``__lt__``, ``__le__``, ``__gt__``, ``__ge__``
+ """
+
+ @abstractmethod
+ def _call(self, value):
+ """
+ Overview:
+ Call the norm.
+ Arguments:
+ - value (:obj:`Any`): The value to be normalized.
+ """
+
+ raise NotImplementedError
+
+ def __call__(self, value):
+ """
+ Overview:
+ Call the norm.
+ Arguments:
+ - value (:obj:`Any`): The value to be normalized.
+ """
+
+ return self._call(value)
+
+ def __add__(self, other):
+ """
+ Overview:
+ Add the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__add__)
+
+ def __radd__(self, other):
+ """
+ Overview:
+ Add the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return norm(other) + self
+
+ def __sub__(self, other):
+ """
+ Overview:
+ Subtract the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__sub__)
+
+ def __rsub__(self, other):
+ """
+ Overview:
+ Subtract the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return norm(other) - self
+
+ def __mul__(self, other):
+ """
+ Overview:
+ Multiply the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__mul__)
+
+ def __rmul__(self, other):
+ """
+ Overview:
+ Multiply the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return norm(other) * self
+
+ def __matmul__(self, other):
+ """
+ Overview:
+ Matrix multiply the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__matmul__)
+
+ def __rmatmul__(self, other):
+ """
+ Overview:
+ Matrix multiply the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return norm(other) @ self
+
+ def __truediv__(self, other):
+ """
+ Overview:
+ Divide the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__truediv__)
+
+ def __rtruediv__(self, other):
+ """
+ Overview:
+ Divide the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return norm(other) / self
+
+ def __floordiv__(self, other):
+ """
+ Overview:
+ Floor divide the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__floordiv__)
+
+ def __rfloordiv__(self, other):
+ """
+ Overview:
+ Floor divide the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return norm(other) // self
+
+ def __mod__(self, other):
+ """
+ Overview:
+ Mod the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__mod__)
+
+ def __rmod__(self, other):
+ """
+ Overview:
+ Mod the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return norm(other) % self
+
+ def __pow__(self, power, modulo=None):
+ """
+ Overview:
+ Power the norm.
+ Arguments:
+ - power (:obj:`Any`): The power.
+ - modulo (:obj:`Any`): The modulo.
+ """
+
+ return _binary(self, norm(power), operator.__pow__)
+
+ def __rpow__(self, other):
+ """
+ Overview:
+ Power the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return norm(other) ** self
+
+ def __lshift__(self, other):
+ """
+ Overview:
+ Lshift the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__lshift__)
+
+ def __rlshift__(self, other):
+ """
+ Overview:
+ Lshift the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return norm(other) << self
+
+ def __rshift__(self, other):
+ """
+ Overview:
+ Rshift the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__rshift__)
+
+ def __rrshift__(self, other):
+ """
+ Overview:
+ Rshift the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return norm(other) >> self
+
+ def __and__(self, other):
+ """
+ Overview:
+ And operation the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__and__)
+
+ def __rand__(self, other):
+ """
+ Overview:
+ And operation the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return norm(other) & self
+
+ def __or__(self, other):
+ """
+ Overview:
+ Or operation the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__or__)
+
+ def __ror__(self, other):
+ """
+ Overview:
+ Or operation the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return norm(other) | self
+
+ def __xor__(self, other):
+ """
+ Overview:
+ Xor operation the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__xor__)
+
+ def __rxor__(self, other):
+ """
+ Overview:
+ Xor operation the norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return norm(other) ^ self
+
+ def __invert__(self):
+ """
+ Overview:
+ Invert the norm.
+ """
+
+ return _unary(self, operator.__invert__)
+
+ def __pos__(self):
+ """
+ Overview:
+ Positive the norm.
+ """
+
+ return _unary(self, operator.__pos__)
+
+ def __neg__(self):
+ """
+ Overview:
+ Negative the norm.
+ """
+
+ return _unary(self, operator.__neg__)
+
+ # Attention: DO NOT USE LINKING COMPARE OPERATORS, IT WILL CAUSE ERROR.
+ def __eq__(self, other):
+ """
+ Overview:
+ Compare the norm if they are equal.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__eq__)
+
+ def __ne__(self, other):
+ """
+ Overview:
+ Compare the norm if they are not equal.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__ne__)
+
+ def __lt__(self, other):
+ """
+ Overview:
+ Compare the norm if it is less than the other norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__lt__)
+
+ def __le__(self, other):
+ """
+ Overview:
+ Compare the norm if it is less than or equal to the other norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__le__)
+
+ def __gt__(self, other):
+ """
+ Overview:
+ Compare the norm if it is greater than the other norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__gt__)
+
+ def __ge__(self, other):
+ """
+ Overview:
+ Compare the norm if it is greater than or equal to the other norm.
+ Arguments:
+ - other (:obj:`Any`): The other norm.
+ """
+
+ return _binary(self, norm(other), operator.__ge__)
+
+
+lnot = normfunc(lambda x: not x)
+land = _binary_reducing(lambda x, y: x and y, True)
+lor = _binary_reducing(lambda x, y: x or y, True)
+
+lin = normfunc(operator.__contains__)
+lis = normfunc(operator.is_)
+lisnot = normfunc(operator.is_not)
+
+lsum = _binary_reducing(lambda x, y: x + y, 0)
+
+_COMPARE_OPERATORS = {
+ '!=': operator.__ne__,
+ '==': operator.__eq__,
+ '<': operator.__lt__,
+ '<=': operator.__le__,
+ '>': operator.__gt__,
+ '>=': operator.__ge__,
+}
+
+
+@normfunc
+def lcmp(first, *items):
+ """
+ Overview:
+ Compare the items.
+ Arguments:
+ - first (:obj:`Any`): The first item.
+ - items (:obj:`Any`): The other items.
+ """
+
+ if len(items) % 2 == 1:
+ raise ValueError('Count of items should be odd number but {number} found.'.format(number=len(items) + 1))
+
+ ops, items = items[0::2], items[1::2]
+ for op in ops:
+ if op not in _COMPARE_OPERATORS.keys():
+ raise KeyError('Invalid compare operator - {op}.'.format(op=repr(op)))
+
+ _last = first
+ for op, item in zip(ops, items):
+ if not _COMPARE_OPERATORS[op](_last, item):
+ return False
+ _last = item
+
+ return True
diff --git a/DI-engine/ding/utils/loader/number.py b/DI-engine/ding/utils/loader/number.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9fdc59a16ca962ac50f80cf9a69a7ac5dd74833
--- /dev/null
+++ b/DI-engine/ding/utils/loader/number.py
@@ -0,0 +1,361 @@
+import math
+import operator
+from typing import Optional, Union, Callable, Any
+
+from .base import Loader, ILoaderClass
+from .utils import keep, check_only
+
+NUMBER_TYPES = (int, float)
+NUMBER_TYPING = Union[int, float]
+
+
+def numeric(int_ok: bool = True, float_ok: bool = True, inf_ok: bool = True) -> ILoaderClass:
+ """
+ Overview:
+ Create a numeric loader.
+ Arguments:
+ - int_ok (:obj:`bool`): Whether int is allowed.
+ - float_ok (:obj:`bool`): Whether float is allowed.
+ - inf_ok (:obj:`bool`): Whether inf is allowed.
+ """
+
+ if not int_ok and not float_ok:
+ raise ValueError('Either int or float should be allowed.')
+
+ def _load(value) -> NUMBER_TYPING:
+ if isinstance(value, NUMBER_TYPES):
+ if math.isnan(value):
+ raise ValueError('nan is not numeric value')
+ if isinstance(value, int) and not int_ok:
+ raise TypeError('int is not allowed but {actual} found'.format(actual=repr(value)))
+ if isinstance(value, float) and not float_ok:
+ raise TypeError('float is not allowed but {actual} found'.format(actual=repr(value)))
+ if math.isinf(value) and not inf_ok:
+ raise ValueError('inf is not allowed but {actual} found'.format(actual=repr(value)))
+
+ return value
+ else:
+ raise TypeError(
+ 'numeric value should be either int, float or str, but {actual} found'.format(
+ actual=repr(type(value).__name__)
+ )
+ )
+
+ return Loader(_load)
+
+
+def interval(
+ left: Optional[NUMBER_TYPING] = None,
+ right: Optional[NUMBER_TYPING] = None,
+ left_ok: bool = True,
+ right_ok: bool = True,
+ eps=0.0
+) -> ILoaderClass:
+ """
+ Overview:
+ Create a interval loader.
+ Arguments:
+ - left (:obj:`Optional[NUMBER_TYPING]`): The left bound.
+ - right (:obj:`Optional[NUMBER_TYPING]`): The right bound.
+ - left_ok (:obj:`bool`): Whether left bound is allowed.
+ - right_ok (:obj:`bool`): Whether right bound is allowed.
+ - eps (:obj:`float`): The epsilon.
+ """
+
+ if left is None:
+ left = -math.inf
+ if right is None:
+ right = +math.inf
+ if left > right:
+ raise ValueError(
+ "Left bound should no more than right bound, but {left} > {right}.".format(
+ left=repr(left), right=repr(right)
+ )
+ )
+ eps = math.fabs(eps)
+
+ def _value_compare_with_eps(a, b) -> int:
+ if math.fabs(a - b) <= eps:
+ return 0
+ elif a < b:
+ return -1
+ else:
+ return 1
+
+ def _load(value) -> NUMBER_TYPING:
+ _left_check = _value_compare_with_eps(value, left)
+ if _left_check < 0:
+ raise ValueError(
+ 'value should be no less than {left} but {value} found'.format(left=repr(left), value=repr(value))
+ )
+ elif not left_ok and _left_check == 0:
+ raise ValueError(
+ 'value should not be equal to left bound {left} but {value} found'.format(
+ left=repr(left), value=repr(value)
+ )
+ )
+
+ _right_check = _value_compare_with_eps(value, right)
+ if _right_check > 0:
+ raise ValueError(
+ 'value should be no more than {right} but {value} found'.format(right=repr(right), value=repr(value))
+ )
+ elif not right_ok and _right_check == 0:
+ raise ValueError(
+ 'value should not be equal to right bound {right} but {value} found'.format(
+ right=repr(right), value=repr(value)
+ )
+ )
+
+ return value
+
+ return Loader(_load)
+
+
+def is_negative() -> ILoaderClass:
+ """
+ Overview:
+ Create a negative loader.
+ """
+
+ return Loader((lambda x: x < 0, lambda x: ValueError('negative required but {value} found'.format(value=repr(x)))))
+
+
+def is_positive() -> ILoaderClass:
+ """
+ Overview:
+ Create a positive loader.
+ """
+
+ return Loader((lambda x: x > 0, lambda x: ValueError('positive required but {value} found'.format(value=repr(x)))))
+
+
+def non_negative() -> ILoaderClass:
+ """
+ Overview:
+ Create a non-negative loader.
+ """
+
+ return Loader(
+ (lambda x: x >= 0, lambda x: ValueError('non-negative required but {value} found'.format(value=repr(x))))
+ )
+
+
+def non_positive() -> ILoaderClass:
+ """
+ Overview:
+ Create a non-positive loader.
+ """
+
+ return Loader(
+ (lambda x: x <= 0, lambda x: ValueError('non-positive required but {value} found'.format(value=repr(x))))
+ )
+
+
+def negative() -> ILoaderClass:
+ """
+ Overview:
+ Create a negative loader.
+ """
+
+ return Loader(lambda x: -x)
+
+
+def positive() -> ILoaderClass:
+ """
+ Overview:
+ Create a positive loader.
+ """
+
+ return Loader(lambda x: +x)
+
+
+def _math_binary(func: Callable[[Any, Any], Any], attachment) -> ILoaderClass:
+ """
+ Overview:
+ Create a math binary loader.
+ Arguments:
+ - func (:obj:`Callable[[Any, Any], Any]`): The function.
+ - attachment (:obj:`Any`): The attachment.
+ """
+
+ return Loader(lambda x: func(x, Loader(attachment)(x)))
+
+
+def plus(addend) -> ILoaderClass:
+ """
+ Overview:
+ Create a plus loader.
+ Arguments:
+ - addend (:obj:`Any`): The addend.
+ """
+
+ return _math_binary(lambda x, y: x + y, addend)
+
+
+def minus(subtrahend) -> ILoaderClass:
+ """
+ Overview:
+ Create a minus loader.
+ Arguments:
+ - subtrahend (:obj:`Any`): The subtrahend.
+ """
+
+ return _math_binary(lambda x, y: x - y, subtrahend)
+
+
+def minus_with(minuend) -> ILoaderClass:
+ """
+ Overview:
+ Create a minus loader.
+ Arguments:
+ - minuend (:obj:`Any`): The minuend.
+ """
+
+ return _math_binary(lambda x, y: y - x, minuend)
+
+
+def multi(multiplier) -> ILoaderClass:
+ """
+ Overview:
+ Create a multi loader.
+ Arguments:
+ - multiplier (:obj:`Any`): The multiplier.
+ """
+
+ return _math_binary(lambda x, y: x * y, multiplier)
+
+
+def divide(divisor) -> ILoaderClass:
+ """
+ Overview:
+ Create a divide loader.
+ Arguments:
+ - divisor (:obj:`Any`): The divisor.
+ """
+
+ return _math_binary(lambda x, y: x / y, divisor)
+
+
+def divide_with(dividend) -> ILoaderClass:
+ """
+ Overview:
+ Create a divide loader.
+ Arguments:
+ - dividend (:obj:`Any`): The dividend.
+ """
+
+ return _math_binary(lambda x, y: y / x, dividend)
+
+
+def power(index) -> ILoaderClass:
+ """
+ Overview:
+ Create a power loader.
+ Arguments:
+ - index (:obj:`Any`): The index.
+ """
+
+ return _math_binary(lambda x, y: x ** y, index)
+
+
+def power_with(base) -> ILoaderClass:
+ """
+ Overview:
+ Create a power loader.
+ Arguments:
+ - base (:obj:`Any`): The base.
+ """
+
+ return _math_binary(lambda x, y: y ** x, base)
+
+
+def msum(*items) -> ILoaderClass:
+ """
+ Overview:
+ Create a sum loader.
+ Arguments:
+ - items (:obj:`tuple`): The items.
+ """
+
+ def _load(value):
+ return sum([item(value) for item in items])
+
+ return Loader(_load)
+
+
+def mmulti(*items) -> ILoaderClass:
+ """
+ Overview:
+ Create a multi loader.
+ Arguments:
+ - items (:obj:`tuple`): The items.
+ """
+
+ def _load(value):
+ _result = 1
+ for item in items:
+ _result *= item(value)
+ return _result
+
+ return Loader(_load)
+
+
+_COMPARE_OPERATORS = {
+ '!=': operator.__ne__,
+ '==': operator.__eq__,
+ '<': operator.__lt__,
+ '<=': operator.__le__,
+ '>': operator.__gt__,
+ '>=': operator.__ge__,
+}
+
+
+def _msinglecmp(first, op, second) -> ILoaderClass:
+ """
+ Overview:
+ Create a single compare loader.
+ Arguments:
+ - first (:obj:`Any`): The first item.
+ - op (:obj:`str`): The operator.
+ - second (:obj:`Any`): The second item.
+ """
+
+ first = Loader(first)
+ second = Loader(second)
+
+ if op in _COMPARE_OPERATORS.keys():
+ return Loader(
+ (
+ lambda x: _COMPARE_OPERATORS[op](first(x), second(x)), lambda x: ValueError(
+ 'comparison failed for {first} {op} {second}'.format(
+ first=repr(first(x)),
+ second=repr(second(x)),
+ op=op,
+ )
+ )
+ )
+ )
+ else:
+ raise KeyError('Invalid compare operator - {op}.'.format(op=repr(op)))
+
+
+def mcmp(first, *items) -> ILoaderClass:
+ """
+ Overview:
+ Create a multi compare loader.
+ Arguments:
+ - first (:obj:`Any`): The first item.
+ - items (:obj:`tuple`): The items.
+ """
+
+ if len(items) % 2 == 1:
+ raise ValueError('Count of items should be odd number but {number} found.'.format(number=len(items) + 1))
+
+ ops, items = items[0::2], items[1::2]
+
+ _result = keep()
+ for first, op, second in zip((first, ) + items[:-1], ops, items):
+ _result &= _msinglecmp(first, op, second)
+
+ return check_only(_result)
diff --git a/DI-engine/ding/utils/loader/string.py b/DI-engine/ding/utils/loader/string.py
new file mode 100644
index 0000000000000000000000000000000000000000..16d4827cc4665d90da15f6e55f300871a3c489c3
--- /dev/null
+++ b/DI-engine/ding/utils/loader/string.py
@@ -0,0 +1,112 @@
+import re
+from functools import wraps
+from itertools import islice
+from typing import Callable, Union, Pattern
+
+from .base import Loader, ILoaderClass
+
+STRING_PROCESSOR = Callable[[str], str]
+
+
+def enum(*items, case_sensitive: bool = True) -> ILoaderClass:
+ """
+ Overview:
+ Create an enum loader.
+ Arguments:
+ - items (:obj:`Iterable[str]`): The items.
+ - case_sensitive (:obj:`bool`): Whether case sensitive.
+ """
+
+ def _case_sensitive(func: STRING_PROCESSOR) -> STRING_PROCESSOR:
+ if case_sensitive:
+ return func
+ else:
+
+ @wraps(func)
+ def _new_func(value: str) -> str:
+ return func(value).lower()
+
+ return _new_func
+
+ @_case_sensitive
+ def _item_process(value):
+ return str(value)
+
+ item_set = set([_item_process(item) for item in items])
+
+ def _load(value: str):
+ real_value = _item_process(value)
+ if real_value not in item_set:
+ raise ValueError('unknown enum value {value}'.format(value=repr(value)))
+
+ return real_value
+
+ return Loader(_load)
+
+
+def _to_regexp(regexp) -> Pattern:
+ """
+ Overview:
+ Convert regexp to re.Pattern.
+ Arguments:
+ - regexp (:obj:`Union[str, re.Pattern]`): The regexp.
+ """
+
+ if isinstance(regexp, Pattern):
+ return regexp
+ elif isinstance(regexp, str):
+ return re.compile(regexp)
+ else:
+ raise TypeError(
+ 'Regexp should be either str or re.Pattern but {actual} found.'.format(actual=repr(type(regexp).__name__))
+ )
+
+
+def rematch(regexp: Union[str, Pattern]) -> ILoaderClass:
+ """
+ Overview:
+ Create a rematch loader.
+ Arguments:
+ - regexp (:obj:`Union[str, re.Pattern]`): The regexp.
+ """
+
+ regexp = _to_regexp(regexp)
+
+ def _load(value: str):
+ if not regexp.fullmatch(value):
+ raise ValueError(
+ 'fully match with regexp {pattern} expected but {actual} found'.format(
+ pattern=repr(regexp.pattern),
+ actual=repr(value),
+ )
+ )
+
+ return value
+
+ return Loader(_load)
+
+
+def regrep(regexp: Union[str, Pattern], group: int = 0) -> ILoaderClass:
+ """
+ Overview:
+ Create a regrep loader.
+ Arguments:
+ - regexp (:obj:`Union[str, re.Pattern]`): The regexp.
+ - group (:obj:`int`): The group.
+ """
+
+ regexp = _to_regexp(regexp)
+
+ def _load(value: str):
+ results = list(islice(regexp.finditer(value), 1))
+ if results:
+ return results[0][group]
+ else:
+ raise ValueError(
+ 'fully match with regexp {pattern} expected but {actual} found'.format(
+ pattern=repr(regexp.pattern),
+ actual=repr(value),
+ )
+ )
+
+ return Loader(_load)
diff --git a/DI-engine/ding/utils/loader/tests/__init__.py b/DI-engine/ding/utils/loader/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..737dfcebde5ac29ee64f4a48afa35e8ce56ebf26
--- /dev/null
+++ b/DI-engine/ding/utils/loader/tests/__init__.py
@@ -0,0 +1,2 @@
+from .loader import *
+from .test_cartpole_dqn_serial_config_loader import test_main_config, test_create_config
diff --git a/DI-engine/ding/utils/loader/tests/loader/__init__.py b/DI-engine/ding/utils/loader/tests/loader/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad6f25ba543aa569836db95835bb49ffc57bdb6d
--- /dev/null
+++ b/DI-engine/ding/utils/loader/tests/loader/__init__.py
@@ -0,0 +1,9 @@
+from .test_base import TestConfigLoaderBase
+from .test_collection import TestConfigLoaderCollection
+from .test_dict import TestConfigLoaderDict
+from .test_mapping import TestConfigLoaderMapping
+from .test_norm import TestConfigLoaderNorm
+from .test_number import TestConfigLoaderNumber
+from .test_string import TestConfigLoaderString
+from .test_types import TestConfigLoaderTypes
+from .test_utils import TestConfigLoaderUtils
diff --git a/DI-engine/ding/utils/loader/tests/loader/test_base.py b/DI-engine/ding/utils/loader/tests/loader/test_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..b88baa3d4296bb19c4ac446c165676bb7d188f5b
--- /dev/null
+++ b/DI-engine/ding/utils/loader/tests/loader/test_base.py
@@ -0,0 +1,176 @@
+import pytest
+
+from ding.utils.loader import Loader
+
+
+@pytest.mark.unittest
+class TestConfigLoaderBase:
+
+ def test_load(self):
+ _loader = Loader(int)
+ assert _loader.load(1) == 1
+ with pytest.raises(TypeError):
+ _loader.load('string')
+
+ def test_check(self):
+ _loader = Loader(int)
+ assert _loader.check(1)
+ assert not _loader.check('string')
+
+ def test_call(self):
+ _loader = Loader(int)
+ assert _loader(1) == 1
+ with pytest.raises(TypeError):
+ _loader('string')
+
+ def test_or(self):
+ _loader = Loader(int) | str
+ assert _loader(1) == 1
+ assert _loader('string') == 'string'
+ with pytest.raises(TypeError):
+ _loader([])
+
+ assert _loader.check(1)
+ assert _loader.check('string')
+ assert not _loader.check([])
+
+ def test_ror(self):
+ _loader = (lambda v: v < 0, 'Negative number expected.') | Loader(int)
+
+ assert _loader(-1) == -1
+ assert _loader(1) == 1
+ assert _loader(-1.0) - 1.0
+ with pytest.raises(TypeError):
+ _loader(1.0)
+
+ assert _loader.check(-1)
+ assert _loader.check(1)
+ assert _loader.check(-1.0)
+ assert not _loader.check(1.0)
+
+ # noinspection DuplicatedCode
+ def test_and(self):
+ _loader = Loader(int) & (lambda x: x >= 0, 'non-negative number required')
+
+ assert _loader(1) == 1
+ with pytest.raises(TypeError):
+ _loader(1.0)
+ with pytest.raises(ValueError):
+ _loader(-1)
+ with pytest.raises(TypeError):
+ _loader(-1.0)
+
+ assert _loader.check(1)
+ assert not _loader.check(1.0)
+ assert not _loader.check(-1)
+ assert not _loader.check(-1.0)
+
+ # noinspection DuplicatedCode
+ def test_rand(self):
+ _loader = (lambda x: x >= 0, 'non-negative number required') & Loader(int)
+
+ assert _loader(1) == 1
+ with pytest.raises(TypeError):
+ _loader(1.0)
+ with pytest.raises(ValueError):
+ _loader(-1)
+ with pytest.raises(ValueError):
+ _loader(-1.0)
+
+ assert _loader.check(1)
+ assert not _loader.check(1.0)
+ assert not _loader.check(-1)
+ assert not _loader.check(-1.0)
+
+ def test_tuple_2(self):
+ _loader = Loader((lambda x: x > 0, 'value error'))
+ assert _loader(1) == 1
+ with pytest.raises(ValueError):
+ assert _loader(0)
+
+ _loader = Loader((lambda x: x > 0, RuntimeError('runtime error')))
+ assert _loader(1) == 1
+ with pytest.raises(RuntimeError):
+ assert _loader(0)
+
+ _loader = Loader((lambda x: x > 0, lambda x: SystemError('system error, value is {v}'.format(v=repr(x)))))
+ assert _loader(1) == 1
+ with pytest.raises(SystemError):
+ assert _loader(0)
+
+ def test_tuple_3(self):
+ _loader = Loader((lambda x: x > 0, lambda x: x + 1, 'value error'))
+ assert _loader(1) == 2
+ assert _loader(0.5) == 1.5
+ with pytest.raises(ValueError):
+ assert _loader(0)
+
+ _loader = Loader((lambda x: x > 0, lambda x: -x, RuntimeError('runtime error')))
+ assert _loader(1) == -1
+ assert _loader(0.5) == -0.5
+ with pytest.raises(RuntimeError):
+ assert _loader(0)
+
+ _loader = Loader(
+ (lambda x: x > 0, lambda x: x ** 2, lambda x: SystemError('system error, value is {v}'.format(v=repr(x))))
+ )
+ assert _loader(1) == 1
+ assert _loader(0.5) == 0.25
+ with pytest.raises(SystemError):
+ assert _loader(0)
+
+ def test_tuple_invalid(self):
+ with pytest.raises(ValueError):
+ Loader(())
+ with pytest.raises(ValueError):
+ Loader((lambda x: x > 0, ))
+ with pytest.raises(ValueError):
+ Loader((lambda x: x > 0, lambda x: x + 1, 'value error', None))
+
+ def test_bool(self):
+ _loader = Loader(int) & True
+ assert _loader(1) == 1
+ with pytest.raises(TypeError):
+ _loader(None)
+
+ assert _loader.check(1)
+ assert not _loader.check(None)
+
+ _loader = Loader(int) & False
+ with pytest.raises(ValueError):
+ _loader(1)
+ with pytest.raises(TypeError):
+ _loader(None)
+
+ assert not _loader.check(1)
+ assert not _loader.check(None)
+
+ _loader = Loader(int) | True
+ assert _loader(1) == 1
+ assert _loader(None) is None
+
+ assert _loader.check(1)
+ assert _loader.check(None)
+
+ _loader = Loader(int) | False
+ assert _loader(1) == 1
+ with pytest.raises(ValueError):
+ _loader(None)
+
+ assert _loader.check(1)
+ assert not _loader.check(None)
+
+ def test_none(self):
+ _loader = Loader(int) | None
+ assert _loader(1) == 1
+ assert _loader(None) is None
+ with pytest.raises(TypeError):
+ _loader('string')
+
+ assert _loader.check(1)
+ assert _loader.check(None)
+ assert not _loader.check('string')
+
+ def test_raw_loader(self):
+ _loader = Loader([1, 2, 3])
+ assert _loader(None) == [1, 2, 3]
diff --git a/DI-engine/ding/utils/loader/tests/loader/test_collection.py b/DI-engine/ding/utils/loader/tests/loader/test_collection.py
new file mode 100644
index 0000000000000000000000000000000000000000..91c8e41daa01074b5633980bc1c7c418b254b388
--- /dev/null
+++ b/DI-engine/ding/utils/loader/tests/loader/test_collection.py
@@ -0,0 +1,150 @@
+import pytest
+
+from ding.utils.loader import Loader, collection, contains, length_is, length, tuple_, CollectionError, cofilter, \
+ tpselector, plus, minus, interval, negative, to_type, optional, check_only
+
+
+@pytest.mark.unittest
+class TestConfigLoaderCollection:
+
+ def test_collection(self):
+ _loader = collection(Loader(int) | str)
+ assert _loader([1]) == [1]
+ assert _loader([1, 'string']) == [1, 'string']
+ assert _loader({1, 'string'}) == {1, 'string'}
+ assert _loader((1, 'string')) == (1, 'string')
+ with pytest.raises(TypeError):
+ _loader(1)
+ with pytest.raises(TypeError):
+ _loader(None)
+ with pytest.raises(CollectionError) as ei:
+ _loader([None, 1, 'string', 290384.23])
+
+ err = ei.value
+ assert len(err.errors) == 2
+ assert [index for index, _ in err.errors] == [0, 3]
+ assert [type(item) for _, item in err.errors] == [TypeError, TypeError]
+
+ def test_collection_map(self):
+ _loader = collection(
+ ((Loader(int) | float) >> plus(1) >> negative()) | (str >> (to_type(int) | to_type(float)))
+ )
+ assert _loader([1, 2, -3.0, '1', '2.0']) == [-2, -3, 2.0, 1, 2.0]
+ assert [type(item) for item in _loader([1, 2, -3.0, '1', '2.0'])] == [int, int, float, int, float]
+
+ def test_tuple(self):
+ _loader = tuple_(int, optional(float), plus(1) >> interval(2, 3), minus(1) >> interval(-4, -3))
+ assert _loader((1, 2.3, 1.2, -2.5)) == (1, 2.3, 2.2, -3.5)
+ assert _loader((10, None, 2, -3)) == (10, None, 3, -4)
+ with pytest.raises(TypeError):
+ _loader((10.1, 9238.2, 1.2, -2.5))
+ with pytest.raises(ValueError):
+ _loader((10, 9238.2, 4.2, -2.5))
+
+ # noinspection DuplicatedCode
+ def test_length_min_length(self):
+ _loader = length(min_length=2)
+ assert _loader('ab') == 'ab'
+ assert _loader('abcdefg') == 'abcdefg'
+ assert _loader([1, 2]) == [1, 2]
+ assert _loader([1, 2, 3, 4, 5, 6, 7]) == [1, 2, 3, 4, 5, 6, 7]
+ with pytest.raises(ValueError):
+ _loader('a')
+ with pytest.raises(ValueError):
+ _loader([1])
+ assert _loader('abcdefghij') == 'abcdefghij'
+ assert _loader([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
+ with pytest.raises(TypeError):
+ _loader(1)
+
+ # noinspection DuplicatedCode
+ def test_length_max_length(self):
+ _loader = length(max_length=7)
+ assert _loader('ab') == 'ab'
+ assert _loader('abcdefg') == 'abcdefg'
+ assert _loader([1, 2]) == [1, 2]
+ assert _loader([1, 2, 3, 4, 5, 6, 7]) == [1, 2, 3, 4, 5, 6, 7]
+ assert _loader('a') == 'a'
+ assert _loader([1]) == [1]
+ with pytest.raises(ValueError):
+ _loader('abcdefghij')
+ with pytest.raises(ValueError):
+ _loader([1, 2, 3, 4, 5, 6, 7, 8, 9, 0])
+ with pytest.raises(TypeError):
+ _loader(1)
+
+ # noinspection DuplicatedCode
+ def test_length_both_length(self):
+ _loader = length(min_length=2, max_length=7)
+ assert _loader('ab') == 'ab'
+ assert _loader('abcdefg') == 'abcdefg'
+ assert _loader([1, 2]) == [1, 2]
+ assert _loader([1, 2, 3, 4, 5, 6, 7]) == [1, 2, 3, 4, 5, 6, 7]
+ with pytest.raises(ValueError):
+ _loader('a')
+ with pytest.raises(ValueError):
+ _loader([1])
+ with pytest.raises(ValueError):
+ _loader('abcdefghij')
+ with pytest.raises(ValueError):
+ _loader([1, 2, 3, 4, 5, 6, 7, 8, 9, 0])
+ with pytest.raises(TypeError):
+ _loader(1)
+
+ def test_length_is(self):
+ _loader = length_is(10)
+ assert _loader('abcdefghij') == 'abcdefghij'
+ assert _loader([1, 2, 3, 4, 5, 6, 7, 8, 9, 0]) == [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]
+ with pytest.raises(ValueError):
+ _loader('abcdefg')
+ with pytest.raises(ValueError):
+ _loader('abcdefghijk')
+ with pytest.raises(ValueError):
+ _loader([1, 2, 3, 4])
+ with pytest.raises(ValueError):
+ _loader([1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1])
+ with pytest.raises(TypeError):
+ _loader(1)
+
+ def test_contains(self):
+ _loader = contains('item') & list & collection(str)
+ assert _loader(['item']) == ['item']
+ assert _loader(['item', 'string_1', 'string_2']) == ['item', 'string_1', 'string_2']
+ with pytest.raises(TypeError):
+ _loader(('item', ))
+ with pytest.raises(TypeError):
+ _loader(('item', 'string_1', 'string_2'))
+ with pytest.raises(CollectionError) as ei:
+ _loader(['item', 1, [1, 2]])
+ err = ei.value
+ assert len(err.errors) == 2
+ assert [index for index, _ in err.errors] == [1, 2]
+ assert [type(item) for _, item in err.errors] == [TypeError, TypeError]
+
+ with pytest.raises(ValueError):
+ _loader(['itemx'])
+ with pytest.raises(TypeError):
+ _loader(1)
+
+ def test_cofilter(self):
+ _loader = cofilter(lambda x: x > 0)
+
+ assert _loader([1, 2, -1, 3, -2]) == [1, 2, 3]
+ assert _loader((1, 2, -1, 3, -2)) == (1, 2, 3)
+ assert _loader({1, 2, -1, 3, -2}) == {1, 2, 3}
+
+ def test_cofilter_complex_case_1(self):
+ _loader = list & check_only(
+ (cofilter(lambda x: x == 1) >> length_is(5)) & (cofilter(lambda x: x == 0) >> length_is(5))
+ )
+
+ assert _loader([1, 0, 0, 1, 0, 1, 0, 0, 1, 1]) == [1, 0, 0, 1, 0, 1, 0, 0, 1, 1]
+ assert _loader([1, 2, -3, 0, 0, 1, 0, 1, 0, 0, 1, 1]) == [1, 2, -3, 0, 0, 1, 0, 1, 0, 0, 1, 1]
+ with pytest.raises(ValueError):
+ _loader([1, 0, 0, 1, 0, 1, 0, 0, 1, -1])
+
+ def test_tpselector(self):
+ _loader = tpselector(0, 2)
+
+ assert _loader((1, 2, 3)) == (1, 3)
+ assert _loader((int, None, {}, 4)) == (int, {})
diff --git a/DI-engine/ding/utils/loader/tests/loader/test_dict.py b/DI-engine/ding/utils/loader/tests/loader/test_dict.py
new file mode 100644
index 0000000000000000000000000000000000000000..15fa27a9dd09236075025a66f4011d985554b851
--- /dev/null
+++ b/DI-engine/ding/utils/loader/tests/loader/test_dict.py
@@ -0,0 +1,30 @@
+import pytest
+
+from ding.utils.loader import dict_, DictError, item, norm, msum, keep
+
+
+@pytest.mark.unittest
+class TestConfigLoaderDict:
+
+ def test_dict(self):
+ _loader = dict_(b=item('a'), a=item('b'))
+ assert _loader({'a': 1, 'b': 2}) == {'a': 2, 'b': 1}
+ assert _loader({'a': 4, 'b': [1, 2]}) == {'a': [1, 2], 'b': 4}
+
+ with pytest.raises(DictError) as ei:
+ _loader({'a': 1, 'bb': 2})
+ err = ei.value
+ assert set(err.errors.keys()) == {'a'}
+ assert isinstance(err.errors['a'], KeyError)
+
+ def test_dict_complex_case_1(self):
+ _loader = dict_(
+ real=msum(item('a'), item('b')),
+ result=item('sum') | item('result'),
+ ) >> dict_(
+ real=item('real') >> keep(),
+ result=item('result') >> keep(),
+ correct=norm(item('real')) == norm(item('result')),
+ )
+ assert _loader({'a': 1, 'b': 2, 'result': 3}) == {'real': 3, 'result': 3, 'correct': True}
+ assert _loader({'a': 2, 'b': 2, 'sum': 3}) == {'real': 4, 'result': 3, 'correct': False}
diff --git a/DI-engine/ding/utils/loader/tests/loader/test_mapping.py b/DI-engine/ding/utils/loader/tests/loader/test_mapping.py
new file mode 100644
index 0000000000000000000000000000000000000000..7501c61194b909e66a9c5a145f037a9df9441f87
--- /dev/null
+++ b/DI-engine/ding/utils/loader/tests/loader/test_mapping.py
@@ -0,0 +1,53 @@
+import pytest
+
+from ding.utils.loader import mapping, MappingError, mpfilter, mpkeys, mpvalues, mpitems, item, item_or, is_type, \
+ optional
+
+
+@pytest.mark.unittest
+class TestConfigLoaderMapping:
+
+ def test_mapping(self):
+ _loader = mapping(str, optional(is_type(int) | float))
+ assert _loader({'sdfjk': 1}) == {'sdfjk': 1}
+ assert _loader({'a': 1, 'b': 2.4, 'c': None}) == {'a': 1, 'b': 2.4, 'c': None}
+ with pytest.raises(MappingError) as ei:
+ _loader({'a': 1, 345: 'sdjfhk', 'b': [], None: 389450})
+ err = ei.value
+ assert len(err.key_errors()) == 2
+ assert len(err.value_errors()) == 2
+ assert len(err.errors()) == 4
+ assert {key for key, _ in err.key_errors()} == {345, None}
+ assert {key for key, _ in err.value_errors()} == {345, 'b'}
+
+ with pytest.raises(TypeError):
+ _loader(1)
+ with pytest.raises(TypeError):
+ _loader([])
+
+ def test_mpfilter(self):
+ _loader = mpfilter(lambda k, v: k in {'a', 'b', 'sum'})
+ assert _loader({'a': 1, 'b': 2, 'sum': 3, 'sdk': 4}) == {'a': 1, 'b': 2, 'sum': 3}
+
+ def test_mpkeys(self):
+ _loader = mpkeys()
+ assert _loader({'a': 1, 'b': 2, 'sum': 3, 'sdk': 4}) == {'a', 'b', 'sum', 'sdk'}
+
+ def test_mpvalues(self):
+ _loader = mpvalues()
+ assert _loader({'a': 1, 'b': 2, 'sum': 3, 'sdk': 4}) == {1, 2, 3, 4}
+
+ def test_mpitems(self):
+ _loader = mpitems()
+ assert _loader({'a': 1, 'b': 2, 'sum': 3, 'sdk': 4}) == {('a', 1), ('b', 2), ('sum', 3), ('sdk', 4)}
+
+ def test_item(self):
+ _loader = item('a') | item('b')
+ assert _loader({'a': 1}) == 1
+ assert _loader({'b': 2}) == 2
+ assert _loader({'a': 3, 'b': -2}) == 3
+
+ def test_item_or(self):
+ _loader = item_or('a', 0)
+ assert _loader({'a': 1}) == 1
+ assert _loader({'b': 2}) == 0
diff --git a/DI-engine/ding/utils/loader/tests/loader/test_norm.py b/DI-engine/ding/utils/loader/tests/loader/test_norm.py
new file mode 100644
index 0000000000000000000000000000000000000000..c88260f211bdfeed942fcbf4e15234608933ec99
--- /dev/null
+++ b/DI-engine/ding/utils/loader/tests/loader/test_norm.py
@@ -0,0 +1,365 @@
+from ditk import logging
+
+import pytest
+
+from ding.utils.loader import Loader, interval, item, norm, lin, lis, lisnot, lsum, lcmp, normfunc
+
+
+@pytest.mark.unittest
+class TestConfigLoaderNorm:
+
+ def test_add(self):
+ _norm = norm(item('a')) + 2
+ assert _norm({'a': 2}) == 4
+
+ _norm = 3 + norm(item('a'))
+ assert _norm({'a': 2}) == 5
+
+ _norm = norm(item('a')) + norm(item('b'))
+ assert _norm({'a': 2, 'b': 4}) == 6
+
+ def test_sub(self):
+ _norm = norm(item('a')) - 2
+ assert _norm({'a': 2}) == 0
+
+ _norm = 3 - norm(item('a'))
+ assert _norm({'a': 2}) == 1
+
+ _norm = norm(item('a')) - norm(item('b'))
+ assert _norm({'a': 2, 'b': 4}) == -2
+
+ def test_mul(self):
+ _norm = norm(item('a')) * 2
+ assert _norm({'a': 2}) == 4
+
+ _norm = 3 * norm(item('a'))
+ assert _norm({'a': 2}) == 6
+
+ _norm = norm(item('a')) * norm(item('b'))
+ assert _norm({'a': 2, 'b': 4}) == 8
+
+ def test_matmul(self):
+ # TODO: complete this part
+ logging.warning('Testing of matmul for norm not implemented.')
+
+ def test_truediv(self):
+ _norm = norm(item('a')) / 2
+ assert _norm({'a': 3}) == 1.5
+
+ _norm = 3 / norm(item('a'))
+ assert _norm({'a': 2}) == 1.5
+
+ _norm = norm(item('a')) / norm(item('b'))
+ assert _norm({'a': 2.1, 'b': 4.2}) == 0.5
+
+ def test_floordiv(self):
+ _norm = norm(item('a')) // 2
+ assert _norm({'a': 3}) == 1
+
+ _norm = 3 // norm(item('a'))
+ assert _norm({'a': 2}) == 1
+
+ _norm = norm(item('a')) // norm(item('b'))
+ assert _norm({'a': 10.5, 'b': 4.2}) == 2
+
+ def test_mod(self):
+ _norm = norm(item('a')) % 3
+ assert _norm({'a': 2}) == 2
+ assert _norm({'a': 4}) == 1
+
+ _norm = 4 % norm(item('a'))
+ assert _norm({'a': 2}) == 0
+ assert _norm({'a': 3}) == 1
+
+ _norm = norm(item('a')) % norm(item('b'))
+ assert _norm({'a': 3, 'b': 2}) == 1
+ assert _norm({'a': 5, 'b': 3}) == 2
+
+ def test_pow(self):
+ _norm = norm(item('a')) ** 3
+ assert _norm({'a': 2}) == 8
+ assert _norm({'a': 4}) == 64
+
+ _norm = 4 ** norm(item('a'))
+ assert _norm({'a': 2}) == 16
+ assert _norm({'a': 3}) == 64
+
+ _norm = norm(item('a')) ** norm(item('b'))
+ assert _norm({'a': 3, 'b': 2}) == 9
+ assert _norm({'a': 5, 'b': 3}) == 125
+
+ def test_lshift(self):
+ _norm = norm(item('a')) << 3
+ assert _norm({'a': 2}) == 16
+ assert _norm({'a': 4}) == 32
+
+ _norm = 4 << norm(item('a'))
+ assert _norm({'a': 2}) == 16
+ assert _norm({'a': 3}) == 32
+
+ _norm = norm(item('a')) << norm(item('b'))
+ assert _norm({'a': 3, 'b': 2}) == 12
+ assert _norm({'a': 5, 'b': 3}) == 40
+
+ def test_rshift(self):
+ _norm = norm(item('a')) >> 3
+ assert _norm({'a': 283}) == 35
+ assert _norm({'a': 47}) == 5
+
+ _norm = 47 >> norm(item('a'))
+ assert _norm({'a': 2}) == 11
+ assert _norm({'a': 3}) == 5
+
+ _norm = norm(item('a')) >> norm(item('b'))
+ assert _norm({'a': 37, 'b': 2}) == 9
+ assert _norm({'a': 529, 'b': 5}) == 16
+
+ def test_and(self):
+ _norm = norm(item('a')) & 9
+ assert _norm({'a': 15}) == 9
+ assert _norm({'a': 1}) == 1
+
+ _norm = 11 & norm(item('a'))
+ assert _norm({'a': 15}) == 11
+ assert _norm({'a': 7}) == 3
+
+ _norm = norm(item('a')) & norm(item('b'))
+ assert _norm({'a': 15, 'b': 11}) == 11
+ assert _norm({'a': 9, 'b': 1}) == 1
+
+ def test_or(self):
+ _norm = norm(item('a')) | 9
+ assert _norm({'a': 15}) == 15
+ assert _norm({'a': 83}) == 91
+
+ _norm = 11 | norm(item('a'))
+ assert _norm({'a': 15}) == 15
+ assert _norm({'a': 17}) == 27
+
+ _norm = norm(item('a')) | norm(item('b'))
+ assert _norm({'a': 5, 'b': 11}) == 15
+ assert _norm({'a': 9, 'b': 3}) == 11
+
+ def test_xor(self):
+ _norm = norm(item('a')) ^ 9
+ assert _norm({'a': 15}) == 6
+ assert _norm({'a': 83}) == 90
+
+ _norm = 11 ^ norm(item('a'))
+ assert _norm({'a': 15}) == 4
+ assert _norm({'a': 17}) == 26
+
+ _norm = norm(item('a')) ^ norm(item('b'))
+ assert _norm({'a': 5, 'b': 11}) == 14
+ assert _norm({'a': 9, 'b': 3}) == 10
+
+ def test_invert(self):
+ _norm = ~norm(item('a'))
+ assert _norm({'a': 15}) == -16
+ assert _norm({'a': -2348}) == 2347
+
+ def test_pos(self):
+ _norm = +norm(item('a'))
+ assert _norm({'a': 15}) == 15
+ assert _norm({'a': -2348}) == -2348
+
+ def test_neg(self):
+ _norm = -norm(item('a'))
+ assert _norm({'a': 15}) == -15
+ assert _norm({'a': -2348}) == 2348
+
+ def test_eq(self):
+ _norm = norm(item('a')) == 2
+ assert _norm({'a': 2})
+ assert not _norm({'a': 3})
+
+ _norm = 2 == norm(item('a'))
+ assert _norm({'a': 2})
+ assert not _norm({'a': 3})
+
+ _norm = norm(item('a')) == norm(item('b'))
+ assert _norm({'a': 2, 'b': 2})
+ assert not _norm({'a': 2, 'b': 3})
+
+ def test_ne(self):
+ _norm = norm(item('a')) != 2
+ assert not _norm({'a': 2})
+ assert _norm({'a': 3})
+
+ _norm = 2 != norm(item('a'))
+ assert not _norm({'a': 2})
+ assert _norm({'a': 3})
+
+ _norm = norm(item('a')) != norm(item('b'))
+ assert not _norm({'a': 2, 'b': 2})
+ assert _norm({'a': 2, 'b': 3})
+
+ def test_lt(self):
+ _norm = norm(item('a')) < 2
+ assert _norm({'a': 1})
+ assert not _norm({'a': 2})
+ assert not _norm({'a': 3})
+
+ _norm = 2 < norm(item('a'))
+ assert not _norm({'a': 1})
+ assert not _norm({'a': 2})
+ assert _norm({'a': 3})
+
+ _norm = norm(item('a')) < norm(item('b'))
+ assert _norm({'a': 1, 'b': 2})
+ assert not _norm({'a': 2, 'b': 2})
+ assert not _norm({'a': 3, 'b': 2})
+
+ def test_le(self):
+ _norm = norm(item('a')) <= 2
+ assert _norm({'a': 1})
+ assert _norm({'a': 2})
+ assert not _norm({'a': 3})
+
+ _norm = 2 <= norm(item('a'))
+ assert not _norm({'a': 1})
+ assert _norm({'a': 2})
+ assert _norm({'a': 3})
+
+ _norm = norm(item('a')) <= norm(item('b'))
+ assert _norm({'a': 1, 'b': 2})
+ assert _norm({'a': 2, 'b': 2})
+ assert not _norm({'a': 3, 'b': 2})
+
+ def test_gt(self):
+ _norm = norm(item('a')) > 2
+ assert not _norm({'a': 1})
+ assert not _norm({'a': 2})
+ assert _norm({'a': 3})
+
+ _norm = 2 > norm(item('a'))
+ assert _norm({'a': 1})
+ assert not _norm({'a': 2})
+ assert not _norm({'a': 3})
+
+ _norm = norm(item('a')) > norm(item('b'))
+ assert not _norm({'a': 1, 'b': 2})
+ assert not _norm({'a': 2, 'b': 2})
+ assert _norm({'a': 3, 'b': 2})
+
+ def test_ge(self):
+ _norm = norm(item('a')) >= 2
+ assert not _norm({'a': 1})
+ assert _norm({'a': 2})
+ assert _norm({'a': 3})
+
+ _norm = 2 >= norm(item('a'))
+ assert _norm({'a': 1})
+ assert _norm({'a': 2})
+ assert not _norm({'a': 3})
+
+ _norm = norm(item('a')) >= norm(item('b'))
+ assert not _norm({'a': 1, 'b': 2})
+ assert _norm({'a': 2, 'b': 2})
+ assert _norm({'a': 3, 'b': 2})
+
+ def test_lin(self):
+ _norm = lin(norm(item('a')), 'string')
+ assert _norm({'a': ['string', 1, 2]})
+ assert not _norm({'a': ['strng', 1, 2]})
+
+ _norm = lin([1, 2, 3], norm(item('a')))
+ assert _norm({'a': 1})
+ assert not _norm({'a': 4})
+
+ _norm = lin(norm(item('a')), norm(item('b')))
+ assert _norm({'a': [1, 2], 'b': 1})
+ assert not _norm({'a': [1, 2], 'b': 3})
+
+ def test_lis(self):
+ _norm = lis(norm(item('a')), 'string')
+ assert _norm({'a': 'string'})
+ assert not _norm({'a': ['strng', 1, 2]})
+
+ _norm = lis(None, norm(item('a')))
+ assert _norm({'a': None})
+ assert not _norm({'a': 4})
+
+ _norm = lis(norm(item('a')), norm(item('b')))
+ assert _norm({'a': 1, 'b': 1})
+ assert not _norm({'a': [1, 2], 'b': 3})
+
+ def test_lisnot(self):
+ _norm = lisnot(norm(item('a')), 'string')
+ assert not _norm({'a': 'string'})
+ assert _norm({'a': ['strng', 1, 2]})
+
+ _norm = lisnot(None, norm(item('a')))
+ assert not _norm({'a': None})
+ assert _norm({'a': 4})
+
+ _norm = lisnot(norm(item('a')), norm(item('b')))
+ assert not _norm({'a': 1, 'b': 1})
+ assert _norm({'a': [1, 2], 'b': 3})
+
+ def test_lsum(self):
+ _norm = lsum(1, 2, norm(item('a') | item('b')), norm(item('c')))
+ assert _norm({'a': 1, 'c': 10}) == 14
+ assert _norm({'b': 20, 'c': 100}) == 123
+ assert _norm({'b': 20, 'a': 30, 'c': -1}) == 32
+
+ def test_lcmp(self):
+ _norm = lcmp(2, '<', norm(item('a')), "<=", 5)
+ assert not _norm({'a': 1})
+ assert not _norm({'a': 2})
+ assert _norm({'a': 3})
+ assert _norm({'a': 4})
+ assert _norm({'a': 5})
+ assert not _norm({'a': 6})
+
+ _norm = lcmp(2, '>=', norm(item('b')), '>', -1)
+ assert not _norm({'b': -2})
+ assert not _norm({'b': -1})
+ assert _norm({'b': 0})
+ assert _norm({'b': 1})
+ assert _norm({'b': 2})
+ assert not _norm({'b': 3})
+
+ _norm = lcmp(2, '!=', norm(item('c')), '==', 1)
+ assert _norm({'c': 1})
+ assert not _norm({'c': 2})
+
+ def test_lcmp_invalid(self):
+ _norm = lcmp(2, '<', norm(item('a')), "> interval(1, 3)
+ assert _loader({'a': 2, 'b': -1}) == 1
+ assert _loader({'a': 1, 'b': 1}) == 2
+ with pytest.raises(ValueError):
+ _loader({'a': 0, 'b': 0})
+ with pytest.raises(ValueError):
+ _loader({'a': 0, 'b': 10})
+ with pytest.raises(KeyError):
+ _loader({'a': 0, 'bb': 2})
diff --git a/DI-engine/ding/utils/loader/tests/loader/test_number.py b/DI-engine/ding/utils/loader/tests/loader/test_number.py
new file mode 100644
index 0000000000000000000000000000000000000000..c57d9b194f6c6aec99d18982e1fb199b37fbf37e
--- /dev/null
+++ b/DI-engine/ding/utils/loader/tests/loader/test_number.py
@@ -0,0 +1,669 @@
+import math
+
+import pytest
+
+from ding.utils.loader import item, item_or, numeric, interval, negative, plus, minus, minus_with, multi, divide, \
+ divide_with, power, power_with, positive, msum, mmulti, mcmp, is_positive, is_negative, non_positive, \
+ non_negative, keep
+
+
+@pytest.mark.unittest
+class TestConfigLoaderNumber:
+ # noinspection DuplicatedCode
+ def test_numeric_plain(self):
+ _loader = numeric()
+
+ assert _loader(1) == 1
+ assert _loader(1.0) == 1.0
+ with pytest.raises(TypeError):
+ _loader('1')
+ with pytest.raises(TypeError):
+ _loader('-1.0')
+ assert _loader(math.inf) == math.inf
+ assert _loader(-float('inf')) == -math.inf
+ with pytest.raises(TypeError):
+ _loader('inf')
+ with pytest.raises(TypeError):
+ _loader('-inf')
+ with pytest.raises(ValueError):
+ _loader(math.nan)
+ with pytest.raises(TypeError):
+ _loader('nan')
+ with pytest.raises(TypeError):
+ _loader(None)
+ with pytest.raises(TypeError):
+ _loader('styring')
+ with pytest.raises(TypeError):
+ _loader('-abcdef12345')
+ with pytest.raises(TypeError):
+ _loader('i n f')
+
+ # noinspection DuplicatedCode
+ def test_numeric_int_ban(self):
+ _loader = numeric(int_ok=False)
+ with pytest.raises(TypeError):
+ _loader(1)
+ assert _loader(1.0) == 1.0
+ with pytest.raises(TypeError):
+ _loader('1')
+ with pytest.raises(TypeError):
+ _loader('-1.0')
+ assert _loader(math.inf) == math.inf
+ assert _loader(-float('inf')) == -math.inf
+ with pytest.raises(TypeError):
+ _loader('inf')
+ with pytest.raises(TypeError):
+ _loader('-inf')
+ with pytest.raises(ValueError):
+ _loader(math.nan)
+ with pytest.raises(TypeError):
+ _loader('nan')
+ with pytest.raises(TypeError):
+ _loader(None)
+ with pytest.raises(TypeError):
+ _loader('styring')
+ with pytest.raises(TypeError):
+ _loader('-abcdef12345')
+ with pytest.raises(TypeError):
+ _loader('i n f')
+
+ # noinspection DuplicatedCode
+ def test_numeric_float_ban(self):
+ _loader = numeric(float_ok=False)
+ assert _loader(1) == 1
+ with pytest.raises(TypeError):
+ _loader(1.0)
+ with pytest.raises(TypeError):
+ _loader('1')
+ with pytest.raises(TypeError):
+ _loader('-1.0')
+ with pytest.raises(TypeError):
+ _loader(math.inf)
+ with pytest.raises(TypeError):
+ _loader(-float('inf'))
+ with pytest.raises(TypeError):
+ _loader('inf')
+ with pytest.raises(TypeError):
+ _loader('-inf')
+ with pytest.raises(ValueError):
+ _loader(math.nan)
+ with pytest.raises(TypeError):
+ _loader('nan')
+ with pytest.raises(TypeError):
+ _loader(None)
+ with pytest.raises(TypeError):
+ _loader('styring')
+ with pytest.raises(TypeError):
+ _loader('-abcdef12345')
+ with pytest.raises(TypeError):
+ _loader('i n f')
+
+ def test_numeric_double_ban(self):
+ with pytest.raises(ValueError):
+ numeric(int_ok=False, float_ok=False)
+
+ # noinspection DuplicatedCode
+ def test_numeric_inf_ban(self):
+ _loader = numeric(inf_ok=False)
+ assert _loader(1) == 1
+ assert _loader(1.0) == 1.0
+ with pytest.raises(TypeError):
+ _loader('1')
+ with pytest.raises(TypeError):
+ _loader('-1.0')
+ with pytest.raises(ValueError):
+ _loader(math.inf)
+ with pytest.raises(ValueError):
+ _loader(-float('inf'))
+ with pytest.raises(TypeError):
+ _loader('inf')
+ with pytest.raises(TypeError):
+ _loader('-inf')
+ with pytest.raises(ValueError):
+ _loader(math.nan)
+ with pytest.raises(TypeError):
+ _loader('nan')
+ with pytest.raises(TypeError):
+ _loader(None)
+ with pytest.raises(TypeError):
+ _loader('styring')
+ with pytest.raises(TypeError):
+ _loader('-abcdef12345')
+ with pytest.raises(TypeError):
+ _loader('i n f')
+
+ def test_interval_common(self):
+ _loader = interval(1, 3.5)
+ with pytest.raises(ValueError):
+ _loader(0.5)
+ assert _loader(1.0) == 1.0
+ assert _loader(1.5) == 1.5
+ assert _loader(3.5) == 3.5
+ with pytest.raises(ValueError):
+ _loader(4.0)
+
+ def test_interval_all(self):
+ _loader = interval()
+ assert _loader(0.5) == 0.5
+ assert _loader(1.0) == 1.0
+ assert _loader(1.5) == 1.5
+ assert _loader(3.5) == 3.5
+ assert _loader(4.0) == 4.0
+
+ # noinspection DuplicatedCode
+ def test_interval_left_open(self):
+ _loader = interval(1.0, left_ok=False)
+ with pytest.raises(ValueError):
+ _loader(0.5)
+ with pytest.raises(ValueError):
+ _loader(0.9)
+ with pytest.raises(ValueError):
+ _loader(0.999)
+ with pytest.raises(ValueError):
+ _loader(1.0)
+ assert _loader(1.001) == 1.001
+ assert _loader(1.1) == 1.1
+ assert _loader(1.5) == 1.5
+ assert _loader(3.5) == 3.5
+ assert _loader(4.0) == 4.0
+
+ # noinspection DuplicatedCode
+ def test_interval_left_open_eps(self):
+ _loader = interval(1.0, left_ok=False, eps=0.01)
+ with pytest.raises(ValueError):
+ _loader(0.5)
+ with pytest.raises(ValueError):
+ _loader(0.9)
+ with pytest.raises(ValueError):
+ _loader(0.999)
+ with pytest.raises(ValueError):
+ _loader(1.0)
+ with pytest.raises(ValueError):
+ _loader(1.001)
+ assert _loader(1.1) == 1.1
+ assert _loader(1.5) == 1.5
+ assert _loader(3.5) == 3.5
+ assert _loader(4.0) == 4.0
+
+ # noinspection DuplicatedCode
+ def test_interval_left_close(self):
+ _loader = interval(1.0)
+ with pytest.raises(ValueError):
+ _loader(0.5)
+ with pytest.raises(ValueError):
+ _loader(0.9)
+ with pytest.raises(ValueError):
+ _loader(0.999)
+ assert _loader(1.0) == 1.0
+ assert _loader(1.001) == 1.001
+ assert _loader(1.1) == 1.1
+ assert _loader(1.5) == 1.5
+ assert _loader(3.5) == 3.5
+ assert _loader(4.0) == 4.0
+
+ # noinspection DuplicatedCode
+ def test_interval_left_close_eps(self):
+ _loader = interval(1.0, eps=0.01)
+ with pytest.raises(ValueError):
+ _loader(0.5)
+ with pytest.raises(ValueError):
+ _loader(0.9)
+ assert _loader(0.999) == 0.999
+ assert _loader(1.0) == 1.0
+ assert _loader(1.001) == 1.001
+ assert _loader(1.1) == 1.1
+ assert _loader(1.5) == 1.5
+ assert _loader(3.5) == 3.5
+ assert _loader(4.0) == 4.0
+
+ def test_interval_right_open(self):
+ _loader = interval(right=3.5, right_ok=False)
+ assert _loader(0.5) == 0.5
+ assert _loader(1.0) == 1.0
+ assert _loader(1.5) == 1.5
+ assert _loader(3.4) == 3.4
+ assert _loader(3.499) == 3.499
+ with pytest.raises(ValueError):
+ _loader(3.5)
+ with pytest.raises(ValueError):
+ _loader(3.501)
+ with pytest.raises(ValueError):
+ _loader(3.6)
+ with pytest.raises(ValueError):
+ _loader(4.0)
+
+ # noinspection DuplicatedCode
+ def test_interval_right_open_eps(self):
+ _loader = interval(right=3.5, right_ok=False, eps=0.01)
+ assert _loader(0.5) == 0.5
+ assert _loader(1.0) == 1.0
+ assert _loader(1.5) == 1.5
+ assert _loader(3.4) == 3.4
+ with pytest.raises(ValueError):
+ _loader(3.499)
+ with pytest.raises(ValueError):
+ _loader(3.5)
+ with pytest.raises(ValueError):
+ _loader(3.501)
+ with pytest.raises(ValueError):
+ _loader(3.6)
+ with pytest.raises(ValueError):
+ _loader(4.0)
+
+ def test_interval_right_close(self):
+ _loader = interval(right=3.5)
+ assert _loader(0.5) == 0.5
+ assert _loader(1.0) == 1.0
+ assert _loader(1.5) == 1.5
+ assert _loader(3.4) == 3.4
+ assert _loader(3.499) == 3.499
+ assert _loader(3.5) == 3.5
+ with pytest.raises(ValueError):
+ _loader(3.501)
+ with pytest.raises(ValueError):
+ _loader(3.6)
+ with pytest.raises(ValueError):
+ _loader(4.0)
+
+ def test_interval_right_close_eps(self):
+ _loader = interval(right=3.5, eps=0.01)
+ assert _loader(0.5) == 0.5
+ assert _loader(1.0) == 1.0
+ assert _loader(1.5) == 1.5
+ assert _loader(3.4) == 3.4
+ assert _loader(3.499) == 3.499
+ assert _loader(3.5) == 3.5
+ assert _loader(3.501) == 3.501
+ with pytest.raises(ValueError):
+ _loader(3.6)
+ with pytest.raises(ValueError):
+ _loader(4.0)
+
+ # noinspection DuplicatedCode
+ def test_interval_both_open_open(self):
+ _loader = interval(1.0, 3.5, left_ok=False, right_ok=False)
+ with pytest.raises(ValueError):
+ _loader(0.5)
+ with pytest.raises(ValueError):
+ _loader(0.9)
+ with pytest.raises(ValueError):
+ _loader(0.999)
+ with pytest.raises(ValueError):
+ _loader(1.0)
+ assert _loader(1.001) == 1.001
+ assert _loader(1.1) == 1.1
+ assert _loader(1.5) == 1.5
+ assert _loader(3.4) == 3.4
+ assert _loader(3.499) == 3.499
+ with pytest.raises(ValueError):
+ _loader(3.5)
+ with pytest.raises(ValueError):
+ _loader(3.501)
+ with pytest.raises(ValueError):
+ _loader(3.6)
+ with pytest.raises(ValueError):
+ _loader(4.0)
+
+ # noinspection DuplicatedCode
+ def test_interval_both_open_open_eps(self):
+ _loader = interval(1.0, 3.5, left_ok=False, right_ok=False, eps=0.01)
+ with pytest.raises(ValueError):
+ _loader(0.5)
+ with pytest.raises(ValueError):
+ _loader(0.9)
+ with pytest.raises(ValueError):
+ _loader(0.999)
+ with pytest.raises(ValueError):
+ _loader(1.0)
+ with pytest.raises(ValueError):
+ _loader(1.001)
+ assert _loader(1.1) == 1.1
+ assert _loader(1.5) == 1.5
+ assert _loader(3.4) == 3.4
+ with pytest.raises(ValueError):
+ _loader(3.499)
+ with pytest.raises(ValueError):
+ _loader(3.5)
+ with pytest.raises(ValueError):
+ _loader(3.501)
+ with pytest.raises(ValueError):
+ _loader(3.6)
+ with pytest.raises(ValueError):
+ _loader(4.0)
+
+ # noinspection DuplicatedCode
+ def test_interval_both_open_close(self):
+ _loader = interval(1.0, 3.5, left_ok=False)
+ with pytest.raises(ValueError):
+ _loader(0.5)
+ with pytest.raises(ValueError):
+ _loader(0.9)
+ with pytest.raises(ValueError):
+ _loader(0.999)
+ with pytest.raises(ValueError):
+ _loader(1.0)
+ assert _loader(1.001) == 1.001
+ assert _loader(1.1) == 1.1
+ assert _loader(1.5) == 1.5
+ assert _loader(3.4) == 3.4
+ assert _loader(3.499) == 3.499
+ assert _loader(3.5) == 3.5
+ with pytest.raises(ValueError):
+ _loader(3.501)
+ with pytest.raises(ValueError):
+ _loader(3.6)
+ with pytest.raises(ValueError):
+ _loader(4.0)
+
+ # noinspection DuplicatedCode
+ def test_interval_both_open_close_eps(self):
+ _loader = interval(1.0, 3.5, left_ok=False, eps=0.01)
+ with pytest.raises(ValueError):
+ _loader(0.5)
+ with pytest.raises(ValueError):
+ _loader(0.9)
+ with pytest.raises(ValueError):
+ _loader(0.999)
+ with pytest.raises(ValueError):
+ _loader(1.0)
+ with pytest.raises(ValueError):
+ _loader(1.001)
+ assert _loader(1.1) == 1.1
+ assert _loader(1.5) == 1.5
+ assert _loader(3.4) == 3.4
+ assert _loader(3.499) == 3.499
+ assert _loader(3.5) == 3.5
+ assert _loader(3.501) == 3.501
+ with pytest.raises(ValueError):
+ _loader(3.6)
+ with pytest.raises(ValueError):
+ _loader(4.0)
+
+ def test_interval_both_close_open(self):
+ _loader = interval(1.0, 3.5, right_ok=False)
+ with pytest.raises(ValueError):
+ _loader(0.5)
+ with pytest.raises(ValueError):
+ _loader(0.9)
+ with pytest.raises(ValueError):
+ _loader(0.999)
+ assert _loader(1.0) == 1.0
+ assert _loader(1.001) == 1.001
+ assert _loader(1.1) == 1.1
+ assert _loader(1.5) == 1.5
+ assert _loader(3.4) == 3.4
+ assert _loader(3.499) == 3.499
+ with pytest.raises(ValueError):
+ _loader(3.5)
+ with pytest.raises(ValueError):
+ _loader(3.501)
+ with pytest.raises(ValueError):
+ _loader(3.6)
+ with pytest.raises(ValueError):
+ _loader(4.0)
+
+ # noinspection DuplicatedCode
+ def test_interval_both_close_open_eps(self):
+ _loader = interval(1.0, 3.5, right_ok=False, eps=0.01)
+ with pytest.raises(ValueError):
+ _loader(0.5)
+ with pytest.raises(ValueError):
+ _loader(0.9)
+ assert _loader(0.999) == 0.999
+ assert _loader(1.0) == 1.0
+ assert _loader(1.001) == 1.001
+ assert _loader(1.1) == 1.1
+ assert _loader(1.5) == 1.5
+ assert _loader(3.4) == 3.4
+ with pytest.raises(ValueError):
+ _loader(3.499)
+ with pytest.raises(ValueError):
+ _loader(3.5)
+ with pytest.raises(ValueError):
+ _loader(3.501)
+ with pytest.raises(ValueError):
+ _loader(3.6)
+ with pytest.raises(ValueError):
+ _loader(4.0)
+
+ # noinspection DuplicatedCode
+ def test_interval_both_close_close(self):
+ _loader = interval(1.0, 3.5)
+ with pytest.raises(ValueError):
+ _loader(0.5)
+ with pytest.raises(ValueError):
+ _loader(0.9)
+ with pytest.raises(ValueError):
+ _loader(0.999)
+ assert _loader(1.0) == 1.0
+ assert _loader(1.001) == 1.001
+ assert _loader(1.1) == 1.1
+ assert _loader(1.5) == 1.5
+ assert _loader(3.4) == 3.4
+ assert _loader(3.499) == 3.499
+ assert _loader(3.5) == 3.5
+ with pytest.raises(ValueError):
+ _loader(3.501)
+ with pytest.raises(ValueError):
+ _loader(3.6)
+ with pytest.raises(ValueError):
+ _loader(4.0)
+
+ # noinspection DuplicatedCode
+ def test_interval_both_close_close_eps(self):
+ _loader = interval(1.0, 3.5, eps=0.01)
+ with pytest.raises(ValueError):
+ _loader(0.5)
+ with pytest.raises(ValueError):
+ _loader(0.9)
+ assert _loader(0.999) == 0.999
+ assert _loader(1.0) == 1.0
+ assert _loader(1.001) == 1.001
+ assert _loader(1.1) == 1.1
+ assert _loader(1.5) == 1.5
+ assert _loader(3.4) == 3.4
+ assert _loader(3.499) == 3.499
+ assert _loader(3.5) == 3.5
+ assert _loader(3.501) == 3.501
+ with pytest.raises(ValueError):
+ _loader(3.6)
+ with pytest.raises(ValueError):
+ _loader(4.0)
+
+ def test_interval_invalid(self):
+ with pytest.raises(ValueError):
+ interval(1.0, 0.9)
+
+ def test_interval_complex_1(self):
+ _loader = float & (interval(1, 4, left_ok=False, right_ok=False) | interval(10.2, 13.4, eps=0.01))
+ with pytest.raises(ValueError):
+ _loader(0.9)
+ with pytest.raises(ValueError):
+ _loader(1.0)
+ assert _loader(1.1) == 1.1
+ with pytest.raises(TypeError):
+ _loader(2)
+ assert _loader(2.0) == 2.0
+ assert _loader(3.9) == 3.9
+ with pytest.raises(ValueError):
+ _loader(4.0)
+ with pytest.raises(ValueError):
+ _loader(4.1)
+ with pytest.raises(ValueError):
+ _loader(10.1)
+ assert _loader(10.199) == 10.199
+ assert _loader(10.2) == 10.2
+ with pytest.raises(TypeError):
+ _loader(11)
+ assert _loader(11.0) == 11.0
+ assert _loader(13.4) == 13.4
+ assert _loader(13.401) == 13.401
+ with pytest.raises(ValueError):
+ _loader(13.5)
+ with pytest.raises(TypeError):
+ _loader(None)
+ with pytest.raises(TypeError):
+ _loader('string')
+
+ def test_negative(self):
+ _loader = negative()
+ assert _loader(1) == -1
+ assert _loader(-2) == 2
+
+ def test_positive(self):
+ _loader = positive()
+ assert _loader(1) == 1
+ assert _loader(0) == 0
+ assert _loader(-1) == -1
+
+ def test_plus(self):
+ _loader = plus(1)
+ assert _loader(1) == 2
+ assert _loader(-2) == -1
+
+ _loader = plus(negative())
+ assert _loader(1) == 0
+ assert _loader(-2) == 0
+
+ def test_minus(self):
+ _loader = minus(2)
+ assert _loader(1) == -1
+ assert _loader(-2) == -4
+
+ _loader = minus(negative())
+ assert _loader(1) == 2
+ assert _loader(-2) == -4
+
+ def test_minus_with(self):
+ _loader = minus_with(2)
+ assert _loader(1) == 1
+ assert _loader(-2) == 4
+
+ _loader = minus_with(negative())
+ assert _loader(1) == -2
+ assert _loader(-2) == 4
+
+ def test_multi(self):
+ _loader = multi(2)
+ assert _loader(1) == 2
+ assert _loader(-2) == -4
+
+ _loader = multi(keep())
+ assert _loader(1) == 1
+ assert _loader(-2) == 4
+ assert _loader(-3) == 9
+
+ def test_divide(self):
+ _loader = divide(2)
+ assert _loader(1) == 0.5
+ assert _loader(-2) == -1
+
+ _loader = divide(negative())
+ assert _loader(1) == -1
+ assert _loader(-2) == -1
+
+ def test_divide_with(self):
+ _loader = divide_with(2)
+ assert _loader(1) == 2
+ assert _loader(-2) == -1
+
+ _loader = divide_with(negative())
+ assert _loader(1) == -1
+ assert _loader(-2) == -1
+
+ def test_power(self):
+ _loader = power(2)
+ assert _loader(1) == 1
+ assert _loader(-2) == 4
+
+ _loader = power(keep()) >> power(keep())
+ assert _loader(2) == 256
+ assert _loader(3) == 443426488243037769948249630619149892803
+
+ def test_power_with(self):
+ _loader = power_with(2)
+ assert _loader(1) == 2
+ assert _loader(-2) == 0.25
+
+ _loader = power_with(minus(1)) >> power_with(minus(1))
+ assert _loader(3) == 5764801
+ assert _loader(4) == int(
+ '14134776518227074636666380005943348126619871175004951664972849610340958208'
+ '000000000000000000000000000000000000000000000000000000000000000000000000000000000'
+ )
+
+ def test_msum(self):
+ _loader = msum(item('a'), item('b'), item_or('c', 0))
+ assert _loader({'a': 1, 'b': 3}) == 4
+ assert _loader({'a': -2, 'b': 5, 'c': 20}) == 23
+
+ def test_mmulti(self):
+ _loader = mmulti(item('a'), item('b'), item_or('c', 1))
+ assert _loader({'a': 1, 'b': 3}) == 3
+ assert _loader({'a': -2, 'b': 5, 'c': 3}) == -30
+
+ def test_mcmp(self):
+ _loader = mcmp(1)
+ assert _loader(1) == 1
+
+ _loader = mcmp(1, '<', item('a'), '>=', item('b'))
+ assert _loader({'a': 2, 'b': 1}) == {'a': 2, 'b': 1}
+ assert _loader({'a': 2, 'b': 2}) == {'a': 2, 'b': 2}
+ with pytest.raises(ValueError):
+ _loader({'a': 2, 'b': 3})
+ with pytest.raises(ValueError):
+ _loader({'a': 1, 'b': 0})
+
+ _loader = mcmp(1, '==', keep())
+ assert _loader(1) == 1
+ with pytest.raises(ValueError):
+ _loader(2)
+
+ _loader = mcmp(1, '!=', keep())
+ assert _loader(2) == 2
+ with pytest.raises(ValueError):
+ _loader(1)
+
+ _loader = mcmp(1, '>', item('a'), '<=', item('b'))
+ assert _loader({'a': 0, 'b': 1}) == {'a': 0, 'b': 1}
+ assert _loader({'a': 0, 'b': 0}) == {'a': 0, 'b': 0}
+ with pytest.raises(ValueError):
+ _loader({'a': 0, 'b': -1})
+ with pytest.raises(ValueError):
+ _loader({'a': 1, 'b': 2})
+
+ def test_mcmp_invalid(self):
+ with pytest.raises(ValueError):
+ mcmp(1, '>', item('a'), '<=', item('b'), '==')
+ with pytest.raises(KeyError):
+ mcmp(1, '>', item('a'), '*=', item('b'))
+
+ def test_is_positive(self):
+ _loader = is_positive()
+ assert _loader(1) == 1
+ with pytest.raises(ValueError):
+ _loader(0)
+ with pytest.raises(ValueError):
+ _loader(-1)
+
+ def test_is_negative(self):
+ _loader = is_negative()
+ with pytest.raises(ValueError):
+ _loader(1)
+ with pytest.raises(ValueError):
+ _loader(0)
+ assert _loader(-1) == -1
+
+ def test_non_positive(self):
+ _loader = non_positive()
+ with pytest.raises(ValueError):
+ _loader(1)
+ assert _loader(0) == 0
+ assert _loader(-1) == -1
+
+ def test_non_negative(self):
+ _loader = non_negative()
+ assert _loader(1) == 1
+ assert _loader(0) == 0
+ with pytest.raises(ValueError):
+ _loader(-1)
diff --git a/DI-engine/ding/utils/loader/tests/loader/test_string.py b/DI-engine/ding/utils/loader/tests/loader/test_string.py
new file mode 100644
index 0000000000000000000000000000000000000000..cdf046c1d088441970a9715bdd2ee41973c8563b
--- /dev/null
+++ b/DI-engine/ding/utils/loader/tests/loader/test_string.py
@@ -0,0 +1,102 @@
+import re
+
+import pytest
+
+from ding.utils.loader import enum, rematch, regrep, to_type
+
+
+@pytest.mark.unittest
+class TestConfigLoaderString:
+
+ def test_enum_plain(self):
+ _loader = enum('red', 'green', 'blue', 'yellow')
+ assert _loader('red') == 'red'
+ assert _loader('green') == 'green'
+ assert _loader('blue') == 'blue'
+ assert _loader('yellow') == 'yellow'
+ with pytest.raises(ValueError):
+ _loader(int)
+ with pytest.raises(ValueError):
+ _loader('Red')
+ with pytest.raises(ValueError):
+ _loader('YELLOW')
+ with pytest.raises(ValueError):
+ _loader(1)
+ with pytest.raises(ValueError):
+ _loader(None)
+
+ def test_enum_case_insensitive(self):
+ _loader = enum('red', 'green', 'blue', 'yellow', case_sensitive=False)
+ assert _loader('red') == 'red'
+ assert _loader('green') == 'green'
+ assert _loader('blue') == 'blue'
+ assert _loader('yellow') == 'yellow'
+ with pytest.raises(ValueError):
+ _loader(int)
+ assert _loader('Red') == 'red'
+ assert _loader('YELLOW') == 'yellow'
+ with pytest.raises(ValueError):
+ _loader(1)
+ with pytest.raises(ValueError):
+ _loader(None)
+
+ def test_enum_complex_case_1(self):
+ _loader = (lambda x: str(x).strip()) >> enum('red', 'green', 'blue', 'yellow', case_sensitive=False)
+ assert _loader('red') == 'red'
+ assert _loader('green') == 'green'
+ assert _loader('blue') == 'blue'
+ assert _loader('yellow') == 'yellow'
+ assert _loader(' yellow ') == 'yellow'
+ with pytest.raises(ValueError):
+ _loader(int)
+ assert _loader('Red') == 'red'
+ assert _loader('YELLOW') == 'yellow'
+ assert _loader(' YelloW ') == 'yellow'
+ with pytest.raises(ValueError):
+ _loader(1)
+ with pytest.raises(ValueError):
+ _loader(None)
+
+ # noinspection DuplicatedCode
+ def test_rematch_str(self):
+ _loader = to_type(str) >> str.strip >> str.lower >> rematch('[0-9a-z_]+@([0-9a-z]+.)+[0-9a-z]+')
+ assert _loader('hansbug@buaa.edu.cn') == 'hansbug@buaa.edu.cn'
+ assert _loader(' hansbug@BUAA.EDU.CN\t') == 'hansbug@buaa.edu.cn'
+ with pytest.raises(ValueError):
+ _loader(' hansbug.buaa.edu.cn')
+ with pytest.raises(ValueError):
+ _loader(' hansbug@cn')
+ with pytest.raises(ValueError):
+ _loader(' hansbug@buaa.edu..cn')
+
+ # noinspection DuplicatedCode
+ def test_rematch_pattern(self):
+ _loader = to_type(str) >> str.strip >> str.lower >> rematch(re.compile('[0-9a-z_]+@([0-9a-z]+.)+[0-9a-z]+'))
+ assert _loader('hansbug@buaa.edu.cn') == 'hansbug@buaa.edu.cn'
+ assert _loader(' hansbug@BUAA.EDU.CN\t') == 'hansbug@buaa.edu.cn'
+ with pytest.raises(ValueError):
+ _loader(' hansbug.buaa.edu.cn')
+ with pytest.raises(ValueError):
+ _loader(' hansbug@cn')
+ with pytest.raises(ValueError):
+ _loader(' hansbug@buaa.edu..cn')
+
+ def test_rematch_invalid(self):
+ with pytest.raises(TypeError):
+ _loader = rematch(1)
+
+ def test_regrep(self):
+ _loader = to_type(str) >> str.lower >> regrep('[0-9a-z_]+@([0-9a-z]+.)+[0-9a-z]+')
+ assert _loader('hansbug@buaa.edu.cn') == 'hansbug@buaa.edu.cn'
+ assert _loader(' hansbug@BUAA.EDU.CN\t') == 'hansbug@buaa.edu.cn'
+ assert _loader('This is my email hansbug@buaa.edu.cn, thanks~~') == 'hansbug@buaa.edu.cn'
+ with pytest.raises(ValueError):
+ _loader('this is hansbug.buaa.edu.cn')
+
+ def test_regrep_group(self):
+ _loader = to_type(str) >> str.lower >> regrep('[0-9a-z_]+@(([0-9a-z]+.)+[0-9a-z]+)', group=1)
+ assert _loader('hansbug@buaa.edu.cn') == 'buaa.edu.cn'
+ assert _loader(' hansbug@BUAA.EDU.CN\t') == 'buaa.edu.cn'
+ assert _loader('This is my email hansbug@buaa.edu.cn, thanks~~') == 'buaa.edu.cn'
+ with pytest.raises(ValueError):
+ _loader(' @buaa.edu.cn')
diff --git a/DI-engine/ding/utils/loader/tests/loader/test_types.py b/DI-engine/ding/utils/loader/tests/loader/test_types.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ac6ae225cf3226d2568ceeca44cf535548da8bf
--- /dev/null
+++ b/DI-engine/ding/utils/loader/tests/loader/test_types.py
@@ -0,0 +1,110 @@
+import pytest
+from easydict import EasyDict
+
+from ding.utils.loader import interval, negative, is_type, to_type, prop, method, fcall, is_callable, fpartial, keep
+
+
+@pytest.mark.unittest
+class TestConfigLoaderTypes:
+
+ def test_is_type(self):
+ _loader = is_type(float) | is_type(int)
+ assert _loader(1) == 1
+ assert _loader(2.5) == 2.5
+ with pytest.raises(TypeError):
+ _loader(None)
+
+ # noinspection PyTypeChecker
+ def test_is_type_invalid(self):
+ with pytest.raises(TypeError):
+ is_type(lambda x: x + 1)
+
+ def test_to_type_float(self):
+ _loader = keep() >> to_type(float)
+ assert _loader(1) == 1.0
+ assert isinstance(_loader(1), float)
+ assert _loader(2.0) == 2.0
+ assert isinstance(_loader(2.0), float)
+
+ def test_to_type_str(self):
+ _loader = keep() >> to_type(str)
+ assert _loader(1) == '1'
+ assert _loader(2.0) == '2.0'
+ assert _loader(None) == 'None'
+
+ def test_to_type_float_str(self):
+ _loader = keep() >> to_type(float) >> to_type(str)
+ assert _loader(1) == '1.0'
+ assert _loader(2.0) == '2.0'
+ with pytest.raises(TypeError):
+ _loader(None)
+
+ def test_is_callable(self):
+ _loader = is_callable()
+ assert _loader(lambda x: 1)
+ assert _loader(str) == str
+ assert _loader(str.lower) == str.lower
+ with pytest.raises(TypeError):
+ _loader(1)
+
+ def test_prop(self):
+ t1 = EasyDict({'x': 1, 'y': 2, 'z': 'string'})
+ t2 = EasyDict({'x': 'str'})
+ t3 = EasyDict({'z': -1, 'y': 'sss'})
+
+ _loader = prop('x') >> int
+ assert _loader(t1) == 1
+ with pytest.raises(TypeError):
+ _loader(t2)
+ with pytest.raises(AttributeError):
+ _loader(t3)
+
+ _loader = (prop('x') >> str) | (prop('y') >> str) | (prop('z') >> str)
+ assert _loader(t1) == 'string'
+ assert _loader(t2) == 'str'
+ assert _loader(t3) == 'sss'
+
+ def test_method(self):
+ t1 = 'STRING'
+ t2 = 2
+ t3 = EasyDict({'lower': 1})
+
+ _loader = method('lower')
+ assert _loader(t1)() == 'string'
+ with pytest.raises(TypeError):
+ _loader(t2)
+ with pytest.raises(TypeError):
+ _loader(t3)
+
+ def test_fcall(self):
+ _loader = fcall('STRING')
+ assert _loader(lambda x: len(x)) == 6
+ assert _loader(str.lower) == 'string'
+
+ def test_fpartial(self):
+ _loader = fpartial(x=2)
+
+ def _func_1(x, y):
+ return x + y
+
+ def _func_2(x, y):
+ return x * y
+
+ assert _loader(_func_1)(y=6) == 8
+ assert _loader(_func_2)(y=6) == 12
+
+ def test_func_complex_case_1(self):
+ _loader = fpartial(x=1) >> ((fcall(y=1) >> interval(0, None)) | (fcall(y=2) >> interval(None, 0) >> negative()))
+
+ def _func_1(x, y):
+ return x + y
+
+ def _func_2(x, y):
+ return 5 * x - 4 * y
+
+ def _func_3(x, y):
+ return -5 * x - 4 * y
+
+ assert _loader(_func_1) == 2
+ assert _loader(_func_2) == 1
+ assert _loader(_func_3) == 13
diff --git a/DI-engine/ding/utils/loader/tests/loader/test_utils.py b/DI-engine/ding/utils/loader/tests/loader/test_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d832045c9c6e833a07adf928424a946a1aaea402
--- /dev/null
+++ b/DI-engine/ding/utils/loader/tests/loader/test_utils.py
@@ -0,0 +1,48 @@
+import pytest
+
+from ding.utils.loader import Loader, interval, to_type, is_type, keep, optional, check_only, raw, check
+
+
+@pytest.mark.unittest
+class TestConfigLoaderUtils:
+
+ def test_keep(self):
+ _loader = keep()
+ assert _loader(1) == 1
+ assert _loader(2) == 2
+ assert _loader(None) is None
+
+ def test_raw(self):
+ _loader = raw(233)
+ assert _loader(1) == 233
+ assert _loader(2) == 233
+
+ def test_optional(self):
+ _loader = optional(Loader(int) | float)
+ assert _loader(1) == 1
+ assert _loader(2.0) == 2.0
+ assert _loader(None) is None
+ with pytest.raises(TypeError):
+ _loader('string')
+
+ def test_check_only(self):
+ tonumber = to_type(int) | to_type(float)
+ _loader = tonumber >> (((lambda x: x + 1) >> interval(1, 2)) | ((lambda x: x - 1) >> interval(-2, -1)))
+
+ def test_check(self):
+ _loader = check(is_type(int) | is_type(float))
+ assert _loader(1)
+ assert _loader(1.2)
+ assert not _loader('sjhkj')
+
+ def test_complex_case_1(self):
+ tonumber = to_type(int) | to_type(float)
+ _loader = tonumber >> check_only(
+ ((lambda x: x + 1) >> interval(1, 2)) | ((lambda x: x - 1) >> interval(-2, -1))
+ )
+ assert _loader(1) == 1
+ assert _loader(-1) == -1
+ with pytest.raises(ValueError):
+ _loader(2)
+ with pytest.raises(ValueError):
+ _loader(-2.0)
diff --git a/DI-engine/ding/utils/loader/tests/test_cartpole_dqn_serial_config_loader.py b/DI-engine/ding/utils/loader/tests/test_cartpole_dqn_serial_config_loader.py
new file mode 100644
index 0000000000000000000000000000000000000000..56dda2b1a907990262326f99e07ba790d5613f6b
--- /dev/null
+++ b/DI-engine/ding/utils/loader/tests/test_cartpole_dqn_serial_config_loader.py
@@ -0,0 +1,87 @@
+import math
+
+import pytest
+
+from dizoo.classic_control.cartpole.config import cartpole_dqn_config, cartpole_dqn_create_config
+from ding.utils.loader import dict_, is_type, to_type, collection, interval, is_positive, mcmp, enum, item, raw, \
+ check_only
+from ding.utils import pretty_print
+
+
+@pytest.mark.unittest
+def test_main_config():
+ element_loader = dict_(
+ env=item('env') >> dict_(
+ collector_env_num=item('collector_env_num') >> is_type(int) >> interval(1, 32),
+ evaluator_env_num=item('evaluator_env_num') >> is_type(int) >> interval(1, 32),
+ ),
+ policy=item('policy') >> dict_(
+ type=item('type') | raw('dqn') >> is_type(str),
+ cuda=item('cuda') >> is_type(bool),
+ on_policy=item('on_policy') | raw(False) >> is_type(bool),
+ priority=item('priority') | raw(False) >> is_type(bool),
+ model=item('model') >> dict_(
+ obs_dim=item('obs_shape') >> (is_type(int) | collection(int)),
+ action_dim=item('action_shape') >> (is_type(int) | collection(int)),
+ hidden_size_list=item('encoder_hidden_size_list') >> is_type(list),
+ dueling=item('dueling') >> is_type(bool),
+ ),
+ learn=item('learn') >> dict_(
+ multi_gpu=item('multi_gpu') | raw(False) >> is_type(bool),
+ update_per_collect=item('update_per_collect') | raw(1) >> (is_type(int) & interval(1, 500)),
+ batch_size=item('batch_size') | raw(64) >> (is_type(int) & interval(1, 128)),
+ learning_rate=item('learning_rate') | raw(0.001) >> interval(0.0001, 0.01),
+ target_update_freq=item('target_update_freq') | raw(200) >> (is_type(int) & interval(100, 2000)),
+ discount_factor=item('discount_factor') | raw(0.99) >> interval(0.9, 1.0),
+ nstep=item('nstep') | raw(1) >> (is_type(int) & interval(1, 10)),
+ ignore_done=item('ignore_done') | raw(False) >> is_type(bool),
+ ),
+ collect=item('collect') >> dict_(
+ n_sample=item('n_sample') | raw(20) >> is_type(int) >> interval(8, 128),
+ n_episode=item('n_episode') | raw(10) >> is_type(int) >> interval(2, 10),
+ unroll_len=item('unroll_len') | raw(1) >> is_type(int) >> interval(1, 200),
+ nstep=item('nstep') | raw(1) >> (is_type(int) & interval(1, 10)),
+ ),
+ other=item('other') >> dict_(
+ eps=item('eps') >> dict_(
+ type=item('type') >> enum('linear', 'exp'),
+ start=item('start') >> interval(0.0, 1.0, left_ok=False),
+ end=item('end') >> interval(0.0, 1.0, right_ok=False),
+ decay=item('decay') >> (is_type(int) | (is_type(float) >> to_type(int))) >> is_positive(),
+ ),
+ replay_buffer=item('replay_buffer') >>
+ dict_(replay_buffer_size=item('replay_buffer_size') >> is_type(int) >> interval(1, math.inf), ),
+ ),
+ ),
+ )
+ learn_nstep = item('policy') >> item('learn') >> item('nstep')
+ collect_nstep = item('policy') >> item('collect') >> item('nstep')
+ relation_loader = check_only(
+ dict_(
+ nstep_check=mcmp(learn_nstep, "==", collect_nstep),
+ eps_check=item('policy') >> item('other') >> item('eps') >> mcmp(item('start'), ">=", item('end')),
+ )
+ )
+ cartpole_dqn_main_loader = element_loader >> relation_loader
+
+ output = cartpole_dqn_main_loader(cartpole_dqn_config)
+ pretty_print(output, direct_print=True)
+
+
+@pytest.mark.unittest
+def test_create_config():
+ element_loader = dict_(
+ env=item('env') >> dict_(
+ import_names=item('import_names') >> collection(str),
+ type=item('type') >> is_type(str),
+ ),
+ env_manager=item('env_manager') >> dict_(
+ type=item('type') >> enum('base', 'subprocess', 'async_subprocess'),
+ shared_memory=item('shared_memory') | raw(True) >> is_type(bool),
+ ),
+ policy=item('policy') >> dict_(type=item('type') >> is_type(str), ),
+ )
+ cartpole_dqn_create_loader = element_loader
+
+ output = cartpole_dqn_create_loader(cartpole_dqn_create_config)
+ pretty_print(output, direct_print=True)
diff --git a/DI-engine/ding/utils/loader/types.py b/DI-engine/ding/utils/loader/types.py
new file mode 100644
index 0000000000000000000000000000000000000000..6039395ca60220dd53bd61fb77b1cb117d830499
--- /dev/null
+++ b/DI-engine/ding/utils/loader/types.py
@@ -0,0 +1,95 @@
+from functools import partial
+
+from .base import Loader, ILoaderClass, _reset_exception
+from .utils import check_only
+
+
+def is_type(type_: type) -> ILoaderClass:
+ """
+ Overview:
+ Create a type loader.
+ Arguments:
+ - type_ (:obj:`type`): The type.
+ """
+
+ if isinstance(type_, type):
+ return Loader(type_)
+ else:
+ raise TypeError('Type variable expected but {actual} found.'.format(actual=repr(type(type_).__name__)))
+
+
+def to_type(type_: type) -> ILoaderClass:
+ """
+ Overview:
+ Create a type loader.
+ Arguments:
+ - type_ (:obj:`type`): The type.
+ """
+
+ return Loader(lambda v: type_(v))
+
+
+def is_callable() -> ILoaderClass:
+ """
+ Overview:
+ Create a callable loader.
+ """
+
+ return _reset_exception(
+ check_only(prop('__call__')),
+ lambda v, e: TypeError('callable expected but {func} not found'.format(func=repr('__call__')))
+ )
+
+
+def prop(attr_name: str) -> ILoaderClass:
+ """
+ Overview:
+ Create a attribute loader.
+ Arguments:
+ - attr_name (:obj:`str`): The attribute name.
+ """
+
+ return Loader(
+ (
+ lambda v: hasattr(v, attr_name), lambda v: getattr(v, attr_name),
+ AttributeError('attribute {name} expected but not found'.format(name=repr(attr_name)))
+ )
+ )
+
+
+def method(method_name: str) -> ILoaderClass:
+ """
+ Overview:
+ Create a method loader.
+ Arguments:
+ - method_name (:obj:`str`): The method name.
+ """
+
+ return _reset_exception(
+ prop(method_name) >> is_callable(), lambda v, e:
+ TypeError('type {type} not support function {func}'.format(type=repr(type(v).__name__), func=repr('__iter__')))
+ )
+
+
+def fcall(*args, **kwargs) -> ILoaderClass:
+ """
+ Overview:
+ Create a function loader.
+ Arguments:
+ - args (:obj:`Tuple[Any]`): The args.
+ - kwargs (:obj:`Dict[str, Any]`): The kwargs.
+ """
+
+ return Loader(lambda v: v(*args, **kwargs))
+
+
+def fpartial(*args, **kwargs) -> ILoaderClass:
+ """
+ Overview:
+ Create a partial function loader.
+ Arguments:
+ - args (:obj:`Tuple[Any]`): The args.
+ - kwargs (:obj:`Dict[str, Any]`): The kwargs.
+ """
+
+ return Loader(lambda v: partial(v, *args, **kwargs))
diff --git a/DI-engine/ding/utils/loader/utils.py b/DI-engine/ding/utils/loader/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..140bbf033db0d6d97c9c1d7ef12ea2f18df29cbb
--- /dev/null
+++ b/DI-engine/ding/utils/loader/utils.py
@@ -0,0 +1,52 @@
+from .base import Loader, ILoaderClass
+
+
+def keep() -> ILoaderClass:
+ """
+ Overview:
+ Create a keep loader.
+ """
+
+ return Loader(lambda v: v)
+
+
+def raw(value) -> ILoaderClass:
+ """
+ Overview:
+ Create a raw loader.
+ """
+
+ return Loader(lambda v: value)
+
+
+def optional(loader) -> ILoaderClass:
+ """
+ Overview:
+ Create a optional loader.
+ Arguments:
+ - loader (:obj:`ILoaderClass`): The loader.
+ """
+
+ return Loader(loader) | None
+
+
+def check_only(loader) -> ILoaderClass:
+ """
+ Overview:
+ Create a check only loader.
+ Arguments:
+ - loader (:obj:`ILoaderClass`): The loader.
+ """
+
+ return Loader(loader) & keep()
+
+
+def check(loader) -> ILoaderClass:
+ """
+ Overview:
+ Create a check loader.
+ Arguments:
+ - loader (:obj:`ILoaderClass`): The loader.
+ """
+
+ return Loader(lambda x: Loader(loader).check(x))
diff --git a/DI-engine/ding/utils/lock_helper.py b/DI-engine/ding/utils/lock_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..368cac35f8aacb443cd76d6b51cd76915107b30c
--- /dev/null
+++ b/DI-engine/ding/utils/lock_helper.py
@@ -0,0 +1,182 @@
+import os
+import multiprocessing
+import threading
+import platform
+from enum import Enum, unique
+
+from pathlib import Path
+if platform.system().lower() != 'windows':
+ import fcntl
+else:
+ fcntl = None
+
+
+@unique
+class LockContextType(Enum):
+ """
+ Overview:
+ Enum to express the type of the lock.
+ """
+ THREAD_LOCK = 1
+ PROCESS_LOCK = 2
+
+
+_LOCK_TYPE_MAPPING = {
+ LockContextType.THREAD_LOCK: threading.Lock,
+ LockContextType.PROCESS_LOCK: multiprocessing.Lock,
+}
+
+
+class LockContext(object):
+ """
+ Overview:
+ Generate a LockContext in order to make sure the thread safety.
+
+ Interfaces:
+ ``__init__``, ``__enter__``, ``__exit__``.
+
+ Example:
+ >>> with LockContext() as lock:
+ >>> print("Do something here.")
+ """
+
+ def __init__(self, type_: LockContextType = LockContextType.THREAD_LOCK):
+ """
+ Overview:
+ Init the lock according to the given type.
+
+ Arguments:
+ - type_ (:obj:`LockContextType`): The type of lock to be used. Defaults to LockContextType.THREAD_LOCK.
+ """
+ self.lock = _LOCK_TYPE_MAPPING[type_]()
+
+ def acquire(self):
+ """
+ Overview:
+ Acquires the lock.
+ """
+ self.lock.acquire()
+
+ def release(self):
+ """
+ Overview:
+ Releases the lock.
+ """
+ self.lock.release()
+
+ def __enter__(self):
+ """
+ Overview:
+ Enters the context and acquires the lock.
+ """
+ self.lock.acquire()
+
+ def __exit__(self, *args, **kwargs):
+ """
+ Overview:
+ Exits the context and releases the lock.
+ Arguments:
+ - args (:obj:`Tuple`): The arguments passed to the ``__exit__`` function.
+ - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__exit__`` function.
+ """
+ self.lock.release()
+
+
+rw_lock_mapping = {}
+
+
+def get_rw_file_lock(name: str, op: str):
+ """
+ Overview:
+ Get generated file lock with name and operator
+ Arguments:
+ - name (:obj:`str`): Lock's name.
+ - op (:obj:`str`): Assigned operator, i.e. ``read`` or ``write``.
+ Returns:
+ - (:obj:`RWLockFairD`): Generated rwlock
+ """
+ assert op in ['read', 'write']
+ try:
+ from readerwriterlock import rwlock
+ except ImportError:
+ import sys
+ from ditk import logging
+ logging.warning("Please install readerwriterlock first, such as `pip3 install readerwriterlock`.")
+ sys.exit(1)
+ if name not in rw_lock_mapping:
+ rw_lock_mapping[name] = rwlock.RWLockFairD()
+ lock = rw_lock_mapping[name]
+ if op == 'read':
+ return lock.gen_rlock()
+ elif op == 'write':
+ return lock.gen_wlock()
+
+
+class FcntlContext:
+ """
+ Overview:
+ A context manager that acquires an exclusive lock on a file using fcntl. \
+ This is useful for preventing multiple processes from running the same code.
+
+ Interfaces:
+ ``__init__``, ``__enter__``, ``__exit__``.
+
+ Example:
+ >>> lock_path = "/path/to/lock/file"
+ >>> with FcntlContext(lock_path) as lock:
+ >>> # Perform operations while the lock is held
+
+ """
+
+ def __init__(self, lock_path: str) -> None:
+ """
+ Overview:
+ Initialize the LockHelper object.
+
+ Arguments:
+ - lock_path (:obj:`str`): The path to the lock file.
+ """
+ self.lock_path = lock_path
+ self.f = None
+
+ def __enter__(self) -> None:
+ """
+ Overview:
+ Acquires the lock and opens the lock file in write mode. \
+ If the lock file does not exist, it is created.
+ """
+ assert self.f is None, self.lock_path
+ self.f = open(self.lock_path, 'w')
+ fcntl.flock(self.f.fileno(), fcntl.LOCK_EX)
+
+ def __exit__(self, *args, **kwargs) -> None:
+ """
+ Overview:
+ Closes the file and releases any resources used by the lock_helper object.
+ Arguments:
+ - args (:obj:`Tuple`): The arguments passed to the ``__exit__`` function.
+ - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__exit__`` function.
+ """
+ self.f.close()
+ self.f = None
+
+
+def get_file_lock(name: str, op: str) -> FcntlContext:
+ """
+ Overview:
+ Acquires a file lock for the specified file. \
+
+ Arguments:
+ - name (:obj:`str`): The name of the file.
+ - op (:obj:`str`): The operation to perform on the file lock.
+ """
+ if fcntl is None:
+ return get_rw_file_lock(name, op)
+ else:
+ lock_name = name + '.lock'
+ if not os.path.isfile(lock_name):
+ try:
+ Path(lock_name).touch()
+ except Exception as e:
+ pass
+ return FcntlContext(lock_name)
diff --git a/DI-engine/ding/utils/log_helper.py b/DI-engine/ding/utils/log_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a966532acbe21fcc4e8ed431c5e07f12ab52353
--- /dev/null
+++ b/DI-engine/ding/utils/log_helper.py
@@ -0,0 +1,174 @@
+import json
+from ditk import logging
+import os
+from typing import Optional, Tuple, Union, Dict, Any
+
+import ditk.logging
+import numpy as np
+import yaml
+from hbutils.system import touch
+from tabulate import tabulate
+
+from .log_writer_helper import DistributedWriter
+
+
+def build_logger(
+ path: str,
+ name: Optional[str] = None,
+ need_tb: bool = True,
+ need_text: bool = True,
+ text_level: Union[int, str] = logging.INFO
+) -> Tuple[Optional[logging.Logger], Optional['SummaryWriter']]: # noqa
+ """
+ Overview:
+ Build text logger and tensorboard logger.
+ Arguments:
+ - path (:obj:`str`): Logger(``Textlogger`` & ``SummaryWriter``)'s saved dir
+ - name (:obj:`str`): The logger file name
+ - need_tb (:obj:`bool`): Whether ``SummaryWriter`` instance would be created and returned
+ - need_text (:obj:`bool`): Whether ``loggingLogger`` instance would be created and returned
+ - text_level (:obj:`int`` or :obj:`str`): Logging level of ``logging.Logger``, default set to ``logging.INFO``
+ Returns:
+ - logger (:obj:`Optional[logging.Logger]`): Logger that displays terminal output
+ - tb_logger (:obj:`Optional['SummaryWriter']`): Saves output to tfboard, only return when ``need_tb``.
+ """
+ if name is None:
+ name = 'default'
+ logger = LoggerFactory.create_logger(path, name=name, level=text_level) if need_text else None
+ tb_name = name + '_tb_logger'
+ tb_logger = TBLoggerFactory.create_logger(os.path.join(path, tb_name)) if need_tb else None
+ return logger, tb_logger
+
+
+class TBLoggerFactory(object):
+ """
+ Overview:
+ TBLoggerFactory is a factory class for ``SummaryWriter``.
+ Interfaces:
+ ``create_logger``
+ Properties:
+ - ``tb_loggers`` (:obj:`Dict[str, SummaryWriter]`): A dict that stores ``SummaryWriter`` instances.
+ """
+
+ tb_loggers = {}
+
+ @classmethod
+ def create_logger(cls: type, logdir: str) -> DistributedWriter:
+ if logdir in cls.tb_loggers:
+ return cls.tb_loggers[logdir]
+ tb_logger = DistributedWriter(logdir)
+ cls.tb_loggers[logdir] = tb_logger
+ return tb_logger
+
+
+class LoggerFactory(object):
+ """
+ Overview:
+ LoggerFactory is a factory class for ``logging.Logger``.
+ Interfaces:
+ ``create_logger``, ``get_tabulate_vars``, ``get_tabulate_vars_hor``
+ """
+
+ @classmethod
+ def create_logger(cls, path: str, name: str = 'default', level: Union[int, str] = logging.INFO) -> logging.Logger:
+ """
+ Overview:
+ Create logger using logging
+ Arguments:
+ - name (:obj:`str`): Logger's name
+ - path (:obj:`str`): Logger's save dir
+ - level (:obj:`int` or :obj:`str`): Used to set the level. Reference: ``Logger.setLevel`` method.
+ Returns:
+ - (:obj:`logging.Logger`): new logging logger
+ """
+ ditk.logging.try_init_root(level)
+
+ logger_name = f'{name}_logger'
+ logger_file_path = os.path.join(path, f'{logger_name}.txt')
+ touch(logger_file_path)
+
+ logger = ditk.logging.getLogger(logger_name, level, [logger_file_path])
+ logger.get_tabulate_vars = LoggerFactory.get_tabulate_vars
+ logger.get_tabulate_vars_hor = LoggerFactory.get_tabulate_vars_hor
+
+ return logger
+
+ @staticmethod
+ def get_tabulate_vars(variables: Dict[str, Any]) -> str:
+ """
+ Overview:
+ Get the text description in tabular form of all vars
+ Arguments:
+ - variables (:obj:`List[str]`): Names of the vars to query.
+ Returns:
+ - string (:obj:`str`): Text description in tabular form of all vars
+ """
+ headers = ["Name", "Value"]
+ data = []
+ for k, v in variables.items():
+ data.append([k, "{:.6f}".format(v)])
+ s = "\n" + tabulate(data, headers=headers, tablefmt='grid')
+ return s
+
+ @staticmethod
+ def get_tabulate_vars_hor(variables: Dict[str, Any]) -> str:
+ """
+ Overview:
+ Get the text description in tabular form of all vars
+ Arguments:
+ - variables (:obj:`List[str]`): Names of the vars to query.
+ """
+
+ column_to_divide = 5 # which includes the header "Name & Value"
+
+ datak = []
+ datav = []
+
+ divide_count = 0
+ for k, v in variables.items():
+ if divide_count == 0 or divide_count >= (column_to_divide - 1):
+ datak.append("Name")
+ datav.append("Value")
+ if divide_count >= (column_to_divide - 1):
+ divide_count = 0
+ divide_count += 1
+
+ datak.append(k)
+ if not isinstance(v, str) and np.isscalar(v):
+ datav.append("{:.6f}".format(v))
+ else:
+ datav.append(v)
+
+ s = "\n"
+ row_number = len(datak) // column_to_divide + 1
+ for row_id in range(row_number):
+ item_start = row_id * column_to_divide
+ item_end = (row_id + 1) * column_to_divide
+ if (row_id + 1) * column_to_divide > len(datak):
+ item_end = len(datak)
+ data = [datak[item_start:item_end], datav[item_start:item_end]]
+ s = s + tabulate(data, tablefmt='grid') + "\n"
+
+ return s
+
+
+def pretty_print(result: dict, direct_print: bool = True) -> str:
+ """
+ Overview:
+ Print a dict ``result`` in a pretty way
+ Arguments:
+ - result (:obj:`dict`): The result to print
+ - direct_print (:obj:`bool`): Whether to print directly
+ Returns:
+ - string (:obj:`str`): The pretty-printed result in str format
+ """
+ result = result.copy()
+ out = {}
+ for k, v in result.items():
+ if v is not None:
+ out[k] = v
+ cleaned = json.dumps(out)
+ string = yaml.safe_dump(json.loads(cleaned), default_flow_style=False)
+ if direct_print:
+ print(string)
+ return string
diff --git a/DI-engine/ding/utils/log_writer_helper.py b/DI-engine/ding/utils/log_writer_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f8a1c511511abcbf7f713fdfd704f257fbd4a32
--- /dev/null
+++ b/DI-engine/ding/utils/log_writer_helper.py
@@ -0,0 +1,171 @@
+from typing import TYPE_CHECKING
+
+from tensorboardX import SummaryWriter
+
+if TYPE_CHECKING:
+ # TYPE_CHECKING is always False at runtime, but mypy will evaluate the contents of this block.
+ # So if you import this module within TYPE_CHECKING, you will get code hints and other benefits.
+ # Here is a good answer on stackoverflow:
+ # https://stackoverflow.com/questions/39740632/python-type-hinting-without-cyclic-imports
+ from ding.framework import Parallel
+
+
+class DistributedWriter(SummaryWriter):
+ """
+ Overview:
+ A simple subclass of SummaryWriter that supports writing to one process in multi-process mode.
+ The best way is to use it in conjunction with the ``router`` to take advantage of the message \
+ and event components of the router (see ``writer.plugin``).
+ Interfaces:
+ ``get_instance``, ``plugin``, ``initialize``, ``__del__``
+ """
+ root = None
+
+ def __init__(self, *args, **kwargs):
+ """
+ Overview:
+ Initialize the DistributedWriter object.
+ Arguments:
+ - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \
+ SummaryWriter.
+ - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \
+ SummaryWriter.
+ """
+
+ self._default_writer_to_disk = kwargs.get("write_to_disk") if "write_to_disk" in kwargs else True
+ # We need to write data to files lazily, so we should not use file writer in __init__,
+ # On the contrary, we will initialize the file writer when the user calls the
+ # add_* function for the first time
+ kwargs["write_to_disk"] = False
+ super().__init__(*args, **kwargs)
+ self._in_parallel = False
+ self._router = None
+ self._is_writer = False
+ self._lazy_initialized = False
+
+ @classmethod
+ def get_instance(cls, *args, **kwargs) -> "DistributedWriter":
+ """
+ Overview:
+ Get instance and set the root level instance on the first called. If args and kwargs is none,
+ this method will return root instance.
+ Arguments:
+ - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \
+ SummaryWriter.
+ - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \
+ SummaryWriter.
+ """
+ if args or kwargs:
+ ins = cls(*args, **kwargs)
+ if cls.root is None:
+ cls.root = ins
+ return ins
+ else:
+ return cls.root
+
+ def plugin(self, router: "Parallel", is_writer: bool = False) -> "DistributedWriter":
+ """
+ Overview:
+ Plugin ``router``, so when using this writer with active router, it will automatically send requests\
+ to the main writer instead of writing it to the disk. So we can collect data from multiple processes\
+ and write them into one file.
+ Arguments:
+ - router (:obj:`Parallel`): The router to be plugged in.
+ - is_writer (:obj:`bool`): Whether this writer is the main writer.
+ Examples:
+ >>> DistributedWriter().plugin(router, is_writer=True)
+ """
+ if router.is_active:
+ self._in_parallel = True
+ self._router = router
+ self._is_writer = is_writer
+ if is_writer:
+ self.initialize()
+ self._lazy_initialized = True
+ router.on("distributed_writer", self._on_distributed_writer)
+ return self
+
+ def _on_distributed_writer(self, fn_name: str, *args, **kwargs):
+ """
+ Overview:
+ This method is called when the router receives a request to write data.
+ Arguments:
+ - fn_name (:obj:`str`): The name of the function to be called.
+ - args (:obj:`Tuple`): The arguments passed to the function to be called.
+ - kwargs (:obj:`Dict`): The keyword arguments passed to the function to be called.
+ """
+
+ if self._is_writer:
+ getattr(self, fn_name)(*args, **kwargs)
+
+ def initialize(self):
+ """
+ Overview:
+ Initialize the file writer.
+ """
+ self.close()
+ self._write_to_disk = self._default_writer_to_disk
+ self._get_file_writer()
+ self._lazy_initialized = True
+
+ def __del__(self):
+ """
+ Overview:
+ Close the file writer.
+ """
+ self.close()
+
+
+def enable_parallel(fn_name, fn):
+ """
+ Overview:
+ Decorator to enable parallel writing.
+ Arguments:
+ - fn_name (:obj:`str`): The name of the function to be called.
+ - fn (:obj:`Callable`): The function to be called.
+ """
+
+ def _parallel_fn(self: DistributedWriter, *args, **kwargs):
+ if not self._lazy_initialized:
+ self.initialize()
+ if self._in_parallel and not self._is_writer:
+ self._router.emit("distributed_writer", fn_name, *args, **kwargs)
+ else:
+ fn(self, *args, **kwargs)
+
+ return _parallel_fn
+
+
+ready_to_parallel_fns = [
+ 'add_audio',
+ 'add_custom_scalars',
+ 'add_custom_scalars_marginchart',
+ 'add_custom_scalars_multilinechart',
+ 'add_embedding',
+ 'add_figure',
+ 'add_graph',
+ 'add_graph_deprecated',
+ 'add_histogram',
+ 'add_histogram_raw',
+ 'add_hparams',
+ 'add_image',
+ 'add_image_with_boxes',
+ 'add_images',
+ 'add_mesh',
+ 'add_onnx_graph',
+ 'add_openvino_graph',
+ 'add_pr_curve',
+ 'add_pr_curve_raw',
+ 'add_scalar',
+ 'add_scalars',
+ 'add_text',
+ 'add_video',
+]
+for fn_name in ready_to_parallel_fns:
+ if hasattr(DistributedWriter, fn_name):
+ setattr(DistributedWriter, fn_name, enable_parallel(fn_name, getattr(DistributedWriter, fn_name)))
+
+# Examples:
+# In main, `distributed_writer.plugin(task.router, is_writer=True)`,
+# In middleware, `distributed_writer.record()`
+distributed_writer = DistributedWriter()
diff --git a/DI-engine/ding/utils/normalizer_helper.py b/DI-engine/ding/utils/normalizer_helper.py
new file mode 100755
index 0000000000000000000000000000000000000000..0fc914f30ef0db301ae806962d61c4248884064d
--- /dev/null
+++ b/DI-engine/ding/utils/normalizer_helper.py
@@ -0,0 +1,493 @@
+import numpy as np
+
+
+class DatasetNormalizer:
+ """
+ Overview:
+ The `DatasetNormalizer` class provides functionality to normalize and unnormalize data in a dataset.
+ It takes a dataset as input and applies a normalizer function to each key in the dataset.
+
+ Interfaces:
+ ``__init__``, ``__repr__``, ``normalize``, ``unnormalize``.
+ """
+
+ def __init__(self, dataset: np.ndarray, normalizer: str, path_lengths: list = None):
+ """
+ Overview:
+ Initialize the NormalizerHelper object.
+
+ Arguments:
+ - dataset (:obj:`np.ndarray`): The dataset to be normalized.
+ - normalizer (:obj:`str`): The type of normalizer to be used. Can be a string representing the name of \
+ the normalizer class.
+ - path_lengths (:obj:`list`): The length of the paths in the dataset. Defaults to None.
+ """
+ dataset = flatten(dataset, path_lengths)
+
+ self.observation_dim = dataset['observations'].shape[1]
+ self.action_dim = dataset['actions'].shape[1]
+
+ if isinstance(normalizer, str):
+ normalizer = eval(normalizer)
+
+ self.normalizers = {}
+ for key, val in dataset.items():
+ try:
+ self.normalizers[key] = normalizer(val)
+ except:
+ print(f'[ utils/normalization ] Skipping {key} | {normalizer}')
+ # key: normalizer(val)
+ # for key, val in dataset.items()
+
+ def __repr__(self) -> str:
+ """
+ Overview:
+ Returns a string representation of the NormalizerHelper object. \
+ The string representation includes the key-value pairs of the normalizers \
+ stored in the NormalizerHelper object.
+ Returns:
+ - ret (:obj:`str`):A string representation of the NormalizerHelper object.
+ """
+ string = ''
+ for key, normalizer in self.normalizers.items():
+ string += f'{key}: {normalizer}]\n'
+ return string
+
+ def normalize(self, x: np.ndarray, key: str) -> np.ndarray:
+ """
+ Overview:
+ Normalize the input data using the specified key.
+
+ Arguments:
+ - x (:obj:`np.ndarray`): The input data to be normalized.
+ - key (:obj`str`): The key to identify the normalizer.
+
+ Returns:
+ - ret (:obj:`np.ndarray`): The normalized value of the input data.
+ """
+ return self.normalizers[key].normalize(x)
+
+ def unnormalize(self, x: np.ndarray, key: str) -> np.ndarray:
+ """
+ Overview:
+ Unnormalizes the given value `x` using the specified `key`.
+
+ Arguments:
+ - x (:obj:`np.ndarray`): The value to be unnormalized.
+ - key (:obj`str`): The key to identify the normalizer.
+
+ Returns:
+ - ret (:obj:`np.ndarray`): The unnormalized value.
+ """
+ return self.normalizers[key].unnormalize(x)
+
+
+def flatten(dataset: dict, path_lengths: list) -> dict:
+ """
+ Overview:
+ Flattens dataset of { key: [ n_episodes x max_path_length x dim ] } \
+ to { key : [ (n_episodes * sum(path_lengths)) x dim ] }
+
+ Arguments:
+ - dataset (:obj:`dict`): The dataset to be flattened.
+ - path_lengths (:obj:`list`): A list of path lengths for each episode.
+
+ Returns:
+ - flattened (:obj:`dict`): The flattened dataset.
+ """
+ flattened = {}
+ for key, xs in dataset.items():
+ assert len(xs) == len(path_lengths)
+ if key == 'path_lengths':
+ continue
+ flattened[key] = np.concatenate([x[:length] for x, length in zip(xs, path_lengths)], axis=0)
+ return flattened
+
+
+class Normalizer:
+ """
+ Overview:
+ Parent class, subclass by defining the `normalize` and `unnormalize` methods
+
+ Interfaces:
+ ``__init__``, ``__repr__``, ``normalize``, ``unnormalize``.
+ """
+
+ def __init__(self, X):
+ """
+ Overview:
+ Initialize the Normalizer object.
+ Arguments:
+ - X (:obj:`np.ndarray`): The data to be normalized.
+ """
+
+ self.X = X.astype(np.float32)
+ self.mins = X.min(axis=0)
+ self.maxs = X.max(axis=0)
+
+ def __repr__(self) -> str:
+ """
+ Overview:
+ Returns a string representation of the Normalizer object.
+ Returns:
+ - ret (:obj:`str`): A string representation of the Normalizer object.
+ """
+
+ return (
+ f"""[ Normalizer ] dim: {self.mins.size}\n -: """
+ f"""{np.round(self.mins, 2)}\n +: {np.round(self.maxs, 2)}\n"""
+ )
+
+ def normalize(self, *args, **kwargs):
+ """
+ Overview:
+ Normalize the input data.
+ Arguments:
+ - args (:obj:`list`): The arguments passed to the ``normalize`` function.
+ - kwargs (:obj:`dict`): The keyword arguments passed to the ``normalize`` function.
+ """
+
+ raise NotImplementedError()
+
+ def unnormalize(self, *args, **kwargs):
+ """
+ Overview:
+ Unnormalize the input data.
+ Arguments:
+ - args (:obj:`list`): The arguments passed to the ``unnormalize`` function.
+ - kwargs (:obj:`dict`): The keyword arguments passed to the ``unnormalize`` function.
+ """
+
+ raise NotImplementedError()
+
+
+class GaussianNormalizer(Normalizer):
+ """
+ Overview:
+ A class that normalizes data to zero mean and unit variance.
+
+ Interfaces:
+ ``__init__``, ``__repr__``, ``normalize``, ``unnormalize``.
+ """
+
+ def __init__(self, *args, **kwargs):
+ """
+ Overview:
+ Initialize the GaussianNormalizer object.
+ Arguments:
+ - args (:obj:`list`): The arguments passed to the ``__init__`` function of the parent class, \
+ i.e., the Normalizer class.
+ - kwargs (:obj:`dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \
+ i.e., the Normalizer class.
+ """
+
+ super().__init__(*args, **kwargs)
+ self.means = self.X.mean(axis=0)
+ self.stds = self.X.std(axis=0)
+ self.z = 1
+
+ def __repr__(self) -> str:
+ """
+ Overview:
+ Returns a string representation of the GaussianNormalizer object.
+ Returns:
+ - ret (:obj:`str`): A string representation of the GaussianNormalizer object.
+ """
+
+ return (
+ f"""[ Normalizer ] dim: {self.mins.size}\n """
+ f"""means: {np.round(self.means, 2)}\n """
+ f"""stds: {np.round(self.z * self.stds, 2)}\n"""
+ )
+
+ def normalize(self, x: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Normalize the input data.
+
+ Arguments:
+ - x (:obj:`np.ndarray`): The input data to be normalized.
+
+ Returns:
+ - ret (:obj:`np.ndarray`): The normalized data.
+ """
+ return (x - self.means) / self.stds
+
+ def unnormalize(self, x: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Unnormalize the input data.
+
+ Arguments:
+ - x (:obj:`np.ndarray`): The input data to be unnormalized.
+
+ Returns:
+ - ret (:obj:`np.ndarray`): The unnormalized data.
+ """
+ return x * self.stds + self.means
+
+
+class CDFNormalizer(Normalizer):
+ """
+ Overview:
+ A class that makes training data uniform (over each dimension) by transforming it with marginal CDFs.
+
+ Interfaces:
+ ``__init__``, ``__repr__``, ``normalize``, ``unnormalize``.
+ """
+
+ def __init__(self, X):
+ """
+ Overview:
+ Initialize the CDFNormalizer object.
+ Arguments:
+ - X (:obj:`np.ndarray`): The data to be normalized.
+ """
+
+ super().__init__(atleast_2d(X))
+ self.dim = self.X.shape[1]
+ self.cdfs = [CDFNormalizer1d(self.X[:, i]) for i in range(self.dim)]
+
+ def __repr__(self) -> str:
+ """
+ Overview:
+ Returns a string representation of the CDFNormalizer object.
+ Returns:
+ - ret (:obj:`str`): A string representation of the CDFNormalizer object.
+ """
+
+ return f'[ CDFNormalizer ] dim: {self.mins.size}\n' + ' | '.join(
+ f'{i:3d}: {cdf}' for i, cdf in enumerate(self.cdfs)
+ )
+
+ def wrap(self, fn_name: str, x: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Wraps the given function name and applies it to the input data.
+
+ Arguments:
+ - fn_name (:obj:`str`): The name of the function to be applied.
+ - x (:obj:`np.ndarray`): The input data.
+
+ Returns:
+ - ret: The output of the function applied to the input data.
+ """
+ shape = x.shape
+ # reshape to 2d
+ x = x.reshape(-1, self.dim)
+ out = np.zeros_like(x)
+ for i, cdf in enumerate(self.cdfs):
+ fn = getattr(cdf, fn_name)
+ out[:, i] = fn(x[:, i])
+ return out.reshape(shape)
+
+ def normalize(self, x: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Normalizes the input data.
+
+ Arguments:
+ - x (:obj:`np.ndarray`): The input data.
+
+ Returns:
+ - ret (:obj:`np.ndarray`): The normalized data.
+ """
+ return self.wrap('normalize', x)
+
+ def unnormalize(self, x: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Unnormalizes the input data.
+
+ Arguments:
+ - x (:obj:`np.ndarray`): The input data.
+
+ Returns:
+ - ret (:obj:`np.ndarray`):: The unnormalized data.
+ """
+ return self.wrap('unnormalize', x)
+
+
+class CDFNormalizer1d:
+ """
+ Overview:
+ CDF normalizer for a single dimension. This class provides methods to normalize and unnormalize data \
+ using the Cumulative Distribution Function (CDF) approach.
+ Interfaces:
+ ``__init__``, ``__repr__``, ``normalize``, ``unnormalize``.
+ """
+
+ def __init__(self, X: np.ndarray):
+ """
+ Overview:
+ Initialize the CDFNormalizer1d object.
+ Arguments:
+ - X (:obj:`np.ndarray`): The data to be normalized.
+ """
+
+ import scipy.interpolate as interpolate
+ assert X.ndim == 1
+ self.X = X.astype(np.float32)
+ if self.X.max() == self.X.min():
+ self.constant = True
+ else:
+ self.constant = False
+ quantiles, cumprob = empirical_cdf(self.X)
+ self.fn = interpolate.interp1d(quantiles, cumprob)
+ self.inv = interpolate.interp1d(cumprob, quantiles)
+
+ self.xmin, self.xmax = quantiles.min(), quantiles.max()
+ self.ymin, self.ymax = cumprob.min(), cumprob.max()
+
+ def __repr__(self) -> str:
+ """
+ Overview:
+ Returns a string representation of the CDFNormalizer1d object.
+ """
+
+ return (f'[{np.round(self.xmin, 2):.4f}, {np.round(self.xmax, 2):.4f}')
+
+ def normalize(self, x: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Normalize the input data.
+
+ Arguments:
+ - x (:obj:`np.ndarray`): The data to be normalized.
+
+ Returns:
+ - ret (:obj:`np.ndarray`): The normalized data.
+ """
+ if self.constant:
+ return x
+
+ x = np.clip(x, self.xmin, self.xmax)
+ # [ 0, 1 ]
+ y = self.fn(x)
+ # [ -1, 1 ]
+ y = 2 * y - 1
+ return y
+
+ def unnormalize(self, x: np.ndarray, eps: float = 1e-4) -> np.ndarray:
+ """
+ Overview:
+ Unnormalize the input data.
+
+ Arguments:
+ - x (:obj:`np.ndarray`): The data to be unnormalized.
+ - eps (:obj:`float`): A small value used for numerical stability. Defaults to 1e-4.
+
+ Returns:
+ - ret (:obj:`np.ndarray`): The unnormalized data.
+ """
+ # [ -1, 1 ] --> [ 0, 1 ]
+ if self.constant:
+ return x
+
+ x = (x + 1) / 2.
+
+ if (x < self.ymin - eps).any() or (x > self.ymax + eps).any():
+ print(
+ f"""[ dataset/normalization ] Warning: out of range in unnormalize: """
+ f"""[{x.min()}, {x.max()}] | """
+ f"""x : [{self.xmin}, {self.xmax}] | """
+ f"""y: [{self.ymin}, {self.ymax}]"""
+ )
+
+ x = np.clip(x, self.ymin, self.ymax)
+
+ y = self.inv(x)
+ return y
+
+
+def empirical_cdf(sample: np.ndarray) -> (np.ndarray, np.ndarray):
+ """
+ Overview:
+ Compute the empirical cumulative distribution function (CDF) of a given sample.
+
+ Arguments:
+ - sample (:obj:`np.ndarray`): The input sample for which to compute the empirical CDF.
+
+ Returns:
+ - quantiles (:obj:`np.ndarray`): The unique values in the sample.
+ - cumprob (:obj:`np.ndarray`): The cumulative probabilities corresponding to the quantiles.
+
+ References:
+ - Stack Overflow: https://stackoverflow.com/a/33346366
+ """
+
+ # find the unique values and their corresponding counts
+ quantiles, counts = np.unique(sample, return_counts=True)
+
+ # take the cumulative sum of the counts and divide by the sample size to
+ # get the cumulative probabilities between 0 and 1
+ cumprob = np.cumsum(counts).astype(np.double) / sample.size
+
+ return quantiles, cumprob
+
+
+def atleast_2d(x: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Ensure that the input array has at least two dimensions.
+
+ Arguments:
+ - x (:obj:`np.ndarray`): The input array.
+
+ Returns:
+ - ret (:obj:`np.ndarray`): The input array with at least two dimensions.
+ """
+ if x.ndim < 2:
+ x = x[:, None]
+ return x
+
+
+class LimitsNormalizer(Normalizer):
+ """
+ Overview:
+ A class that normalizes and unnormalizes values within specified limits. \
+ This class maps values within the range [xmin, xmax] to the range [-1, 1].
+
+ Interfaces:
+ ``__init__``, ``__repr__``, ``normalize``, ``unnormalize``.
+ """
+
+ def normalize(self, x: np.ndarray) -> np.ndarray:
+ """
+ Overview:
+ Normalizes the input values.
+
+ Argments:
+ - x (:obj:`np.ndarray`): The input values to be normalized.
+
+ Returns:
+ - ret (:obj:`np.ndarray`): The normalized values.
+
+ """
+ # [ 0, 1 ]
+ x = (x - self.mins) / (self.maxs - self.mins)
+ # [ -1, 1 ]
+ x = 2 * x - 1
+ return x
+
+ def unnormalize(self, x: np.ndarray, eps: float = 1e-4) -> np.ndarray:
+ """
+ Overview:
+ Unnormalizes the input values.
+
+ Arguments:
+ - x (:obj:`np.ndarray`): The input values to be unnormalized.
+ - eps (:obj:`float`): A small value used for clipping. Defaults to 1e-4.
+
+ Returns:
+ - ret (:obj:`np.ndarray`): The unnormalized values.
+
+ """
+ if x.max() > 1 + eps or x.min() < -1 - eps:
+ # print(f'[ datasets/mujoco ] Warning: sample out of range | ({x.min():.4f}, {x.max():.4f})')
+ x = np.clip(x, -1, 1)
+
+ # [ -1, 1 ] --> [ 0, 1 ]
+ x = (x + 1) / 2.
+
+ return x * (self.maxs - self.mins) + self.mins
diff --git a/DI-engine/ding/utils/orchestrator_launcher.py b/DI-engine/ding/utils/orchestrator_launcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..69324ecc081960374b0206611003f8a3fd490a2f
--- /dev/null
+++ b/DI-engine/ding/utils/orchestrator_launcher.py
@@ -0,0 +1,157 @@
+import subprocess
+import time
+from ding.utils import K8sLauncher
+from .default_helper import one_time_warning
+
+
+class OrchestratorLauncher(object):
+ """
+ Overview:
+ Object to manage di-orchestrator in existing k8s cluster
+ Interfaces:
+ ``__init__``, ``create_orchestrator``, ``delete_orchestrator``
+ """
+
+ def __init__(
+ self,
+ version: str,
+ name: str = 'di-orchestrator',
+ cluster: K8sLauncher = None,
+ registry: str = 'diorchestrator',
+ cert_manager_version: str = 'v1.3.1',
+ cert_manager_registry: str = 'quay.io/jetstack'
+ ) -> None:
+ """
+ Overview:
+ Initialize the OrchestratorLauncher object.
+ Arguments:
+ - version (:obj:`str`): The version of di-orchestrator.
+ - name (:obj:`str`): The name of di-orchestrator.
+ - cluster (:obj:`K8sLauncher`): The k8s cluster to deploy di-orchestrator.
+ - registry (:obj:`str`): The docker registry to pull images.
+ - cert_manager_version (:obj:`str`): The version of cert-manager.
+ - cert_manager_registry (:obj:`str`): The docker registry to pull cert-manager images.
+ """
+
+ self.name = name
+ self.version = version
+ self.cluster = cluster
+ self.registry = registry
+ self.cert_manager_version = cert_manager_version
+ self.cert_manager_registry = cert_manager_registry
+
+ self._namespace = 'di-system'
+ self._webhook = 'di-webhook'
+ self._cert_manager_namespace = 'cert-manager'
+ self._cert_manager_webhook = 'cert-manager-webhook'
+
+ self.installer = 'https://raw.githubusercontent.com/opendilab/' + \
+ f'DI-orchestrator/{self.version}/config/di-manager.yaml'
+ self.cert_manager = 'https://github.com/jetstack/' + \
+ f'cert-manager/releases/download/{self.cert_manager_version}/cert-manager.yaml'
+
+ self._images = [
+ f'{self.registry}/di-operator:{self.version}',
+ f'{self.registry}/di-webhook:{self.version}',
+ f'{self.registry}/di-server:{self.version}',
+ f'{self.cert_manager_registry}/cert-manager-cainjector:{self.cert_manager_version}',
+ f'{self.cert_manager_registry}/cert-manager-controller:{self.cert_manager_version}',
+ f'{self.cert_manager_registry}/cert-manager-webhook:{self.cert_manager_version}',
+ ]
+
+ self._check_kubectl_tools()
+
+ def _check_kubectl_tools(self) -> None:
+ """
+ Overview:
+ Check if kubectl tools is installed.
+ """
+
+ args = ['which', 'kubectl']
+ proc = subprocess.Popen(args, stdout=subprocess.PIPE, stderr=subprocess.PIPE)
+ out, _ = proc.communicate()
+ if out.decode('utf-8') == '':
+ raise FileNotFoundError(
+ "No kubectl tools found, please install by executing ./ding/scripts/install-k8s-tools.sh"
+ )
+
+ def create_orchestrator(self) -> None:
+ """
+ Overview:
+ Create di-orchestrator in k8s cluster.
+ """
+
+ print('Creating orchestrator...')
+ if self.cluster is not None:
+ self.cluster.preload_images(self._images)
+
+ # create and wait for cert-manager to be available
+ create_components_from_config(self.cert_manager)
+ wait_to_be_ready(self._cert_manager_namespace, self._cert_manager_webhook)
+
+ # create and wait for di-orchestrator to be available
+ create_components_from_config(self.installer)
+ wait_to_be_ready(self._namespace, self._webhook)
+
+ def delete_orchestrator(self) -> None:
+ """
+ Overview:
+ Delete di-orchestrator in k8s cluster.
+ """
+
+ print('Deleting orchestrator...')
+ for item in [self.cert_manager, self.installer]:
+ args = ['kubectl', 'delete', '-f', f'{item}']
+ proc = subprocess.Popen(args, stderr=subprocess.PIPE)
+ _, err = proc.communicate()
+ err_str = err.decode('utf-8').strip()
+ if err_str != '' and 'WARN' not in err_str and \
+ 'NotFound' not in err_str:
+ raise RuntimeError(f'Failed to delete di-orchestrator: {err_str}')
+
+
+def create_components_from_config(config: str) -> None:
+ """
+ Overview:
+ Create components from config file.
+ Arguments:
+ - config (:obj:`str`): The config file.
+ """
+
+ args = ['kubectl', 'create', '-f', f'{config}']
+ proc = subprocess.Popen(args, stderr=subprocess.PIPE)
+ _, err = proc.communicate()
+ err_str = err.decode('utf-8').strip()
+ if err_str != '' and 'WARN' not in err_str:
+ if 'already exists' in err_str:
+ print(f'Components already exists: {config}')
+ else:
+ raise RuntimeError(f'Failed to launch components: {err_str}')
+
+
+def wait_to_be_ready(namespace: str, component: str, timeout: int = 120) -> None:
+ """
+ Overview:
+ Wait for the component to be ready.
+ Arguments:
+ - namespace (:obj:`str`): The namespace of the component.
+ - component (:obj:`str`): The name of the component.
+ - timeout (:obj:`int`): The timeout of waiting.
+ """
+
+ try:
+ from kubernetes import config, client, watch
+ except ModuleNotFoundError:
+ one_time_warning("You have not installed kubernetes package! Please try 'pip install DI-engine[k8s]'.")
+ exit(-1)
+
+ config.load_kube_config()
+ appv1 = client.AppsV1Api()
+ w = watch.Watch()
+ for event in w.stream(appv1.list_namespaced_deployment, namespace, timeout_seconds=timeout):
+ # print("Event: %s %s %s" % (event['type'], event['object'].kind, event['object'].metadata.name))
+ if event['object'].metadata.name.startswith(component) and \
+ event['object'].status.ready_replicas is not None and \
+ event['object'].status.ready_replicas >= 1:
+ print(f'component {component} is ready for serving')
+ w.stop()
diff --git a/DI-engine/ding/utils/profiler_helper.py b/DI-engine/ding/utils/profiler_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..96c2a1a076c108c16fadf1d906deddd532cb7449
--- /dev/null
+++ b/DI-engine/ding/utils/profiler_helper.py
@@ -0,0 +1,76 @@
+import atexit
+import pstats
+import io
+import cProfile
+import os
+
+
+def register_profiler(write_profile, pr, folder_path):
+ atexit.register(write_profile, pr, folder_path)
+
+
+class Profiler:
+ """
+ Overview:
+ A class for profiling code execution. It can be used as a context manager or a decorator.
+
+ Interfaces:
+ ``__init__``, ``mkdir``, ``write_profile``, ``profile``.
+ """
+
+ def __init__(self):
+ """
+ Overview:
+ Initialize the Profiler object.
+ """
+
+ self.pr = cProfile.Profile()
+
+ def mkdir(self, directory: str):
+ """
+ OverView:
+ Create a directory if it doesn't exist.
+
+ Arguments:
+ - directory (:obj:`str`): The path of the directory to be created.
+ """
+ if not os.path.exists(directory):
+ os.makedirs(directory)
+
+ def write_profile(self, pr: cProfile.Profile, folder_path: str):
+ """
+ OverView:
+ Write the profiling results to files.
+
+ Arguments:
+ - pr (:obj:`cProfile.Profile`): The profiler object containing the profiling results.
+ - folder_path (:obj:`str`): The path of the folder where the profiling files will be saved.
+ """
+ pr.disable()
+ s_tottime = io.StringIO()
+ s_cumtime = io.StringIO()
+
+ ps = pstats.Stats(pr, stream=s_tottime).sort_stats('tottime')
+ ps.print_stats()
+ with open(folder_path + "/profile_tottime.txt", 'w+') as f:
+ f.write(s_tottime.getvalue())
+
+ ps = pstats.Stats(pr, stream=s_cumtime).sort_stats('cumtime')
+ ps.print_stats()
+ with open(folder_path + "/profile_cumtime.txt", 'w+') as f:
+ f.write(s_cumtime.getvalue())
+
+ pr.dump_stats(folder_path + "/profile.prof")
+
+ def profile(self, folder_path="./tmp"):
+ """
+ OverView:
+ Enable profiling and save the results to files.
+
+ Arguments:
+ - folder_path (:obj:`str`): The path of the folder where the profiling files will be saved. \
+ Defaults to "./tmp".
+ """
+ self.mkdir(folder_path)
+ self.pr.enable()
+ register_profiler(self.write_profile, self.pr, folder_path)
diff --git a/DI-engine/ding/utils/pytorch_ddp_dist_helper.py b/DI-engine/ding/utils/pytorch_ddp_dist_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..13d9e1e299214a82629d1fe745dbd9934960b423
--- /dev/null
+++ b/DI-engine/ding/utils/pytorch_ddp_dist_helper.py
@@ -0,0 +1,264 @@
+from typing import Callable, Tuple, List, Any, Union
+from easydict import EasyDict
+
+import os
+import numpy as np
+import torch
+import torch.distributed as dist
+
+from .default_helper import error_wrapper
+
+# from .slurm_helper import get_master_addr
+
+
+def get_rank() -> int:
+ """
+ Overview:
+ Get the rank of current process in total world_size
+ """
+ # return int(os.environ.get('SLURM_PROCID', 0))
+ return error_wrapper(dist.get_rank, 0)()
+
+
+def get_world_size() -> int:
+ """
+ Overview:
+ Get the world_size(total process number in data parallel training)
+ """
+ # return int(os.environ.get('SLURM_NTASKS', 1))
+ return error_wrapper(dist.get_world_size, 1)()
+
+
+broadcast = dist.broadcast
+allgather = dist.all_gather
+broadcast_object_list = dist.broadcast_object_list
+
+
+def allreduce(x: torch.Tensor) -> None:
+ """
+ Overview:
+ All reduce the tensor ``x`` in the world
+ Arguments:
+ - x (:obj:`torch.Tensor`): the tensor to be reduced
+ """
+
+ dist.all_reduce(x)
+ x.div_(get_world_size())
+
+
+def allreduce_async(name: str, x: torch.Tensor) -> None:
+ """
+ Overview:
+ All reduce the tensor ``x`` in the world asynchronously
+ Arguments:
+ - name (:obj:`str`): the name of the tensor
+ - x (:obj:`torch.Tensor`): the tensor to be reduced
+ """
+
+ x.div_(get_world_size())
+ dist.all_reduce(x, async_op=True)
+
+
+def reduce_data(x: Union[int, float, torch.Tensor], dst: int) -> Union[int, float, torch.Tensor]:
+ """
+ Overview:
+ Reduce the tensor ``x`` to the destination process ``dst``
+ Arguments:
+ - x (:obj:`Union[int, float, torch.Tensor]`): the tensor to be reduced
+ - dst (:obj:`int`): the destination process
+ """
+
+ if np.isscalar(x):
+ x_tensor = torch.as_tensor([x]).cuda()
+ dist.reduce(x_tensor, dst)
+ return x_tensor.item()
+ elif isinstance(x, torch.Tensor):
+ dist.reduce(x, dst)
+ return x
+ else:
+ raise TypeError("not supported type: {}".format(type(x)))
+
+
+def allreduce_data(x: Union[int, float, torch.Tensor], op: str) -> Union[int, float, torch.Tensor]:
+ """
+ Overview:
+ All reduce the tensor ``x`` in the world
+ Arguments:
+ - x (:obj:`Union[int, float, torch.Tensor]`): the tensor to be reduced
+ - op (:obj:`str`): the operation to perform on data, support ``['sum', 'avg']``
+ """
+
+ assert op in ['sum', 'avg'], op
+ if np.isscalar(x):
+ x_tensor = torch.as_tensor([x]).cuda()
+ dist.all_reduce(x_tensor)
+ if op == 'avg':
+ x_tensor.div_(get_world_size())
+ return x_tensor.item()
+ elif isinstance(x, torch.Tensor):
+ dist.all_reduce(x)
+ if op == 'avg':
+ x.div_(get_world_size())
+ return x
+ else:
+ raise TypeError("not supported type: {}".format(type(x)))
+
+
+synchronize = torch.cuda.synchronize
+
+
+def get_group(group_size: int) -> List:
+ """
+ Overview:
+ Get the group segmentation of ``group_size`` each group
+ Arguments:
+ - group_size (:obj:`int`) the ``group_size``
+ """
+ rank = get_rank()
+ world_size = get_world_size()
+ if group_size is None:
+ group_size = world_size
+ assert (world_size % group_size == 0)
+ return simple_group_split(world_size, rank, world_size // group_size)
+
+
+def dist_mode(func: Callable) -> Callable:
+ """
+ Overview:
+ Wrap the function so that in can init and finalize automatically before each call
+ Arguments:
+ - func (:obj:`Callable`): the function to be wrapped
+ """
+
+ def wrapper(*args, **kwargs):
+ dist_init()
+ func(*args, **kwargs)
+ dist_finalize()
+
+ return wrapper
+
+
+def dist_init(backend: str = 'nccl',
+ addr: str = None,
+ port: str = None,
+ rank: int = None,
+ world_size: int = None) -> Tuple[int, int]:
+ """
+ Overview:
+ Initialize the distributed training setting
+ Arguments:
+ - backend (:obj:`str`): The backend of the distributed training, support ``['nccl', 'gloo']``
+ - addr (:obj:`str`): The address of the master node
+ - port (:obj:`str`): The port of the master node
+ - rank (:obj:`int`): The rank of current process
+ - world_size (:obj:`int`): The total number of processes
+ """
+
+ assert backend in ['nccl', 'gloo'], backend
+ os.environ['MASTER_ADDR'] = addr or os.environ.get('MASTER_ADDR', "localhost")
+ os.environ['MASTER_PORT'] = port or os.environ.get('MASTER_PORT', "10314") # hard-code
+
+ if rank is None:
+ local_id = os.environ.get('SLURM_LOCALID', os.environ.get('RANK', None))
+ if local_id is None:
+ raise RuntimeError("please indicate rank explicitly in dist_init method")
+ else:
+ rank = int(local_id)
+ if world_size is None:
+ ntasks = os.environ.get('SLURM_NTASKS', os.environ.get('WORLD_SIZE', None))
+ if ntasks is None:
+ raise RuntimeError("please indicate world_size explicitly in dist_init method")
+ else:
+ world_size = int(ntasks)
+
+ dist.init_process_group(backend=backend, rank=rank, world_size=world_size)
+
+ num_gpus = torch.cuda.device_count()
+ torch.cuda.set_device(rank % num_gpus)
+ world_size = get_world_size()
+ rank = get_rank()
+ return rank, world_size
+
+
+def dist_finalize() -> None:
+ """
+ Overview:
+ Finalize distributed training resources
+ """
+ # This operation usually hangs out so we ignore it temporally.
+ # dist.destroy_process_group()
+ pass
+
+
+class DDPContext:
+ """
+ Overview:
+ A context manager for ``linklink`` distribution
+ Interfaces:
+ ``__init__``, ``__enter__``, ``__exit__``
+ """
+
+ def __init__(self) -> None:
+ """
+ Overview:
+ Initialize the ``DDPContext``
+ """
+
+ pass
+
+ def __enter__(self) -> None:
+ """
+ Overview:
+ Initialize ``linklink`` distribution
+ """
+
+ dist_init()
+
+ def __exit__(self, *args, **kwargs) -> Any:
+ """
+ Overview:
+ Finalize ``linklink`` distribution
+ """
+
+ dist_finalize()
+
+
+def simple_group_split(world_size: int, rank: int, num_groups: int) -> List:
+ """
+ Overview:
+ Split the group according to ``worldsize``, ``rank`` and ``num_groups``
+ Arguments:
+ - world_size (:obj:`int`): The world size
+ - rank (:obj:`int`): The rank
+ - num_groups (:obj:`int`): The number of groups
+
+ .. note::
+ With faulty input, raise ``array split does not result in an equal division``
+ """
+ groups = []
+ rank_list = np.split(np.arange(world_size), num_groups)
+ rank_list = [list(map(int, x)) for x in rank_list]
+ for i in range(num_groups):
+ groups.append(dist.new_group(rank_list[i]))
+ group_size = world_size // num_groups
+ return groups[rank // group_size]
+
+
+def to_ddp_config(cfg: EasyDict) -> EasyDict:
+ """
+ Overview:
+ Convert the config to ddp config
+ Arguments:
+ - cfg (:obj:`EasyDict`): The config to be converted
+ """
+
+ w = get_world_size()
+ if 'batch_size' in cfg.policy:
+ cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w))
+ if 'batch_size' in cfg.policy.learn:
+ cfg.policy.learn.batch_size = int(np.ceil(cfg.policy.learn.batch_size / w))
+ if 'n_sample' in cfg.policy.collect:
+ cfg.policy.collect.n_sample = int(np.ceil(cfg.policy.collect.n_sample / w))
+ if 'n_episode' in cfg.policy.collect:
+ cfg.policy.collect.n_episode = int(np.ceil(cfg.policy.collect.n_episode / w))
+ return cfg
diff --git a/DI-engine/ding/utils/registry.py b/DI-engine/ding/utils/registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..1d55041ffbc91cbaef4ac80eb0ede358cb0b02d3
--- /dev/null
+++ b/DI-engine/ding/utils/registry.py
@@ -0,0 +1,165 @@
+import inspect
+import os
+from collections import OrderedDict
+from typing import Optional, Iterable, Callable
+
+_innest_error = True
+
+_DI_ENGINE_REG_TRACE_IS_ON = os.environ.get('DIENGINEREGTRACE', 'OFF').upper() == 'ON'
+
+
+class Registry(dict):
+ """
+ Overview:
+ A helper class for managing registering modules, it extends a dictionary
+ and provides a register functions.
+ Interfaces:
+ ``__init__``, ``register``, ``get``, ``build``, ``query``, ``query_details``
+ Examples:
+ creeting a registry:
+ >>> some_registry = Registry({"default": default_module})
+
+ There're two ways of registering new modules:
+ 1): normal way is just calling register function:
+ >>> def foo():
+ >>> ...
+ some_registry.register("foo_module", foo)
+ 2): used as decorator when declaring the module:
+ >>> @some_registry.register("foo_module")
+ >>> @some_registry.register("foo_modeul_nickname")
+ >>> def foo():
+ >>> ...
+
+ Access of module is just like using a dictionary, eg:
+ >>> f = some_registry["foo_module"]
+ """
+
+ def __init__(self, *args, **kwargs) -> None:
+ """
+ Overview:
+ Initialize the Registry object.
+ Arguments:
+ - args (:obj:`Tuple`): The arguments passed to the ``__init__`` function of the parent class, \
+ dict.
+ - kwargs (:obj:`Dict`): The keyword arguments passed to the ``__init__`` function of the parent class, \
+ dict.
+ """
+
+ super(Registry, self).__init__(*args, **kwargs)
+ self.__trace__ = dict()
+
+ def register(
+ self,
+ module_name: Optional[str] = None,
+ module: Optional[Callable] = None,
+ force_overwrite: bool = False
+ ) -> Callable:
+ """
+ Overview:
+ Register the module.
+ Arguments:
+ - module_name (:obj:`Optional[str]`): The name of the module.
+ - module (:obj:`Optional[Callable]`): The module to be registered.
+ - force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name.
+ """
+
+ if _DI_ENGINE_REG_TRACE_IS_ON:
+ frame = inspect.stack()[1][0]
+ info = inspect.getframeinfo(frame)
+ filename = info.filename
+ lineno = info.lineno
+ # used as function call
+ if module is not None:
+ assert module_name is not None
+ Registry._register_generic(self, module_name, module, force_overwrite)
+ if _DI_ENGINE_REG_TRACE_IS_ON:
+ self.__trace__[module_name] = (filename, lineno)
+ return
+
+ # used as decorator
+ def register_fn(fn: Callable) -> Callable:
+ if module_name is None:
+ name = fn.__name__
+ else:
+ name = module_name
+ Registry._register_generic(self, name, fn, force_overwrite)
+ if _DI_ENGINE_REG_TRACE_IS_ON:
+ self.__trace__[name] = (filename, lineno)
+ return fn
+
+ return register_fn
+
+ @staticmethod
+ def _register_generic(module_dict: dict, module_name: str, module: Callable, force_overwrite: bool = False) -> None:
+ """
+ Overview:
+ Register the module.
+ Arguments:
+ - module_dict (:obj:`dict`): The dict to store the module.
+ - module_name (:obj:`str`): The name of the module.
+ - module (:obj:`Callable`): The module to be registered.
+ - force_overwrite (:obj:`bool`): Whether to overwrite the module with the same name.
+ """
+
+ if not force_overwrite:
+ assert module_name not in module_dict, module_name
+ module_dict[module_name] = module
+
+ def get(self, module_name: str) -> Callable:
+ """
+ Overview:
+ Get the module.
+ Arguments:
+ - module_name (:obj:`str`): The name of the module.
+ """
+
+ return self[module_name]
+
+ def build(self, obj_type: str, *obj_args, **obj_kwargs) -> object:
+ """
+ Overview:
+ Build the object.
+ Arguments:
+ - obj_type (:obj:`str`): The type of the object.
+ - obj_args (:obj:`Tuple`): The arguments passed to the object.
+ - obj_kwargs (:obj:`Dict`): The keyword arguments passed to the object.
+ """
+
+ try:
+ build_fn = self[obj_type]
+ return build_fn(*obj_args, **obj_kwargs)
+ except Exception as e:
+ # get build_fn fail
+ if isinstance(e, KeyError):
+ raise KeyError("not support buildable-object type: {}".format(obj_type))
+ # build_fn execution fail
+ global _innest_error
+ if _innest_error:
+ argspec = inspect.getfullargspec(build_fn)
+ message = 'for {}(alias={})'.format(build_fn, obj_type)
+ message += '\nExpected args are:{}'.format(argspec)
+ message += '\nGiven args are:{}/{}'.format(argspec, obj_kwargs.keys())
+ message += '\nGiven args details are:{}/{}'.format(argspec, obj_kwargs)
+ _innest_error = False
+ raise e
+
+ def query(self) -> Iterable:
+ """
+ Overview:
+ all registered module names.
+ """
+
+ return self.keys()
+
+ def query_details(self, aliases: Optional[Iterable] = None) -> OrderedDict:
+ """
+ Overview:
+ Get the details of the registered modules.
+ Arguments:
+ - aliases (:obj:`Optional[Iterable]`): The aliases of the modules.
+ """
+
+ assert _DI_ENGINE_REG_TRACE_IS_ON, "please exec 'export DIENGINEREGTRACE=ON' first"
+ if aliases is None:
+ aliases = self.keys()
+ return OrderedDict((alias, self.__trace__[alias]) for alias in aliases)
diff --git a/DI-engine/ding/utils/registry_factory.py b/DI-engine/ding/utils/registry_factory.py
new file mode 100644
index 0000000000000000000000000000000000000000..87d43fd627b8cfee1d2f5e41dedf7b0350c48d9f
--- /dev/null
+++ b/DI-engine/ding/utils/registry_factory.py
@@ -0,0 +1,45 @@
+from .registry import Registry
+
+POLICY_REGISTRY = Registry()
+ENV_REGISTRY = Registry()
+ENV_WRAPPER_REGISTRY = Registry()
+LEARNER_REGISTRY = Registry()
+COMM_LEARNER_REGISTRY = Registry()
+SERIAL_COLLECTOR_REGISTRY = Registry()
+PARALLEL_COLLECTOR_REGISTRY = Registry()
+COMM_COLLECTOR_REGISTRY = Registry()
+BUFFER_REGISTRY = Registry()
+COMMANDER_REGISTRY = Registry()
+LEAGUE_REGISTRY = Registry()
+PLAYER_REGISTRY = Registry()
+MODEL_REGISTRY = Registry()
+ENV_MANAGER_REGISTRY = Registry()
+REWARD_MODEL_REGISTRY = Registry()
+DATASET_REGISTRY = Registry()
+SERIAL_EVALUATOR_REGISTRY = Registry()
+MQ_REGISTRY = Registry()
+WORLD_MODEL_REGISTRY = Registry()
+STOCHASTIC_OPTIMIZER_REGISTRY = Registry()
+
+registries = {
+ 'policy': POLICY_REGISTRY,
+ 'env': ENV_REGISTRY,
+ 'env_wrapper': ENV_WRAPPER_REGISTRY,
+ 'model': MODEL_REGISTRY,
+ 'reward_model': REWARD_MODEL_REGISTRY,
+ 'learner': LEARNER_REGISTRY,
+ 'serial_collector': SERIAL_COLLECTOR_REGISTRY,
+ 'parallel_collector': PARALLEL_COLLECTOR_REGISTRY,
+ 'env_manager': ENV_MANAGER_REGISTRY,
+ 'comm_learner': COMM_LEARNER_REGISTRY,
+ 'comm_collector': COMM_COLLECTOR_REGISTRY,
+ 'commander': COMMANDER_REGISTRY,
+ 'league': LEAGUE_REGISTRY,
+ 'player': PLAYER_REGISTRY,
+ 'buffer': BUFFER_REGISTRY,
+ 'dataset': DATASET_REGISTRY,
+ 'serial_evaluator': SERIAL_EVALUATOR_REGISTRY,
+ 'message_queue': MQ_REGISTRY,
+ 'world_model': WORLD_MODEL_REGISTRY,
+ 'stochastic_optimizer': STOCHASTIC_OPTIMIZER_REGISTRY,
+}
diff --git a/DI-engine/ding/utils/render_helper.py b/DI-engine/ding/utils/render_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..11aed759412df0f44d89757be4b1ce7000a9a799
--- /dev/null
+++ b/DI-engine/ding/utils/render_helper.py
@@ -0,0 +1,76 @@
+from typing import TYPE_CHECKING, Optional
+from numpy import ndarray
+
+if TYPE_CHECKING:
+ from ding.envs import BaseEnv, BaseEnvManager
+
+
+def render_env(env, render_mode: Optional[str] = 'rgb_array') -> "ndarray":
+ """
+ Overview:
+ Render the environment's current frame.
+ Arguments:
+ - env (:obj:`gym.Env`): DI-engine env instance.
+ - render_mode (:obj:`str`): Render mode.
+ Returns:
+ - frame (:obj:`numpy.ndarray`): [H * W * C]
+ """
+ if hasattr(env, 'sim'):
+ # mujoco: mujoco frame is unside-down by default
+ return env.sim.render(camera_name='track', height=128, width=128)[::-1]
+ else:
+ # other envs
+ return env.render(mode=render_mode)
+
+
+def render(env: "BaseEnv", render_mode: Optional[str] = 'rgb_array') -> "ndarray":
+ """
+ Overview:
+ Render the environment's current frame.
+ Arguments:
+ - env (:obj:`BaseEnv`): DI-engine env instance.
+ - render_mode (:obj:`str`): Render mode.
+ Returns:
+ - frame (:obj:`numpy.ndarray`): [H * W * C]
+ """
+ gym_env = env._env
+ return render_env(gym_env, render_mode=render_mode)
+
+
+def get_env_fps(env) -> "int":
+ """
+ Overview:
+ Get the environment's fps.
+ Arguments:
+ - env (:obj:`gym.Env`): DI-engine env instance.
+ Returns:
+ - fps (:obj:`int`).
+ """
+
+ if hasattr(env, 'model'):
+ # mujoco
+ fps = 1 / env.model.opt.timestep
+ elif hasattr(env, 'env') and 'video.frames_per_second' in env.env.metadata.keys():
+ # classic control
+ fps = env.env.metadata['video.frames_per_second']
+ else:
+ # atari and other envs
+ fps = 30
+ return fps
+
+
+def fps(env_manager: "BaseEnvManager") -> "int":
+ """
+ Overview:
+ Render the environment's fps.
+ Arguments:
+ - env (:obj:`BaseEnvManager`): DI-engine env manager instance.
+ Returns:
+ - fps (:obj:`int`).
+ """
+ try:
+ # env_ref is a ding gym environment
+ gym_env = env_manager.env_ref._env
+ return get_env_fps(gym_env)
+ except:
+ return 30
diff --git a/DI-engine/ding/utils/scheduler_helper.py b/DI-engine/ding/utils/scheduler_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..d37ce97c5297a8d28009f398dede892220d83967
--- /dev/null
+++ b/DI-engine/ding/utils/scheduler_helper.py
@@ -0,0 +1,177 @@
+from .default_helper import deep_merge_dicts
+from easydict import EasyDict
+
+
+class Scheduler(object):
+ """
+ Overview:
+ Update learning parameters when the trueskill metrics has stopped improving.
+ For example, models often benefits from reducing entropy weight once the learning process stagnates.
+ This scheduler reads a metrics quantity and if no improvement is seen for a 'patience' number of epochs,
+ the corresponding parameter is increased or decreased, which decides on the 'schedule_mode'.
+ Arguments:
+ - schedule_flag (:obj:`bool`): Indicates whether to use scheduler in training pipeline.
+ Default: False
+ - schedule_mode (:obj:`str`): One of 'reduce', 'add','multi','div'. The schecule_mode
+ decides the way of updating the parameters. Default:'reduce'.
+ - factor (:obj:`float`) : Amount (greater than 0) by which the parameter will be
+ increased/decreased. Default: 0.05
+ - change_range (:obj:`list`): Indicates the minimum and maximum value
+ the parameter can reach respectively. Default: [-1,1]
+ - threshold (:obj:`float`): Threshold for measuring the new optimum,
+ to only focus on significant changes. Default: 1e-4.
+ - optimize_mode (:obj:`str`): One of 'min', 'max', which indicates the sign of
+ optimization objective. Dynamic_threshold = last_metrics + threshold in `max`
+ mode or last_metrics - threshold in `min` mode. Default: 'min'
+ - patience (:obj:`int`): Number of epochs with no improvement after which
+ the parameter will be updated. For example, if `patience = 2`, then we
+ will ignore the first 2 epochs with no improvement, and will only update
+ the parameter after the 3rd epoch if the metrics still hasn't improved then.
+ Default: 10.
+ - cooldown (:obj:`int`): Number of epochs to wait before resuming
+ normal operation after the parameter has been updated. Default: 0.
+ Interfaces:
+ __init__, update_param, step
+ Property:
+ in_cooldown, is_better
+ """
+
+ config = dict(
+ schedule_flag=False,
+ schedule_mode='reduce',
+ factor=0.05,
+ change_range=[-1, 1],
+ threshold=1e-4,
+ optimize_mode='min',
+ patience=10,
+ cooldown=0,
+ )
+
+ def __init__(self, merged_scheduler_config: EasyDict) -> None:
+ """
+ Overview:
+ Initialize the scheduler.
+ Arguments:
+ - merged_scheduler_config (:obj:`EasyDict`): the scheduler config, which merges the user
+ config and defaul config
+ """
+
+ schedule_mode = merged_scheduler_config.schedule_mode
+ factor = merged_scheduler_config.factor
+ change_range = merged_scheduler_config.change_range
+ threshold = merged_scheduler_config.threshold
+ optimize_mode = merged_scheduler_config.optimize_mode
+ patience = merged_scheduler_config.patience
+ cooldown = merged_scheduler_config.cooldown
+
+ assert schedule_mode in [
+ 'reduce', 'add', 'multi', 'div'
+ ], 'The schedule mode should be one of [\'reduce\', \'add\', \'multi\',\'div\']'
+ self.schedule_mode = schedule_mode
+
+ assert isinstance(factor, (float, int)), 'The factor should be a float/int number '
+ assert factor > 0, 'The factor should be greater than 0'
+ self.factor = float(factor)
+
+ assert isinstance(change_range,
+ list) and len(change_range) == 2, 'The change_range should be a list with 2 float numbers'
+ assert (isinstance(change_range[0], (float, int))) and (
+ isinstance(change_range[1], (float, int))
+ ), 'The change_range should be a list with 2 float/int numbers'
+ assert change_range[0] < change_range[1], 'The first num should be smaller than the second num'
+ self.change_range = change_range
+
+ assert isinstance(threshold, (float, int)), 'The threshold should be a float/int number'
+ self.threshold = threshold
+
+ assert optimize_mode in ['min', 'max'], 'The optimize_mode should be one of [\'min\', \'max\']'
+ self.optimize_mode = optimize_mode
+
+ assert isinstance(patience, int), 'The patience should be a integer greater than or equal to 0'
+ assert patience >= 0, 'The patience should be a integer greater than or equal to 0'
+ self.patience = patience
+
+ assert isinstance(cooldown, int), 'The cooldown_counter should be a integer greater than or equal to 0'
+ assert cooldown >= 0, 'The cooldown_counter should be a integer greater than or equal to 0'
+ self.cooldown = cooldown
+ self.cooldown_counter = cooldown
+
+ self.last_metrics = None
+ self.bad_epochs_num = 0
+
+ def step(self, metrics: float, param: float) -> float:
+ """
+ Overview:
+ Decides whether to update the scheduled parameter
+ Args:
+ - metrics (:obj:`float`): current input metrics
+ - param (:obj:`float`): parameter need to be updated
+ Returns:
+ - step_param (:obj:`float`): parameter after one step
+ """
+ assert isinstance(metrics, float), 'The metrics should be converted to a float number'
+ cur_metrics = metrics
+
+ if self.is_better(cur_metrics):
+ self.bad_epochs_num = 0
+ else:
+ self.bad_epochs_num += 1
+ self.last_metrics = cur_metrics
+
+ if self.in_cooldown:
+ self.cooldown_counter -= 1
+ self.bad_epochs_num = 0 # ignore any bad epochs in cooldown
+
+ if self.bad_epochs_num > self.patience:
+ param = self.update_param(param)
+ self.cooldown_counter = self.cooldown
+ self.bad_epochs_num = 0
+ return param
+
+ def update_param(self, param: float) -> float:
+ """
+ Overview:
+ update the scheduling parameter
+ Args:
+ - param (:obj:`float`): parameter need to be updated
+ Returns:
+ - updated param (:obj:`float`): parameter after updating
+ """
+ schedule_fn = {
+ 'reduce': lambda x, y, z: max(x - y, z[0]),
+ 'add': lambda x, y, z: min(x + y, z[1]),
+ 'multi': lambda x, y, z: min(x * y, z[1]) if y >= 1 else max(x * y, z[0]),
+ 'div': lambda x, y, z: max(x / y, z[0]) if y >= 1 else min(x / y, z[1]),
+ }
+
+ schedule_mode_list = list(schedule_fn.keys())
+
+ if self.schedule_mode in schedule_mode_list:
+ return schedule_fn[self.schedule_mode](param, self.factor, self.change_range)
+ else:
+ raise KeyError("invalid schedule_mode({}) in {}".format(self.schedule_mode, schedule_mode_list))
+
+ @property
+ def in_cooldown(self) -> bool:
+ """
+ Overview:
+ Checks whether the scheduler is in cooldown peried. If in cooldown, the scheduler
+ will ignore any bad epochs.
+ """
+ return self.cooldown_counter > 0
+
+ def is_better(self, cur: float) -> bool:
+ """
+ Overview:
+ Checks whether the current metrics is better than last matric with respect to threshold.
+ Args:
+ - cur (:obj:`float`): current metrics
+ """
+ if self.last_metrics is None:
+ return True
+
+ elif self.optimize_mode == 'min':
+ return cur < self.last_metrics - self.threshold
+
+ elif self.optimize_mode == 'max':
+ return cur > self.last_metrics + self.threshold
diff --git a/DI-engine/ding/utils/segment_tree.py b/DI-engine/ding/utils/segment_tree.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c87280ab4393d73cad38f17a745829a1cd88f16
--- /dev/null
+++ b/DI-engine/ding/utils/segment_tree.py
@@ -0,0 +1,284 @@
+from functools import partial, lru_cache
+from typing import Callable, Optional
+
+import numpy as np
+
+import ding
+from .default_helper import one_time_warning
+
+
+@lru_cache()
+def njit():
+ """
+ Overview:
+ Decorator to compile a function using numba.
+ """
+
+ try:
+ if ding.enable_numba:
+ import numba
+ from numba import njit as _njit
+ version = numba.__version__
+ middle_version = version.split(".")[1]
+ if int(middle_version) < 53:
+ _njit = partial # noqa
+ one_time_warning(
+ "Due to your numba version <= 0.53.0, DI-engine disables it. And you can install \
+ numba==0.53.0 if you want to speed up something"
+ )
+ else:
+ _njit = partial
+ except ImportError:
+ one_time_warning("If you want to use numba to speed up segment tree, please install numba first")
+ _njit = partial
+ return _njit
+
+
+class SegmentTree:
+ """
+ Overview:
+ Segment tree data structure, implemented by the tree-like array. Only the leaf nodes are real value,
+ non-leaf nodes are to do some operations on its left and right child.
+ Interfaces:
+ ``__init__``, ``reduce``, ``__setitem__``, ``__getitem__``
+ """
+
+ def __init__(self, capacity: int, operation: Callable, neutral_element: Optional[float] = None) -> None:
+ """
+ Overview:
+ Initialize the segment tree. Tree's root node is at index 1.
+ Arguments:
+ - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes), should be the power of 2.
+ - operation (:obj:`function`): The operation function to construct the tree, e.g. sum, max, min, etc.
+ - neutral_element (:obj:`float` or :obj:`None`): The value of the neutral element, which is used to init \
+ all nodes value in the tree.
+ """
+ assert capacity > 0 and capacity & (capacity - 1) == 0
+ self.capacity = capacity
+ self.operation = operation
+ # Set neutral value(initial value) for all elements.
+ if neutral_element is None:
+ if operation == 'sum':
+ neutral_element = 0.
+ elif operation == 'min':
+ neutral_element = np.inf
+ elif operation == 'max':
+ neutral_element = -np.inf
+ else:
+ raise ValueError("operation argument should be in min, max, sum (built in python functions).")
+ self.neutral_element = neutral_element
+ # Index 1 is the root; Index ranging in [capacity, 2 * capacity - 1] are the leaf nodes.
+ # For each parent node with index i, left child is value[2*i] and right child is value[2*i+1].
+ self.value = np.full([capacity * 2], neutral_element)
+ self._compile()
+
+ def reduce(self, start: int = 0, end: Optional[int] = None) -> float:
+ """
+ Overview:
+ Reduce the tree in range ``[start, end)``
+ Arguments:
+ - start (:obj:`int`): Start index(relative index, the first leaf node is 0), default set to 0
+ - end (:obj:`int` or :obj:`None`): End index(relative index), default set to ``self.capacity``
+ Returns:
+ - reduce_result (:obj:`float`): The reduce result value, which is dependent on data type and operation
+ """
+ # TODO(nyz) check if directly reduce from the array(value) can be faster
+ if end is None:
+ end = self.capacity
+ assert (start < end)
+ # Change to absolute leaf index by adding capacity.
+ start += self.capacity
+ end += self.capacity
+ return _reduce(self.value, start, end, self.neutral_element, self.operation)
+
+ def __setitem__(self, idx: int, val: float) -> None:
+ """
+ Overview:
+ Set ``leaf[idx] = val``; Then update the related nodes.
+ Arguments:
+ - idx (:obj:`int`): Leaf node index(relative index), should add ``capacity`` to change to absolute index.
+ - val (:obj:`float`): The value that will be assigned to ``leaf[idx]``.
+ """
+ assert (0 <= idx < self.capacity), idx
+ # ``idx`` should add ``capacity`` to change to absolute index.
+ _setitem(self.value, idx + self.capacity, val, self.operation)
+
+ def __getitem__(self, idx: int) -> float:
+ """
+ Overview:
+ Get ``leaf[idx]``
+ Arguments:
+ - idx (:obj:`int`): Leaf node ``index(relative index)``, add ``capacity`` to change to absolute index.
+ Returns:
+ - val (:obj:`float`): The value of ``leaf[idx]``
+ """
+ assert (0 <= idx < self.capacity)
+ return self.value[idx + self.capacity]
+
+ def _compile(self) -> None:
+ """
+ Overview:
+ Compile the functions using numba.
+ """
+
+ f64 = np.array([0, 1], dtype=np.float64)
+ f32 = np.array([0, 1], dtype=np.float32)
+ i64 = np.array([0, 1], dtype=np.int64)
+ for d in [f64, f32, i64]:
+ _setitem(d, 0, 3.0, 'sum')
+ _reduce(d, 0, 1, 0.0, 'min')
+ _find_prefixsum_idx(d, 1, 0.5, 0.0)
+
+
+class SumSegmentTree(SegmentTree):
+ """
+ Overview:
+ Sum segment tree, which is inherited from ``SegmentTree``. Init by passing ``operation='sum'``.
+ Interfaces:
+ ``__init__``, ``find_prefixsum_idx``
+ """
+
+ def __init__(self, capacity: int) -> None:
+ """
+ Overview:
+ Init sum segment tree by passing ``operation='sum'``
+ Arguments:
+ - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes).
+ """
+ super(SumSegmentTree, self).__init__(capacity, operation='sum')
+
+ def find_prefixsum_idx(self, prefixsum: float, trust_caller: bool = True) -> int:
+ """
+ Overview:
+ Find the highest non-zero index i, sum_{j}leaf[j] <= ``prefixsum`` (where 0 <= j < i)
+ and sum_{j}leaf[j] > ``prefixsum`` (where 0 <= j < i+1)
+ Arguments:
+ - prefixsum (:obj:`float`): The target prefixsum.
+ - trust_caller (:obj:`bool`): Whether to trust caller, which means whether to check whether \
+ this tree's sum is greater than the input ``prefixsum`` by calling ``reduce`` function.
+ Default set to True.
+ Returns:
+ - idx (:obj:`int`): Eligible index.
+ """
+ if not trust_caller:
+ assert 0 <= prefixsum <= self.reduce() + 1e-5, prefixsum
+ return _find_prefixsum_idx(self.value, self.capacity, prefixsum, self.neutral_element)
+
+
+class MinSegmentTree(SegmentTree):
+ """
+ Overview:
+ Min segment tree, which is inherited from ``SegmentTree``. Init by passing ``operation='min'``.
+ Interfaces:
+ ``__init__``
+ """
+
+ def __init__(self, capacity: int) -> None:
+ """
+ Overview:
+ Initialize sum segment tree by passing ``operation='min'``
+ Arguments:
+ - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes).
+ """
+ super(MinSegmentTree, self).__init__(capacity, operation='min')
+
+
+@njit()
+def _setitem(tree: np.ndarray, idx: int, val: float, operation: str) -> None:
+ """
+ Overview:
+ Set ``tree[idx] = val``; Then update the related nodes.
+ Arguments:
+ - tree (:obj:`np.ndarray`): The tree array.
+ - idx (:obj:`int`): The index of the leaf node.
+ - val (:obj:`float`): The value that will be assigned to ``leaf[idx]``.
+ - operation (:obj:`str`): The operation function to construct the tree, e.g. sum, max, min, etc.
+ """
+
+ tree[idx] = val
+ # Update from specified node to the root node
+ while idx > 1:
+ idx = idx >> 1 # To parent node idx
+ left, right = tree[2 * idx], tree[2 * idx + 1]
+ if operation == 'sum':
+ tree[idx] = left + right
+ elif operation == 'min':
+ tree[idx] = min([left, right])
+
+
+@njit()
+def _reduce(tree: np.ndarray, start: int, end: int, neutral_element: float, operation: str) -> float:
+ """
+ Overview:
+ Reduce the tree in range ``[start, end)``
+ Arguments:
+ - tree (:obj:`np.ndarray`): The tree array.
+ - start (:obj:`int`): Start index(relative index, the first leaf node is 0).
+ - end (:obj:`int`): End index(relative index).
+ - neutral_element (:obj:`float`): The value of the neutral element, which is used to init \
+ all nodes value in the tree.
+ - operation (:obj:`str`): The operation function to construct the tree, e.g. sum, max, min, etc.
+ """
+
+ # Nodes in 【start, end) will be aggregated
+ result = neutral_element
+ while start < end:
+ if start & 1:
+ # If current start node (tree[start]) is a right child node, operate on start node and increase start by 1
+ if operation == 'sum':
+ result = result + tree[start]
+ elif operation == 'min':
+ result = min([result, tree[start]])
+ start += 1
+ if end & 1:
+ # If current end node (tree[end - 1]) is right child node, decrease end by 1 and operate on end node
+ end -= 1
+ if operation == 'sum':
+ result = result + tree[end]
+ elif operation == 'min':
+ result = min([result, tree[end]])
+ # Both start and end transform to respective parent node
+ start = start >> 1
+ end = end >> 1
+ return result
+
+
+@njit()
+def _find_prefixsum_idx(tree: np.ndarray, capacity: int, prefixsum: float, neutral_element: float) -> int:
+ """
+ Overview:
+ Find the highest non-zero index i, sum_{j}leaf[j] <= ``prefixsum`` (where 0 <= j < i)
+ and sum_{j}leaf[j] > ``prefixsum`` (where 0 <= j < i+1)
+ Arguments:
+ - tree (:obj:`np.ndarray`): The tree array.
+ - capacity (:obj:`int`): Capacity of the tree (the number of the leaf nodes).
+ - prefixsum (:obj:`float`): The target prefixsum.
+ - neutral_element (:obj:`float`): The value of the neutral element, which is used to init \
+ all nodes value in the tree.
+ """
+
+ # The function is to find a non-leaf node's index which satisfies:
+ # self.value[idx] > input prefixsum and self.value[idx + 1] <= input prefixsum
+ # In other words, we can assume that there are intervals: [num_0, num_1), [num_1, num_2), ... [num_k, num_k+1),
+ # the function is to find input prefixsum falls in which interval and return the interval's index.
+ idx = 1 # start from root node
+ while idx < capacity:
+ child_base = 2 * idx
+ if tree[child_base] > prefixsum:
+ idx = child_base
+ else:
+ prefixsum -= tree[child_base]
+ idx = child_base + 1
+ # Special case: The last element of ``self.value`` is neutral_element(0),
+ # and caller wants to ``find_prefixsum_idx(root_value)``.
+ # However, input prefixsum should be smaller than root_value.
+ if idx == 2 * capacity - 1 and tree[idx] == neutral_element:
+ tmp = idx
+ while tmp >= capacity and tree[tmp] == neutral_element:
+ tmp -= 1
+ if tmp != capacity:
+ idx = tmp
+ else:
+ raise ValueError("All elements in tree are the neutral_element(0), can't find non-zero element")
+ assert (tree[idx] != neutral_element)
+ return idx - capacity
diff --git a/DI-engine/ding/utils/slurm_helper.py b/DI-engine/ding/utils/slurm_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..c03b3e94631dafad7368d528146f0dc473f1dc2a
--- /dev/null
+++ b/DI-engine/ding/utils/slurm_helper.py
@@ -0,0 +1,123 @@
+import os
+import subprocess
+from typing import Optional, Dict, Tuple
+
+MANAGER_NODE_TABLE = {
+ '10.198.8': '10.198.8.31',
+ '10.198.6': '10.198.6.31',
+ '10.5.38': '10.5.38.31',
+ '10.5.39': '10.5.38.31',
+ '10.5.36': '10.5.36.31',
+ '10.5.37': '10.5.36.31',
+ '10.10.30': '10.10.30.91',
+}
+
+
+def get_ip() -> str:
+ """
+ Overview:
+ Get the ip of the current node
+ """
+
+ assert os.environ.get('SLURMD_NODENAME'), 'not found SLURMD_NODENAME env variable'
+ # expecting nodename to be like: 'SH-IDC1-10-5-36-64'
+ nodename = os.environ.get('SLURMD_NODENAME', '')
+ myaddr = '.'.join(nodename.split('-')[-4:])
+ return myaddr
+
+
+def get_manager_node_ip(node_ip: Optional[str] = None) -> str:
+ """
+ Overview:
+ Look up the manager node of the slurm cluster and return the node ip
+ Arguments:
+ - node_ip (:obj:`Optional[str]`): The ip of the current node
+ """
+ if 'SLURM_JOB_ID' not in os.environ:
+ from ditk import logging
+ logging.error(
+ 'We are not running on slurm!, \'auto\' for manager_ip or '
+ 'coordinator_ip is only intended for running on multiple slurm clusters'
+ )
+ return '127.0.0.1'
+ node_ip = node_ip or get_ip()
+ learner_manager_ip_prefix = '.'.join(node_ip.split('.')[0:3])
+
+ if learner_manager_ip_prefix in MANAGER_NODE_TABLE:
+ return MANAGER_NODE_TABLE[learner_manager_ip_prefix]
+ else:
+ raise KeyError("Cluster not found, please add it to the MANAGER_NODE_TABLE in {}".format(__file__))
+
+
+# get all info of cluster
+def get_cls_info() -> Dict[str, list]:
+ """
+ Overview:
+ Get the cluster info
+ """
+
+ ret_dict = {}
+ info = subprocess.getoutput('sinfo -Nh').split('\n')
+ for line in info:
+ line = line.strip().split()
+ if len(line) != 4:
+ continue
+ node, _, partition, state = line
+ if partition not in ret_dict:
+ ret_dict[partition] = []
+ assert node not in ret_dict[partition]
+ if state in ['idle', 'mix']:
+ ret_dict[partition].append(node)
+
+ return ret_dict
+
+
+def node_to_partition(target_node: str) -> Tuple[str, str]:
+ """
+ Overview:
+ Get the partition of the target node
+ Arguments:
+ - target_node (:obj:`str`): The target node
+ """
+
+ info = subprocess.getoutput('sinfo -Nh').split('\n')
+ for line in info:
+ line = line.strip().split()
+ if len(line) != 4:
+ continue
+ node, _, partition, state = line
+ if node == target_node:
+ return partition
+ raise RuntimeError("not found target_node: {}".format(target_node))
+
+
+def node_to_host(node: str) -> str:
+ """
+ Overview:
+ Get the host of the node
+ Arguments:
+ - node (:obj:`str`): The node
+ """
+
+ return '.'.join(node.split('-')[-4:])
+
+
+def find_free_port_slurm(node: str) -> int:
+ """
+ Overview:
+ Find a free port on the node
+ Arguments:
+ - node (:obj:`str`): The node
+ """
+
+ partition = node_to_partition(node)
+ if partition == 'spring_scheduler':
+ comment = '--comment=spring-submit'
+ else:
+ comment = ''
+ output = subprocess.getoutput(
+ "srun -p {} -w {} {} python -c \"from ding.utils import find_free_port; print('port' + str(find_free_port(0)))\"" # noqa
+ .format(partition, node, comment)
+ )
+ port = output.split('port')[-1]
+ return int(port)
diff --git a/DI-engine/ding/utils/system_helper.py b/DI-engine/ding/utils/system_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..915ef380e9134fc3e51acea4b1aa2fcc0bed0dca
--- /dev/null
+++ b/DI-engine/ding/utils/system_helper.py
@@ -0,0 +1,87 @@
+import os
+import socket
+import time
+import uuid
+from contextlib import closing
+from threading import Thread
+from typing import Any
+
+
+def get_ip() -> str:
+ """
+ Overview:
+ Get the ``ip(host)`` of socket
+ Returns:
+ - ip(:obj:`str`): The corresponding ip
+ """
+ # beware: return 127.0.0.1 on some slurm nodes
+ myname = socket.getfqdn(socket.gethostname())
+ myaddr = socket.gethostbyname(myname)
+
+ return myaddr
+
+
+def get_pid() -> int:
+ """
+ Overview:
+ ``os.getpid``
+ """
+ return os.getpid()
+
+
+def get_task_uid() -> str:
+ """
+ Overview:
+ Get the slurm ``job_id``, ``pid`` and ``uid``
+ """
+ return '{}_{}'.format(str(uuid.uuid4()), str(time.time())[-6:])
+
+
+class PropagatingThread(Thread):
+ """
+ Overview:
+ Subclass of Thread that propagates execution exception in the thread to the caller
+ Interfaces:
+ ``run``, ``join``
+ Examples:
+ >>> def func():
+ >>> raise Exception()
+ >>> t = PropagatingThread(target=func, args=())
+ >>> t.start()
+ >>> t.join()
+ """
+
+ def run(self) -> None:
+ """
+ Overview:
+ Run the thread
+ """
+
+ self.exc = None
+ try:
+ self.ret = self._target(*self._args, **self._kwargs)
+ except Exception as e:
+ self.exc = e
+
+ def join(self) -> Any:
+ """
+ Overview:
+ Join the thread
+ """
+
+ super(PropagatingThread, self).join()
+ if self.exc:
+ raise RuntimeError('Exception in thread({})'.format(id(self))) from self.exc
+ return self.ret
+
+
+def find_free_port(host: str) -> int:
+ """
+ Overview:
+ Look up the free port list and return one
+ Arguments:
+ - host (:obj:`str`): The host
+ """
+ with closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
+ s.bind(('', 0))
+ return s.getsockname()[1]
diff --git a/DI-engine/ding/utils/tests/config/k8s-config.yaml b/DI-engine/ding/utils/tests/config/k8s-config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..a9cc028a89627f9d2479e3c98ca69176b4115cde
--- /dev/null
+++ b/DI-engine/ding/utils/tests/config/k8s-config.yaml
@@ -0,0 +1,7 @@
+type: k3s # k3s or local
+name: di-cluster
+servers: 1 # # of k8s masters
+agents: 0 # # of k8s nodes
+preload_images:
+- busybox:latest
+- hello-world:latest
diff --git a/DI-engine/ding/utils/tests/test_bfs_helper.py b/DI-engine/ding/utils/tests/test_bfs_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f095907a08245c3af17966edf6db42436295643
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_bfs_helper.py
@@ -0,0 +1,26 @@
+import easydict
+import numpy
+import pytest
+
+from ding.utils import get_vi_sequence
+from dizoo.maze.envs.maze_env import Maze
+
+
+@pytest.mark.unittest
+class TestBFSHelper:
+
+ def test_bfs(self):
+
+ def load_env(seed):
+ ccc = easydict.EasyDict({'size': 16})
+ e = Maze(ccc)
+ e.seed(seed)
+ e.reset()
+ return e
+
+ env = load_env(314)
+ start_obs = env.process_states(env._get_obs(), env.get_maze_map())
+ vi_sequence, track_back = get_vi_sequence(env, start_obs)
+ assert vi_sequence.shape[1:] == (16, 16)
+ assert track_back[0][0].shape == (16, 16, 3)
+ assert isinstance(track_back[0][1], numpy.int32)
diff --git a/DI-engine/ding/utils/tests/test_collection_helper.py b/DI-engine/ding/utils/tests/test_collection_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..cb618eebb9a672614082cf421ad9557ed9539273
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_collection_helper.py
@@ -0,0 +1,13 @@
+import pytest
+
+from ding.utils.collection_helper import iter_mapping
+
+
+@pytest.mark.unittest
+class TestCollectionHelper:
+
+ def test_iter_mapping(self):
+ _iter = iter_mapping([1, 2, 3, 4, 5], lambda x: x ** 2)
+
+ assert not isinstance(_iter, list)
+ assert list(_iter) == [1, 4, 9, 16, 25]
diff --git a/DI-engine/ding/utils/tests/test_compression_helper.py b/DI-engine/ding/utils/tests/test_compression_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..494e86914bb0c3064532fbec0c0678c9b8c8f78b
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_compression_helper.py
@@ -0,0 +1,28 @@
+import random
+import numpy as np
+from ding.utils.compression_helper import get_data_compressor, get_data_decompressor
+
+import pytest
+
+
+@pytest.mark.unittest
+class TestCompression():
+
+ def get_step_data(self):
+ return {'input': [random.randint(10, 100) for i in range(100)]}
+
+ def testnaive(self):
+ compress_names = ['lz4', 'zlib', 'none']
+ for s in compress_names:
+ compressor = get_data_compressor(s)
+ decompressor = get_data_decompressor(s)
+ data = self.get_step_data()
+ assert data == decompressor(compressor(data))
+
+ def test_arr_to_st(self):
+ data = np.random.randint(0, 255, (96, 96, 3), dtype=np.uint8)
+ compress_names = ['jpeg']
+ for s in compress_names:
+ compressor = get_data_compressor(s)
+ decompressor = get_data_decompressor(s)
+ assert data.shape == decompressor(compressor(data)).shape
diff --git a/DI-engine/ding/utils/tests/test_config_helper.py b/DI-engine/ding/utils/tests/test_config_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe33fcaad40a80d9cf6e717e88a46411c2b182c8
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_config_helper.py
@@ -0,0 +1,57 @@
+import pytest
+import os
+import copy
+from easydict import EasyDict
+
+from ding.config import read_config_directly, save_config
+from ding.utils.default_helper import deep_merge_dicts, flatten_dict, deep_update
+
+
+@pytest.mark.unittest
+class TestConfigHelper():
+
+ def test_flatten_dict(self):
+ dict1 = {'a': {'aa': {'aaa': 'data - aaa'}, 'ab': 'data - ab'}}
+ dict2 = {'a/ab': 'data - ab', 'a/aa/aaa': 'data - aaa'}
+ assert flatten_dict(dict1) == dict2
+
+ def test_deep_merge_dicts(self):
+ dict1 = {'a': {'aa': 'aa1', 'ab': 'ab2'}, 'b': {'bb': 'bb2'}}
+ dict2 = {'a': {'aa': 'aa2', 'ac': 'ab1'}, 'b': {'ba': 'ba2'}, 'c': {}}
+ merged = {'a': {'aa': 'aa2', 'ab': 'ab2', 'ac': 'ab1'}, 'b': {'bb': 'bb2', 'ba': 'ba2'}, 'c': {}}
+ assert deep_merge_dicts(dict1, dict2) == merged
+ with pytest.raises(RuntimeError):
+ deep_update(dict1, dict2, new_keys_allowed=False)
+
+ def test_config(self):
+ import tempfile
+
+ # Test whether save and read is reversible.
+ old_config = EasyDict(
+ {
+ "aa": 1,
+ "bb": 0.0001,
+ "cc": None,
+ "dd": "string",
+ "ee": ["11", "22"],
+ "ff": {
+ "correct": 11
+ }
+ }
+ )
+ cfg_path = tempfile.mktemp(suffix=".py")
+ save_config(old_config, cfg_path)
+ assert os.path.exists(cfg_path)
+ config = read_config_directly(cfg_path)["exp_config"]
+
+ def assert_equal(item1, iterm2):
+ if isinstance(item1, list):
+ for item11, iterm22 in zip(item1, iterm2):
+ assert_equal(item11, iterm22)
+ elif isinstance(item1, dict):
+ for item11, item22 in zip(item1.values(), iterm2.values()):
+ assert_equal(item11, item22)
+ else:
+ assert item1 == iterm2
+
+ assert_equal(config, old_config)
diff --git a/DI-engine/ding/utils/tests/test_default_helper.py b/DI-engine/ding/utils/tests/test_default_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..c48b1d05aefd11bbd5cb4c528161f73fb0aa8efb
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_default_helper.py
@@ -0,0 +1,302 @@
+from collections import namedtuple
+
+import numpy as np
+import pytest
+import torch
+import treetensor.torch as ttorch
+
+from ding.utils.default_helper import lists_to_dicts, dicts_to_lists, squeeze, default_get, override, error_wrapper, \
+ list_split, LimitedSpaceContainer, set_pkg_seed, deep_merge_dicts, deep_update, flatten_dict, RunningMeanStd, \
+ one_time_warning, split_data_generator, get_shape0
+
+
+@pytest.mark.unittest
+class TestDefaultHelper():
+
+ def test_get_shape0(self):
+ a = {
+ 'a': {
+ 'b': torch.randn(4, 3)
+ },
+ 'c': {
+ 'd': torch.randn(4)
+ },
+ }
+ b = [a, a]
+ c = (a, a)
+ d = {
+ 'a': {
+ 'b': ["a", "b", "c", "d"]
+ },
+ 'c': {
+ 'd': torch.randn(4)
+ },
+ }
+ a = ttorch.as_tensor(a)
+ assert get_shape0(a) == 4
+ assert get_shape0(b) == 4
+ assert get_shape0(c) == 4
+ with pytest.raises(Exception) as e_info:
+ assert get_shape0(d) == 4
+
+ def test_lists_to_dicts(self):
+ set_pkg_seed(12)
+ with pytest.raises(ValueError):
+ lists_to_dicts([])
+ with pytest.raises(TypeError):
+ lists_to_dicts([1])
+ assert lists_to_dicts([{1: 1, 10: 3}, {1: 2, 10: 4}]) == {1: [1, 2], 10: [3, 4]}
+ T = namedtuple('T', ['location', 'race'])
+ data = [T({'x': 1, 'y': 2}, 'zerg') for _ in range(3)]
+ output = lists_to_dicts(data)
+ assert isinstance(output, T) and output.__class__ == T
+ assert len(output.location) == 3
+ data = [{'value': torch.randn(1), 'obs': {'scalar': torch.randn(4)}} for _ in range(3)]
+ output = lists_to_dicts(data, recursive=True)
+ assert isinstance(output, dict)
+ assert len(output['value']) == 3
+ assert len(output['obs']['scalar']) == 3
+
+ def test_dicts_to_lists(self):
+ assert dicts_to_lists({1: [1, 2], 10: [3, 4]}) == [{1: 1, 10: 3}, {1: 2, 10: 4}]
+
+ def test_squeeze(self):
+ assert squeeze((4, )) == 4
+ assert squeeze({'a': 4}) == 4
+ assert squeeze([1, 3]) == (1, 3)
+ data = np.random.randn(3)
+ output = squeeze(data)
+ assert (output == data).all()
+
+ def test_default_get(self):
+ assert default_get({}, 'a', default_value=1, judge_fn=lambda x: x < 2) == 1
+ assert default_get({}, 'a', default_fn=lambda: 1, judge_fn=lambda x: x < 2) == 1
+ with pytest.raises(AssertionError):
+ default_get({}, 'a', default_fn=lambda: 1, judge_fn=lambda x: x < 0)
+ assert default_get({'val': 1}, 'val', default_value=2) == 1
+
+ def test_override(self):
+
+ class foo(object):
+
+ def fun(self):
+ raise NotImplementedError
+
+ class foo1(foo):
+
+ @override(foo)
+ def fun(self):
+ return "a"
+
+ with pytest.raises(NameError):
+
+ class foo2(foo):
+
+ @override(foo)
+ def func(self):
+ pass
+
+ with pytest.raises(NotImplementedError):
+ foo().fun()
+ foo1().fun()
+
+ def test_error_wrapper(self):
+
+ def good_ret(a, b=1):
+ return a + b
+
+ wrap_good_ret = error_wrapper(good_ret, 0)
+ assert good_ret(1) == wrap_good_ret(1)
+
+ def bad_ret(a, b=0):
+ return a / b
+
+ wrap_bad_ret = error_wrapper(bad_ret, 0)
+ assert wrap_bad_ret(1) == 0
+ wrap_bad_ret_with_customized_log = error_wrapper(bad_ret, 0, 'customized_information')
+
+ def test_list_split(self):
+ data = [i for i in range(10)]
+ output, residual = list_split(data, step=4)
+ assert len(output) == 2
+ assert output[1] == [4, 5, 6, 7]
+ assert residual == [8, 9]
+ output, residual = list_split(data, step=5)
+ assert len(output) == 2
+ assert output[1] == [5, 6, 7, 8, 9]
+ assert residual is None
+
+
+@pytest.mark.unittest
+class TestLimitedSpaceContainer():
+
+ def test_container(self):
+ container = LimitedSpaceContainer(0, 5)
+ first = container.acquire_space()
+ assert first
+ assert container.cur == 1
+ left = container.get_residual_space()
+ assert left == 4
+ assert container.cur == container.max_val == 5
+ no_space = container.acquire_space()
+ assert not no_space
+ container.increase_space()
+ six = container.acquire_space()
+ assert six
+ for i in range(6):
+ container.release_space()
+ assert container.cur == 5 - i
+ container.decrease_space()
+ assert container.max_val == 5
+
+
+@pytest.mark.unittest
+class TestDict:
+
+ def test_deep_merge_dicts(self):
+ dict1 = {
+ 'a': 3,
+ 'b': {
+ 'c': 3,
+ 'd': {
+ 'e': 6,
+ 'f': 5,
+ }
+ }
+ }
+ dict2 = {
+ 'b': {
+ 'c': 5,
+ 'd': 6,
+ 'g': 4,
+ }
+ }
+ new_dict = deep_merge_dicts(dict1, dict2)
+ assert new_dict['a'] == 3
+ assert isinstance(new_dict['b'], dict)
+ assert new_dict['b']['c'] == 5
+ assert new_dict['b']['c'] == 5
+ assert new_dict['b']['g'] == 4
+
+ def test_deep_update(self):
+ dict1 = {
+ 'a': 3,
+ 'b': {
+ 'c': 3,
+ 'd': {
+ 'e': 6,
+ 'f': 5,
+ },
+ 'z': 4,
+ }
+ }
+ dict2 = {
+ 'b': {
+ 'c': 5,
+ 'd': 6,
+ 'g': 4,
+ }
+ }
+ with pytest.raises(RuntimeError):
+ new1 = deep_update(dict1, dict2, new_keys_allowed=False)
+ new2 = deep_update(dict1, dict2, new_keys_allowed=False, whitelist=['b'])
+ assert new2['a'] == 3
+ assert new2['b']['c'] == 5
+ assert new2['b']['d'] == 6
+ assert new2['b']['g'] == 4
+ assert new2['b']['z'] == 4
+
+ dict1 = {
+ 'a': 3,
+ 'b': {
+ 'type': 'old',
+ 'z': 4,
+ }
+ }
+ dict2 = {
+ 'b': {
+ 'type': 'new',
+ 'c': 5,
+ }
+ }
+ new3 = deep_update(dict1, dict2, new_keys_allowed=True, whitelist=[], override_all_if_type_changes=['b'])
+ assert new3['a'] == 3
+ assert new3['b']['type'] == 'new'
+ assert new3['b']['c'] == 5
+ assert 'z' not in new3['b']
+
+ def test_flatten_dict(self):
+ dict = {
+ 'a': 3,
+ 'b': {
+ 'c': 3,
+ 'd': {
+ 'e': 6,
+ 'f': 5,
+ },
+ 'z': 4,
+ }
+ }
+ flat = flatten_dict(dict)
+ assert flat['a'] == 3
+ assert flat['b/c'] == 3
+ assert flat['b/d/e'] == 6
+ assert flat['b/d/f'] == 5
+ assert flat['b/z'] == 4
+
+ def test_one_time_warning(self):
+ one_time_warning('test_one_time_warning')
+
+ def test_running_mean_std(self):
+ running = RunningMeanStd()
+ running.reset()
+ running.update(np.arange(1, 10))
+ assert running.mean == pytest.approx(5, abs=1e-4)
+ assert running.std == pytest.approx(2.582030, abs=1e-6)
+ running.update(np.arange(2, 11))
+ assert running.mean == pytest.approx(5.5, abs=1e-4)
+ assert running.std == pytest.approx(2.629981, abs=1e-6)
+ running.reset()
+ running.update(np.arange(1, 10))
+ assert pytest.approx(running.mean, abs=1e-4) == 5
+ assert running.mean == pytest.approx(5, abs=1e-4)
+ assert running.std == pytest.approx(2.582030, abs=1e-6)
+ new_shape = running.new_shape((2, 4), (3, ), (1, ))
+ assert isinstance(new_shape, tuple) and len(new_shape) == 3
+
+ running = RunningMeanStd(shape=(4, ))
+ running.reset()
+ running.update(np.random.random((10, 4)))
+ assert isinstance(running.mean, torch.Tensor) and running.mean.shape == (4, )
+ assert isinstance(running.std, torch.Tensor) and running.std.shape == (4, )
+
+ def test_split_data_generator(self):
+
+ def get_data():
+ return {
+ 'obs': torch.randn(5),
+ 'action': torch.randint(0, 10, size=(1, )),
+ 'prev_state': [None, None],
+ 'info': {
+ 'other_obs': torch.randn(5)
+ },
+ }
+
+ data = [get_data() for _ in range(4)]
+ data = lists_to_dicts(data)
+ data['obs'] = torch.stack(data['obs'])
+ data['action'] = torch.stack(data['action'])
+ data['info'] = {'other_obs': torch.stack([t['other_obs'] for t in data['info']])}
+ assert len(data['obs']) == 4
+ data['NoneKey'] = None
+ generator = split_data_generator(data, 3)
+ generator_result = list(generator)
+ assert len(generator_result) == 2
+ assert generator_result[0]['NoneKey'] is None
+ assert len(generator_result[0]['obs']) == 3
+ assert generator_result[0]['info']['other_obs'].shape == (3, 5)
+ assert generator_result[1]['NoneKey'] is None
+ assert len(generator_result[1]['obs']) == 3
+ assert generator_result[1]['info']['other_obs'].shape == (3, 5)
+
+ generator = split_data_generator(data, 3, shuffle=False)
diff --git a/DI-engine/ding/utils/tests/test_design_helper.py b/DI-engine/ding/utils/tests/test_design_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..52d029d7372406906a0280828b2fd2c591467286
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_design_helper.py
@@ -0,0 +1,37 @@
+import random
+
+import pytest
+
+from ding.utils import SingletonMetaclass
+
+
+@pytest.mark.unittest
+def test_singleton():
+ global count
+ count = 0
+
+ class A(object, metaclass=SingletonMetaclass):
+
+ def __init__(self, t):
+ self.t = t
+ self.p = random.randint(0, 10)
+ global count
+ count += 1
+
+ obj = [A(i) for i in range(3)]
+ assert count == 1
+ assert all([o.t == 0 for o in obj])
+ assert all([o.p == obj[0].p for o in obj])
+ assert all([id(o) == id(obj[0]) for o in obj])
+ assert id(A.instance) == id(obj[0])
+
+ # subclass test
+ class B(A):
+ pass
+
+ obj = [B(i) for i in range(3, 6)]
+ assert count == 2
+ assert all([o.t == 3 for o in obj])
+ assert all([o.p == obj[0].p for o in obj])
+ assert all([id(o) == id(obj[0]) for o in obj])
+ assert id(B.instance) == id(obj[0])
diff --git a/DI-engine/ding/utils/tests/test_file_helper.py b/DI-engine/ding/utils/tests/test_file_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..da81d835080007cae5efbd6c4eceff72409adaf9
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_file_helper.py
@@ -0,0 +1,34 @@
+import pytest
+import random
+import pickle
+
+from ding.utils.file_helper import read_file, read_from_file, remove_file, save_file, read_from_path, save_file_ceph
+
+
+@pytest.mark.unittest
+def test_normal_file():
+ data1 = {'a': [random.randint(0, 100) for i in range(100)]}
+ save_file('./f', data1)
+ data2 = read_file("./f")
+ assert (data2 == data1)
+ with open("./f1", "wb") as f1:
+ pickle.dump(data1, f1)
+ data3 = read_from_file("./f1")
+ assert (data3 == data1)
+ data4 = read_from_path("./f1")
+ assert (data4 == data1)
+ save_file_ceph("./f2", data1)
+ assert (data1 == read_from_file("./f2"))
+ # test lock
+ save_file('./f3', data1, use_lock=True)
+ data_read = read_file('./f3', use_lock=True)
+ assert isinstance(data_read, dict)
+
+ remove_file("./f")
+ remove_file("./f1")
+ remove_file("./f2")
+ remove_file("./f3")
+ remove_file('./f.lock')
+ remove_file('./f2.lock')
+ remove_file('./f3.lock')
+ remove_file('./name.txt')
diff --git a/DI-engine/ding/utils/tests/test_import_helper.py b/DI-engine/ding/utils/tests/test_import_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ed4e1e71e30048f2cb9fa6580884b66e981adc0
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_import_helper.py
@@ -0,0 +1,17 @@
+import pytest
+
+import ding
+from ding.utils.import_helper import try_import_ceph, try_import_mc, try_import_redis, try_import_rediscluster, \
+ try_import_link, import_module
+
+
+@pytest.mark.unittest
+def test_try_import():
+ try_import_ceph()
+ try_import_mc()
+ try_import_redis()
+ try_import_rediscluster()
+ try_import_link()
+ import_module(['ding.utils'])
+ ding.enable_linklink = True
+ try_import_link()
diff --git a/DI-engine/ding/utils/tests/test_k8s_launcher.py b/DI-engine/ding/utils/tests/test_k8s_launcher.py
new file mode 100644
index 0000000000000000000000000000000000000000..6145c17d1b8e7cd96e90f7821d60ff45c31f53ca
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_k8s_launcher.py
@@ -0,0 +1,76 @@
+import os
+import subprocess
+import pytest
+
+from ding.utils import K8sLauncher, OrchestratorLauncher
+
+try:
+ from kubernetes import config, client, watch
+except ImportError:
+ _test_mark = pytest.mark.ignore
+else:
+ _test_mark = pytest.mark.envtest
+
+
+@_test_mark
+def test_operate_k8s_cluster():
+ cluster_name = 'test-k8s-launcher'
+ config_path = os.path.join(os.path.dirname(__file__), 'config', 'k8s-config.yaml')
+ launcher = K8sLauncher(config_path)
+ launcher.name = cluster_name
+
+ # create cluster
+ launcher.create_cluster()
+
+ # check that cluster is successfully created
+ config.load_kube_config()
+ current_context = config.list_kube_config_contexts()[1]
+ assert current_context['context']['cluster'].startswith(f"k3d-{cluster_name}")
+ subprocess.run('kubectl create ns di-system', shell=True)
+
+ # create orchestrator
+ olauncher = OrchestratorLauncher('v1.1.3', cluster=launcher)
+ olauncher.create_orchestrator()
+
+ # check orchestrator is successfully created
+ expected_deployments, expected_crds = 2, 1
+ appv1 = client.AppsV1Api()
+ ret = appv1.list_namespaced_deployment("di-system")
+ assert len(ret.items) == expected_deployments
+
+ # check crds are installed
+ extensionv1 = client.ApiextensionsV1Api()
+ ret = extensionv1.list_custom_resource_definition()
+ found = 0
+ for crd in ret.items:
+ found = found + 1 if crd.metadata.name == 'aggregatorconfigs.diengine.opendilab.org' else found
+ found = found + 1 if crd.metadata.name == 'dijobs.diengine.opendilab.org' else found
+ assert found == expected_crds
+
+ # delete orchestrator
+ olauncher.delete_orchestrator()
+
+ # sleep for a few seconds and check crds are deleted
+ timeout = 10
+ deleted_crds = 0
+ w = watch.Watch()
+ for event in w.stream(extensionv1.list_custom_resource_definition, timeout_seconds=timeout):
+ if event['type'] == "DELETED":
+ deleted_crds += 1
+ if deleted_crds == expected_crds:
+ w.stop()
+ ret = extensionv1.list_custom_resource_definition()
+ found = 0
+ for crd in ret.items:
+ found = found + 1 if crd.metadata.name == 'dijobs.diengine.opendilab.org' else found
+ assert found == 0
+
+ # delete cluster
+ launcher.delete_cluster()
+ try:
+ config.load_kube_config()
+ except Exception:
+ print("No k8s cluster found, skipped...")
+ else:
+ current_context = config.list_kube_config_contexts()[1]
+ assert not current_context['context']['cluster'].startswith(f"k3d-{cluster_name}")
diff --git a/DI-engine/ding/utils/tests/test_lock.py b/DI-engine/ding/utils/tests/test_lock.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab0391e32e49172a1b0e559d2e6ecaecdb01da70
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_lock.py
@@ -0,0 +1,37 @@
+import pytest
+import numpy as np
+from collections import deque
+
+from ding.utils import LockContext, LockContextType, get_rw_file_lock
+
+
+@pytest.mark.unittest
+def test_usage():
+ lock = LockContext(LockContextType.PROCESS_LOCK)
+ queue = deque(maxlen=10)
+ data = np.random.randn(4)
+ with lock:
+ queue.append(np.copy(data))
+ with lock:
+ output = queue.popleft()
+ assert (output == data).all()
+ lock.acquire()
+ queue.append(np.copy(data))
+ lock.release()
+ lock.acquire()
+ output = queue.popleft()
+ lock.release()
+ assert (output == data).all()
+
+
+@pytest.mark.unittest
+def test_get_rw_file_lock():
+ path = 'tmp.npy'
+ # TODO real read-write case
+ read_lock = get_rw_file_lock(path, 'read')
+ write_lock = get_rw_file_lock(path, 'write')
+ with write_lock:
+ np.save(path, np.random.randint(0, 1, size=(3, 4)))
+ with read_lock:
+ data = np.load(path)
+ assert data.shape == (3, 4)
diff --git a/DI-engine/ding/utils/tests/test_log_helper.py b/DI-engine/ding/utils/tests/test_log_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..21b78015f3472a6c1ac7aeb62427ef03aa5074cd
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_log_helper.py
@@ -0,0 +1,55 @@
+import random
+import pytest
+from easydict import EasyDict
+from ditk import logging
+
+from ding.utils.log_helper import build_logger, pretty_print
+from ding.utils.file_helper import remove_file
+
+cfg = EasyDict(
+ {
+ 'env': {},
+ 'env_num': 4,
+ 'common': {
+ 'save_path': "./summary_log",
+ 'load_path': '',
+ 'name': 'fakeLog',
+ 'only_evaluate': False,
+ },
+ 'logger': {
+ 'print_freq': 10,
+ 'save_freq': 200,
+ 'eval_freq': 200,
+ },
+ 'data': {
+ 'train': {},
+ 'eval': {},
+ },
+ 'learner': {
+ 'log_freq': 100,
+ },
+ }
+)
+
+
+@pytest.mark.unittest
+class TestLogger:
+
+ def test_pretty_print(self):
+ pretty_print(cfg)
+
+ def test_logger(self):
+ logger, tb_logger = build_logger(cfg.common.save_path, name="fake_test", need_tb=True, text_level=logging.DEBUG)
+ variables = {'aa': 3.0, 'bb': 4, 'cc': 3e4}
+ # text logger
+ logger.info("I'm an info")
+ logger.debug("I'm a bug")
+ logger.error("I'm an error")
+ logger.info(logger.get_tabulate_vars(variables))
+ # tensorboard logger
+ for i in range(10):
+ new_vars = {k: v * (i + random.random()) for k, v in variables.items()}
+ for k, v in new_vars.items():
+ tb_logger.add_scalar(k, v, i)
+ remove_file(cfg.common.save_path)
+ tb_logger.close()
diff --git a/DI-engine/ding/utils/tests/test_log_writer_helper.py b/DI-engine/ding/utils/tests/test_log_writer_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c45d5c95f4b2cdfa38e33412c3ff62442ac7804
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_log_writer_helper.py
@@ -0,0 +1,39 @@
+import pytest
+import time
+import tempfile
+import shutil
+import os
+from os import path
+from ding.framework import Parallel
+from ding.framework.task import task
+from ding.utils import DistributedWriter
+
+
+def main_distributed_writer(tempdir):
+ with task.start():
+ time.sleep(task.router.node_id * 1) # Sleep 0 and 1, write to different files
+
+ tblogger = DistributedWriter(tempdir).plugin(task.router, is_writer=(task.router.node_id == 0))
+
+ def _add_scalar(ctx):
+ n = 10
+ for i in range(n):
+ tblogger.add_scalar(str(task.router.node_id), task.router.node_id, ctx.total_step * n + i)
+
+ task.use(_add_scalar)
+ task.use(lambda _: time.sleep(0.2))
+ task.run(max_step=3)
+
+ time.sleep(0.3 + (1 - task.router.node_id) * 2)
+
+
+@pytest.mark.unittest
+def test_distributed_writer():
+ tempdir = path.join(tempfile.gettempdir(), "tblogger")
+ try:
+ Parallel.runner(n_parallel_workers=2)(main_distributed_writer, tempdir)
+ assert path.exists(tempdir)
+ assert len(os.listdir(tempdir)) == 1
+ finally:
+ if path.exists(tempdir):
+ shutil.rmtree(tempdir)
diff --git a/DI-engine/ding/utils/tests/test_normalizer_helper.py b/DI-engine/ding/utils/tests/test_normalizer_helper.py
new file mode 100755
index 0000000000000000000000000000000000000000..d3339a00b40ba8bbfc4671caf87d2b56490a9670
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_normalizer_helper.py
@@ -0,0 +1,38 @@
+import easydict
+import numpy
+import pytest
+
+from ding.utils.normalizer_helper import DatasetNormalizer
+
+
+# TODO(nyz): fix unittest bugs
+@pytest.mark.tmp
+class TestNormalizerHelper:
+
+ def test_normalizer(self):
+ x = numpy.random.randn(10)
+ mean = x.mean()
+ std = x.std()
+ mins = x.min()
+ maxs = x.max()
+ normalizer = DatasetNormalizer({'test': x}, 'GaussianNormalizer', 10)
+ test = numpy.random.randn(1)
+ normal_test = normalizer.normalize(test, 'test')
+ unnormal_test = normalizer.unnormalize(normal_test, 'test')
+ assert unnormal_test == test
+ assert normal_test == (test - mean) / std
+
+ normalizer = DatasetNormalizer({'test': x}, 'LimitsNormalizer', 10)
+ test = numpy.random.randn(1)
+ normal_test1 = (test - mins) / (maxs - mins)
+ normal_test1 = 2 * normal_test1 - 1
+ normal_test = normalizer.normalize(test, 'test')
+ unnormal_test = normalizer.unnormalize(normal_test, 'test')
+ assert unnormal_test == test
+ assert normal_test == normal_test1
+
+ normalizer = DatasetNormalizer({'test': x}, 'CDFNormalizer', 10)
+ test = numpy.random.randn(1)
+ normal_test = normalizer.normalize(test, 'test')
+ unnormal_test = normalizer.unnormalize(normal_test, 'test')
+ assert unnormal_test == test
diff --git a/DI-engine/ding/utils/tests/test_profiler_helper.py b/DI-engine/ding/utils/tests/test_profiler_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd1b075cb2cd6b70fbbdd1dce013e0fcb6fcdaff
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_profiler_helper.py
@@ -0,0 +1,42 @@
+from easydict import EasyDict
+import pytest
+import unittest
+from unittest import mock
+from unittest.mock import patch
+import pathlib as pl
+import os
+import shutil
+
+from ding.utils.profiler_helper import Profiler, register_profiler
+
+
+@pytest.mark.unittest
+class TestProfilerModule:
+
+ def assertIsFile(self, path):
+ if not pl.Path(path).resolve().is_file():
+ raise AssertionError("File does not exist: %s" % str(path))
+
+ def test(self):
+ profiler = Profiler()
+
+ def register_mock(write_profile, pr, folder_path):
+ profiler.write_profile(pr, folder_path)
+
+ def clean_up(dir):
+ if os.path.exists(dir):
+ shutil.rmtree(dir)
+
+ dir = "./tmp_test/"
+ clean_up(dir)
+
+ with patch('ding.utils.profiler_helper.register_profiler', register_mock):
+ profiler.profile(dir)
+ file_path = os.path.join(dir, "profile_tottime.txt")
+ self.assertIsFile(file_path)
+ file_path = os.path.join(dir, "profile_cumtime.txt")
+ self.assertIsFile(file_path)
+ file_path = os.path.join(dir, "profile.prof")
+ self.assertIsFile(file_path)
+
+ clean_up(dir)
diff --git a/DI-engine/ding/utils/tests/test_registry.py b/DI-engine/ding/utils/tests/test_registry.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1035916da3ddd34df6faec870b0cab38a2a21de
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_registry.py
@@ -0,0 +1,27 @@
+import pytest
+from ding.utils.registry import Registry
+
+
+@pytest.mark.unittest
+def test_registry():
+ TEST_REGISTRY = Registry()
+
+ @TEST_REGISTRY.register('a')
+ class A:
+ pass
+
+ instance = TEST_REGISTRY.build('a')
+ assert isinstance(instance, A)
+
+ with pytest.raises(AssertionError):
+
+ @TEST_REGISTRY.register('a')
+ class A1:
+ pass
+
+ @TEST_REGISTRY.register('a', force_overwrite=True)
+ class A2:
+ pass
+
+ instance = TEST_REGISTRY.build('a')
+ assert isinstance(instance, A2)
diff --git a/DI-engine/ding/utils/tests/test_scheduler_helper.py b/DI-engine/ding/utils/tests/test_scheduler_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..052b4db8ec69676be0f9d28cd2388e7f45143b8d
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_scheduler_helper.py
@@ -0,0 +1,112 @@
+from easydict import EasyDict
+import pytest
+from ding.utils import Scheduler
+from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config
+
+
+@pytest.mark.unittest
+class TestSchedulerModule():
+
+ test_merged_scheduler_config = dict(
+ schedule_flag=False,
+ schedule_mode='reduce',
+ factor=0.05,
+ change_range=[-1, 1],
+ threshold=1e-4,
+ optimize_mode='min',
+ patience=1,
+ cooldown=0,
+ )
+ test_merged_scheduler_config = EasyDict(test_merged_scheduler_config)
+ test_policy_config = EasyDict(league_demo_ppo_config.policy)
+ test_policy_config_param = test_policy_config.learn.entropy_weight
+
+ def test_init_factor(self):
+ self.test_merged_scheduler_config.factor = 'hello_test'
+ with pytest.raises(AssertionError) as excinfo:
+ test_scheduler = Scheduler(self.test_merged_scheduler_config)
+ assert 'float/int' in str(excinfo.value)
+
+ self.test_merged_scheduler_config.factor = 0
+ with pytest.raises(AssertionError) as excinfo:
+ test_scheduler = Scheduler(self.test_merged_scheduler_config)
+ assert 'greater than 0' in str(excinfo.value)
+
+ # recover the correct value for later test function
+ self.test_merged_scheduler_config.factor = 0.05
+
+ def test_init_change_range(self):
+ self.test_merged_scheduler_config.change_range = 0
+ with pytest.raises(AssertionError) as excinfo:
+ test_scheduler = Scheduler(self.test_merged_scheduler_config)
+ assert 'list' in str(excinfo.value)
+
+ self.test_merged_scheduler_config.change_range = [0, 'hello_test']
+ with pytest.raises(AssertionError) as excinfo:
+ test_scheduler = Scheduler(self.test_merged_scheduler_config)
+ assert 'float' in str(excinfo.value)
+
+ self.test_merged_scheduler_config.change_range = [0, -1]
+ with pytest.raises(AssertionError) as excinfo:
+ test_scheduler = Scheduler(self.test_merged_scheduler_config)
+ assert 'smaller' in str(excinfo.value)
+
+ # recover the correct value for later test function
+ self.test_merged_scheduler_config.change_range = [-1, 1]
+
+ def test_init_patience(self):
+ self.test_merged_scheduler_config.patience = "hello_test"
+ with pytest.raises(AssertionError) as excinfo:
+ test_scheduler = Scheduler(self.test_merged_scheduler_config)
+ assert 'integer' in str(excinfo.value)
+
+ self.test_merged_scheduler_config.patience = -1
+ with pytest.raises(AssertionError) as excinfo:
+ test_scheduler = Scheduler(self.test_merged_scheduler_config)
+ assert 'greater' in str(excinfo.value)
+
+ # recover the correct value for later test function
+ self.test_merged_scheduler_config.patience = 1
+
+ def test_is_better(self):
+ test_scheduler = Scheduler(self.test_merged_scheduler_config)
+ assert test_scheduler.is_better(-1) is True
+
+ test_scheduler.last_metrics = 1
+ assert test_scheduler.is_better(0.5) is True
+
+ def test_in_cooldown(self):
+ self.test_merged_scheduler_config.cooldown_counter = 0
+ test_scheduler = Scheduler(self.test_merged_scheduler_config)
+ assert test_scheduler.in_cooldown is False
+
+ def test_step(self):
+
+ self.test_merged_scheduler_config.cooldown = 1
+
+ test_scheduler = Scheduler(self.test_merged_scheduler_config)
+ assert test_scheduler.cooldown_counter == 1
+ test_scheduler.last_metrics = 1.0
+
+ old_param = self.test_policy_config.learn.entropy_weight
+
+ # good epoch with maximum cooldown lenth is 1
+ self.test_policy_config_param = test_scheduler.step(0.9, self.test_policy_config_param)
+ assert self.test_policy_config_param == old_param
+ assert test_scheduler.cooldown_counter == 0
+ assert test_scheduler.last_metrics == 0.9
+ assert test_scheduler.bad_epochs_num == 0
+
+ # first bad epoch in cooldown period
+ self.test_policy_config_param = test_scheduler.step(0.899999, self.test_policy_config_param)
+ assert self.test_policy_config_param == old_param
+ assert test_scheduler.cooldown_counter == 0
+ assert test_scheduler.last_metrics == 0.899999
+ assert test_scheduler.bad_epochs_num == 1
+
+ # first bad epoch after cooldown
+ self.test_policy_config_param = test_scheduler.step(0.899998, self.test_policy_config_param)
+ assert self.test_policy_config_param == old_param - self.test_merged_scheduler_config.factor
+ assert test_scheduler.cooldown_counter == 1
+ assert test_scheduler.last_metrics == 0.899998
+ assert test_scheduler.bad_epochs_num == 0
diff --git a/DI-engine/ding/utils/tests/test_segment_tree.py b/DI-engine/ding/utils/tests/test_segment_tree.py
new file mode 100644
index 0000000000000000000000000000000000000000..70739df254aa6d32df3c81a212fe51a32c04b4ea
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_segment_tree.py
@@ -0,0 +1,87 @@
+import numpy as np
+import pytest
+
+import ding
+ding.enable_numba = False # noqa
+from ding.utils import SumSegmentTree, MinSegmentTree # noqa
+
+
+@pytest.mark.unittest
+class TestSumSegmentTree:
+
+ def test_create(self):
+ with pytest.raises(AssertionError):
+ tree = SumSegmentTree(capacity=13)
+
+ tree = SumSegmentTree(capacity=16)
+ assert (tree.operation == 'sum')
+ assert (tree.neutral_element == 0.)
+ assert (max(tree.value) == 0.)
+ assert (min(tree.value) == 0.)
+
+ def test_set_get_item(self):
+ tree = SumSegmentTree(capacity=4)
+ elements = [1, 5, 4, 7]
+ get_result = []
+ for idx, val in enumerate(elements):
+ tree[idx] = val
+ get_result.append(tree[idx])
+
+ assert (elements == get_result)
+ assert (tree.reduce() == sum(elements))
+ assert (tree.reduce(0, 3) == sum(elements[:3]))
+ assert (tree.reduce(0, 2) == sum(elements[:2]))
+ assert (tree.reduce(0, 1) == sum(elements[:1]))
+ assert (tree.reduce(1, 3) == sum(elements[1:3]))
+ assert (tree.reduce(1, 2) == sum(elements[1:2]))
+ assert (tree.reduce(2, 3) == sum(elements[2:3]))
+
+ with pytest.raises(AssertionError):
+ tree.reduce(2, 2)
+
+ def test_find_prefixsum_idx(self):
+ tree = SumSegmentTree(capacity=8)
+ elements = [0, 0.1, 0.5, 0, 0, 0.2, 0.8, 0]
+ for idx, val in enumerate(elements):
+ tree[idx] = val
+ with pytest.raises(AssertionError):
+ tree.find_prefixsum_idx(tree.reduce() + 1e-4, trust_caller=False)
+ with pytest.raises(AssertionError):
+ tree.find_prefixsum_idx(-1e-6, trust_caller=False)
+
+ assert (tree.find_prefixsum_idx(0) == 1)
+ assert (tree.find_prefixsum_idx(0.09) == 1)
+ assert (tree.find_prefixsum_idx(0.1) == 2)
+ assert (tree.find_prefixsum_idx(0.59) == 2)
+ assert (tree.find_prefixsum_idx(0.6) == 5)
+ assert (tree.find_prefixsum_idx(0.799) == 5)
+ assert (tree.find_prefixsum_idx(0.8) == 6)
+ assert (tree.find_prefixsum_idx(tree.reduce()) == 6)
+
+
+@pytest.mark.unittest
+class TestMinSegmentTree:
+
+ def test_create(self):
+ tree = MinSegmentTree(capacity=16)
+ assert (tree.operation == 'min')
+ assert (tree.neutral_element == np.inf)
+ assert (max(tree.value) == np.inf)
+ assert (min(tree.value) == np.inf)
+
+ def test_set_get_item(self):
+ tree = MinSegmentTree(capacity=4)
+ elements = [1, -10, 10, 7]
+ get_result = []
+ for idx, val in enumerate(elements):
+ tree[idx] = val
+ get_result.append(tree[idx])
+
+ assert (elements == get_result)
+ assert (tree.reduce() == min(elements))
+ assert (tree.reduce(0, 3) == min(elements[:3]))
+ assert (tree.reduce(0, 2) == min(elements[:2]))
+ assert (tree.reduce(0, 1) == min(elements[:1]))
+ assert (tree.reduce(1, 3) == min(elements[1:3]))
+ assert (tree.reduce(1, 2) == min(elements[1:2]))
+ assert (tree.reduce(2, 3) == min(elements[2:3]))
diff --git a/DI-engine/ding/utils/tests/test_system_helper.py b/DI-engine/ding/utils/tests/test_system_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..00b40593e1f9bbefe7b1acb1a4356fe9e734fbd3
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_system_helper.py
@@ -0,0 +1,15 @@
+import pytest
+
+from ding.utils.system_helper import get_ip, get_pid, get_task_uid
+
+
+@pytest.mark.unittest
+class TestSystemHelper():
+
+ def test_get(self):
+ try:
+ get_ip()
+ except:
+ pass
+ assert isinstance(get_pid(), int)
+ assert isinstance(get_task_uid(), str)
diff --git a/DI-engine/ding/utils/tests/test_time_helper.py b/DI-engine/ding/utils/tests/test_time_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f7e52aa2c23b5f0d85a1343b1cdd1ee2334ed0a
--- /dev/null
+++ b/DI-engine/ding/utils/tests/test_time_helper.py
@@ -0,0 +1,74 @@
+import pytest
+import numpy as np
+import time
+from ding.utils.time_helper import build_time_helper, WatchDog, TimeWrapperTime, EasyTimer
+
+
+@pytest.mark.unittest
+class TestTimeHelper:
+
+ def test_naive(self):
+
+ class NaiveObject(object):
+ pass
+
+ cfg = NaiveObject()
+ setattr(cfg, 'common', NaiveObject())
+ setattr(cfg.common, 'time_wrapper_type', 'time')
+ with pytest.raises(RuntimeError):
+ time_handle = build_time_helper()
+ with pytest.raises(KeyError):
+ build_time_helper(cfg=None, wrapper_type="not_implement")
+ time_handle = build_time_helper(cfg)
+ time_handle = build_time_helper(wrapper_type='cuda')
+ # wrapper_type='cuda' but cuda is not available
+ assert issubclass(time_handle, TimeWrapperTime)
+ time_handle = build_time_helper(wrapper_type='time')
+
+ @time_handle.wrapper
+ def func1(x):
+ return x + 1
+
+ def func2(x):
+ return x + 1
+
+ # usage 1
+ ret, t = func1(3)
+ assert np.isscalar(t)
+ assert func1(4)[0] == func2(4)
+
+ # usage 2
+ time_handle.start_time()
+ _ = func2(3)
+ t = time_handle.end_time()
+ assert np.isscalar(t)
+
+ #test time_lag and restart
+ time_handle.start_time()
+ time.sleep(0.5)
+ time_handle.start_time()
+ time.sleep(1)
+ t = time_handle.end_time()
+ assert np.isscalar(t)
+ # time_lag is bigger than 1e-3
+ # assert abs(t-1) < 1e-3
+ assert abs(t - 1) < 1e-2
+
+ timer = EasyTimer()
+ with timer:
+ tmp = np.random.random(size=(4, 100))
+ tmp = tmp ** 2
+ value = timer.value
+ assert isinstance(value, float)
+
+
+@pytest.mark.unittest
+class TestWatchDog:
+
+ def test_naive(self):
+ watchdog = WatchDog(3)
+ watchdog.start()
+ time.sleep(2)
+ with pytest.raises(TimeoutError):
+ time.sleep(2)
+ watchdog.stop()
diff --git a/DI-engine/ding/utils/time_helper.py b/DI-engine/ding/utils/time_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..06498e4546cfc3248cee0b18113fd1edfb3585ce
--- /dev/null
+++ b/DI-engine/ding/utils/time_helper.py
@@ -0,0 +1,185 @@
+import signal
+import time
+from typing import Any, Callable
+
+import torch
+from easydict import EasyDict
+from .time_helper_base import TimeWrapper
+from .time_helper_cuda import get_cuda_time_wrapper
+
+
+def build_time_helper(cfg: EasyDict = None, wrapper_type: str = None) -> Callable[[], 'TimeWrapper']:
+ """
+ Overview:
+ Build the timehelper
+
+ Arguments:
+ - cfg (:obj:`dict`):
+ The config file, which is a multilevel dict, have large domain like
+ evaluate, common, model, train etc, and each large domain
+ has it's smaller domain.
+ - wrapper_type (:obj:`str`): The type of wrapper returned, support ``['time', 'cuda']``
+
+ Returns:
+ - time_wrapper (:obj:`TimeWrapper`):
+ Return the corresponding timewrapper, Reference: ``ding.utils.timehelper.TimeWrapperTime``
+ and ``ding.utils.timehelper.get_cuda_time_wrapper``.
+ """
+ # Note: wrapper_type has higher priority
+ if wrapper_type is not None:
+ time_wrapper_type = wrapper_type
+ elif cfg is not None:
+ time_wrapper_type = cfg.common.time_wrapper_type
+ else:
+ raise RuntimeError('Either wrapper_type or cfg should be provided.')
+
+ if time_wrapper_type == 'time':
+ return TimeWrapperTime
+ elif time_wrapper_type == 'cuda':
+ if torch.cuda.is_available():
+ # lazy initialize to make code runnable locally
+ return get_cuda_time_wrapper()
+ else:
+ return TimeWrapperTime
+ else:
+ raise KeyError('invalid time_wrapper_type: {}'.format(time_wrapper_type))
+
+
+class EasyTimer:
+ """
+ Overview:
+ A decent timer wrapper that can be used easily.
+
+ Interfaces:
+ ``__init__``, ``__enter__``, ``__exit__``
+
+ Example:
+ >>> wait_timer = EasyTimer()
+ >>> with wait_timer:
+ >>> func(...)
+ >>> time_ = wait_timer.value # in second
+ """
+
+ def __init__(self, cuda=True):
+ """
+ Overview:
+ Init class EasyTimer
+
+ Arguments:
+ - cuda (:obj:`bool`): Whether to build timer with cuda type
+ """
+ if torch.cuda.is_available() and cuda:
+ time_wrapper_type = "cuda"
+ else:
+ time_wrapper_type = "time"
+ self._timer = build_time_helper(wrapper_type=time_wrapper_type)
+ self.value = 0.0
+
+ def __enter__(self):
+ """
+ Overview:
+ Enter timer, start timing
+ """
+ self.value = 0.0
+ self._timer.start_time()
+
+ def __exit__(self, *args):
+ """
+ Overview:
+ Exit timer, stop timing
+ """
+ self.value = self._timer.end_time()
+
+
+class TimeWrapperTime(TimeWrapper):
+ """
+ Overview:
+ A class method that inherit from ``TimeWrapper`` class
+
+ Interfaces:
+ ``start_time``, ``end_time``
+ """
+
+ # overwrite
+ @classmethod
+ def start_time(cls):
+ """
+ Overview:
+ Implement and override the ``start_time`` method in ``TimeWrapper`` class
+ """
+ cls.start = time.time()
+
+ # overwrite
+ @classmethod
+ def end_time(cls):
+ """
+ Overview:
+ Implement and override the end_time method in ``TimeWrapper`` class
+
+ Returns:
+ - time(:obj:`float`): The time between ``start_time`` and end_time
+ """
+ cls.end = time.time()
+ return cls.end - cls.start
+
+
+class WatchDog(object):
+ """
+ Overview:
+ Simple watchdog timer to detect timeouts
+
+ Arguments:
+ - timeout (:obj:`int`): Timeout value of the ``watchdog [seconds]``.
+
+ .. note::
+ If it is not reset before exceeding this value, ``TimeourError`` raised.
+
+ Interfaces:
+ ``start``, ``stop``
+
+ Examples:
+ >>> watchdog = WatchDog(x) # x is a timeout value
+ >>> ...
+ >>> watchdog.start()
+ >>> ... # Some function
+
+ """
+
+ def __init__(self, timeout: int = 1):
+ """
+ Overview:
+ Initialize watchdog with ``timeout`` value.
+ Arguments:
+ - timeout (:obj:`int`): Timeout value of the ``watchdog [seconds]``.
+ """
+
+ self._timeout = timeout + 1
+ self._failed = False
+
+ def start(self):
+ """
+ Overview:
+ Start watchdog.
+ """
+ signal.signal(signal.SIGALRM, self._event)
+ signal.alarm(self._timeout)
+
+ @staticmethod
+ def _event(signum: Any, frame: Any):
+ """
+ Overview:
+ Event handler for watchdog.
+ Arguments:
+ - signum (:obj:`Any`): Signal number.
+ - frame (:obj:`Any`): Current stack frame.
+ """
+
+ raise TimeoutError()
+
+ def stop(self):
+ """
+ Overview:
+ Stop watchdog with ``alarm(0)``, ``SIGALRM``, and ``SIG_DFL`` signals.
+ """
+ signal.alarm(0)
+ signal.signal(signal.SIGALRM, signal.SIG_DFL)
diff --git a/DI-engine/ding/utils/time_helper_base.py b/DI-engine/ding/utils/time_helper_base.py
new file mode 100644
index 0000000000000000000000000000000000000000..86f58d0fe8be6a19814858e51d5105b6a19fb1da
--- /dev/null
+++ b/DI-engine/ding/utils/time_helper_base.py
@@ -0,0 +1,41 @@
+class TimeWrapper(object):
+ """
+ Overview:
+ Abstract class method that defines ``TimeWrapper`` class
+
+ Interfaces:
+ ``wrapper``, ``start_time``, ``end_time``
+ """
+
+ @classmethod
+ def wrapper(cls, fn):
+ """
+ Overview:
+ Classmethod wrapper, wrap a function and automatically return its running time
+ Arguments:
+ - fn (:obj:`function`): The function to be wrap and timed
+ """
+
+ def time_func(*args, **kwargs):
+ cls.start_time()
+ ret = fn(*args, **kwargs)
+ t = cls.end_time()
+ return ret, t
+
+ return time_func
+
+ @classmethod
+ def start_time(cls):
+ """
+ Overview:
+ Abstract classmethod, start timing
+ """
+ raise NotImplementedError
+
+ @classmethod
+ def end_time(cls):
+ """
+ Overview:
+ Abstract classmethod, stop timing
+ """
+ raise NotImplementedError
diff --git a/DI-engine/ding/utils/time_helper_cuda.py b/DI-engine/ding/utils/time_helper_cuda.py
new file mode 100644
index 0000000000000000000000000000000000000000..51ea5e925a15d82ca4f2956f8b089a01df6e6ec1
--- /dev/null
+++ b/DI-engine/ding/utils/time_helper_cuda.py
@@ -0,0 +1,59 @@
+from typing import Callable
+import torch
+from .time_helper_base import TimeWrapper
+
+
+def get_cuda_time_wrapper() -> Callable[[], 'TimeWrapper']:
+ """
+ Overview:
+ Return the ``TimeWrapperCuda`` class, this wrapper aims to ensure compatibility in no cuda device
+
+ Returns:
+ - TimeWrapperCuda(:obj:`class`): See ``TimeWrapperCuda`` class
+
+ .. note::
+ Must use ``torch.cuda.synchronize()``, reference:
+
+ """
+
+ # TODO find a way to autodoc the class within method
+ class TimeWrapperCuda(TimeWrapper):
+ """
+ Overview:
+ A class method that inherit from ``TimeWrapper`` class
+
+ Notes:
+ Must use torch.cuda.synchronize(), reference: \
+
+
+ Interfaces:
+ ``start_time``, ``end_time``
+ """
+ # cls variable is initialized on loading this class
+ start_record = torch.cuda.Event(enable_timing=True)
+ end_record = torch.cuda.Event(enable_timing=True)
+
+ # overwrite
+ @classmethod
+ def start_time(cls):
+ """
+ Overview:
+ Implement and overide the ``start_time`` method in ``TimeWrapper`` class
+ """
+ torch.cuda.synchronize()
+ cls.start = cls.start_record.record()
+
+ # overwrite
+ @classmethod
+ def end_time(cls):
+ """
+ Overview:
+ Implement and overide the end_time method in ``TimeWrapper`` class
+ Returns:
+ - time(:obj:`float`): The time between ``start_time`` and ``end_time``
+ """
+ cls.end = cls.end_record.record()
+ torch.cuda.synchronize()
+ return cls.start_record.elapsed_time(cls.end_record) / 1000
+
+ return TimeWrapperCuda
diff --git a/DI-engine/ding/utils/type_helper.py b/DI-engine/ding/utils/type_helper.py
new file mode 100644
index 0000000000000000000000000000000000000000..54ee10ec160491e299b61d83e44967aa6cfb57e7
--- /dev/null
+++ b/DI-engine/ding/utils/type_helper.py
@@ -0,0 +1,5 @@
+from collections import namedtuple
+from typing import List, Tuple, TypeVar
+
+SequenceType = TypeVar('SequenceType', List, Tuple, namedtuple)
+Tensor = TypeVar('torch.Tensor')
diff --git a/DI-engine/ding/worker/__init__.py b/DI-engine/ding/worker/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4c5b415b1c891c6ae39f3c5b65c4fe4abfb7c58
--- /dev/null
+++ b/DI-engine/ding/worker/__init__.py
@@ -0,0 +1,5 @@
+from .collector import *
+from .learner import *
+from .replay_buffer import *
+from .coordinator import *
+from .adapter import *
diff --git a/DI-engine/ding/worker/adapter/__init__.py b/DI-engine/ding/worker/adapter/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6ab4283f73bd6eaee774cf7b363afa3a63a3764
--- /dev/null
+++ b/DI-engine/ding/worker/adapter/__init__.py
@@ -0,0 +1 @@
+from .learner_aggregator import LearnerAggregator
diff --git a/DI-engine/ding/worker/adapter/learner_aggregator.py b/DI-engine/ding/worker/adapter/learner_aggregator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f2c2a2af69b8bdf8dd9f5f31d63a8fb416526be
--- /dev/null
+++ b/DI-engine/ding/worker/adapter/learner_aggregator.py
@@ -0,0 +1,314 @@
+from typing import Union, Optional
+import traceback
+import numbers
+import copy
+import time
+from functools import reduce
+from threading import Thread
+from easydict import EasyDict
+
+from ding.interaction import Master, Slave, TaskFail
+from ding.interaction.master.task import TaskStatus
+from ding.utils import build_logger, get_operator_server_kwargs, exist_operator_server
+from ..coordinator.operator_server import OperatorServer
+
+
+class LearnerAggregatorSlave(Slave):
+ """
+ Overview:
+ A slave, whose master is coordinator.
+ """
+
+ def __init__(self, *args, callback_fn: Optional[dict] = None, **kwargs) -> None:
+ """
+ Overview:
+ Init callback functions additionally. Callback functions are methods in ``LearnerAggregator``.
+ As for callback mechanisim, you can refer to ``worker/learner/comm/flask_fs_learner.py`` for help.
+ """
+ super().__init__(*args, **kwargs)
+ self._callback_fn = callback_fn
+
+ def _process_task(self, task: dict) -> Union[dict, TaskFail]:
+ """
+ Overview:
+ Process a task according to input task info dict, which is passed in by coordinator's master.
+ For each type of task, you can refer to corresponding callback function in
+ ``LearnerAggregator`` for details.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Task dict. Must contain key "name".
+ Returns:
+ - result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception.
+ """
+ task_name = task['name']
+ if task_name == 'resource':
+ return self._callback_fn['deal_with_get_resource']()
+ elif task_name == 'learner_start_task':
+ return self._callback_fn['deal_with_learner_start'](task)
+ elif task_name == 'learner_get_data_task':
+ return self._callback_fn['deal_with_get_data'](task)
+ elif task_name == 'learner_learn_task':
+ return self._callback_fn['deal_with_learn'](task)
+ else:
+ raise TaskFail(result={'message': 'task name error'}, message='illegal learner task <{}>'.format(task_name))
+
+
+class LearnerAggregator(object):
+ """
+ Overview:
+ Aggregate multiple learners.
+ Interfaces:
+ __init__, start, close, merge_info
+ """
+
+ def __init__(self, cfg: dict) -> None:
+ """
+ Overview:
+ Init method.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config dict.
+ """
+ self._cfg = cfg
+ callback_fn = {
+ 'deal_with_get_resource': self.deal_with_get_resource,
+ 'deal_with_learner_start': self.deal_with_learner_start,
+ 'deal_with_get_data': self.deal_with_get_data,
+ 'deal_with_learn': self.deal_with_learn,
+ }
+ host, port = cfg.slave.host, cfg.slave.port
+ self._slave = LearnerAggregatorSlave(host, port, callback_fn=callback_fn)
+ self._logger, _ = build_logger(path='./log', name='learner_aggregator', need_tb=False)
+ self._end_flag = True
+ self._max_retry_second = 60
+
+ # ``_world_size`` indicates how many learners are connected;
+ # And ``_learner_connection`` lists those connections in dict type.
+ self._world_size = 0
+ self._learner_connection = {}
+
+ # create operator server
+ if exist_operator_server():
+ # get from default or env vars
+ server_kwargs = get_operator_server_kwargs(EasyDict({}))
+ self._operator_server = OperatorServer(**server_kwargs)
+ self._operator_server.set_worker_type('aggregator')
+ else:
+ self._operator_server = None
+
+ # failed connection
+ self._failed_learner_conn = set()
+
+ def start(self) -> None:
+ """
+ Overview:
+ Start the aggregator. Set up a master and build connections with all learners within max retry time.
+ """
+ self._end_flag = False
+ try:
+ self._slave.start()
+ except Exception as e:
+ self._logger.error(
+ "learner_aggregator slave start error:\n" + ''.join(traceback.format_tb(e.__traceback__)) + repr(e)
+ )
+ return
+ try:
+ self._master = Master(self._cfg.master.host, self._cfg.master.port)
+ self._master.start()
+ self._master.ping()
+ except Exception as e:
+ self._logger.error(
+ "learner_aggregator master start error:\n" + ''.join(traceback.format_tb(e.__traceback__)) + repr(e)
+ )
+ return
+ self._world_size = 0
+ for _, (learner_id, learner_host, learner_port) in self._cfg.learner.items():
+ self._new_connection_learner(learner_id, learner_host, int(learner_port))
+
+ if self._operator_server:
+ self._init_conn_flag = False
+ # create sync learner thread
+ self._period_sync_with_server_thread = Thread(
+ target=self._period_sync_with_server, name="period_sync", daemon=True
+ )
+ self._period_sync_with_server_thread.start()
+ start_time = time.time()
+ while time.time() - start_time <= self._max_retry_second and not self._end_flag:
+ if not self._init_conn_flag:
+ time.sleep(0.2)
+
+ # Exceeds max retry time and no learner connection found.
+ if len(self._learner_connection) == 0:
+ self._logger.error("learner_aggregator master max retries failed")
+ else:
+ self._logger.info("learner aggregator is started")
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close aggregator slave, connections with learners, and master.
+ """
+ if self._end_flag:
+ return
+ self._end_flag = True
+ try:
+ self._slave.close()
+ for _, conn in self._learner_connection.items():
+ conn.disconnect()
+ assert not conn.is_connected
+ self._master.close()
+ except Exception: # Ignore close exception.
+ pass
+
+ def deal_with_get_resource(self) -> dict:
+ return {'gpu': self._world_size}
+
+ def deal_with_learner_start(self, task: dict) -> dict:
+ if len(self._learner_connection) == 0:
+ raise TaskFail(message='no connected learner', result={'message': 'no connected learner'})
+ name = task['name']
+ start_task = {}
+ for k, v in self._learner_connection.items():
+ start_task[k] = v.new_task({'name': name, 'task_info': task['task_info']})
+ start_task[k].start()
+ for k, v in start_task.items():
+ v.join()
+ task_status = [v.status for v in start_task.values()]
+ if any([s != TaskStatus.COMPLETED for s in task_status]):
+ # TODO(nyz) dynamic learner gpu add/remove
+ message = "one of learner can't start_task"
+ raise TaskFail(message=message, result={'message': message})
+ return {'message': 'learner task has started'}
+
+ def deal_with_get_data(self, task: dict) -> dict:
+ data_task = {}
+ for k, v in self._learner_connection.items():
+ data_task[k] = v.new_task({'name': task['name']})
+ data_task[k].start()
+ for k, v in data_task.items():
+ v.join()
+ # TODO deal with task fail
+ self._data_demand = {k: v.result for k, v in data_task.items()}
+ demand_list = list(self._data_demand.values())
+ # Merge data demand info by adding up all learners' demand batch size.
+ merged_demand = copy.deepcopy(demand_list[0])
+ merged_demand['batch_size'] = sum([d['batch_size'] for d in demand_list])
+ return merged_demand
+
+ def deal_with_learn(self, task: dict) -> dict:
+ learn_task = {}
+ merged_data = task['data']
+ # Split training data for each learner according to ``self._data_demand``.
+ split_data = []
+ start = 0
+ for item in self._data_demand.values():
+ end = item['batch_size'] + start
+ split_data.append(merged_data[start:end])
+ start = end
+ for (k, v), d in zip(self._learner_connection.items(), split_data):
+ learn_task[k] = v.new_task({'name': task['name'], 'data': d})
+ learn_task[k].start()
+ for k, v in learn_task.items():
+ v.join()
+ # TODO deal with task fail
+ info_list = [v.result for v in learn_task.values()]
+ # Merge learn info through ``merge_info`` method.
+ merged_info = self.merge_info(info_list)
+ return merged_info
+
+ @staticmethod
+ def merge_info(info: list) -> dict:
+ homogeneous_keys = ['learner_step', 'buffer_id', 'task_id', 'learner_done']
+ elem = info[0]
+ if elem is None:
+ return info
+ elif isinstance(elem, numbers.Integral) or isinstance(elem, str) or isinstance(elem, float):
+ return info
+ elif isinstance(elem, list) or isinstance(elem, tuple):
+ return list(reduce(lambda x, y: x + y, info))
+ elif isinstance(elem, dict):
+ ret = {}
+ for k in elem.keys():
+ if k in homogeneous_keys:
+ ret[k] = elem[k]
+ else:
+ ret[k] = LearnerAggregator.merge_info([e[k] for e in info])
+ return ret
+ else:
+ raise TypeError("not support type: {}".format(type(elem)))
+
+ def _new_connection_learner(self, learner_id: str, learner_host: str, learner_port: int) -> None:
+ start_time = time.time()
+ conn = None
+ while time.time() - start_time <= self._max_retry_second and not self._end_flag:
+ try:
+ if conn is None or not conn.is_connected:
+ conn = self._master.new_connection(learner_id, learner_host, learner_port)
+ conn.connect()
+ assert conn.is_connected
+ self._learner_connection[learner_id] = conn
+ self._world_size += 1
+ break
+ except Exception as e:
+ self._logger.error(
+ f"learner({learner_id}) connection start error:\n" + ''.join(traceback.format_tb(e.__traceback__)) +
+ repr(e) + '\nAuto Retry...'
+ )
+ time.sleep(2)
+
+ if learner_id in self._learner_connection:
+ self._logger.info(f"Succeed to connect to learner({learner_id})")
+ else:
+ self._logger.info(f"Fail to connect to learner({learner_id})")
+ self._failed_learner_conn.add(learner_id)
+
+ def _update_connection_learner(self, cur_learners) -> None:
+ conn_learners = list(self._learner_connection.keys())
+ new_c = set(cur_learners) - set(conn_learners)
+ del_c = set(conn_learners) - (set(cur_learners) | self._failed_learner_conn)
+ # conns which have terminated in server side, clear up
+ self._failed_learner_conn = self._failed_learner_conn & set(cur_learners)
+
+ # connect to each new learner
+ for learner_id in new_c:
+ learner_host, learner_port = learner_id.split(':')
+ self._new_connection_learner(learner_id, learner_host, int(learner_port))
+
+ for learner_id in del_c:
+ if learner_id in conn_learners:
+ if self._connection_learner[learner_id].is_connected:
+ conn = self._connection_learner.pop(learner_id)
+ conn.disconnect()
+ assert not conn.is_connected
+ else:
+ # ignore the operation of disconnect, since the pod will be terminated by server,
+ # just throw the connection
+ self._connection_learner.pop(learner_id)
+
+ def _period_sync_with_server(self) -> None:
+ while not self._end_flag:
+ # First: send failed list to notify server which replicas are failed, then terminate such replicas.
+ if len(self._failed_learner_conn) > 0:
+ learner_conn = []
+ for replica_conn in self._failed_learner_conn:
+ dns_name = replica_conn.split(":")[0]
+ pod_name_list = dns_name.split(".")[:-1]
+ pod_name = ".".join(pod_name_list)
+ if pod_name not in learner_conn:
+ learner_conn.append(pod_name)
+ success, _, message, _ = self._operator_server.post_replicas_failed(learners=list(learner_conn))
+ if success:
+ # do not update learner instantly, update at /GET replicas
+ self._failed_learner_conn.clear()
+ else:
+ self._logger.error("Failed to send failed list to server, message: {}".format(message))
+
+ # get list from server
+ success, _, message, data = self._operator_server.get_replicas()
+ if success:
+ cur_learners = data["learners"]
+ # self._logger.info("current list:", cur_learners)
+ self._update_connection_learner(cur_learners)
+ self._init_conn_flag = self._init_conn_flag | True
+ else:
+ self._logger.error("Failed to sync with server, message: {}".format(message))
+
+ time.sleep(3)
diff --git a/DI-engine/ding/worker/adapter/tests/test_learner_aggregator.py b/DI-engine/ding/worker/adapter/tests/test_learner_aggregator.py
new file mode 100644
index 0000000000000000000000000000000000000000..511d8de7ef876b3473fb06cf030085766f1d371f
--- /dev/null
+++ b/DI-engine/ding/worker/adapter/tests/test_learner_aggregator.py
@@ -0,0 +1,105 @@
+from ding.worker.adapter.learner_aggregator import LearnerAggregator
+from typing import Union
+import numpy as np
+import pytest
+from easydict import EasyDict
+
+from ding.interaction import Master, Slave, TaskFail
+from ding.interaction.master.task import TaskStatus
+from ding.utils import build_logger
+
+
+class LearnerSlave(Slave):
+
+ def __init__(self, id: int, *args, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self.batch_size = 32
+ self.learner_step = np.random.randint(100 * id, 100 * id + 100)
+ self.buffer_id = "buffer_" + str(np.random.randint(10 * id, 10 * id + 10))
+ self.task_id = "task_" + str(np.random.randint(10 * id, 10 * id + 10))
+ self.learner_done = True if np.random.rand() < 0.5 else False
+
+ def _process_task(self, task: dict) -> Union[dict, TaskFail]:
+ task_name = task['name']
+ if task_name == 'resource':
+ return {'gpu': 1}
+ elif task_name == 'learner_start_task':
+ return {'message': 'learner task has started'}
+ elif task_name == 'learner_get_data_task':
+ return {'batch_size': self.batch_size}
+ elif task_name == 'learner_learn_task':
+ return {
+ 'learner_step': self.learner_step,
+ 'buffer_id': self.buffer_id,
+ 'task_id': self.task_id,
+ 'learner_done': self.learner_done,
+ 'a_list': [1, 2],
+ }
+ else:
+ raise TaskFail(result={'message': 'task name error'}, message='illegal learner task <{}>'.format(task_name))
+
+
+@pytest.mark.unittest
+def test_learner_aggregator():
+ learner_slaves = [LearnerSlave(i, '0.0.0.0', 19900 + i) for i in range(4)]
+ for learner_slave in learner_slaves:
+ learner_slave.start()
+ la_cfg = EasyDict(
+ master=dict(
+ host='0.0.0.0',
+ port=19999,
+ ),
+ slave=dict(
+ host='0.0.0.0',
+ port=18800,
+ ),
+ learner=dict(
+ learner0=('learner0', '0.0.0.0', 19900),
+ learner1=('learner1', '0.0.0.0', 19901),
+ learner2=('learner2', '0.0.0.0', 19902),
+ learner3=('learner3', '0.0.0.0', 19903),
+ )
+ )
+ learner_aggregator = LearnerAggregator(la_cfg)
+ learner_aggregator.start()
+ with Master('0.0.0.0', 18888) as master: # coordinator master
+ master.ping() # True if master launch success, otherwise False
+ with master.new_connection('with_la_slave', '0.0.0.0', 18800) as conn:
+ assert conn.is_connected
+ assert 'with_la_slave' in master
+
+ task = conn.new_task({'name': 'resource'})
+ task.start().join()
+ assert task.result == {'gpu': 4}
+ assert task.status == TaskStatus.COMPLETED
+
+ task = conn.new_task({'name': 'learner_start_task', 'task_info': {}})
+ task.start().join()
+ assert task.result == {'message': 'learner task has started'}
+ assert task.status == TaskStatus.COMPLETED
+
+ task = conn.new_task({'name': 'learner_get_data_task', 'task_info': {}})
+ task.start().join()
+ sum_batch_size = sum([learner.batch_size for learner in learner_slaves])
+ assert task.result['batch_size'] == sum_batch_size
+ assert task.status == TaskStatus.COMPLETED
+
+ task = conn.new_task({'name': 'learner_learn_task', 'data': [i for i in range(sum_batch_size)]})
+ task.start().join()
+ assert task.result['learner_step'] == learner_slaves[0].learner_step
+ assert task.result['buffer_id'] == learner_slaves[0].buffer_id
+ assert task.result['task_id'] == learner_slaves[0].task_id
+ assert task.result['learner_done'] == learner_slaves[0].learner_done
+ assert task.result['a_list'] == [1, 2] * 4
+ assert task.status == TaskStatus.COMPLETED
+
+ task = conn.new_task({'name': 'fake_task', 'task_info': {}})
+ task.start().join()
+ assert task.status == TaskStatus.FAILED
+ assert task.result == {'message': 'task name error'}
+
+ assert learner_aggregator.deal_with_get_resource()['gpu'] == len(learner_slaves)
+
+ learner_aggregator.close()
+ for learner_slave in learner_slaves:
+ learner_slave.close()
diff --git a/DI-engine/ding/worker/collector/__init__.py b/DI-engine/ding/worker/collector/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ccfb172607b6a194b30b60a9687426766a1103e
--- /dev/null
+++ b/DI-engine/ding/worker/collector/__init__.py
@@ -0,0 +1,18 @@
+# serial
+from .base_serial_collector import ISerialCollector, create_serial_collector, get_serial_collector_cls, \
+ to_tensor_transitions
+
+from .sample_serial_collector import SampleSerialCollector
+from .episode_serial_collector import EpisodeSerialCollector
+from .battle_episode_serial_collector import BattleEpisodeSerialCollector
+from .battle_sample_serial_collector import BattleSampleSerialCollector
+
+from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor, create_serial_evaluator
+from .interaction_serial_evaluator import InteractionSerialEvaluator
+from .battle_interaction_serial_evaluator import BattleInteractionSerialEvaluator
+from .metric_serial_evaluator import MetricSerialEvaluator, IMetric
+# parallel
+from .base_parallel_collector import BaseParallelCollector, create_parallel_collector, get_parallel_collector_cls
+from .zergling_parallel_collector import ZerglingParallelCollector
+from .marine_parallel_collector import MarineParallelCollector
+from .comm import BaseCommCollector, FlaskFileSystemCollector, create_comm_collector, NaiveCollector
diff --git a/DI-engine/ding/worker/collector/base_parallel_collector.py b/DI-engine/ding/worker/collector/base_parallel_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff9d28700fd5eea8b2034d03de3d97893a2eab7e
--- /dev/null
+++ b/DI-engine/ding/worker/collector/base_parallel_collector.py
@@ -0,0 +1,243 @@
+from typing import Any, Union, Tuple
+from abc import ABC, abstractmethod
+import sys
+from ditk import logging
+import copy
+from collections import namedtuple
+from functools import partial
+from easydict import EasyDict
+import torch
+
+from ding.policy import Policy
+from ding.envs import BaseEnvManager
+from ding.utils.autolog import LoggedValue, LoggedModel, TickTime
+from ding.utils import build_logger, EasyTimer, get_task_uid, import_module, pretty_print, PARALLEL_COLLECTOR_REGISTRY
+from ding.torch_utils import build_log_buffer, to_tensor, to_ndarray
+
+
+class BaseParallelCollector(ABC):
+ """
+ Overview:
+ Abstract baseclass for collector.
+ Interfaces:
+ __init__, info, error, debug, get_finish_info, start, close, _setup_timer, _setup_logger, _iter_after_hook,
+ _policy_inference, _env_step, _process_timestep, _finish_task, _update_policy, _start_thread, _join_thread
+ Property:
+ policy
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ def __init__(self, cfg: EasyDict) -> None:
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config dict
+ """
+ self._cfg = cfg
+ self._eval_flag = cfg.eval_flag
+ self._prefix = 'EVALUATOR' if self._eval_flag else 'COLLECTOR'
+ self._collector_uid = get_task_uid()
+ self._logger, self._monitor, self._log_buffer = self._setup_logger()
+ self._end_flag = False
+ self._setup_timer()
+ self._iter_count = 0
+ self.info("\nCFG INFO:\n{}".format(pretty_print(cfg, direct_print=False)))
+
+ def info(self, s: str) -> None:
+ self._logger.info("[{}({})]: {}".format(self._prefix, self._collector_uid, s))
+
+ def debug(self, s: str) -> None:
+ self._logger.debug("[{}({})]: {}".format(self._prefix, self._collector_uid, s))
+
+ def error(self, s: str) -> None:
+ self._logger.error("[{}({})]: {}".format(self._prefix, self._collector_uid, s))
+
+ def _setup_timer(self) -> None:
+ """
+ Overview:
+ Setup TimeWrapper for base_collector. TimeWrapper is a decent timer wrapper that can be used easily.
+ You can refer to ``ding/utils/time_helper.py``.
+
+ Note:
+ - _policy_inference (:obj:`Callable`): The wrapper to acquire a policy's time.
+ - _env_step (:obj:`Callable`): The wrapper to acquire a environment's time.
+ """
+ self._timer = EasyTimer()
+
+ def policy_wrapper(fn):
+
+ def wrapper(*args, **kwargs):
+ with self._timer:
+ ret = fn(*args, **kwargs)
+ self._log_buffer['policy_time'] = self._timer.value
+ return ret
+
+ return wrapper
+
+ def env_wrapper(fn):
+
+ def wrapper(*args, **kwargs):
+ with self._timer:
+ ret = fn(*args, **kwargs)
+ size = sys.getsizeof(ret) / (1024 * 1024) # MB
+ self._log_buffer['env_time'] = self._timer.value
+ self._log_buffer['timestep_size'] = size
+ self._log_buffer['norm_env_time'] = self._timer.value / size
+ return ret
+
+ return wrapper
+
+ self._policy_inference = policy_wrapper(self._policy_inference)
+ self._env_step = env_wrapper(self._env_step)
+
+ def _setup_logger(self) -> Tuple[logging.Logger, 'TickMonitor', 'LogDict']: # noqa
+ """
+ Overview:
+ Setup logger for base_collector. Logger includes logger, monitor and log buffer dict.
+ Returns:
+ - logger (:obj:`logging.Logger`): logger that displays terminal output
+ - monitor (:obj:`TickMonitor`): monitor that is related info of one interation with env
+ - log_buffer (:obj:`LogDict`): log buffer dict
+ """
+ path = './{}/log/{}'.format(self._cfg.exp_name, self._prefix.lower())
+ name = '{}'.format(self._collector_uid)
+ logger, _ = build_logger(path, name, need_tb=False)
+ monitor = TickMonitor(TickTime(), expire=self._cfg.print_freq * 2)
+ log_buffer = build_log_buffer()
+ return logger, monitor, log_buffer
+
+ def start(self) -> None:
+ self._end_flag = False
+ self._update_policy()
+ self._start_thread()
+ while not self._end_flag:
+ obs = self._env_manager.ready_obs
+ obs = to_tensor(obs, dtype=torch.float32)
+ action = self._policy_inference(obs)
+ action = to_ndarray(action)
+ timestep = self._env_step(action)
+ timestep = to_tensor(timestep, dtype=torch.float32)
+ self._process_timestep(timestep)
+ self._iter_after_hook()
+ if self._env_manager.done:
+ break
+
+ def close(self) -> None:
+ if self._end_flag:
+ return
+ self._end_flag = True
+ self._join_thread()
+
+ def _iter_after_hook(self):
+ # log_buffer -> tick_monitor -> monitor.step
+ for k, v in self._log_buffer.items():
+ setattr(self._monitor, k, v)
+ self._monitor.time.step()
+ # Print info
+ if self._iter_count % self._cfg.print_freq == 0:
+ self.debug('{}TimeStep{}{}'.format('=' * 35, self._iter_count, '=' * 35))
+ # tick_monitor -> var_dict
+ var_dict = {}
+ for k in self._log_buffer:
+ for attr in self._monitor.get_property_attribute(k):
+ k_attr = k + '_' + attr
+ var_dict[k_attr] = getattr(self._monitor, attr)[k]()
+ self._logger.debug(self._logger.get_tabulate_vars_hor(var_dict))
+ self._log_buffer.clear()
+ self._iter_count += 1
+
+ @abstractmethod
+ def get_finish_info(self) -> dict:
+ raise NotImplementedError
+
+ @abstractmethod
+ def __repr__(self) -> str:
+ raise NotImplementedError
+
+ @abstractmethod
+ def _policy_inference(self, obs: Any) -> Any:
+ raise NotImplementedError
+
+ @abstractmethod
+ def _env_step(self, action: Any) -> Any:
+ raise NotImplementedError
+
+ @abstractmethod
+ def _process_timestep(self, timestep: namedtuple) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def _update_policy(self) -> None:
+ raise NotImplementedError
+
+ def _start_thread(self) -> None:
+ pass
+
+ def _join_thread(self) -> None:
+ pass
+
+ @property
+ def policy(self) -> Policy:
+ return self._policy
+
+ @policy.setter
+ def policy(self, _policy: Policy) -> None:
+ self._policy = _policy
+
+ @property
+ def env_manager(self) -> BaseEnvManager:
+ return self._env_manager
+
+ @env_manager.setter
+ def env_manager(self, _env_manager: BaseEnvManager) -> None:
+ self._env_manager = _env_manager
+
+
+def create_parallel_collector(cfg: EasyDict) -> BaseParallelCollector:
+ import_module(cfg.get('import_names', []))
+ return PARALLEL_COLLECTOR_REGISTRY.build(cfg.type, cfg=cfg)
+
+
+def get_parallel_collector_cls(cfg: EasyDict) -> type:
+ import_module(cfg.get('import_names', []))
+ return PARALLEL_COLLECTOR_REGISTRY.get(cfg.type)
+
+
+class TickMonitor(LoggedModel):
+ """
+ Overview:
+ TickMonitor is to monitor related info of one interation with env.
+ Info include: policy_time, env_time, norm_env_time, timestep_size...
+ These info variables would first be recorded in ``log_buffer``, then in ``self._iter_after_hook`` will vars in
+ in this monitor be updated by``log_buffer``, then printed to text logger and tensorboard logger.
+ Interface:
+ __init__, fixed_time, current_time, freeze, unfreeze, register_attribute_value, __getattr__
+ Property:
+ time, expire
+ """
+ policy_time = LoggedValue(float)
+ env_time = LoggedValue(float)
+ timestep_size = LoggedValue(float)
+ norm_env_time = LoggedValue(float)
+
+ def __init__(self, time_: 'BaseTime', expire: Union[int, float]): # noqa
+ LoggedModel.__init__(self, time_, expire)
+ self.__register()
+
+ def __register(self):
+
+ def __avg_func(prop_name: str) -> float:
+ records = self.range_values[prop_name]()
+ _list = [_value for (_begin_time, _end_time), _value in records]
+ return sum(_list) / len(_list) if len(_list) != 0 else 0
+
+ self.register_attribute_value('avg', 'policy_time', partial(__avg_func, prop_name='policy_time'))
+ self.register_attribute_value('avg', 'env_time', partial(__avg_func, prop_name='env_time'))
+ self.register_attribute_value('avg', 'timestep_size', partial(__avg_func, prop_name='timestep_size'))
+ self.register_attribute_value('avg', 'norm_env_time', partial(__avg_func, prop_name='norm_env_time'))
diff --git a/DI-engine/ding/worker/collector/base_serial_collector.py b/DI-engine/ding/worker/collector/base_serial_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..caa88410af33c8fb74d93883c0c92fe9c89f5453
--- /dev/null
+++ b/DI-engine/ding/worker/collector/base_serial_collector.py
@@ -0,0 +1,229 @@
+from abc import ABC, abstractmethod, abstractproperty
+from typing import List, Dict, Any, Optional, Union
+from collections import namedtuple
+from easydict import EasyDict
+import copy
+
+from ding.envs import BaseEnvManager
+from ding.utils import SERIAL_COLLECTOR_REGISTRY, import_module
+from ding.torch_utils import to_tensor
+
+INF = float("inf")
+
+
+class ISerialCollector(ABC):
+ """
+ Overview:
+ Abstract baseclass for serial collector.
+ Interfaces:
+ default_config, reset_env, reset_policy, reset, collect
+ Property:
+ envstep
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ """
+ Overview:
+ Get collector's default config. We merge collector's default config with other default configs\
+ and user's config to get the final config.
+ Return:
+ cfg: (:obj:`EasyDict`): collector's default config
+ """
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ @abstractmethod
+ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset collector's environment. In some case, we need collector use the same policy to collect \
+ data in different environments. We can use reset_env to reset the environment.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
+ """
+ Overview:
+ Reset collector's policy. In some case, we need collector work in this same environment but use\
+ different policy to collect data. We can use reset_policy to reset the policy.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset collector's policy and environment. Use new policy and environment to collect data.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def collect(self, per_collect_target: Any) -> List[Any]:
+ """
+ Overview:
+ Collect the corresponding data according to the specified target and return. \
+ There are different definitions in episode and sample mode.
+ """
+ raise NotImplementedError
+
+ @abstractproperty
+ def envstep(self) -> int:
+ """
+ Overview:
+ Get the total envstep num.
+ """
+ raise NotImplementedError
+
+
+def create_serial_collector(cfg: EasyDict, **kwargs) -> ISerialCollector:
+ """
+ Overview:
+ Create a specific collector instance based on the config.
+ """
+ import_module(cfg.get('import_names', []))
+ return SERIAL_COLLECTOR_REGISTRY.build(cfg.type, cfg=cfg, **kwargs)
+
+
+def get_serial_collector_cls(cfg: EasyDict) -> type:
+ """
+ Overview:
+ Get the specific collector class according to the config.
+ """
+ assert hasattr(cfg, 'type'), "{}-{}-{}".format(type(cfg), cfg.keys(), cfg['type'])
+ import_module(cfg.get('import_names', []))
+ return SERIAL_COLLECTOR_REGISTRY.get(cfg.type)
+
+
+class CachePool(object):
+ """
+ Overview:
+ CachePool is the repository of cache items.
+ Interfaces:
+ __init__, update, __getitem__, reset
+ """
+
+ def __init__(self, name: str, env_num: int, deepcopy: bool = False) -> None:
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - name (:obj:`str`): name of cache
+ - env_num (:obj:`int`): number of environments
+ - deepcopy (:obj:`bool`): whether to deepcopy data
+ """
+ self._pool = [None for _ in range(env_num)]
+ # TODO(nyz) whether must use deepcopy
+ self._deepcopy = deepcopy
+
+ def update(self, data: Union[Dict[int, Any], list]) -> None:
+ """
+ Overview:
+ Update elements in cache pool.
+ Arguments:
+ - data (:obj:`Dict[int, Any]`): A dict containing update index-value pairs. Key is index in cache pool, \
+ and value is the new element.
+ """
+ if isinstance(data, dict):
+ data = [data]
+ for index in range(len(data)):
+ for i in data[index].keys():
+ d = data[index][i]
+ if self._deepcopy:
+ copy_d = copy.deepcopy(d)
+ else:
+ copy_d = d
+ if index == 0:
+ self._pool[i] = [copy_d]
+ else:
+ self._pool[i].append(copy_d)
+
+ def __getitem__(self, idx: int) -> Any:
+ """
+ Overview:
+ Get item in cache pool.
+ Arguments:
+ - idx (:obj:`int`): The index of the item we need to get.
+ Return:
+ - item (:obj:`Any`): The item we get.
+ """
+ data = self._pool[idx]
+ if data is not None and len(data) == 1:
+ data = data[0]
+ return data
+
+ def reset(self, idx: int) -> None:
+ """
+ Overview:
+ Reset the cache pool.
+ Arguments:
+ - idx (:obj:`int`): The index of the position we need to reset.
+ """
+ self._pool[idx] = None
+
+
+class TrajBuffer(list):
+ """
+ Overview:
+ TrajBuffer is used to store traj_len pieces of transitions.
+ Interfaces:
+ __init__, append
+ """
+
+ def __init__(self, maxlen: int, *args, deepcopy: bool = False, **kwargs) -> None:
+ """
+ Overview:
+ Initialization trajBuffer.
+ Arguments:
+ - maxlen (:obj:`int`): The maximum length of trajectory buffer.
+ - deepcopy (:obj:`bool`): Whether to deepcopy data when do operation.
+ """
+ self._maxlen = maxlen
+ self._deepcopy = deepcopy
+ super().__init__(*args, **kwargs)
+
+ def append(self, data: Any) -> None:
+ """
+ Overview:
+ Append data to trajBuffer.
+ """
+ if self._maxlen is not None:
+ while len(self) >= self._maxlen:
+ del self[0]
+ if self._deepcopy:
+ data = copy.deepcopy(data)
+ super().append(data)
+
+
+def to_tensor_transitions(data: List[Dict[str, Any]], shallow_copy_next_obs: bool = True) -> List[Dict[str, Any]]:
+ """
+ Overview:
+ Transform ths original transition return from env to tensor format.
+ Argument:
+ - data (:obj:`List[Dict[str, Any]]`): The data that will be transformed to tensor.
+ - shallow_copy_next_obs (:obj:`bool`): Whether to shallow copy next_obs. Default: True.
+ Return:
+ - data (:obj:`List[Dict[str, Any]]`): The transformed tensor-like data.
+
+ .. tip::
+ In order to save memory, If there are next_obs in the passed data, we do special \
+ treatment on next_obs so that the next_obs of each state in the data fragment is \
+ the next state's obs and the next_obs of the last state is its own next_obsself. \
+ Besides, we set transform_scalar to False to avoid the extra ``.item()`` operation.
+ """
+ if 'next_obs' not in data[0]:
+ return to_tensor(data, transform_scalar=False)
+ else:
+ # to_tensor will assign the separate memory to next_obs, if shallow_copy_next_obs is True,
+ # we can add ignore_keys to avoid this data copy for saving memory of next_obs.
+ if shallow_copy_next_obs:
+ data = to_tensor(data, ignore_keys=['next_obs'], transform_scalar=False)
+ for i in range(len(data) - 1):
+ data[i]['next_obs'] = data[i + 1]['obs']
+ data[-1]['next_obs'] = to_tensor(data[-1]['next_obs'], transform_scalar=False)
+ return data
+ else:
+ data = to_tensor(data, transform_scalar=False)
+ return data
diff --git a/DI-engine/ding/worker/collector/base_serial_evaluator.py b/DI-engine/ding/worker/collector/base_serial_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5a31319ece00154338635cb9114941fe9b5ad7a
--- /dev/null
+++ b/DI-engine/ding/worker/collector/base_serial_evaluator.py
@@ -0,0 +1,220 @@
+from typing import Any, Optional, Callable, Tuple
+from abc import ABC, abstractmethod
+from collections import namedtuple, deque
+from easydict import EasyDict
+import copy
+import numpy as np
+import torch
+
+from ding.utils import SERIAL_EVALUATOR_REGISTRY, import_module, lists_to_dicts
+from ding.torch_utils import to_tensor, to_ndarray, tensor_to_list
+
+
+class ISerialEvaluator(ABC):
+ """
+ Overview:
+ Basic interface class for serial evaluator.
+ Interfaces:
+ reset, reset_policy, reset_env, close, should_eval, eval
+ Property:
+ env, policy
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ """
+ Overview:
+ Get evaluator's default config. We merge evaluator's default config with other default configs\
+ and user's config to get the final config.
+ Return:
+ cfg: (:obj:`EasyDict`): evaluator's default config
+ """
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ @abstractmethod
+ def reset_env(self, _env: Optional[Any] = None) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[Any] = None) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def close(self) -> None:
+ raise NotImplementedError
+
+ @abstractmethod
+ def should_eval(self, train_iter: int) -> bool:
+ raise NotImplementedError
+
+ @abstractmethod
+ def eval(
+ self,
+ save_ckpt_fn: Callable = None,
+ train_iter: int = -1,
+ envstep: int = -1,
+ n_episode: Optional[int] = None
+ ) -> Any:
+ raise NotImplementedError
+
+
+def create_serial_evaluator(cfg: EasyDict, **kwargs) -> ISerialEvaluator:
+ """
+ Overview:
+ Create a specific evaluator instance based on the config.
+ """
+ import_module(cfg.get('import_names', []))
+ if 'type' not in cfg:
+ cfg.type = 'interaction'
+ return SERIAL_EVALUATOR_REGISTRY.build(cfg.type, cfg=cfg, **kwargs)
+
+
+class VectorEvalMonitor(object):
+ """
+ Overview:
+ In some cases, different environment in evaluator may collect different length episode. For example, \
+ suppose we want to collect 12 episodes in evaluator but only have 5 environments, if we didn’t do \
+ any thing, it is likely that we will get more short episodes than long episodes. As a result, \
+ our average reward will have a bias and may not be accurate. we use VectorEvalMonitor to solve the problem.
+ Interfaces:
+ __init__, is_finished, update_info, update_reward, get_episode_return, get_latest_reward, get_current_episode,\
+ get_episode_info
+ """
+
+ def __init__(self, env_num: int, n_episode: int) -> None:
+ """
+ Overview:
+ Init method. According to the number of episodes and the number of environments, determine how many \
+ episodes need to be opened for each environment, and initialize the reward, info and other \
+ information
+ Arguments:
+ - env_num (:obj:`int`): the number of episodes need to be open
+ - n_episode (:obj:`int`): the number of environments
+ """
+ assert n_episode >= env_num, "n_episode < env_num, please decrease the number of eval env"
+ self._env_num = env_num
+ self._n_episode = n_episode
+ each_env_episode = [n_episode // env_num for _ in range(env_num)]
+ for i in range(n_episode % env_num):
+ each_env_episode[i] += 1
+ self._video = {
+ env_id: deque([[] for _ in range(maxlen)], maxlen=maxlen)
+ for env_id, maxlen in enumerate(each_env_episode)
+ }
+ self._reward = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)}
+ self._info = {env_id: deque(maxlen=maxlen) for env_id, maxlen in enumerate(each_env_episode)}
+
+ def is_finished(self) -> bool:
+ """
+ Overview:
+ Determine whether the evaluator has completed the work.
+ Return:
+ - result: (:obj:`bool`): whether the evaluator has completed the work
+ """
+ return all([len(v) == v.maxlen for v in self._reward.values()])
+
+ def update_info(self, env_id: int, info: Any) -> None:
+ """
+ Overview:
+ Update the information of the environment indicated by env_id.
+ Arguments:
+ - env_id: (:obj:`int`): the id of the environment we need to update information
+ - info: (:obj:`Any`): the information we need to update
+ """
+ info = tensor_to_list(info)
+ self._info[env_id].append(info)
+
+ def update_reward(self, env_id: int, reward: Any) -> None:
+ """
+ Overview:
+ Update the reward indicated by env_id.
+ Arguments:
+ - env_id: (:obj:`int`): the id of the environment we need to update the reward
+ - reward: (:obj:`Any`): the reward we need to update
+ """
+ if isinstance(reward, torch.Tensor):
+ reward = reward.item()
+ self._reward[env_id].append(reward)
+
+ def update_video(self, imgs):
+ for env_id, img in imgs.items():
+ if len(self._reward[env_id]) == self._reward[env_id].maxlen:
+ continue
+ self._video[env_id][len(self._reward[env_id])].append(img)
+
+ def get_video(self):
+ """
+ Overview:
+ Convert list of videos into [N, T, C, H, W] tensor, containing
+ worst, median, best evaluation trajectories for video logging.
+ """
+ videos = sum([list(v) for v in self._video.values()], [])
+ videos = [np.transpose(np.stack(video, 0), [0, 3, 1, 2]) for video in videos]
+ sortarg = np.argsort(self.get_episode_return())
+ # worst, median(s), best
+ if len(sortarg) == 1:
+ idxs = [sortarg[0]]
+ elif len(sortarg) == 2:
+ idxs = [sortarg[0], sortarg[-1]]
+ elif len(sortarg) == 3:
+ idxs = [sortarg[0], sortarg[len(sortarg) // 2], sortarg[-1]]
+ else:
+ # TensorboardX pad the number of videos to even numbers with black frames,
+ # therefore providing even number of videos prevents black frames being rendered.
+ idxs = [sortarg[0], sortarg[len(sortarg) // 2 - 1], sortarg[len(sortarg) // 2], sortarg[-1]]
+ videos = [videos[idx] for idx in idxs]
+ # pad videos to the same length with last frames
+ max_length = max(video.shape[0] for video in videos)
+ for i in range(len(videos)):
+ if videos[i].shape[0] < max_length:
+ padding = np.tile([videos[i][-1]], (max_length - videos[i].shape[0], 1, 1, 1))
+ videos[i] = np.concatenate([videos[i], padding], 0)
+ videos = np.stack(videos, 0)
+ assert len(videos.shape) == 5, 'Need [N, T, C, H, W] input tensor for video logging!'
+ return videos
+
+ def get_episode_return(self) -> list:
+ """
+ Overview:
+ Sum up all reward and get the total return of one episode.
+ """
+ return sum([list(v) for v in self._reward.values()], []) # sum(iterable, start)
+
+ def get_latest_reward(self, env_id: int) -> int:
+ """
+ Overview:
+ Get the latest reward of a certain environment.
+ Arguments:
+ - env_id: (:obj:`int`): the id of the environment we need to get reward.
+ """
+ return self._reward[env_id][-1]
+
+ def get_current_episode(self) -> int:
+ """
+ Overview:
+ Get the current episode. We can know which episode our evaluator is executing now.
+ """
+ return sum([len(v) for v in self._reward.values()])
+
+ def get_episode_info(self) -> dict:
+ """
+ Overview:
+ Get all episode information, such as total return of one episode.
+ """
+ if len(self._info[0]) == 0:
+ return None
+ else:
+ total_info = sum([list(v) for v in self._info.values()], [])
+ total_info = lists_to_dicts(total_info)
+ new_dict = {}
+ for k in total_info.keys():
+ if np.isscalar(total_info[k][0]):
+ new_dict[k + '_mean'] = np.mean(total_info[k])
+ total_info.update(new_dict)
+ return total_info
diff --git a/DI-engine/ding/worker/collector/battle_episode_serial_collector.py b/DI-engine/ding/worker/collector/battle_episode_serial_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..6609adcaea8f0b970881978d7cad1a73bc8b983c
--- /dev/null
+++ b/DI-engine/ding/worker/collector/battle_episode_serial_collector.py
@@ -0,0 +1,339 @@
+from typing import Optional, Any, List, Tuple
+from collections import namedtuple, deque
+from easydict import EasyDict
+import numpy as np
+import torch
+
+from ding.envs import BaseEnvManager
+from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, dicts_to_lists
+from ding.torch_utils import to_tensor, to_ndarray
+from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions
+
+
+@SERIAL_COLLECTOR_REGISTRY.register('episode_1v1')
+class BattleEpisodeSerialCollector(ISerialCollector):
+ """
+ Overview:
+ Episode collector(n_episode) with two policy battle
+ Interfaces:
+ __init__, reset, reset_env, reset_policy, collect, close
+ Property:
+ envstep
+ """
+
+ config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100, get_train_sample=False)
+
+ def __init__(
+ self,
+ cfg: EasyDict,
+ env: BaseEnvManager = None,
+ policy: List[namedtuple] = None,
+ tb_logger: 'SummaryWriter' = None, # noqa
+ exp_name: Optional[str] = 'default_experiment',
+ instance_name: Optional[str] = 'collector'
+ ) -> None:
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config dict
+ - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager)
+ - policy (:obj:`List[namedtuple]`): the api namedtuple of collect_mode policy
+ - tb_logger (:obj:`SummaryWriter`): tensorboard handle
+ """
+ self._exp_name = exp_name
+ self._instance_name = instance_name
+ self._collect_print_freq = cfg.collect_print_freq
+ self._deepcopy_obs = cfg.deepcopy_obs
+ self._transform_obs = cfg.transform_obs
+ self._cfg = cfg
+ self._timer = EasyTimer()
+ self._end_flag = False
+
+ if tb_logger is not None:
+ self._logger, _ = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
+ )
+ self._tb_logger = tb_logger
+ else:
+ self._logger, self._tb_logger = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
+ )
+ self._traj_len = float("inf")
+ self.reset(policy, env)
+
+ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset the environment.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the collector with the new passed \
+ in environment and launch.
+ Arguments:
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self._env = _env
+ self._env.launch()
+ self._env_num = self._env.env_num
+ else:
+ self._env.reset()
+
+ def reset_policy(self, _policy: Optional[List[namedtuple]] = None) -> None:
+ """
+ Overview:
+ Reset the policy.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the collector with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy
+ """
+ assert hasattr(self, '_env'), "please set env first"
+ if _policy is not None:
+ assert len(_policy) == 2, "1v1 episode collector needs 2 policy, but found {}".format(len(_policy))
+ self._policy = _policy
+ self._default_n_episode = _policy[0].get_attribute('cfg').collect.get('n_episode', None)
+ self._unroll_len = _policy[0].get_attribute('unroll_len')
+ self._on_policy = _policy[0].get_attribute('cfg').on_policy
+ self._traj_len = INF
+ self._logger.debug(
+ 'Set default n_episode mode(n_episode({}), env_num({}), traj_len({}))'.format(
+ self._default_n_episode, self._env_num, self._traj_len
+ )
+ )
+ for p in self._policy:
+ p.reset()
+
+ def reset(self, _policy: Optional[List[namedtuple]] = None, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset the environment and policy.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the collector with the new passed \
+ in environment and launch.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the collector with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self.reset_env(_env)
+ if _policy is not None:
+ self.reset_policy(_policy)
+
+ self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs)
+ self._policy_output_pool = CachePool('policy_output', self._env_num)
+ # _traj_buffer is {env_id: {policy_id: TrajBuffer}}, is used to store traj_len pieces of transitions
+ self._traj_buffer = {
+ env_id: {policy_id: TrajBuffer(maxlen=self._traj_len)
+ for policy_id in range(2)}
+ for env_id in range(self._env_num)
+ }
+ self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)}
+
+ self._episode_info = []
+ self._total_envstep_count = 0
+ self._total_episode_count = 0
+ self._total_duration = 0
+ self._last_train_iter = 0
+ self._end_flag = False
+
+ def _reset_stat(self, env_id: int) -> None:
+ """
+ Overview:
+ Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\
+ and env_info. Reset these states according to env_id. You can refer to base_serial_collector\
+ to get more messages.
+ Arguments:
+ - env_id (:obj:`int`): the id where we need to reset the collector's state
+ """
+ for i in range(2):
+ self._traj_buffer[env_id][i].clear()
+ self._obs_pool.reset(env_id)
+ self._policy_output_pool.reset(env_id)
+ self._env_info[env_id] = {'time': 0., 'step': 0}
+
+ @property
+ def envstep(self) -> int:
+ """
+ Overview:
+ Print the total envstep count.
+ Return:
+ - envstep (:obj:`int`): the total envstep count
+ """
+ return self._total_envstep_count
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close the collector. If end_flag is False, close the environment, flush the tb_logger\
+ and close the tb_logger.
+ """
+ if self._end_flag:
+ return
+ self._end_flag = True
+ self._env.close()
+ self._tb_logger.flush()
+ self._tb_logger.close()
+
+ def __del__(self) -> None:
+ """
+ Overview:
+ Execute the close command and close the collector. __del__ is automatically called to \
+ destroy the collector instance when the collector finishes its work
+ """
+ self.close()
+
+ def collect(self,
+ n_episode: Optional[int] = None,
+ train_iter: int = 0,
+ policy_kwargs: Optional[dict] = None) -> Tuple[List[Any], List[Any]]:
+ """
+ Overview:
+ Collect `n_episode` data with policy_kwargs, which is already trained `train_iter` iterations
+ Arguments:
+ - n_episode (:obj:`int`): the number of collecting data episode
+ - train_iter (:obj:`int`): the number of training iteration
+ - policy_kwargs (:obj:`dict`): the keyword args for policy forward
+ Returns:
+ - return_data (:obj:`Tuple[List, List]`): A tuple with training sample(data) and episode info, \
+ the former is a list containing collected episodes if not get_train_sample, \
+ otherwise, return train_samples split by unroll_len.
+ """
+ if n_episode is None:
+ if self._default_n_episode is None:
+ raise RuntimeError("Please specify collect n_episode")
+ else:
+ n_episode = self._default_n_episode
+ assert n_episode >= self._env_num, "Please make sure n_episode >= env_num"
+ if policy_kwargs is None:
+ policy_kwargs = {}
+ collected_episode = 0
+ return_data = [[] for _ in range(2)]
+ return_info = [[] for _ in range(2)]
+ ready_env_id = set()
+ remain_episode = n_episode
+
+ while True:
+ with self._timer:
+ # Get current env obs.
+ obs = self._env.ready_obs
+ new_available_env_id = set(obs.keys()).difference(ready_env_id)
+ ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode]))
+ remain_episode -= min(len(new_available_env_id), remain_episode)
+ obs = {env_id: obs[env_id] for env_id in ready_env_id}
+ # Policy forward.
+ self._obs_pool.update(obs)
+ if self._transform_obs:
+ obs = to_tensor(obs, dtype=torch.float32)
+ obs = dicts_to_lists(obs)
+ policy_output = [p.forward(obs[i], **policy_kwargs) for i, p in enumerate(self._policy)]
+ self._policy_output_pool.update(policy_output)
+ # Interact with env.
+ actions = {}
+ for env_id in ready_env_id:
+ actions[env_id] = []
+ for output in policy_output:
+ actions[env_id].append(output[env_id]['action'])
+ actions = to_ndarray(actions)
+ timesteps = self._env.step(actions)
+
+ # TODO(nyz) this duration may be inaccurate in async env
+ interaction_duration = self._timer.value / len(timesteps)
+
+ # TODO(nyz) vectorize this for loop
+ for env_id, timestep in timesteps.items():
+ self._env_info[env_id]['step'] += 1
+ self._total_envstep_count += 1
+ with self._timer:
+ for policy_id, policy in enumerate(self._policy):
+ policy_timestep_data = [d[policy_id] if not isinstance(d, bool) else d for d in timestep]
+ policy_timestep = type(timestep)(*policy_timestep_data)
+ transition = self._policy[policy_id].process_transition(
+ self._obs_pool[env_id][policy_id], self._policy_output_pool[env_id][policy_id],
+ policy_timestep
+ )
+ transition['collect_iter'] = train_iter
+ self._traj_buffer[env_id][policy_id].append(transition)
+ # prepare data
+ if timestep.done:
+ transitions = to_tensor_transitions(
+ self._traj_buffer[env_id][policy_id], not self._deepcopy_obs
+ )
+ if self._cfg.get_train_sample:
+ train_sample = self._policy[policy_id].get_train_sample(transitions)
+ return_data[policy_id].extend(train_sample)
+ else:
+ return_data[policy_id].append(transitions)
+ self._traj_buffer[env_id][policy_id].clear()
+
+ self._env_info[env_id]['time'] += self._timer.value + interaction_duration
+
+ # If env is done, record episode info and reset
+ if timestep.done:
+ self._total_episode_count += 1
+ info = {
+ 'reward0': timestep.info[0]['eval_episode_return'],
+ 'reward1': timestep.info[1]['eval_episode_return'],
+ 'time': self._env_info[env_id]['time'],
+ 'step': self._env_info[env_id]['step'],
+ }
+ collected_episode += 1
+ self._episode_info.append(info)
+ for i, p in enumerate(self._policy):
+ p.reset([env_id])
+ self._reset_stat(env_id)
+ ready_env_id.remove(env_id)
+ for policy_id in range(2):
+ return_info[policy_id].append(timestep.info[policy_id])
+ if collected_episode >= n_episode:
+ break
+ # log
+ self._output_log(train_iter)
+ return return_data, return_info
+
+ def _output_log(self, train_iter: int) -> None:
+ """
+ Overview:
+ Print the output log information. You can refer to Docs/Best Practice/How to understand\
+ training generated folders/Serial mode/log/collector for more details.
+ Arguments:
+ - train_iter (:obj:`int`): the number of training iteration.
+ """
+ if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0:
+ self._last_train_iter = train_iter
+ episode_count = len(self._episode_info)
+ envstep_count = sum([d['step'] for d in self._episode_info])
+ duration = sum([d['time'] for d in self._episode_info])
+ episode_return0 = [d['reward0'] for d in self._episode_info]
+ episode_return1 = [d['reward1'] for d in self._episode_info]
+ self._total_duration += duration
+ info = {
+ 'episode_count': episode_count,
+ 'envstep_count': envstep_count,
+ 'avg_envstep_per_episode': envstep_count / episode_count,
+ 'avg_envstep_per_sec': envstep_count / duration,
+ 'avg_episode_per_sec': episode_count / duration,
+ 'collect_time': duration,
+ 'reward0_mean': np.mean(episode_return0),
+ 'reward0_std': np.std(episode_return0),
+ 'reward0_max': np.max(episode_return0),
+ 'reward0_min': np.min(episode_return0),
+ 'reward1_mean': np.mean(episode_return1),
+ 'reward1_std': np.std(episode_return1),
+ 'reward1_max': np.max(episode_return1),
+ 'reward1_min': np.min(episode_return1),
+ 'total_envstep_count': self._total_envstep_count,
+ 'total_episode_count': self._total_episode_count,
+ 'total_duration': self._total_duration,
+ }
+ self._episode_info.clear()
+ self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))
+ for k, v in info.items():
+ self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
+ if k in ['total_envstep_count']:
+ continue
+ self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)
diff --git a/DI-engine/ding/worker/collector/battle_interaction_serial_evaluator.py b/DI-engine/ding/worker/collector/battle_interaction_serial_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..9d4700ca5740ebea6fd06013d36a5d541329ed90
--- /dev/null
+++ b/DI-engine/ding/worker/collector/battle_interaction_serial_evaluator.py
@@ -0,0 +1,277 @@
+from typing import List, Dict, Any, Optional, Callable, Tuple
+from collections import namedtuple, deque
+from easydict import EasyDict
+from functools import reduce
+import copy
+import numpy as np
+import torch
+
+from ding.utils import build_logger, EasyTimer, deep_merge_dicts, lists_to_dicts, dicts_to_lists, \
+ SERIAL_EVALUATOR_REGISTRY
+from ding.envs import BaseEnvManager
+from ding.torch_utils import to_tensor, to_ndarray, tensor_to_list, to_item
+from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor
+
+
+@SERIAL_EVALUATOR_REGISTRY.register('battle_interaction')
+class BattleInteractionSerialEvaluator(ISerialEvaluator):
+ """
+ Overview:
+ Multiple player battle evaluator class.
+ Interfaces:
+ __init__, reset, reset_policy, reset_env, close, should_eval, eval
+ Property:
+ env, policy
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ """
+ Overview:
+ Get evaluator's default config. We merge evaluator's default config with other default configs\
+ and user's config to get the final config.
+ Return:
+ cfg: (:obj:`EasyDict`): evaluator's default config
+ """
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ config = dict(
+ # Evaluate every "eval_freq" training iterations.
+ eval_freq=50,
+ )
+
+ def __init__(
+ self,
+ cfg: dict,
+ env: BaseEnvManager = None,
+ policy: List[namedtuple] = None,
+ tb_logger: 'SummaryWriter' = None, # noqa
+ exp_name: Optional[str] = 'default_experiment',
+ instance_name: Optional[str] = 'evaluator',
+ ) -> None:
+ """
+ Overview:
+ Init method. Load config and use ``self._cfg`` setting to build common serial evaluator components,
+ e.g. logger helper, timer.
+ Policy is not initialized here, but set afterwards through policy setter.
+ Arguments:
+ - cfg (:obj:`EasyDict`)
+ """
+ self._cfg = cfg
+ self._exp_name = exp_name
+ self._instance_name = instance_name
+ if tb_logger is not None:
+ self._logger, _ = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
+ )
+ self._tb_logger = tb_logger
+ else:
+ self._logger, self._tb_logger = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
+ )
+ self.reset(policy, env)
+
+ self._timer = EasyTimer()
+ self._default_n_episode = cfg.n_episode
+ self._stop_value = cfg.stop_value
+
+ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset evaluator's environment. In some case, we need evaluator use the same policy in different \
+ environments. We can use reset_env to reset the environment.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the evaluator with the \
+ new passed in environment and launch.
+ Arguments:
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self._env = _env
+ self._env.launch()
+ self._env_num = self._env.env_num
+ else:
+ self._env.reset()
+
+ def reset_policy(self, _policy: Optional[List[namedtuple]] = None) -> None:
+ """
+ Overview:
+ Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\
+ different policy. We can use reset_policy to reset the policy.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of eval_mode policy
+ """
+ assert hasattr(self, '_env'), "please set env first"
+ if _policy is not None:
+ assert len(_policy) > 1, "battle evaluator needs more than 1 policy, but found {}".format(len(_policy))
+ self._policy = _policy
+ self._policy_num = len(self._policy)
+ for p in self._policy:
+ p.reset()
+
+ def reset(self, _policy: Optional[List[namedtuple]] = None, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset evaluator's policy and environment. Use new policy and environment to collect data.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the evaluator with the new passed in \
+ environment and launch.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of eval_mode policy
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self.reset_env(_env)
+ if _policy is not None:
+ self.reset_policy(_policy)
+ self._max_episode_return = float("-inf")
+ self._last_eval_iter = 0
+ self._end_flag = False
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\
+ and close the tb_logger.
+ """
+ if self._end_flag:
+ return
+ self._end_flag = True
+ self._env.close()
+ self._tb_logger.flush()
+ self._tb_logger.close()
+
+ def __del__(self):
+ """
+ Overview:
+ Execute the close command and close the evaluator. __del__ is automatically called \
+ to destroy the evaluator instance when the evaluator finishes its work
+ """
+ self.close()
+
+ def should_eval(self, train_iter: int) -> bool:
+ """
+ Overview:
+ Determine whether you need to start the evaluation mode, if the number of training has reached\
+ the maximum number of times to start the evaluator, return True
+ """
+ if (train_iter - self._last_eval_iter) < self._cfg.eval_freq and train_iter != 0:
+ return False
+ self._last_eval_iter = train_iter
+ return True
+
+ def eval(
+ self,
+ save_ckpt_fn: Callable = None,
+ train_iter: int = -1,
+ envstep: int = -1,
+ n_episode: Optional[int] = None
+ ) -> Tuple[bool, List[dict]]:
+ '''
+ Overview:
+ Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
+ Arguments:
+ - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward.
+ - train_iter (:obj:`int`): Current training iteration.
+ - envstep (:obj:`int`): Current env interaction step.
+ - n_episode (:obj:`int`): Number of evaluation episodes.
+ Returns:
+ - stop_flag (:obj:`bool`): Whether this training program can be ended.
+ - return_info (:obj:`list`): Environment information of each finished episode.
+ '''
+ if n_episode is None:
+ n_episode = self._default_n_episode
+ assert n_episode is not None, "please indicate eval n_episode"
+ envstep_count = 0
+ info = {}
+ # TODO replace return_info with episode_info (validated by the league demo case)
+ return_info = [[] for _ in range(self._policy_num)]
+ eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode)
+ self._env.reset()
+ for p in self._policy:
+ p.reset()
+
+ with self._timer:
+ while not eval_monitor.is_finished():
+ obs = self._env.ready_obs
+ ready_env_id = obs.keys()
+ obs = to_tensor(obs, dtype=torch.float32)
+ obs = dicts_to_lists(obs)
+ policy_output = [p.forward(obs[i]) for i, p in enumerate(self._policy)]
+ actions = {}
+ for env_id in ready_env_id:
+ actions[env_id] = []
+ for output in policy_output:
+ actions[env_id].append(output[env_id]['action'])
+ actions = to_ndarray(actions)
+ timesteps = self._env.step(actions)
+ timesteps = to_tensor(timesteps, dtype=torch.float32)
+ for env_id, t in timesteps.items():
+ if t.done:
+ # Env reset is done by env_manager automatically.
+ for p in self._policy:
+ p.reset([env_id])
+ # policy0 is regarded as main policy default
+ reward = t.info[0]['eval_episode_return']
+ if 'episode_info' in t.info[0]:
+ eval_monitor.update_info(env_id, t.info[0]['episode_info'])
+ eval_monitor.update_reward(env_id, reward)
+ for policy_id in range(self._policy_num):
+ return_info[policy_id].append(t.info[policy_id])
+ self._logger.info(
+ "[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format(
+ env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode()
+ )
+ )
+ envstep_count += 1
+ duration = self._timer.value
+ episode_return = eval_monitor.get_episode_return()
+ info = {
+ 'train_iter': train_iter,
+ 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter),
+ 'episode_count': n_episode,
+ 'envstep_count': envstep_count,
+ 'avg_envstep_per_episode': envstep_count / n_episode,
+ 'evaluate_time': duration,
+ 'avg_envstep_per_sec': envstep_count / duration,
+ 'avg_time_per_episode': n_episode / duration,
+ 'reward_mean': np.mean(episode_return),
+ 'reward_std': np.std(episode_return),
+ 'reward_max': np.max(episode_return),
+ 'reward_min': np.min(episode_return),
+ # 'each_reward': episode_return,
+ }
+ episode_info = eval_monitor.get_episode_info()
+ if episode_info is not None:
+ info.update(episode_info)
+ self._logger.info(self._logger.get_tabulate_vars_hor(info))
+ # self._logger.info(self._logger.get_tabulate_vars(info))
+ for k, v in info.items():
+ if k in ['train_iter', 'ckpt_name', 'each_reward']:
+ continue
+ if not np.isscalar(v):
+ continue
+ self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
+ self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep)
+ episode_return = np.mean(episode_return)
+ if episode_return > self._max_episode_return:
+ if save_ckpt_fn:
+ save_ckpt_fn('ckpt_best.pth.tar')
+ self._max_episode_return = episode_return
+ stop_flag = episode_return >= self._stop_value and train_iter > 0
+ if stop_flag:
+ self._logger.info(
+ "[DI-engine serial pipeline] " +
+ "Current episode_return: {} is greater than stop_value: {}".format(episode_return, self._stop_value) +
+ ", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details."
+ )
+ return_info = to_item(return_info)
+ return stop_flag, return_info
diff --git a/DI-engine/ding/worker/collector/battle_sample_serial_collector.py b/DI-engine/ding/worker/collector/battle_sample_serial_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..dffc43f5f79b711df43cd4acbd6b8d594a437d4a
--- /dev/null
+++ b/DI-engine/ding/worker/collector/battle_sample_serial_collector.py
@@ -0,0 +1,353 @@
+from typing import Optional, Any, List, Tuple
+from collections import namedtuple
+from easydict import EasyDict
+import numpy as np
+import torch
+
+from ding.envs import BaseEnvManager
+from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, dicts_to_lists, one_time_warning
+from ding.torch_utils import to_tensor, to_ndarray
+from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions
+
+
+@SERIAL_COLLECTOR_REGISTRY.register('sample_1v1')
+class BattleSampleSerialCollector(ISerialCollector):
+ """
+ Overview:
+ Sample collector(n_sample) with multiple(n VS n) policy battle
+ Interfaces:
+ __init__, reset, reset_env, reset_policy, collect, close
+ Property:
+ envstep
+ """
+
+ config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100)
+
+ def __init__(
+ self,
+ cfg: EasyDict,
+ env: BaseEnvManager = None,
+ policy: List[namedtuple] = None,
+ tb_logger: 'SummaryWriter' = None, # noqa
+ exp_name: Optional[str] = 'default_experiment',
+ instance_name: Optional[str] = 'collector'
+ ) -> None:
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config dict
+ - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager)
+ - policy (:obj:`List[namedtuple]`): the api namedtuple of collect_mode policy
+ - tb_logger (:obj:`SummaryWriter`): tensorboard handle
+ """
+ self._exp_name = exp_name
+ self._instance_name = instance_name
+ self._collect_print_freq = cfg.collect_print_freq
+ self._deepcopy_obs = cfg.deepcopy_obs
+ self._transform_obs = cfg.transform_obs
+ self._cfg = cfg
+ self._timer = EasyTimer()
+ self._end_flag = False
+
+ if tb_logger is not None:
+ self._logger, _ = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
+ )
+ self._tb_logger = tb_logger
+ else:
+ self._logger, self._tb_logger = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
+ )
+ self._traj_len = float("inf")
+ self.reset(policy, env)
+
+ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset the environment.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the collector with the new passed \
+ in environment and launch.
+ Arguments:
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self._env = _env
+ self._env.launch()
+ self._env_num = self._env.env_num
+ else:
+ self._env.reset()
+
+ def reset_policy(self, _policy: Optional[List[namedtuple]] = None) -> None:
+ """
+ Overview:
+ Reset the policy.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the collector with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy
+ """
+ assert hasattr(self, '_env'), "please set env first"
+ if _policy is not None:
+ assert len(_policy) > 1, "battle sample collector needs more than 1 policy, but found {}".format(
+ len(_policy)
+ )
+ self._policy = _policy
+ self._policy_num = len(self._policy)
+ self._default_n_sample = _policy[0].get_attribute('cfg').collect.get('n_sample', None)
+ self._unroll_len = _policy[0].get_attribute('unroll_len')
+ self._on_policy = _policy[0].get_attribute('cfg').on_policy
+ self._policy_collect_data = [
+ getattr(self._policy[i], 'collect_data', True) for i in range(self._policy_num)
+ ]
+ if self._default_n_sample is not None:
+ self._traj_len = max(
+ self._unroll_len,
+ self._default_n_sample // self._env_num + int(self._default_n_sample % self._env_num != 0)
+ )
+ self._logger.debug(
+ 'Set default n_sample mode(n_sample({}), env_num({}), traj_len({}))'.format(
+ self._default_n_sample, self._env_num, self._traj_len
+ )
+ )
+ else:
+ self._traj_len = INF
+ for p in self._policy:
+ p.reset()
+
+ def reset(self, _policy: Optional[List[namedtuple]] = None, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset the environment and policy.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the collector with the new passed \
+ in environment and launch.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the collector with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self.reset_env(_env)
+ if _policy is not None:
+ self.reset_policy(_policy)
+
+ self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs)
+ self._policy_output_pool = CachePool('policy_output', self._env_num)
+ # _traj_buffer is {env_id: {policy_id: TrajBuffer}}, is used to store traj_len pieces of transitions
+ self._traj_buffer = {
+ env_id: {policy_id: TrajBuffer(maxlen=self._traj_len)
+ for policy_id in range(self._policy_num)}
+ for env_id in range(self._env_num)
+ }
+ self._env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(self._env_num)}
+
+ self._episode_info = []
+ self._total_envstep_count = 0
+ self._total_episode_count = 0
+ self._total_train_sample_count = 0
+ self._total_duration = 0
+ self._last_train_iter = 0
+ self._end_flag = False
+
+ def _reset_stat(self, env_id: int) -> None:
+ """
+ Overview:
+ Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\
+ and env_info. Reset these states according to env_id. You can refer to base_serial_collector\
+ to get more messages.
+ Arguments:
+ - env_id (:obj:`int`): the id where we need to reset the collector's state
+ """
+ for i in range(2):
+ self._traj_buffer[env_id][i].clear()
+ self._obs_pool.reset(env_id)
+ self._policy_output_pool.reset(env_id)
+ self._env_info[env_id] = {'time': 0., 'step': 0, 'train_sample': 0}
+
+ @property
+ def envstep(self) -> int:
+ """
+ Overview:
+ Print the total envstep count.
+ Return:
+ - envstep (:obj:`int`): the total envstep count
+ """
+ return self._total_envstep_count
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close the collector. If end_flag is False, close the environment, flush the tb_logger\
+ and close the tb_logger.
+ """
+ if self._end_flag:
+ return
+ self._end_flag = True
+ self._env.close()
+ self._tb_logger.flush()
+ self._tb_logger.close()
+
+ def __del__(self) -> None:
+ """
+ Overview:
+ Execute the close command and close the collector. __del__ is automatically called to \
+ destroy the collector instance when the collector finishes its work
+ """
+ self.close()
+
+ def collect(
+ self,
+ n_sample: Optional[int] = None,
+ train_iter: int = 0,
+ drop_extra: bool = True,
+ policy_kwargs: Optional[dict] = None
+ ) -> Tuple[List[Any], List[Any]]:
+ """
+ Overview:
+ Collect `n_sample` data with policy_kwargs, which is already trained `train_iter` iterations.
+ Arguments:
+ - n_sample (:obj:`int`): The number of collecting data sample.
+ - train_iter (:obj:`int`): The number of training iteration when calling collect method.
+ - drop_extra (:obj:`bool`): Whether to drop extra return_data more than `n_sample`.
+ - policy_kwargs (:obj:`dict`): The keyword args for policy forward.
+ Returns:
+ - return_data (:obj:`List`): A list containing training samples.
+ """
+ if n_sample is None:
+ if self._default_n_sample is None:
+ raise RuntimeError("Please specify collect n_sample")
+ else:
+ n_sample = self._default_n_sample
+ if n_sample % self._env_num != 0:
+ one_time_warning(
+ "Please make sure env_num is divisible by n_sample: {}/{}, ".format(n_sample, self._env_num) +
+ "which may cause convergence problems in a few algorithms"
+ )
+ if policy_kwargs is None:
+ policy_kwargs = {}
+ collected_sample = [0 for _ in range(self._policy_num)]
+ return_data = [[] for _ in range(self._policy_num)]
+ return_info = [[] for _ in range(self._policy_num)]
+
+ while any([c < n_sample for i, c in enumerate(collected_sample) if self._policy_collect_data[i]]):
+ with self._timer:
+ # Get current env obs.
+ obs = self._env.ready_obs
+ # Policy forward.
+ self._obs_pool.update(obs)
+ if self._transform_obs:
+ obs = to_tensor(obs, dtype=torch.float32)
+ obs = dicts_to_lists(obs)
+ policy_output = [p.forward(obs[i], **policy_kwargs) for i, p in enumerate(self._policy)]
+ self._policy_output_pool.update(policy_output)
+ # Interact with env.
+ actions = {}
+ for policy_output_item in policy_output:
+ for env_id, output in policy_output_item.items():
+ if env_id not in actions:
+ actions[env_id] = []
+ actions[env_id].append(output['action'])
+ actions = to_ndarray(actions)
+ timesteps = self._env.step(actions)
+
+ # TODO(nyz) this duration may be inaccurate in async env
+ interaction_duration = self._timer.value / len(timesteps)
+
+ # TODO(nyz) vectorize this for loop
+ for env_id, timestep in timesteps.items():
+ self._env_info[env_id]['step'] += 1
+ self._total_envstep_count += 1
+ with self._timer:
+ for policy_id, policy in enumerate(self._policy):
+ if not self._policy_collect_data[policy_id]:
+ continue
+ policy_timestep_data = [d[policy_id] if not isinstance(d, bool) else d for d in timestep]
+ policy_timestep = type(timestep)(*policy_timestep_data)
+ transition = self._policy[policy_id].process_transition(
+ self._obs_pool[env_id][policy_id], self._policy_output_pool[env_id][policy_id],
+ policy_timestep
+ )
+ transition['collect_iter'] = train_iter
+ self._traj_buffer[env_id][policy_id].append(transition)
+ # prepare data
+ if timestep.done or len(self._traj_buffer[env_id][policy_id]) == self._traj_len:
+ transitions = to_tensor_transitions(
+ self._traj_buffer[env_id][policy_id], not self._deepcopy_obs
+ )
+ train_sample = self._policy[policy_id].get_train_sample(transitions)
+ return_data[policy_id].extend(train_sample)
+ self._total_train_sample_count += len(train_sample)
+ self._env_info[env_id]['train_sample'] += len(train_sample)
+ collected_sample[policy_id] += len(train_sample)
+ self._traj_buffer[env_id][policy_id].clear()
+
+ self._env_info[env_id]['time'] += self._timer.value + interaction_duration
+
+ # If env is done, record episode info and reset
+ if timestep.done:
+ self._total_episode_count += 1
+ info = {
+ 'time': self._env_info[env_id]['time'],
+ 'step': self._env_info[env_id]['step'],
+ 'train_sample': self._env_info[env_id]['train_sample'],
+ }
+ for i in range(self._policy_num):
+ info['reward{}'.format(i)] = timestep.info[i]['eval_episode_return']
+ self._episode_info.append(info)
+ for i, p in enumerate(self._policy):
+ p.reset([env_id])
+ self._reset_stat(env_id)
+ for policy_id in range(2):
+ return_info[policy_id].append(timestep.info[policy_id])
+ # log
+ self._output_log(train_iter)
+ return_data = [r[:n_sample] for r in return_data]
+ if drop_extra:
+ return_data = return_data[:n_sample]
+ return return_data, return_info
+
+ def _output_log(self, train_iter: int) -> None:
+ """
+ Overview:
+ Print the output log information. You can refer to Docs/Best Practice/How to understand\
+ training generated folders/Serial mode/log/collector for more details.
+ Arguments:
+ - train_iter (:obj:`int`): the number of training iteration.
+ """
+ if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0:
+ self._last_train_iter = train_iter
+ episode_count = len(self._episode_info)
+ envstep_count = sum([d['step'] for d in self._episode_info])
+ duration = sum([d['time'] for d in self._episode_info])
+ episode_return = []
+ for i in range(self._policy_num):
+ episode_return_item = [d['reward{}'.format(i)] for d in self._episode_info]
+ episode_return.append(episode_return_item)
+ self._total_duration += duration
+ info = {
+ 'episode_count': episode_count,
+ 'envstep_count': envstep_count,
+ 'avg_envstep_per_episode': envstep_count / episode_count,
+ 'avg_envstep_per_sec': envstep_count / duration,
+ 'avg_episode_per_sec': episode_count / duration,
+ 'collect_time': duration,
+ 'total_envstep_count': self._total_envstep_count,
+ 'total_episode_count': self._total_episode_count,
+ 'total_duration': self._total_duration,
+ }
+ for k, fn in {'mean': np.mean, 'std': np.std, 'max': np.max, 'min': np.min}.items():
+ for i in range(self._policy_num):
+ # such as reward0_mean
+ info['reward{}_{}'.format(i, k)] = fn(episode_return[i])
+ self._episode_info.clear()
+ self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))
+ for k, v in info.items():
+ self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
+ if k in ['total_envstep_count']:
+ continue
+ self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)
diff --git a/DI-engine/ding/worker/collector/comm/__init__.py b/DI-engine/ding/worker/collector/comm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0afedd0890ff7a06f139618565e3084012ca68e1
--- /dev/null
+++ b/DI-engine/ding/worker/collector/comm/__init__.py
@@ -0,0 +1,3 @@
+from .base_comm_collector import BaseCommCollector, create_comm_collector
+from .flask_fs_collector import FlaskFileSystemCollector
+from .utils import NaiveCollector # for test
diff --git a/DI-engine/ding/worker/collector/comm/base_comm_collector.py b/DI-engine/ding/worker/collector/comm/base_comm_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..242051f36d3f61ca74435a17116c2c0cdb50c208
--- /dev/null
+++ b/DI-engine/ding/worker/collector/comm/base_comm_collector.py
@@ -0,0 +1,117 @@
+from abc import ABC, abstractmethod
+from typing import Any
+from easydict import EasyDict
+
+from ding.utils import get_task_uid, import_module, COMM_COLLECTOR_REGISTRY
+from ..base_parallel_collector import create_parallel_collector, BaseParallelCollector
+
+
+class BaseCommCollector(ABC):
+ """
+ Overview:
+ Abstract baseclass for common collector.
+ Interfaces:
+ __init__, get_policy_update_info, send_metadata, send_stepdata
+ start, close, _create_collector
+ Property:
+ collector_uid
+ """
+
+ def __init__(self, cfg):
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config dict
+ """
+ self._cfg = cfg
+ self._end_flag = True
+ self._collector_uid = get_task_uid()
+
+ @abstractmethod
+ def get_policy_update_info(self, path: str) -> Any:
+ """
+ Overview:
+ Get policy information in corresponding path.
+ Will be registered in base collector.
+ Arguments:
+ - path (:obj:`str`): path to policy update information.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def send_metadata(self, metadata: Any) -> None:
+ """
+ Overview:
+ Store meta data in queue, which will be retrieved by callback function "deal_with_collector_data"
+ in collector slave, then will be sent to coordinator.
+ Will be registered in base collector.
+ Arguments:
+ - metadata (:obj:`Any`): meta data.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def send_stepdata(self, stepdata: Any) -> None:
+ """
+ Overview:
+ Save step data in corresponding path.
+ Will be registered in base collector.
+ Arguments:
+ - stepdata (:obj:`Any`): step data.
+ """
+ raise NotImplementedError
+
+ def start(self) -> None:
+ """
+ Overview:
+ Start comm collector.
+ """
+ self._end_flag = False
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close comm collector.
+ """
+ self._end_flag = True
+
+ @property
+ def collector_uid(self) -> str:
+ return self._collector_uid
+
+ def _create_collector(self, task_info: dict) -> BaseParallelCollector:
+ """
+ Overview:
+ Receive ``task_info`` passed from coordinator and create a collector.
+ Arguments:
+ - task_info (:obj:`dict`): Task info dict from coordinator. Should be like \
+ Returns:
+ - collector (:obj:`BaseParallelCollector`): Created base collector.
+ Note:
+ Four methods('send_metadata', 'send_stepdata', 'get_policy_update_info'), and policy are set.
+ The reason why they are set here rather than base collector is, they highly depend on the specific task.
+ Only after task info is passed from coordinator to comm collector through learner slave, can they be
+ clarified and initialized.
+ """
+ collector_cfg = EasyDict(task_info['collector_cfg'])
+ collector = create_parallel_collector(collector_cfg)
+ for item in ['send_metadata', 'send_stepdata', 'get_policy_update_info']:
+ setattr(collector, item, getattr(self, item))
+ return collector
+
+
+def create_comm_collector(cfg: EasyDict) -> BaseCommCollector:
+ """
+ Overview:
+ Given the key(comm_collector_name), create a new comm collector instance if in comm_map's values,
+ or raise an KeyError. In other words, a derived comm collector must first register,
+ then can call ``create_comm_collector`` to get the instance.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Collector config. Necessary keys: [import_names, comm_collector_type].
+ Returns:
+ - collector (:obj:`BaseCommCollector`): The created new comm collector, should be an instance of one of \
+ comm_map's values.
+ """
+ import_module(cfg.get('import_names', []))
+ return COMM_COLLECTOR_REGISTRY.build(cfg.type, cfg=cfg)
diff --git a/DI-engine/ding/worker/collector/comm/flask_fs_collector.py b/DI-engine/ding/worker/collector/comm/flask_fs_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..db9f0be7c3ee472bb71b6cdc1bad16de7f889ada
--- /dev/null
+++ b/DI-engine/ding/worker/collector/comm/flask_fs_collector.py
@@ -0,0 +1,235 @@
+import os
+import time
+from typing import Union, Dict, Callable
+from queue import Queue
+from threading import Thread
+
+from ding.utils import read_file, save_file, COMM_COLLECTOR_REGISTRY
+from ding.utils.file_helper import save_to_di_store
+from ding.interaction import Slave, TaskFail
+from .base_comm_collector import BaseCommCollector
+
+
+class CollectorSlave(Slave):
+ """
+ Overview:
+ A slave, whose master is coordinator.
+ Used to pass message between comm collector and coordinator.
+ Interfaces:
+ __init__, _process_task
+ """
+
+ # override
+ def __init__(self, *args, callback_fn: Dict[str, Callable], **kwargs) -> None:
+ """
+ Overview:
+ Init callback functions additionally. Callback functions are methods in comm collector.
+ """
+ super().__init__(*args, **kwargs)
+ self._callback_fn = callback_fn
+ self._current_task_info = None
+
+ def _process_task(self, task: dict) -> Union[dict, TaskFail]:
+ """
+ Overview:
+ Process a task according to input task info dict, which is passed in by master coordinator.
+ For each type of task, you can refer to corresponding callback function in comm collector for details.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Task dict. Must contain key "name".
+ Returns:
+ - result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception.
+ """
+ task_name = task['name']
+ if task_name == 'resource':
+ return self._callback_fn['deal_with_resource']()
+ elif task_name == 'collector_start_task':
+ self._current_task_info = task['task_info']
+ self._callback_fn['deal_with_collector_start'](self._current_task_info)
+ return {'message': 'collector task has started'}
+ elif task_name == 'collector_data_task':
+ data = self._callback_fn['deal_with_collector_data']()
+ data['buffer_id'] = self._current_task_info['buffer_id']
+ data['task_id'] = self._current_task_info['task_id']
+ return data
+ elif task_name == 'collector_close_task':
+ data = self._callback_fn['deal_with_collector_close']()
+ data['task_id'] = self._current_task_info['task_id']
+ return data
+ else:
+ raise TaskFail(
+ result={'message': 'task name error'}, message='illegal collector task <{}>'.format(task_name)
+ )
+
+
+@COMM_COLLECTOR_REGISTRY.register('flask_fs')
+class FlaskFileSystemCollector(BaseCommCollector):
+ """
+ Overview:
+ An implementation of CommLearner, using flask and the file system.
+ Interfaces:
+ __init__, deal_with_resource, deal_with_collector_start, deal_with_collector_data, deal_with_collector_close,\
+ get_policy_update_info, send_stepdata, send_metadata, start, close
+ """
+
+ # override
+ def __init__(self, cfg: dict) -> None:
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config dict
+ """
+ BaseCommCollector.__init__(self, cfg)
+ host, port = cfg.host, cfg.port
+ self._callback_fn = {
+ 'deal_with_resource': self.deal_with_resource,
+ 'deal_with_collector_start': self.deal_with_collector_start,
+ 'deal_with_collector_data': self.deal_with_collector_data,
+ 'deal_with_collector_close': self.deal_with_collector_close,
+ }
+ self._slave = CollectorSlave(host, port, callback_fn=self._callback_fn)
+
+ self._path_policy = cfg.path_policy
+ self._path_data = cfg.path_data
+ if not os.path.exists(self._path_data):
+ try:
+ os.mkdir(self._path_data)
+ except Exception as e:
+ pass
+ self._metadata_queue = Queue(8)
+ self._collector_close_flag = False
+ self._collector = None
+
+ def deal_with_resource(self) -> dict:
+ """
+ Overview:
+ Callback function in ``CollectorSlave``. Return how many resources are needed to start current collector.
+ Returns:
+ - resource (:obj:`dict`): Resource info dict, including ['gpu', 'cpu'].
+ """
+ return {'gpu': 1, 'cpu': 20}
+
+ def deal_with_collector_start(self, task_info: dict) -> None:
+ """
+ Overview:
+ Callback function in ``CollectorSlave``.
+ Create a collector and start a collector thread of the created one.
+ Arguments:
+ - task_info (:obj:`dict`): Task info dict.
+ Note:
+ In ``_create_collector`` method in base class ``BaseCommCollector``, 4 methods
+ 'send_metadata', 'send_stepdata', 'get_policy_update_info', and policy are set.
+ You can refer to it for details.
+ """
+ self._collector_close_flag = False
+ self._collector = self._create_collector(task_info)
+ self._collector_thread = Thread(target=self._collector.start, args=(), daemon=True, name='collector_start')
+ self._collector_thread.start()
+
+ def deal_with_collector_data(self) -> dict:
+ """
+ Overview:
+ Callback function in ``CollectorSlave``. Get data sample dict from ``_metadata_queue``,
+ which will be sent to coordinator afterwards.
+ Returns:
+ - data (:obj:`Any`): Data sample dict.
+ """
+ while True:
+ if not self._metadata_queue.empty():
+ data = self._metadata_queue.get()
+ break
+ else:
+ time.sleep(0.1)
+ return data
+
+ def deal_with_collector_close(self) -> dict:
+ self._collector_close_flag = True
+ finish_info = self._collector.get_finish_info()
+ self._collector.close()
+ self._collector_thread.join()
+ del self._collector_thread
+ self._collector = None
+ return finish_info
+
+ # override
+ def get_policy_update_info(self, path: str) -> dict:
+ """
+ Overview:
+ Get policy information in corresponding path.
+ Arguments:
+ - path (:obj:`str`): path to policy update information.
+ """
+ if self._collector_close_flag:
+ return
+ if self._path_policy not in path:
+ path = os.path.join(self._path_policy, path)
+ return read_file(path, use_lock=True)
+
+ # override
+ def send_stepdata(self, path: str, stepdata: list) -> None:
+ """
+ Overview:
+ Save collector's step data in corresponding path.
+ Arguments:
+ - path (:obj:`str`): Path to save data.
+ - stepdata (:obj:`Any`): Data of one step.
+ """
+ if save_to_di_store:
+ if self._collector_close_flag:
+ return b'0' * 20 # return an object reference that doesn't exist
+ object_ref = save_to_di_store(stepdata)
+ # print('send_stepdata:', path, 'object ref:', object_ref, 'len:', len(stepdata))
+ return object_ref
+
+ if self._collector_close_flag:
+ return
+ name = os.path.join(self._path_data, path)
+ save_file(name, stepdata, use_lock=False)
+
+ # override
+ def send_metadata(self, metadata: dict) -> None:
+ """
+ Overview:
+ Store learn info dict in queue, which will be retrieved by callback function "deal_with_collector_learn"
+ in collector slave, then will be sent to coordinator.
+ Arguments:
+ - metadata (:obj:`Any`): meta data.
+ """
+ if self._collector_close_flag:
+ return
+ necessary_metadata_keys = set(['data_id', 'policy_iter'])
+ necessary_info_keys = set(['collector_done', 'cur_episode', 'cur_sample', 'cur_step'])
+ assert necessary_metadata_keys.issubset(set(metadata.keys())
+ ) or necessary_info_keys.issubset(set(metadata.keys()))
+ while True:
+ if not self._metadata_queue.full():
+ self._metadata_queue.put(metadata)
+ break
+ else:
+ time.sleep(0.1)
+
+ def start(self) -> None:
+ """
+ Overview:
+ Start comm collector itself and the collector slave.
+ """
+ BaseCommCollector.start(self)
+ self._slave.start()
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close comm collector itself and the collector slave.
+ """
+ if self._end_flag:
+ return
+ total_sleep_count = 0
+ while self._collector is not None and total_sleep_count < 10:
+ self._collector.info("please first close collector")
+ time.sleep(1)
+ total_sleep_count += 1
+ self._slave.close()
+ BaseCommCollector.close(self)
+
+ def __del__(self) -> None:
+ self.close()
diff --git a/DI-engine/ding/worker/collector/comm/tests/test_collector_with_coordinator.py b/DI-engine/ding/worker/collector/comm/tests/test_collector_with_coordinator.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e4e310cf3449bf818f7f9efc7351d766a6b174e
--- /dev/null
+++ b/DI-engine/ding/worker/collector/comm/tests/test_collector_with_coordinator.py
@@ -0,0 +1,87 @@
+import pytest
+import os
+import time
+from threading import Thread
+from multiprocessing import Process
+import torch
+
+from ding.worker import Coordinator, create_comm_collector
+from ding.worker.learner.comm import NaiveLearner
+from ding.utils import lists_to_dicts
+from ding.interaction.slave import Slave, TaskFail
+from ding.config import compile_config_parallel
+from ding.config.utils import parallel_test_main_config, parallel_test_create_config, parallel_test_system_config
+
+DATA_PREFIX = 'SLAVE_COLLECTOR_DATA_COLLECTOR_TEST'
+
+
+@pytest.fixture(scope='function')
+def setup_config():
+ return compile_config_parallel(
+ parallel_test_main_config, create_cfg=parallel_test_create_config, system_cfg=parallel_test_system_config
+ )
+
+
+@pytest.fixture(scope='function')
+def setup_collector(setup_config):
+ collector = {}
+ for k, v in setup_config.system.items():
+ if 'collector' in k:
+ collector[k] = create_comm_collector(v)
+ collector[k].start()
+ yield collector
+ time.sleep(1) # avoid collector is not closed but comm collector receive close signal
+ for a in collector.values():
+ a.close()
+
+
+@pytest.fixture(scope='function')
+def setup_learner(setup_config):
+ cfg = setup_config.system.coordinator.learner
+ learner = {}
+ for _, (name, host, port) in cfg.items():
+ learner[name] = NaiveLearner(host, port, prefix=DATA_PREFIX)
+ learner[name].start()
+ yield learner
+ time.sleep(1)
+ for l in learner.values():
+ l.close()
+
+
+@pytest.mark.unittest
+class TestCollectorWithCoordinator:
+
+ def test_naive(self, setup_config, setup_collector, setup_learner):
+ os.popen('rm -rf {}*'.format(DATA_PREFIX))
+ os.popen('rm -rf env_*_*')
+ os.popen('rm -rf test.pth')
+ assert len(setup_collector) == len(setup_config.system.coordinator.collector)
+ try:
+ coordinator = Coordinator(setup_config)
+ coordinator.start()
+ while True:
+ if setup_collector['collector0']._collector is not None:
+ break
+ time.sleep(0.5)
+ torch.save(
+ {
+ 'model': setup_collector['collector0']._collector.policy.state_dict()['model'],
+ 'iter': 0
+ }, 'test.pth'
+ )
+ while True:
+ commander = coordinator._commander
+ if commander._learner_task_finish_count >= 1 and commander._collector_task_finish_count >= 2:
+ break
+ time.sleep(0.5)
+ coordinator.close()
+ except Exception as e:
+ os.popen('rm -rf {}*'.format(DATA_PREFIX))
+ assert False, e
+
+ assert len(coordinator._replay_buffer) == 0
+ learner_task_ids = [i for i in coordinator._historical_task if 'learner' in i]
+ for i in learner_task_ids:
+ assert len(coordinator._commander._learner_info[i]) == 5
+ os.popen('rm -rf {}*'.format(DATA_PREFIX))
+ os.popen('rm -rf env_*_*')
diff --git a/DI-engine/ding/worker/collector/comm/utils.py b/DI-engine/ding/worker/collector/comm/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c32a6d9378a812f5b2801641a605f260db4a90e
--- /dev/null
+++ b/DI-engine/ding/worker/collector/comm/utils.py
@@ -0,0 +1,65 @@
+import torch
+from ding.interaction.slave import Slave, TaskFail
+
+
+class NaiveCollector(Slave):
+ """
+ Overview:
+ A slave, whose master is coordinator.
+ Used to pass message between comm collector and coordinator.
+ Interfaces:
+ _process_task, _get_timestep
+ """
+
+ def __init__(self, *args, prefix='', **kwargs):
+ super().__init__(*args, **kwargs)
+ self._prefix = prefix
+
+ def _process_task(self, task):
+ """
+ Overview:
+ Process a task according to input task info dict, which is passed in by master coordinator.
+ For each type of task, you can refer to corresponding callback function in comm collector for details.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Task dict. Must contain key "name".
+ Returns:
+ - result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception.
+ """
+ task_name = task['name']
+ if task_name == 'resource':
+ return {'cpu': '20', 'gpu': '1'}
+ elif task_name == 'collector_start_task':
+ self.count = 0
+ self.task_info = task['task_info']
+ return {'message': 'collector task has started'}
+ elif task_name == 'collector_data_task':
+ self.count += 1
+ data_id = './{}_{}_{}'.format(self._prefix, self.task_info['task_id'], self.count)
+ torch.save(self._get_timestep(), data_id)
+ data = {'data_id': data_id, 'buffer_id': self.task_info['buffer_id'], 'unroll_split_begin': 0}
+ data['task_id'] = self.task_info['task_id']
+ if self.count == 20:
+ return {
+ 'task_id': self.task_info['task_id'],
+ 'collector_done': True,
+ 'cur_episode': 1,
+ 'cur_step': 314,
+ 'cur_sample': 314,
+ }
+ else:
+ return data
+ else:
+ raise TaskFail(
+ result={'message': 'task name error'}, message='illegal collector task <{}>'.format(task_name)
+ )
+
+ def _get_timestep(self):
+ return [
+ {
+ 'obs': torch.rand(4),
+ 'next_obs': torch.randn(4),
+ 'reward': torch.randint(0, 2, size=(3, )).float(),
+ 'action': torch.randint(0, 2, size=(1, )),
+ 'done': False,
+ }
+ ]
diff --git a/DI-engine/ding/worker/collector/episode_serial_collector.py b/DI-engine/ding/worker/collector/episode_serial_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fca2283f8383857fcaf0e6469495aef01a880cd
--- /dev/null
+++ b/DI-engine/ding/worker/collector/episode_serial_collector.py
@@ -0,0 +1,327 @@
+from typing import Optional, Any, List
+from collections import namedtuple
+from easydict import EasyDict
+import numpy as np
+import torch
+
+from ding.envs import BaseEnvManager
+from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY
+from ding.torch_utils import to_tensor, to_ndarray
+from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions
+
+
+@SERIAL_COLLECTOR_REGISTRY.register('episode')
+class EpisodeSerialCollector(ISerialCollector):
+ """
+ Overview:
+ Episode collector(n_episode)
+ Interfaces:
+ __init__, reset, reset_env, reset_policy, collect, close
+ Property:
+ envstep
+ """
+
+ config = dict(
+ deepcopy_obs=False, transform_obs=False, collect_print_freq=100, get_train_sample=False, reward_shaping=False
+ )
+
+ def __init__(
+ self,
+ cfg: EasyDict,
+ env: BaseEnvManager = None,
+ policy: namedtuple = None,
+ tb_logger: 'SummaryWriter' = None, # noqa
+ exp_name: Optional[str] = 'default_experiment',
+ instance_name: Optional[str] = 'collector'
+ ) -> None:
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config dict
+ - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager)
+ - policy (:obj:`namedtuple`): the api namedtuple of collect_mode policy
+ - tb_logger (:obj:`SummaryWriter`): tensorboard handle
+ """
+ self._exp_name = exp_name
+ self._instance_name = instance_name
+ self._collect_print_freq = cfg.collect_print_freq
+ self._deepcopy_obs = cfg.deepcopy_obs
+ self._transform_obs = cfg.transform_obs
+ self._cfg = cfg
+ self._timer = EasyTimer()
+ self._end_flag = False
+
+ if tb_logger is not None:
+ self._logger, _ = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
+ )
+ self._tb_logger = tb_logger
+ else:
+ self._logger, self._tb_logger = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
+ )
+ self.reset(policy, env)
+
+ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset the environment.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the collector with the new passed \
+ in environment and launch.
+ Arguments:
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self._env = _env
+ self._env.launch()
+ self._env_num = self._env.env_num
+ else:
+ self._env.reset()
+
+ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
+ """
+ Overview:
+ Reset the policy.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the collector with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy
+ """
+ assert hasattr(self, '_env'), "please set env first"
+ if _policy is not None:
+ self._policy = _policy
+ self._policy_cfg = self._policy.get_attribute('cfg')
+ self._default_n_episode = _policy.get_attribute('n_episode')
+ self._unroll_len = _policy.get_attribute('unroll_len')
+ self._on_policy = _policy.get_attribute('on_policy')
+ self._traj_len = INF
+ self._logger.debug(
+ 'Set default n_episode mode(n_episode({}), env_num({}), traj_len({}))'.format(
+ self._default_n_episode, self._env_num, self._traj_len
+ )
+ )
+ self._policy.reset()
+
+ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset the environment and policy.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the collector with the new passed \
+ in environment and launch.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the collector with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self.reset_env(_env)
+ if _policy is not None:
+ self.reset_policy(_policy)
+
+ self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs)
+ self._policy_output_pool = CachePool('policy_output', self._env_num)
+ # _traj_buffer is {env_id: TrajBuffer}, is used to store traj_len pieces of transitions
+ self._traj_buffer = {env_id: TrajBuffer(maxlen=self._traj_len) for env_id in range(self._env_num)}
+ self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)}
+
+ self._episode_info = []
+ self._total_envstep_count = 0
+ self._total_episode_count = 0
+ self._total_duration = 0
+ self._last_train_iter = 0
+ self._end_flag = False
+
+ def _reset_stat(self, env_id: int) -> None:
+ """
+ Overview:
+ Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\
+ and env_info. Reset these states according to env_id. You can refer to base_serial_collector\
+ to get more messages.
+ Arguments:
+ - env_id (:obj:`int`): the id where we need to reset the collector's state
+ """
+ self._traj_buffer[env_id].clear()
+ self._obs_pool.reset(env_id)
+ self._policy_output_pool.reset(env_id)
+ self._env_info[env_id] = {'time': 0., 'step': 0}
+
+ @property
+ def envstep(self) -> int:
+ """
+ Overview:
+ Print the total envstep count.
+ Return:
+ - envstep (:obj:`int`): the total envstep count
+ """
+ return self._total_envstep_count
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close the collector. If end_flag is False, close the environment, flush the tb_logger\
+ and close the tb_logger.
+ """
+ if self._end_flag:
+ return
+ self._end_flag = True
+ self._env.close()
+ self._tb_logger.flush()
+ self._tb_logger.close()
+
+ def __del__(self) -> None:
+ """
+ Overview:
+ Execute the close command and close the collector. __del__ is automatically called to \
+ destroy the collector instance when the collector finishes its work
+ """
+ self.close()
+
+ def collect(self,
+ n_episode: Optional[int] = None,
+ train_iter: int = 0,
+ policy_kwargs: Optional[dict] = None) -> List[Any]:
+ """
+ Overview:
+ Collect `n_episode` data with policy_kwargs, which is already trained `train_iter` iterations
+ Arguments:
+ - n_episode (:obj:`int`): the number of collecting data episode
+ - train_iter (:obj:`int`): the number of training iteration
+ - policy_kwargs (:obj:`dict`): the keyword args for policy forward
+ Returns:
+ - return_data (:obj:`List`): A list containing collected episodes if not get_train_sample, otherwise, \
+ return train_samples split by unroll_len.
+ """
+ if n_episode is None:
+ if self._default_n_episode is None:
+ raise RuntimeError("Please specify collect n_episode")
+ else:
+ n_episode = self._default_n_episode
+ assert n_episode >= self._env_num, "Please make sure n_episode >= env_num{}/{}".format(n_episode, self._env_num)
+ if policy_kwargs is None:
+ policy_kwargs = {}
+ collected_episode = 0
+ return_data = []
+ ready_env_id = set()
+ remain_episode = n_episode
+
+ while True:
+ with self._timer:
+ # Get current env obs.
+ obs = self._env.ready_obs
+ new_available_env_id = set(obs.keys()).difference(ready_env_id)
+ ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode]))
+ remain_episode -= min(len(new_available_env_id), remain_episode)
+ obs = {env_id: obs[env_id] for env_id in ready_env_id}
+ # Policy forward.
+ self._obs_pool.update(obs)
+ if self._transform_obs:
+ obs = to_tensor(obs, dtype=torch.float32)
+ policy_output = self._policy.forward(obs, **policy_kwargs)
+ self._policy_output_pool.update(policy_output)
+ # Interact with env.
+ actions = {env_id: output['action'] for env_id, output in policy_output.items()}
+ actions = to_ndarray(actions)
+ timesteps = self._env.step(actions)
+
+ # TODO(nyz) this duration may be inaccurate in async env
+ interaction_duration = self._timer.value / len(timesteps)
+
+ # TODO(nyz) vectorize this for loop
+ for env_id, timestep in timesteps.items():
+ with self._timer:
+ if timestep.info.get('abnormal', False):
+ # If there is an abnormal timestep, reset all the related variables(including this env).
+ # suppose there is no reset param, just reset this env
+ self._env.reset({env_id: None})
+ self._policy.reset([env_id])
+ self._reset_stat(env_id)
+ self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info))
+ continue
+ transition = self._policy.process_transition(
+ self._obs_pool[env_id], self._policy_output_pool[env_id], timestep
+ )
+ # ``train_iter`` passed in from ``serial_entry``, indicates current collecting model's iteration.
+ transition['collect_iter'] = train_iter
+ self._traj_buffer[env_id].append(transition)
+ self._env_info[env_id]['step'] += 1
+ self._total_envstep_count += 1
+ # prepare data
+ if timestep.done:
+ transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs)
+ if self._cfg.reward_shaping:
+ self._env.reward_shaping(env_id, transitions)
+ if self._cfg.get_train_sample:
+ train_sample = self._policy.get_train_sample(transitions)
+ return_data.extend(train_sample)
+ else:
+ return_data.append(transitions)
+ self._traj_buffer[env_id].clear()
+
+ self._env_info[env_id]['time'] += self._timer.value + interaction_duration
+
+ # If env is done, record episode info and reset
+ if timestep.done:
+ self._total_episode_count += 1
+ reward = timestep.info['eval_episode_return']
+ info = {
+ 'reward': reward,
+ 'time': self._env_info[env_id]['time'],
+ 'step': self._env_info[env_id]['step'],
+ }
+ collected_episode += 1
+ self._episode_info.append(info)
+ self._policy.reset([env_id])
+ self._reset_stat(env_id)
+ ready_env_id.remove(env_id)
+ if collected_episode >= n_episode:
+ break
+ # log
+ self._output_log(train_iter)
+ return return_data
+
+ def _output_log(self, train_iter: int) -> None:
+ """
+ Overview:
+ Print the output log information. You can refer to Docs/Best Practice/How to understand\
+ training generated folders/Serial mode/log/collector for more details.
+ Arguments:
+ - train_iter (:obj:`int`): the number of training iteration.
+ """
+ if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0:
+ self._last_train_iter = train_iter
+ episode_count = len(self._episode_info)
+ envstep_count = sum([d['step'] for d in self._episode_info])
+ duration = sum([d['time'] for d in self._episode_info])
+ episode_return = [d['reward'] for d in self._episode_info]
+ self._total_duration += duration
+ info = {
+ 'episode_count': episode_count,
+ 'envstep_count': envstep_count,
+ 'avg_envstep_per_episode': envstep_count / episode_count,
+ 'avg_envstep_per_sec': envstep_count / duration,
+ 'avg_episode_per_sec': episode_count / duration,
+ 'collect_time': duration,
+ 'reward_mean': np.mean(episode_return),
+ 'reward_std': np.std(episode_return),
+ 'reward_max': np.max(episode_return),
+ 'reward_min': np.min(episode_return),
+ 'total_envstep_count': self._total_envstep_count,
+ 'total_episode_count': self._total_episode_count,
+ 'total_duration': self._total_duration,
+ # 'each_reward': episode_return,
+ }
+ self._episode_info.clear()
+ self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))
+ for k, v in info.items():
+ if k in ['each_reward']:
+ continue
+ self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
+ if k in ['total_envstep_count']:
+ continue
+ self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)
diff --git a/DI-engine/ding/worker/collector/interaction_serial_evaluator.py b/DI-engine/ding/worker/collector/interaction_serial_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..09893ca5257b4dc011d84115adb48390a66af073
--- /dev/null
+++ b/DI-engine/ding/worker/collector/interaction_serial_evaluator.py
@@ -0,0 +1,321 @@
+from typing import Optional, Callable, Tuple, Dict, List
+from collections import namedtuple
+import numpy as np
+import torch
+
+from ding.envs import BaseEnvManager
+from ding.torch_utils import to_tensor, to_ndarray, to_item
+from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY
+from ding.utils import get_world_size, get_rank, broadcast_object_list
+from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor
+
+
+@SERIAL_EVALUATOR_REGISTRY.register('interaction')
+class InteractionSerialEvaluator(ISerialEvaluator):
+ """
+ Overview:
+ Interaction serial evaluator class, policy interacts with env.
+ Interfaces:
+ __init__, reset, reset_policy, reset_env, close, should_eval, eval
+ Property:
+ env, policy
+ """
+
+ config = dict(
+ # (int) Evaluate every "eval_freq" training iterations.
+ eval_freq=1000,
+ render=dict(
+ # Tensorboard video render is disabled by default.
+ render_freq=-1,
+ mode='train_iter',
+ ),
+ # (str) File path for visualize environment information.
+ figure_path=None,
+ )
+
+ def __init__(
+ self,
+ cfg: dict,
+ env: BaseEnvManager = None,
+ policy: namedtuple = None,
+ tb_logger: 'SummaryWriter' = None, # noqa
+ exp_name: Optional[str] = 'default_experiment',
+ instance_name: Optional[str] = 'evaluator',
+ ) -> None:
+ """
+ Overview:
+ Init method. Load config and use ``self._cfg`` setting to build common serial evaluator components, \
+ e.g. logger helper, timer.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Configuration EasyDict.
+ """
+ self._cfg = cfg
+ self._exp_name = exp_name
+ self._instance_name = instance_name
+
+ # Logger (Monitor will be initialized in policy setter)
+ # Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output.
+ if get_rank() == 0:
+ if tb_logger is not None:
+ self._logger, _ = build_logger(
+ './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
+ )
+ self._tb_logger = tb_logger
+ else:
+ self._logger, self._tb_logger = build_logger(
+ './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name
+ )
+ else:
+ self._logger, self._tb_logger = None, None # for close elegantly
+ self.reset(policy, env)
+
+ self._timer = EasyTimer()
+ self._default_n_episode = cfg.n_episode
+ self._stop_value = cfg.stop_value
+ # only one freq
+ self._render = cfg.render
+ assert self._render.mode in ('envstep', 'train_iter'), 'mode should be envstep or train_iter'
+
+ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset evaluator's environment. In some case, we need evaluator use the same policy in different \
+ environments. We can use reset_env to reset the environment.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the evaluator with the \
+ new passed in environment and launch.
+ Arguments:
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self._env = _env
+ self._env.launch()
+ self._env_num = self._env.env_num
+ else:
+ self._env.reset()
+
+ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
+ """
+ Overview:
+ Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\
+ different policy. We can use reset_policy to reset the policy.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy
+ """
+ assert hasattr(self, '_env'), "please set env first"
+ if _policy is not None:
+ self._policy = _policy
+ self._policy_cfg = self._policy.get_attribute('cfg')
+ self._policy.reset()
+
+ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset evaluator's policy and environment. Use new policy and environment to collect data.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the evaluator with the new passed in \
+ environment and launch.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self.reset_env(_env)
+ if _policy is not None:
+ self.reset_policy(_policy)
+ if self._policy_cfg.type == 'dreamer_command':
+ self._states = None
+ self._resets = np.array([False for i in range(self._env_num)])
+ self._max_episode_return = float("-inf")
+ self._last_eval_iter = -1
+ self._end_flag = False
+ self._last_render_iter = -1
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\
+ and close the tb_logger.
+ """
+ if self._end_flag:
+ return
+ self._end_flag = True
+ self._env.close()
+ if self._tb_logger:
+ self._tb_logger.flush()
+ self._tb_logger.close()
+
+ def __del__(self):
+ """
+ Overview:
+ Execute the close command and close the evaluator. __del__ is automatically called \
+ to destroy the evaluator instance when the evaluator finishes its work
+ """
+ self.close()
+
+ def should_eval(self, train_iter: int) -> bool:
+ """
+ Overview:
+ Determine whether you need to start the evaluation mode, if the number of training has reached\
+ the maximum number of times to start the evaluator, return True
+ """
+ if train_iter == self._last_eval_iter:
+ return False
+ if (train_iter - self._last_eval_iter) < self._cfg.eval_freq and train_iter != 0:
+ return False
+ self._last_eval_iter = train_iter
+ return True
+
+ def _should_render(self, envstep, train_iter):
+ if self._render.render_freq == -1:
+ return False
+ iter = envstep if self._render.mode == 'envstep' else train_iter
+ if (iter - self._last_render_iter) < self._render.render_freq:
+ return False
+ self._last_render_iter = iter
+ return True
+
+ def eval(
+ self,
+ save_ckpt_fn: Callable = None,
+ train_iter: int = -1,
+ envstep: int = -1,
+ n_episode: Optional[int] = None,
+ force_render: bool = False,
+ policy_kwargs: Optional[Dict] = {},
+ ) -> Tuple[bool, Dict[str, List]]:
+ '''
+ Overview:
+ Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
+ Arguments:
+ - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward.
+ - train_iter (:obj:`int`): Current training iteration.
+ - envstep (:obj:`int`): Current env interaction step.
+ - n_episode (:obj:`int`): Number of evaluation episodes.
+ Returns:
+ - stop_flag (:obj:`bool`): Whether this training program can be ended.
+ - episode_info (:obj:`Dict[str, List]`): Current evaluation episode information.
+ '''
+ # evaluator only work on rank0
+ stop_flag = False
+ if get_rank() == 0:
+ if n_episode is None:
+ n_episode = self._default_n_episode
+ assert n_episode is not None, "please indicate eval n_episode"
+ envstep_count = 0
+ info = {}
+ eval_monitor = VectorEvalMonitor(self._env.env_num, n_episode)
+ self._env.reset()
+ self._policy.reset()
+
+ # force_render overwrite frequency constraint
+ render = force_render or self._should_render(envstep, train_iter)
+
+ with self._timer:
+ while not eval_monitor.is_finished():
+ obs = self._env.ready_obs
+ obs = to_tensor(obs, dtype=torch.float32)
+
+ # update videos
+ if render:
+ eval_monitor.update_video(self._env.ready_imgs)
+
+ if self._policy_cfg.type == 'dreamer_command':
+ policy_output = self._policy.forward(
+ obs, **policy_kwargs, reset=self._resets, state=self._states
+ )
+ #self._states = {env_id: output['state'] for env_id, output in policy_output.items()}
+ self._states = [output['state'] for output in policy_output.values()]
+ else:
+ policy_output = self._policy.forward(obs, **policy_kwargs)
+ actions = {i: a['action'] for i, a in policy_output.items()}
+ actions = to_ndarray(actions)
+ timesteps = self._env.step(actions)
+ timesteps = to_tensor(timesteps, dtype=torch.float32)
+ for env_id, t in timesteps.items():
+ if t.info.get('abnormal', False):
+ # If there is an abnormal timestep, reset all the related variables(including this env).
+ self._policy.reset([env_id])
+ continue
+ if self._policy_cfg.type == 'dreamer_command':
+ self._resets[env_id] = t.done
+ if t.done:
+ # Env reset is done by env_manager automatically.
+ if 'figure_path' in self._cfg and self._cfg.figure_path is not None:
+ self._env.enable_save_figure(env_id, self._cfg.figure_path)
+ self._policy.reset([env_id])
+ reward = t.info['eval_episode_return']
+ saved_info = {'eval_episode_return': t.info['eval_episode_return']}
+ if 'episode_info' in t.info:
+ saved_info.update(t.info['episode_info'])
+ eval_monitor.update_info(env_id, saved_info)
+ eval_monitor.update_reward(env_id, reward)
+ self._logger.info(
+ "[EVALUATOR]env {} finish episode, final reward: {:.4f}, current episode: {}".format(
+ env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode()
+ )
+ )
+ envstep_count += 1
+ duration = self._timer.value
+ episode_return = eval_monitor.get_episode_return()
+ info = {
+ 'train_iter': train_iter,
+ 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter),
+ 'episode_count': n_episode,
+ 'envstep_count': envstep_count,
+ 'avg_envstep_per_episode': envstep_count / n_episode,
+ 'evaluate_time': duration,
+ 'avg_envstep_per_sec': envstep_count / duration,
+ 'avg_time_per_episode': n_episode / duration,
+ 'reward_mean': np.mean(episode_return),
+ 'reward_std': np.std(episode_return),
+ 'reward_max': np.max(episode_return),
+ 'reward_min': np.min(episode_return),
+ # 'each_reward': episode_return,
+ }
+ episode_info = eval_monitor.get_episode_info()
+ if episode_info is not None:
+ info.update(episode_info)
+ self._logger.info(self._logger.get_tabulate_vars_hor(info))
+ # self._logger.info(self._logger.get_tabulate_vars(info))
+ for k, v in info.items():
+ if k in ['train_iter', 'ckpt_name', 'each_reward']:
+ continue
+ if not np.isscalar(v):
+ continue
+ self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
+ self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep)
+
+ if render:
+ video_title = '{}_{}/'.format(self._instance_name, self._render.mode)
+ videos = eval_monitor.get_video()
+ render_iter = envstep if self._render.mode == 'envstep' else train_iter
+ from ding.utils import fps
+ self._tb_logger.add_video(video_title, videos, render_iter, fps(self._env))
+
+ episode_return = np.mean(episode_return)
+ if episode_return > self._max_episode_return:
+ if save_ckpt_fn:
+ save_ckpt_fn('ckpt_best.pth.tar')
+ self._max_episode_return = episode_return
+ stop_flag = episode_return >= self._stop_value and train_iter > 0
+ if stop_flag:
+ self._logger.info(
+ "[DI-engine serial pipeline] " + "Current episode_return: {:.4f} is greater than stop_value: {}".
+ format(episode_return, self._stop_value) + ", so your RL agent is converged, you can refer to " +
+ "'log/evaluator/evaluator_logger.txt' for details."
+ )
+
+ if get_world_size() > 1:
+ objects = [stop_flag, episode_info]
+ broadcast_object_list(objects, src=0)
+ stop_flag, episode_info = objects
+
+ episode_info = to_item(episode_info)
+ return stop_flag, episode_info
diff --git a/DI-engine/ding/worker/collector/marine_parallel_collector.py b/DI-engine/ding/worker/collector/marine_parallel_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..c659c7039856354d3e652da4c69bd1760c393ec4
--- /dev/null
+++ b/DI-engine/ding/worker/collector/marine_parallel_collector.py
@@ -0,0 +1,346 @@
+from typing import Dict, Any, List
+import copy
+import time
+import uuid
+from collections import namedtuple
+from threading import Thread
+from functools import partial
+import numpy as np
+import torch
+from easydict import EasyDict
+
+from ding.policy import create_policy, Policy
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.utils import get_data_compressor, pretty_print, PARALLEL_COLLECTOR_REGISTRY
+from ding.envs import BaseEnvTimestep, BaseEnvManager
+from .base_parallel_collector import BaseParallelCollector
+from .base_serial_collector import CachePool, TrajBuffer
+
+INF = float("inf")
+
+
+@PARALLEL_COLLECTOR_REGISTRY.register('marine')
+class MarineParallelCollector(BaseParallelCollector):
+ """
+ Feature:
+ - one policy or two policies, many envs
+ - async envs(step + reset)
+ - batch network eval
+ - different episode length env
+ - periodic policy update
+ - metadata + stepdata
+ """
+ config = dict(
+ print_freq=5,
+ compressor='lz4',
+ update_policy_second=3,
+ # The following keys is set by the commander
+ # env
+ # policy
+ # collect_setting
+ # eval_flag
+ # policy_update_path
+ )
+
+ # override
+ def __init__(self, cfg: dict) -> None:
+ super().__init__(cfg)
+ self._update_policy_thread = Thread(
+ target=self._update_policy_periodically, args=(), name='update_policy', daemon=True
+ )
+ self._start_time = time.time()
+ self._compressor = get_data_compressor(self._cfg.compressor)
+
+ # create env
+ self._env_cfg = self._cfg.env
+ env_manager = self._setup_env_manager(self._env_cfg)
+ self.env_manager = env_manager
+
+ # create policy
+ if self._eval_flag:
+ assert len(self._cfg.policy) == 1
+ policy = [create_policy(self._cfg.policy[0], enable_field=['eval']).eval_mode]
+ self.policy = policy
+ self._policy_is_active = [None]
+ self._policy_iter = [None]
+ self._traj_buffer_length = self._traj_len if self._traj_len != INF else None
+ self._traj_buffer = {env_id: [TrajBuffer(self._traj_len)] for env_id in range(self._env_num)}
+ else:
+ assert len(self._cfg.policy) == 2
+ policy = [create_policy(self._cfg.policy[i], enable_field=['collect']).collect_mode for i in range(2)]
+ self.policy = policy
+ self._policy_is_active = [None for _ in range(2)]
+ self._policy_iter = [None for _ in range(2)]
+ self._traj_buffer_length = self._traj_len if self._traj_len != INF else None
+ self._traj_buffer = {
+ env_id: [TrajBuffer(self._traj_buffer_length) for _ in range(len(policy))]
+ for env_id in range(self._env_num)
+ }
+ # self._first_update_policy = True
+
+ self._episode_result = [[] for k in range(self._env_num)]
+ self._obs_pool = CachePool('obs', self._env_num)
+ self._policy_output_pool = CachePool('policy_output', self._env_num)
+ self._total_step = 0
+ self._total_sample = 0
+ self._total_episode = 0
+
+ @property
+ def policy(self) -> List[Policy]:
+ return self._policy
+
+ # override
+ @policy.setter
+ def policy(self, _policy: List[Policy]) -> None:
+ self._policy = _policy
+ self._n_episode = _policy[0].get_attribute('cfg').collect.get('n_episode', None)
+ self._n_sample = _policy[0].get_attribute('cfg').collect.get('n_sample', None)
+ assert any(
+ [t is None for t in [self._n_sample, self._n_episode]]
+ ), "n_episode/n_sample in policy cfg can't be not None at the same time"
+ # TODO(nyz) the same definition of traj_len in serial and parallel
+ if self._n_episode is not None:
+ self._traj_len = INF
+ elif self._n_sample is not None:
+ self._traj_len = self._n_sample
+
+ @property
+ def env_manager(self, _env_manager) -> None:
+ self._env_manager = _env_manager
+
+ # override
+ @env_manager.setter
+ def env_manager(self, _env_manager: BaseEnvManager) -> None:
+ self._env_manager = _env_manager
+ self._env_manager.launch()
+ self._env_num = self._env_manager.env_num
+ self._predefined_episode_count = self._env_num * self._env_manager._episode_num
+
+ def _setup_env_manager(self, cfg: EasyDict) -> BaseEnvManager:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg)
+ if self._eval_flag:
+ env_cfg = evaluator_env_cfg
+ else:
+ env_cfg = collector_env_cfg
+ env_manager = create_env_manager(cfg.manager, [partial(env_fn, cfg=c) for c in env_cfg])
+ return env_manager
+
+ def _start_thread(self) -> None:
+ # evaluator doesn't need to update policy periodically, only updating policy when starts
+ if not self._eval_flag:
+ self._update_policy_thread.start()
+
+ def _join_thread(self) -> None:
+ if not self._eval_flag:
+ self._update_policy_thread.join()
+ del self._update_policy_thread
+
+ # override
+ def close(self) -> None:
+ if self._end_flag:
+ return
+ self._end_flag = True
+ time.sleep(1)
+ if hasattr(self, '_env_manager'):
+ self._env_manager.close()
+ self._join_thread()
+
+ # override
+ def _policy_inference(self, obs: Dict[int, Any]) -> Dict[int, Any]:
+ env_ids = list(obs.keys())
+ if len(self._policy) > 1:
+ assert not self._eval_flag
+ obs = [{id: obs[id][i] for id in env_ids} for i in range(len(self._policy))]
+ else:
+ assert self._eval_flag
+ obs = [obs]
+ self._obs_pool.update(obs)
+ policy_outputs = []
+ for i in range(len(self._policy)):
+ if self._eval_flag:
+ policy_output = self._policy[i].forward(obs[i])
+ else:
+ policy_output = self._policy[i].forward(obs[i], **self._cfg.collect_setting)
+ policy_outputs.append(policy_output)
+ self._policy_output_pool.update(policy_outputs)
+ actions = {}
+ for env_id in env_ids:
+ action = [policy_outputs[i][env_id]['action'] for i in range(len(self._policy))]
+ action = torch.stack(action).squeeze()
+ actions[env_id] = action
+ return actions
+
+ # override
+ def _env_step(self, actions: Dict[int, Any]) -> Dict[int, Any]:
+ return self._env_manager.step(actions)
+
+ # override
+ def _process_timestep(self, timestep: Dict[int, namedtuple]) -> None:
+ for env_id, t in timestep.items():
+ if t.info.get('abnormal', False):
+ # If there is an abnormal timestep, reset all the related variables, also this env has been reset
+ for c in self._traj_buffer[env_id]:
+ c.clear()
+ self._obs_pool.reset(env_id)
+ self._policy_output_pool.reset(env_id)
+ for p in self._policy:
+ p.reset([env_id])
+ continue
+ self._total_step += 1
+ t = [BaseEnvTimestep(t.obs[i], t.reward[i], t.done, t.info) for i in range(len(self._policy))]
+ if t[0].done:
+ self._total_episode += 1
+ if not self._eval_flag:
+ for i in range(len(self._policy)):
+ if self._policy_is_active[i]:
+ # Only active policy will store transition into replay buffer.
+ transition = self._policy[i].process_transition(
+ self._obs_pool[env_id][i], self._policy_output_pool[env_id][i], t[i]
+ )
+ self._traj_buffer[env_id][i].append(transition)
+ full_indices = []
+ for i in range(len(self._traj_buffer[env_id])):
+ if len(self._traj_buffer[env_id][i]) == self._traj_len:
+ full_indices.append(i)
+ if t[0].done or len(full_indices) > 0:
+ for i in full_indices:
+ train_sample = self._policy[i].get_train_sample(self._traj_buffer[env_id][i])
+ for s in train_sample:
+ s = self._compressor(s)
+ self._total_sample += 1
+ metadata = self._get_metadata(s, env_id)
+ self.send_stepdata(metadata['data_id'], s)
+ self.send_metadata(metadata)
+ self._traj_buffer[env_id][i].clear()
+ if t[0].done:
+ # env reset is done by env_manager automatically
+ self._obs_pool.reset(env_id)
+ self._policy_output_pool.reset(env_id)
+ for p in self._policy:
+ p.reset([env_id])
+ reward = t[0].info['eval_episode_return']
+ # Only left player's reward will be recorded.
+ left_reward = reward[0]
+ if isinstance(left_reward, torch.Tensor):
+ left_reward = left_reward.item()
+ self._episode_result[env_id].append(left_reward)
+ self.debug(
+ "Env {} finish episode, final reward: {}, collected episode: {}.".format(
+ env_id, reward, len(self._episode_result[env_id])
+ )
+ )
+ self._total_step += 1
+ dones = [t.done for t in timestep.values()]
+ if any(dones):
+ collector_info = self._get_collector_info()
+ self.send_metadata(collector_info)
+
+ # override
+ def get_finish_info(self) -> dict:
+ duration = max(time.time() - self._start_time, 1e-8)
+ game_result = copy.deepcopy(self._episode_result)
+ for i, env_result in enumerate(game_result):
+ for j, rew in enumerate(env_result):
+ if rew < 0:
+ game_result[i][j] = "losses"
+ elif rew == 0:
+ game_result[i][j] = "draws"
+ else:
+ game_result[i][j] = "wins"
+
+ finish_info = {
+ # 'finished_task': True, # flag
+ 'eval_flag': self._eval_flag,
+ # 'episode_num': self._episode_num,
+ 'env_num': self._env_num,
+ 'duration': duration,
+ 'collector_done': self._env_manager.done,
+ 'predefined_episode_count': self._predefined_episode_count,
+ 'real_episode_count': self._total_episode,
+ 'step_count': self._total_step,
+ 'sample_count': self._total_sample,
+ 'avg_time_per_episode': duration / max(1, self._total_episode),
+ 'avg_time_per_step': duration / self._total_step,
+ 'avg_time_per_train_sample': duration / max(1, self._total_sample),
+ 'avg_step_per_episode': self._total_step / max(1, self._total_episode),
+ 'avg_sample_per_episode': self._total_sample / max(1, self._total_episode),
+ 'reward_mean': np.mean(self._episode_result),
+ 'reward_std': np.std(self._episode_result),
+ 'reward_raw': self._episode_result,
+ 'finish_time': time.time(),
+ 'game_result': game_result,
+ }
+ if not self._eval_flag:
+ finish_info['collect_setting'] = self._cfg.collect_setting
+ self._logger.info('\nFINISH INFO\n{}'.format(pretty_print(finish_info, direct_print=False)))
+ return finish_info
+
+ # override
+ def _update_policy(self) -> None:
+ path = self._cfg.policy_update_path
+ self._policy_is_active = self._cfg.policy_update_flag
+ for i in range(len(path)):
+ # if not self._first_update_policy and not self._policy_is_active[i]:
+ if not self._policy_is_active[i]:
+ # For the first time, all policies should be updated(i.e. initialized);
+ # For other times, only active player's policies should be updated.
+ continue
+ while True:
+ try:
+ policy_update_info = self.get_policy_update_info(path[i])
+ break
+ except Exception as e:
+ self.error('Policy {} update error: {}'.format(i + 1, e))
+ time.sleep(1)
+ if policy_update_info is None:
+ continue
+ self._policy_iter[i] = policy_update_info.pop('iter')
+ self._policy[i].load_state_dict(policy_update_info)
+ self.debug('Update policy {} with {}(iter{}) in {}'.format(i + 1, path, self._policy_iter, time.time()))
+ # self._first_update_policy = False
+
+ # ******************************** thread **************************************
+
+ def _update_policy_periodically(self) -> None:
+ last = time.time()
+ while not self._end_flag:
+ cur = time.time()
+ interval = cur - last
+ if interval < self._cfg.update_policy_second:
+ time.sleep(self._cfg.update_policy_second * 0.1)
+ continue
+ else:
+ self._update_policy()
+ last = time.time()
+ time.sleep(0.1)
+
+ def _get_metadata(self, stepdata: List, env_id: int) -> dict:
+ data_id = "env_{}_{}".format(env_id, str(uuid.uuid1()))
+ metadata = {
+ 'eval_flag': self._eval_flag,
+ 'data_id': data_id,
+ 'env_id': env_id,
+ 'policy_iter': self._policy_iter,
+ 'unroll_len': len(stepdata),
+ 'compressor': self._cfg.compressor,
+ 'get_data_time': time.time(),
+ # TODO(nyz) the relationship between traj priority and step priority
+ 'priority': 1.0,
+ 'cur_episode': self._total_episode,
+ 'cur_sample': self._total_sample,
+ 'cur_step': self._total_step,
+ }
+ return metadata
+
+ def _get_collector_info(self) -> dict:
+ return {
+ 'eval_flag': self._eval_flag,
+ 'get_info_time': time.time(),
+ 'collector_done': self._env_manager.done,
+ 'cur_episode': self._total_episode,
+ 'cur_sample': self._total_sample,
+ 'cur_step': self._total_step,
+ }
+
+ def __repr__(self) -> str:
+ return "MarineParallelCollector"
diff --git a/DI-engine/ding/worker/collector/metric_serial_evaluator.py b/DI-engine/ding/worker/collector/metric_serial_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..a160e437fcf772bf501eb6c2b70bdbb2281cf53d
--- /dev/null
+++ b/DI-engine/ding/worker/collector/metric_serial_evaluator.py
@@ -0,0 +1,225 @@
+from typing import Optional, Callable, Tuple, Any, List
+from abc import ABC, abstractmethod
+from collections import namedtuple
+import numpy as np
+import torch
+from torch.utils.data import DataLoader
+
+from ding.torch_utils import to_tensor, to_ndarray
+from ding.utils import build_logger, EasyTimer, SERIAL_EVALUATOR_REGISTRY, allreduce
+from .base_serial_evaluator import ISerialEvaluator, VectorEvalMonitor
+
+
+class IMetric(ABC):
+
+ @abstractmethod
+ def eval(self, inputs: Any, label: Any) -> dict:
+ raise NotImplementedError
+
+ @abstractmethod
+ def reduce_mean(self, inputs: List[Any]) -> Any:
+ raise NotImplementedError
+
+ @abstractmethod
+ def gt(self, metric1: Any, metric2: Any) -> bool:
+ """
+ Overview:
+ Whether metric1 is greater than metric2 (>=)
+
+ .. note::
+ If metric2 is None, return True
+ """
+ raise NotImplementedError
+
+
+@SERIAL_EVALUATOR_REGISTRY.register('metric')
+class MetricSerialEvaluator(ISerialEvaluator):
+ """
+ Overview:
+ Metric serial evaluator class, policy is evaluated by objective metric(env).
+ Interfaces:
+ __init__, reset, reset_policy, reset_env, close, should_eval, eval
+ Property:
+ env, policy
+ """
+
+ config = dict(
+ # Evaluate every "eval_freq" training iterations.
+ eval_freq=50,
+ )
+
+ def __init__(
+ self,
+ cfg: dict,
+ env: Tuple[DataLoader, IMetric] = None,
+ policy: namedtuple = None,
+ tb_logger: 'SummaryWriter' = None, # noqa
+ exp_name: Optional[str] = 'default_experiment',
+ instance_name: Optional[str] = 'evaluator',
+ ) -> None:
+ """
+ Overview:
+ Init method. Load config and use ``self._cfg`` setting to build common serial evaluator components,
+ e.g. logger helper, timer.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Configuration EasyDict.
+ """
+ self._cfg = cfg
+ self._exp_name = exp_name
+ self._instance_name = instance_name
+ if tb_logger is not None:
+ self._logger, _ = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
+ )
+ self._tb_logger = tb_logger
+ else:
+ self._logger, self._tb_logger = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
+ )
+ self.reset(policy, env)
+
+ self._timer = EasyTimer()
+ self._stop_value = cfg.stop_value
+
+ def reset_env(self, _env: Optional[Tuple[DataLoader, IMetric]] = None) -> None:
+ """
+ Overview:
+ Reset evaluator's environment. In some case, we need evaluator use the same policy in different \
+ environments. We can use reset_env to reset the environment.
+ If _env is not None, replace the old environment in the evaluator with the new one
+ Arguments:
+ - env (:obj:`Optional[Tuple[DataLoader, IMetric]]`): Instance of the DataLoader and Metric
+ """
+ if _env is not None:
+ self._dataloader, self._metric = _env
+
+ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
+ """
+ Overview:
+ Reset evaluator's policy. In some case, we need evaluator work in this same environment but use\
+ different policy. We can use reset_policy to reset the policy.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy
+ """
+ if _policy is not None:
+ self._policy = _policy
+ self._policy.reset()
+
+ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[Tuple[DataLoader, IMetric]] = None) -> None:
+ """
+ Overview:
+ Reset evaluator's policy and environment. Use new policy and environment to collect data.
+ If _env is not None, replace the old environment in the evaluator with the new one
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the evaluator with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[namedtuple]`): the api namedtuple of eval_mode policy
+ - env (:obj:`Optional[Tuple[DataLoader, IMetric]]`): Instance of the DataLoader and Metric
+ """
+ if _env is not None:
+ self.reset_env(_env)
+ if _policy is not None:
+ self.reset_policy(_policy)
+ self._max_avg_eval_result = None
+ self._last_eval_iter = -1
+ self._end_flag = False
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close the evaluator. If end_flag is False, close the environment, flush the tb_logger\
+ and close the tb_logger.
+ """
+ if self._end_flag:
+ return
+ self._end_flag = True
+ self._tb_logger.flush()
+ self._tb_logger.close()
+
+ def __del__(self):
+ """
+ Overview:
+ Execute the close command and close the evaluator. __del__ is automatically called \
+ to destroy the evaluator instance when the evaluator finishes its work
+ """
+ self.close()
+
+ def should_eval(self, train_iter: int) -> bool:
+ """
+ Overview:
+ Determine whether you need to start the evaluation mode, if the number of training has reached\
+ the maximum number of times to start the evaluator, return True
+ """
+ if train_iter == self._last_eval_iter:
+ return False
+ if (train_iter - self._last_eval_iter) < self._cfg.eval_freq and train_iter != 0:
+ return False
+ self._last_eval_iter = train_iter
+ return True
+
+ def eval(
+ self,
+ save_ckpt_fn: Callable = None,
+ train_iter: int = -1,
+ envstep: int = -1,
+ ) -> Tuple[bool, Any]:
+ '''
+ Overview:
+ Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
+ Arguments:
+ - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward.
+ - train_iter (:obj:`int`): Current training iteration.
+ - envstep (:obj:`int`): Current env interaction step.
+ Returns:
+ - stop_flag (:obj:`bool`): Whether this training program can be ended.
+ - eval_metric (:obj:`float`): Current evaluation metric result.
+ '''
+ self._policy.reset()
+ eval_results = []
+
+ with self._timer:
+ self._logger.info("Evaluation begin...")
+ for batch_idx, batch_data in enumerate(self._dataloader):
+ inputs, label = to_tensor(batch_data)
+ policy_output = self._policy.forward(inputs)
+ eval_results.append(self._metric.eval(policy_output, label))
+ avg_eval_result = self._metric.reduce_mean(eval_results)
+ if self._cfg.multi_gpu:
+ device = self._policy.get_attribute('device')
+ for k in avg_eval_result.keys():
+ value_tensor = torch.FloatTensor([avg_eval_result[k]]).to(device)
+ allreduce(value_tensor)
+ avg_eval_result[k] = value_tensor.item()
+
+ duration = self._timer.value
+ info = {
+ 'train_iter': train_iter,
+ 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter),
+ 'data_length': len(self._dataloader),
+ 'evaluate_time': duration,
+ 'avg_time_per_data': duration / len(self._dataloader),
+ }
+ info.update(avg_eval_result)
+ self._logger.info(self._logger.get_tabulate_vars_hor(info))
+ # self._logger.info(self._logger.get_tabulate_vars(info))
+ for k, v in info.items():
+ if k in ['train_iter', 'ckpt_name']:
+ continue
+ if not np.isscalar(v):
+ continue
+ self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
+ self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep)
+ if self._metric.gt(avg_eval_result, self._max_avg_eval_result):
+ if save_ckpt_fn:
+ save_ckpt_fn('ckpt_best.pth.tar')
+ self._max_avg_eval_result = avg_eval_result
+ stop_flag = self._metric.gt(avg_eval_result, self._stop_value) and train_iter > 0
+ if stop_flag:
+ self._logger.info(
+ "[DI-engine serial pipeline] " +
+ "Current episode_return: {} is greater than stop_value: {}".format(avg_eval_result, self._stop_value) +
+ ", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details."
+ )
+ return stop_flag, avg_eval_result
diff --git a/DI-engine/ding/worker/collector/sample_serial_collector.py b/DI-engine/ding/worker/collector/sample_serial_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..26db458edbcb5fe2881f1a4314c9b1295313d337
--- /dev/null
+++ b/DI-engine/ding/worker/collector/sample_serial_collector.py
@@ -0,0 +1,413 @@
+from typing import Optional, Any, List
+from collections import namedtuple
+from easydict import EasyDict
+import copy
+import numpy as np
+import torch
+
+from ding.envs import BaseEnvManager
+from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, one_time_warning, get_rank, get_world_size, \
+ broadcast_object_list, allreduce_data
+from ding.torch_utils import to_tensor, to_ndarray
+from .base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, to_tensor_transitions
+
+
+@SERIAL_COLLECTOR_REGISTRY.register('sample')
+class SampleSerialCollector(ISerialCollector):
+ """
+ Overview:
+ Sample collector(n_sample), a sample is one training sample for updating model,
+ it is usually like (one transition)
+ while is a trajectory with many transitions, which is often used in RNN-model.
+ Interfaces:
+ __init__, reset, reset_env, reset_policy, collect, close
+ Property:
+ envstep
+ """
+
+ config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100)
+
+ def __init__(
+ self,
+ cfg: EasyDict,
+ env: BaseEnvManager = None,
+ policy: namedtuple = None,
+ tb_logger: 'SummaryWriter' = None, # noqa
+ exp_name: Optional[str] = 'default_experiment',
+ instance_name: Optional[str] = 'collector'
+ ) -> None:
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config dict
+ - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager)
+ - policy (:obj:`namedtuple`): the api namedtuple of collect_mode policy
+ - tb_logger (:obj:`SummaryWriter`): tensorboard handle
+ """
+ self._exp_name = exp_name
+ self._instance_name = instance_name
+ self._collect_print_freq = cfg.collect_print_freq
+ self._deepcopy_obs = cfg.deepcopy_obs # whether to deepcopy each data
+ self._transform_obs = cfg.transform_obs
+ self._cfg = cfg
+ self._timer = EasyTimer()
+ self._end_flag = False
+ self._rank = get_rank()
+ self._world_size = get_world_size()
+
+ if self._rank == 0:
+ if tb_logger is not None:
+ self._logger, _ = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name),
+ name=self._instance_name,
+ need_tb=False
+ )
+ self._tb_logger = tb_logger
+ else:
+ self._logger, self._tb_logger = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
+ )
+ else:
+ self._logger, _ = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
+ )
+ self._tb_logger = None
+
+ self.reset(policy, env)
+
+ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset the environment.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the collector with the new passed \
+ in environment and launch.
+ Arguments:
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self._env = _env
+ self._env.launch()
+ self._env_num = self._env.env_num
+ else:
+ self._env.reset()
+
+ def reset_policy(self, _policy: Optional[namedtuple] = None) -> None:
+ """
+ Overview:
+ Reset the policy.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the collector with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy
+ """
+ assert hasattr(self, '_env'), "please set env first"
+ if _policy is not None:
+ self._policy = _policy
+ self._policy_cfg = self._policy.get_attribute('cfg')
+ self._default_n_sample = _policy.get_attribute('n_sample')
+ self._traj_len_inf = self._policy_cfg.traj_len_inf
+ self._unroll_len = _policy.get_attribute('unroll_len')
+ self._on_policy = _policy.get_attribute('on_policy')
+ if self._default_n_sample is not None and not self._traj_len_inf:
+ self._traj_len = max(
+ self._unroll_len,
+ self._default_n_sample // self._env_num + int(self._default_n_sample % self._env_num != 0)
+ )
+ self._logger.debug(
+ 'Set default n_sample mode(n_sample({}), env_num({}), traj_len({}))'.format(
+ self._default_n_sample, self._env_num, self._traj_len
+ )
+ )
+ else:
+ self._traj_len = INF
+ self._policy.reset()
+
+ def reset(self, _policy: Optional[namedtuple] = None, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset the environment and policy.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the collector with the new passed \
+ in environment and launch.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the collector with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[namedtuple]`): the api namedtuple of collect_mode policy
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self.reset_env(_env)
+ if _policy is not None:
+ self.reset_policy(_policy)
+
+ if self._policy_cfg.type == 'dreamer_command':
+ self._states = None
+ self._resets = np.array([False for i in range(self._env_num)])
+ self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs)
+ self._policy_output_pool = CachePool('policy_output', self._env_num)
+ # _traj_buffer is {env_id: TrajBuffer}, is used to store traj_len pieces of transitions
+ maxlen = self._traj_len if self._traj_len != INF else None
+ self._traj_buffer = {
+ env_id: TrajBuffer(maxlen=maxlen, deepcopy=self._deepcopy_obs)
+ for env_id in range(self._env_num)
+ }
+ self._env_info = {env_id: {'time': 0., 'step': 0, 'train_sample': 0} for env_id in range(self._env_num)}
+
+ self._episode_info = []
+ self._total_envstep_count = 0
+ self._total_episode_count = 0
+ self._total_train_sample_count = 0
+ self._total_duration = 0
+ self._last_train_iter = 0
+ self._end_flag = False
+
+ def _reset_stat(self, env_id: int) -> None:
+ """
+ Overview:
+ Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\
+ and env_info. Reset these states according to env_id. You can refer to base_serial_collector\
+ to get more messages.
+ Arguments:
+ - env_id (:obj:`int`): the id where we need to reset the collector's state
+ """
+ self._traj_buffer[env_id].clear()
+ self._obs_pool.reset(env_id)
+ self._policy_output_pool.reset(env_id)
+ self._env_info[env_id] = {'time': 0., 'step': 0, 'train_sample': 0}
+
+ @property
+ def envstep(self) -> int:
+ """
+ Overview:
+ Print the total envstep count.
+ Return:
+ - envstep (:obj:`int`): the total envstep count
+ """
+ return self._total_envstep_count
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close the collector. If end_flag is False, close the environment, flush the tb_logger\
+ and close the tb_logger.
+ """
+ if self._end_flag:
+ return
+ self._end_flag = True
+ self._env.close()
+ if self._tb_logger:
+ self._tb_logger.flush()
+ self._tb_logger.close()
+
+ def __del__(self) -> None:
+ """
+ Overview:
+ Execute the close command and close the collector. __del__ is automatically called to \
+ destroy the collector instance when the collector finishes its work
+ """
+ self.close()
+
+ def collect(
+ self,
+ n_sample: Optional[int] = None,
+ train_iter: int = 0,
+ drop_extra: bool = True,
+ random_collect: bool = False,
+ record_random_collect: bool = True,
+ policy_kwargs: Optional[dict] = None,
+ level_seeds: Optional[List] = None,
+ ) -> List[Any]:
+ """
+ Overview:
+ Collect `n_sample` data with policy_kwargs, which is already trained `train_iter` iterations.
+ Arguments:
+ - n_sample (:obj:`int`): The number of collecting data sample.
+ - train_iter (:obj:`int`): The number of training iteration when calling collect method.
+ - drop_extra (:obj:`bool`): Whether to drop extra return_data more than `n_sample`.
+ - record_random_collect (:obj:`bool`) :Whether to output logs of random collect.
+ - policy_kwargs (:obj:`dict`): The keyword args for policy forward.
+ - level_seeds (:obj:`dict`): Used in PLR, represents the seed of the environment that \
+ generate the data
+ Returns:
+ - return_data (:obj:`List`): A list containing training samples.
+ """
+ if n_sample is None:
+ if self._default_n_sample is None:
+ raise RuntimeError("Please specify collect n_sample")
+ else:
+ n_sample = self._default_n_sample
+ if n_sample % self._env_num != 0:
+ one_time_warning(
+ "Please make sure env_num is divisible by n_sample: {}/{}, ".format(n_sample, self._env_num) +
+ "which may cause convergence problems in a few algorithms"
+ )
+ if policy_kwargs is None:
+ policy_kwargs = {}
+ collected_sample = 0
+ collected_step = 0
+ collected_episode = 0
+ return_data = []
+
+ while collected_sample < n_sample:
+ with self._timer:
+ # Get current env obs.
+ obs = self._env.ready_obs
+ # Policy forward.
+ self._obs_pool.update(obs)
+ if self._transform_obs:
+ obs = to_tensor(obs, dtype=torch.float32)
+ if self._policy_cfg.type == 'dreamer_command' and not random_collect:
+ policy_output = self._policy.forward(obs, **policy_kwargs, reset=self._resets, state=self._states)
+ #self._states = {env_id: output['state'] for env_id, output in policy_output.items()}
+ self._states = [output['state'] for output in policy_output.values()]
+ else:
+ policy_output = self._policy.forward(obs, **policy_kwargs)
+ self._policy_output_pool.update(policy_output)
+ # Interact with env.
+ actions = {env_id: output['action'] for env_id, output in policy_output.items()}
+ actions = to_ndarray(actions)
+ timesteps = self._env.step(actions)
+
+ # TODO(nyz) this duration may be inaccurate in async env
+ interaction_duration = self._timer.value / len(timesteps)
+
+ # TODO(nyz) vectorize this for loop
+ for env_id, timestep in timesteps.items():
+ with self._timer:
+ if timestep.info.get('abnormal', False):
+ # If there is an abnormal timestep, reset all the related variables(including this env).
+ # suppose there is no reset param, just reset this env
+ self._env.reset({env_id: None})
+ self._policy.reset([env_id])
+ self._reset_stat(env_id)
+ self._logger.info('Env{} returns a abnormal step, its info is {}'.format(env_id, timestep.info))
+ continue
+ if self._policy_cfg.type == 'dreamer_command' and not random_collect:
+ self._resets[env_id] = timestep.done
+ if self._policy_cfg.type == 'ngu_command': # for NGU policy
+ transition = self._policy.process_transition(
+ self._obs_pool[env_id], self._policy_output_pool[env_id], timestep, env_id
+ )
+ else:
+ transition = self._policy.process_transition(
+ self._obs_pool[env_id], self._policy_output_pool[env_id], timestep
+ )
+ if level_seeds is not None:
+ transition['seed'] = level_seeds[env_id]
+ # ``train_iter`` passed in from ``serial_entry``, indicates current collecting model's iteration.
+ transition['collect_iter'] = train_iter
+ self._traj_buffer[env_id].append(transition)
+ self._env_info[env_id]['step'] += 1
+ collected_step += 1
+ # prepare data
+ if timestep.done or len(self._traj_buffer[env_id]) == self._traj_len:
+ # If policy is r2d2:
+ # 1. For each collect_env, we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # 2. The length of a train (sequence) sample in r2d2 is
+ # (please refer to r2d2.py) and in each collect phase,
+ # we collect a total of (sequence) samples.
+ # 3. When timestep is done and we only collected very few transitions in self._traj_buffer,
+ # by going through self._policy.get_train_sample, it will be padded automatically to get the
+ # sequence sample of length (please refer to r2d2.py).
+
+ # Episode is done or traj_buffer(maxlen=traj_len) is full.
+ # indicate whether to shallow copy next obs, i.e., overlap of s_t and s_t+1
+ transitions = to_tensor_transitions(self._traj_buffer[env_id], not self._deepcopy_obs)
+ train_sample = self._policy.get_train_sample(transitions)
+ return_data.extend(train_sample)
+ self._env_info[env_id]['train_sample'] += len(train_sample)
+ collected_sample += len(train_sample)
+ self._traj_buffer[env_id].clear()
+
+ self._env_info[env_id]['time'] += self._timer.value + interaction_duration
+
+ # If env is done, record episode info and reset
+ if timestep.done:
+ collected_episode += 1
+ reward = timestep.info['eval_episode_return']
+ info = {
+ 'reward': reward,
+ 'time': self._env_info[env_id]['time'],
+ 'step': self._env_info[env_id]['step'],
+ 'train_sample': self._env_info[env_id]['train_sample'],
+ }
+ self._episode_info.append(info)
+ # Env reset is done by env_manager automatically
+ self._policy.reset([env_id])
+ self._reset_stat(env_id)
+
+ collected_duration = sum([d['time'] for d in self._episode_info])
+ # reduce data when enables DDP
+ if self._world_size > 1:
+ collected_sample = allreduce_data(collected_sample, 'sum')
+ collected_step = allreduce_data(collected_step, 'sum')
+ collected_episode = allreduce_data(collected_episode, 'sum')
+ collected_duration = allreduce_data(collected_duration, 'sum')
+ self._total_envstep_count += collected_step
+ self._total_episode_count += collected_episode
+ self._total_duration += collected_duration
+ self._total_train_sample_count += collected_sample
+ # log
+ if record_random_collect: # default is true, but when random collect, record_random_collect is False
+ self._output_log(train_iter)
+ else:
+ self._episode_info.clear()
+ # on-policy reset
+ if self._on_policy:
+ for env_id in range(self._env_num):
+ self._reset_stat(env_id)
+
+ if drop_extra:
+ return return_data[:n_sample]
+ else:
+ return return_data
+
+ def _output_log(self, train_iter: int) -> None:
+ """
+ Overview:
+ Print the output log information. You can refer to the docs of `Best Practice` to understand \
+ the training generated logs and tensorboards.
+ Arguments:
+ - train_iter (:obj:`int`): the number of training iteration.
+ """
+ if self._rank != 0:
+ return
+ if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0:
+ self._last_train_iter = train_iter
+ episode_count = len(self._episode_info)
+ envstep_count = sum([d['step'] for d in self._episode_info])
+ train_sample_count = sum([d['train_sample'] for d in self._episode_info])
+ duration = sum([d['time'] for d in self._episode_info])
+ episode_return = [d['reward'] for d in self._episode_info]
+ info = {
+ 'episode_count': episode_count,
+ 'envstep_count': envstep_count,
+ 'train_sample_count': train_sample_count,
+ 'avg_envstep_per_episode': envstep_count / episode_count,
+ 'avg_sample_per_episode': train_sample_count / episode_count,
+ 'avg_envstep_per_sec': envstep_count / duration,
+ 'avg_train_sample_per_sec': train_sample_count / duration,
+ 'avg_episode_per_sec': episode_count / duration,
+ 'reward_mean': np.mean(episode_return),
+ 'reward_std': np.std(episode_return),
+ 'reward_max': np.max(episode_return),
+ 'reward_min': np.min(episode_return),
+ 'total_envstep_count': self._total_envstep_count,
+ 'total_train_sample_count': self._total_train_sample_count,
+ 'total_episode_count': self._total_episode_count,
+ # 'each_reward': episode_return,
+ }
+ self._episode_info.clear()
+ self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))
+ for k, v in info.items():
+ if k in ['each_reward']:
+ continue
+ self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
+ if k in ['total_envstep_count']:
+ continue
+ self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)
diff --git a/DI-engine/ding/worker/collector/tests/__init__.py b/DI-engine/ding/worker/collector/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/ding/worker/collector/tests/fake_cls_policy.py b/DI-engine/ding/worker/collector/tests/fake_cls_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..6bbebc0fd6496b4a0c59a26c2973c18500ff113a
--- /dev/null
+++ b/DI-engine/ding/worker/collector/tests/fake_cls_policy.py
@@ -0,0 +1,34 @@
+from ding.policy import Policy
+from ding.model import model_wrap
+
+
+class fake_policy(Policy):
+
+ def _init_learn(self):
+ pass
+
+ def _forward_learn(self, data):
+ pass
+
+ def _init_eval(self):
+ self._eval_model = model_wrap(self._model, 'base')
+
+ def _forward_eval(self, data):
+ self._eval_model.eval()
+ output = self._eval_model.forward(data)
+ return output
+
+ def _monitor_vars_learn(self):
+ return ['forward_time', 'backward_time', 'sync_time']
+
+ def _init_collect(self):
+ pass
+
+ def _forward_collect(self, data):
+ pass
+
+ def _process_transition(self):
+ pass
+
+ def _get_train_sample(self):
+ pass
diff --git a/DI-engine/ding/worker/collector/tests/fake_cpong_dqn_config.py b/DI-engine/ding/worker/collector/tests/fake_cpong_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a3d9d3e384818cf25b86c95bda7e10853dd9ba9
--- /dev/null
+++ b/DI-engine/ding/worker/collector/tests/fake_cpong_dqn_config.py
@@ -0,0 +1,97 @@
+from easydict import EasyDict
+from ding.config import parallel_transform
+
+fake_cpong_dqn_config = dict(
+ exp_name='fake_cpong_dqn',
+ env=dict(
+ collector_env_num=16,
+ collector_episode_num=2,
+ evaluator_env_num=8,
+ evaluator_episode_num=2,
+ stop_value=20,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=3,
+ encoder_hidden_size_list=[128, 128, 256],
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ learn=dict(
+ batch_size=16,
+ learning_rate=0.001,
+ learner=dict(
+ learner_num=1,
+ send_policy_freq=1,
+ ),
+ ),
+ collect=dict(
+ n_sample=16,
+ collector=dict(
+ collector_num=2,
+ update_policy_second=5,
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=5, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=100000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=10000,
+ enable_track_used_data=False,
+ ),
+ commander=dict(
+ collector_task_space=2,
+ learner_task_space=1,
+ eval_interval=5,
+ league=dict(),
+ ),
+ ),
+ )
+)
+fake_cpong_dqn_config = EasyDict(fake_cpong_dqn_config)
+main_config = fake_cpong_dqn_config
+
+fake_cpong_dqn_create_config = dict(
+ env=dict(
+ import_names=['ding.worker.collector.tests.test_marine_parallel_collector'],
+ type='fake_competitive_rl',
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn_command'),
+ learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
+ collector=dict(
+ type='marine',
+ import_names=['ding.worker.collector.marine_parallel_collector'],
+ ),
+ commander=dict(
+ type='one_vs_one',
+ import_names=['ding.worker.coordinator.one_vs_one_parallel_commander'],
+ ),
+ comm_learner=dict(
+ type='flask_fs',
+ import_names=['ding.worker.learner.comm.flask_fs_learner'],
+ ),
+ comm_collector=dict(
+ type='flask_fs',
+ import_names=['ding.worker.collector.comm.flask_fs_collector'],
+ ),
+)
+fake_cpong_dqn_create_config = EasyDict(fake_cpong_dqn_create_config)
+create_config = fake_cpong_dqn_create_config
+
+fake_cpong_dqn_system_config = dict(
+ coordinator=dict(),
+ path_data='./data',
+ path_policy='./policy',
+ communication_mode='auto',
+ learner_gpu_num=0,
+)
+fake_cpong_dqn_system_config = EasyDict(fake_cpong_dqn_system_config)
+system_config = fake_cpong_dqn_system_config
diff --git a/DI-engine/ding/worker/collector/tests/speed_test/__init__.py b/DI-engine/ding/worker/collector/tests/speed_test/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/ding/worker/collector/tests/speed_test/fake_env.py b/DI-engine/ding/worker/collector/tests/speed_test/fake_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..731e990e3455010bb312fb589a91388e6f937140
--- /dev/null
+++ b/DI-engine/ding/worker/collector/tests/speed_test/fake_env.py
@@ -0,0 +1,86 @@
+from typing import Any, List, Union, Optional
+import time
+import gym
+import numpy as np
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.torch_utils import to_ndarray
+
+from ding.worker.collector.tests.speed_test.utils import random_change
+
+global env_sum
+env_sum = 0
+
+
+def env_sleep(duration):
+ time.sleep(duration)
+ global env_sum
+ env_sum += duration
+
+
+class FakeEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._obs_dim = cfg.get('obs_dim', 4)
+ self._action_dim = cfg.get('action_dim', 2)
+ self._episode_step_base = cfg.get('episode_step', 200)
+ self._reset_time = cfg.get('reset_time', 0.)
+ self._step_time = cfg.get('step_time', 0.)
+ self.reset()
+ # gym attribute
+ self.metadata = {'render.modes': ['human', 'rgb_array'], 'video.frames_per_second': 1}
+ self._observation_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(self._obs_dim, ), dtype=np.float32)
+ self._action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(self._action_dim, ), dtype=np.float32)
+ self._reward_space = gym.spaces.Box(low=-1.0, high=1.0, shape=(1, ), dtype=np.float32)
+ self._init_flag = True
+
+ def reset(self) -> np.ndarray:
+ if hasattr(self, '_seed'):
+ self.seed()
+ self._episode_step = int(random_change(self._episode_step_base))
+ env_sleep(random_change(self._reset_time))
+ self._step_count = 0
+ self._eval_episode_return = 0.
+ obs = np.random.randn(self._obs_dim).astype(np.float32)
+ return obs
+
+ def close(self) -> None:
+ self._init_flag = False
+
+ def seed(self, seed: Optional[int] = None) -> None:
+ if seed is not None:
+ self._seed = seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ env_sleep(random_change(self._step_time))
+ self._step_count += 1
+ obs = np.random.randn(self._obs_dim).astype(np.float32)
+ rew = np.random.randint(2)
+ done = True if self._step_count == self._episode_step else False
+ info = {}
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ rew = to_ndarray([rew]) # to shape (1,)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def __repr__(self) -> str:
+ return "DI-engine Fake Env for collector profile test"
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
diff --git a/DI-engine/ding/worker/collector/tests/speed_test/fake_policy.py b/DI-engine/ding/worker/collector/tests/speed_test/fake_policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..b36b7bdf320144e30d1f5328ecc8e965af727d63
--- /dev/null
+++ b/DI-engine/ding/worker/collector/tests/speed_test/fake_policy.py
@@ -0,0 +1,96 @@
+from collections import namedtuple, deque
+from typing import Optional, List, Dict, Any, Tuple, Union
+import torch
+from easydict import EasyDict
+import time
+
+from ding.model import create_model
+from ding.utils import import_module, allreduce, broadcast, get_rank, POLICY_REGISTRY
+from ding.utils.data import default_collate, default_decollate
+from ding.policy import Policy
+from ding.rl_utils import get_train_sample
+
+from ding.worker.collector.tests.speed_test.utils import random_change
+
+
+class FakePolicy(Policy):
+ config = dict(
+ cuda=False,
+ on_policy=False,
+ forward_time=0.002,
+ learn=dict(),
+ collect=dict(
+ n_sample=80,
+ unroll_len=1,
+ collector=dict(collect_print_freq=1000000),
+ ),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
+ )
+
+ def __init__(
+ self,
+ cfg: dict,
+ model: Optional[Union[type, torch.nn.Module]] = None,
+ enable_field: Optional[List[str]] = None
+ ) -> None:
+ self._cfg = cfg
+ self._cuda = cfg.cuda and torch.cuda.is_available()
+ self._init_collect()
+ self._forward_time = cfg.forward_time
+ self._on_policy = cfg.on_policy
+ self.policy_sum = 0
+ self.policy_times = 0
+
+ def policy_sleep(self, duration):
+ time.sleep(duration)
+ self.policy_sum += duration
+ self.policy_times += 1
+
+ def _init_learn(self) -> None:
+ pass
+
+ def _init_collect(self) -> None:
+ self._unroll_len = 1
+
+ def _init_eval(self) -> None:
+ pass
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ pass
+
+ def _create_model(self, cfg: dict, model: Optional[Union[type, torch.nn.Module]] = None) -> torch.nn.Module:
+ pass
+
+ def _forward_eval(self, data_id: List[int], data: dict) -> dict:
+ pass
+
+ def _forward_learn(self, data_id: List[int], data: dict) -> dict:
+ pass
+
+ # *************************************** collect function ************************************
+
+ def _forward_collect(self, data: dict, **kwargs) -> dict:
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ self.policy_sleep(random_change(self._forward_time))
+ output = {'action': torch.ones(data.shape[0], 1).long()}
+ output = default_decollate(output)
+ output = {i: d for i, d in zip(data_id, output)}
+ return output
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'action': model_output['action'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: deque) -> Union[None, List[Any]]:
+ return get_train_sample(data, self._unroll_len)
+
+ def _reset_collect(self, data_id: Optional[List[int]] = None) -> None:
+ pass
diff --git a/DI-engine/ding/worker/collector/tests/speed_test/test_collector_profile.py b/DI-engine/ding/worker/collector/tests/speed_test/test_collector_profile.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f1a268c402286b9e29aa10c5026a5e105197bd3
--- /dev/null
+++ b/DI-engine/ding/worker/collector/tests/speed_test/test_collector_profile.py
@@ -0,0 +1,207 @@
+from ditk import logging
+import time
+import copy
+import pytest
+import numpy as np
+import gym
+from easydict import EasyDict
+from functools import partial
+
+from ding.worker import SampleSerialCollector, NaiveReplayBuffer
+from ding.envs import get_vec_env_setting, create_env_manager, AsyncSubprocessEnvManager, SyncSubprocessEnvManager,\
+ BaseEnvManager, get_env_manager_cls, DingEnvWrapper
+from ding.utils import deep_merge_dicts, set_pkg_seed, pretty_print
+
+from ding.worker.collector.tests.speed_test.fake_policy import FakePolicy
+from ding.worker.collector.tests.speed_test.fake_env import FakeEnv
+
+n_sample = 80
+env_policy_cfg_dict = dict(
+ # Small env and policy, such as Atari/Mujoco
+ small=dict(
+ size="small",
+ env=dict(
+ collector_env_num=8,
+ obs_dim=64,
+ action_dim=2,
+ episode_step=500,
+ reset_time=0.1,
+ step_time=0.005,
+ manager=dict(),
+ ),
+ policy=dict(forward_time=0.004),
+ ),
+ # Middle env and policy, such as Carla/Sumo/Vizdoom
+ middle=dict(
+ size="middle",
+ env=dict(
+ collector_env_num=8,
+ obs_dim=int(3e2), # int(3e3),
+ action_dim=2,
+ episode_step=500,
+ reset_time=0.5,
+ step_time=0.01,
+ manager=dict(),
+ ),
+ policy=dict(forward_time=0.008),
+ ),
+ # Big env and policy, such as SC2 full game
+ big=dict(
+ size="big",
+ env=dict(
+ collector_env_num=8,
+ obs_dim=int(3e3), # int(3e6),
+ action_dim=2,
+ episode_step=500,
+ reset_time=2,
+ step_time=0.1,
+ manager=dict(),
+ ),
+ policy=dict(forward_time=0.02)
+ ),
+ # cartpole env
+ cartpole=dict(
+ size='cartpole',
+ env=dict(collector_env_num=8, stop_value=195, reset_time=0.5, manager=dict(reset_inplace=True, )),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ collect=dict(
+ n_sample=n_sample,
+ collector=dict(collect_print_freq=1000000),
+ ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+ )
+)
+
+
+def wrapped_cartpole_env():
+ return DingEnvWrapper(gym.make('CartPole-v0'))
+
+
+def wrapped_gym_cartpole_env():
+ return gym.make('CartPole-v0')
+
+
+# SLOW MODE: used in normal test
+# - Repeat 3 times; Collect 300 times;
+# - Test on small + middle + big env
+# - Test on base + async_subprocess + sync_subprocess env manager
+# - Test with reset_ratio = 1 and 5.
+# FAST MODE: used in CI benchmark test
+# - Only once (No repeat); Collect 50 times;
+# - Test on small env
+# - Test on base + sync_subprocess env manager
+# - Test with reset_ratio = 1.
+FAST_MODE = True
+if FAST_MODE:
+ # Note: 'base' takes approximately 6 times longer than 'subprocess'
+ test_env_manager_list = ['base', 'subprocess', 'gym_vector']
+ test_env_policy_cfg_dict = {'small': env_policy_cfg_dict['small'], 'cartpole': env_policy_cfg_dict['cartpole']}
+ env_reset_ratio_list = [1]
+ repeat_times_per_test = 1
+ collect_times_per_repeat = 50
+else:
+ test_env_manager_list = ['base', 'subprocess', 'sync_subprocess', 'gym_vector']
+ test_env_policy_cfg_dict = env_policy_cfg_dict
+ env_reset_ratio_list = [1, 5]
+ repeat_times_per_test = 3
+ collect_times_per_repeat = 300
+
+
+def compare_test(cfg: EasyDict, seed: int, test_name: str) -> None:
+ duration_list = []
+ total_collected_sample = n_sample * collect_times_per_repeat
+ for i in range(repeat_times_per_test):
+ # create collector_env
+ collector_env_cfg = copy.deepcopy(cfg.env)
+ collector_env_num = collector_env_cfg.collector_env_num
+ if cfg.size == 'cartpole':
+ if cfg.env.manager.type == 'gym_vector':
+ collector_env_fns = [wrapped_gym_cartpole_env for _ in range(collector_env_num)]
+ else:
+ collector_env_fns = [wrapped_cartpole_env for _ in range(collector_env_num)]
+ else:
+ collector_env_fns = [partial(FakeEnv, cfg=collector_env_cfg) for _ in range(collector_env_num)]
+
+ collector_env = create_env_manager(cfg.env.manager, collector_env_fns)
+ collector_env.seed(seed)
+ # create policy
+ policy = FakePolicy(cfg.policy)
+
+ # create collector and buffer
+ collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode)
+ replay_buffer = NaiveReplayBuffer(cfg.policy.other.replay_buffer)
+
+ # collect test
+
+ t1 = time.time()
+ for i in range(collect_times_per_repeat):
+ new_data = collector.collect()
+ assert len(new_data) == n_sample
+ replay_buffer.push(new_data, cur_collector_envstep=i * n_sample)
+ duration_list.append(time.time() - t1)
+
+ # close and release
+ collector.close()
+ replay_buffer.close()
+ del policy
+ del collector
+ del replay_buffer
+
+ fps = [total_collected_sample / duration for duration in duration_list]
+
+ template = "Test Name: {}\t Test Result: Avg FPS(env frame per second): {:.3f}±{:.3f} frame/s"
+ print(template.format(test_name, np.mean(fps), np.std(fps)))
+
+
+# TODO(nyz) fix CI bug when py==3.8.15
+@pytest.mark.tmp
+def test_collector_profile():
+ # ignore them for clear log
+ collector_log = logging.getLogger('collector_logger')
+ collector_log.disabled = True
+ buffer_log = logging.getLogger('buffer_logger')
+ buffer_log.disabled = True
+
+ seed = 0
+ set_pkg_seed(seed, use_cuda=False)
+ print("=========== test_collector_profile ===========")
+
+ for cfg_name, env_policy_cfg in test_env_policy_cfg_dict.items():
+ for env_manager_type in test_env_manager_list:
+ for env_reset_ratio in env_reset_ratio_list:
+
+ test_name = '{}-{}-reset{}'.format(cfg_name, env_manager_type, env_reset_ratio)
+ copy_cfg = EasyDict(copy.deepcopy(env_policy_cfg))
+ env_manager_cfg = EasyDict({'type': env_manager_type})
+
+ # modify args inplace
+ copy_cfg.policy = deep_merge_dicts(FakePolicy.default_config(), copy_cfg.policy)
+ copy_cfg.policy.collect.collector = deep_merge_dicts(
+ SampleSerialCollector.default_config(), copy_cfg.policy.collect.collector
+ )
+ copy_cfg.policy.collect.collector.n_sample = n_sample
+ copy_cfg.policy.other.replay_buffer = deep_merge_dicts(
+ NaiveReplayBuffer.default_config(), copy_cfg.policy.other.replay_buffer
+ )
+ copy_cfg.env.reset_time *= env_reset_ratio
+ manager_cfg = get_env_manager_cls(env_manager_cfg).default_config()
+ copy_cfg.env.manager = deep_merge_dicts(manager_cfg, copy_cfg.env.manager)
+ copy_cfg.env.manager.type = env_manager_type
+
+ compare_test(copy_cfg, seed, test_name)
diff --git a/DI-engine/ding/worker/collector/tests/speed_test/utils.py b/DI-engine/ding/worker/collector/tests/speed_test/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e13e1c072c1b678e57dd977aa5139f6fe381e22e
--- /dev/null
+++ b/DI-engine/ding/worker/collector/tests/speed_test/utils.py
@@ -0,0 +1,5 @@
+import numpy as np
+
+
+def random_change(number):
+ return number * (1 + (np.random.random() - 0.5) * 0.6)
diff --git a/DI-engine/ding/worker/collector/tests/test_base_serial_collector.py b/DI-engine/ding/worker/collector/tests/test_base_serial_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..475a6a4b17489d66322ec4076c18b9c6becf876e
--- /dev/null
+++ b/DI-engine/ding/worker/collector/tests/test_base_serial_collector.py
@@ -0,0 +1,42 @@
+import pytest
+import numpy as np
+import torch
+from ding.worker.collector.base_serial_collector import to_tensor_transitions
+
+
+def get_transition():
+ return {
+ 'obs': np.random.random((2, 3)),
+ 'action': np.random.randint(0, 6, size=(1, )),
+ 'reward': np.random.random((1, )),
+ 'done': False,
+ 'next_obs': np.random.random((2, 3)),
+ }
+
+
+@pytest.mark.unittest
+def test_to_tensor_transitions():
+ # test case when shallow copy is True
+ transition_list = [get_transition() for _ in range(4)]
+ tensor_list = to_tensor_transitions(transition_list, shallow_copy_next_obs=True)
+ for i in range(len(tensor_list)):
+ tensor = tensor_list[i]
+ assert isinstance(tensor['obs'], torch.Tensor)
+ assert isinstance(tensor['action'], torch.Tensor), type(tensor['action'])
+ assert isinstance(tensor['reward'], torch.Tensor)
+ assert isinstance(tensor['done'], bool)
+ assert 'next_obs' in tensor
+ if i < len(tensor_list) - 1:
+ assert id(tensor['next_obs']) == id(tensor_list[i + 1]['obs'])
+ # test case when shallow copy is False
+ transition_list = [get_transition() for _ in range(4)]
+ tensor_list = to_tensor_transitions(transition_list, shallow_copy_next_obs=False)
+ for i in range(len(tensor_list)):
+ tensor = tensor_list[i]
+ assert isinstance(tensor['obs'], torch.Tensor)
+ assert isinstance(tensor['action'], torch.Tensor)
+ assert isinstance(tensor['reward'], torch.Tensor)
+ assert isinstance(tensor['done'], bool)
+ assert 'next_obs' in tensor
+ if i < len(tensor_list) - 1:
+ assert id(tensor['next_obs']) != id(tensor_list[i + 1]['obs'])
diff --git a/DI-engine/ding/worker/collector/tests/test_episode_serial_collector.py b/DI-engine/ding/worker/collector/tests/test_episode_serial_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..2586e84ef4ea6664f1a526e892feb4988f2c5931
--- /dev/null
+++ b/DI-engine/ding/worker/collector/tests/test_episode_serial_collector.py
@@ -0,0 +1,56 @@
+import pytest
+from ding.worker import EpisodeSerialCollector
+from ding.envs import BaseEnvManager, SyncSubprocessEnvManager, AsyncSubprocessEnvManager
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from dizoo.classic_control.cartpole.envs import CartPoleEnv
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('env_manager_type', [BaseEnvManager, SyncSubprocessEnvManager])
+def test_collect(env_manager_type):
+ env = env_manager_type([lambda: CartPoleEnv({}) for _ in range(8)], env_manager_type.default_config())
+ env.seed(0)
+ model = DQN(obs_shape=4, action_shape=1)
+ policy = DQNPolicy(DQNPolicy.default_config(), model=model).collect_mode
+ collector = EpisodeSerialCollector(EpisodeSerialCollector.default_config(), env, policy)
+
+ collected_episode = collector.collect(
+ n_episode=18, train_iter=collector._collect_print_freq, policy_kwargs={'eps': 0.5}
+ )
+ assert len(collected_episode) == 18
+ assert all([e[-1]['done'] for e in collected_episode])
+ assert all([len(c) == 0 for c in collector._traj_buffer.values()])
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('env_manager_type', [BaseEnvManager, SyncSubprocessEnvManager])
+def test_abnormal_env_step(env_manager_type):
+
+ class AbnormalEnv(CartPoleEnv):
+
+ def step(self, action):
+ timestep = super().step(action)
+ new_info = timestep.info
+ if not hasattr(self, 'count'):
+ self.count = 0
+ if self.count <= 3:
+ new_info['abnormal'] = True
+ new_info['count'] = self.count
+ self.count += 1
+ timestep._replace(info=new_info)
+ return timestep
+
+ env = env_manager_type(
+ [lambda: CartPoleEnv({}) for _ in range(3)] + [lambda: AbnormalEnv({})], env_manager_type.default_config()
+ )
+ env.seed(0)
+ model = DQN(obs_shape=4, action_shape=1)
+ policy = DQNPolicy(DQNPolicy.default_config(), model=model).collect_mode
+ collector = EpisodeSerialCollector(EpisodeSerialCollector.default_config(), env, policy)
+
+ collected_episode = collector.collect(
+ n_episode=8, train_iter=collector._collect_print_freq, policy_kwargs={'eps': 0.5}
+ )
+ assert len(collected_episode) == 8
+ assert len(env.ready_obs) == 4
diff --git a/DI-engine/ding/worker/collector/tests/test_marine_parallel_collector.py b/DI-engine/ding/worker/collector/tests/test_marine_parallel_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4253799e7b925cc9bfb3b3bf68f38995b8e02fc
--- /dev/null
+++ b/DI-engine/ding/worker/collector/tests/test_marine_parallel_collector.py
@@ -0,0 +1,76 @@
+from typing import Any, Union, List
+import copy
+import torch
+import numpy as np
+import pytest
+import os
+import gym
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.utils import ENV_REGISTRY
+from ding.entry import parallel_pipeline
+from .fake_cpong_dqn_config import fake_cpong_dqn_config, fake_cpong_dqn_create_config, fake_cpong_dqn_system_config
+
+
+@ENV_REGISTRY.register('fake_competitive_rl')
+class FakeCompetitiveRlEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._is_evaluator = cfg.is_evaluator
+ self.num_agents = 2
+ self.observation_space = gym.spaces.Box(low=0, high=256, shape=(2, 4, 84, 84), dtype=np.int64)
+ self.action_space = gym.spaces.Box(low=0, high=3, shape=(1, ), dtype=np.float32)
+ self.reward_space = gym.spaces.Box(
+ low=np.float32("-inf"), high=np.float32("inf"), shape=(1, ), dtype=np.float32
+ )
+
+ def reset(self) -> np.ndarray:
+ self._step_times = 0
+ obs_shape = (4, 84, 84)
+ if not self._is_evaluator:
+ obs_shape = (2, ) + obs_shape
+ obs = np.random.randint(0, 256, obs_shape).astype(np.float32)
+ return obs
+
+ def close(self) -> None:
+ pass
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ pass
+
+ def step(self, action: Union[torch.Tensor, np.ndarray, list]) -> BaseEnvTimestep:
+ obs_shape = (4, 84, 84)
+ if not self._is_evaluator:
+ obs_shape = (2, ) + obs_shape
+ obs = np.random.randint(0, 256, obs_shape).astype(np.float32)
+ rew = np.array([1.]) if self._is_evaluator else np.array([1., -1.])
+ done = False if self._step_times < 20 else True
+ info = {}
+ if done:
+ info['eval_episode_return'] = np.array([21.]) if self._is_evaluator else np.array([5., -5.])
+ self._step_times += 1
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def __repr__(self) -> str:
+ return "Fake Competitve RL Env({})".format(self._cfg.env_id)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_cfg = copy.deepcopy(cfg)
+ collector_env_num = collector_cfg.pop('collector_env_num', 1)
+ collector_cfg.is_evaluator = False
+ return [collector_cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_cfg = copy.deepcopy(cfg)
+ evaluator_env_num = evaluator_cfg.pop('evaluator_env_num', 1)
+ evaluator_cfg.is_evaluator = True
+ return [evaluator_cfg for _ in range(evaluator_env_num)]
+
+
+@pytest.mark.unittest
+def test_1v1_collector():
+ parallel_pipeline([fake_cpong_dqn_config, fake_cpong_dqn_create_config, fake_cpong_dqn_system_config], 0)
+ os.popen("rm -rf data log policy ckpt* total_config.py")
diff --git a/DI-engine/ding/worker/collector/tests/test_metric_serial_evaluator.py b/DI-engine/ding/worker/collector/tests/test_metric_serial_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..2652f0ace4bda3c1f57d938af6e75a5dc9ed8787
--- /dev/null
+++ b/DI-engine/ding/worker/collector/tests/test_metric_serial_evaluator.py
@@ -0,0 +1,102 @@
+from ding.worker import MetricSerialEvaluator, IMetric
+from torch.utils.data import DataLoader
+import pytest
+import torch.utils.data as data
+
+import torch.nn as nn
+from ding.torch_utils import to_tensor
+import torch
+from easydict import EasyDict
+from ding.worker.collector.tests.fake_cls_policy import fake_policy
+
+fake_cls_config = dict(
+ exp_name='fake_config_for_test_metric_serial_evaluator',
+ policy=dict(
+ on_policy=False,
+ cuda=False,
+ eval=dict(batch_size=1, evaluator=dict(eval_freq=1, multi_gpu=False, stop_value=dict(acc=75.0))),
+ ),
+ env=dict(),
+)
+
+cfg = EasyDict(fake_cls_config)
+
+
+class fake_eval_dataset(data.Dataset):
+
+ def __init__(self) -> None:
+ self.data = [i for i in range(5)] # [0, 1, 2, 3, 4, 5]
+ self.target = [2 * i + 1 for i in range(5)] # [0, 3, 5, 7, 9, 11]
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ def __getitem__(self, index: int):
+ data = self.data[index]
+ target = self.target[index]
+ return data, target
+
+
+class fake_model(nn.Module): # y = 2*x+1
+
+ def __init__(self) -> None:
+ super(fake_model, self).__init__()
+ self.linear = nn.Linear(1, 1)
+ nn.init.constant_(self.linear.bias, 1)
+ nn.init.constant_(self.linear.weight, 2)
+
+ def forward(self, x):
+ x = to_tensor(x).float()
+ return self.linear(x)
+
+
+class fake_ClassificationMetric(IMetric):
+
+ @staticmethod
+ def accuracy(inputs: torch.Tensor, label: torch.Tensor) -> dict:
+ batch_size = label.size(0)
+ correct = inputs.eq(label)
+ return {'acc': correct.reshape(-1).float().sum(0) * 100. / batch_size}
+
+ def eval(self, inputs: torch.Tensor, label: torch.Tensor) -> dict:
+ output = self.accuracy(inputs, label)
+ for k in output:
+ output[k] = output[k].item()
+ return output
+
+ def reduce_mean(self, inputs) -> dict:
+ L = len(inputs)
+ output = {}
+ for k in inputs[0].keys():
+ output[k] = sum([t[k] for t in inputs]) / L
+ return output
+
+ def gt(self, metric1: dict, metric2: dict) -> bool:
+ if metric2 is None:
+ return True
+ for k in metric1:
+ if metric1[k] < metric2[k]:
+ return False
+ return True
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('cfg', [cfg])
+def test_evaluator(cfg):
+ model = fake_model()
+ eval_dataset = fake_eval_dataset()
+ eval_dataloader = DataLoader(eval_dataset, cfg.policy.eval.batch_size, num_workers=2)
+ policy = fake_policy(cfg.policy, model=model, enable_field=['eval'])
+ eval_metric = fake_ClassificationMetric()
+ evaluator = MetricSerialEvaluator(
+ cfg.policy.eval.evaluator, [eval_dataloader, eval_metric], policy.eval_mode, exp_name=cfg.exp_name
+ )
+
+ cur_iter = 0
+ assert evaluator.should_eval(cur_iter)
+
+ evaluator._last_eval_iter = 0
+ cur_iter = 1
+ stop, reward = evaluator.eval(None, cur_iter, 0)
+ assert stop
+ assert reward['acc'] == 100
diff --git a/DI-engine/ding/worker/collector/tests/test_sample_serial_collector.py b/DI-engine/ding/worker/collector/tests/test_sample_serial_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..dbc0994f10726abd4279a34288628cdb8a1e3495
--- /dev/null
+++ b/DI-engine/ding/worker/collector/tests/test_sample_serial_collector.py
@@ -0,0 +1,39 @@
+import pytest
+from ding.worker import SampleSerialCollector
+from ding.envs import BaseEnvManager, SyncSubprocessEnvManager, AsyncSubprocessEnvManager
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from dizoo.classic_control.cartpole.envs import CartPoleEnv
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('env_manager_type', [BaseEnvManager, SyncSubprocessEnvManager])
+def test_collect(env_manager_type):
+ env = env_manager_type([lambda: CartPoleEnv({}) for _ in range(8)], env_manager_type.default_config())
+ env.seed(0)
+ model = DQN(obs_shape=4, action_shape=1)
+ policy = DQNPolicy(DQNPolicy.default_config(), model=model).collect_mode
+ collector = SampleSerialCollector(SampleSerialCollector.default_config(), env, policy)
+
+ collected_sample = collector.collect(
+ n_sample=1000,
+ train_iter=collector._collect_print_freq,
+ record_random_collect=False,
+ policy_kwargs={'eps': 0.5}
+ )
+ assert len(collected_sample) == 1000
+
+
+@pytest.mark.unittest
+@pytest.mark.parametrize('env_manager_type', [BaseEnvManager, SyncSubprocessEnvManager])
+def test_random_collect(env_manager_type):
+ env = env_manager_type([lambda: CartPoleEnv({}) for _ in range(8)], env_manager_type.default_config())
+ env.seed(0)
+ model = DQN(obs_shape=4, action_shape=1)
+ policy = DQNPolicy(DQNPolicy.default_config(), model=model).collect_mode
+ collector = SampleSerialCollector(SampleSerialCollector.default_config(), env, policy)
+
+ collected_sample = collector.collect(
+ n_sample=1000, train_iter=collector._collect_print_freq, record_random_collect=True, policy_kwargs={'eps': 0.5}
+ )
+ assert len(collected_sample) == 1000
diff --git a/DI-engine/ding/worker/collector/zergling_parallel_collector.py b/DI-engine/ding/worker/collector/zergling_parallel_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c1da9c41e4e47990deffba3af9a2f1b9ec24279
--- /dev/null
+++ b/DI-engine/ding/worker/collector/zergling_parallel_collector.py
@@ -0,0 +1,296 @@
+from typing import Dict, Any, List
+import time
+import uuid
+from collections import namedtuple
+from threading import Thread
+from functools import partial
+
+import numpy as np
+import torch
+from easydict import EasyDict
+
+from ding.policy import create_policy, Policy
+from ding.envs import get_vec_env_setting, create_env_manager, BaseEnvManager
+from ding.utils import get_data_compressor, pretty_print, PARALLEL_COLLECTOR_REGISTRY
+from .base_parallel_collector import BaseParallelCollector
+from .base_serial_collector import CachePool, TrajBuffer
+
+INF = float("inf")
+
+
+@PARALLEL_COLLECTOR_REGISTRY.register('zergling')
+class ZerglingParallelCollector(BaseParallelCollector):
+ """
+ Feature:
+ - one policy, many envs
+ - async envs(step + reset)
+ - batch network eval
+ - different episode length env
+ - periodic policy update
+ - metadata + stepdata
+ """
+ config = dict(
+ print_freq=5,
+ compressor='lz4',
+ update_policy_second=3,
+ # The following keys is set by the commander
+ # env
+ # policy
+ # collect_setting
+ # eval_flag
+ # policy_update_path
+ )
+
+ # override
+ def __init__(self, cfg: dict) -> None:
+ super().__init__(cfg)
+ self._update_policy_thread = Thread(
+ target=self._update_policy_periodically, args=(), name='update_policy', daemon=True
+ )
+ self._start_time = time.time()
+ self._compressor = get_data_compressor(self._cfg.compressor)
+
+ # create env
+ self._env_cfg = self._cfg.env
+ env_manager = self._setup_env_manager(self._env_cfg)
+ self.env_manager = env_manager
+
+ # create policy
+ if self._eval_flag:
+ policy = create_policy(self._cfg.policy, enable_field=['eval']).eval_mode
+ else:
+ policy = create_policy(self._cfg.policy, enable_field=['collect']).collect_mode
+ self.policy = policy
+
+ self._episode_result = [[] for k in range(self._env_num)]
+ self._obs_pool = CachePool('obs', self._env_num)
+ self._policy_output_pool = CachePool('policy_output', self._env_num)
+ self._traj_buffer = {env_id: TrajBuffer(self._traj_len) for env_id in range(self._env_num)}
+ self._total_step = 0
+ self._total_sample = 0
+ self._total_episode = 0
+
+ @property
+ def policy(self) -> Policy:
+ return self._policy
+
+ # override
+ @policy.setter
+ def policy(self, _policy: Policy) -> None:
+ self._policy = _policy
+ self._policy_cfg = self._policy.get_attribute('cfg')
+ self._n_sample = _policy.get_attribute('n_sample')
+ self._n_episode = _policy.get_attribute('n_episode')
+ assert not all(
+ [t is None for t in [self._n_sample, self._n_episode]]
+ ), "n_episode/n_sample in policy cfg can't be not None at the same time"
+ # TODO(nyz) the same definition of traj_len in serial and parallel
+ if self._n_episode is not None:
+ self._traj_len = INF
+ elif self._n_sample is not None:
+ self._traj_len = self._n_sample
+
+ @property
+ def env_manager(self, _env_manager) -> None:
+ self._env_manager = _env_manager
+
+ # override
+ @env_manager.setter
+ def env_manager(self, _env_manager: BaseEnvManager) -> None:
+ self._env_manager = _env_manager
+ self._env_manager.launch()
+ self._env_num = self._env_manager.env_num
+ self._predefined_episode_count = self._env_num * self._env_manager._episode_num
+
+ def _setup_env_manager(self, cfg: EasyDict) -> BaseEnvManager:
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg)
+ if self._eval_flag:
+ env_cfg = evaluator_env_cfg
+ else:
+ env_cfg = collector_env_cfg
+ env_manager = create_env_manager(cfg.manager, [partial(env_fn, cfg=c) for c in env_cfg])
+ return env_manager
+
+ def _start_thread(self) -> None:
+ # evaluator doesn't need to update policy periodically, only updating policy when starts
+ if not self._eval_flag:
+ self._update_policy_thread.start()
+
+ def _join_thread(self) -> None:
+ if not self._eval_flag:
+ self._update_policy_thread.join()
+ del self._update_policy_thread
+
+ # override
+ def close(self) -> None:
+ if self._end_flag:
+ return
+ self._end_flag = True
+ time.sleep(1)
+ if hasattr(self, '_env_manager'):
+ self._env_manager.close()
+ self._join_thread()
+
+ # override
+ def _policy_inference(self, obs: Dict[int, Any]) -> Dict[int, Any]:
+ self._obs_pool.update(obs)
+ if self._eval_flag:
+ policy_output = self._policy.forward(obs)
+ else:
+ policy_output = self._policy.forward(obs, **self._cfg.collect_setting)
+ self._policy_output_pool.update(policy_output)
+ actions = {env_id: output['action'] for env_id, output in policy_output.items()}
+ return actions
+
+ # override
+ def _env_step(self, actions: Dict[int, Any]) -> Dict[int, Any]:
+ return self._env_manager.step(actions)
+
+ # override
+ def _process_timestep(self, timestep: Dict[int, namedtuple]) -> None:
+ send_data_time = []
+ for env_id, t in timestep.items():
+ if t.info.get('abnormal', False):
+ # if there is a abnormal timestep, reset all the related variable, also this env has been reset
+ self._traj_buffer[env_id].clear()
+ self._obs_pool.reset(env_id)
+ self._policy_output_pool.reset(env_id)
+ self._policy.reset([env_id])
+ continue
+ self._total_step += 1
+ if t.done: # must be executed before send_metadata
+ self._total_episode += 1
+ if not self._eval_flag:
+ transition = self._policy.process_transition(
+ self._obs_pool[env_id], self._policy_output_pool[env_id], t
+ )
+ self._traj_buffer[env_id].append(transition)
+ if (not self._eval_flag) and (t.done or len(self._traj_buffer[env_id]) == self._traj_len):
+ train_sample = self._policy.get_train_sample(self._traj_buffer[env_id])
+ for s in train_sample:
+ s = self._compressor(s)
+ self._total_sample += 1
+ with self._timer:
+ metadata = self._get_metadata(s, env_id)
+ object_ref = self.send_stepdata(metadata['data_id'], s)
+ if object_ref:
+ metadata['object_ref'] = object_ref
+ self.send_metadata(metadata)
+ send_data_time.append(self._timer.value)
+ self._traj_buffer[env_id].clear()
+ if t.done:
+ # env reset is done by env_manager automatically
+ self._obs_pool.reset(env_id)
+ self._policy_output_pool.reset(env_id)
+ self._policy.reset([env_id])
+ reward = t.info['eval_episode_return']
+ if isinstance(reward, torch.Tensor):
+ reward = reward.item()
+ self._episode_result[env_id].append(reward)
+ self.debug(
+ "env {} finish episode, final reward: {}, collected episode {}".format(
+ env_id, reward, len(self._episode_result[env_id])
+ )
+ )
+ self.debug(
+ "send {} train sample with average time: {:.6f}".format(
+ len(send_data_time),
+ sum(send_data_time) / (1e-6 + len(send_data_time))
+ )
+ )
+ dones = [t.done for t in timestep.values()]
+ if any(dones):
+ collector_info = self._get_collector_info()
+ self.send_metadata(collector_info)
+
+ # override
+ def get_finish_info(self) -> dict:
+ duration = max(time.time() - self._start_time, 1e-8)
+ episode_result = sum(self._episode_result, [])
+ finish_info = {
+ 'eval_flag': self._eval_flag,
+ 'env_num': self._env_num,
+ 'duration': duration,
+ 'train_iter': self._policy_iter,
+ 'collector_done': self._env_manager.done,
+ 'predefined_episode_count': self._predefined_episode_count,
+ 'real_episode_count': self._total_episode,
+ 'step_count': self._total_step,
+ 'sample_count': self._total_sample,
+ 'avg_time_per_episode': duration / max(1, self._total_episode),
+ 'avg_time_per_step': duration / self._total_step,
+ 'avg_time_per_train_sample': duration / max(1, self._total_sample),
+ 'avg_step_per_episode': self._total_step / max(1, self._total_episode),
+ 'avg_sample_per_episode': self._total_sample / max(1, self._total_episode),
+ 'reward_mean': np.mean(episode_result) if len(episode_result) > 0 else 0,
+ 'reward_std': np.std(episode_result) if len(episode_result) > 0 else 0,
+ 'reward_raw': episode_result,
+ 'finish_time': time.time()
+ }
+ if not self._eval_flag:
+ finish_info['collect_setting'] = self._cfg.collect_setting
+ self._logger.info('\nFINISH INFO\n{}'.format(pretty_print(finish_info, direct_print=False)))
+ return finish_info
+
+ # override
+ def _update_policy(self) -> None:
+ path = self._cfg.policy_update_path
+ while True:
+ try:
+ policy_update_info = self.get_policy_update_info(path)
+ break
+ except Exception as e:
+ self.error('Policy update error: {}'.format(e))
+ time.sleep(1)
+ if policy_update_info is None:
+ return
+
+ self._policy_iter = policy_update_info.pop('iter')
+ self._policy.load_state_dict(policy_update_info)
+ self.debug('update policy with {}(iter{}) in {}'.format(path, self._policy_iter, time.time()))
+
+ # ******************************** thread **************************************
+
+ def _update_policy_periodically(self) -> None:
+ last = time.time()
+ while not self._end_flag:
+ cur = time.time()
+ interval = cur - last
+ if interval < self._cfg.update_policy_second:
+ time.sleep(self._cfg.update_policy_second * 0.1)
+ continue
+ else:
+ self._update_policy()
+ last = time.time()
+ time.sleep(0.1)
+
+ def _get_metadata(self, stepdata: List, env_id: int) -> dict:
+ data_id = "env_{}_{}".format(env_id, str(uuid.uuid1()))
+ metadata = {
+ 'eval_flag': self._eval_flag,
+ 'data_id': data_id,
+ 'env_id': env_id,
+ 'policy_iter': self._policy_iter,
+ 'unroll_len': len(stepdata),
+ 'compressor': self._cfg.compressor,
+ 'get_data_time': time.time(),
+ # TODO(nyz) the relationship between traj priority and step priority
+ 'priority': 1.0,
+ 'cur_episode': self._total_episode,
+ 'cur_sample': self._total_sample,
+ 'cur_step': self._total_step,
+ }
+ return metadata
+
+ def _get_collector_info(self) -> dict:
+ return {
+ 'eval_flag': self._eval_flag,
+ 'get_info_time': time.time(),
+ 'collector_done': self._env_manager.done,
+ 'cur_episode': self._total_episode,
+ 'cur_sample': self._total_sample,
+ 'cur_step': self._total_step,
+ }
+
+ def __repr__(self) -> str:
+ return "ZerglingParallelCollector"
diff --git a/DI-engine/ding/worker/coordinator/__init__.py b/DI-engine/ding/worker/coordinator/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..754a1b1bf057a49668f0a86b611650aa605fe114
--- /dev/null
+++ b/DI-engine/ding/worker/coordinator/__init__.py
@@ -0,0 +1,3 @@
+from .base_serial_commander import BaseSerialCommander
+from .base_parallel_commander import create_parallel_commander, get_parallel_commander_cls
+from .coordinator import Coordinator
diff --git a/DI-engine/ding/worker/coordinator/base_parallel_commander.py b/DI-engine/ding/worker/coordinator/base_parallel_commander.py
new file mode 100644
index 0000000000000000000000000000000000000000..31db4d5697fb06a1a3387a7b2f724d7a61ee1140
--- /dev/null
+++ b/DI-engine/ding/worker/coordinator/base_parallel_commander.py
@@ -0,0 +1,200 @@
+from abc import ABC, abstractmethod
+from collections import defaultdict
+from easydict import EasyDict
+import copy
+
+from ding.utils import import_module, COMMANDER_REGISTRY, LimitedSpaceContainer
+
+
+class BaseCommander(ABC):
+ r"""
+ Overview:
+ Base parallel commander abstract class.
+ Interface:
+ get_collector_task
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ @abstractmethod
+ def get_collector_task(self) -> dict:
+ raise NotImplementedError
+
+ def judge_collector_finish(self, task_id: str, info: dict) -> bool:
+ collector_done = info.get('collector_done', False)
+ if collector_done:
+ return True
+ return False
+
+ def judge_learner_finish(self, task_id: str, info: dict) -> bool:
+ learner_done = info.get('learner_done', False)
+ if learner_done:
+ return True
+ return False
+
+
+@COMMANDER_REGISTRY.register('naive')
+class NaiveCommander(BaseCommander):
+ r"""
+ Overview:
+ A naive implementation of parallel commander.
+ Interface:
+ __init__, get_collector_task, get_learner_task, finsh_collector_task, finish_learner_task,
+ notify_fail_collector_task, notify_fail_learner_task, update_learner_info
+ """
+ config = dict(
+ collector_task_space=1,
+ learner_task_space=1,
+ eval_interval=60,
+ )
+
+ def __init__(self, cfg: dict) -> None:
+ r"""
+ Overview:
+ Init the naive commander according to config
+ Arguments:
+ - cfg (:obj:`dict`): The config to init commander. Should include \
+ "collector_task_space" and "learner_task_space".
+ """
+ self._cfg = cfg
+ self._exp_name = cfg.exp_name
+ commander_cfg = self._cfg.policy.other.commander
+ self._collector_task_space = LimitedSpaceContainer(0, commander_cfg.collector_task_space)
+ self._learner_task_space = LimitedSpaceContainer(0, commander_cfg.learner_task_space)
+
+ self._collector_env_cfg = copy.deepcopy(self._cfg.env)
+ self._collector_env_cfg.pop('collector_episode_num')
+ self._collector_env_cfg.pop('evaluator_episode_num')
+ self._collector_env_cfg.manager.episode_num = self._cfg.env.collector_episode_num
+
+ self._collector_task_count = 0
+ self._learner_task_count = 0
+ self._learner_info = defaultdict(list)
+ self._learner_task_finish_count = 0
+ self._collector_task_finish_count = 0
+
+ def get_collector_task(self) -> dict:
+ r"""
+ Overview:
+ Get a new collector task when ``collector_task_count`` is smaller than ``collector_task_space``.
+ Return:
+ - task (:obj:`dict`): New collector task.
+ """
+ if self._collector_task_space.acquire_space():
+ self._collector_task_count += 1
+ collector_cfg = copy.deepcopy(self._cfg.policy.collect.collector)
+ collector_cfg.collect_setting = {'eps': 0.9}
+ collector_cfg.eval_flag = False
+ collector_cfg.policy = copy.deepcopy(self._cfg.policy)
+ collector_cfg.policy_update_path = 'test.pth'
+ collector_cfg.env = self._collector_env_cfg
+ collector_cfg.exp_name = self._exp_name
+ return {
+ 'task_id': 'collector_task_id{}'.format(self._collector_task_count),
+ 'buffer_id': 'test',
+ 'collector_cfg': collector_cfg,
+ }
+ else:
+ return None
+
+ def get_learner_task(self) -> dict:
+ r"""
+ Overview:
+ Get the new learner task when task_count is less than task_space
+ Return:
+ - task (:obj:`dict`): the new learner task
+ """
+ if self._learner_task_space.acquire_space():
+ self._learner_task_count += 1
+ learner_cfg = copy.deepcopy(self._cfg.policy.learn.learner)
+ learner_cfg.exp_name = self._exp_name
+ return {
+ 'task_id': 'learner_task_id{}'.format(self._learner_task_count),
+ 'policy_id': 'test.pth',
+ 'buffer_id': 'test',
+ 'learner_cfg': learner_cfg,
+ 'replay_buffer_cfg': copy.deepcopy(self._cfg.policy.other.replay_buffer),
+ 'policy': copy.deepcopy(self._cfg.policy),
+ }
+ else:
+ return None
+
+ def finish_collector_task(self, task_id: str, finished_task: dict) -> None:
+ r"""
+ Overview:
+ finish collector task will add the collector_task_finish_count
+ """
+ self._collector_task_space.release_space()
+ self._collector_task_finish_count += 1
+
+ def finish_learner_task(self, task_id: str, finished_task: dict) -> str:
+ r"""
+ Overview:
+ finish learner task will add the learner_task_finish_count and get the buffer_id of task to close the buffer
+ Return:
+ the finished_task buffer_id
+ """
+ self._learner_task_finish_count += 1
+ self._learner_task_space.release_space()
+ return finished_task['buffer_id']
+
+ def notify_fail_collector_task(self, task: dict) -> None:
+ r"""
+ Overview:
+ naive coordinator will pass when need to notify_fail_collector_task
+ """
+ self._collector_task_space.release_space()
+
+ def notify_fail_learner_task(self, task: dict) -> None:
+ r"""
+ Overview:
+ naive coordinator will pass when need to notify_fail_learner_task
+ """
+ self._learner_task_space.release_space()
+
+ def update_learner_info(self, task_id: str, info: dict) -> None:
+ r"""
+ Overview:
+ append the info to learner:
+ Arguments:
+ - task_id (:obj:`str`): the learner task_id
+ - info (:obj:`dict`): the info to append to learner
+ """
+ self._learner_info[task_id].append(info)
+
+ def increase_collector_task_space(self):
+ r""""
+ Overview:
+ Increase task space when a new collector has added dynamically.
+ """
+ self._collector_task_space.increase_space()
+
+ def decrease_collector_task_space(self):
+ r""""
+ Overview:
+ Decrease task space when a new collector has removed dynamically.
+ """
+ self._collector_task_space.decrease_space()
+
+
+def create_parallel_commander(cfg: EasyDict) -> BaseCommander:
+ r"""
+ Overview:
+ create the commander according to cfg
+ Arguments:
+ - cfg (:obj:`dict`): the commander cfg to create, should include import_names and parallel_commander_type
+ """
+ cfg = EasyDict(cfg)
+ import_names = cfg.policy.other.commander.import_names
+ import_module(import_names)
+ return COMMANDER_REGISTRY.build(cfg.policy.other.commander.type, cfg=cfg)
+
+
+def get_parallel_commander_cls(cfg: EasyDict) -> type:
+ cfg = EasyDict(cfg)
+ import_module(cfg.get('import_names', []))
+ return COMMANDER_REGISTRY.get(cfg.type)
diff --git a/DI-engine/ding/worker/coordinator/base_serial_commander.py b/DI-engine/ding/worker/coordinator/base_serial_commander.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba3e318b2a2f78ee202e507a9a271fc15c7b360e
--- /dev/null
+++ b/DI-engine/ding/worker/coordinator/base_serial_commander.py
@@ -0,0 +1,72 @@
+from collections import namedtuple
+from easydict import EasyDict
+import copy
+
+
+class BaseSerialCommander(object):
+ r"""
+ Overview:
+ Base serial commander class.
+ Interface:
+ __init__, step
+ Property:
+ policy
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ config = {}
+
+ def __init__(
+ self,
+ cfg: dict,
+ learner: 'BaseLearner', # noqa
+ collector: 'BaseSerialCollector', # noqa
+ evaluator: 'InteractionSerialEvaluator', # noqa
+ replay_buffer: 'IBuffer', # noqa
+ policy: namedtuple = None,
+ ) -> None:
+ r"""
+ Overview:
+ Init the BaseSerialCommander
+ Arguments:
+ - cfg (:obj:`dict`): the config of commander
+ - learner (:obj:`BaseLearner`): the learner
+ - collector (:obj:`BaseSerialCollector`): the collector
+ - evaluator (:obj:`InteractionSerialEvaluator`): the evaluator
+ - replay_buffer (:obj:`IBuffer`): the buffer
+ """
+ self._cfg = cfg
+ self._learner = learner
+ self._collector = collector
+ self._evaluator = evaluator
+ self._replay_buffer = replay_buffer
+ self._info = {}
+ if policy is not None:
+ self.policy = policy
+
+ def step(self) -> None:
+ r"""
+ Overview:
+ Step the commander
+ """
+ # Update info
+ learn_info = self._learner.learn_info
+ collector_info = {'envstep': self._collector.envstep}
+ self._info.update(learn_info)
+ self._info.update(collector_info)
+ # update kwargs
+ collect_kwargs = self._policy.get_setting_collect(self._info)
+ return collect_kwargs
+
+ @property
+ def policy(self) -> 'Policy': # noqa
+ return self._policy
+
+ @policy.setter
+ def policy(self, _policy: 'Policy') -> None: # noqa
+ self._policy = _policy
diff --git a/DI-engine/ding/worker/coordinator/comm_coordinator.py b/DI-engine/ding/worker/coordinator/comm_coordinator.py
new file mode 100644
index 0000000000000000000000000000000000000000..66fd9e0dc167253548ef69edbdb7b9caa1955a35
--- /dev/null
+++ b/DI-engine/ding/worker/coordinator/comm_coordinator.py
@@ -0,0 +1,568 @@
+import traceback
+import time
+import sys
+import requests
+from typing import Dict, Callable
+from threading import Thread
+
+from ding.utils import LockContext, LockContextType, get_operator_server_kwargs
+from ding.interaction import Master
+from ding.interaction.master.task import TaskStatus
+from .resource_manager import NaiveResourceManager
+from .operator_server import OperatorServer
+
+
+class CommCoordinator(object):
+ r"""
+ Overview:
+ the communication part of coordinator(coordinator intercollector)
+ Interface:
+ __init__ , start, close, __del__, send_collector_task, send_learner_task
+ """
+
+ def __init__(self, cfg: dict, callback_fn: Dict[str, Callable], logger: 'logging.Logger') -> None: # noqa
+ r"""
+ Overview:
+ init the interactor of coordinator
+ Arguments:
+ - cfg (:obj:`dict`): The config file of communication coordinator
+ - callback_fn (:obj:`Dict[str, Callable]`): The callback functions given by coordinator
+ - logger (:obj:`logging.Logger`): The text logger.
+ """
+ self._cfg = cfg
+ self._callback_fn = callback_fn
+ self._logger = logger
+ self._max_retry_second = 120
+ self._end_flag = True
+
+ self._connection_collector = {}
+ self._connection_learner = {}
+ self._resource_manager = NaiveResourceManager()
+
+ self._remain_task_lock = LockContext(LockContextType.THREAD_LOCK)
+ self._remain_collector_task = set()
+ self._remain_learner_task = set()
+
+ if self._cfg.operator_server:
+ server_kwargs = get_operator_server_kwargs(self._cfg.operator_server)
+ self._operator_server = OperatorServer(**server_kwargs)
+ self._operator_server.set_worker_type('coordinator')
+ self._collector_target_num = self._cfg.operator_server.collector_target_num
+ self._learner_target_num = self._cfg.operator_server.learner_target_num
+ else:
+ self._operator_server = None
+
+ # for update resource
+ self._resource_lock = LockContext(LockContextType.THREAD_LOCK)
+
+ # failed connection
+ self._failed_learner_conn = set()
+ self._failed_collector_conn = set()
+
+ def start(self) -> None:
+ r"""
+ Overview:
+ start the coordinator interactor and manage resources and connections
+ """
+ self._end_flag = False
+ self._master = Master(self._cfg.host, self._cfg.port)
+ self._master.start()
+ self._master.ping()
+
+ # new connection from config
+ for _, (learner_id, learner_host, learner_port) in self._cfg.learner.items():
+ self._new_connection_learner(learner_id, learner_host, learner_port)
+ for _, (collector_id, collector_host, collector_port) in self._cfg.collector.items():
+ self._new_connection_collector(collector_id, collector_host, collector_port)
+
+ if self._operator_server:
+ # post init learner/collector demand
+ start_time, init_flag = time.time(), False
+ while time.time() - start_time <= self._max_retry_second and not self._end_flag:
+ success, _, message, _ = self._operator_server.post_replicas(
+ self._cfg.operator_server.init_replicas_request
+ )
+ if success:
+ self._logger.info("Post replicas demand to server successfully")
+ init_flag = True
+ break
+ else:
+ self._logger.info("Failed to post replicas request to server, message: {}".format(message))
+ time.sleep(2)
+
+ if not init_flag:
+ self._logger.info('Exit since cannot request replicas to operator-server...')
+ self.close()
+ sys.exit(1)
+
+ # create sync learner/collector thread
+ self._period_sync_with_server_thread = Thread(
+ target=self._period_sync_with_server, name="period_sync", daemon=True
+ )
+ self._period_sync_with_server_thread.start()
+
+ # wait for enough collector/learner
+ start_time = time.time()
+ enough_flag = False
+ while time.time() - start_time <= self._max_retry_second:
+ if len(self._connection_collector) < self._collector_target_num and len(self._connection_learner
+ ) < self._learner_target_num:
+ self._logger.info(
+ "Only can connect {} collectors, {} learners.".format(
+ len(self._connection_collector), len(self._connection_learner)
+ )
+ )
+ time.sleep(2)
+ else:
+ self._logger.info(
+ "Have connected {} collectors, {} learners, match limit requests.".format(
+ len(self._connection_collector), len(self._connection_learner)
+ )
+ )
+ self._logger.info("Total DI-engine pipeline start...")
+ enough_flag = True
+ break
+
+ if not enough_flag:
+ self._logger.error(
+ "Exit since only can connect {} collectors, {} learners.".format(
+ len(self._connection_collector), len(self._connection_learner)
+ )
+ )
+ self.close()
+ sys.exit(1)
+
+ if self._end_flag:
+ self._logger.error("connection max retries failed")
+ sys.exit(1)
+
+ def _new_connection_collector(
+ self,
+ collector_id: str,
+ collector_host: str,
+ collector_port: int,
+ increase_task_space: bool = False,
+ ) -> None:
+ start_time = time.time()
+ conn = None
+ while time.time() - start_time <= self._max_retry_second and not self._end_flag:
+ try:
+ if conn is None or not conn.is_connected:
+ conn = self._master.new_connection(collector_id, collector_host, collector_port)
+ conn.connect()
+ assert conn.is_connected
+ resource_task = self._get_resource(conn)
+ if resource_task.status != TaskStatus.COMPLETED:
+ self._logger.error("can't acquire resource for collector({})".format(collector_id))
+ continue
+ else:
+ with self._resource_lock:
+ self._resource_manager.update('collector', collector_id, resource_task.result)
+ self._connection_collector[collector_id] = conn
+ if increase_task_space:
+ self._callback_fn['deal_with_increase_collector']()
+ break
+
+ except Exception as e:
+ self._logger.error(
+ f"Collector({collector_id}) connection start error:\n" +
+ ''.join(traceback.format_tb(e.__traceback__)) + repr(e) + '\nAuto Retry...'
+ )
+ time.sleep(2)
+
+ if collector_id in self._connection_collector:
+ self._logger.info(f"Succeed to connect to collector({collector_id})")
+ else:
+ self._logger.info(f"Fail to connect to collector({collector_id})")
+ self._failed_collector_conn.add(collector_id)
+
+ def _new_connection_learner(self, learner_id: str, learner_host: str, learner_port: int) -> None:
+ start_time = time.time()
+ conn = None
+ while time.time() - start_time <= self._max_retry_second and not self._end_flag:
+ try:
+ if conn is None or not conn.is_connected:
+ conn = self._master.new_connection(learner_id, learner_host, learner_port)
+ conn.connect()
+ assert conn.is_connected
+ resource_task = self._get_resource(conn)
+ if resource_task.status != TaskStatus.COMPLETED:
+ self._logger.error("can't acquire resource for learner({})".format(learner_id))
+ continue
+ else:
+ with self._resource_lock:
+ self._resource_manager.update('learner', learner_id, resource_task.result)
+ self._connection_learner[learner_id] = conn
+ break
+
+ except Exception as e:
+ self._logger.error(
+ f"learner({learner_id}) connection start error:\n" + ''.join(traceback.format_tb(e.__traceback__)) +
+ repr(e) + '\nAuto Retry...'
+ )
+ time.sleep(2)
+
+ if learner_id in self._connection_learner:
+ self._logger.info(f"Succeed to connect to learner({learner_id})")
+ else:
+ self._logger.info(f"Fail to connect to learner({learner_id})")
+ self._failed_learner_conn.add(learner_id)
+
+ def close(self) -> None:
+ r"""
+ Overview:
+ close the coordinator interactor
+ """
+ if self._end_flag:
+ return
+ self._end_flag = True
+ # wait for execute thread
+ start_time = time.time()
+ # TODO
+ if self._operator_server:
+ self._period_sync_with_server_thread.join()
+ # wait from all slave receive DELETE
+ time.sleep(5)
+ while time.time() - start_time <= 60:
+ if len(self._remain_learner_task) == 0 and len(self._remain_collector_task) == 0:
+ break
+ else:
+ time.sleep(1)
+ for collector_id, conn in self._connection_collector.items():
+ conn.disconnect()
+ assert not conn.is_connected
+ for learner_id, conn in self._connection_learner.items():
+ conn.disconnect()
+ assert not conn.is_connected
+ self._master.close()
+
+ def __del__(self) -> None:
+ r"""
+ Overview:
+ __del__ method will close the coordinator interactor
+ """
+ self.close()
+
+ def _get_resource(self, conn: 'Connection') -> 'TaskResult': # noqa
+ r"""
+ Overview:
+ get the resources according to connection
+ Arguments:
+ - conn (:obj:`Connection`): the connection to get resource_task
+ """
+ resource_task = conn.new_task({'name': 'resource'})
+ resource_task.start().join()
+ return resource_task
+
+ def send_collector_task(self, collector_task: dict) -> bool:
+ r"""
+ Overview:
+ send the collector_task to collector_task threads and execute
+ Arguments:
+ - collector_task (:obj:`dict`): the collector_task to send
+ """
+ # assert not self._end_flag, "please start interaction first"
+ task_id = collector_task['task_id']
+ # according to resource info, assign task to a specific collector and adapt task
+ assigned_collector = self._resource_manager.assign_collector(collector_task)
+ if assigned_collector is None:
+ self._logger.error("collector task({}) doesn't have enough collector to execute".format(task_id))
+ return False
+ collector_task.update(assigned_collector)
+
+ collector_id = collector_task['collector_id']
+ start_task = self._connection_collector[collector_id].new_task(
+ {
+ 'name': 'collector_start_task',
+ 'task_info': collector_task
+ }
+ )
+ start_task.start().join()
+ if start_task.status != TaskStatus.COMPLETED:
+ self._resource_manager.update(
+ 'collector', assigned_collector['collector_id'], assigned_collector['resource_info']
+ )
+ self._logger.error('collector_task({}) start failed: {}'.format(task_id, start_task.result))
+ return False
+ else:
+ self._logger.info('collector task({}) is assigned to collector({})'.format(task_id, collector_id))
+ with self._remain_task_lock:
+ self._remain_collector_task.add(task_id)
+ collector_task_thread = Thread(
+ target=self._execute_collector_task, args=(collector_task, ), name='coordinator_collector_task'
+ )
+ collector_task_thread.start()
+ return True
+
+ def _execute_collector_task(self, collector_task: dict) -> None:
+ r"""
+ Overview:
+ execute the collector task
+ Arguments:
+ - collector_task (:obj:`dict`): the collector task to execute
+ """
+ close_flag = False
+ collector_id = collector_task['collector_id']
+ while not self._end_flag:
+ try:
+ # data task
+ data_task = self._connection_collector[collector_id].new_task({'name': 'collector_data_task'})
+ self._logger.info('collector data task begin')
+ data_task.start().join()
+ self._logger.info('collector data task end')
+ if data_task.status != TaskStatus.COMPLETED:
+ # TODO(deal with fail task)
+ self._logger.error('collector data task is failed')
+ continue
+ result = data_task.result
+ task_id = result.get('task_id', None)
+ # data result
+ if 'data_id' in result:
+ buffer_id = result.get('buffer_id', None)
+ data_id = result.get('data_id', None)
+ self._callback_fn['deal_with_collector_send_data'](task_id, buffer_id, data_id, result)
+ # info result
+ else:
+ is_finished = self._callback_fn['deal_with_collector_judge_finish'](task_id, result)
+ if not is_finished:
+ continue
+ # close task
+ self._logger.error('close_task: {}\n{}'.format(task_id, result))
+ close_task = self._connection_collector[collector_id].new_task({'name': 'collector_close_task'})
+ close_task.start().join()
+ if close_task.status != TaskStatus.COMPLETED:
+ # TODO(deal with fail task)
+ self._logger.error('collector close is failed')
+ break
+ result = close_task.result
+ task_id = result.get('task_id', None)
+ self._callback_fn['deal_with_collector_finish_task'](task_id, result)
+ resource_task = self._get_resource(self._connection_collector[collector_id])
+ if resource_task.status == TaskStatus.COMPLETED:
+ self._resource_manager.update('collector', collector_id, resource_task.result)
+ close_flag = True
+ break
+ except requests.exceptions.HTTPError as e:
+ if self._end_flag:
+ break
+ else:
+ raise e
+
+ if not close_flag:
+ close_task = self._connection_collector[collector_id].new_task({'name': 'collector_close_task'})
+ close_task.start().join()
+ with self._remain_task_lock:
+ self._remain_collector_task.remove(task_id)
+
+ def send_learner_task(self, learner_task: dict) -> bool:
+ r"""
+ Overview:
+ send the learner_task to learner_task threads and execute
+ Arguments:
+ - learner_task (:obj:`dict`): the learner_task to send
+ """
+ # assert not self._end_flag, "please start interaction first"
+ task_id = learner_task['task_id']
+ assigned_learner = self._resource_manager.assign_learner(learner_task)
+ if assigned_learner is None:
+ self._logger.error("learner task({}) doesn't have enough learner to execute".format(task_id))
+ return False
+ learner_task.update(assigned_learner)
+
+ learner_id = learner_task['learner_id']
+ start_task = self._connection_learner[learner_id].new_task(
+ {
+ 'name': 'learner_start_task',
+ 'task_info': learner_task
+ }
+ )
+ start_task.start().join()
+ if start_task.status != TaskStatus.COMPLETED:
+ self._resource_manager.update('learner', assigned_learner['learner_id'], assigned_learner['resource_info'])
+ self._logger.info('learner_task({}) start failed: {}'.format(task_id, start_task.result))
+ return False
+ else:
+ self._logger.info('learner task({}) is assigned to learner({})'.format(task_id, learner_id))
+ with self._remain_task_lock:
+ self._remain_learner_task.add(task_id)
+ learner_task_thread = Thread(
+ target=self._execute_learner_task, args=(learner_task, ), name='coordinator_learner_task'
+ )
+ learner_task_thread.start()
+ return True
+
+ def _execute_learner_task(self, learner_task: dict) -> None:
+ r"""
+ Overview:
+ execute the learner task
+ Arguments:
+ - learner_task (:obj:`dict`): the learner task to execute
+ """
+ close_flag = False
+ learner_id = learner_task['learner_id']
+ while not self._end_flag:
+ try:
+ # get data
+ get_data_task = self._connection_learner[learner_id].new_task({'name': 'learner_get_data_task'})
+ get_data_task.start().join()
+ if get_data_task.status != TaskStatus.COMPLETED:
+ # TODO(deal with fail task)
+ self._logger.error('learner get_data_task failed: {}'.format(get_data_task.result))
+ continue
+ result = get_data_task.result
+ task_id, buffer_id, batch_size = result['task_id'], result['buffer_id'], result['batch_size']
+ cur_learner_iter = result['cur_learner_iter']
+ sleep_count = 1
+ while True:
+ data = self._callback_fn['deal_with_learner_get_data'](
+ task_id, buffer_id, batch_size, cur_learner_iter
+ )
+ if self._end_flag or data is not None:
+ self._logger.info('sample result is ok')
+ break
+ else:
+ self._logger.info('sample result is None')
+ time.sleep(sleep_count)
+ sleep_count += 2
+ if self._end_flag:
+ break
+
+ # learn task
+ learn_task = self._connection_learner[learner_id].new_task({'name': 'learner_learn_task', 'data': data})
+ learn_task.start().join()
+ if learn_task.status != TaskStatus.COMPLETED:
+ # TODO(deal with fail task)
+ self._logger.error('learner learn_task failed: {}'.format(learn_task.result))
+ continue
+ result = learn_task.result
+ task_id, info = result['task_id'], result['info']
+ is_finished = self._callback_fn['deal_with_learner_judge_finish'](task_id, info)
+ if is_finished:
+ # close task and update resource
+ close_task = self._connection_learner[learner_id].new_task({'name': 'learner_close_task'})
+ close_task.start().join()
+ if close_task.status != TaskStatus.COMPLETED:
+ self._logger.error('learner close_task failed: {}'.format(close_task.result))
+ break
+ result = close_task.result
+ task_id = result.get('task_id', None)
+ self._callback_fn['deal_with_learner_finish_task'](task_id, result)
+ resource_task = self._get_resource(self._connection_learner[learner_id])
+ if resource_task.status == TaskStatus.COMPLETED:
+ self._resource_manager.update('learner', learner_id, resource_task.result)
+ close_flag = True
+ break
+ else:
+ # update info
+ buffer_id = result['buffer_id']
+ self._callback_fn['deal_with_learner_send_info'](task_id, buffer_id, info)
+ except requests.exceptions.HTTPError as e:
+ if self._end_flag:
+ break
+ else:
+ raise e
+
+ if not close_flag:
+ close_task = self._connection_learner[learner_id].new_task({'name': 'learner_close_task'})
+ close_task.start().join()
+ with self._remain_task_lock:
+ self._remain_learner_task.remove(task_id)
+
+ def _period_sync_with_server(self) -> None:
+ while not self._end_flag:
+ # First: send failed list to notify DI-engine server which replicas are failed,
+ # then terminate such replicas.
+ # self._logger.info("failed list:", list(self._failed_collector_conn), list(self._failed_learner_conn))
+ if len(self._failed_learner_conn) > 0 or len(self._failed_collector_conn) > 0:
+ collector_conn = []
+ for replica_conn in self._failed_collector_conn:
+ dns_name = replica_conn.split(":")[0]
+ pod_name_list = dns_name.split(".")[:-1]
+ pod_name = ".".join(pod_name_list)
+ collector_conn.append(pod_name)
+ learner_conn = []
+ for replica_conn in self._failed_learner_conn:
+ dns_name = replica_conn.split(":")[0]
+ pod_name_list = dns_name.split(".")[:-1]
+ pod_name = ".".join(pod_name_list)
+ learner_conn.append(pod_name)
+
+ success, _, message, _ = self._operator_server.post_replicas_failed(
+ learners=list(learner_conn), collectors=list(collector_conn)
+ )
+ if success:
+ # do not update collector or learner instantly, update at /GET replicas
+ self._failed_collector_conn.clear()
+ self._failed_learner_conn.clear()
+ else:
+ self._logger.error("Failed to send failed list to server, message: {}".format(message))
+
+ # get list from server
+ success, _, message, data = self._operator_server.get_replicas()
+ if success:
+ cur_collectors = data["collectors"]
+ cur_learners = data["learners"]
+ # self._logger.info("current list:", cur_collectors, cur_learners)
+ self._update_connection_collector(cur_collectors)
+ self._update_connection_learner(cur_learners)
+ else:
+ self._logger.error("Failed to sync with server, message: {}".format(message))
+
+ time.sleep(1)
+
+ def _update_connection_collector(self, cur_collectors: list) -> None:
+ conn_collectors = list(self._connection_collector.keys())
+ new_c = set(cur_collectors) - set(conn_collectors)
+ del_c = set(conn_collectors) - (set(cur_collectors) | self._failed_collector_conn)
+ # conns which have terminated in server side, clear up
+ self._failed_collector_conn = self._failed_collector_conn & set(cur_collectors)
+
+ # connect to each new collector
+ for collector_id in new_c:
+ collector_host, collector_port = collector_id.split(':')
+ self._new_connection_collector(collector_id, collector_host, int(collector_port), True)
+
+ for collector_id in del_c:
+ if collector_id in conn_collectors:
+ # TODO(nyz) whether to need to close task first
+ with self._resource_lock:
+ if not self._resource_manager.have_assigned('collector', collector_id):
+ self._resource_manager.delete("collector", collector_id)
+
+ if self._connection_collector[collector_id].is_connected:
+ conn = self._connection_collector.pop(collector_id)
+ conn.disconnect()
+ assert not conn.is_connected
+ self._callback_fn['deal_with_decrease_collector']()
+ else:
+ # ignore the operation of disconnect, since the pod will be terminated by server,
+ # just throw the connection
+ self._connection_collector.pop(collector_id)
+
+ def _update_connection_learner(self, cur_learners) -> None:
+ conn_learners = list(self._connection_learner.keys())
+ new_c = set(cur_learners) - set(conn_learners)
+ del_c = set(conn_learners) - (set(cur_learners) | self._failed_learner_conn)
+ # conns which have terminated in server side, clear up
+ self._failed_learner_conn = self._failed_learner_conn & set(cur_learners)
+
+ # connect to each new learner
+ for learner_id in new_c:
+ learner_host, learner_port = learner_id.split(':')
+ self._new_connection_learner(learner_id, learner_host, int(learner_port))
+
+ for learner_id in del_c:
+ if learner_id in conn_learners:
+ # TODO(nyz) whether to need to close task first
+ with self._resource_lock:
+ if not self._resource_manager.have_assigned('learner', learner_id):
+ self._resource_manager.delete("learner", learner_id)
+
+ if self._connection_learner[learner_id].is_connected:
+ conn = self._connection_learner.pop(learner_id)
+ conn.disconnect()
+ assert not conn.is_connected
+ else:
+ # ignore the operation of disconnect, since the pod will be terminated by server,
+ # just throw the connection
+ self._connection_learner.pop(learner_id)
diff --git a/DI-engine/ding/worker/coordinator/coordinator.py b/DI-engine/ding/worker/coordinator/coordinator.py
new file mode 100644
index 0000000000000000000000000000000000000000..508f78990abb9d35798e82de22434eeb1e07065c
--- /dev/null
+++ b/DI-engine/ding/worker/coordinator/coordinator.py
@@ -0,0 +1,480 @@
+import time
+import copy
+from typing import List
+from queue import Queue
+from threading import Thread
+from easydict import EasyDict
+
+from ding.utils import build_logger, LockContext, LockContextType, get_task_uid
+from ding.worker import create_buffer
+from .comm_coordinator import CommCoordinator
+from .base_parallel_commander import create_parallel_commander
+
+
+class TaskState(object):
+ r"""
+ Overview:
+ State recorder of the task, including ``task_id`` and ``start_time``.
+ Interface:
+ __init__
+ """
+
+ def __init__(self, task_id: str) -> None:
+ r"""
+ Overview:
+ Init the task tate according to task_id and the init time.
+ """
+ self.task_id = task_id
+ self.start_time = time.time()
+
+
+class Coordinator(object):
+ r"""
+ Overview:
+ the coordinator will manage parallel tasks and data
+ Interface:
+ __init__, start, close, __del__, state_dict, load_state_dict,
+ deal_with_collector_send_data, deal_with_collector_finish_task,
+ deal_with_learner_get_data, deal_with_learner_send_info, deal_with_learner_finish_task
+ Property:
+ system_shutdown_flag
+ """
+ config = dict(
+ collector_task_timeout=30,
+ learner_task_timeout=600,
+ operator_server=dict(),
+ )
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ def __init__(self, cfg: dict) -> None:
+ r"""
+ Overview:
+ init method of the coordinator
+ Arguments:
+ - cfg (:obj:`dict`): the config file to init the coordinator
+ """
+ self._exp_name = cfg.main.exp_name
+ self._coordinator_uid = get_task_uid()
+ coor_cfg = cfg.system.coordinator
+ self._collector_task_timeout = coor_cfg.collector_task_timeout
+ self._learner_task_timeout = coor_cfg.learner_task_timeout
+
+ self._callback = {
+ 'deal_with_collector_send_data': self.deal_with_collector_send_data,
+ 'deal_with_collector_judge_finish': self.deal_with_collector_judge_finish,
+ 'deal_with_collector_finish_task': self.deal_with_collector_finish_task,
+ 'deal_with_learner_get_data': self.deal_with_learner_get_data,
+ 'deal_with_learner_send_info': self.deal_with_learner_send_info,
+ 'deal_with_learner_judge_finish': self.deal_with_learner_judge_finish,
+ 'deal_with_learner_finish_task': self.deal_with_learner_finish_task,
+ 'deal_with_increase_collector': self.deal_with_increase_collector,
+ 'deal_with_decrease_collector': self.deal_with_decrease_collector,
+ }
+ self._logger, _ = build_logger(path='./{}/log'.format(self._exp_name), name='coordinator', need_tb=False)
+ self._interaction = CommCoordinator(coor_cfg, self._callback, self._logger)
+ self._learner_task_queue = Queue()
+ self._collector_task_queue = Queue()
+ self._commander = create_parallel_commander(cfg.main) # commander can access all the main config
+ self._commander_lock = LockContext(LockContextType.THREAD_LOCK)
+ # ############## Thread #####################
+ # Assign thread todo
+ # Produce thread todo
+ self._assign_collector_thread = Thread(
+ target=self._assign_collector_task, args=(), name='coordinator_assign_collector'
+ )
+ self._assign_learner_thread = Thread(
+ target=self._assign_learner_task, args=(), name='coordinator_assign_learner'
+ )
+ self._produce_collector_thread = Thread(
+ target=self._produce_collector_task, args=(), name='coordinator_produce_collector'
+ )
+ self._produce_learner_thread = Thread(
+ target=self._produce_learner_task, args=(), name='coordinator_produce_learner'
+ )
+
+ self._replay_buffer = {}
+ self._task_state = {} # str -> TaskState
+ self._historical_task = []
+ # TODO remove used data
+ # TODO load/save state_dict
+ self._end_flag = True
+ self._system_shutdown_flag = False
+
+ def _assign_collector_task(self) -> None:
+ r"""
+ Overview:
+ The function to be called in the assign_collector_task thread.
+ Will get an collector task from ``collector_task_queue`` and assign the task.
+ """
+ while not self._end_flag:
+ time.sleep(0.01)
+ # get valid task, abandon timeout task
+ if self._collector_task_queue.empty():
+ continue
+ else:
+ collector_task, put_time = self._collector_task_queue.get()
+ start_retry_time = time.time()
+ max_retry_time = 0.3 * self._collector_task_timeout
+ while True:
+ # timeout or assigned to collector
+ get_time = time.time()
+ if get_time - put_time >= self._collector_task_timeout:
+ self.info(
+ 'collector task({}) timeout: [{}, {}, {}/{}]'.format(
+ collector_task['task_id'], get_time, put_time, get_time - put_time,
+ self._collector_task_timeout
+ )
+ )
+ with self._commander_lock:
+ self._commander.notify_fail_collector_task(collector_task)
+ break
+ buffer_id = collector_task['buffer_id']
+ if buffer_id in self._replay_buffer:
+ if self._interaction.send_collector_task(collector_task):
+ self._record_task(collector_task)
+ self.info(
+ "collector_task({}) is successful to be assigned".format(collector_task['task_id'])
+ )
+ break
+ else:
+ self.info("collector_task({}) is failed to be assigned".format(collector_task['task_id']))
+ else:
+ self.info(
+ "collector_task({}) can't find proper buffer_id({})".format(
+ collector_task['task_id'], buffer_id
+ )
+ )
+ if time.time() - start_retry_time >= max_retry_time:
+ # reput into queue
+ self._collector_task_queue.put([collector_task, put_time])
+ self.info("collector task({}) reput into queue".format(collector_task['task_id']))
+ break
+ time.sleep(3)
+
+ def _assign_learner_task(self) -> None:
+ r"""
+ Overview:
+ The function to be called in the assign_learner_task thread.
+ Will take a learner task from learner_task_queue and assign the task.
+ """
+ while not self._end_flag:
+ time.sleep(0.01)
+ if self._learner_task_queue.empty():
+ continue
+ else:
+ learner_task, put_time = self._learner_task_queue.get()
+ start_retry_time = time.time()
+ max_retry_time = 0.1 * self._learner_task_timeout
+ while True:
+ # timeout or assigned to learner
+ get_time = time.time()
+ if get_time - put_time >= self._learner_task_timeout:
+ self.info(
+ 'learner task({}) timeout: [{}, {}, {}/{}]'.format(
+ learner_task['task_id'], get_time, put_time, get_time - put_time,
+ self._learner_task_timeout
+ )
+ )
+ with self._commander_lock:
+ self._commander.notify_fail_learner_task(learner_task)
+ break
+ if self._interaction.send_learner_task(learner_task):
+ self._record_task(learner_task)
+ # create replay_buffer
+ buffer_id = learner_task['buffer_id']
+ if buffer_id not in self._replay_buffer:
+ replay_buffer_cfg = learner_task.pop('replay_buffer_cfg')
+ self._replay_buffer[buffer_id] = create_buffer(replay_buffer_cfg, exp_name=self._exp_name)
+ self._replay_buffer[buffer_id].start()
+ self.info("replay_buffer({}) is created".format(buffer_id))
+ self.info("learner_task({}) is successful to be assigned".format(learner_task['task_id']))
+ break
+ else:
+ self.info("learner_task({}) is failed to be assigned".format(learner_task['task_id']))
+ if time.time() - start_retry_time >= max_retry_time:
+ # reput into queue
+ self._learner_task_queue.put([learner_task, put_time])
+ self.info("learner task({}) reput into queue".format(learner_task['task_id']))
+ break
+ time.sleep(3)
+
+ def _produce_collector_task(self) -> None:
+ r"""
+ Overview:
+ The function to be called in the ``produce_collector_task`` thread.
+ Will ask commander to produce a collector task, then put it into ``collector_task_queue``.
+ """
+ while not self._end_flag:
+ time.sleep(0.01)
+ with self._commander_lock:
+ collector_task = self._commander.get_collector_task()
+ if collector_task is None:
+ continue
+ self.info("collector task({}) put into queue".format(collector_task['task_id']))
+ self._collector_task_queue.put([collector_task, time.time()])
+
+ def _produce_learner_task(self) -> None:
+ r"""
+ Overview:
+ The function to be called in the produce_learner_task thread.
+ Will produce a learner task and put it into the learner_task_queue.
+ """
+ while not self._end_flag:
+ time.sleep(0.01)
+ with self._commander_lock:
+ learner_task = self._commander.get_learner_task()
+ if learner_task is None:
+ continue
+ self.info("learner task({}) put into queue".format(learner_task['task_id']))
+ self._learner_task_queue.put([learner_task, time.time()])
+
+ def state_dict(self) -> dict:
+ r"""
+ Overview:
+ Return empty state_dict.
+ """
+ return {}
+
+ def load_state_dict(self, state_dict: dict) -> None:
+ r"""
+ Overview:
+ Pass when load state_dict.
+ """
+ pass
+
+ def start(self) -> None:
+ r"""
+ Overview:
+ Start the coordinator, including lunching the interaction thread and the collector learner threads.
+ """
+ self._end_flag = False
+ self._interaction.start()
+ self._produce_collector_thread.start()
+ self._assign_collector_thread.start()
+ self._produce_learner_thread.start()
+ self._assign_learner_thread.start()
+
+ def close(self) -> None:
+ r"""
+ Overview:
+ Close the coordinator, including closing the interaction thread, the collector learner threads and the \
+ buffers.
+ """
+ if self._end_flag:
+ return
+ self._end_flag = True
+ time.sleep(1)
+ self._produce_collector_thread.join()
+ self._assign_collector_thread.join()
+ self._produce_learner_thread.join()
+ self._assign_learner_thread.join()
+ self._interaction.close()
+ # close replay buffer
+ replay_buffer_keys = list(self._replay_buffer.keys())
+ for k in replay_buffer_keys:
+ v = self._replay_buffer.pop(k)
+ v.close()
+ self.info('coordinator is closed')
+
+ def __del__(self) -> None:
+ r"""
+ Overview:
+ __del__ method will close the coordinator.
+ """
+ self.close()
+
+ def deal_with_collector_send_data(self, task_id: str, buffer_id: str, data_id: str, data: dict) -> None:
+ r"""
+ Overview:
+ deal with the data send from collector
+ Arguments:
+ - task_id (:obj:`str`): the collector task_id
+ - buffer_id (:obj:`str`): the buffer_id
+ - data_id (:obj:`str`): the data_id
+ - data (:obj:`str`): the data to dealt with
+ """
+ if task_id not in self._task_state:
+ self.error('collector task({}) not in self._task_state when send data, throw it'.format(task_id))
+ return
+ if buffer_id not in self._replay_buffer:
+ self.error(
+ "collector task({}) data({}) doesn't have proper buffer_id({})".format(task_id, data_id, buffer_id)
+ )
+ return
+ self._replay_buffer[buffer_id].push(data, -1)
+ self.info('collector task({}) send data({})'.format(task_id, data_id))
+
+ def deal_with_collector_judge_finish(self, task_id: str, data: dict) -> bool:
+ if task_id not in self._task_state:
+ self.error('collector task({}) not in self._task_state when send data, throw it'.format(task_id))
+ return False
+ with self._commander_lock:
+ collector_finish_flag = self._commander.judge_collector_finish(task_id, data)
+ if collector_finish_flag:
+ self.info('collector task({}) is finished'.format(task_id))
+ return collector_finish_flag
+
+ def deal_with_collector_finish_task(self, task_id: str, finished_task: dict) -> None:
+ r"""
+ Overview:
+ finish the collector task
+ Arguments:
+ - task_id (:obj:`str`): the collector task_id
+ - finished_task (:obj:`dict`): the finished_task
+ """
+ if task_id not in self._task_state:
+ self.error('collector task({}) not in self._task_state when finish, throw it'.format(task_id))
+ return
+ # finish_task
+ with self._commander_lock:
+ # commander will judge whether the whole system is converged and shoule be shutdowned
+ self._system_shutdown_flag = self._commander.finish_collector_task(task_id, finished_task)
+ self._task_state.pop(task_id)
+ self._historical_task.append(task_id)
+ self.info('collector task({}) is finished'.format(task_id))
+
+ def deal_with_learner_get_data(self, task_id: str, buffer_id: str, batch_size: int,
+ cur_learner_iter: int) -> List[dict]:
+ r"""
+ Overview:
+ learner get the data from buffer
+ Arguments:
+ - task_id (:obj:`str`): the learner task_id
+ - buffer_id (:obj:`str`): the buffer_id
+ - batch_size (:obj:`int`): the batch_size to sample
+ - cur_learn_iter (:obj:`int`): the current learner iter num
+ """
+ if task_id not in self._task_state:
+ self.error("learner task({}) get data doesn't have proper task_id".format(task_id))
+ raise RuntimeError(
+ "invalid learner task_id({}) for get data, valid learner_id is {}".format(
+ task_id, self._task_state.keys()
+ )
+ )
+ if buffer_id not in self._replay_buffer:
+ self.error("learner task({}) get data doesn't have proper buffer_id({})".format(task_id, buffer_id))
+ return
+ self.info("learner task({}) get data".format(task_id))
+ return self._replay_buffer[buffer_id].sample(batch_size, cur_learner_iter)
+
+ def deal_with_learner_send_info(self, task_id: str, buffer_id: str, info: dict) -> None:
+ r"""
+ Overview:
+ the learner send the info and update the priority in buffer
+ Arguments:
+ - task_id (:obj:`str`): the learner task id
+ - buffer_id (:obj:`str`): the buffer_id of buffer to add info to
+ - info (:obj:`dict`): the info to add
+ """
+ if task_id not in self._task_state:
+ self.error("learner task({}) send info doesn't have proper task_id".format(task_id))
+ raise RuntimeError(
+ "invalid learner task_id({}) for send info, valid learner_id is {}".format(
+ task_id, self._task_state.keys()
+ )
+ )
+ if buffer_id not in self._replay_buffer:
+ self.error("learner task({}) send info doesn't have proper buffer_id({})".format(task_id, buffer_id))
+ return
+ self._replay_buffer[buffer_id].update(info['priority_info'])
+ with self._commander_lock:
+ self._commander.update_learner_info(task_id, info)
+ self.info("learner task({}) send info".format(task_id))
+
+ def deal_with_learner_judge_finish(self, task_id: str, info: dict) -> bool:
+ if task_id not in self._task_state:
+ self.error("learner task({}) finish task doesn't have proper task_id".format(task_id))
+ raise RuntimeError(
+ "invalid learner task_id({}) for finish task, valid learner_id is {}".format(
+ task_id, self._task_state.keys()
+ )
+ )
+ with self._commander_lock:
+ learner_finish_flag = self._commander.judge_learner_finish(task_id, info)
+ if learner_finish_flag:
+ self.info('learner task({}) is finished'.format(task_id))
+ return learner_finish_flag
+
+ def deal_with_learner_finish_task(self, task_id: str, finished_task: dict) -> None:
+ r"""
+ Overview:
+ finish the learner task, close the corresponding buffer
+ Arguments:
+ - task_id (:obj:`str`): the learner task_id
+ - finished_task (:obj:`dict`): the dict of task to finish
+ """
+ if task_id not in self._task_state:
+ self.error("learner task({}) finish task doesn't have proper task_id".format(task_id))
+ raise RuntimeError(
+ "invalid learner task_id({}) for finish task, valid learner_id is {}".format(
+ task_id, self._task_state.keys()
+ )
+ )
+ with self._commander_lock:
+ buffer_id = self._commander.finish_learner_task(task_id, finished_task)
+ self._task_state.pop(task_id)
+ self._historical_task.append(task_id)
+ self.info("learner task({}) finish".format(task_id))
+ # delete replay buffer
+ if buffer_id is not None:
+ replay_buffer = self._replay_buffer.pop(buffer_id)
+ replay_buffer.close()
+ self.info('replay_buffer({}) is closed'.format(buffer_id))
+
+ def deal_with_increase_collector(self):
+ r""""
+ Overview:
+ Increase task space when a new collector has added dynamically.
+ """
+ with self._commander_lock:
+ self._commander.increase_collector_task_space()
+
+ def deal_with_decrease_collector(self):
+ r""""
+ Overview:
+ Decrease task space when a new collector has removed dynamically.
+ """
+ with self._commander_lock:
+ self._commander.decrease_collector_task_space()
+
+ def info(self, s: str) -> None:
+ r"""
+ Overview:
+ Return the info
+ Arguments:
+ - s (:obj:`str`): the string to print in info
+ """
+ self._logger.info('[Coordinator({})]: {}'.format(self._coordinator_uid, s))
+
+ def error(self, s: str) -> None:
+ r"""
+ Overview:
+ Return the error
+ Arguments:
+ - s (:obj:`str`): the error info to print
+ """
+ self._logger.error('[Coordinator({})]: {}'.format(self._coordinator_uid, s))
+
+ def _record_task(self, task: dict):
+ r"""
+ Overview:
+ Create task state to record task
+ Arguments:
+ - task (:obj:`dict`): the task dict
+ """
+ self._task_state[task['task_id']] = TaskState(task['task_id'])
+
+ @property
+ def system_shutdown_flag(self) -> bool:
+ r"""
+ Overview:
+ Return whether the system is shutdown
+ Returns:
+ - system_shutdown_flag (:obj:`bool`): whether the system is shutdown
+ """
+ return self._system_shutdown_flag
diff --git a/DI-engine/ding/worker/coordinator/one_vs_one_parallel_commander.py b/DI-engine/ding/worker/coordinator/one_vs_one_parallel_commander.py
new file mode 100644
index 0000000000000000000000000000000000000000..82b85420cbfeac11e5e19ee7e2fbe51cd455c52a
--- /dev/null
+++ b/DI-engine/ding/worker/coordinator/one_vs_one_parallel_commander.py
@@ -0,0 +1,374 @@
+from typing import Optional
+import time
+import copy
+
+from ding.utils import deep_merge_dicts
+from ding.policy import create_policy
+from ding.utils import LimitedSpaceContainer, get_task_uid, build_logger, COMMANDER_REGISTRY
+from ding.league import create_league, OneVsOneLeague
+from .base_parallel_commander import BaseCommander
+
+
+@COMMANDER_REGISTRY.register('one_vs_one')
+class OneVsOneCommander(BaseCommander):
+ r"""
+ Overview:
+ Parallel commander for battle games.
+ Interface:
+ __init__, get_collector_task, get_learner_task, finish_collector_task, finish_learner_task,
+ notify_fail_collector_task, notify_fail_learner_task, get_learner_info
+ """
+ config = dict(
+ collector_task_space=2,
+ learner_task_space=1,
+ eval_interval=60,
+ )
+
+ def __init__(self, cfg: dict) -> None:
+ r"""
+ Overview:
+ Init the 1v1 commander according to config.
+ Arguments:
+ - cfg (:obj:`dict`): Dict type config file.
+ """
+ self._cfg = cfg
+ self._exp_name = cfg.exp_name
+ commander_cfg = self._cfg.policy.other.commander
+ self._commander_cfg = commander_cfg
+
+ self._collector_env_cfg = copy.deepcopy(self._cfg.env)
+ self._collector_env_cfg.pop('collector_episode_num')
+ self._collector_env_cfg.pop('evaluator_episode_num')
+ self._collector_env_cfg.manager.episode_num = self._cfg.env.collector_episode_num
+ self._evaluator_env_cfg = copy.deepcopy(self._cfg.env)
+ self._evaluator_env_cfg.pop('collector_episode_num')
+ self._evaluator_env_cfg.pop('evaluator_episode_num')
+ self._evaluator_env_cfg.manager.episode_num = self._cfg.env.evaluator_episode_num
+
+ self._collector_task_space = LimitedSpaceContainer(0, commander_cfg.collector_task_space)
+ self._learner_task_space = LimitedSpaceContainer(0, commander_cfg.learner_task_space)
+ self._learner_info = [{'learner_step': 0}]
+ # TODO accumulate collect info
+ self._collector_info = []
+ self._total_collector_env_step = 0
+ self._evaluator_info = []
+ self._current_buffer_id = None
+ self._current_policy_id = [] # 1v1 commander has multiple policies
+ self._last_eval_time = 0
+ # policy_cfg must be deepcopyed
+ policy_cfg = copy.deepcopy(self._cfg.policy)
+ self._policy = create_policy(policy_cfg, enable_field=['command']).command_mode
+ self._logger, self._tb_logger = build_logger(
+ "./{}/log/commander".format(self._exp_name), "commander", need_tb=True
+ )
+ self._collector_logger, _ = build_logger(
+ "./{}/log/commander".format(self._exp_name), "commander_collector", need_tb=False
+ )
+ self._evaluator_logger, _ = build_logger(
+ "./{}/log/commander".format(self._exp_name), "commander_evaluator", need_tb=False
+ )
+ self._sub_logger = {
+ 'collector': self._collector_logger,
+ 'evaluator': self._evaluator_logger,
+ }
+ self._end_flag = False
+
+ # League
+ path_policy = commander_cfg.path_policy
+ self._path_policy = path_policy
+ commander_cfg.league.path_policy = path_policy
+ commander_cfg.league = deep_merge_dicts(OneVsOneLeague.default_config(), commander_cfg.league)
+ self._league = create_league(commander_cfg.league)
+ self._active_player = self._league.active_players[0]
+ self._current_player_id = {}
+
+ def get_collector_task(self) -> Optional[dict]:
+ r"""
+ Overview:
+ Return the new collector task when there is residual task space; Otherwise return None.
+ Return:
+ - task (:obj:`Optional[dict]`): New collector task.
+ """
+ if self._end_flag:
+ return None
+ if self._collector_task_space.acquire_space():
+ if self._current_buffer_id is None or len(self._current_policy_id) == 0:
+ self._collector_task_space.release_space()
+ return None
+ cur_time = time.time()
+ if cur_time - self._last_eval_time > self._commander_cfg.eval_interval:
+ eval_flag = True
+ self._last_eval_time = time.time()
+ else:
+ eval_flag = False
+ collector_cfg = copy.deepcopy(self._cfg.policy.collect.collector)
+ info = self._learner_info[-1]
+ info['envstep'] = self._total_collector_env_step
+ collector_cfg.collect_setting = self._policy.get_setting_collect(info)
+ eval_or_collect = "EVALUATOR" if eval_flag else "COLLECTOR"
+ task_id = '{}_task_{}'.format(eval_or_collect.lower(), get_task_uid())
+ league_job_dict = self._league.get_job_info(self._active_player.player_id, eval_flag)
+ # `self._current_player_id`: For eval, [id1, id2]; For collect, [id1].
+ self._current_player_id[task_id] = league_job_dict['player_id']
+ collector_cfg.policy_update_path = league_job_dict['checkpoint_path']
+ collector_cfg.policy_update_flag = league_job_dict['player_active_flag']
+ collector_cfg.eval_flag = eval_flag
+ collector_cfg.exp_name = self._exp_name
+ if eval_flag:
+ collector_cfg.policy = copy.deepcopy([self._cfg.policy])
+ collector_cfg.env = self._evaluator_env_cfg
+ collector_cfg.env.eval_opponent = league_job_dict['eval_opponent']
+ else:
+ collector_cfg.policy = copy.deepcopy([self._cfg.policy for _ in range(2)])
+ collector_cfg.env = self._collector_env_cfg
+ collector_command = {
+ 'task_id': task_id,
+ 'buffer_id': self._current_buffer_id,
+ 'collector_cfg': collector_cfg,
+ }
+ # self._logger.info(
+ # "[{}] Task starts:\n{}".format(
+ # eval_or_collect, '\n'.join(
+ # [
+ # '{}: {}'.format(k, v) for k, v in collector_command.items()
+ # if k not in ['collector_cfg', 'policy']
+ # ]
+ # )
+ # )
+ # )
+ return collector_command
+ else:
+ # self._logger.info("[{}] Fails to start because of no launch space".format(eval_or_collect.upper()))
+ return None
+
+ def get_learner_task(self) -> Optional[dict]:
+ r"""
+ Overview:
+ Return the new learner task when there is residual task space; Otherwise return None.
+ Return:
+ - task (:obj:`Optional[dict]`): New learner task.
+ """
+ if self._end_flag:
+ return None
+ if self._learner_task_space.acquire_space():
+ learner_cfg = copy.deepcopy(self._cfg.policy.learn.learner)
+ learner_cfg.exp_name = self._exp_name
+ learner_command = {
+ 'task_id': 'learner_task_{}'.format(get_task_uid()),
+ 'policy_id': self._init_policy_id(),
+ 'buffer_id': self._init_buffer_id(),
+ 'learner_cfg': learner_cfg,
+ 'replay_buffer_cfg': self._cfg.policy.other.replay_buffer,
+ 'policy': copy.deepcopy(self._cfg.policy),
+ 'league_save_checkpoint_path': self._active_player.checkpoint_path,
+ }
+ # self._logger.info(
+ # "[LEARNER] Task starts:\n{}".format(
+ # '\n'.join(
+ # [
+ # '{}: {}'.format(k, v) for k, v in learner_command.items()
+ # if k not in ['learner_cfg', 'replay_buffer_cfg', 'policy']
+ # ]
+ # )
+ # )
+ # )
+ return learner_command
+ else:
+ # self._logger.info("[LEARNER] Fails to start because of no launch space")
+ return None
+
+ def finish_collector_task(self, task_id: str, finished_task: dict) -> bool:
+ r"""
+ Overview:
+ Get collector's finish_task_info and release collector_task_space.
+ If collector's task is evaluation, judge the convergence and return it.
+ Arguments:
+ - task_id (:obj:`str`): the collector task_id
+ - finished_task (:obj:`dict`): the finished task
+ Returns:
+ - convergence (:obj:`bool`): Whether the stop val is reached and the algorithm is converged. \
+ If True, the pipeline can be finished. It is only effective for an evaluator finish task.
+ """
+ self._collector_task_space.release_space()
+ if finished_task['eval_flag']:
+ self._evaluator_info.append(finished_task)
+ # Evaluate difficulty increment
+ wins, games = 0, 0
+ game_result = finished_task['game_result']
+ for i in game_result:
+ for j in i:
+ if j == "wins":
+ wins += 1
+ games += 1
+ eval_win = True if wins / games > 0.7 else False
+ player_update_info = {
+ 'player_id': self._active_player.player_id,
+ 'eval_win': eval_win,
+ }
+ difficulty_inc = self._league.update_active_player(player_update_info)
+ is_hardest = eval_win and not difficulty_inc
+ # Print log
+ train_iter = self._learner_info[-1]['learner_step']
+ info = {
+ 'train_iter': train_iter,
+ 'episode_count': finished_task['real_episode_count'],
+ 'step_count': finished_task['step_count'],
+ 'avg_step_per_episode': finished_task['avg_time_per_episode'],
+ 'avg_time_per_step': finished_task['avg_time_per_step'],
+ 'avg_time_per_episode': finished_task['avg_step_per_episode'],
+ 'reward_mean': finished_task['reward_mean'],
+ 'reward_std': finished_task['reward_std'],
+ 'game_result': finished_task['game_result'],
+ 'eval_win': eval_win,
+ 'difficulty_inc': difficulty_inc,
+ }
+ self._sub_logger['evaluator'].info(
+ "[EVALUATOR] Task ends:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))
+ )
+ for k, v in info.items():
+ if k in ['train_iter', 'game_result', 'eval_win', 'difficulty_inc']:
+ continue
+ self._tb_logger.add_scalar('evaluator_iter/' + k, v, train_iter)
+ self._tb_logger.add_scalar('evaluator_step/' + k, v, self._total_collector_env_step)
+ # If evaluator task ends, whether to stop training should be judged.
+ eval_stop_value = self._cfg.env.stop_value
+ print('===', eval_stop_value)
+ print('===', finished_task['reward_mean'])
+ print('===', eval_win, difficulty_inc)
+ if eval_stop_value is not None and finished_task['reward_mean'] >= eval_stop_value and is_hardest:
+ self._logger.info(
+ "[DI-engine parallel pipeline] Current episode_return: {} is greater than the stop_value: {}".
+ format(finished_task['reward_mean'], eval_stop_value) + ", so the total training program is over."
+ )
+ self._end_flag = True
+ return True
+ else:
+ self._collector_info.append(finished_task)
+ self._total_collector_env_step += finished_task['step_count']
+ # If collector task ends, league payoff should be updated.
+ payoff_update_dict = {
+ 'player_id': self._current_player_id.pop(task_id),
+ 'result': finished_task['game_result'],
+ }
+ self._league.finish_job(payoff_update_dict)
+ # Print log
+ train_iter = self._learner_info[-1]['learner_step']
+ info = {
+ 'train_iter': train_iter,
+ 'episode_count': finished_task['real_episode_count'],
+ 'step_count': finished_task['step_count'],
+ 'avg_step_per_episode': finished_task['avg_time_per_episode'],
+ 'avg_time_per_step': finished_task['avg_time_per_step'],
+ 'avg_time_per_episode': finished_task['avg_step_per_episode'],
+ 'reward_mean': finished_task['reward_mean'],
+ 'reward_std': finished_task['reward_std'],
+ 'game_result': finished_task['game_result'],
+ }
+ self._sub_logger['collector'].info(
+ "[COLLECTOR] Task ends:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()]))
+ )
+ for k, v in info.items():
+ if k in ['train_iter', 'game_result']:
+ continue
+ self._tb_logger.add_scalar('collector_iter/' + k, v, train_iter)
+ self._tb_logger.add_scalar('collector_step/' + k, v, self._total_collector_env_step)
+ return False
+ return False
+
+ def finish_learner_task(self, task_id: str, finished_task: dict) -> str:
+ r"""
+ Overview:
+ Get learner's finish_task_info, release learner_task_space, reset corresponding variables.
+ Arguments:
+ - task_id (:obj:`str`): Learner task_id
+ - finished_task (:obj:`dict`): Learner's finish_learn_info.
+ Returns:
+ - buffer_id (:obj:`str`): Buffer id of the finished learner.
+ """
+ self._learner_task_space.release_space()
+ buffer_id = finished_task['buffer_id']
+ self._current_buffer_id = None
+ self._current_policy_id = []
+ self._learner_info = [{'learner_step': 0}]
+ self._evaluator_info = []
+ self._last_eval_time = 0
+ self._current_player_id = {}
+ # self._logger.info("[LEARNER] Task ends.")
+ return buffer_id
+
+ def notify_fail_collector_task(self, task: dict) -> None:
+ r"""
+ Overview:
+ Release task space when collector task fails.
+ """
+ self._collector_task_space.release_space()
+ # self._logger.info("[COLLECTOR/EVALUATOR] Task fails.")
+
+ def notify_fail_learner_task(self, task: dict) -> None:
+ r"""
+ Overview:
+ Release task space when learner task fails.
+ """
+ self._learner_task_space.release_space()
+ # self._logger.info("[LEARNER] Task fails.")
+
+ def update_learner_info(self, task_id: str, info: dict) -> None:
+ r"""
+ Overview:
+ Get learner info dict, use it to update commander record and league record.
+ Arguments:
+ - task_id (:obj:`str`): Learner task_id
+ - info (:obj:`dict`): Dict type learner info.
+ """
+ self._learner_info.append(info)
+ player_update_info = {
+ 'player_id': self._active_player.player_id,
+ 'train_iteration': info['learner_step'],
+ }
+ self._league.update_active_player(player_update_info)
+ self._logger.info("[LEARNER] Update info at step {}".format(player_update_info['train_iteration']))
+ snapshot = self._league.judge_snapshot(self._active_player.player_id)
+ if snapshot:
+ self._logger.info(
+ "[LEAGUE] Player {} snapshot at step {}".format(
+ player_update_info['player_id'], player_update_info['train_iteration']
+ )
+ )
+
+ def _init_policy_id(self) -> str:
+ r"""
+ Overview:
+ Init the policy id and return it.
+ Returns:
+ - policy_id (:obj:`str`): New initialized policy id.
+ """
+ policy_id = 'policy_{}'.format(get_task_uid())
+ self._current_policy_id.append(policy_id)
+ assert len(self._current_policy_id) <= 2
+ return policy_id
+
+ def _init_buffer_id(self) -> str:
+ r"""
+ Overview:
+ Init the buffer id and return it.
+ Returns:
+ - buffer_id (:obj:`str`): New initialized buffer id.
+ """
+ buffer_id = 'buffer_{}'.format(get_task_uid())
+ self._current_buffer_id = buffer_id # todo(why policy 2, buffer 1)
+ # assert len(self._current_buffer_id) <= 2
+ return buffer_id
+
+ def increase_collector_task_space(self):
+ r""""
+ Overview:
+ Increase task space when a new collector has added dynamically.
+ """
+ self._collector_task_space.increase_space()
+
+ def decrease_collector_task_space(self):
+ r""""
+ Overview:
+ Decrease task space when a new collector has removed dynamically.
+ """
+ self._collector_task_space.decrease_space()
diff --git a/DI-engine/ding/worker/coordinator/operator_server.py b/DI-engine/ding/worker/coordinator/operator_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d2152336b15dc411ca67049a25df1c502987dde
--- /dev/null
+++ b/DI-engine/ding/worker/coordinator/operator_server.py
@@ -0,0 +1,96 @@
+from typing import Optional, Mapping, Any
+from requests.exceptions import RequestException
+from ding.interaction.base import get_http_engine_class, get_values_from_response
+
+
+class OperatorServer:
+
+ def __init__(
+ self,
+ host: str,
+ port: Optional[int] = None,
+ api_version: str = "v1alpha1",
+ https: bool = False,
+ namespace: str = None,
+ name: str = None,
+ ):
+ # request part
+ self.__http_engine = get_http_engine_class(headers={})()(host, port, https)
+ self.__api_version = api_version
+ self.__namespace = namespace
+ self.__my_name = name
+ self.__worker_type = None
+
+ @property
+ def api_version(self):
+ return self.__api_version
+
+ def set_worker_type(self, type):
+ assert type in ['coordinator', 'aggregator'], "invalid worker_type: {}".format(type)
+ self.__worker_type = type
+
+ def __prefix_with_api_version(self, path):
+ return self.__api_version + path
+
+ def get_replicas(self, name: str = None):
+ try:
+ if name is None:
+ assert self.__worker_type, "set worker type first"
+ params = {"namespace": self.__namespace, self.__worker_type: self.__my_name}
+ else:
+ params = {"namespace": self.__namespace, "name": name}
+ response = self.__http_engine.request('GET', self.__prefix_with_api_version('/replicas'), params=params)
+ except RequestException as err:
+ return self._error_request(err)
+ else:
+ return self._after_request(*get_values_from_response(response))
+
+ def post_replicas(self, data):
+ try:
+ data.update({"namespace": self.__namespace, "coordinator": self.__my_name})
+ response = self.__http_engine.request('POST', self.__prefix_with_api_version('/replicas'), data=data)
+ except RequestException as err:
+ return self._error_request(err)
+ else:
+ return self._after_request(*get_values_from_response(response))
+
+ def post_replicas_failed(self, collectors=[], learners=[]):
+ try:
+ data = {
+ "namespace": self.__namespace,
+ "coordinator": self.__my_name,
+ "collectors": collectors,
+ "learners": learners,
+ }
+ response = self.__http_engine.request('POST', self.__prefix_with_api_version('/replicas/failed'), data=data)
+ except RequestException as err:
+ return self._error_request(err)
+ else:
+ return self._after_request(*get_values_from_response(response))
+
+ def delete_replicas(self, n_collectors=0, n_learners=0):
+ try:
+ data = {
+ "namespace": self.__namespace,
+ "coordinator": self.__my_name,
+ "collectors": {
+ "replicas": n_collectors,
+ },
+ "learners": {
+ "replicas": n_learners,
+ }
+ }
+ response = self.__http_engine.request('DELETE', self.__prefix_with_api_version('/replicas'), data=data)
+ except RequestException as err:
+ return self._error_request(err)
+ else:
+ return self._after_request(*get_values_from_response(response))
+
+ def _after_request(
+ self, status_code: int, success: bool, code: int, message: Optional[str], data: Optional[Mapping[str, Any]]
+ ) -> Any:
+ return success, code, message, data
+
+ def _error_request(self, error: RequestException) -> Any:
+ # raise error
+ raise RequestException
diff --git a/DI-engine/ding/worker/coordinator/resource_manager.py b/DI-engine/ding/worker/coordinator/resource_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..e81227345046c3387bafe57f51f7ca3bcc8dc923
--- /dev/null
+++ b/DI-engine/ding/worker/coordinator/resource_manager.py
@@ -0,0 +1,71 @@
+import random
+
+
+class NaiveResourceManager(object):
+ r"""
+ Overview:
+ the naive resource manager
+ Interface:
+ __init__, assign_collector, assign_learner, update
+ """
+
+ def __init__(self) -> None:
+ r"""
+ Overview:
+ init the resouce manager
+ """
+ self._worker_type = ['collector', 'learner']
+ self._resource_info = {k: {} for k in self._worker_type}
+
+ def assign_collector(self, collector_task: dict) -> dict:
+ r"""
+ Overview:
+ assign the collector_task randomly and return the resouce info
+ Arguments:
+ - collector_task (:obj:`dict`): the collector task to assign
+ """
+ available_collector_list = list(self._resource_info['collector'].keys())
+ if len(available_collector_list) > 0:
+ selected_collector = random.sample(available_collector_list, 1)[0]
+ info = self._resource_info['collector'].pop(selected_collector)
+ return {'collector_id': selected_collector, 'resource_info': info}
+ else:
+ return None
+
+ def assign_learner(self, learner_task: dict) -> dict:
+ r"""
+ Overview:
+ assign the learner_task randomly and return the resouce info
+ Arguments:
+ - learner_task (:obj:`dict`): the learner task to assign
+ """
+ available_learner_list = list(self._resource_info['learner'].keys())
+ if len(available_learner_list) > 0:
+ selected_learner = random.sample(available_learner_list, 1)[0]
+ info = self._resource_info['learner'].pop(selected_learner)
+ return {'learner_id': selected_learner, 'resource_info': info}
+ else:
+ return None
+
+ def have_assigned(self, name: id, worker_id: str) -> bool:
+ assert name in self._worker_type, "invalid worker_type: {}".format(name)
+ if name == 'collector':
+ return worker_id in self._resource_info['collector']
+ elif name == 'learner':
+ return worker_id in self._resource_info['learner']
+
+ def delete(self, name: id, worker_id: str) -> bool:
+ assert name in self._worker_type, "invalid worker_type: {}".format(name)
+ if worker_id in self._resource_info[name]:
+ self._resource_info.pop(worker_id)
+ return True
+ else:
+ return False
+
+ def update(self, name: str, worker_id: str, resource_info: dict) -> None:
+ r"""
+ Overview:
+ update the reource info
+ """
+ assert name in self._worker_type, "invalid worker_type: {}".format(name)
+ self._resource_info[name][worker_id] = resource_info
diff --git a/DI-engine/ding/worker/coordinator/solo_parallel_commander.py b/DI-engine/ding/worker/coordinator/solo_parallel_commander.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab374ebbc9903fe4d8c3783667c48fe4c55d8474
--- /dev/null
+++ b/DI-engine/ding/worker/coordinator/solo_parallel_commander.py
@@ -0,0 +1,264 @@
+from typing import Optional
+import time
+import copy
+
+from ding.policy import create_policy
+from ding.utils import LimitedSpaceContainer, get_task_uid, build_logger, COMMANDER_REGISTRY
+from .base_parallel_commander import BaseCommander
+
+
+@COMMANDER_REGISTRY.register('solo')
+class SoloCommander(BaseCommander):
+ r"""
+ Overview:
+ Parallel commander for solo games.
+ Interface:
+ __init__, get_collector_task, get_learner_task, finish_collector_task, finish_learner_task,
+ notify_fail_collector_task, notify_fail_learner_task, update_learner_info
+ """
+ config = dict(
+ collector_task_space=1,
+ learner_task_space=1,
+ eval_interval=60,
+ )
+
+ def __init__(self, cfg: dict) -> None:
+ r"""
+ Overview:
+ Init the solo commander according to config.
+ Arguments:
+ - cfg (:obj:`dict`): Dict type config file.
+ """
+ self._cfg = cfg
+ self._exp_name = cfg.exp_name
+ commander_cfg = self._cfg.policy.other.commander
+ self._commander_cfg = commander_cfg
+
+ self._collector_env_cfg = copy.deepcopy(self._cfg.env)
+ self._collector_env_cfg.pop('collector_episode_num')
+ self._collector_env_cfg.pop('evaluator_episode_num')
+ self._collector_env_cfg.manager.episode_num = self._cfg.env.collector_episode_num
+ self._evaluator_env_cfg = copy.deepcopy(self._cfg.env)
+ self._evaluator_env_cfg.pop('collector_episode_num')
+ self._evaluator_env_cfg.pop('evaluator_episode_num')
+ self._evaluator_env_cfg.manager.episode_num = self._cfg.env.evaluator_episode_num
+
+ self._collector_task_space = LimitedSpaceContainer(0, commander_cfg.collector_task_space)
+ self._learner_task_space = LimitedSpaceContainer(0, commander_cfg.learner_task_space)
+ self._learner_info = [{'learner_step': 0}]
+ # TODO(nyz) accumulate collect info
+ self._collector_info = []
+ self._total_collector_env_step = 0
+ self._evaluator_info = []
+ self._current_buffer_id = None
+ self._current_policy_id = None
+ self._last_eval_time = 0
+ # policy_cfg must be deepcopyed
+ policy_cfg = copy.deepcopy(self._cfg.policy)
+ self._policy = create_policy(policy_cfg, enable_field=['command']).command_mode
+ self._logger, self._tb_logger = build_logger(
+ "./{}/log/commander".format(self._exp_name), "commander", need_tb=True
+ )
+ self._collector_logger, _ = build_logger(
+ "./{}/log/commander".format(self._exp_name), "commander_collector", need_tb=False
+ )
+ self._evaluator_logger, _ = build_logger(
+ "./{}/log/commander".format(self._exp_name), "commander_evaluator", need_tb=False
+ )
+ self._sub_logger = {
+ 'collector': self._collector_logger,
+ 'evaluator': self._evaluator_logger,
+ }
+ self._end_flag = False
+
+ def get_collector_task(self) -> Optional[dict]:
+ r"""
+ Overview:
+ Return the new collector task when there is residual task space; Otherwise return None.
+ Return:
+ - task (:obj:`Optional[dict]`): New collector task.
+ """
+ if self._end_flag:
+ return None
+ if self._collector_task_space.acquire_space():
+ if self._current_buffer_id is None or self._current_policy_id is None:
+ self._collector_task_space.release_space()
+ return None
+ cur_time = time.time()
+ if cur_time - self._last_eval_time > self._commander_cfg.eval_interval:
+ eval_flag = True
+ self._last_eval_time = time.time()
+ else:
+ eval_flag = False
+ collector_cfg = copy.deepcopy(self._cfg.policy.collect.collector)
+ # the newest info
+ info = self._learner_info[-1]
+ info['envstep'] = self._total_collector_env_step
+ collector_cfg.collect_setting = self._policy.get_setting_collect(info)
+ collector_cfg.policy_update_path = self._current_policy_id
+ collector_cfg.eval_flag = eval_flag
+ collector_cfg.policy = copy.deepcopy(self._cfg.policy)
+ collector_cfg.exp_name = self._exp_name
+ if eval_flag:
+ collector_cfg.env = self._evaluator_env_cfg
+ else:
+ collector_cfg.env = self._collector_env_cfg
+ return {
+ 'task_id': 'collector_task_{}'.format(get_task_uid()),
+ 'buffer_id': self._current_buffer_id,
+ 'collector_cfg': collector_cfg,
+ }
+ else:
+ return None
+
+ def get_learner_task(self) -> Optional[dict]:
+ r"""
+ Overview:
+ Return the new learner task when there is residual task space; Otherwise return None.
+ Return:
+ - task (:obj:`Optional[dict]`): New learner task.
+ """
+ if self._end_flag:
+ return None
+ if self._learner_task_space.acquire_space():
+ learner_cfg = copy.deepcopy(self._cfg.policy.learn.learner)
+ learner_cfg.exp_name = self._exp_name
+ return {
+ 'task_id': 'learner_task_{}'.format(get_task_uid()),
+ 'policy_id': self._init_policy_id(),
+ 'buffer_id': self._init_buffer_id(),
+ 'learner_cfg': learner_cfg,
+ 'replay_buffer_cfg': copy.deepcopy(self._cfg.policy.other.replay_buffer),
+ 'policy': copy.deepcopy(self._cfg.policy),
+ }
+ else:
+ return None
+
+ def finish_collector_task(self, task_id: str, finished_task: dict) -> bool:
+ r"""
+ Overview:
+ Get collector's finish_task_info and release collector_task_space.
+ If collector's task is evaluation, judge the convergence and return it.
+ Arguments:
+ - task_id (:obj:`str`): the collector task_id
+ - finished_task (:obj:`dict`): the finished task
+ Returns:
+ - convergence (:obj:`bool`): Whether the stop val is reached and the algorithm is converged. \
+ If True, the pipeline can be finished.
+ """
+ self._collector_task_space.release_space()
+ evaluator_or_collector = "evaluator" if finished_task['eval_flag'] else "collector"
+ train_iter = finished_task['train_iter']
+ info = {
+ 'train_iter': train_iter,
+ 'episode_count': finished_task['real_episode_count'],
+ 'step_count': finished_task['step_count'],
+ 'avg_step_per_episode': finished_task['avg_time_per_episode'],
+ 'avg_time_per_step': finished_task['avg_time_per_step'],
+ 'avg_time_per_episode': finished_task['avg_step_per_episode'],
+ 'reward_mean': finished_task['reward_mean'],
+ 'reward_std': finished_task['reward_std'],
+ }
+ self._sub_logger[evaluator_or_collector].info(
+ "[{}] Task ends:\n{}".format(
+ evaluator_or_collector.upper(), '\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])
+ )
+ )
+ for k, v in info.items():
+ if k in ['train_iter']:
+ continue
+ self._tb_logger.add_scalar('{}_iter/'.format(evaluator_or_collector) + k, v, train_iter)
+ self._tb_logger.add_scalar('{}_step/'.format(evaluator_or_collector) + k, v, self._total_collector_env_step)
+ if finished_task['eval_flag']:
+ self._evaluator_info.append(finished_task)
+ eval_stop_value = self._cfg.env.stop_value
+ if eval_stop_value is not None and finished_task['reward_mean'] >= eval_stop_value:
+ self._logger.info(
+ "[DI-engine parallel pipeline] current episode_return: {} is greater than the stop_value: {}".
+ format(finished_task['reward_mean'], eval_stop_value) + ", so the total training program is over."
+ )
+ self._end_flag = True
+ return True
+ else:
+ self._collector_info.append(finished_task)
+ self._total_collector_env_step += finished_task['step_count']
+ return False
+
+ def finish_learner_task(self, task_id: str, finished_task: dict) -> str:
+ r"""
+ Overview:
+ Get learner's finish_task_info, release learner_task_space, reset corresponding variables.
+ Arguments:
+ - task_id (:obj:`str`): Learner task_id
+ - finished_task (:obj:`dict`): Learner's finish_learn_info.
+ Returns:
+ - buffer_id (:obj:`str`): Buffer id of the finished learner.
+ """
+ self._learner_task_space.release_space()
+ buffer_id = finished_task['buffer_id']
+ self._current_buffer_id = None
+ self._current_policy_id = None
+ self._learner_info = [{'learner_step': 0}]
+ self._evaluator_info = []
+ self._last_eval_time = 0
+ return buffer_id
+
+ def notify_fail_collector_task(self, task: dict) -> None:
+ r"""
+ Overview:
+ Release task space when collector task fails.
+ """
+ self._collector_task_space.release_space()
+
+ def notify_fail_learner_task(self, task: dict) -> None:
+ r"""
+ Overview:
+ Release task space when learner task fails.
+ """
+ self._learner_task_space.release_space()
+
+ def update_learner_info(self, task_id: str, info: dict) -> None:
+ r"""
+ Overview:
+ Append the info to learner_info:
+ Arguments:
+ - task_id (:obj:`str`): Learner task_id
+ - info (:obj:`dict`): Dict type learner info.
+ """
+ self._learner_info.append(info)
+
+ def _init_policy_id(self) -> str:
+ r"""
+ Overview:
+ Init the policy id and return it.
+ Returns:
+ - policy_id (:obj:`str`): New initialized policy id.
+ """
+ policy_id = 'policy_{}'.format(get_task_uid())
+ self._current_policy_id = policy_id
+ return policy_id
+
+ def _init_buffer_id(self) -> str:
+ r"""
+ Overview:
+ Init the buffer id and return it.
+ Returns:
+ - buffer_id (:obj:`str`): New initialized buffer id.
+ """
+ buffer_id = 'buffer_{}'.format(get_task_uid())
+ self._current_buffer_id = buffer_id
+ return buffer_id
+
+ def increase_collector_task_space(self):
+ r""""
+ Overview:
+ Increase task space when a new collector has added dynamically.
+ """
+ self._collector_task_space.increase_space()
+
+ def decrease_collector_task_space(self):
+ r""""
+ Overview:
+ Decrease task space when a new collector has removed dynamically.
+ """
+ self._collector_task_space.decrease_space()
diff --git a/DI-engine/ding/worker/coordinator/tests/conftest.py b/DI-engine/ding/worker/coordinator/tests/conftest.py
new file mode 100644
index 0000000000000000000000000000000000000000..96c9ce67e8ea1cd8ba0e73fad9cd48a39d9a9231
--- /dev/null
+++ b/DI-engine/ding/worker/coordinator/tests/conftest.py
@@ -0,0 +1,108 @@
+import pytest
+from easydict import EasyDict
+
+from ding.config import compile_config_parallel
+from ding.worker.coordinator.one_vs_one_parallel_commander import OneVsOneCommander
+
+
+@pytest.fixture(scope='function')
+def setup_1v1commander():
+ nstep = 1
+ eval_interval = 5
+ main_config = dict(
+ exp_name='one_vs_one_test',
+ env=dict(
+ collector_env_num=8,
+ collector_episode_num=2,
+ evaluator_env_num=5,
+ evaluator_episode_num=1,
+ stop_value=20,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=3,
+ encoder_kwargs=dict(encoder_type='conv2d'),
+ ),
+ nstep=nstep,
+ learn=dict(
+ batch_size=32,
+ learning_rate=0.0001,
+ weight_decay=0.,
+ algo=dict(
+ target_update_freq=500,
+ discount_factor=0.99,
+ nstep=nstep,
+ ),
+ learner=dict(
+ learner_num=1,
+ send_policy_freq=1,
+ ),
+ ),
+ collect=dict(
+ traj_len=15,
+ algo=dict(nstep=nstep),
+ collector=dict(
+ collector_num=2,
+ update_policy_second=3,
+ ),
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1.,
+ end=0.005,
+ decay=1000000,
+ ),
+ commander=dict(
+ collector_task_space=2,
+ learner_task_space=1,
+ eval_interval=eval_interval,
+ league=dict(naive_sp_player=dict(one_phase_step=1000, ), ),
+ ),
+ replay_buffer=dict(),
+ ),
+ ),
+ )
+ main_config = EasyDict(main_config)
+ create_config = dict(
+ env=dict(
+ # 1v1 commander should use “competitive_rl”.
+ # However, because this env is hard to install, we use "cartpole" instead.
+ # But commander does not need a real env, it is just preserved to use `compile_config_parallel`.
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn_command'),
+ learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
+ collector=dict(
+ type='zergling',
+ import_names=['ding.worker.collector.zergling_parallel_collector'],
+ ),
+ commander=dict(
+ type='one_vs_one',
+ import_names=['ding.worker.coordinator.one_vs_one_parallel_commander'],
+ ),
+ comm_learner=dict(
+ type='flask_fs',
+ import_names=['ding.worker.learner.comm.flask_fs_learner'],
+ ),
+ comm_collector=dict(
+ type='flask_fs',
+ import_names=['ding.worker.collector.comm.flask_fs_collector'],
+ ),
+ league=dict(type='one_vs_one'),
+ )
+ system_config = dict(
+ coordinator=dict(),
+ path_data='./data',
+ path_policy='./policy',
+ communication_mode='auto',
+ learner_gpu_num=1,
+ )
+ system_config = EasyDict(system_config)
+ create_config = EasyDict(create_config)
+ config = compile_config_parallel(main_config, create_cfg=create_config, system_cfg=system_config)
+ return OneVsOneCommander(config['main'])
diff --git a/DI-engine/ding/worker/coordinator/tests/test_coordinator.py b/DI-engine/ding/worker/coordinator/tests/test_coordinator.py
new file mode 100644
index 0000000000000000000000000000000000000000..9beeffd77c085de06b3019be7c7e2b9612b568c9
--- /dev/null
+++ b/DI-engine/ding/worker/coordinator/tests/test_coordinator.py
@@ -0,0 +1,73 @@
+import pytest
+import os
+import time
+from ding.worker import Coordinator
+from ding.worker.learner.comm import NaiveLearner
+from ding.worker.collector.comm import NaiveCollector
+from ding.utils import find_free_port
+from ding.config import compile_config_parallel
+from ding.config.utils import parallel_test_main_config, parallel_test_create_config, parallel_test_system_config
+
+DATA_PREFIX = 'SLAVE_COLLECTOR_DATA_COORDINATOR_TEST'
+
+
+@pytest.fixture(scope='function')
+def setup_config():
+ return compile_config_parallel(
+ parallel_test_main_config, create_cfg=parallel_test_create_config, system_cfg=parallel_test_system_config
+ )
+
+
+@pytest.fixture(scope='function')
+def setup_collector(setup_config):
+ cfg = setup_config.system.coordinator.collector
+ collector = {}
+ for _, (name, host, port) in cfg.items():
+ collector[name] = NaiveCollector(host, port, prefix=DATA_PREFIX)
+ collector[name].start()
+ yield collector
+ for a in collector.values():
+ a.close()
+
+
+@pytest.fixture(scope='function')
+def setup_learner(setup_config):
+ cfg = setup_config.system.coordinator.learner
+ learner = {}
+ for _, (name, host, port) in cfg.items():
+ learner[name] = NaiveLearner(host, port, prefix=DATA_PREFIX)
+ learner[name].start()
+ yield learner
+ for l in learner.values():
+ l.close()
+
+
+@pytest.mark.unittest(rerun=5)
+class TestCoordinator:
+
+ def test_naive(self, setup_config, setup_collector, setup_learner):
+ os.popen('rm -rf {}*'.format(DATA_PREFIX))
+ assert len(setup_collector) == len(setup_config.system.coordinator.collector)
+ assert len(setup_learner) == len(setup_config.system.coordinator.learner)
+ try:
+ coordinator = Coordinator(setup_config)
+ coordinator.start()
+ while True:
+ if coordinator._commander._learner_task_finish_count == 1:
+ break
+ time.sleep(0.5)
+ coordinator.close()
+ except Exception as e:
+ os.popen('rm -rf {}*'.format(DATA_PREFIX))
+ assert False, e
+
+ collector_task_ids = [t for t in coordinator._historical_task if 'collector' in t]
+ for i in range(1, 21):
+ for t in collector_task_ids:
+ assert os.path.exists('{}_{}_{}'.format(DATA_PREFIX, t, i))
+ assert os.path.exists('{}_final_model.pth'.format(DATA_PREFIX))
+ assert len(coordinator._replay_buffer) == 0
+ learner_task_ids = [i for i in coordinator._historical_task if 'learner' in i]
+ for i in learner_task_ids:
+ assert len(coordinator._commander._learner_info[i]) == 5
+ os.popen('rm -rf {}*'.format(DATA_PREFIX))
diff --git a/DI-engine/ding/worker/coordinator/tests/test_fake_operator_server.py b/DI-engine/ding/worker/coordinator/tests/test_fake_operator_server.py
new file mode 100644
index 0000000000000000000000000000000000000000..20e0ac5a7bf7ed99145c7eafa8c3a6dbfec28602
--- /dev/null
+++ b/DI-engine/ding/worker/coordinator/tests/test_fake_operator_server.py
@@ -0,0 +1,178 @@
+import pytest
+import os
+import copy
+import time
+from threading import Thread
+import json
+from queue import Queue
+from flask import Flask, request
+
+from ding.worker import Coordinator
+from ding.worker.learner.comm import NaiveLearner
+from ding.worker.collector.comm import NaiveCollector
+from ding.utils import find_free_port
+from ding.config import compile_config_parallel
+from ding.config.utils import parallel_test_main_config, parallel_test_create_config, parallel_test_system_config
+
+DATA_PREFIX = 'SLAVE_COLLECTOR_DATA_FAKE_OPERATOR_TEST'
+init_replicas_request = {
+ "collectors": {
+ "cpu": "0.5",
+ "memory": "200Mi",
+ "replicas": 2,
+ },
+ "learners": {
+ "cpu": "0.5",
+ "memory": "200Mi",
+ "gpu": "0",
+ "replicas": 1,
+ },
+}
+api_version = 'v1alpha1'
+system_addr = 'https://0.0.0.0:14502'
+
+
+def create_app(creator):
+ app = Flask(__name__)
+
+ @app.route('/{}/replicas'.format(api_version), methods=['POST'])
+ def post_replicas():
+ data = json.loads(request.data.decode())
+ collectors = data['collectors']["replicas"]
+ learners = data['learners']["replicas"]
+ creator.set_target_source(learners, collectors)
+ return {'success': True, 'code': 0, 'message': '', 'data': ''}
+
+ @app.route('/{}/replicas'.format(api_version), methods=['GET'])
+ def get_replicas():
+ data = json.loads(request.data.decode())
+ return {'success': True, 'code': 0, 'message': '', 'data': creator.current_resource}
+
+ return app
+
+
+@pytest.fixture(scope='function')
+def setup_config():
+ cfg = compile_config_parallel(
+ parallel_test_main_config, create_cfg=parallel_test_create_config, system_cfg=parallel_test_system_config
+ )
+ cfg.system.coordinator.operator_server = dict(
+ system_addr=system_addr,
+ api_version=api_version,
+ init_replicas_request=init_replicas_request,
+ collector_target_num=len(cfg.system.coordinator.collector),
+ learner_target_num=len(cfg.system.coordinator.learner),
+ )
+ return cfg
+
+
+class Creator:
+
+ def __init__(self, learner_addr, collector_addr):
+ self.learner_addr = learner_addr
+ self.collector_addr = collector_addr
+ self.collector_demand = Queue()
+ self.learner_demand = Queue()
+ self.learners = {}
+ self.collectors = {}
+ self.end_flag = False
+
+ def set_target_source(self, learner_target, collector_target):
+ print('set_target_source', learner_target, collector_target)
+ time.sleep(3) # simulate
+ self.collector_demand.put(collector_target)
+ self.learner_demand.put(learner_target)
+
+ def start(self):
+ while not self.end_flag:
+ if self.learner_demand.empty() and self.collector_demand.empty():
+ time.sleep(0.1)
+ continue
+ else:
+ learner_demand, collector_demand = None, None
+ if not self.learner_demand.empty():
+ learner_demand = self.learner_demand.get()
+ if not self.collector_demand.empty():
+ collector_demand = self.collector_demand.get()
+
+ for i in range(collector_demand):
+ name, host, port = self.collector_addr[i]
+ self.collectors[name] = NaiveCollector(host, port, prefix=DATA_PREFIX)
+ self.collectors[name].start()
+ for i in range(learner_demand):
+ name, host, port = self.learner_addr[i]
+ self.learners[name] = NaiveLearner(host, port, prefix=DATA_PREFIX)
+ self.learners[name].start()
+
+ def close(self):
+ self.end_flag = True
+ time.sleep(1)
+ for t in self.learners.values():
+ t.close()
+ for t in self.collectors.values():
+ t.close()
+
+ @property
+ def current_resource(self):
+ collectors = {k: {} for k in self.collectors}
+ learners = {k: {} for k in self.learners}
+ return {"collectors": collectors, 'learners': learners}
+
+
+@pytest.fixture(scope='function')
+def setup_operator_server(setup_config):
+ host, port = system_addr.split("https://")[1].split(":")
+ port = int(port)
+ learner_addr = copy.deepcopy(setup_config.system.coordinator.learner)
+ learner_addr = list(learner_addr.values())
+ for i in range(len(learner_addr)):
+ learner_addr[i][0] = '{}:{}'.format(learner_addr[i][1], learner_addr[i][2])
+ collector_addr = copy.deepcopy(setup_config.system.coordinator.collector)
+ collector_addr = list(collector_addr.values())
+ for i in range(len(collector_addr)):
+ collector_addr[i][0] = '{}:{}'.format(collector_addr[i][1], collector_addr[i][2])
+ print(learner_addr, collector_addr)
+
+ creator = Creator(learner_addr, collector_addr)
+ creator_start_thread = Thread(target=creator.start, args=(), daemon=True)
+ creator_start_thread.start()
+
+ app = create_app(creator)
+ app_run_thread = Thread(target=app.run, args=(host, port), daemon=True)
+ app_run_thread.start()
+ yield app
+ creator.close()
+ print('end')
+
+
+@pytest.mark.unittest
+class TestCoordinatorFakeOperator:
+
+ def test_naive(self, setup_config, setup_operator_server):
+ os.popen('rm -rf {}*'.format(DATA_PREFIX))
+ # learner/collector is created by operator-server
+ setup_config.system.coordinator.learner = {}
+ setup_config.system.coordinator.collector = {}
+
+ try:
+ coordinator = Coordinator(setup_config)
+ coordinator.start()
+ while True:
+ if coordinator._commander._learner_task_finish_count == 1:
+ break
+ time.sleep(0.5)
+ coordinator.close()
+ except Exception as e:
+ os.popen('rm -rf {}*'.format(DATA_PREFIX))
+ assert False, e
+
+ collector_task_ids = [t for t in coordinator._historical_task if 'collector' in t]
+ for i in range(1, 21):
+ for t in collector_task_ids:
+ assert os.path.exists('{}_{}_{}'.format(DATA_PREFIX, t, i))
+ assert os.path.exists('{}_final_model.pth'.format(DATA_PREFIX))
+ assert len(coordinator._replay_buffer) == 0
+ learner_task_ids = [i for i in coordinator._historical_task if 'learner' in i]
+ for i in learner_task_ids:
+ assert len(coordinator._commander._learner_info[i]) == 5
+ os.popen('rm -rf {}*'.format(DATA_PREFIX))
diff --git a/DI-engine/ding/worker/coordinator/tests/test_one_vs_one_commander.py b/DI-engine/ding/worker/coordinator/tests/test_one_vs_one_commander.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f54f232a1f2b491fb4ea84458f4cc3847fe1086
--- /dev/null
+++ b/DI-engine/ding/worker/coordinator/tests/test_one_vs_one_commander.py
@@ -0,0 +1,144 @@
+import time
+import pytest
+import os
+
+
+@pytest.mark.unittest
+class Test1v1Commander:
+
+ def test_init(self, setup_1v1commander):
+ # basic
+ assert not setup_1v1commander._end_flag
+ # task space
+ assert setup_1v1commander._collector_task_space.cur == setup_1v1commander._collector_task_space.min_val == 0
+ assert setup_1v1commander._collector_task_space.max_val == 2
+ assert setup_1v1commander._learner_task_space.cur == setup_1v1commander._learner_task_space.min_val == 0
+ assert setup_1v1commander._learner_task_space.max_val == 1
+ # league
+ league = setup_1v1commander._league
+ active_players = league.active_players
+ assert len(active_players) == 1
+ active_player = active_players[0]
+ assert active_player.player_id == setup_1v1commander._active_player.player_id
+ # policy
+ assert 'eps' in setup_1v1commander._policy.get_setting_collect({'learner_step': 100, 'envstep': 10000})
+
+ def test_get_task(self, setup_1v1commander):
+ # Must fist learner, then collector.
+ assert setup_1v1commander.get_collector_task() is None
+
+ # Get learner task
+ learner_task_info = setup_1v1commander.get_learner_task()
+ assert setup_1v1commander._learner_task_space.cur == 1
+ learner_task_id = learner_task_info['task_id']
+ assert learner_task_id.startswith('learner_task_'), learner_task_info['task_id']
+ assert len(setup_1v1commander._current_policy_id) == 1
+ assert learner_task_info['policy_id'] == setup_1v1commander._current_policy_id[0]
+ assert learner_task_info['buffer_id'] == setup_1v1commander._current_buffer_id
+ assert setup_1v1commander.get_learner_task() is None
+
+ # Get evaluator task
+ # Only after evaluator task is finished, can get collector task.
+ evaluator_task_info = setup_1v1commander.get_collector_task()
+ assert setup_1v1commander._collector_task_space.cur == 1
+ evaluator_task_id = evaluator_task_info['task_id']
+ assert evaluator_task_id.startswith('evaluator_task_'), evaluator_task_info['task_id']
+ assert evaluator_task_info['collector_cfg'].eval_flag
+ env_kwargs = evaluator_task_info['collector_cfg'].env
+ assert env_kwargs.eval_opponent == setup_1v1commander._league.active_players[0]._eval_opponent_difficulty[0]
+ assert len(evaluator_task_info['collector_cfg'].policy) == 1
+
+ # Finish evaluator task, not reach stop value
+ finished_task_dict = {
+ 'eval_flag': True,
+ 'game_result': [['losses', 'losses'], ['losses', 'draws']],
+ 'train_iter': 0,
+ 'real_episode_count': 4,
+ 'step_count': 4 * 120,
+ 'avg_time_per_episode': 1.89,
+ 'avg_time_per_step': 1.89 / 120,
+ 'avg_step_per_episode': 120.,
+ 'reward_mean': -10.3,
+ 'reward_std': 3.4,
+ }
+ assert not setup_1v1commander.finish_collector_task(evaluator_task_id, finished_task_dict)
+ assert setup_1v1commander._collector_task_space.cur == 0
+
+ # Get collector_task
+ collector_task_info = setup_1v1commander.get_collector_task()
+ assert setup_1v1commander._collector_task_space.cur == 1
+ collector_task_id = collector_task_info['task_id']
+ assert collector_task_id.startswith('collector_task_'), collector_task_info['task_id']
+ assert collector_task_info['buffer_id'] == learner_task_info['buffer_id']
+ assert 'eps' in collector_task_info['collector_cfg'].collect_setting
+ policy_update_path = collector_task_info['collector_cfg'].policy_update_path
+ assert len(policy_update_path) == 2
+ assert policy_update_path[0] == policy_update_path[1]
+ policy_update_flag = collector_task_info['collector_cfg'].policy_update_flag
+ assert policy_update_flag[0] == policy_update_flag[1]
+ assert not collector_task_info['collector_cfg'].eval_flag
+ assert len(collector_task_info['collector_cfg'].policy) == 2
+
+ # Finish collector_task
+ finished_task_dict = {
+ 'eval_flag': False,
+ 'game_result': [['losses', 'losses'], ['losses', 'losses']],
+ 'step_count': 400,
+ 'train_iter': 20,
+ 'real_episode_count': 8,
+ 'avg_time_per_episode': 1.33,
+ 'avg_time_per_step': 1.33 / 500,
+ 'avg_step_per_episode': 50.,
+ 'reward_mean': 11.,
+ 'reward_std': 3.,
+ }
+ assert not setup_1v1commander.finish_collector_task(collector_task_id, finished_task_dict)
+ assert setup_1v1commander._collector_task_space.cur == 0
+
+ # Update learner info
+ for i in range(0, 101, 10):
+ learner_info = {
+ 'learner_step': i,
+ }
+ setup_1v1commander.update_learner_info('some_task_id', learner_info)
+
+ # Get evaluator task; Finish evaluator task and reach stop value.
+ time.sleep(5 + 0.1)
+ evaluator_task_info = setup_1v1commander.get_collector_task()
+ evaluator_task_id = evaluator_task_info['task_id']
+ assert setup_1v1commander._collector_task_space.cur == 1
+ assert evaluator_task_info['collector_cfg'].eval_flag
+ finished_task_dict = {
+ 'eval_flag': True,
+ 'game_result': [['wins', 'wins'], ['wins', 'wins']],
+ 'train_iter': 100,
+ 'real_episode_count': 4,
+ 'step_count': 4 * 120,
+ 'avg_time_per_episode': 1.89,
+ 'avg_time_per_step': 1.89 / 120,
+ 'avg_step_per_episode': 120.,
+ 'reward_mean': 20.,
+ 'reward_std': 0.,
+ }
+ assert setup_1v1commander.finish_collector_task(evaluator_task_id, finished_task_dict)
+ assert setup_1v1commander._end_flag
+ assert setup_1v1commander._collector_task_space.cur == 0
+
+ # Finish learner task
+ finished_task_dict = {'buffer_id': setup_1v1commander._current_buffer_id}
+ setup_1v1commander.finish_learner_task(learner_task_id, finished_task_dict)
+ assert setup_1v1commander._learner_task_space.cur == 0
+
+ @pytest.mark.notify
+ def test_notify(self, setup_1v1commander):
+ _ = setup_1v1commander.get_learner_task()
+ setup_1v1commander.notify_fail_learner_task({})
+ time.sleep(0.01)
+ assert setup_1v1commander._learner_task_space.cur == 0
+ _ = setup_1v1commander.get_collector_task()
+ setup_1v1commander.notify_fail_collector_task({})
+ time.sleep(0.01)
+ assert setup_1v1commander._collector_task_space.cur == 0
+
+ os.popen('rm -rf log')
+ os.popen('rm -rf total_config.py')
diff --git a/DI-engine/ding/worker/learner/__init__.py b/DI-engine/ding/worker/learner/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c75e26315d5ddd0bca930dcd5e7611899fa9eee4
--- /dev/null
+++ b/DI-engine/ding/worker/learner/__init__.py
@@ -0,0 +1,3 @@
+from .base_learner import BaseLearner, create_learner
+from .comm import BaseCommLearner, FlaskFileSystemLearner, create_comm_learner
+from .learner_hook import register_learner_hook, add_learner_hook, merge_hooks, LearnerHook, build_learner_hook_by_cfg
diff --git a/DI-engine/ding/worker/learner/base_learner.py b/DI-engine/ding/worker/learner/base_learner.py
new file mode 100644
index 0000000000000000000000000000000000000000..1144a412cd5cc13ef785492a6c9a93f7f15bebc2
--- /dev/null
+++ b/DI-engine/ding/worker/learner/base_learner.py
@@ -0,0 +1,536 @@
+from typing import Any, Union, Callable, List, Dict, Optional, Tuple
+from ditk import logging
+from collections import namedtuple
+from functools import partial
+from easydict import EasyDict
+
+import copy
+
+from ding.torch_utils import CountVar, auto_checkpoint, build_log_buffer
+from ding.utils import build_logger, EasyTimer, import_module, LEARNER_REGISTRY, get_rank, get_world_size
+from ding.utils.autolog import LoggedValue, LoggedModel, TickTime
+from ding.utils.data import AsyncDataLoader
+from .learner_hook import build_learner_hook_by_cfg, add_learner_hook, merge_hooks, LearnerHook
+
+
+@LEARNER_REGISTRY.register('base')
+class BaseLearner(object):
+ r"""
+ Overview:
+ Base class for policy learning.
+ Interface:
+ train, call_hook, register_hook, save_checkpoint, start, setup_dataloader, close
+ Property:
+ learn_info, priority_info, last_iter, train_iter, rank, world_size, policy
+ monitor, log_buffer, logger, tb_logger, ckpt_name, exp_name, instance_name
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ config = dict(
+ train_iterations=int(1e9),
+ dataloader=dict(num_workers=0, ),
+ log_policy=True,
+ # --- Hooks ---
+ hook=dict(
+ load_ckpt_before_run='',
+ log_show_after_iter=100,
+ save_ckpt_after_iter=10000,
+ save_ckpt_after_run=True,
+ ),
+ )
+
+ _name = "BaseLearner" # override this variable for sub-class learner
+
+ def __init__(
+ self,
+ cfg: EasyDict,
+ policy: namedtuple = None,
+ tb_logger: Optional['SummaryWriter'] = None, # noqa
+ dist_info: Tuple[int, int] = None,
+ exp_name: Optional[str] = 'default_experiment',
+ instance_name: Optional[str] = 'learner',
+ ) -> None:
+ """
+ Overview:
+ Initialization method, build common learner components according to cfg, such as hook, wrapper and so on.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Learner config, you can refer cls.config for details.
+ - policy (:obj:`namedtuple`): A collection of policy function of learn mode. And policy can also be \
+ initialized when runtime.
+ - tb_logger (:obj:`SummaryWriter`): Tensorboard summary writer.
+ - dist_info (:obj:`Tuple[int, int]`): Multi-GPU distributed training information.
+ - exp_name (:obj:`str`): Experiment name, which is used to indicate output directory.
+ - instance_name (:obj:`str`): Instance name, which should be unique among different learners.
+ Notes:
+ If you want to debug in sync CUDA mode, please add the following code at the beginning of ``__init__``.
+
+ .. code:: python
+
+ os.environ['CUDA_LAUNCH_BLOCKING'] = "1" # for debug async CUDA
+ """
+ self._cfg = cfg
+ self._exp_name = exp_name
+ self._instance_name = instance_name
+ self._ckpt_name = None
+ self._timer = EasyTimer()
+
+ # These 2 attributes are only used in parallel mode.
+ self._end_flag = False
+ self._learner_done = False
+ if dist_info is None:
+ self._rank = get_rank()
+ self._world_size = get_world_size()
+ else:
+ # Learner rank. Used to discriminate which GPU it uses.
+ self._rank, self._world_size = dist_info
+ if self._world_size > 1:
+ self._cfg.hook.log_reduce_after_iter = True
+
+ # Logger (Monitor will be initialized in policy setter)
+ # Only rank == 0 learner needs monitor and tb_logger, others only need text_logger to display terminal output.
+ if self._rank == 0:
+ if tb_logger is not None:
+ self._logger, _ = build_logger(
+ './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
+ )
+ self._tb_logger = tb_logger
+ else:
+ self._logger, self._tb_logger = build_logger(
+ './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name
+ )
+ else:
+ self._logger, _ = build_logger(
+ './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
+ )
+ self._tb_logger = None
+ self._log_buffer = {
+ 'scalar': build_log_buffer(),
+ 'scalars': build_log_buffer(),
+ 'histogram': build_log_buffer(),
+ }
+
+ # Setup policy
+ if policy is not None:
+ self.policy = policy
+
+ # Learner hooks. Used to do specific things at specific time point. Will be set in ``_setup_hook``
+ self._hooks = {'before_run': [], 'before_iter': [], 'after_iter': [], 'after_run': []}
+ # Last iteration. Used to record current iter.
+ self._last_iter = CountVar(init_val=0)
+
+ # Setup time wrapper and hook.
+ self._setup_wrapper()
+ self._setup_hook()
+
+ def _setup_hook(self) -> None:
+ """
+ Overview:
+ Setup hook for base_learner. Hook is the way to implement some functions at specific time point
+ in base_learner. You can refer to ``learner_hook.py``.
+ """
+ if hasattr(self, '_hooks'):
+ self._hooks = merge_hooks(self._hooks, build_learner_hook_by_cfg(self._cfg.hook))
+ else:
+ self._hooks = build_learner_hook_by_cfg(self._cfg.hook)
+
+ def _setup_wrapper(self) -> None:
+ """
+ Overview:
+ Use ``_time_wrapper`` to get ``train_time``.
+ Note:
+ ``data_time`` is wrapped in ``setup_dataloader``.
+ """
+ self._wrapper_timer = EasyTimer()
+ self.train = self._time_wrapper(self.train, 'scalar', 'train_time')
+
+ def _time_wrapper(self, fn: Callable, var_type: str, var_name: str) -> Callable:
+ """
+ Overview:
+ Wrap a function and record the time it used in ``_log_buffer``.
+ Arguments:
+ - fn (:obj:`Callable`): Function to be time_wrapped.
+ - var_type (:obj:`str`): Variable type, e.g. ['scalar', 'scalars', 'histogram'].
+ - var_name (:obj:`str`): Variable name, e.g. ['cur_lr', 'total_loss'].
+ Returns:
+ - wrapper (:obj:`Callable`): The wrapper to acquire a function's time.
+ """
+
+ def wrapper(*args, **kwargs) -> Any:
+ with self._wrapper_timer:
+ ret = fn(*args, **kwargs)
+ self._log_buffer[var_type][var_name] = self._wrapper_timer.value
+ return ret
+
+ return wrapper
+
+ def register_hook(self, hook: LearnerHook) -> None:
+ """
+ Overview:
+ Add a new learner hook.
+ Arguments:
+ - hook (:obj:`LearnerHook`): The hook to be addedr.
+ """
+ add_learner_hook(self._hooks, hook)
+
+ def train(self, data: dict, envstep: int = -1, policy_kwargs: Optional[dict] = None) -> None:
+ """
+ Overview:
+ Given training data, implement network update for one iteration and update related variables.
+ Learner's API for serial entry.
+ Also called in ``start`` for each iteration's training.
+ Arguments:
+ - data (:obj:`dict`): Training data which is retrieved from repaly buffer.
+
+ .. note::
+
+ ``_policy`` must be set before calling this method.
+
+ ``_policy.forward`` method contains: forward, backward, grad sync(if in multi-gpu mode) and
+ parameter update.
+
+ ``before_iter`` and ``after_iter`` hooks are called at the beginning and ending.
+ """
+ assert hasattr(self, '_policy'), "please set learner policy"
+ self.call_hook('before_iter')
+
+ if policy_kwargs is None:
+ policy_kwargs = {}
+
+ # Forward
+ log_vars = self._policy.forward(data, **policy_kwargs)
+
+ # Update replay buffer's priority info
+ if isinstance(log_vars, dict):
+ priority = log_vars.pop('priority', None)
+ elif isinstance(log_vars, list):
+ priority = log_vars[-1].pop('priority', None)
+ else:
+ raise TypeError("not support type for log_vars: {}".format(type(log_vars)))
+ if priority is not None:
+ replay_buffer_idx = [d.get('replay_buffer_idx', None) for d in data]
+ replay_unique_id = [d.get('replay_unique_id', None) for d in data]
+ self.priority_info = {
+ 'priority': priority,
+ 'replay_buffer_idx': replay_buffer_idx,
+ 'replay_unique_id': replay_unique_id,
+ }
+ # Discriminate vars in scalar, scalars and histogram type
+ # Regard a var as scalar type by default. For scalars and histogram type, must annotate by prefix "[xxx]"
+ self._collector_envstep = envstep
+ if isinstance(log_vars, dict):
+ log_vars = [log_vars]
+ for elem in log_vars:
+ scalars_vars, histogram_vars = {}, {}
+ for k in list(elem.keys()):
+ if "[scalars]" in k:
+ new_k = k.split(']')[-1]
+ scalars_vars[new_k] = elem.pop(k)
+ elif "[histogram]" in k:
+ new_k = k.split(']')[-1]
+ histogram_vars[new_k] = elem.pop(k)
+ # Update log_buffer
+ self._log_buffer['scalar'].update(elem)
+ self._log_buffer['scalars'].update(scalars_vars)
+ self._log_buffer['histogram'].update(histogram_vars)
+
+ self.call_hook('after_iter')
+ self._last_iter.add(1)
+
+ return log_vars
+
+ @auto_checkpoint
+ def start(self) -> None:
+ """
+ Overview:
+ [Only Used In Parallel Mode] Learner's API for parallel entry.
+ For each iteration, learner will get data through ``_next_data`` and call ``train`` to train.
+
+ .. note::
+
+ ``before_run`` and ``after_run`` hooks are called at the beginning and ending.
+ """
+ self._end_flag = False
+ self._learner_done = False
+ # before run hook
+ self.call_hook('before_run')
+
+ for i in range(self._cfg.train_iterations):
+ data = self._next_data()
+ if self._end_flag:
+ break
+ self.train(data)
+
+ self._learner_done = True
+ # after run hook
+ self.call_hook('after_run')
+
+ def setup_dataloader(self) -> None:
+ """
+ Overview:
+ [Only Used In Parallel Mode] Setup learner's dataloader.
+
+ .. note::
+
+ Only in parallel mode will we use attributes ``get_data`` and ``_dataloader`` to get data from file system;
+ Instead, in serial version, we can fetch data from memory directly.
+
+ In parallel mode, ``get_data`` is set by ``LearnerCommHelper``, and should be callable.
+ Users don't need to know the related details if not necessary.
+ """
+ cfg = self._cfg.dataloader
+ batch_size = self._policy.get_attribute('batch_size')
+ device = self._policy.get_attribute('device')
+ chunk_size = cfg.chunk_size if 'chunk_size' in cfg else batch_size
+ self._dataloader = AsyncDataLoader(
+ self.get_data, batch_size, device, chunk_size, collate_fn=lambda x: x, num_workers=cfg.num_workers
+ )
+ self._next_data = self._time_wrapper(self._next_data, 'scalar', 'data_time')
+
+ def _next_data(self) -> Any:
+ """
+ Overview:
+ [Only Used In Parallel Mode] Call ``_dataloader``'s ``__next__`` method to return next training data.
+ Returns:
+ - data (:obj:`Any`): Next training data from dataloader.
+ """
+ return next(self._dataloader)
+
+ def close(self) -> None:
+ """
+ Overview:
+ [Only Used In Parallel Mode] Close the related resources, e.g. dataloader, tensorboard logger, etc.
+ """
+ if self._end_flag:
+ return
+ self._end_flag = True
+ if hasattr(self, '_dataloader'):
+ self._dataloader.close()
+ if self._tb_logger:
+ self._tb_logger.flush()
+ self._tb_logger.close()
+
+ def __del__(self) -> None:
+ self.close()
+
+ def call_hook(self, name: str) -> None:
+ """
+ Overview:
+ Call the corresponding hook plugins according to position name.
+ Arguments:
+ - name (:obj:`str`): Hooks in which position to call, \
+ should be in ['before_run', 'after_run', 'before_iter', 'after_iter'].
+ """
+ for hook in self._hooks[name]:
+ hook(self)
+
+ def info(self, s: str) -> None:
+ """
+ Overview:
+ Log string info by ``self._logger.info``.
+ Arguments:
+ - s (:obj:`str`): The message to add into the logger.
+ """
+ self._logger.info('[RANK{}]: {}'.format(self._rank, s))
+
+ def debug(self, s: str) -> None:
+ self._logger.debug('[RANK{}]: {}'.format(self._rank, s))
+
+ def save_checkpoint(self, ckpt_name: str = None) -> None:
+ """
+ Overview:
+ Directly call ``save_ckpt_after_run`` hook to save checkpoint.
+ Note:
+ Must guarantee that "save_ckpt_after_run" is registered in "after_run" hook.
+ This method is called in:
+
+ - ``auto_checkpoint`` (``torch_utils/checkpoint_helper.py``), which is designed for \
+ saving checkpoint whenever an exception raises.
+ - ``serial_pipeline`` (``entry/serial_entry.py``). Used to save checkpoint when reaching \
+ new highest episode return.
+ """
+ if ckpt_name is not None:
+ self.ckpt_name = ckpt_name
+ names = [h.name for h in self._hooks['after_run']]
+ assert 'save_ckpt_after_run' in names
+ idx = names.index('save_ckpt_after_run')
+ self._hooks['after_run'][idx](self)
+ self.ckpt_name = None
+
+ @property
+ def learn_info(self) -> dict:
+ """
+ Overview:
+ Get current info dict, which will be sent to commander, e.g. replay buffer priority update,
+ current iteration, hyper-parameter adjustment, whether task is finished, etc.
+ Returns:
+ - info (:obj:`dict`): Current learner info dict.
+ """
+ ret = {
+ 'learner_step': self._last_iter.val,
+ 'priority_info': self.priority_info,
+ 'learner_done': self._learner_done,
+ }
+ return ret
+
+ @property
+ def last_iter(self) -> CountVar:
+ return self._last_iter
+
+ @property
+ def train_iter(self) -> int:
+ return self._last_iter.val
+
+ @property
+ def monitor(self) -> 'TickMonitor': # noqa
+ return self._monitor
+
+ @property
+ def log_buffer(self) -> dict: # LogDict
+ return self._log_buffer
+
+ @log_buffer.setter
+ def log_buffer(self, _log_buffer: Dict[str, Dict[str, Any]]) -> None:
+ self._log_buffer = _log_buffer
+
+ @property
+ def logger(self) -> logging.Logger:
+ return self._logger
+
+ @property
+ def tb_logger(self) -> 'TensorBoradLogger': # noqa
+ return self._tb_logger
+
+ @property
+ def exp_name(self) -> str:
+ return self._exp_name
+
+ @property
+ def instance_name(self) -> str:
+ return self._instance_name
+
+ @property
+ def rank(self) -> int:
+ return self._rank
+
+ @property
+ def world_size(self) -> int:
+ return self._world_size
+
+ @property
+ def policy(self) -> 'Policy': # noqa
+ return self._policy
+
+ @policy.setter
+ def policy(self, _policy: 'Policy') -> None: # noqa
+ """
+ Note:
+ Policy variable monitor is set alongside with policy, because variables are determined by specific policy.
+ """
+ self._policy = _policy
+ if self._rank == 0:
+ self._monitor = get_simple_monitor_type(self._policy.monitor_vars())(TickTime(), expire=10)
+ if self._cfg.log_policy:
+ self.info(self._policy.info())
+
+ @property
+ def priority_info(self) -> dict:
+ if not hasattr(self, '_priority_info'):
+ self._priority_info = {}
+ return self._priority_info
+
+ @priority_info.setter
+ def priority_info(self, _priority_info: dict) -> None:
+ self._priority_info = _priority_info
+
+ @property
+ def ckpt_name(self) -> str:
+ return self._ckpt_name
+
+ @ckpt_name.setter
+ def ckpt_name(self, _ckpt_name: str) -> None:
+ self._ckpt_name = _ckpt_name
+
+
+def create_learner(cfg: EasyDict, **kwargs) -> BaseLearner:
+ """
+ Overview:
+ Given the key(learner_name), create a new learner instance if in learner_mapping's values,
+ or raise an KeyError. In other words, a derived learner must first register, then can call ``create_learner``
+ to get the instance.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Learner config. Necessary keys: [learner.import_module, learner.learner_type].
+ Returns:
+ - learner (:obj:`BaseLearner`): The created new learner, should be an instance of one of \
+ learner_mapping's values.
+ """
+ import_module(cfg.get('import_names', []))
+ return LEARNER_REGISTRY.build(cfg.type, cfg=cfg, **kwargs)
+
+
+class TickMonitor(LoggedModel):
+ """
+ Overview:
+ TickMonitor is to monitor related info during training.
+ Info includes: cur_lr, time(data, train, forward, backward), loss(total,...)
+ These info variables are firstly recorded in ``log_buffer``, then in ``LearnerHook`` will vars in
+ in this monitor be updated by``log_buffer``, finally printed to text logger and tensorboard logger.
+ Interface:
+ __init__, fixed_time, current_time, freeze, unfreeze, register_attribute_value, __getattr__
+ Property:
+ time, expire
+ """
+ data_time = LoggedValue(float)
+ train_time = LoggedValue(float)
+ total_collect_step = LoggedValue(float)
+ total_step = LoggedValue(float)
+ total_episode = LoggedValue(float)
+ total_sample = LoggedValue(float)
+ total_duration = LoggedValue(float)
+
+ def __init__(self, time_: 'BaseTime', expire: Union[int, float]): # noqa
+ LoggedModel.__init__(self, time_, expire)
+ self.__register()
+
+ def __register(self):
+
+ def __avg_func(prop_name: str) -> float:
+ records = self.range_values[prop_name]()
+ _list = [_value for (_begin_time, _end_time), _value in records]
+ return sum(_list) / len(_list) if len(_list) != 0 else 0
+
+ def __val_func(prop_name: str) -> float:
+ records = self.range_values[prop_name]()
+ return records[-1][1]
+
+ for k in getattr(self, '_LoggedModel__properties'):
+ self.register_attribute_value('avg', k, partial(__avg_func, prop_name=k))
+ self.register_attribute_value('val', k, partial(__val_func, prop_name=k))
+
+
+def get_simple_monitor_type(properties: List[str] = []) -> TickMonitor:
+ """
+ Overview:
+ Besides basic training variables provided in ``TickMonitor``, many policies have their own customized
+ ones to record and monitor. This function can return a customized tick monitor.
+ Compared with ``TickMonitor``, ``SimpleTickMonitor`` can record extra ``properties`` passed in by a policy.
+ Argumenst:
+ - properties (:obj:`List[str]`): Customized properties to monitor.
+ Returns:
+ - simple_tick_monitor (:obj:`SimpleTickMonitor`): A simple customized tick monitor.
+ """
+ if len(properties) == 0:
+ return TickMonitor
+ else:
+ attrs = {}
+ properties = [
+ 'data_time', 'train_time', 'sample_count', 'total_collect_step', 'total_step', 'total_sample',
+ 'total_episode', 'total_duration'
+ ] + properties
+ for p_name in properties:
+ attrs[p_name] = LoggedValue(float)
+ return type('SimpleTickMonitor', (TickMonitor, ), attrs)
diff --git a/DI-engine/ding/worker/learner/comm/__init__.py b/DI-engine/ding/worker/learner/comm/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..388fd23b0c19ee227ff7aa9b53e31cd533522283
--- /dev/null
+++ b/DI-engine/ding/worker/learner/comm/__init__.py
@@ -0,0 +1,3 @@
+from .base_comm_learner import BaseCommLearner, create_comm_learner
+from .flask_fs_learner import FlaskFileSystemLearner
+from .utils import NaiveLearner # for test
diff --git a/DI-engine/ding/worker/learner/comm/base_comm_learner.py b/DI-engine/ding/worker/learner/comm/base_comm_learner.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a9562888fdbf76b108acb9bbc8bc5692d1e6449
--- /dev/null
+++ b/DI-engine/ding/worker/learner/comm/base_comm_learner.py
@@ -0,0 +1,138 @@
+from abc import ABC, abstractmethod, abstractproperty
+from easydict import EasyDict
+
+from ding.utils import EasyTimer, import_module, get_task_uid, dist_init, dist_finalize, COMM_LEARNER_REGISTRY
+from ding.policy import create_policy
+from ding.worker.learner import create_learner
+
+
+class BaseCommLearner(ABC):
+ """
+ Overview:
+ Abstract baseclass for CommLearner.
+ Interfaces:
+ __init__, send_policy, get_data, send_learn_info, start, close
+ Property:
+ hooks4call
+ """
+
+ def __init__(self, cfg: 'EasyDict') -> None: # noqa
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config dict
+ """
+ self._cfg = cfg
+ self._learner_uid = get_task_uid()
+ self._timer = EasyTimer()
+ if cfg.multi_gpu:
+ self._rank, self._world_size = dist_init()
+ else:
+ self._rank, self._world_size = 0, 1
+ self._multi_gpu = cfg.multi_gpu
+ self._end_flag = True
+
+ @abstractmethod
+ def send_policy(self, state_dict: dict) -> None:
+ """
+ Overview:
+ Save learner's policy in corresponding path.
+ Will be registered in base learner.
+ Arguments:
+ - state_dict (:obj:`dict`): State dict of the runtime policy.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def get_data(self, batch_size: int) -> list:
+ """
+ Overview:
+ Get batched meta data from coordinator.
+ Will be registered in base learner.
+ Arguments:
+ - batch_size (:obj:`int`): Batch size.
+ Returns:
+ - stepdata (:obj:`list`): A list of training data, each element is one trajectory.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def send_learn_info(self, learn_info: dict) -> None:
+ """
+ Overview:
+ Send learn info to coordinator.
+ Will be registered in base learner.
+ Arguments:
+ - learn_info (:obj:`dict`): Learn info in dict type.
+ """
+ raise NotImplementedError
+
+ def start(self) -> None:
+ """
+ Overview:
+ Start comm learner.
+ """
+ self._end_flag = False
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close comm learner.
+ """
+ self._end_flag = True
+ if self._multi_gpu:
+ dist_finalize()
+
+ @abstractproperty
+ def hooks4call(self) -> list:
+ """
+ Returns:
+ - hooks (:obj:`list`): The hooks which comm learner has. Will be registered in learner as well.
+ """
+ raise NotImplementedError
+
+ def _create_learner(self, task_info: dict) -> 'BaseLearner': # noqa
+ """
+ Overview:
+ Receive ``task_info`` passed from coordinator and create a learner.
+ Arguments:
+ - task_info (:obj:`dict`): Task info dict from coordinator. Should be like \
+ {"learner_cfg": xxx, "policy": xxx}.
+ Returns:
+ - learner (:obj:`BaseLearner`): Created base learner.
+
+ .. note::
+ Three methods('get_data', 'send_policy', 'send_learn_info'), dataloader and policy are set.
+ The reason why they are set here rather than base learner is that, they highly depend on the specific task.
+ Only after task info is passed from coordinator to comm learner through learner slave, can they be
+ clarified and initialized.
+ """
+ # Prepare learner config and instantiate a learner object.
+ learner_cfg = EasyDict(task_info['learner_cfg'])
+ learner = create_learner(learner_cfg, dist_info=[self._rank, self._world_size], exp_name=learner_cfg.exp_name)
+ # Set 3 methods and dataloader in created learner that are necessary in parallel setting.
+ for item in ['get_data', 'send_policy', 'send_learn_info']:
+ setattr(learner, item, getattr(self, item))
+ # Set policy in created learner.
+ policy_cfg = task_info['policy']
+ policy_cfg = EasyDict(policy_cfg)
+ learner.policy = create_policy(policy_cfg, enable_field=['learn']).learn_mode
+ learner.setup_dataloader()
+ return learner
+
+
+def create_comm_learner(cfg: EasyDict) -> BaseCommLearner:
+ """
+ Overview:
+ Given the key(comm_learner_name), create a new comm learner instance if in comm_map's values,
+ or raise an KeyError. In other words, a derived comm learner must first register,
+ then can call ``create_comm_learner`` to get the instance.
+ Arguments:
+ - cfg (:obj:`dict`): Learner config. Necessary keys: [import_names, comm_learner_type].
+ Returns:
+ - learner (:obj:`BaseCommLearner`): The created new comm learner, should be an instance of one of \
+ comm_map's values.
+ """
+ import_module(cfg.get('import_names', []))
+ return COMM_LEARNER_REGISTRY.build(cfg.type, cfg=cfg)
diff --git a/DI-engine/ding/worker/learner/comm/flask_fs_learner.py b/DI-engine/ding/worker/learner/comm/flask_fs_learner.py
new file mode 100644
index 0000000000000000000000000000000000000000..3cba39c735f99f6f52c38bf5ee1a748b762a7853
--- /dev/null
+++ b/DI-engine/ding/worker/learner/comm/flask_fs_learner.py
@@ -0,0 +1,403 @@
+import os
+import time
+from typing import List, Union, Dict, Callable, Any
+from functools import partial
+from queue import Queue
+from threading import Thread
+
+from ding.utils import read_file, save_file, get_data_decompressor, COMM_LEARNER_REGISTRY
+from ding.utils.file_helper import read_from_di_store
+from ding.interaction import Slave, TaskFail
+from .base_comm_learner import BaseCommLearner
+from ..learner_hook import LearnerHook
+
+
+class LearnerSlave(Slave):
+ """
+ Overview:
+ A slave, whose master is coordinator.
+ Used to pass message between comm learner and coordinator.
+ """
+
+ def __init__(self, *args, callback_fn: Dict[str, Callable], **kwargs) -> None:
+ """
+ Overview:
+ Init callback functions additionally. Callback functions are methods in comm learner.
+ """
+ super().__init__(*args, **kwargs)
+ self._callback_fn = callback_fn
+
+ def _process_task(self, task: dict) -> Union[dict, TaskFail]:
+ """
+ Overview:
+ Process a task according to input task info dict, which is passed in by master coordinator.
+ For each type of task, you can refer to corresponding callback function in comm learner for details.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Task dict. Must contain key "name".
+ Returns:
+ - result (:obj:`Union[dict, TaskFail]`): Task result dict, or task fail exception.
+ """
+ task_name = task['name']
+ if task_name == 'resource':
+ return self._callback_fn['deal_with_resource']()
+ elif task_name == 'learner_start_task':
+ self._current_task_info = task['task_info']
+ self._callback_fn['deal_with_learner_start'](self._current_task_info)
+ return {'message': 'learner task has started'}
+ elif task_name == 'learner_get_data_task':
+ data_demand = self._callback_fn['deal_with_get_data']()
+ ret = {
+ 'task_id': self._current_task_info['task_id'],
+ 'buffer_id': self._current_task_info['buffer_id'],
+ }
+ ret.update(data_demand)
+ return ret
+ elif task_name == 'learner_learn_task':
+ info = self._callback_fn['deal_with_learner_learn'](task['data'])
+ data = {'info': info}
+ data['buffer_id'] = self._current_task_info['buffer_id']
+ data['task_id'] = self._current_task_info['task_id']
+ return data
+ elif task_name == 'learner_close_task':
+ self._callback_fn['deal_with_learner_close']()
+ return {
+ 'task_id': self._current_task_info['task_id'],
+ 'buffer_id': self._current_task_info['buffer_id'],
+ }
+ else:
+ raise TaskFail(result={'message': 'task name error'}, message='illegal learner task <{}>'.format(task_name))
+
+
+@COMM_LEARNER_REGISTRY.register('flask_fs')
+class FlaskFileSystemLearner(BaseCommLearner):
+ """
+ Overview:
+ An implementation of CommLearner, using flask and the file system.
+ Interfaces:
+ __init__, send_policy, get_data, send_learn_info, start, close
+ Property:
+ hooks4call
+ """
+
+ def __init__(self, cfg: 'EasyDict') -> None: # noqa
+ """
+ Overview:
+ Init method.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config dict.
+ """
+ BaseCommLearner.__init__(self, cfg)
+
+ # Callback functions for message passing between comm learner and coordinator.
+ self._callback_fn = {
+ 'deal_with_resource': self.deal_with_resource,
+ 'deal_with_learner_start': self.deal_with_learner_start,
+ 'deal_with_get_data': self.deal_with_get_data,
+ 'deal_with_learner_learn': self.deal_with_learner_learn,
+ 'deal_with_learner_close': self.deal_with_learner_close,
+ }
+ # Learner slave to implement those callback functions. Host and port is used to build connection with master.
+ host, port = cfg.host, cfg.port
+ if isinstance(port, list):
+ port = port[self._rank]
+ elif isinstance(port, int) and self._world_size > 1:
+ port = port + self._rank
+ self._slave = LearnerSlave(host, port, callback_fn=self._callback_fn)
+
+ self._path_data = cfg.path_data # path to read data from
+ self._path_policy = cfg.path_policy # path to save policy
+
+ # Queues to store info dicts. Only one info is needed to pass between learner and coordinator at a time.
+ self._data_demand_queue = Queue(maxsize=1)
+ self._data_result_queue = Queue(maxsize=1)
+ self._learn_info_queue = Queue(maxsize=1)
+
+ # Task-level learner and policy will only be set once received the task.
+ self._learner = None
+ self._policy_id = None
+
+ def start(self) -> None:
+ """
+ Overview:
+ Start comm learner itself and the learner slave.
+ """
+ BaseCommLearner.start(self)
+ self._slave.start()
+
+ def close(self) -> None:
+ """
+ Overview:
+ Join learner thread and close learner if still running.
+ Then close learner slave and comm learner itself.
+ """
+ if self._end_flag:
+ return
+ if self._learner is not None:
+ self.deal_with_learner_close()
+ self._slave.close()
+ BaseCommLearner.close(self)
+
+ def __del__(self) -> None:
+ """
+ Overview:
+ Call ``close`` for deletion.
+ """
+ self.close()
+
+ def deal_with_resource(self) -> dict:
+ """
+ Overview:
+ Callback function. Return how many resources are needed to start current learner.
+ Returns:
+ - resource (:obj:`dict`): Resource info dict, including ["gpu"].
+ """
+ return {'gpu': self._world_size}
+
+ def deal_with_learner_start(self, task_info: dict) -> None:
+ """
+ Overview:
+ Callback function. Create a learner and help register its hooks. Start a learner thread of the created one.
+ Arguments:
+ - task_info (:obj:`dict`): Task info dict.
+
+ .. note::
+ In ``_create_learner`` method in base class ``BaseCommLearner``, 3 methods
+ ('get_data', 'send_policy', 'send_learn_info'), dataloader and policy are set.
+ You can refer to it for details.
+ """
+ self._policy_id = task_info['policy_id']
+ self._league_save_checkpoint_path = task_info.get('league_save_checkpoint_path', None)
+ self._learner = self._create_learner(task_info)
+ for h in self.hooks4call:
+ self._learner.register_hook(h)
+ self._learner_thread = Thread(target=self._learner.start, args=(), daemon=True, name='learner_start')
+ self._learner_thread.start()
+
+ def deal_with_get_data(self) -> Any:
+ """
+ Overview:
+ Callback function. Get data demand info dict from ``_data_demand_queue``,
+ which will be sent to coordinator afterwards.
+ Returns:
+ - data_demand (:obj:`Any`): Data demand info dict.
+ """
+ data_demand = self._data_demand_queue.get()
+ return data_demand
+
+ def deal_with_learner_learn(self, data: dict) -> dict:
+ """
+ Overview:
+ Callback function. Put training data info dict (i.e. meta data), which is received from coordinator, into
+ ``_data_result_queue``, and wait for ``get_data`` to retrieve. Wait for learner training and
+ get learn info dict from ``_learn_info_queue``. If task is finished, join the learner thread and
+ close the learner.
+ Returns:
+ - learn_info (:obj:`Any`): Learn info dict.
+ """
+ self._data_result_queue.put(data)
+ learn_info = self._learn_info_queue.get()
+ return learn_info
+
+ def deal_with_learner_close(self) -> None:
+ self._learner.close()
+ self._learner_thread.join()
+ del self._learner_thread
+ self._learner = None
+ self._policy_id = None
+
+ # override
+ def send_policy(self, state_dict: dict) -> None:
+ """
+ Overview:
+ Save learner's policy in corresponding path, called by ``SendPolicyHook``.
+ Arguments:
+ - state_dict (:obj:`dict`): State dict of the policy.
+ """
+ if not os.path.exists(self._path_policy):
+ os.mkdir(self._path_policy)
+ path = self._policy_id
+ if self._path_policy not in path:
+ path = os.path.join(self._path_policy, path)
+ setattr(self, "_latest_policy_path", path)
+ save_file(path, state_dict, use_lock=True)
+
+ if self._league_save_checkpoint_path is not None:
+ save_file(self._league_save_checkpoint_path, state_dict, use_lock=True)
+
+ @staticmethod
+ def load_data_fn(path, meta: Dict[str, Any], decompressor: Callable) -> Any:
+ """
+ Overview:
+ The function that is used to load data file.
+ Arguments:
+ - meta (:obj:`Dict[str, Any]`): Meta data info dict.
+ - decompressor (:obj:`Callable`): Decompress function.
+ Returns:
+ - s (:obj:`Any`): Data which is read from file.
+ """
+ # Due to read-write conflict, read_file raise an error, therefore we set a while loop.
+ while True:
+ try:
+ s = read_from_di_store(path) if read_from_di_store else read_file(path, use_lock=False)
+ s = decompressor(s)
+ break
+ except Exception:
+ time.sleep(0.01)
+ unroll_len = meta.get('unroll_len', 1)
+ if 'unroll_split_begin' in meta:
+ begin = meta['unroll_split_begin']
+ if unroll_len == 1:
+ s = s[begin]
+ s.update(meta)
+ else:
+ end = begin + unroll_len
+ s = s[begin:end]
+ # add metadata key-value to stepdata
+ for i in range(len(s)):
+ s[i].update(meta)
+ else:
+ s.update(meta)
+ return s
+
+ # override
+ def get_data(self, batch_size: int) -> List[Callable]:
+ """
+ Overview:
+ Get a list of data loading function, which can be implemented by dataloader to read data from files.
+ Arguments:
+ - batch_size (:obj:`int`): Batch size.
+ Returns:
+ - data (:obj:`List[Callable]`): A list of callable data loading function.
+ """
+ while self._learner is None:
+ time.sleep(1)
+ # Tell coordinator that we need training data, by putting info dict in data_demand_queue.
+ assert self._data_demand_queue.qsize() == 0
+ self._data_demand_queue.put({'batch_size': batch_size, 'cur_learner_iter': self._learner.last_iter.val})
+ # Get a list of meta data (data info dict) from coordinator, by getting info dict from data_result_queue.
+ data = self._data_result_queue.get()
+ assert isinstance(data, list)
+ assert len(data) == batch_size, '{}/{}'.format(len(data), batch_size)
+ # Transform meta data to callable data loading function (partial ``load_data_fn``).
+ decompressor = get_data_decompressor(data[0].get('compressor', 'none'))
+ data = [
+ partial(
+ FlaskFileSystemLearner.load_data_fn,
+ path=m['object_ref'] if read_from_di_store else os.path.join(self._path_data, m['data_id']),
+ meta=m,
+ decompressor=decompressor,
+ ) for m in data
+ ]
+ return data
+
+ # override
+ def send_learn_info(self, learn_info: dict) -> None:
+ """
+ Overview:
+ Store learn info dict in queue, which will be retrieved by callback function "deal_with_learner_learn"
+ in learner slave, then will be sent to coordinator.
+ Arguments:
+ - learn_info (:obj:`dict`): Learn info in `dict` type. Keys are like 'learner_step', 'priority_info' \
+ 'finished_task', etc. You can refer to ``learn_info``(``worker/learner/base_learner.py``) for details.
+ """
+ assert self._learn_info_queue.qsize() == 0
+ self._learn_info_queue.put(learn_info)
+
+ @property
+ def hooks4call(self) -> List[LearnerHook]:
+ """
+ Overview:
+ Return the hooks that are related to message passing with coordinator.
+ Returns:
+ - hooks (:obj:`list`): The hooks which comm learner has. Will be registered in learner as well.
+ """
+ return [
+ SendPolicyHook('send_policy', 100, position='before_run', ext_args={}),
+ SendPolicyHook('send_policy', 100, position='after_iter', ext_args={'send_policy_freq': 1}),
+ SendLearnInfoHook(
+ 'send_learn_info',
+ 100,
+ position='after_iter',
+ ext_args={'freq': 10},
+ ),
+ SendLearnInfoHook(
+ 'send_learn_info',
+ 100,
+ position='after_run',
+ ext_args={'freq': 1},
+ ),
+ ]
+
+
+class SendPolicyHook(LearnerHook):
+ """
+ Overview:
+ Hook to send policy
+ Interfaces:
+ __init__, __call__
+ Property:
+ name, priority, position
+ """
+
+ def __init__(self, *args, ext_args: dict = {}, **kwargs) -> None:
+ """
+ Overview:
+ init SendpolicyHook
+ Arguments:
+ - ext_args (:obj:`dict`): Extended arguments. Use ``ext_args.freq`` to set send_policy_freq
+ """
+ super().__init__(*args, **kwargs)
+ if 'send_policy_freq' in ext_args:
+ self._freq = ext_args['send_policy_freq']
+ else:
+ self._freq = 1
+
+ def __call__(self, engine: 'BaseLearner') -> None: # noqa
+ """
+ Overview:
+ Save learner's policy in corresponding path at interval iterations by calling ``engine``'s ``send_policy``.
+ Saved file includes model_state_dict, learner_last_iter.
+ Arguments:
+ - engine (:obj:`BaseLearner`): The BaseLearner.
+
+ .. note::
+ Only rank == 0 learner will save policy.
+ """
+ last_iter = engine.last_iter.val
+ if engine.rank == 0 and last_iter % self._freq == 0:
+ state_dict = {'model': engine.policy.state_dict()['model'], 'iter': last_iter}
+ engine.send_policy(state_dict)
+ engine.debug('{} save iter{} policy'.format(engine.instance_name, last_iter))
+
+
+class SendLearnInfoHook(LearnerHook):
+ """
+ Overview:
+ Hook to send learn info
+ Interfaces:
+ __init__, __call__
+ Property:
+ name, priority, position
+ """
+
+ def __init__(self, *args, ext_args: dict, **kwargs) -> None:
+ """
+ Overview:
+ init SendLearnInfoHook
+ Arguments:
+ - ext_args (:obj:`dict`): extended_args, use ext_args.freq
+ """
+ super().__init__(*args, **kwargs)
+ self._freq = ext_args['freq']
+
+ def __call__(self, engine: 'BaseLearner') -> None: # noqa
+ """
+ Overview:
+ Send learn info including last_iter at interval iterations and priority info
+ Arguments:
+ - engine (:obj:`BaseLearner`): the BaseLearner
+ """
+ last_iter = engine.last_iter.val
+ engine.send_learn_info(engine.learn_info)
+ if last_iter % self._freq == 0:
+ engine.debug('{} save iter{} learn_info'.format(engine.instance_name, last_iter))
diff --git a/DI-engine/ding/worker/learner/comm/tests/test_learner_with_coordinator.py b/DI-engine/ding/worker/learner/comm/tests/test_learner_with_coordinator.py
new file mode 100644
index 0000000000000000000000000000000000000000..be98f1282227862564a0c5ee9ea607479617b89a
--- /dev/null
+++ b/DI-engine/ding/worker/learner/comm/tests/test_learner_with_coordinator.py
@@ -0,0 +1,76 @@
+import pytest
+import os
+import time
+from multiprocessing import Process
+
+from ding.worker import Coordinator, create_comm_learner
+from ding.worker.collector.comm import NaiveCollector
+from ding.utils import lists_to_dicts
+from ding.config import compile_config_parallel
+from ding.config.utils import parallel_test_main_config, parallel_test_create_config, parallel_test_system_config
+
+DATA_PREFIX = 'SLAVE_COLLECTOR_DATA_LEARNER_TEST'
+
+
+@pytest.fixture(scope='function')
+def setup_config():
+ cfg = compile_config_parallel(
+ parallel_test_main_config, create_cfg=parallel_test_create_config, system_cfg=parallel_test_system_config
+ )
+ cfg.main.policy.learn.learner.train_iterations = 100
+ return cfg
+
+
+@pytest.fixture(scope='function')
+def setup_collector(setup_config):
+ cfg = setup_config.system.coordinator.collector
+ collector = {}
+ for _, (name, host, port) in cfg.items():
+ collector[name] = NaiveCollector(host, port, prefix=DATA_PREFIX)
+ collector[name].start()
+ yield collector
+ for a in collector.values():
+ a.close()
+
+
+@pytest.fixture(scope='function')
+def setup_learner(setup_config):
+ learner = {}
+ for k, v in setup_config.system.items():
+ if 'learner' in k:
+ learner[k] = create_comm_learner(v)
+ learner[k].start()
+ yield learner
+ for l in learner.values():
+ l.close()
+
+
+@pytest.mark.unittest(rerun=5)
+class TestLearnerWithCoordinator:
+
+ def test_naive(self, setup_config, setup_collector, setup_learner):
+ os.popen('rm -rf {}*'.format(DATA_PREFIX))
+ assert len(setup_collector) == len(setup_config.system.coordinator.collector)
+ try:
+ coordinator = Coordinator(setup_config)
+ coordinator.start()
+ while True:
+ if coordinator._commander._learner_task_finish_count == 1:
+ break
+ time.sleep(0.5)
+ coordinator.close()
+ except Exception as e:
+ os.popen('rm -rf {}*'.format(DATA_PREFIX))
+ assert False, e
+
+ collector_task_ids = [t for t in coordinator._historical_task if 'collector' in t]
+ for i in range(1, 21):
+ for t in collector_task_ids:
+ assert os.path.exists('{}_{}_{}'.format(DATA_PREFIX, t, i))
+ assert len(coordinator._replay_buffer) == 0
+ learner_task_ids = [i for i in coordinator._historical_task if 'learner' in i]
+ for i in learner_task_ids:
+ assert len(
+ coordinator._commander._learner_info[i]
+ ) == setup_config.main.policy.learn.learner.train_iterations
+ os.popen('rm -rf {}*'.format(DATA_PREFIX))
diff --git a/DI-engine/ding/worker/learner/comm/utils.py b/DI-engine/ding/worker/learner/comm/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..dcd2c916d7dbda59bcc3e313d9920d584d1d07e1
--- /dev/null
+++ b/DI-engine/ding/worker/learner/comm/utils.py
@@ -0,0 +1,56 @@
+import time
+import os
+from ding.interaction import Slave, TaskFail
+from ding.utils import lists_to_dicts
+
+
+class NaiveLearner(Slave):
+
+ def __init__(self, *args, prefix='', **kwargs):
+ super().__init__(*args, **kwargs)
+ self._prefix = prefix
+
+ def _process_task(self, task):
+ task_name = task['name']
+ if task_name == 'resource':
+ return {'cpu': 'xxx', 'gpu': 'xxx'}
+ elif task_name == 'learner_start_task':
+ time.sleep(1)
+ self.task_info = task['task_info']
+ self.count = 0
+ return {'message': 'learner task has started'}
+ elif task_name == 'learner_get_data_task':
+ time.sleep(0.01)
+ return {
+ 'task_id': self.task_info['task_id'],
+ 'buffer_id': self.task_info['buffer_id'],
+ 'batch_size': 2,
+ 'cur_learner_iter': 1
+ }
+ elif task_name == 'learner_learn_task':
+ data = task['data']
+ if data is None:
+ raise TaskFail(result={'message': 'no data'})
+ time.sleep(0.1)
+ data = lists_to_dicts(data)
+ assert 'data_id' in data.keys()
+ priority_keys = ['replay_unique_id', 'replay_buffer_idx', 'priority']
+ self.count += 1
+ ret = {
+ 'info': {
+ 'learner_step': self.count
+ },
+ 'task_id': self.task_info['task_id'],
+ 'buffer_id': self.task_info['buffer_id']
+ }
+ ret['info']['priority_info'] = {k: data[k] for k in priority_keys}
+ if self.count > 5:
+ ret['info']['learner_done'] = True
+ os.popen('touch {}_final_model.pth'.format(self._prefix))
+ return ret
+ elif task_name == 'learner_close_task':
+ return {'task_id': self.task_info['task_id'], 'buffer_id': self.task_info['buffer_id']}
+ else:
+ raise TaskFail(
+ result={'message': 'task name error'}, message='illegal collector task <{}>'.format(task_name)
+ )
diff --git a/DI-engine/ding/worker/learner/learner_hook.py b/DI-engine/ding/worker/learner/learner_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..250a8f195081591d2c35a4608dc6a96ce2661507
--- /dev/null
+++ b/DI-engine/ding/worker/learner/learner_hook.py
@@ -0,0 +1,434 @@
+import numbers
+import os
+from abc import ABC, abstractmethod
+from typing import Any, Dict, List
+import torch
+from easydict import EasyDict
+
+import ding
+from ding.utils import allreduce, read_file, save_file, get_rank
+
+
+class Hook(ABC):
+ """
+ Overview:
+ Abstract class for hooks.
+ Interfaces:
+ __init__, __call__
+ Property:
+ name, priority
+ """
+
+ def __init__(self, name: str, priority: float, **kwargs) -> None:
+ """
+ Overview:
+ Init method for hooks. Set name and priority.
+ Arguments:
+ - name (:obj:`str`): The name of hook
+ - priority (:obj:`float`): The priority used in ``call_hook``'s calling sequence. \
+ Lower value means higher priority.
+ """
+ self._name = name
+ assert priority >= 0, "invalid priority value: {}".format(priority)
+ self._priority = priority
+
+ @property
+ def name(self) -> str:
+ return self._name
+
+ @property
+ def priority(self) -> float:
+ return self._priority
+
+ @abstractmethod
+ def __call__(self, engine: Any) -> Any:
+ """
+ Overview:
+ Should be overwritten by subclass.
+ Arguments:
+ - engine (:obj:`Any`): For LearnerHook, it should be ``BaseLearner`` or its subclass.
+ """
+ raise NotImplementedError
+
+
+class LearnerHook(Hook):
+ """
+ Overview:
+ Abstract class for hooks used in Learner.
+ Interfaces:
+ __init__
+ Property:
+ name, priority, position
+
+ .. note::
+
+ Subclass should implement ``self.__call__``.
+ """
+ positions = ['before_run', 'after_run', 'before_iter', 'after_iter']
+
+ def __init__(self, *args, position: str, **kwargs) -> None:
+ """
+ Overview:
+ Init LearnerHook.
+ Arguments:
+ - position (:obj:`str`): The position to call hook in learner. \
+ Must be in ['before_run', 'after_run', 'before_iter', 'after_iter'].
+ """
+ super().__init__(*args, **kwargs)
+ assert position in self.positions
+ self._position = position
+
+ @property
+ def position(self) -> str:
+ return self._position
+
+
+class LoadCkptHook(LearnerHook):
+ """
+ Overview:
+ Hook to load checkpoint
+ Interfaces:
+ __init__, __call__
+ Property:
+ name, priority, position
+ """
+
+ def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None:
+ """
+ Overview:
+ Init LoadCkptHook.
+ Arguments:
+ - ext_args (:obj:`EasyDict`): Extended arguments. Use ``ext_args.freq`` to set ``load_ckpt_freq``.
+ """
+ super().__init__(*args, **kwargs)
+ self._load_path = ext_args['load_path']
+
+ def __call__(self, engine: 'BaseLearner') -> None: # noqa
+ """
+ Overview:
+ Load checkpoint to learner. Checkpoint info includes policy state_dict and iter num.
+ Arguments:
+ - engine (:obj:`BaseLearner`): The BaseLearner to load checkpoint to.
+ """
+ path = self._load_path
+ if path == '': # not load
+ return
+ state_dict = read_file(path)
+ if 'last_iter' in state_dict:
+ last_iter = state_dict.pop('last_iter')
+ engine.last_iter.update(last_iter)
+ engine.policy.load_state_dict(state_dict)
+ engine.info('{} load ckpt in {}'.format(engine.instance_name, path))
+
+
+class SaveCkptHook(LearnerHook):
+ """
+ Overview:
+ Hook to save checkpoint
+ Interfaces:
+ __init__, __call__
+ Property:
+ name, priority, position
+ """
+
+ def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None:
+ """
+ Overview:
+ init SaveCkptHook
+ Arguments:
+ - ext_args (:obj:`EasyDict`): extended_args, use ext_args.freq to set save_ckpt_freq
+ """
+ super().__init__(*args, **kwargs)
+ if ext_args == {}:
+ self._freq = 1
+ else:
+ self._freq = ext_args.freq
+
+ def __call__(self, engine: 'BaseLearner') -> None: # noqa
+ """
+ Overview:
+ Save checkpoint in corresponding path.
+ Checkpoint info includes policy state_dict and iter num.
+ Arguments:
+ - engine (:obj:`BaseLearner`): the BaseLearner which needs to save checkpoint
+ """
+ if engine.rank == 0 and engine.last_iter.val % self._freq == 0:
+ if engine.instance_name == 'learner':
+ dirname = './{}/ckpt'.format(engine.exp_name)
+ else:
+ dirname = './{}/ckpt_{}'.format(engine.exp_name, engine.instance_name)
+ if not os.path.exists(dirname):
+ try:
+ os.makedirs(dirname)
+ except FileExistsError:
+ pass
+ ckpt_name = engine.ckpt_name if engine.ckpt_name else 'iteration_{}.pth.tar'.format(engine.last_iter.val)
+ path = os.path.join(dirname, ckpt_name)
+ state_dict = engine.policy.state_dict()
+ state_dict.update({'last_iter': engine.last_iter.val})
+ save_file(path, state_dict)
+ engine.info('{} save ckpt in {}'.format(engine.instance_name, path))
+
+
+class LogShowHook(LearnerHook):
+ """
+ Overview:
+ Hook to show log
+ Interfaces:
+ __init__, __call__
+ Property:
+ name, priority, position
+ """
+
+ def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None:
+ """
+ Overview:
+ init LogShowHook
+ Arguments:
+ - ext_args (:obj:`EasyDict`): extended_args, use ext_args.freq to set freq
+ """
+ super().__init__(*args, **kwargs)
+ if ext_args == {}:
+ self._freq = 1
+ else:
+ self._freq = ext_args.freq
+
+ def __call__(self, engine: 'BaseLearner') -> None: # noqa
+ """
+ Overview:
+ Show log, update record and tb_logger if rank is 0 and at interval iterations,
+ clear the log buffer for all learners regardless of rank
+ Arguments:
+ - engine (:obj:`BaseLearner`): the BaseLearner
+ """
+ # Only show log for rank 0 learner
+ if engine.rank != 0:
+ for k in engine.log_buffer:
+ engine.log_buffer[k].clear()
+ return
+ # For 'scalar' type variables: log_buffer -> tick_monitor -> monitor_time.step
+ for k, v in engine.log_buffer['scalar'].items():
+ setattr(engine.monitor, k, v)
+ engine.monitor.time.step()
+
+ iters = engine.last_iter.val
+ if iters % self._freq == 0:
+ engine.info("=== Training Iteration {} Result ===".format(iters))
+ # For 'scalar' type variables: tick_monitor -> var_dict -> text_logger & tb_logger
+ var_dict = {}
+ log_vars = engine.policy.monitor_vars()
+ attr = 'avg'
+ for k in log_vars:
+ k_attr = k + '_' + attr
+ var_dict[k_attr] = getattr(engine.monitor, attr)[k]()
+ engine.logger.info(engine.logger.get_tabulate_vars_hor(var_dict))
+ for k, v in var_dict.items():
+ engine.tb_logger.add_scalar('{}_iter/'.format(engine.instance_name) + k, v, iters)
+ engine.tb_logger.add_scalar('{}_step/'.format(engine.instance_name) + k, v, engine._collector_envstep)
+ # For 'histogram' type variables: log_buffer -> tb_var_dict -> tb_logger
+ tb_var_dict = {}
+ for k in engine.log_buffer['histogram']:
+ new_k = '{}/'.format(engine.instance_name) + k
+ tb_var_dict[new_k] = engine.log_buffer['histogram'][k]
+ for k, v in tb_var_dict.items():
+ engine.tb_logger.add_histogram(k, v, iters)
+ for k in engine.log_buffer:
+ engine.log_buffer[k].clear()
+
+
+class LogReduceHook(LearnerHook):
+ """
+ Overview:
+ Hook to reduce the distributed(multi-gpu) logs
+ Interfaces:
+ __init__, __call__
+ Property:
+ name, priority, position
+ """
+
+ def __init__(self, *args, ext_args: EasyDict = EasyDict(), **kwargs) -> None:
+ """
+ Overview:
+ init LogReduceHook
+ Arguments:
+ - ext_args (:obj:`EasyDict`): extended_args, use ext_args.freq to set log_reduce_freq
+ """
+ super().__init__(*args, **kwargs)
+
+ def __call__(self, engine: 'BaseLearner') -> None: # noqa
+ """
+ Overview:
+ reduce the logs from distributed(multi-gpu) learners
+ Arguments:
+ - engine (:obj:`BaseLearner`): the BaseLearner
+ """
+
+ def aggregate(data):
+ r"""
+ Overview:
+ aggregate the information from all ranks(usually use sync allreduce)
+ Arguments:
+ - data (:obj:`dict`): Data that needs to be reduced. \
+ Could be dict, torch.Tensor, numbers.Integral or numbers.Real.
+ Returns:
+ - new_data (:obj:`dict`): data after reduce
+ """
+ if isinstance(data, dict):
+ new_data = {k: aggregate(v) for k, v in data.items()}
+ elif isinstance(data, list) or isinstance(data, tuple):
+ new_data = [aggregate(t) for t in data]
+ elif isinstance(data, torch.Tensor):
+ new_data = data.clone().detach()
+ if ding.enable_linklink:
+ allreduce(new_data)
+ else:
+ new_data = new_data.to(get_rank())
+ allreduce(new_data)
+ new_data = new_data.cpu()
+ elif isinstance(data, numbers.Integral) or isinstance(data, numbers.Real):
+ new_data = torch.scalar_tensor(data).reshape([1])
+ if ding.enable_linklink:
+ allreduce(new_data)
+ else:
+ new_data = new_data.to(get_rank())
+ allreduce(new_data)
+ new_data = new_data.cpu()
+ new_data = new_data.item()
+ else:
+ raise TypeError("invalid type in reduce: {}".format(type(data)))
+ return new_data
+
+ engine.log_buffer = aggregate(engine.log_buffer)
+
+
+hook_mapping = {
+ 'load_ckpt': LoadCkptHook,
+ 'save_ckpt': SaveCkptHook,
+ 'log_show': LogShowHook,
+ 'log_reduce': LogReduceHook,
+}
+
+
+def register_learner_hook(name: str, hook_type: type) -> None:
+ """
+ Overview:
+ Add a new LearnerHook class to hook_mapping, so you can build one instance with `build_learner_hook_by_cfg`.
+ Arguments:
+ - name (:obj:`str`): name of the register hook
+ - hook_type (:obj:`type`): the register hook_type you implemented that realize LearnerHook
+ Examples:
+ >>> class HookToRegister(LearnerHook):
+ >>> def __init__(*args, **kargs):
+ >>> ...
+ >>> ...
+ >>> def __call__(*args, **kargs):
+ >>> ...
+ >>> ...
+ >>> ...
+ >>> register_learner_hook('name_of_hook', HookToRegister)
+ >>> ...
+ >>> hooks = build_learner_hook_by_cfg(cfg)
+ """
+ assert issubclass(hook_type, LearnerHook)
+ hook_mapping[name] = hook_type
+
+
+simplified_hook_mapping = {
+ 'log_show_after_iter': lambda freq: hook_mapping['log_show']
+ ('log_show', 20, position='after_iter', ext_args=EasyDict({'freq': freq})),
+ 'load_ckpt_before_run': lambda path: hook_mapping['load_ckpt']
+ ('load_ckpt', 20, position='before_run', ext_args=EasyDict({'load_path': path})),
+ 'save_ckpt_after_iter': lambda freq: hook_mapping['save_ckpt']
+ ('save_ckpt_after_iter', 20, position='after_iter', ext_args=EasyDict({'freq': freq})),
+ 'save_ckpt_after_run': lambda _: hook_mapping['save_ckpt']('save_ckpt_after_run', 20, position='after_run'),
+ 'log_reduce_after_iter': lambda _: hook_mapping['log_reduce']('log_reduce_after_iter', 10, position='after_iter'),
+}
+
+
+def find_char(s: str, flag: str, num: int, reverse: bool = False) -> int:
+ assert num > 0, num
+ count = 0
+ iterable_obj = reversed(range(len(s))) if reverse else range(len(s))
+ for i in iterable_obj:
+ if s[i] == flag:
+ count += 1
+ if count == num:
+ return i
+ return -1
+
+
+def build_learner_hook_by_cfg(cfg: EasyDict) -> Dict[str, List[Hook]]:
+ """
+ Overview:
+ Build the learner hooks in hook_mapping by config.
+ This function is often used to initialize ``hooks`` according to cfg,
+ while add_learner_hook() is often used to add an existing LearnerHook to `hooks`.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config dict. Should be like {'hook': xxx}.
+ Returns:
+ - hooks (:obj:`Dict[str, List[Hook]`): Keys should be in ['before_run', 'after_run', 'before_iter', \
+ 'after_iter'], each value should be a list containing all hooks in this position.
+ Note:
+ Lower value means higher priority.
+ """
+ hooks = {k: [] for k in LearnerHook.positions}
+ for key, value in cfg.items():
+ if key in simplified_hook_mapping and not isinstance(value, dict):
+ pos = key[find_char(key, '_', 2, reverse=True) + 1:]
+ hook = simplified_hook_mapping[key](value)
+ priority = hook.priority
+ else:
+ priority = value.get('priority', 100)
+ pos = value.position
+ ext_args = value.get('ext_args', {})
+ hook = hook_mapping[value.type](value.name, priority, position=pos, ext_args=ext_args)
+ idx = 0
+ for i in reversed(range(len(hooks[pos]))):
+ if priority >= hooks[pos][i].priority:
+ idx = i + 1
+ break
+ hooks[pos].insert(idx, hook)
+ return hooks
+
+
+def add_learner_hook(hooks: Dict[str, List[Hook]], hook: LearnerHook) -> None:
+ """
+ Overview:
+ Add a learner hook(:obj:`LearnerHook`) to hooks(:obj:`Dict[str, List[Hook]`)
+ Arguments:
+ - hooks (:obj:`Dict[str, List[Hook]`): You can refer to ``build_learner_hook_by_cfg``'s return ``hooks``.
+ - hook (:obj:`LearnerHook`): The LearnerHook which will be added to ``hooks``.
+ """
+ position = hook.position
+ priority = hook.priority
+ idx = 0
+ for i in reversed(range(len(hooks[position]))):
+ if priority >= hooks[position][i].priority:
+ idx = i + 1
+ break
+ assert isinstance(hook, LearnerHook)
+ hooks[position].insert(idx, hook)
+
+
+def merge_hooks(hooks1: Dict[str, List[Hook]], hooks2: Dict[str, List[Hook]]) -> Dict[str, List[Hook]]:
+ """
+ Overview:
+ Merge two hooks dict, which have the same keys, and each value is sorted by hook priority with stable method.
+ Arguments:
+ - hooks1 (:obj:`Dict[str, List[Hook]`): hooks1 to be merged.
+ - hooks2 (:obj:`Dict[str, List[Hook]`): hooks2 to be merged.
+ Returns:
+ - new_hooks (:obj:`Dict[str, List[Hook]`): New merged hooks dict.
+ Note:
+ This merge function uses stable sort method without disturbing the same priority hook.
+ """
+ assert set(hooks1.keys()) == set(hooks2.keys())
+ new_hooks = {}
+ for k in hooks1.keys():
+ new_hooks[k] = sorted(hooks1[k] + hooks2[k], key=lambda x: x.priority)
+ return new_hooks
+
+
+def show_hooks(hooks: Dict[str, List[Hook]]) -> None:
+ for k in hooks.keys():
+ print('{}: {}'.format(k, [x.__class__.__name__ for x in hooks[k]]))
diff --git a/DI-engine/ding/worker/learner/tests/test_base_learner.py b/DI-engine/ding/worker/learner/tests/test_base_learner.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd3d00df2c9da02a007a3534255a85a49722bf66
--- /dev/null
+++ b/DI-engine/ding/worker/learner/tests/test_base_learner.py
@@ -0,0 +1,136 @@
+import os
+import time
+
+import pytest
+import torch
+from easydict import EasyDict
+from typing import Any
+from functools import partial
+
+from ding.worker import BaseLearner
+from ding.worker.learner import LearnerHook, add_learner_hook, create_learner
+
+
+class FakeLearner(BaseLearner):
+
+ @staticmethod
+ def random_data():
+ return {
+ 'obs': torch.randn(2),
+ 'replay_buffer_idx': 0,
+ 'replay_unique_id': 0,
+ }
+
+ def get_data(self, batch_size):
+ return [self.random_data for _ in range(batch_size)]
+
+
+class FakePolicy:
+
+ def __init__(self):
+ self._model = torch.nn.Identity()
+
+ def forward(self, x):
+ return {
+ 'total_loss': torch.randn(1).squeeze(),
+ 'cur_lr': 0.1,
+ 'priority': [1., 2., 3.],
+ '[histogram]h_example': [1.2, 2.3, 3.4],
+ '[scalars]s_example': {
+ 'a': 5.,
+ 'b': 4.
+ },
+ }
+
+ def data_preprocess(self, x):
+ return x
+
+ def state_dict(self):
+ return {'model': self._model}
+
+ def load_state_dict(self, state_dict):
+ pass
+
+ def info(self):
+ return 'FakePolicy'
+
+ def monitor_vars(self):
+ return ['total_loss', 'cur_lr']
+
+ def get_attribute(self, name):
+ if name == 'cuda':
+ return False
+ elif name == 'device':
+ return 'cpu'
+ elif name == 'batch_size':
+ return 2
+ elif name == 'on_policy':
+ return False
+ else:
+ raise KeyError
+
+ def reset(self):
+ pass
+
+
+@pytest.mark.unittest
+class TestBaseLearner:
+
+ def _get_cfg(self, path):
+ cfg = BaseLearner.default_config()
+ cfg.import_names = []
+ cfg.learner_type = 'fake'
+ cfg.train_iterations = 10
+ cfg.hook.load_ckpt_before_run = path
+ cfg.hook.log_show_after_iter = 5
+ # Another way to build hook: Complete config
+ cfg.hook.save_ckpt_after_iter = dict(
+ name='save_ckpt_after_iter', type='save_ckpt', priority=40, position='after_iter', ext_args={'freq': 5}
+ )
+
+ return cfg
+
+ def test_naive(self):
+ os.popen('rm -rf iteration_5.pth.tar*')
+ time.sleep(1.0)
+ with pytest.raises(KeyError):
+ create_learner(EasyDict({'type': 'placeholder', 'import_names': []}))
+ path = os.path.join(os.path.dirname(__file__), './iteration_5.pth.tar')
+ torch.save({'model': {}, 'last_iter': 5}, path)
+ time.sleep(0.5)
+ cfg = self._get_cfg(path)
+ learner = FakeLearner(cfg, exp_name='exp_test')
+ learner.policy = FakePolicy()
+ learner.setup_dataloader()
+ learner.start()
+ time.sleep(2)
+ assert learner.last_iter.val == 10 + 5
+
+ # test hook
+ dir_name = '{}/ckpt'.format(learner.exp_name)
+ for n in [5, 10, 15]:
+ assert os.path.exists(dir_name + '/iteration_{}.pth.tar'.format(n))
+ for n in [0, 4, 7, 12]:
+ assert not os.path.exists(dir_name + '/iteration_{}.pth.tar'.format(n))
+ learner.debug('iter [5, 10, 15] exists; iter [0, 4, 7, 12] does not exist.')
+
+ learner.save_checkpoint('best')
+
+ info = learner.learn_info
+ for info_name in ['learner_step', 'priority_info', 'learner_done']:
+ assert info_name in info
+
+ class FakeHook(LearnerHook):
+
+ def __call__(self, engine: Any) -> Any:
+ pass
+
+ original_hook_num = len(learner._hooks['after_run'])
+ add_learner_hook(learner._hooks, FakeHook(name='fake_hook', priority=30, position='after_run'))
+ assert len(learner._hooks['after_run']) == original_hook_num + 1
+
+ os.popen('rm -rf iteration_5.pth.tar*')
+ os.popen('rm -rf ' + dir_name)
+ os.popen('rm -rf learner')
+ os.popen('rm -rf log')
+ learner.close()
diff --git a/DI-engine/ding/worker/learner/tests/test_learner_hook.py b/DI-engine/ding/worker/learner/tests/test_learner_hook.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf29f91a382544d5cb95092f088dd6907e9d3599
--- /dev/null
+++ b/DI-engine/ding/worker/learner/tests/test_learner_hook.py
@@ -0,0 +1,75 @@
+import easydict
+import pytest
+from ding.worker.learner import register_learner_hook, build_learner_hook_by_cfg, LearnerHook
+from ding.worker.learner.learner_hook import SaveCkptHook, LoadCkptHook, LogShowHook, LogReduceHook
+from ding.worker.learner.learner_hook import show_hooks, add_learner_hook, merge_hooks
+from easydict import EasyDict
+
+
+@pytest.fixture(scope='function')
+def setup_simplified_hook_cfg():
+ return dict(
+ save_ckpt_after_iter=20,
+ save_ckpt_after_run=True,
+ )
+
+
+@pytest.fixture(scope='function')
+def fake_setup_simplified_hook_cfg():
+ return dict(
+ log_show_after_iter=20,
+ log_reduce_after_iter=True,
+ )
+
+
+@pytest.mark.unittest
+class TestLearnerHook:
+
+ def test_register(self):
+
+ class FakeHook(LearnerHook):
+ pass
+
+ register_learner_hook('fake', FakeHook)
+ with pytest.raises(AssertionError):
+ register_learner_hook('placeholder', type)
+
+ def test_build_learner_hook_by_cfg(self, setup_simplified_hook_cfg):
+ hooks = build_learner_hook_by_cfg(setup_simplified_hook_cfg)
+ show_hooks(hooks)
+ assert len(hooks['before_run']) == 0
+ assert len(hooks['before_iter']) == 0
+ assert len(hooks['after_iter']) == 1
+ assert isinstance(hooks['after_iter'][0], SaveCkptHook)
+ assert len(hooks['after_run']) == 1
+ assert isinstance(hooks['after_run'][0], SaveCkptHook)
+
+ def test_add_learner_hook(self, setup_simplified_hook_cfg):
+ hooks = build_learner_hook_by_cfg(setup_simplified_hook_cfg)
+ hook_1 = LogShowHook('log_show', 20, position='after_iter', ext_args=EasyDict({'freq': 100}))
+ add_learner_hook(hooks, hook_1)
+ hook_2 = LoadCkptHook('load_ckpt', 20, position='before_run', ext_args=EasyDict({'load_path': './model.pth'}))
+ add_learner_hook(hooks, hook_2)
+ hook_3 = LogReduceHook('log_reduce', 10, position='after_iter')
+ add_learner_hook(hooks, hook_3)
+
+ show_hooks(hooks)
+ assert len(hooks['after_iter']) == 3
+ assert len(hooks['after_run']) == 1
+ assert len(hooks['before_run']) == 1
+ assert len(hooks['before_iter']) == 0
+ assert isinstance(hooks['after_run'][0], SaveCkptHook)
+ assert isinstance(hooks['before_run'][0], LoadCkptHook)
+
+ def test_merge_hooks(self, setup_simplified_hook_cfg, fake_setup_simplified_hook_cfg):
+ hooks = build_learner_hook_by_cfg(setup_simplified_hook_cfg)
+ show_hooks(hooks)
+ fake_hooks = build_learner_hook_by_cfg(fake_setup_simplified_hook_cfg)
+ show_hooks(fake_hooks)
+ hooks_ = merge_hooks(hooks, fake_hooks)
+ show_hooks(hooks_)
+ assert len(hooks_['after_iter']) == 3
+ assert len(hooks_['after_run']) == 1
+ assert len(hooks_['before_run']) == 0
+ assert len(hooks_['before_iter']) == 0
+ assert isinstance(hooks['after_run'][0], SaveCkptHook)
diff --git a/DI-engine/ding/worker/replay_buffer/__init__.py b/DI-engine/ding/worker/replay_buffer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4f1bf3e8772af5543dd8b43c4654d328150e3bd
--- /dev/null
+++ b/DI-engine/ding/worker/replay_buffer/__init__.py
@@ -0,0 +1,4 @@
+from .base_buffer import IBuffer, create_buffer, get_buffer_cls
+from .naive_buffer import NaiveReplayBuffer, SequenceReplayBuffer
+from .advanced_buffer import AdvancedReplayBuffer
+from .episode_buffer import EpisodeReplayBuffer
diff --git a/DI-engine/ding/worker/replay_buffer/advanced_buffer.py b/DI-engine/ding/worker/replay_buffer/advanced_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..31b4c46d666d4d3c7340d0f98efa52ad60531022
--- /dev/null
+++ b/DI-engine/ding/worker/replay_buffer/advanced_buffer.py
@@ -0,0 +1,787 @@
+import os
+import copy
+import time
+from typing import Union, Any, Optional, List, Dict, Tuple
+import numpy as np
+import hickle
+
+from ding.worker.replay_buffer import IBuffer
+from ding.utils import SumSegmentTree, MinSegmentTree, BUFFER_REGISTRY
+from ding.utils import LockContext, LockContextType, build_logger, get_rank
+from ding.utils.autolog import TickTime
+from .utils import UsedDataRemover, generate_id, SampledDataAttrMonitor, PeriodicThruputMonitor, ThruputController
+
+
+def to_positive_index(idx: Union[int, None], size: int) -> int:
+ if idx is None or idx >= 0:
+ return idx
+ else:
+ return size + idx
+
+
+@BUFFER_REGISTRY.register('advanced')
+class AdvancedReplayBuffer(IBuffer):
+ r"""
+ Overview:
+ Prioritized replay buffer derived from ``NaiveReplayBuffer``.
+ This replay buffer adds:
+
+ 1) Prioritized experience replay implemented by segment tree.
+ 2) Data quality monitor. Monitor use count and staleness of each data.
+ 3) Throughput monitor and control.
+ 4) Logger. Log 2) and 3) in tensorboard or text.
+ Interface:
+ start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config
+ Property:
+ beta, replay_buffer_size, push_count
+ """
+
+ config = dict(
+ type='advanced',
+ # Max length of the buffer.
+ replay_buffer_size=4096,
+ # Max use times of one data in the buffer. Data will be removed once used for too many times.
+ max_use=float("inf"),
+ # Max staleness time duration of one data in the buffer; Data will be removed if
+ # the duration from collecting to training is too long, i.e. The data is too stale.
+ max_staleness=float("inf"),
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ # Anneal step for beta: 0 means no annealing
+ anneal_step=int(1e5),
+ # Whether to track the used data. Used data means they are removed out of buffer and would never be used again.
+ enable_track_used_data=False,
+ # Whether to deepcopy data when willing to insert and sample data. For security purpose.
+ deepcopy=False,
+ thruput_controller=dict(
+ # Rate limit. The ratio of "Sample Count" to "Push Count" should be in [min, max] range.
+ # If greater than max ratio, return `None` when calling ``sample```;
+ # If smaller than min ratio, throw away the new data when calling ``push``.
+ push_sample_rate_limit=dict(
+ max=float("inf"),
+ min=0,
+ ),
+ # Controller will take how many seconds into account, i.e. For the past `window_seconds` seconds,
+ # sample_push_rate will be calculated and campared with `push_sample_rate_limit`.
+ window_seconds=30,
+ # The minimum ratio that buffer must satisfy before anything can be sampled.
+ # The ratio is calculated by "Valid Count" divided by "Batch Size".
+ # E.g. sample_min_limit_ratio = 2.0, valid_count = 50, batch_size = 32, it is forbidden to sample.
+ sample_min_limit_ratio=1,
+ ),
+ # Monitor configuration for monitor and logger to use. This part does not affect buffer's function.
+ monitor=dict(
+ sampled_data_attr=dict(
+ # Past datas will be used for moving average.
+ average_range=5,
+ # Print data attributes every `print_freq` samples.
+ print_freq=200, # times
+ ),
+ periodic_thruput=dict(
+ # Every `seconds` seconds, thruput(push/sample/remove count) will be printed.
+ seconds=60,
+ ),
+ ),
+ )
+
+ def __init__(
+ self,
+ cfg: dict,
+ tb_logger: Optional['SummaryWriter'] = None, # noqa
+ exp_name: Optional[str] = 'default_experiment',
+ instance_name: Optional[str] = 'buffer',
+ ) -> int:
+ """
+ Overview:
+ Initialize the buffer
+ Arguments:
+ - cfg (:obj:`dict`): Config dict.
+ - tb_logger (:obj:`Optional['SummaryWriter']`): Outer tb logger. Usually get this argument in serial mode.
+ - exp_name (:obj:`Optional[str]`): Name of this experiment.
+ - instance_name (:obj:`Optional[str]`): Name of this instance.
+ """
+ self._exp_name = exp_name
+ self._instance_name = instance_name
+ self._end_flag = False
+ self._cfg = cfg
+ self._rank = get_rank()
+ self._replay_buffer_size = self._cfg.replay_buffer_size
+ self._deepcopy = self._cfg.deepcopy
+ # ``_data`` is a circular queue to store data (full data or meta data)
+ self._data = [None for _ in range(self._replay_buffer_size)]
+ # Current valid data count, indicating how many elements in ``self._data`` is valid.
+ self._valid_count = 0
+ # How many pieces of data have been pushed into this buffer, should be no less than ``_valid_count``.
+ self._push_count = 0
+ # Point to the tail position where next data can be inserted, i.e. latest inserted data's next position.
+ self._tail = 0
+ # Is used to generate a unique id for each data: If a new data is inserted, its unique id will be this.
+ self._next_unique_id = 0
+ # Lock to guarantee thread safe
+ self._lock = LockContext(type_=LockContextType.THREAD_LOCK)
+ # Point to the head of the circular queue. The true data is the stalest(oldest) data in this queue.
+ # Because buffer would remove data due to staleness or use count, and at the beginning when queue is not
+ # filled with data head would always be 0, so ``head`` may be not equal to ``tail``;
+ # Otherwise, they two should be the same. Head is used to optimize staleness check in ``_sample_check``.
+ self._head = 0
+ # use_count is {position_idx: use_count}
+ self._use_count = {idx: 0 for idx in range(self._cfg.replay_buffer_size)}
+ # Max priority till now. Is used to initizalize a data's priority if "priority" is not passed in with the data.
+ self._max_priority = 1.0
+ # A small positive number to avoid edge-case, e.g. "priority" == 0.
+ self._eps = 1e-5
+ # Data check function list, used in ``_append`` and ``_extend``. This buffer requires data to be dict.
+ self.check_list = [lambda x: isinstance(x, dict)]
+
+ self._max_use = self._cfg.max_use
+ self._max_staleness = self._cfg.max_staleness
+ self.alpha = self._cfg.alpha
+ assert 0 <= self.alpha <= 1, self.alpha
+ self._beta = self._cfg.beta
+ assert 0 <= self._beta <= 1, self._beta
+ self._anneal_step = self._cfg.anneal_step
+ if self._anneal_step != 0:
+ self._beta_anneal_step = (1 - self._beta) / self._anneal_step
+
+ # Prioritized sample.
+ # Capacity needs to be the power of 2.
+ capacity = int(np.power(2, np.ceil(np.log2(self.replay_buffer_size))))
+ # Sum segtree and min segtree are used to sample data according to priority.
+ self._sum_tree = SumSegmentTree(capacity)
+ self._min_tree = MinSegmentTree(capacity)
+
+ # Thruput controller
+ push_sample_rate_limit = self._cfg.thruput_controller.push_sample_rate_limit
+ self._always_can_push = True if push_sample_rate_limit['max'] == float('inf') else False
+ self._always_can_sample = True if push_sample_rate_limit['min'] == 0 else False
+ self._use_thruput_controller = not self._always_can_push or not self._always_can_sample
+ if self._use_thruput_controller:
+ self._thruput_controller = ThruputController(self._cfg.thruput_controller)
+ self._sample_min_limit_ratio = self._cfg.thruput_controller.sample_min_limit_ratio
+ assert self._sample_min_limit_ratio >= 1
+
+ # Monitor & Logger
+ monitor_cfg = self._cfg.monitor
+ if self._rank == 0:
+ if tb_logger is not None:
+ self._logger, _ = build_logger(
+ './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
+ )
+ self._tb_logger = tb_logger
+ else:
+ self._logger, self._tb_logger = build_logger(
+ './{}/log/{}'.format(self._exp_name, self._instance_name),
+ self._instance_name,
+ )
+ else:
+ self._logger, _ = build_logger(
+ './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
+ )
+ self._tb_logger = None
+ self._start_time = time.time()
+ # Sampled data attributes.
+ self._cur_learner_iter = -1
+ self._cur_collector_envstep = -1
+ self._sampled_data_attr_print_count = 0
+ self._sampled_data_attr_monitor = SampledDataAttrMonitor(
+ TickTime(), expire=monitor_cfg.sampled_data_attr.average_range
+ )
+ self._sampled_data_attr_print_freq = monitor_cfg.sampled_data_attr.print_freq
+ # Periodic thruput.
+ if self._rank == 0:
+ self._periodic_thruput_monitor = PeriodicThruputMonitor(
+ self._instance_name, monitor_cfg.periodic_thruput, self._logger, self._tb_logger
+ )
+
+ # Used data remover
+ self._enable_track_used_data = self._cfg.enable_track_used_data
+ if self._enable_track_used_data:
+ self._used_data_remover = UsedDataRemover()
+
+ def start(self) -> None:
+ """
+ Overview:
+ Start the buffer's used_data_remover thread if enables track_used_data.
+ """
+ if self._enable_track_used_data:
+ self._used_data_remover.start()
+
+ def close(self) -> None:
+ """
+ Overview:
+ Clear the buffer; Join the buffer's used_data_remover thread if enables track_used_data.
+ Join periodic throughtput monitor, flush tensorboard logger.
+ """
+ if self._end_flag:
+ return
+ self._end_flag = True
+ self.clear()
+ if self._rank == 0:
+ self._periodic_thruput_monitor.close()
+ self._tb_logger.flush()
+ self._tb_logger.close()
+ if self._enable_track_used_data:
+ self._used_data_remover.close()
+
+ def sample(self, size: int, cur_learner_iter: int, sample_range: slice = None) -> Optional[list]:
+ """
+ Overview:
+ Sample data with length ``size``.
+ Arguments:
+ - size (:obj:`int`): The number of the data that will be sampled.
+ - cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness.
+ - sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \
+ means only sample among the last 10 data
+ Returns:
+ - sample_data (:obj:`list`): A list of data with length ``size``
+ ReturnsKeys:
+ - necessary: original keys(e.g. `obs`, `action`, `next_obs`, `reward`, `info`), \
+ `replay_unique_id`, `replay_buffer_idx`
+ - optional(if use priority): `IS`, `priority`
+ """
+ if size == 0:
+ return []
+ can_sample_stalenss, staleness_info = self._sample_check(size, cur_learner_iter)
+ if self._always_can_sample:
+ can_sample_thruput, thruput_info = True, "Always can sample because push_sample_rate_limit['min'] == 0"
+ else:
+ can_sample_thruput, thruput_info = self._thruput_controller.can_sample(size)
+ if not can_sample_stalenss or not can_sample_thruput:
+ self._logger.info(
+ 'Refuse to sample due to -- \nstaleness: {}, {} \nthruput: {}, {}'.format(
+ not can_sample_stalenss, staleness_info, not can_sample_thruput, thruput_info
+ )
+ )
+ return None
+ with self._lock:
+ indices = self._get_indices(size, sample_range)
+ result = self._sample_with_indices(indices, cur_learner_iter)
+ # Deepcopy ``result``'s same indice datas in case ``self._get_indices`` may get datas with
+ # the same indices, i.e. the same datas would be sampled afterwards.
+ # if self._deepcopy==True -> all data is different
+ # if len(indices) == len(set(indices)) -> no duplicate data
+ if not self._deepcopy and len(indices) != len(set(indices)):
+ for i, index in enumerate(indices):
+ tmp = []
+ for j in range(i + 1, size):
+ if index == indices[j]:
+ tmp.append(j)
+ for j in tmp:
+ result[j] = copy.deepcopy(result[j])
+ self._monitor_update_of_sample(result, cur_learner_iter)
+ return result
+
+ def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None:
+ r"""
+ Overview:
+ Push a data into buffer.
+ Arguments:
+ - data (:obj:`Union[List[Any], Any]`): The data which will be pushed into buffer. Can be one \
+ (in `Any` type), or many(int `List[Any]` type).
+ - cur_collector_envstep (:obj:`int`): Collector's current env step.
+ """
+ push_size = len(data) if isinstance(data, list) else 1
+ if self._always_can_push:
+ can_push, push_info = True, "Always can push because push_sample_rate_limit['max'] == float('inf')"
+ else:
+ can_push, push_info = self._thruput_controller.can_push(push_size)
+ if not can_push:
+ self._logger.info('Refuse to push because {}'.format(push_info))
+ return
+ if isinstance(data, list):
+ self._extend(data, cur_collector_envstep)
+ else:
+ self._append(data, cur_collector_envstep)
+
+ def save_data(self, file_name: str):
+ if not os.path.exists(os.path.dirname(file_name)):
+ if os.path.dirname(file_name) != "":
+ os.makedirs(os.path.dirname(file_name))
+ hickle.dump(py_obj=self._data, file_obj=file_name)
+
+ def load_data(self, file_name: str):
+ self.push(hickle.load(file_name), 0)
+
+ def _sample_check(self, size: int, cur_learner_iter: int) -> Tuple[bool, str]:
+ r"""
+ Overview:
+ Do preparations for sampling and check whether data is enough for sampling
+ Preparation includes removing stale datas in ``self._data``.
+ Check includes judging whether this buffer has more than ``size`` datas to sample.
+ Arguments:
+ - size (:obj:`int`): The number of the data that will be sampled.
+ - cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness.
+ Returns:
+ - can_sample (:obj:`bool`): Whether this buffer can sample enough data.
+ - str_info (:obj:`str`): Str type info, explaining why cannot sample. (If can sample, return "Can sample")
+
+ .. note::
+ This function must be called before data sample.
+ """
+ staleness_remove_count = 0
+ with self._lock:
+ if self._max_staleness != float("inf"):
+ p = self._head
+ while True:
+ if self._data[p] is not None:
+ staleness = self._calculate_staleness(p, cur_learner_iter)
+ if staleness >= self._max_staleness:
+ self._remove(p)
+ staleness_remove_count += 1
+ else:
+ # Since the circular queue ``self._data`` guarantees that data's staleness is decreasing
+ # from index self._head to index self._tail - 1, we can jump out of the loop as soon as
+ # meeting a fresh enough data
+ break
+ p = (p + 1) % self._replay_buffer_size
+ if p == self._tail:
+ # Traverse a circle and go back to the tail, which means can stop staleness checking now
+ break
+ str_info = "Remove {} elements due to staleness. ".format(staleness_remove_count)
+ if self._valid_count / size < self._sample_min_limit_ratio:
+ str_info += "Not enough for sampling. valid({}) / sample({}) < sample_min_limit_ratio({})".format(
+ self._valid_count, size, self._sample_min_limit_ratio
+ )
+ return False, str_info
+ else:
+ str_info += "Can sample."
+ return True, str_info
+
+ def _append(self, ori_data: Any, cur_collector_envstep: int = -1) -> None:
+ r"""
+ Overview:
+ Append a data item into queue.
+ Add two keys in data:
+
+ - replay_unique_id: The data item's unique id, using ``generate_id`` to generate it.
+ - replay_buffer_idx: The data item's position index in the queue, this position may already have an \
+ old element, then it would be replaced by this new input one. using ``self._tail`` to locate.
+ Arguments:
+ - ori_data (:obj:`Any`): The data which will be inserted.
+ - cur_collector_envstep (:obj:`int`): Collector's current env step, used to draw tensorboard.
+ """
+ with self._lock:
+ if self._deepcopy:
+ data = copy.deepcopy(ori_data)
+ else:
+ data = ori_data
+ try:
+ assert self._data_check(data)
+ except AssertionError:
+ # If data check fails, log it and return without any operations.
+ self._logger.info('Illegal data type [{}], reject it...'.format(type(data)))
+ return
+ self._push_count += 1
+ # remove->set weight->set data
+ if self._data[self._tail] is not None:
+ self._head = (self._tail + 1) % self._replay_buffer_size
+ self._remove(self._tail)
+ data['replay_unique_id'] = generate_id(self._instance_name, self._next_unique_id)
+ data['replay_buffer_idx'] = self._tail
+ self._set_weight(data)
+ self._data[self._tail] = data
+ self._valid_count += 1
+ if self._rank == 0:
+ self._periodic_thruput_monitor.valid_count = self._valid_count
+ self._tail = (self._tail + 1) % self._replay_buffer_size
+ self._next_unique_id += 1
+ self._monitor_update_of_push(1, cur_collector_envstep)
+
+ def _extend(self, ori_data: List[Any], cur_collector_envstep: int = -1) -> None:
+ r"""
+ Overview:
+ Extend a data list into queue.
+ Add two keys in each data item, you can refer to ``_append`` for more details.
+ Arguments:
+ - ori_data (:obj:`List[Any]`): The data list.
+ - cur_collector_envstep (:obj:`int`): Collector's current env step, used to draw tensorboard.
+ """
+ with self._lock:
+ if self._deepcopy:
+ data = copy.deepcopy(ori_data)
+ else:
+ data = ori_data
+ check_result = [self._data_check(d) for d in data]
+ # Only keep data items that pass ``_data_check`.
+ valid_data = [d for d, flag in zip(data, check_result) if flag]
+ length = len(valid_data)
+ # When updating ``_data`` and ``_use_count``, should consider two cases regarding
+ # the relationship between "tail + data length" and "queue max length" to check whether
+ # data will exceed beyond queue's max length limitation.
+ if self._tail + length <= self._replay_buffer_size:
+ for j in range(self._tail, self._tail + length):
+ if self._data[j] is not None:
+ self._head = (j + 1) % self._replay_buffer_size
+ self._remove(j)
+ for i in range(length):
+ valid_data[i]['replay_unique_id'] = generate_id(self._instance_name, self._next_unique_id + i)
+ valid_data[i]['replay_buffer_idx'] = (self._tail + i) % self._replay_buffer_size
+ self._set_weight(valid_data[i])
+ self._push_count += 1
+ self._data[self._tail:self._tail + length] = valid_data
+ else:
+ data_start = self._tail
+ valid_data_start = 0
+ residual_num = len(valid_data)
+ while True:
+ space = self._replay_buffer_size - data_start
+ L = min(space, residual_num)
+ for j in range(data_start, data_start + L):
+ if self._data[j] is not None:
+ self._head = (j + 1) % self._replay_buffer_size
+ self._remove(j)
+ for i in range(valid_data_start, valid_data_start + L):
+ valid_data[i]['replay_unique_id'] = generate_id(self._instance_name, self._next_unique_id + i)
+ valid_data[i]['replay_buffer_idx'] = (self._tail + i) % self._replay_buffer_size
+ self._set_weight(valid_data[i])
+ self._push_count += 1
+ self._data[data_start:data_start + L] = valid_data[valid_data_start:valid_data_start + L]
+ residual_num -= L
+ if residual_num <= 0:
+ break
+ else:
+ data_start = 0
+ valid_data_start += L
+ self._valid_count += len(valid_data)
+ if self._rank == 0:
+ self._periodic_thruput_monitor.valid_count = self._valid_count
+ # Update ``tail`` and ``next_unique_id`` after the whole list is pushed into buffer.
+ self._tail = (self._tail + length) % self._replay_buffer_size
+ self._next_unique_id += length
+ self._monitor_update_of_push(length, cur_collector_envstep)
+
+ def update(self, info: dict) -> None:
+ r"""
+ Overview:
+ Update a data's priority. Use `repaly_buffer_idx` to locate, and use `replay_unique_id` to verify.
+ Arguments:
+ - info (:obj:`dict`): Info dict containing all necessary keys for priority update.
+ ArgumentsKeys:
+ - necessary: `replay_unique_id`, `replay_buffer_idx`, `priority`. All values are lists with the same length.
+ """
+ with self._lock:
+ if 'priority' not in info:
+ return
+ data = [info['replay_unique_id'], info['replay_buffer_idx'], info['priority']]
+ for id_, idx, priority in zip(*data):
+ # Only if the data still exists in the queue, will the update operation be done.
+ if self._data[idx] is not None \
+ and self._data[idx]['replay_unique_id'] == id_: # Verify the same transition(data)
+ assert priority >= 0, priority
+ assert self._data[idx]['replay_buffer_idx'] == idx
+ self._data[idx]['priority'] = priority + self._eps # Add epsilon to avoid priority == 0
+ self._set_weight(self._data[idx])
+ # Update max priority
+ self._max_priority = max(self._max_priority, priority)
+ else:
+ self._logger.debug(
+ '[Skip Update]: buffer_idx: {}; id_in_buffer: {}; id_in_update_info: {}'.format(
+ idx, id_, priority
+ )
+ )
+
+ def clear(self) -> None:
+ """
+ Overview:
+ Clear all the data and reset the related variables.
+ """
+ with self._lock:
+ for i in range(len(self._data)):
+ self._remove(i)
+ assert self._valid_count == 0, self._valid_count
+ self._head = 0
+ self._tail = 0
+ self._max_priority = 1.0
+
+ def __del__(self) -> None:
+ """
+ Overview:
+ Call ``close`` to delete the object.
+ """
+ if not self._end_flag:
+ self.close()
+
+ def _set_weight(self, data: Dict) -> None:
+ r"""
+ Overview:
+ Set sumtree and mintree's weight of the input data according to its priority.
+ If input data does not have key "priority", it would set to ``self._max_priority`` instead.
+ Arguments:
+ - data (:obj:`Dict`): The data whose priority(weight) in segement tree should be set/updated.
+ """
+ if 'priority' not in data.keys() or data['priority'] is None:
+ data['priority'] = self._max_priority
+ weight = data['priority'] ** self.alpha
+ idx = data['replay_buffer_idx']
+ self._sum_tree[idx] = weight
+ self._min_tree[idx] = weight
+
+ def _data_check(self, d: Any) -> bool:
+ r"""
+ Overview:
+ Data legality check, using rules(functions) in ``self.check_list``.
+ Arguments:
+ - d (:obj:`Any`): The data which needs to be checked.
+ Returns:
+ - result (:obj:`bool`): Whether the data passes the check.
+ """
+ # only the data passes all the check functions, would the check return True
+ return all([fn(d) for fn in self.check_list])
+
+ def _get_indices(self, size: int, sample_range: slice = None) -> list:
+ r"""
+ Overview:
+ Get the sample index list according to the priority probability.
+ Arguments:
+ - size (:obj:`int`): The number of the data that will be sampled
+ Returns:
+ - index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``.
+ """
+ # Divide [0, 1) into size intervals on average
+ intervals = np.array([i * 1.0 / size for i in range(size)])
+ # Uniformly sample within each interval
+ mass = intervals + np.random.uniform(size=(size, )) * 1. / size
+ if sample_range is None:
+ # Rescale to [0, S), where S is the sum of all datas' priority (root value of sum tree)
+ mass *= self._sum_tree.reduce()
+ else:
+ # Rescale to [a, b)
+ start = to_positive_index(sample_range.start, self._replay_buffer_size)
+ end = to_positive_index(sample_range.stop, self._replay_buffer_size)
+ a = self._sum_tree.reduce(0, start)
+ b = self._sum_tree.reduce(0, end)
+ mass = mass * (b - a) + a
+ # Find prefix sum index to sample with probability
+ return [self._sum_tree.find_prefixsum_idx(m) for m in mass]
+
+ def _remove(self, idx: int, use_too_many_times: bool = False) -> None:
+ r"""
+ Overview:
+ Remove a data(set the element in the list to ``None``) and update corresponding variables,
+ e.g. sum_tree, min_tree, valid_count.
+ Arguments:
+ - idx (:obj:`int`): Data at this position will be removed.
+ """
+ if use_too_many_times:
+ if self._enable_track_used_data:
+ # Must track this data, but in parallel mode.
+ # Do not remove it, but make sure it will not be sampled.
+ self._data[idx]['priority'] = 0
+ self._sum_tree[idx] = self._sum_tree.neutral_element
+ self._min_tree[idx] = self._min_tree.neutral_element
+ return
+ elif idx == self._head:
+ # Correct `self._head` when the queue head is removed due to use_count
+ self._head = (self._head + 1) % self._replay_buffer_size
+ if self._data[idx] is not None:
+ if self._enable_track_used_data:
+ self._used_data_remover.add_used_data(self._data[idx])
+ self._valid_count -= 1
+ if self._rank == 0:
+ self._periodic_thruput_monitor.valid_count = self._valid_count
+ self._periodic_thruput_monitor.remove_data_count += 1
+ self._data[idx] = None
+ self._sum_tree[idx] = self._sum_tree.neutral_element
+ self._min_tree[idx] = self._min_tree.neutral_element
+ self._use_count[idx] = 0
+
+ def _sample_with_indices(self, indices: List[int], cur_learner_iter: int) -> list:
+ r"""
+ Overview:
+ Sample data with ``indices``; Remove a data item if it is used for too many times.
+ Arguments:
+ - indices (:obj:`List[int]`): A list including all the sample indices.
+ - cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness.
+ Returns:
+ - data (:obj:`list`) Sampled data.
+ """
+ # Calculate max weight for normalizing IS
+ sum_tree_root = self._sum_tree.reduce()
+ p_min = self._min_tree.reduce() / sum_tree_root
+ max_weight = (self._valid_count * p_min) ** (-self._beta)
+ data = []
+ for idx in indices:
+ assert self._data[idx] is not None
+ assert self._data[idx]['replay_buffer_idx'] == idx, (self._data[idx]['replay_buffer_idx'], idx)
+ if self._deepcopy:
+ copy_data = copy.deepcopy(self._data[idx])
+ else:
+ copy_data = self._data[idx]
+ # Store staleness, use and IS(importance sampling weight for gradient step) for monitor and outer use
+ self._use_count[idx] += 1
+ copy_data['staleness'] = self._calculate_staleness(idx, cur_learner_iter)
+ copy_data['use'] = self._use_count[idx]
+ p_sample = self._sum_tree[idx] / sum_tree_root
+ weight = (self._valid_count * p_sample) ** (-self._beta)
+ copy_data['IS'] = weight / max_weight
+ data.append(copy_data)
+ if self._max_use != float("inf"):
+ # Remove datas whose "use count" is greater than ``max_use``
+ for idx in indices:
+ if self._use_count[idx] >= self._max_use:
+ self._remove(idx, use_too_many_times=True)
+ # Beta annealing
+ if self._anneal_step != 0:
+ self._beta = min(1.0, self._beta + self._beta_anneal_step)
+ return data
+
+ def _monitor_update_of_push(self, add_count: int, cur_collector_envstep: int = -1) -> None:
+ r"""
+ Overview:
+ Update values in monitor, then update text logger and tensorboard logger.
+ Called in ``_append`` and ``_extend``.
+ Arguments:
+ - add_count (:obj:`int`): How many datas are added into buffer.
+ - cur_collector_envstep (:obj:`int`): Collector envstep, passed in by collector.
+ """
+ if self._rank == 0:
+ self._periodic_thruput_monitor.push_data_count += add_count
+ if self._use_thruput_controller:
+ self._thruput_controller.history_push_count += add_count
+ self._cur_collector_envstep = cur_collector_envstep
+
+ def _monitor_update_of_sample(self, sample_data: list, cur_learner_iter: int) -> None:
+ r"""
+ Overview:
+ Update values in monitor, then update text logger and tensorboard logger.
+ Called in ``sample``.
+ Arguments:
+ - sample_data (:obj:`list`): Sampled data. Used to get sample length and data's attributes, \
+ e.g. use, priority, staleness, etc.
+ - cur_learner_iter (:obj:`int`): Learner iteration, passed in by learner.
+ """
+ if self._rank == 0:
+ self._periodic_thruput_monitor.sample_data_count += len(sample_data)
+ if self._use_thruput_controller:
+ self._thruput_controller.history_sample_count += len(sample_data)
+ self._cur_learner_iter = cur_learner_iter
+ use_avg = sum([d['use'] for d in sample_data]) / len(sample_data)
+ use_max = max([d['use'] for d in sample_data])
+ priority_avg = sum([d['priority'] for d in sample_data]) / len(sample_data)
+ priority_max = max([d['priority'] for d in sample_data])
+ priority_min = min([d['priority'] for d in sample_data])
+ staleness_avg = sum([d['staleness'] for d in sample_data]) / len(sample_data)
+ staleness_max = max([d['staleness'] for d in sample_data])
+ self._sampled_data_attr_monitor.use_avg = use_avg
+ self._sampled_data_attr_monitor.use_max = use_max
+ self._sampled_data_attr_monitor.priority_avg = priority_avg
+ self._sampled_data_attr_monitor.priority_max = priority_max
+ self._sampled_data_attr_monitor.priority_min = priority_min
+ self._sampled_data_attr_monitor.staleness_avg = staleness_avg
+ self._sampled_data_attr_monitor.staleness_max = staleness_max
+ self._sampled_data_attr_monitor.time.step()
+ out_dict = {
+ 'use_avg': self._sampled_data_attr_monitor.avg['use'](),
+ 'use_max': self._sampled_data_attr_monitor.max['use'](),
+ 'priority_avg': self._sampled_data_attr_monitor.avg['priority'](),
+ 'priority_max': self._sampled_data_attr_monitor.max['priority'](),
+ 'priority_min': self._sampled_data_attr_monitor.min['priority'](),
+ 'staleness_avg': self._sampled_data_attr_monitor.avg['staleness'](),
+ 'staleness_max': self._sampled_data_attr_monitor.max['staleness'](),
+ 'beta': self._beta,
+ }
+ if self._sampled_data_attr_print_count % self._sampled_data_attr_print_freq == 0 and self._rank == 0:
+ self._logger.info("=== Sample data {} Times ===".format(self._sampled_data_attr_print_count))
+ self._logger.info(self._logger.get_tabulate_vars_hor(out_dict))
+ for k, v in out_dict.items():
+ iter_metric = self._cur_learner_iter if self._cur_learner_iter != -1 else None
+ step_metric = self._cur_collector_envstep if self._cur_collector_envstep != -1 else None
+ if iter_metric is not None:
+ self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, iter_metric)
+ if step_metric is not None:
+ self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, step_metric)
+ self._sampled_data_attr_print_count += 1
+
+ def _calculate_staleness(self, pos_index: int, cur_learner_iter: int) -> Optional[int]:
+ r"""
+ Overview:
+ Calculate a data's staleness according to its own attribute ``collect_iter``
+ and input parameter ``cur_learner_iter``.
+ Arguments:
+ - pos_index (:obj:`int`): The position index. Staleness of the data at this index will be calculated.
+ - cur_learner_iter (:obj:`int`): Learner's current iteration, used to calculate staleness.
+ Returns:
+ - staleness (:obj:`int`): Staleness of data at position ``pos_index``.
+
+ .. note::
+ Caller should guarantee that data at ``pos_index`` is not None; Otherwise this function may raise an error.
+ """
+ if self._data[pos_index] is None:
+ raise ValueError("Prioritized's data at index {} is None".format(pos_index))
+ else:
+ # Calculate staleness, remove it if too stale
+ collect_iter = self._data[pos_index].get('collect_iter', cur_learner_iter + 1)
+ if isinstance(collect_iter, list):
+ # Timestep transition's collect_iter is a list
+ collect_iter = min(collect_iter)
+ # ``staleness`` might be -1, means invalid, e.g. collector does not report collecting model iter,
+ # or it is a demonstration buffer(which means data is not generated by collector) etc.
+ staleness = cur_learner_iter - collect_iter
+ return staleness
+
+ def count(self) -> int:
+ """
+ Overview:
+ Count how many valid datas there are in the buffer.
+ Returns:
+ - count (:obj:`int`): Number of valid data.
+ """
+ return self._valid_count
+
+ @property
+ def beta(self) -> float:
+ return self._beta
+
+ @beta.setter
+ def beta(self, beta: float) -> None:
+ self._beta = beta
+
+ def state_dict(self) -> dict:
+ """
+ Overview:
+ Provide a state dict to keep a record of current buffer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. \
+ With the dict, one can easily reproduce the buffer.
+ """
+ return {
+ 'data': self._data,
+ 'use_count': self._use_count,
+ 'tail': self._tail,
+ 'max_priority': self._max_priority,
+ 'anneal_step': self._anneal_step,
+ 'beta': self._beta,
+ 'head': self._head,
+ 'next_unique_id': self._next_unique_id,
+ 'valid_count': self._valid_count,
+ 'push_count': self._push_count,
+ 'sum_tree': self._sum_tree,
+ 'min_tree': self._min_tree,
+ }
+
+ def load_state_dict(self, _state_dict: dict, deepcopy: bool = False) -> None:
+ """
+ Overview:
+ Load state dict to reproduce the buffer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer.
+ """
+ assert 'data' in _state_dict
+ if set(_state_dict.keys()) == set(['data']):
+ self._extend(_state_dict['data'])
+ else:
+ for k, v in _state_dict.items():
+ if deepcopy:
+ setattr(self, '_{}'.format(k), copy.deepcopy(v))
+ else:
+ setattr(self, '_{}'.format(k), v)
+
+ @property
+ def replay_buffer_size(self) -> int:
+ return self._replay_buffer_size
+
+ @property
+ def push_count(self) -> int:
+ return self._push_count
diff --git a/DI-engine/ding/worker/replay_buffer/base_buffer.py b/DI-engine/ding/worker/replay_buffer/base_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..7231c34067d4492f0fa205d265200c4da97f2531
--- /dev/null
+++ b/DI-engine/ding/worker/replay_buffer/base_buffer.py
@@ -0,0 +1,149 @@
+from typing import Union, Dict, Any, List
+from abc import ABC, abstractmethod
+import copy
+from easydict import EasyDict
+
+from ding.utils import import_module, BUFFER_REGISTRY
+
+
+class IBuffer(ABC):
+ r"""
+ Overview:
+ Buffer interface
+ Interfaces:
+ default_config, push, update, sample, clear, count, state_dict, load_state_dict
+ """
+
+ @classmethod
+ def default_config(cls) -> EasyDict:
+ r"""
+ Overview:
+ Default config of this buffer class.
+ Returns:
+ - default_config (:obj:`EasyDict`)
+ """
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ @abstractmethod
+ def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None:
+ r"""
+ Overview:
+ Push a data into buffer.
+ Arguments:
+ - data (:obj:`Union[List[Any], Any]`): The data which will be pushed into buffer. Can be one \
+ (in `Any` type), or many(int `List[Any]` type).
+ - cur_collector_envstep (:obj:`int`): Collector's current env step.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def update(self, info: Dict[str, list]) -> None:
+ r"""
+ Overview:
+ Update data info, e.g. priority.
+ Arguments:
+ - info (:obj:`Dict[str, list]`): Info dict. Keys depends on the specific buffer type.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def sample(self, batch_size: int, cur_learner_iter: int) -> list:
+ r"""
+ Overview:
+ Sample data with length ``batch_size``.
+ Arguments:
+ - size (:obj:`int`): The number of the data that will be sampled.
+ - cur_learner_iter (:obj:`int`): Learner's current iteration.
+ Returns:
+ - sampled_data (:obj:`list`): A list of data with length `batch_size`.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def clear(self) -> None:
+ """
+ Overview:
+ Clear all the data and reset the related variables.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def count(self) -> int:
+ """
+ Overview:
+ Count how many valid datas there are in the buffer.
+ Returns:
+ - count (:obj:`int`): Number of valid data.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def save_data(self, file_name: str):
+ """
+ Overview:
+ Save buffer data into a file.
+ Arguments:
+ - file_name (:obj:`str`): file name of buffer data
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def load_data(self, file_name: str):
+ """
+ Overview:
+ Load buffer data from a file.
+ Arguments:
+ - file_name (:obj:`str`): file name of buffer data
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def state_dict(self) -> Dict[str, Any]:
+ """
+ Overview:
+ Provide a state dict to keep a record of current buffer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. \
+ With the dict, one can easily reproduce the buffer.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def load_state_dict(self, _state_dict: Dict[str, Any]) -> None:
+ """
+ Overview:
+ Load state dict to reproduce the buffer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer.
+ """
+ raise NotImplementedError
+
+
+def create_buffer(cfg: EasyDict, *args, **kwargs) -> IBuffer:
+ r"""
+ Overview:
+ Create a buffer according to cfg and other arguments.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Buffer config.
+ ArgumentsKeys:
+ - necessary: `type`
+ """
+ import_module(cfg.get('import_names', []))
+ if cfg.type == 'naive':
+ kwargs.pop('tb_logger', None)
+ return BUFFER_REGISTRY.build(cfg.type, cfg, *args, **kwargs)
+
+
+def get_buffer_cls(cfg: EasyDict) -> type:
+ r"""
+ Overview:
+ Get a buffer class according to cfg.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Buffer config.
+ ArgumentsKeys:
+ - necessary: `type`
+ """
+ import_module(cfg.get('import_names', []))
+ return BUFFER_REGISTRY.get(cfg.type)
diff --git a/DI-engine/ding/worker/replay_buffer/episode_buffer.py b/DI-engine/ding/worker/replay_buffer/episode_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3488a2cbe2a142ada4ce5912ba59673436af978e
--- /dev/null
+++ b/DI-engine/ding/worker/replay_buffer/episode_buffer.py
@@ -0,0 +1,19 @@
+from typing import List
+from ding.worker.replay_buffer import NaiveReplayBuffer
+from ding.utils import BUFFER_REGISTRY
+
+
+@BUFFER_REGISTRY.register('episode')
+class EpisodeReplayBuffer(NaiveReplayBuffer):
+ r"""
+ Overview:
+ Episode replay buffer is a buffer to store complete episodes, i.e. Each element in episode buffer is an episode.
+ Some algorithms do not want to sample `batch_size` complete episodes, however, they want some transitions with
+ some fixed length. As a result, ``sample`` should be overwritten for those requirements.
+ Interface:
+ start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config
+ """
+
+ @property
+ def episode_len(self) -> List[int]:
+ return [len(episode) for episode in self._data]
diff --git a/DI-engine/ding/worker/replay_buffer/naive_buffer.py b/DI-engine/ding/worker/replay_buffer/naive_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..4addc5838dd1f9db0827c7a3accb4cfefebf1bf1
--- /dev/null
+++ b/DI-engine/ding/worker/replay_buffer/naive_buffer.py
@@ -0,0 +1,565 @@
+import os
+import copy
+from typing import Union, Any, Optional, List
+import numpy as np
+import math
+import hickle
+from easydict import EasyDict
+
+from ding.worker.replay_buffer import IBuffer
+from ding.utils import LockContext, LockContextType, BUFFER_REGISTRY, build_logger
+from .utils import UsedDataRemover, PeriodicThruputMonitor
+
+
+@BUFFER_REGISTRY.register('naive')
+class NaiveReplayBuffer(IBuffer):
+ r"""
+ Overview:
+ Naive replay buffer, can store and sample data.
+ An naive implementation of replay buffer with no priority or any other advanced features.
+ This buffer refers to multi-thread/multi-process and guarantees thread-safe, which means that methods like
+ ``sample``, ``push``, ``clear`` are all mutual to each other.
+ Interface:
+ start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config
+ Property:
+ replay_buffer_size, push_count
+ """
+
+ config = dict(
+ type='naive',
+ replay_buffer_size=10000,
+ deepcopy=False,
+ # default `False` for serial pipeline
+ enable_track_used_data=False,
+ periodic_thruput_seconds=60,
+ )
+
+ def __init__(
+ self,
+ cfg: 'EasyDict', # noqa
+ tb_logger: Optional['SummaryWriter'] = None, # noqa
+ exp_name: Optional[str] = 'default_experiment',
+ instance_name: Optional[str] = 'buffer',
+ ) -> None:
+ """
+ Overview:
+ Initialize the buffer
+ Arguments:
+ - cfg (:obj:`dict`): Config dict.
+ - tb_logger (:obj:`Optional['SummaryWriter']`): Outer tb logger. Usually get this argument in serial mode.
+ - exp_name (:obj:`Optional[str]`): Name of this experiment.
+ - instance_name (:obj:`Optional[str]`): Name of this instance.
+ """
+ self._exp_name = exp_name
+ self._instance_name = instance_name
+ self._cfg = cfg
+ self._replay_buffer_size = self._cfg.replay_buffer_size
+ self._deepcopy = self._cfg.deepcopy
+ # ``_data`` is a circular queue to store data (full data or meta data)
+ self._data = [None for _ in range(self._replay_buffer_size)]
+ # Current valid data count, indicating how many elements in ``self._data`` is valid.
+ self._valid_count = 0
+ # How many pieces of data have been pushed into this buffer, should be no less than ``_valid_count``.
+ self._push_count = 0
+ # Point to the tail position where next data can be inserted, i.e. latest inserted data's next position.
+ self._tail = 0
+ # Lock to guarantee thread safe
+ self._lock = LockContext(type_=LockContextType.THREAD_LOCK)
+ self._end_flag = False
+ self._enable_track_used_data = self._cfg.enable_track_used_data
+ if self._enable_track_used_data:
+ self._used_data_remover = UsedDataRemover()
+ if tb_logger is not None:
+ self._logger, _ = build_logger(
+ './{}/log/{}'.format(self._exp_name, self._instance_name), self._instance_name, need_tb=False
+ )
+ self._tb_logger = tb_logger
+ else:
+ self._logger, self._tb_logger = build_logger(
+ './{}/log/{}'.format(self._exp_name, self._instance_name),
+ self._instance_name,
+ )
+ # Periodic thruput. Here by default, monitor range is 60 seconds. You can modify it for free.
+ self._periodic_thruput_monitor = PeriodicThruputMonitor(
+ self._instance_name, EasyDict(seconds=self._cfg.periodic_thruput_seconds), self._logger, self._tb_logger
+ )
+
+ def start(self) -> None:
+ """
+ Overview:
+ Start the buffer's used_data_remover thread if enables track_used_data.
+ """
+ if self._enable_track_used_data:
+ self._used_data_remover.start()
+
+ def close(self) -> None:
+ """
+ Overview:
+ Clear the buffer; Join the buffer's used_data_remover thread if enables track_used_data.
+ """
+ self.clear()
+ if self._enable_track_used_data:
+ self._used_data_remover.close()
+ self._tb_logger.flush()
+ self._tb_logger.close()
+
+ def push(self, data: Union[List[Any], Any], cur_collector_envstep: int) -> None:
+ r"""
+ Overview:
+ Push a data into buffer.
+ Arguments:
+ - data (:obj:`Union[List[Any], Any]`): The data which will be pushed into buffer. Can be one \
+ (in `Any` type), or many(int `List[Any]` type).
+ - cur_collector_envstep (:obj:`int`): Collector's current env step. \
+ Not used in naive buffer, but preserved for compatibility.
+ """
+ if isinstance(data, list):
+ self._extend(data, cur_collector_envstep)
+ self._periodic_thruput_monitor.push_data_count += len(data)
+ else:
+ self._append(data, cur_collector_envstep)
+ self._periodic_thruput_monitor.push_data_count += 1
+
+ def sample(self,
+ size: int,
+ cur_learner_iter: int,
+ sample_range: slice = None,
+ replace: bool = False) -> Optional[list]:
+ """
+ Overview:
+ Sample data with length ``size``.
+ Arguments:
+ - size (:obj:`int`): The number of the data that will be sampled.
+ - cur_learner_iter (:obj:`int`): Learner's current iteration. \
+ Not used in naive buffer, but preserved for compatibility.
+ - sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \
+ means only sample among the last 10 data
+ - replace (:obj:`bool`): Whether sample with replacement
+ Returns:
+ - sample_data (:obj:`list`): A list of data with length ``size``.
+ """
+ if size == 0:
+ return []
+ can_sample = self._sample_check(size, replace)
+ if not can_sample:
+ return None
+ with self._lock:
+ indices = self._get_indices(size, sample_range, replace)
+ sample_data = self._sample_with_indices(indices, cur_learner_iter)
+ self._periodic_thruput_monitor.sample_data_count += len(sample_data)
+ return sample_data
+
+ def save_data(self, file_name: str):
+ if not os.path.exists(os.path.dirname(file_name)):
+ if os.path.dirname(file_name) != "":
+ os.makedirs(os.path.dirname(file_name))
+ hickle.dump(py_obj=self._data, file_obj=file_name)
+
+ def load_data(self, file_name: str):
+ self.push(hickle.load(file_name), 0)
+
+ def _append(self, ori_data: Any, cur_collector_envstep: int = -1) -> None:
+ r"""
+ Overview:
+ Append a data item into ``self._data``.
+ Arguments:
+ - ori_data (:obj:`Any`): The data which will be inserted.
+ - cur_collector_envstep (:obj:`int`): Not used in this method, but preserved for compatibility.
+ """
+ with self._lock:
+ if self._deepcopy:
+ data = copy.deepcopy(ori_data)
+ else:
+ data = ori_data
+ self._push_count += 1
+ if self._data[self._tail] is None:
+ self._valid_count += 1
+ self._periodic_thruput_monitor.valid_count = self._valid_count
+ elif self._enable_track_used_data:
+ self._used_data_remover.add_used_data(self._data[self._tail])
+ self._data[self._tail] = data
+ self._tail = (self._tail + 1) % self._replay_buffer_size
+
+ def _extend(self, ori_data: List[Any], cur_collector_envstep: int = -1) -> None:
+ r"""
+ Overview:
+ Extend a data list into queue.
+ Add two keys in each data item, you can refer to ``_append`` for details.
+ Arguments:
+ - ori_data (:obj:`List[Any]`): The data list.
+ - cur_collector_envstep (:obj:`int`): Not used in this method, but preserved for compatibility.
+ """
+ with self._lock:
+ if self._deepcopy:
+ data = copy.deepcopy(ori_data)
+ else:
+ data = ori_data
+ length = len(data)
+ # When updating ``_data`` and ``_use_count``, should consider two cases regarding
+ # the relationship between "tail + data length" and "replay buffer size" to check whether
+ # data will exceed beyond buffer's max length limitation.
+ if self._tail + length <= self._replay_buffer_size:
+ if self._valid_count != self._replay_buffer_size:
+ self._valid_count += length
+ self._periodic_thruput_monitor.valid_count = self._valid_count
+ elif self._enable_track_used_data:
+ for i in range(length):
+ self._used_data_remover.add_used_data(self._data[self._tail + i])
+ self._push_count += length
+ self._data[self._tail:self._tail + length] = data
+ else:
+ new_tail = self._tail
+ data_start = 0
+ residual_num = len(data)
+ while True:
+ space = self._replay_buffer_size - new_tail
+ L = min(space, residual_num)
+ if self._valid_count != self._replay_buffer_size:
+ self._valid_count += L
+ self._periodic_thruput_monitor.valid_count = self._valid_count
+ elif self._enable_track_used_data:
+ for i in range(L):
+ self._used_data_remover.add_used_data(self._data[new_tail + i])
+ self._push_count += L
+ self._data[new_tail:new_tail + L] = data[data_start:data_start + L]
+ residual_num -= L
+ assert residual_num >= 0
+ if residual_num == 0:
+ break
+ else:
+ new_tail = 0
+ data_start += L
+ # Update ``tail`` and ``next_unique_id`` after the whole list is pushed into buffer.
+ self._tail = (self._tail + length) % self._replay_buffer_size
+
+ def _sample_check(self, size: int, replace: bool = False) -> bool:
+ r"""
+ Overview:
+ Check whether this buffer has more than `size` datas to sample.
+ Arguments:
+ - size (:obj:`int`): Number of data that will be sampled.
+ - replace (:obj:`bool`): Whether sample with replacement.
+ Returns:
+ - can_sample (:obj:`bool`): Whether this buffer can sample enough data.
+ """
+ if self._valid_count == 0:
+ print("The buffer is empty")
+ return False
+ if self._valid_count < size and not replace:
+ print(
+ "No enough elements for sampling without replacement (expect: {} / current: {})".format(
+ size, self._valid_count
+ )
+ )
+ return False
+ else:
+ return True
+
+ def update(self, info: dict) -> None:
+ r"""
+ Overview:
+ Naive Buffer does not need to update any info, but this method is preserved for compatibility.
+ """
+ print(
+ '[BUFFER WARNING] Naive Buffer does not need to update any info, \
+ but `update` method is preserved for compatibility.'
+ )
+
+ def clear(self) -> None:
+ """
+ Overview:
+ Clear all the data and reset the related variables.
+ """
+ with self._lock:
+ for i in range(len(self._data)):
+ if self._data[i] is not None:
+ if self._enable_track_used_data:
+ self._used_data_remover.add_used_data(self._data[i])
+ self._data[i] = None
+ self._valid_count = 0
+ self._periodic_thruput_monitor.valid_count = self._valid_count
+ self._push_count = 0
+ self._tail = 0
+
+ def __del__(self) -> None:
+ """
+ Overview:
+ Call ``close`` to delete the object.
+ """
+ self.close()
+
+ def _get_indices(self, size: int, sample_range: slice = None, replace: bool = False) -> list:
+ r"""
+ Overview:
+ Get the sample index list.
+ Arguments:
+ - size (:obj:`int`): The number of the data that will be sampled
+ - sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \
+ means only sample among the last 10 data
+ Returns:
+ - index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``.
+ """
+ assert self._valid_count <= self._replay_buffer_size
+ if self._valid_count == self._replay_buffer_size:
+ tail = self._replay_buffer_size
+ else:
+ tail = self._tail
+ if sample_range is None:
+ indices = list(np.random.choice(a=tail, size=size, replace=replace))
+ else:
+ indices = list(range(tail))[sample_range]
+ indices = list(np.random.choice(indices, size=size, replace=replace))
+ return indices
+
+ def _sample_with_indices(self, indices: List[int], cur_learner_iter: int) -> list:
+ r"""
+ Overview:
+ Sample data with ``indices``.
+ Arguments:
+ - indices (:obj:`List[int]`): A list including all the sample indices.
+ - cur_learner_iter (:obj:`int`): Not used in this method, but preserved for compatibility.
+ Returns:
+ - data (:obj:`list`) Sampled data.
+ """
+ data = []
+ for idx in indices:
+ assert self._data[idx] is not None, idx
+ if self._deepcopy:
+ copy_data = copy.deepcopy(self._data[idx])
+ else:
+ copy_data = self._data[idx]
+ data.append(copy_data)
+ return data
+
+ def count(self) -> int:
+ """
+ Overview:
+ Count how many valid datas there are in the buffer.
+ Returns:
+ - count (:obj:`int`): Number of valid data.
+ """
+ return self._valid_count
+
+ def state_dict(self) -> dict:
+ """
+ Overview:
+ Provide a state dict to keep a record of current buffer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer. \
+ With the dict, one can easily reproduce the buffer.
+ """
+ return {
+ 'data': self._data,
+ 'tail': self._tail,
+ 'valid_count': self._valid_count,
+ 'push_count': self._push_count,
+ }
+
+ def load_state_dict(self, _state_dict: dict) -> None:
+ """
+ Overview:
+ Load state dict to reproduce the buffer.
+ Returns:
+ - state_dict (:obj:`Dict[str, Any]`): A dict containing all important values in the buffer.
+ """
+ assert 'data' in _state_dict
+ if set(_state_dict.keys()) == set(['data']):
+ self._extend(_state_dict['data'])
+ else:
+ for k, v in _state_dict.items():
+ setattr(self, '_{}'.format(k), v)
+
+ @property
+ def replay_buffer_size(self) -> int:
+ return self._replay_buffer_size
+
+ @property
+ def push_count(self) -> int:
+ return self._push_count
+
+
+@BUFFER_REGISTRY.register('elastic')
+class ElasticReplayBuffer(NaiveReplayBuffer):
+ r"""
+ Overview:
+ Elastic replay buffer, it stores data and support dynamically change the buffer size.
+ An naive implementation of replay buffer with no priority or any other advanced features.
+ This buffer refers to multi-thread/multi-process and guarantees thread-safe, which means that methods like
+ ``sample``, ``push``, ``clear`` are all mutual to each other.
+ Interface:
+ start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config
+ Property:
+ replay_buffer_size, push_count
+ """
+
+ config = dict(
+ type='elastic',
+ replay_buffer_size=10000,
+ deepcopy=False,
+ # default `False` for serial pipeline
+ enable_track_used_data=False,
+ periodic_thruput_seconds=60,
+ )
+
+ def __init__(
+ self,
+ cfg: 'EasyDict', # noqa
+ tb_logger: Optional['SummaryWriter'] = None, # noqa
+ exp_name: Optional[str] = 'default_experiment',
+ instance_name: Optional[str] = 'buffer',
+ ) -> None:
+ """
+ Overview:
+ Initialize the buffer
+ Arguments:
+ - cfg (:obj:`dict`): Config dict.
+ - tb_logger (:obj:`Optional['SummaryWriter']`): Outer tb logger. Usually get this argument in serial mode.
+ - exp_name (:obj:`Optional[str]`): Name of this experiment.
+ - instance_name (:obj:`Optional[str]`): Name of this instance.
+ """
+ super().__init__(cfg, tb_logger, exp_name, instance_name)
+ self._set_buffer_size = self._cfg.set_buffer_size
+ self._current_buffer_size = self._set_buffer_size(0) # Set the buffer size at the 0-th envstep.
+ # The variable 'current_buffer_size' restricts how many samples the buffer can use for sampling
+
+ def _sample_check(self, size: int, replace: bool = False) -> bool:
+ r"""
+ Overview:
+ Check whether this buffer has more than `size` datas to sample.
+ Arguments:
+ - size (:obj:`int`): Number of data that will be sampled.
+ - replace (:obj:`bool`): Whether sample with replacement.
+ Returns:
+ - can_sample (:obj:`bool`): Whether this buffer can sample enough data.
+ """
+ valid_count = min(self._valid_count, self._current_buffer_size)
+ if valid_count == 0:
+ print("The buffer is empty")
+ return False
+ if valid_count < size and not replace:
+ print(
+ "No enough elements for sampling without replacement (expect: {} / current: {})".format(
+ size, self._valid_count
+ )
+ )
+ return False
+ else:
+ return True
+
+ def _get_indices(self, size: int, sample_range: slice = None, replace: bool = False) -> list:
+ r"""
+ Overview:
+ Get the sample index list.
+ Arguments:
+ - size (:obj:`int`): The number of the data that will be sampled.
+ - replace (:obj:`bool`): Whether sample with replacement.
+ Returns:
+ - index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``.
+ """
+ assert self._valid_count <= self._replay_buffer_size
+ assert sample_range is None # not support
+ range = min(self._valid_count, self._current_buffer_size)
+ indices = list(
+ (self._tail - 1 - np.random.choice(a=range, size=size, replace=replace)) % self._replay_buffer_size
+ )
+ return indices
+
+ def update(self, envstep):
+ self._current_buffer_size = self._set_buffer_size(envstep)
+
+
+@BUFFER_REGISTRY.register('sequence')
+class SequenceReplayBuffer(NaiveReplayBuffer):
+ r"""
+ Overview:
+ Interface:
+ start, close, push, update, sample, clear, count, state_dict, load_state_dict, default_config
+ Property:
+ replay_buffer_size, push_count
+ """
+
+ def sample(
+ self,
+ batch: int,
+ sequence: int,
+ cur_learner_iter: int,
+ sample_range: slice = None,
+ replace: bool = False
+ ) -> Optional[list]:
+ """
+ Overview:
+ Sample data with length ``size``.
+ Arguments:
+ - size (:obj:`int`): The number of the data that will be sampled.
+ - sequence (:obj:`int`): The length of the sequence of a data that will be sampled.
+ - cur_learner_iter (:obj:`int`): Learner's current iteration. \
+ Not used in naive buffer, but preserved for compatibility.
+ - sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \
+ means only sample among the last 10 data
+ - replace (:obj:`bool`): Whether sample with replacement
+ Returns:
+ - sample_data (:obj:`list`): A list of data with length ``size``.
+ """
+ if batch == 0:
+ return []
+ can_sample = self._sample_check(batch * sequence, replace)
+ if not can_sample:
+ return None
+ with self._lock:
+ indices = self._get_indices(batch, sequence, sample_range, replace)
+ sample_data = self._sample_with_indices(indices, sequence, cur_learner_iter)
+ self._periodic_thruput_monitor.sample_data_count += len(sample_data)
+ return sample_data
+
+ def _get_indices(self, size: int, sequence: int, sample_range: slice = None, replace: bool = False) -> list:
+ r"""
+ Overview:
+ Get the sample index list.
+ Arguments:
+ - size (:obj:`int`): The number of the data that will be sampled
+ - sample_range (:obj:`slice`): Buffer slice for sampling, such as `slice(-10, None)`, which \
+ means only sample among the last 10 data
+ Returns:
+ - index_list (:obj:`list`): A list including all the sample indices, whose length should equal to ``size``.
+ """
+ assert self._valid_count <= self._replay_buffer_size
+ if self._valid_count == self._replay_buffer_size:
+ tail = self._replay_buffer_size
+ else:
+ tail = self._tail
+ episodes = math.ceil(self._valid_count / 500)
+ batch = 0
+ indices = []
+ if sample_range is None:
+ while batch < size:
+ episode = np.random.choice(episodes)
+ length = tail - episode * 500 if tail - episode * 500 < 500 else 500
+ available = length - sequence
+ if available < 1:
+ continue
+ list(range(episode * 500, episode * 500 + available))
+ indices.append(np.random.randint(episode * 500, episode * 500 + available + 1))
+ batch += 1
+ else:
+ raise NotImplementedError("sample_range is not implemented in this version")
+ return indices
+
+ def _sample_with_indices(self, indices: List[int], sequence: int, cur_learner_iter: int) -> list:
+ r"""
+ Overview:
+ Sample data with ``indices``.
+ Arguments:
+ - indices (:obj:`List[int]`): A list including all the sample indices.
+ - cur_learner_iter (:obj:`int`): Not used in this method, but preserved for compatibility.
+ Returns:
+ - data (:obj:`list`) Sampled data.
+ """
+ data = []
+ for idx in indices:
+ assert self._data[idx] is not None, idx
+ if self._deepcopy:
+ copy_data = copy.deepcopy(self._data[idx:idx + sequence])
+ else:
+ copy_data = self._data[idx:idx + sequence]
+ data.append(copy_data)
+ return data
diff --git a/DI-engine/ding/worker/replay_buffer/tests/conftest.py b/DI-engine/ding/worker/replay_buffer/tests/conftest.py
new file mode 100644
index 0000000000000000000000000000000000000000..52ff49f1168a93cb88e451d39b07bc9948412dc2
--- /dev/null
+++ b/DI-engine/ding/worker/replay_buffer/tests/conftest.py
@@ -0,0 +1,29 @@
+from typing import List
+import numpy as np
+from ding.utils import save_file
+
+ID_COUNT = 0
+np.random.seed(1)
+
+
+def generate_data(meta: bool = False) -> dict:
+ global ID_COUNT
+ ret = {'obs': np.random.randn(4), 'data_id': str(ID_COUNT)}
+ ID_COUNT += 1
+ p_weight = np.random.uniform()
+ if p_weight < 1 / 3:
+ pass # no key 'priority'
+ elif p_weight < 2 / 3:
+ ret['priority'] = None
+ else:
+ ret['priority'] = np.random.uniform() + 1e-3
+ if not meta:
+ return ret
+ else:
+ obs = ret.pop('obs')
+ save_file(ret['data_id'], obs)
+ return ret
+
+
+def generate_data_list(count: int, meta: bool = False) -> List[dict]:
+ return [generate_data(meta) for _ in range(0, count)]
diff --git a/DI-engine/ding/worker/replay_buffer/tests/test_advanced_buffer.py b/DI-engine/ding/worker/replay_buffer/tests/test_advanced_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6bb71a6b36fdf11ab2de1122e0ede361f6b8bac
--- /dev/null
+++ b/DI-engine/ding/worker/replay_buffer/tests/test_advanced_buffer.py
@@ -0,0 +1,315 @@
+import copy
+from collections import defaultdict
+import numpy as np
+import pytest
+from easydict import EasyDict
+import os
+import pickle
+import time
+import tempfile
+
+from ding.worker.replay_buffer import AdvancedReplayBuffer
+from ding.utils import deep_merge_dicts
+from ding.worker.replay_buffer.tests.conftest import generate_data, generate_data_list
+
+demo_data_path = "test_demo_data"
+
+
+@pytest.fixture(scope="function")
+def setup_demo_buffer_factory():
+ demo_data = {'data': generate_data_list(10)}
+ with open(demo_data_path, "wb") as f:
+ pickle.dump(demo_data, f)
+
+ def generator():
+ while True:
+ cfg = copy.deepcopy(AdvancedReplayBuffer.default_config())
+ cfg.replay_buffer_size = 64
+ cfg.max_use = 2
+ cfg.max_staleness = 1000
+ cfg.alpha = 0.6
+ cfg.beta = 0.6
+ cfg.enable_track_used_data = False
+ yield AdvancedReplayBuffer(instance_name="demo", cfg=cfg)
+
+ return generator()
+
+
+@pytest.mark.unittest
+class TestAdvancedBuffer:
+
+ def test_push(self):
+ buffer_cfg = deep_merge_dicts(AdvancedReplayBuffer.default_config(), EasyDict(dict(replay_buffer_size=64)))
+ advanced_buffer = AdvancedReplayBuffer(buffer_cfg, tb_logger=None, instance_name='test')
+ start_pointer = advanced_buffer._tail
+ start_vaildlen = advanced_buffer.count()
+ start_data_id = advanced_buffer._next_unique_id
+ valid_count = 0
+ for _ in range(100):
+ if advanced_buffer._data[advanced_buffer._tail] is None:
+ valid_count += 1
+ advanced_buffer.push(generate_data(), 0)
+ assert (advanced_buffer.replay_buffer_size == 64)
+ assert (advanced_buffer.count() == 64 == start_vaildlen + valid_count)
+ assert (advanced_buffer.push_count == start_vaildlen + 100)
+ assert (advanced_buffer._tail == (start_pointer + 100) % advanced_buffer.replay_buffer_size)
+ assert (advanced_buffer._next_unique_id == start_data_id + 100)
+ # invalid item append test
+ advanced_buffer.push([], 0)
+ assert (advanced_buffer.count() == 64 == start_vaildlen + valid_count)
+ assert (advanced_buffer.push_count == start_vaildlen + 100)
+ assert (advanced_buffer._tail == (start_pointer + 100) % advanced_buffer.replay_buffer_size)
+ assert (advanced_buffer._next_unique_id == start_data_id + 100)
+
+ buffer_cfg = deep_merge_dicts(AdvancedReplayBuffer.default_config(), EasyDict(dict(replay_buffer_size=64)))
+ advanced_buffer = AdvancedReplayBuffer(buffer_cfg, tb_logger=None, instance_name='test')
+ start_pointer = advanced_buffer._tail
+ start_data_id = advanced_buffer._next_unique_id
+ replay_buffer_size = advanced_buffer.replay_buffer_size
+ extend_num = int(0.6 * replay_buffer_size)
+ for i in range(1, 4):
+ data = generate_data_list(extend_num)
+ advanced_buffer.push(data, 0)
+ assert advanced_buffer._tail == (start_pointer + extend_num * i) % replay_buffer_size
+ assert advanced_buffer._next_unique_id == start_data_id + extend_num * i
+ assert advanced_buffer._valid_count == min(start_data_id + extend_num * i, replay_buffer_size)
+
+ def test_save_and_load_data(self):
+ buffer_cfg = deep_merge_dicts(AdvancedReplayBuffer.default_config(), EasyDict(dict(replay_buffer_size=64)))
+ advanced_buffer = AdvancedReplayBuffer(buffer_cfg, tb_logger=None, instance_name='test')
+ start_pointer = advanced_buffer._tail
+ start_vaildlen = advanced_buffer.count()
+ start_data_id = advanced_buffer._next_unique_id
+ valid_count = 0
+ for _ in range(100):
+ if advanced_buffer._data[advanced_buffer._tail] is None:
+ valid_count += 1
+ advanced_buffer.push(generate_data(), 0)
+ assert (advanced_buffer.replay_buffer_size == 64)
+ assert (advanced_buffer.count() == 64 == start_vaildlen + valid_count)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ test_file = os.path.join(tmpdirname, "data.hkl")
+ advanced_buffer.save_data(test_file)
+ advanced_buffer_new = AdvancedReplayBuffer(buffer_cfg, instance_name='test_new')
+ advanced_buffer_new.load_data(test_file)
+ assert (advanced_buffer_new.replay_buffer_size == 64)
+ assert (advanced_buffer_new.count() == 64 == start_vaildlen + valid_count)
+ assert (advanced_buffer_new.push_count == 64)
+
+ def test_update(self):
+ buffer_cfg = deep_merge_dicts(AdvancedReplayBuffer.default_config(), EasyDict(dict(replay_buffer_size=64)))
+ advanced_buffer = AdvancedReplayBuffer(buffer_cfg, tb_logger=None, instance_name='test')
+ for _ in range(64):
+ advanced_buffer.push(generate_data(), 0)
+ assert advanced_buffer.count() == sum([d is not None for d in advanced_buffer._data])
+ selected_idx = [1, 4, 8, 30, 63]
+ info = {'priority': [], 'replay_unique_id': [], 'replay_buffer_idx': []}
+ for idx in selected_idx:
+ info['priority'].append(np.random.uniform() + 64 - idx)
+ info['replay_unique_id'].append(advanced_buffer._data[idx]['replay_unique_id'])
+ info['replay_buffer_idx'].append(advanced_buffer._data[idx]['replay_buffer_idx'])
+
+ for _ in range(8):
+ advanced_buffer.push(generate_data(), 0)
+ origin_data = copy.deepcopy(advanced_buffer._data)
+ advanced_buffer.update(info)
+ assert (np.argmax(info['priority']) == 0)
+ assert (advanced_buffer._max_priority == max(info['priority'][2:]))
+ assert (advanced_buffer._max_priority != max(info['priority']))
+ for i in range(2):
+ assert (origin_data[selected_idx[i]]['priority'] == advanced_buffer._data[selected_idx[i]]['priority'])
+ eps = advanced_buffer._eps
+ for i in range(2, 5):
+ assert (info['priority'][i] + eps == advanced_buffer._data[selected_idx[i]]['priority'])
+ # test case when data is None(such as max use remove)
+ advanced_buffer._data[selected_idx[0]] = None
+ advanced_buffer._valid_count -= 1
+ advanced_buffer.update(info)
+
+ # test beta
+ advanced_buffer.beta = 1.
+ assert (advanced_buffer.beta == 1.)
+
+ def test_sample(self):
+ buffer_cfg = deep_merge_dicts(
+ AdvancedReplayBuffer.default_config(), EasyDict(dict(replay_buffer_size=64, max_use=2))
+ )
+ advanced_buffer = AdvancedReplayBuffer(buffer_cfg, tb_logger=None, instance_name='test')
+ for _ in range(64):
+ data = generate_data()
+ data['priority'] = None
+ advanced_buffer.push(data, 0)
+ use_dict = defaultdict(int)
+ while True:
+ can_sample, _ = advanced_buffer._sample_check(32, 0)
+ if not can_sample:
+ break
+ batch = advanced_buffer.sample(32, 0)
+ assert (len(batch) == 32)
+ assert (all([b['IS'] == 1.0 for b in batch])), [b['IS'] for b in batch] # because priority is not updated
+ idx = [b['replay_buffer_idx'] for b in batch]
+ for i in idx:
+ use_dict[i] += 1
+ assert sum(map(lambda x: x[1] >= advanced_buffer._max_use,
+ use_dict.items())) == advanced_buffer.replay_buffer_size - advanced_buffer.count()
+ for k, v in use_dict.items():
+ if v > advanced_buffer._max_use:
+ assert advanced_buffer._data[k] is None
+
+ for _ in range(64):
+ data = generate_data()
+ data['priority'] = None
+ advanced_buffer.push(data, 0)
+ batch = advanced_buffer.sample(10, 0, sample_range=slice(-20, -2))
+ assert len(batch) == 10
+
+ def test_head_tail(self):
+ buffer_cfg = deep_merge_dicts(
+ AdvancedReplayBuffer.default_config(), EasyDict(dict(replay_buffer_size=64, max_use=4))
+ )
+ advanced_buffer = AdvancedReplayBuffer(buffer_cfg, tb_logger=None, instance_name='test')
+ for i in range(65):
+ advanced_buffer.push(generate_data(), 0)
+ assert advanced_buffer._head == advanced_buffer._tail == 1
+ info = {'replay_unique_id': [], 'replay_buffer_idx': [], 'priority': []}
+ for data in advanced_buffer._data:
+ info['replay_unique_id'].append(data['replay_unique_id'])
+ info['replay_buffer_idx'].append(data['replay_buffer_idx'])
+ info['priority'].append(0.)
+ info['priority'][1] = 1000.
+ advanced_buffer.update(info)
+ while advanced_buffer._data[1] is not None:
+ data = advanced_buffer.sample(1, 0)
+ print(data)
+ advanced_buffer.push({'data_id': '1096'}, 0)
+ assert advanced_buffer._tail == 2
+ assert advanced_buffer._head == 2
+
+ def test_weight(self):
+ buffer_cfg = deep_merge_dicts(
+ AdvancedReplayBuffer.default_config(), EasyDict(dict(replay_buffer_size=64, max_use=1))
+ )
+ advanced_buffer = AdvancedReplayBuffer(buffer_cfg, tb_logger=None, instance_name='test')
+ assert (advanced_buffer.count() == 0) # assert empty buffer
+
+ def get_weights(data_):
+ weights_ = []
+ for d in data_:
+ if 'priority' not in d.keys() or d['priority'] is None:
+ weights_.append(advanced_buffer.max_priority)
+ else:
+ weights_.append(d['priority'])
+ weights_ = np.array(weights_)
+ weights_ = weights_ ** advanced_buffer.alpha
+ return weights_
+
+ # first part(20 elements, smaller than buffer.replay_buffer_size)
+ data = generate_data_list(20)
+ advanced_buffer.push(data, 0)
+
+ assert (advanced_buffer.replay_buffer_size == 64)
+ assert (advanced_buffer.beta == 0.4)
+ assert (advanced_buffer.alpha == 0.6)
+ assert (hasattr(advanced_buffer, '_sum_tree'))
+ assert (hasattr(advanced_buffer, '_min_tree'))
+ assert (advanced_buffer.count() == 20)
+
+ # tree test
+ weights = get_weights(data)
+ assert (np.fabs(weights.sum() - advanced_buffer._sum_tree.reduce()) < 1e-6)
+
+ # second part(80 elements, bigger than buffer.replay_buffer_size)
+ data = generate_data_list(80)
+ advanced_buffer.push(data, 0)
+ assert (advanced_buffer.count() == 64)
+ assert (advanced_buffer._next_unique_id == 20 + 80)
+ assert (advanced_buffer._tail == (20 + 80) % 64)
+ weights = get_weights(data[-64:])
+ assert (np.fabs(weights.sum() - advanced_buffer._sum_tree.reduce()) < 1e-6)
+ weights = get_weights(data[-36:])
+ assert (np.fabs(weights.sum() - advanced_buffer._sum_tree.reduce(start=0, end=36)) < 1e-6)
+
+ @pytest.mark.rate
+ def test_rate_limit(self):
+ buffer_cfg = AdvancedReplayBuffer.default_config()
+ buffer_cfg.replay_buffer_size = 1000
+ buffer_cfg.thruput_controller = EasyDict(
+ push_sample_rate_limit=dict(
+ max=2,
+ min=0.5,
+ ),
+ window_seconds=5,
+ sample_min_limit_ratio=1.5,
+ )
+ prioritized_buffer = AdvancedReplayBuffer(buffer_cfg, tb_logger=None, instance_name='test')
+
+ # Too many samples
+ data = generate_data_list(30)
+ prioritized_buffer.push(data, 0) # push: 30
+ for _ in range(3):
+ _ = prioritized_buffer.sample(19, 0) # sample: 3 * 19 = 57
+ sampled_data = prioritized_buffer.sample(19, 0)
+ assert sampled_data is None
+
+ # Too big batch_size
+ sampled_data = prioritized_buffer.sample(21, 0)
+ assert sampled_data is None
+
+ # Too many pushes
+ assert prioritized_buffer.count() == 30
+ for _ in range(2):
+ data = generate_data_list(30)
+ prioritized_buffer.push(data, 0) # push: 30 + 2 * 30 = 90
+ assert prioritized_buffer.count() == 90
+ data = generate_data_list(30)
+ prioritized_buffer.push(data, 0)
+ assert prioritized_buffer.count() == 90
+
+ # Test thruput_controller
+ cur_sample_count = prioritized_buffer._thruput_controller.history_sample_count
+ cur_push_count = prioritized_buffer._thruput_controller.history_push_count
+ time.sleep(buffer_cfg.thruput_controller.window_seconds)
+ assert abs(prioritized_buffer._thruput_controller.history_sample_count - cur_sample_count *
+ 0.01) < 1e-5, (cur_sample_count, prioritized_buffer._thruput_controller.history_sample_count)
+ assert abs(prioritized_buffer._thruput_controller.history_push_count - cur_push_count *
+ 0.01) < 1e-5, (cur_push_count, prioritized_buffer._thruput_controller.history_push_count)
+
+
+@pytest.mark.unittest(rerun=5)
+class TestDemonstrationBuffer:
+
+ def test_naive(self, setup_demo_buffer_factory):
+ setup_demo_buffer = next(setup_demo_buffer_factory)
+ naive_demo_buffer = next(setup_demo_buffer_factory)
+ while True:
+ with open(demo_data_path, 'rb+') as f:
+ data = pickle.load(f)
+ if len(data) != 0:
+ break
+ else: # for the stability of dist-test
+ demo_data = {'data': generate_data_list(10)}
+ with open(demo_data_path, "wb") as f:
+ pickle.dump(demo_data, f)
+
+ setup_demo_buffer.load_state_dict(data)
+ assert setup_demo_buffer.count() == len(data['data']) # assert buffer not empty
+ samples = setup_demo_buffer.sample(3, 0)
+ assert 'staleness' in samples[0]
+ assert samples[1]['staleness'] == -1
+ assert len(samples) == 3
+ update_info = {'replay_unique_id': ['demo_0', 'demo_2'], 'replay_buffer_idx': [0, 2], 'priority': [1.33, 1.44]}
+ setup_demo_buffer.update(update_info)
+ samples = setup_demo_buffer.sample(10, 0)
+ for sample in samples:
+ if sample['replay_unique_id'] == 'demo_0':
+ assert abs(sample['priority'] - 1.33) <= 0.01 + 1e-5, sample
+ if sample['replay_unique_id'] == 'demo_2':
+ assert abs(sample['priority'] - 1.44) <= 0.02 + 1e-5, sample
+
+ state_dict = setup_demo_buffer.state_dict()
+ naive_demo_buffer.load_state_dict(state_dict, deepcopy=True)
+ assert naive_demo_buffer._tail == setup_demo_buffer._tail
+ assert naive_demo_buffer._max_priority == setup_demo_buffer._max_priority
+
+ os.popen('rm -rf log')
+ os.popen('rm -rf {}'.format(demo_data_path))
diff --git a/DI-engine/ding/worker/replay_buffer/tests/test_naive_buffer.py b/DI-engine/ding/worker/replay_buffer/tests/test_naive_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0122669f9de5d7eefe4ae7d33ac3aa0bad6dbcc
--- /dev/null
+++ b/DI-engine/ding/worker/replay_buffer/tests/test_naive_buffer.py
@@ -0,0 +1,112 @@
+import pytest
+from easydict import EasyDict
+import os
+import time
+import tempfile
+
+from ding.worker.replay_buffer import NaiveReplayBuffer
+from ding.utils import deep_merge_dicts
+from ding.worker.replay_buffer.tests.conftest import generate_data, generate_data_list
+
+
+@pytest.mark.unittest
+class TestNaiveBuffer:
+
+ def test_push(self):
+ buffer_cfg = deep_merge_dicts(NaiveReplayBuffer.default_config(), EasyDict(dict(replay_buffer_size=64)))
+ naive_buffer = NaiveReplayBuffer(buffer_cfg, instance_name='test')
+ start_pointer = naive_buffer._tail
+ start_vaildlen = naive_buffer.count()
+ valid_count = 0
+ for _ in range(100):
+ if naive_buffer._data[naive_buffer._tail] is None:
+ valid_count += 1
+ naive_buffer.push(generate_data(), 0)
+ assert (naive_buffer.replay_buffer_size == 64)
+ assert (naive_buffer.count() == 64 == start_vaildlen + valid_count)
+ assert (naive_buffer.push_count == start_vaildlen + 100)
+ assert (naive_buffer._tail == (start_pointer + 100) % naive_buffer.replay_buffer_size)
+ naive_buffer.update({'no_info': True})
+
+ buffer_cfg = deep_merge_dicts(NaiveReplayBuffer.default_config(), EasyDict(dict(replay_buffer_size=64)))
+ naive_buffer = NaiveReplayBuffer(buffer_cfg, instance_name='test')
+ start_pointer = naive_buffer._tail
+ replay_buffer_size = naive_buffer.replay_buffer_size
+ extend_num = int(0.6 * replay_buffer_size)
+ for i in range(1, 4):
+ data = generate_data_list(extend_num)
+ naive_buffer.push(data, 0)
+ assert naive_buffer._tail == (start_pointer + extend_num * i) % replay_buffer_size
+
+ def test_save_and_load_data(self):
+ buffer_cfg = deep_merge_dicts(NaiveReplayBuffer.default_config(), EasyDict(dict(replay_buffer_size=64)))
+ naive_buffer = NaiveReplayBuffer(buffer_cfg, instance_name='test')
+ start_pointer = naive_buffer._tail
+ start_vaildlen = naive_buffer.count()
+ valid_count = 0
+ for _ in range(100):
+ if naive_buffer._data[naive_buffer._tail] is None:
+ valid_count += 1
+ naive_buffer.push(generate_data(), 0)
+ assert (naive_buffer.replay_buffer_size == 64)
+ assert (naive_buffer.count() == 64 == start_vaildlen + valid_count)
+ assert (naive_buffer.push_count == start_vaildlen + 100)
+ assert (naive_buffer._tail == (start_pointer + 100) % naive_buffer.replay_buffer_size)
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ test_file = os.path.join(tmpdirname, "data.hkl")
+ naive_buffer.save_data(test_file)
+ naive_buffer_new = NaiveReplayBuffer(buffer_cfg, instance_name='test_new')
+ naive_buffer_new.load_data(test_file)
+ assert (naive_buffer_new.replay_buffer_size == 64)
+ assert (naive_buffer_new.count() == 64 == start_vaildlen + valid_count)
+ assert (naive_buffer_new.push_count == 64)
+
+ def test_sample(self):
+ buffer_cfg = deep_merge_dicts(NaiveReplayBuffer.default_config(), EasyDict(dict(replay_buffer_size=64)))
+ naive_buffer = NaiveReplayBuffer(buffer_cfg, instance_name='test')
+ for _ in range(64):
+ naive_buffer.push(generate_data(), 0)
+ batch = naive_buffer.sample(32, 0)
+ assert len(batch) == 32
+ last_one_batch = naive_buffer.sample(1, 0, sample_range=slice(-1, None))
+ assert len(last_one_batch) == 1
+ assert last_one_batch[0] == naive_buffer._data[-1]
+ batch = naive_buffer.sample(5, 0, sample_range=slice(-10, -2))
+ sample_range_data = naive_buffer._data[-10:-2]
+ assert len(batch) == 5
+ for b in batch:
+ assert any([b['data_id'] == d['data_id'] for d in sample_range_data])
+
+ # test clear
+ naive_buffer.clear()
+ assert naive_buffer.count() == 0
+
+ @pytest.mark.used
+ def test_track_used_data(self):
+ buffer_cfg = deep_merge_dicts(
+ NaiveReplayBuffer.default_config(), EasyDict(dict(replay_buffer_size=10, enable_track_used_data=True))
+ )
+ naive_buffer = NaiveReplayBuffer(buffer_cfg, instance_name='test')
+ naive_buffer.start()
+
+ old_data_list = generate_data_list(10, meta=True)
+ naive_buffer.push(old_data_list, 0)
+ for data in old_data_list:
+ assert os.path.exists(data['data_id'])
+ assert naive_buffer.count() == 10
+ new_data_list = generate_data_list(8, meta=True)
+ naive_buffer.push(new_data_list, 0)
+ assert naive_buffer.count() == 10
+ for data in new_data_list:
+ assert os.path.exists(data['data_id'])
+ time.sleep(1)
+ for data in old_data_list[:8]:
+ assert not os.path.exists(data['data_id'])
+ naive_buffer.clear()
+ time.sleep(1)
+ for data in old_data_list[9:]:
+ assert not os.path.exists(data['data_id'])
+ for data in new_data_list:
+ assert not os.path.exists(data['data_id'])
+
+ naive_buffer.close()
diff --git a/DI-engine/ding/worker/replay_buffer/utils.py b/DI-engine/ding/worker/replay_buffer/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..07e2896085c1787c10433e6aba577c917d5714fd
--- /dev/null
+++ b/DI-engine/ding/worker/replay_buffer/utils.py
@@ -0,0 +1,284 @@
+from typing import Any
+import time
+from queue import Queue
+from typing import Union, Tuple
+from threading import Thread
+from functools import partial
+
+from ding.utils.autolog import LoggedValue, LoggedModel
+from ding.utils import LockContext, LockContextType, remove_file
+
+
+def generate_id(name, data_id: int) -> str:
+ """
+ Overview:
+ Use ``self.name`` and input ``id`` to generate a unique id for next data to be inserted.
+ Arguments:
+ - data_id (:obj:`int`): Current unique id.
+ Returns:
+ - id (:obj:`str`): Id in format "BufferName_DataId".
+ """
+ return "{}_{}".format(name, str(data_id))
+
+
+class UsedDataRemover:
+ """
+ Overview:
+ UsedDataRemover is a tool to remove file datas that will no longer be used anymore.
+ Interface:
+ start, close, add_used_data
+ """
+
+ def __init__(self) -> None:
+ self._used_data = Queue()
+ self._delete_used_data_thread = Thread(target=self._delete_used_data, name='delete_used_data')
+ self._delete_used_data_thread.daemon = True
+ self._end_flag = True
+
+ def start(self) -> None:
+ """
+ Overview:
+ Start the `delete_used_data` thread.
+ """
+ self._end_flag = False
+ self._delete_used_data_thread.start()
+
+ def close(self) -> None:
+ """
+ Overview:
+ Delete all datas in `self._used_data`. Then join the `delete_used_data` thread.
+ """
+ while not self._used_data.empty():
+ data_id = self._used_data.get()
+ remove_file(data_id)
+ self._end_flag = True
+
+ def add_used_data(self, data: Any) -> None:
+ """
+ Overview:
+ Delete all datas in `self._used_data`. Then join the `delete_used_data` thread.
+ Arguments:
+ - data (:obj:`Any`): Add a used data item into `self._used_data` for further remove.
+ """
+ assert data is not None and isinstance(data, dict) and 'data_id' in data
+ self._used_data.put(data['data_id'])
+
+ def _delete_used_data(self) -> None:
+ while not self._end_flag:
+ if not self._used_data.empty():
+ data_id = self._used_data.get()
+ remove_file(data_id)
+ else:
+ time.sleep(0.001)
+
+
+class SampledDataAttrMonitor(LoggedModel):
+ """
+ Overview:
+ SampledDataAttrMonitor is to monitor read-out indicators for ``expire`` times recent read-outs.
+ Indicators include: read out time; average and max of read out data items' use; average, max and min of
+ read out data items' priorityl; average and max of staleness.
+ Interface:
+ __init__, fixed_time, current_time, freeze, unfreeze, register_attribute_value, __getattr__
+ Property:
+ time, expire
+ """
+ use_max = LoggedValue(int)
+ use_avg = LoggedValue(float)
+ priority_max = LoggedValue(float)
+ priority_avg = LoggedValue(float)
+ priority_min = LoggedValue(float)
+ staleness_max = LoggedValue(int)
+ staleness_avg = LoggedValue(float)
+
+ def __init__(self, time_: 'BaseTime', expire: Union[int, float]): # noqa
+ LoggedModel.__init__(self, time_, expire)
+ self.__register()
+
+ def __register(self):
+
+ def __avg_func(prop_name: str) -> float:
+ records = self.range_values[prop_name]()
+ _list = [_value for (_begin_time, _end_time), _value in records]
+ return sum(_list) / len(_list) if len(_list) != 0 else 0
+
+ def __max_func(prop_name: str) -> Union[float, int]:
+ records = self.range_values[prop_name]()
+ _list = [_value for (_begin_time, _end_time), _value in records]
+ return max(_list) if len(_list) != 0 else 0
+
+ def __min_func(prop_name: str) -> Union[float, int]:
+ records = self.range_values[prop_name]()
+ _list = [_value for (_begin_time, _end_time), _value in records]
+ return min(_list) if len(_list) != 0 else 0
+
+ self.register_attribute_value('avg', 'use', partial(__avg_func, prop_name='use_avg'))
+ self.register_attribute_value('max', 'use', partial(__max_func, prop_name='use_max'))
+ self.register_attribute_value('avg', 'priority', partial(__avg_func, prop_name='priority_avg'))
+ self.register_attribute_value('max', 'priority', partial(__max_func, prop_name='priority_max'))
+ self.register_attribute_value('min', 'priority', partial(__min_func, prop_name='priority_min'))
+ self.register_attribute_value('avg', 'staleness', partial(__avg_func, prop_name='staleness_avg'))
+ self.register_attribute_value('max', 'staleness', partial(__max_func, prop_name='staleness_max'))
+
+
+class PeriodicThruputMonitor:
+ """
+ Overview:
+ PeriodicThruputMonitor is a tool to record and print logs(text & tensorboard) how many datas are
+ pushed/sampled/removed/valid in a period of time. For tensorboard, you can view it in 'buffer_{$NAME}_sec'.
+ Interface:
+ close
+ Property:
+ push_data_count, sample_data_count, remove_data_count, valid_count
+
+ .. note::
+ `thruput_log` thread is initialized and started in `__init__` method, so PeriodicThruputMonitor only provide
+ one signle interface `close`
+ """
+
+ def __init__(self, name, cfg, logger, tb_logger) -> None:
+ self.name = name
+ self._end_flag = False
+ self._logger = logger
+ self._tb_logger = tb_logger
+ self._thruput_print_seconds = cfg.seconds
+ self._thruput_print_times = 0
+ self._thruput_start_time = time.time()
+ self._history_push_count = 0
+ self._history_sample_count = 0
+ self._remove_data_count = 0
+ self._valid_count = 0
+ self._thruput_log_thread = Thread(target=self._thrput_print_periodically, args=(), name='periodic_thruput_log')
+ self._thruput_log_thread.daemon = True
+ self._thruput_log_thread.start()
+
+ def _thrput_print_periodically(self) -> None:
+ while not self._end_flag:
+ time_passed = time.time() - self._thruput_start_time
+ if time_passed >= self._thruput_print_seconds:
+ self._logger.info('In the past {:.1f} seconds, buffer statistics is as follows:'.format(time_passed))
+ count_dict = {
+ 'pushed_in': self._history_push_count,
+ 'sampled_out': self._history_sample_count,
+ 'removed': self._remove_data_count,
+ 'current_have': self._valid_count,
+ }
+ self._logger.info(self._logger.get_tabulate_vars_hor(count_dict))
+ for k, v in count_dict.items():
+ self._tb_logger.add_scalar('{}_sec/'.format(self.name) + k, v, self._thruput_print_times)
+ self._history_push_count = 0
+ self._history_sample_count = 0
+ self._remove_data_count = 0
+ self._thruput_start_time = time.time()
+ self._thruput_print_times += 1
+ else:
+ time.sleep(min(1, self._thruput_print_seconds * 0.2))
+
+ def close(self) -> None:
+ """
+ Overview:
+ Join the `thruput_log` thread by setting `self._end_flag` to `True`.
+ """
+ self._end_flag = True
+
+ def __del__(self) -> None:
+ self.close()
+
+ @property
+ def push_data_count(self) -> int:
+ return self._history_push_count
+
+ @push_data_count.setter
+ def push_data_count(self, count) -> None:
+ self._history_push_count = count
+
+ @property
+ def sample_data_count(self) -> int:
+ return self._history_sample_count
+
+ @sample_data_count.setter
+ def sample_data_count(self, count) -> None:
+ self._history_sample_count = count
+
+ @property
+ def remove_data_count(self) -> int:
+ return self._remove_data_count
+
+ @remove_data_count.setter
+ def remove_data_count(self, count) -> None:
+ self._remove_data_count = count
+
+ @property
+ def valid_count(self) -> int:
+ return self._valid_count
+
+ @valid_count.setter
+ def valid_count(self, count) -> None:
+ self._valid_count = count
+
+
+class ThruputController:
+
+ def __init__(self, cfg) -> None:
+ self._push_sample_rate_limit = cfg.push_sample_rate_limit
+ assert 'min' in self._push_sample_rate_limit and self._push_sample_rate_limit['min'] >= 0
+ assert 'max' in self._push_sample_rate_limit and self._push_sample_rate_limit['max'] <= float("inf")
+ window_seconds = cfg.window_seconds
+ self._decay_factor = 0.01 ** (1 / window_seconds)
+
+ self._push_lock = LockContext(type_=LockContextType.THREAD_LOCK)
+ self._sample_lock = LockContext(type_=LockContextType.THREAD_LOCK)
+ self._history_push_count = 0
+ self._history_sample_count = 0
+
+ self._end_flag = False
+ self._count_decay_thread = Thread(target=self._count_decay, name='count_decay')
+ self._count_decay_thread.daemon = True
+ self._count_decay_thread.start()
+
+ def _count_decay(self) -> None:
+ while not self._end_flag:
+ time.sleep(1)
+ with self._push_lock:
+ self._history_push_count *= self._decay_factor
+ with self._sample_lock:
+ self._history_sample_count *= self._decay_factor
+
+ def can_push(self, push_size: int) -> Tuple[bool, str]:
+ if abs(self._history_sample_count) < 1e-5:
+ return True, "Can push because `self._history_sample_count` < 1e-5"
+ rate = (self._history_push_count + push_size) / self._history_sample_count
+ if rate > self._push_sample_rate_limit['max']:
+ return False, "push({}+{}) / sample({}) > limit_max({})".format(
+ self._history_push_count, push_size, self._history_sample_count, self._push_sample_rate_limit['max']
+ )
+ return True, "Can push."
+
+ def can_sample(self, sample_size: int) -> Tuple[bool, str]:
+ rate = self._history_push_count / (self._history_sample_count + sample_size)
+ if rate < self._push_sample_rate_limit['min']:
+ return False, "push({}) / sample({}+{}) < limit_min({})".format(
+ self._history_push_count, self._history_sample_count, sample_size, self._push_sample_rate_limit['min']
+ )
+ return True, "Can sample."
+
+ def close(self) -> None:
+ self._end_flag = True
+
+ @property
+ def history_push_count(self) -> int:
+ return self._history_push_count
+
+ @history_push_count.setter
+ def history_push_count(self, count) -> None:
+ with self._push_lock:
+ self._history_push_count = count
+
+ @property
+ def history_sample_count(self) -> int:
+ return self._history_sample_count
+
+ @history_sample_count.setter
+ def history_sample_count(self, count) -> None:
+ with self._sample_lock:
+ self._history_sample_count = count
diff --git a/DI-engine/ding/world_model/__init__.py b/DI-engine/ding/world_model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..45cb5bea3507858c233d0b48960d19ba955cf7d3
--- /dev/null
+++ b/DI-engine/ding/world_model/__init__.py
@@ -0,0 +1,2 @@
+from .base_world_model import WorldModel, DynaWorldModel, DreamWorldModel, HybridWorldModel, \
+ get_world_model_cls, create_world_model
diff --git a/DI-engine/ding/world_model/base_world_model.py b/DI-engine/ding/world_model/base_world_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..941710c2305a14bdf7f15ac454d4a0082bcecf25
--- /dev/null
+++ b/DI-engine/ding/world_model/base_world_model.py
@@ -0,0 +1,362 @@
+from typing import Tuple, Callable, Optional
+from collections import namedtuple
+from abc import ABC, abstractmethod
+
+import torch
+from torch import Tensor, nn
+from easydict import EasyDict
+
+from ding.worker import IBuffer
+from ding.envs import BaseEnv
+from ding.utils import deep_merge_dicts
+from ding.world_model.utils import get_rollout_length_scheduler
+
+from ding.utils import import_module, WORLD_MODEL_REGISTRY
+
+
+def get_world_model_cls(cfg):
+ import_module(cfg.get('import_names', []))
+ return WORLD_MODEL_REGISTRY.get(cfg.type)
+
+
+def create_world_model(cfg, *args, **kwargs):
+ import_module(cfg.get('import_names', []))
+ return WORLD_MODEL_REGISTRY.build(cfg.type, cfg, *args, **kwargs)
+
+
+class WorldModel(ABC):
+ r"""
+ Overview:
+ Abstract baseclass for world model.
+
+ Interfaces:
+ should_train, should_eval, train, eval, step
+ """
+
+ config = dict(
+ train_freq=250, # w.r.t environment step
+ eval_freq=250, # w.r.t environment step
+ cuda=True,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=20000,
+ rollout_end_step=150000,
+ rollout_length_min=1,
+ rollout_length_max=25,
+ )
+ )
+
+ def __init__(self, cfg: dict, env: BaseEnv, tb_logger: 'SummaryWriter'): # noqa
+ self.cfg = cfg
+ self.env = env
+ self.tb_logger = tb_logger
+
+ self._cuda = cfg.cuda
+ self.train_freq = cfg.train_freq
+ self.eval_freq = cfg.eval_freq
+ self.rollout_length_scheduler = get_rollout_length_scheduler(cfg.rollout_length_scheduler)
+
+ self.last_train_step = 0
+ self.last_eval_step = 0
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ # can not call default_config() recursively
+ # because config will be overwritten by subclasses
+ merge_cfg = EasyDict(cfg_type=cls.__name__ + 'Dict')
+ while cls != ABC:
+ merge_cfg = deep_merge_dicts(merge_cfg, cls.config)
+ cls = cls.__base__
+ return merge_cfg
+
+ def should_train(self, envstep: int):
+ r"""
+ Overview:
+ Check whether need to train world model.
+ """
+ return (envstep - self.last_train_step) >= self.train_freq
+
+ def should_eval(self, envstep: int):
+ r"""
+ Overview:
+ Check whether need to evaluate world model.
+ """
+ return (envstep - self.last_eval_step) >= self.eval_freq and self.last_train_step != 0
+
+ @abstractmethod
+ def train(self, env_buffer: IBuffer, envstep: int, train_iter: int):
+ r"""
+ Overview:
+ Train world model using data from env_buffer.
+
+ Arguments:
+ - env_buffer (:obj:`IBuffer`): the buffer which collects real environment steps
+ - envstep (:obj:`int`): the current number of environment steps in real environment
+ - train_iter (:obj:`int`): the current number of policy training iterations
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def eval(self, env_buffer: IBuffer, envstep: int, train_iter: int):
+ r"""
+ Overview:
+ Evaluate world model using data from env_buffer.
+
+ Arguments:
+ - env_buffer (:obj:`IBuffer`): the buffer that collects real environment steps
+ - envstep (:obj:`int`): the current number of environment steps in real environment
+ - train_iter (:obj:`int`): the current number of policy training iterations
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def step(self, obs: Tensor, action: Tensor) -> Tuple[Tensor, Tensor, Tensor]:
+ r"""
+ Overview:
+ Take one step in world model.
+
+ Arguments:
+ - obs (:obj:`torch.Tensor`): current observations :math:`S_t`
+ - action (:obj:`torch.Tensor`): current actions :math:`A_t`
+
+ Returns:
+ - reward (:obj:`torch.Tensor`): rewards :math:`R_t`
+ - next_obs (:obj:`torch.Tensor`): next observations :math:`S_t+1`
+ - done (:obj:`torch.Tensor`): whether the episodes ends
+
+ Shapes:
+ :math:`B`: batch size
+ :math:`O`: observation dimension
+ :math:`A`: action dimension
+
+ - obs: [B, O]
+ - action: [B, A]
+ - reward: [B, ]
+ - next_obs: [B, O]
+ - done: [B, ]
+ """
+ raise NotImplementedError
+
+
+class DynaWorldModel(WorldModel, ABC):
+ r"""
+ Overview:
+ Dyna-style world model (summarized in arXiv: 1907.02057) which stores and\
+ reuses imagination rollout in the imagination buffer.
+
+ Interfaces:
+ sample, fill_img_buffer, should_train, should_eval, train, eval, step
+ """
+
+ config = dict(
+ other=dict(
+ real_ratio=0.05,
+ rollout_retain=4,
+ rollout_batch_size=100000,
+ imagination_buffer=dict(
+ type='elastic',
+ replay_buffer_size=6000000,
+ deepcopy=False,
+ enable_track_used_data=False,
+ # set_buffer_size=set_buffer_size,
+ periodic_thruput_seconds=60,
+ ),
+ )
+ )
+
+ def __init__(self, cfg: dict, env: BaseEnv, tb_logger: 'SummaryWriter'): # noqa
+ super().__init__(cfg, env, tb_logger)
+ self.real_ratio = cfg.other.real_ratio
+ self.rollout_batch_size = cfg.other.rollout_batch_size
+ self.rollout_retain = cfg.other.rollout_retain
+ self.buffer_size_scheduler = \
+ lambda x: self.rollout_length_scheduler(x) * self.rollout_batch_size * self.rollout_retain
+
+ def sample(self, env_buffer: IBuffer, img_buffer: IBuffer, batch_size: int, train_iter: int) -> dict:
+ r"""
+ Overview:
+ Sample from the combination of environment buffer and imagination buffer with\
+ certain ratio to generate batched data for policy training.
+
+ Arguments:
+ - policy (:obj:`namedtuple`): policy in collect mode
+ - env_buffer (:obj:`IBuffer`): the buffer that collects real environment steps
+ - img_buffer (:obj:`IBuffer`): the buffer that collects imagination steps
+ - batch_size (:obj:`int`): the batch size for policy training
+ - train_iter (:obj:`int`): the current number of policy training iterations
+
+ Returns:
+ - data (:obj:`int`): the training data for policy training
+ """
+ env_batch_size = int(batch_size * self.real_ratio)
+ img_batch_size = batch_size - env_batch_size
+ env_data = env_buffer.sample(env_batch_size, train_iter)
+ img_data = img_buffer.sample(img_batch_size, train_iter)
+ train_data = env_data + img_data
+ return train_data
+
+ def fill_img_buffer(
+ self, policy: namedtuple, env_buffer: IBuffer, img_buffer: IBuffer, envstep: int, train_iter: int
+ ):
+ r"""
+ Overview:
+ Sample from the env_buffer, rollouts to generate new data, and push them into the img_buffer.
+
+ Arguments:
+ - policy (:obj:`namedtuple`): policy in collect mode
+ - env_buffer (:obj:`IBuffer`): the buffer that collects real environment steps
+ - img_buffer (:obj:`IBuffer`): the buffer that collects imagination steps
+ - envstep (:obj:`int`): the current number of environment steps in real environment
+ - train_iter (:obj:`int`): the current number of policy training iterations
+ """
+ from ding.torch_utils import to_tensor
+ from ding.envs import BaseEnvTimestep
+ from ding.worker.collector.base_serial_collector import to_tensor_transitions
+
+ def step(obs, act):
+ # This function has the same input and output format as env manager's step
+ data_id = list(obs.keys())
+ obs = torch.stack([obs[id] for id in data_id], dim=0)
+ act = torch.stack([act[id] for id in data_id], dim=0)
+ with torch.no_grad():
+ rewards, next_obs, terminals = self.step(obs, act)
+ # terminals = self.termination_fn(next_obs)
+ timesteps = {
+ id: BaseEnvTimestep(n, r, d, {})
+ for id, n, r, d in zip(
+ data_id,
+ next_obs.cpu().numpy(),
+ rewards.unsqueeze(-1).cpu().numpy(), # ding api
+ terminals.cpu().numpy()
+ )
+ }
+ return timesteps
+
+ # set rollout length
+ rollout_length = self.rollout_length_scheduler(envstep)
+ # load data
+ data = env_buffer.sample(self.rollout_batch_size, train_iter, replace=True)
+ obs = {id: data[id]['obs'] for id in range(len(data))}
+ # rollout
+ buffer = [[] for id in range(len(obs))]
+ new_data = []
+ for i in range(rollout_length):
+ # get action
+ obs = to_tensor(obs, dtype=torch.float32)
+ policy_output = policy.forward(obs)
+ actions = {id: output['action'] for id, output in policy_output.items()}
+ # predict next obs and reward
+ # timesteps = self.step(obs, actions, env_model)
+ timesteps = step(obs, actions)
+ obs_new = {}
+ for id, timestep in timesteps.items():
+ transition = policy.process_transition(obs[id], policy_output[id], timestep)
+ transition['collect_iter'] = train_iter
+ buffer[id].append(transition)
+ if not timestep.done:
+ obs_new[id] = timestep.obs
+ if timestep.done or i + 1 == rollout_length:
+ transitions = to_tensor_transitions(buffer[id])
+ train_sample = policy.get_train_sample(transitions)
+ new_data.extend(train_sample)
+ if len(obs_new) == 0:
+ break
+ obs = obs_new
+
+ img_buffer.push(new_data, cur_collector_envstep=envstep)
+
+
+class DreamWorldModel(WorldModel, ABC):
+ r"""
+ Overview:
+ Dreamer-style world model which uses each imagination rollout only once\
+ and backpropagate through time(rollout) to optimize policy.
+
+ Interfaces:
+ rollout, should_train, should_eval, train, eval, step
+ """
+
+ def rollout(self, obs: Tensor, actor_fn: Callable[[Tensor], Tuple[Tensor, Tensor]], envstep: int,
+ **kwargs) -> Tuple[Tensor, Tensor, Tensor, Tensor, Tensor, Optional[bool]]:
+ r"""
+ Overview:
+ Generate batched imagination rollouts starting from the current observations.\
+ This function is useful for value gradients where the policy is optimized by BPTT.
+
+ Arguments:
+ - obs (:obj:`Tensor`): the current observations :math:`S_t`
+ - actor_fn (:obj:`Callable`): the unified API :math:`(A_t, H_t) = pi(S_t)`
+ - envstep (:obj:`int`): the current number of environment steps in real environment
+
+ Returns:
+ - obss (:obj:`Tensor`): :math:`S_t, ..., S_t+n`
+ - actions (:obj:`Tensor`): :math:`A_t, ..., A_t+n`
+ - rewards (:obj:`Tensor`): :math:`R_t, ..., R_t+n-1`
+ - aug_rewards (:obj:`Tensor`): :math:`H_t, ..., H_t+n`, this can be entropy bonus as in SAC,
+ otherwise it should be a zero tensor
+ - dones (:obj:`Tensor`): :math:`\text{done}_t, ..., \text{done}_t+n`
+
+ Shapes:
+ :math:`N`: time step
+ :math:`B`: batch size
+ :math:`O`: observation dimension
+ :math:`A`: action dimension
+
+ - obss: :math:`[N+1, B, O]`, where obss[0] are the real observations
+ - actions: :math:`[N+1, B, A]`
+ - rewards: :math:`[N, B]`
+ - aug_rewards: :math:`[N+1, B]`
+ - dones: :math:`[N, B]`
+
+ .. note::
+ - The rollout length is determined by rollout length scheduler.
+
+ - actor_fn's inputs and outputs shape are similar to WorldModel.step()
+ """
+ horizon = self.rollout_length_scheduler(envstep)
+ if isinstance(self, nn.Module):
+ # Rollouts should propagate gradients only to policy,
+ # so make sure that the world model is not updated by rollout.
+ self.requires_grad_(False)
+ obss = [obs]
+ actions = []
+ rewards = []
+ aug_rewards = [] # -temperature*logprob
+ dones = []
+ for _ in range(horizon):
+ action, aug_reward = actor_fn(obs)
+ # done: probability of termination
+ reward, obs, done = self.step(obs, action, **kwargs)
+ reward = reward + aug_reward
+ obss.append(obs)
+ actions.append(action)
+ rewards.append(reward)
+ aug_rewards.append(aug_reward)
+ dones.append(done)
+ action, aug_reward = actor_fn(obs)
+ actions.append(action)
+ aug_rewards.append(aug_reward)
+ if isinstance(self, nn.Module):
+ self.requires_grad_(True)
+ return (
+ torch.stack(obss),
+ torch.stack(actions),
+ # rewards is an empty list when horizon=0
+ torch.stack(rewards) if rewards else torch.tensor(rewards, device=obs.device),
+ torch.stack(aug_rewards),
+ torch.stack(dones) if dones else torch.tensor(dones, device=obs.device)
+ )
+
+
+class HybridWorldModel(DynaWorldModel, DreamWorldModel, ABC):
+ r"""
+ Overview:
+ The hybrid model that combines reused and on-the-fly rollouts.
+
+ Interfaces:
+ rollout, sample, fill_img_buffer, should_train, should_eval, train, eval, step
+ """
+
+ def __init__(self, cfg: dict, env: BaseEnv, tb_logger: 'SummaryWriter'): # noqa
+ DynaWorldModel.__init__(self, cfg, env, tb_logger)
+ DreamWorldModel.__init__(self, cfg, env, tb_logger)
diff --git a/DI-engine/ding/world_model/ddppo.py b/DI-engine/ding/world_model/ddppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..075c2aa63b7082c72b1eeaa9434fd8b21b6e600b
--- /dev/null
+++ b/DI-engine/ding/world_model/ddppo.py
@@ -0,0 +1,523 @@
+from functools import partial
+from ditk import logging
+import itertools
+import copy
+import numpy as np
+import multiprocessing
+import torch
+import torch.nn as nn
+
+from ding.utils import WORLD_MODEL_REGISTRY
+from ding.utils.data import default_collate
+from ding.torch_utils import unsqueeze_repeat
+from ding.world_model.base_world_model import HybridWorldModel
+from ding.world_model.model.ensemble import EnsembleModel, StandardScaler
+
+
+#======================= Helper functions =======================
+# tree_query = lambda datapoint: tree.query(datapoint, k=k+1)[1][1:]
+def tree_query(datapoint, tree, k):
+ return tree.query(datapoint, k=k + 1)[1][1:]
+
+
+def get_neighbor_index(data, k, serial=False):
+ """
+ data: [B, N]
+ k: int
+
+ ret: [B, k]
+ """
+ try:
+ from scipy.spatial import KDTree
+ except ImportError:
+ import sys
+ logging.warning("Please install scipy first, such as `pip3 install scipy`.")
+ sys.exit(1)
+ data = data.cpu().numpy()
+ tree = KDTree(data)
+
+ if serial:
+ nn_index = [torch.from_numpy(np.array(tree_query(d, tree, k))) for d in data]
+ nn_index = torch.stack(nn_index).long()
+ else:
+ # TODO: speed up multiprocessing
+ pool = multiprocessing.Pool(processes=multiprocessing.cpu_count())
+ fn = partial(tree_query, tree=tree, k=k)
+ nn_index = torch.from_numpy(np.array(list(pool.map(fn, data)), dtype=np.int32)).to(torch.long)
+ pool.close()
+ return nn_index
+
+
+def get_batch_jacobian(net, x, noutputs): # x: b, in dim, noutpouts: out dim
+ x = x.unsqueeze(1) # b, 1 ,in_dim
+ n = x.size()[0]
+ x = x.repeat(1, noutputs, 1) # b, out_dim, in_dim
+ x.requires_grad_(True)
+ y = net(x)
+ upstream_gradient = torch.eye(noutputs).reshape(1, noutputs, noutputs).repeat(n, 1, 1).to(x.device)
+ re = torch.autograd.grad(y, x, upstream_gradient, create_graph=True)[0]
+
+ return re
+
+
+class EnsembleGradientModel(EnsembleModel):
+
+ def train(self, loss, loss_reg, reg):
+ self.optimizer.zero_grad()
+
+ loss += 0.01 * torch.sum(self.max_logvar) - 0.01 * torch.sum(self.min_logvar)
+ loss += reg * loss_reg
+ if self.use_decay:
+ loss += self.get_decay_loss()
+
+ loss.backward()
+
+ self.optimizer.step()
+
+
+# TODO: derive from MBPO instead of implementing from scratch
+@WORLD_MODEL_REGISTRY.register('ddppo')
+class DDPPOWorldMode(HybridWorldModel, nn.Module):
+ """rollout model + gradient model"""
+ config = dict(
+ model=dict(
+ ensemble_size=7,
+ elite_size=5,
+ state_size=None, # has to be specified
+ action_size=None, # has to be specified
+ reward_size=1,
+ hidden_size=200,
+ use_decay=False,
+ batch_size=256,
+ holdout_ratio=0.2,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ # parameters for DDPPO
+ gradient_model=True,
+ k=3,
+ reg=1,
+ neighbor_pool_size=10000,
+ train_freq_gradient_model=250
+ ),
+ )
+
+ def __init__(self, cfg, env, tb_logger):
+ HybridWorldModel.__init__(self, cfg, env, tb_logger)
+ nn.Module.__init__(self)
+
+ cfg = cfg.model
+ self.ensemble_size = cfg.ensemble_size
+ self.elite_size = cfg.elite_size
+ self.state_size = cfg.state_size
+ self.action_size = cfg.action_size
+ self.reward_size = cfg.reward_size
+ self.hidden_size = cfg.hidden_size
+ self.use_decay = cfg.use_decay
+ self.batch_size = cfg.batch_size
+ self.holdout_ratio = cfg.holdout_ratio
+ self.max_epochs_since_update = cfg.max_epochs_since_update
+ self.deterministic_rollout = cfg.deterministic_rollout
+ # parameters for DDPPO
+ self.gradient_model = cfg.gradient_model
+ self.k = cfg.k
+ self.reg = cfg.reg
+ self.neighbor_pool_size = cfg.neighbor_pool_size
+ self.train_freq_gradient_model = cfg.train_freq_gradient_model
+
+ self.rollout_model = EnsembleModel(
+ self.state_size,
+ self.action_size,
+ self.reward_size,
+ self.ensemble_size,
+ self.hidden_size,
+ use_decay=self.use_decay
+ )
+ self.scaler = StandardScaler(self.state_size + self.action_size)
+
+ self.ensemble_mse_losses = []
+ self.model_variances = []
+ self.elite_model_idxes = []
+
+ if self.gradient_model:
+ self.gradient_model = EnsembleGradientModel(
+ self.state_size,
+ self.action_size,
+ self.reward_size,
+ self.ensemble_size,
+ self.hidden_size,
+ use_decay=self.use_decay
+ )
+ self.elite_model_idxes_gradient_model = []
+
+ self.last_train_step_gradient_model = 0
+ self.serial_calc_nn = False
+
+ if self._cuda:
+ self.cuda()
+
+ def step(self, obs, act, batch_size=8192):
+
+ class Predict(torch.autograd.Function):
+ # TODO: align rollout_model elites with gradient_model elites
+ # use different model for forward and backward
+ @staticmethod
+ def forward(ctx, x):
+ ctx.save_for_backward(x)
+ mean, var = self.rollout_model(x, ret_log_var=False)
+ return torch.cat([mean, var], dim=-1)
+
+ @staticmethod
+ def backward(ctx, grad_out):
+ x, = ctx.saved_tensors
+ with torch.enable_grad():
+ x = x.detach()
+ x.requires_grad_(True)
+ mean, var = self.gradient_model(x, ret_log_var=False)
+ y = torch.cat([mean, var], dim=-1)
+ return torch.autograd.grad(y, x, grad_outputs=grad_out, create_graph=True)
+
+ if len(act.shape) == 1:
+ act = act.unsqueeze(1)
+ if self._cuda:
+ obs = obs.cuda()
+ act = act.cuda()
+ inputs = torch.cat([obs, act], dim=1)
+ inputs = self.scaler.transform(inputs)
+ # predict
+ ensemble_mean, ensemble_var = [], []
+ for i in range(0, inputs.shape[0], batch_size):
+ input = unsqueeze_repeat(inputs[i:i + batch_size], self.ensemble_size)
+ if not torch.is_grad_enabled() or not self.gradient_model:
+ b_mean, b_var = self.rollout_model(input, ret_log_var=False)
+ else:
+ # use gradient model to compute gradients during backward pass
+ output = Predict.apply(input)
+ b_mean, b_var = output.chunk(2, dim=2)
+ ensemble_mean.append(b_mean)
+ ensemble_var.append(b_var)
+ ensemble_mean = torch.cat(ensemble_mean, 1)
+ ensemble_var = torch.cat(ensemble_var, 1)
+ ensemble_mean[:, :, 1:] += obs.unsqueeze(0)
+ ensemble_std = ensemble_var.sqrt()
+ # sample from the predicted distribution
+ if self.deterministic_rollout:
+ ensemble_sample = ensemble_mean
+ else:
+ ensemble_sample = ensemble_mean + torch.randn_like(ensemble_mean).to(ensemble_mean) * ensemble_std
+ # sample from ensemble
+ model_idxes = torch.from_numpy(np.random.choice(self.elite_model_idxes, size=len(obs))).to(inputs.device)
+ batch_idxes = torch.arange(len(obs)).to(inputs.device)
+ sample = ensemble_sample[model_idxes, batch_idxes]
+ rewards, next_obs = sample[:, 0], sample[:, 1:]
+
+ return rewards, next_obs, self.env.termination_fn(next_obs)
+
+ def eval(self, env_buffer, envstep, train_iter):
+ data = env_buffer.sample(self.eval_freq, train_iter)
+ data = default_collate(data)
+ data['done'] = data['done'].float()
+ data['weight'] = data.get('weight', None)
+ obs = data['obs']
+ action = data['action']
+ reward = data['reward']
+ next_obs = data['next_obs']
+ if len(reward.shape) == 1:
+ reward = reward.unsqueeze(1)
+ if len(action.shape) == 1:
+ action = action.unsqueeze(1)
+
+ # build eval samples
+ inputs = torch.cat([obs, action], dim=1)
+ labels = torch.cat([reward, next_obs - obs], dim=1)
+ if self._cuda:
+ inputs = inputs.cuda()
+ labels = labels.cuda()
+
+ # normalize
+ inputs = self.scaler.transform(inputs)
+
+ # repeat for ensemble
+ inputs = unsqueeze_repeat(inputs, self.ensemble_size)
+ labels = unsqueeze_repeat(labels, self.ensemble_size)
+
+ # eval
+ with torch.no_grad():
+ mean, logvar = self.rollout_model(inputs, ret_log_var=True)
+ loss, mse_loss = self.rollout_model.loss(mean, logvar, labels)
+ ensemble_mse_loss = torch.pow(mean.mean(0) - labels[0], 2)
+ model_variance = mean.var(0)
+ self.tb_logger.add_scalar('env_model_step/eval_mse_loss', mse_loss.mean().item(), envstep)
+ self.tb_logger.add_scalar('env_model_step/eval_ensemble_mse_loss', ensemble_mse_loss.mean().item(), envstep)
+ self.tb_logger.add_scalar('env_model_step/eval_model_variances', model_variance.mean().item(), envstep)
+
+ self.last_eval_step = envstep
+
+ def train(self, env_buffer, envstep, train_iter):
+
+ def train_sample(data) -> tuple:
+ data = default_collate(data)
+ data['done'] = data['done'].float()
+ data['weight'] = data.get('weight', None)
+ obs = data['obs']
+ action = data['action']
+ reward = data['reward']
+ next_obs = data['next_obs']
+ if len(reward.shape) == 1:
+ reward = reward.unsqueeze(1)
+ if len(action.shape) == 1:
+ action = action.unsqueeze(1)
+ # build train samples
+ inputs = torch.cat([obs, action], dim=1)
+ labels = torch.cat([reward, next_obs - obs], dim=1)
+ if self._cuda:
+ inputs = inputs.cuda()
+ labels = labels.cuda()
+ return inputs, labels
+
+ logvar = dict()
+
+ data = env_buffer.sample(env_buffer.count(), train_iter)
+ inputs, labels = train_sample(data)
+ logvar.update(self._train_rollout_model(inputs, labels))
+
+ if self.gradient_model:
+ # update neighbor pool
+ if (envstep - self.last_train_step_gradient_model) >= self.train_freq_gradient_model:
+ n = min(env_buffer.count(), self.neighbor_pool_size)
+ self.neighbor_pool = env_buffer.sample(n, train_iter, sample_range=slice(-n, None))
+ inputs_reg, labels_reg = train_sample(self.neighbor_pool)
+ logvar.update(self._train_gradient_model(inputs, labels, inputs_reg, labels_reg))
+ self.last_train_step_gradient_model = envstep
+
+ self.last_train_step = envstep
+
+ # log
+ if self.tb_logger is not None:
+ for k, v in logvar.items():
+ self.tb_logger.add_scalar('env_model_step/' + k, v, envstep)
+
+ def _train_rollout_model(self, inputs, labels):
+ #split
+ num_holdout = int(inputs.shape[0] * self.holdout_ratio)
+ train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]
+ holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout]
+
+ #normalize
+ self.scaler.fit(train_inputs)
+ train_inputs = self.scaler.transform(train_inputs)
+ holdout_inputs = self.scaler.transform(holdout_inputs)
+
+ #repeat for ensemble
+ holdout_inputs = unsqueeze_repeat(holdout_inputs, self.ensemble_size)
+ holdout_labels = unsqueeze_repeat(holdout_labels, self.ensemble_size)
+
+ self._epochs_since_update = 0
+ self._snapshots = {i: (-1, 1e10) for i in range(self.ensemble_size)}
+ self._save_states()
+ for epoch in itertools.count():
+
+ train_idx = torch.stack([torch.randperm(train_inputs.shape[0])
+ for _ in range(self.ensemble_size)]).to(train_inputs.device)
+ self.mse_loss = []
+ for start_pos in range(0, train_inputs.shape[0], self.batch_size):
+ idx = train_idx[:, start_pos:start_pos + self.batch_size]
+ train_input = train_inputs[idx]
+ train_label = train_labels[idx]
+ mean, logvar = self.rollout_model(train_input, ret_log_var=True)
+ loss, mse_loss = self.rollout_model.loss(mean, logvar, train_label)
+ self.rollout_model.train(loss)
+ self.mse_loss.append(mse_loss.mean().item())
+ self.mse_loss = sum(self.mse_loss) / len(self.mse_loss)
+
+ with torch.no_grad():
+ holdout_mean, holdout_logvar = self.rollout_model(holdout_inputs, ret_log_var=True)
+ _, holdout_mse_loss = self.rollout_model.loss(holdout_mean, holdout_logvar, holdout_labels)
+ self.curr_holdout_mse_loss = holdout_mse_loss.mean().item()
+ break_train = self._save_best(epoch, holdout_mse_loss)
+ if break_train:
+ break
+
+ self._load_states()
+ with torch.no_grad():
+ holdout_mean, holdout_logvar = self.rollout_model(holdout_inputs, ret_log_var=True)
+ _, holdout_mse_loss = self.rollout_model.loss(holdout_mean, holdout_logvar, holdout_labels)
+ sorted_loss, sorted_loss_idx = holdout_mse_loss.sort()
+ sorted_loss = sorted_loss.detach().cpu().numpy().tolist()
+ sorted_loss_idx = sorted_loss_idx.detach().cpu().numpy().tolist()
+ self.elite_model_idxes = sorted_loss_idx[:self.elite_size]
+ self.top_holdout_mse_loss = sorted_loss[0]
+ self.middle_holdout_mse_loss = sorted_loss[self.ensemble_size // 2]
+ self.bottom_holdout_mse_loss = sorted_loss[-1]
+ self.best_holdout_mse_loss = holdout_mse_loss.mean().item()
+ return {
+ 'rollout_model/mse_loss': self.mse_loss,
+ 'rollout_model/curr_holdout_mse_loss': self.curr_holdout_mse_loss,
+ 'rollout_model/best_holdout_mse_loss': self.best_holdout_mse_loss,
+ 'rollout_model/top_holdout_mse_loss': self.top_holdout_mse_loss,
+ 'rollout_model/middle_holdout_mse_loss': self.middle_holdout_mse_loss,
+ 'rollout_model/bottom_holdout_mse_loss': self.bottom_holdout_mse_loss,
+ }
+
+ def _get_jacobian(self, model, train_input_reg):
+ """
+ train_input_reg: [ensemble_size, B, state_size+action_size]
+
+ ret: [ensemble_size, B, state_size+reward_size, state_size+action_size]
+ """
+
+ def func(x):
+ x = x.view(self.ensemble_size, -1, self.state_size + self.action_size)
+ state = x[:, :, :self.state_size]
+ x = self.scaler.transform(x)
+ y, _ = model(x)
+ # y[:, :, self.reward_size:] += state, inplace operation leads to error
+ null = torch.zeros_like(y)
+ null[:, :, self.reward_size:] += state
+ y = y + null
+
+ return y.view(-1, self.state_size + self.reward_size, self.state_size + self.reward_size)
+
+ # reshape input
+ train_input_reg = train_input_reg.view(-1, self.state_size + self.action_size)
+ jacobian = get_batch_jacobian(func, train_input_reg, self.state_size + self.reward_size)
+
+ # reshape jacobian
+ return jacobian.view(
+ self.ensemble_size, -1, self.state_size + self.reward_size, self.state_size + self.action_size
+ )
+
+ def _train_gradient_model(self, inputs, labels, inputs_reg, labels_reg):
+ #split
+ num_holdout = int(inputs.shape[0] * self.holdout_ratio)
+ train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]
+ holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout]
+
+ #normalize
+ # self.scaler.fit(train_inputs)
+ train_inputs = self.scaler.transform(train_inputs)
+ holdout_inputs = self.scaler.transform(holdout_inputs)
+
+ #repeat for ensemble
+ holdout_inputs = unsqueeze_repeat(holdout_inputs, self.ensemble_size)
+ holdout_labels = unsqueeze_repeat(holdout_labels, self.ensemble_size)
+
+ #no split and normalization on regulation data
+ train_inputs_reg, train_labels_reg = inputs_reg, labels_reg
+
+ neighbor_index = get_neighbor_index(train_inputs_reg, self.k, serial=self.serial_calc_nn)
+ neighbor_inputs = train_inputs_reg[neighbor_index] # [N, k, state_size+action_size]
+ neighbor_labels = train_labels_reg[neighbor_index] # [N, k, state_size+reward_size]
+ neighbor_inputs_distance = (neighbor_inputs - train_inputs_reg.unsqueeze(1)) # [N, k, state_size+action_size]
+ neighbor_labels_distance = (neighbor_labels - train_labels_reg.unsqueeze(1)) # [N, k, state_size+reward_size]
+
+ self._epochs_since_update = 0
+ self._snapshots = {i: (-1, 1e10) for i in range(self.ensemble_size)}
+ self._save_states()
+ for epoch in itertools.count():
+
+ train_idx = torch.stack([torch.randperm(train_inputs.shape[0])
+ for _ in range(self.ensemble_size)]).to(train_inputs.device)
+
+ train_idx_reg = torch.stack([torch.randperm(train_inputs_reg.shape[0])
+ for _ in range(self.ensemble_size)]).to(train_inputs_reg.device)
+
+ self.mse_loss = []
+ self.grad_loss = []
+ for start_pos in range(0, train_inputs.shape[0], self.batch_size):
+ idx = train_idx[:, start_pos:start_pos + self.batch_size]
+ train_input = train_inputs[idx]
+ train_label = train_labels[idx]
+ mean, logvar = self.gradient_model(train_input, ret_log_var=True)
+ loss, mse_loss = self.gradient_model.loss(mean, logvar, train_label)
+
+ # regulation loss
+ if start_pos % train_inputs_reg.shape[0] < (start_pos + self.batch_size) % train_inputs_reg.shape[0]:
+ idx_reg = train_idx_reg[:, start_pos % train_inputs_reg.shape[0]:(start_pos + self.batch_size) %
+ train_inputs_reg.shape[0]]
+ else:
+ idx_reg = train_idx_reg[:, 0:(start_pos + self.batch_size) % train_inputs_reg.shape[0]]
+
+ train_input_reg = train_inputs_reg[idx_reg]
+ neighbor_input_distance = neighbor_inputs_distance[idx_reg
+ ] # [ensemble_size, B, k, state_size+action_size]
+ neighbor_label_distance = neighbor_labels_distance[idx_reg
+ ] # [ensemble_size, B, k, state_size+reward_size]
+
+ jacobian = self._get_jacobian(self.gradient_model, train_input_reg).unsqueeze(2).repeat_interleave(
+ self.k, dim=2
+ ) # [ensemble_size, B, k(repeat), state_size+reward_size, state_size+action_size]
+
+ directional_derivative = (jacobian @ neighbor_input_distance.unsqueeze(-1)).squeeze(
+ -1
+ ) # [ensemble_size, B, k, state_size+reward_size]
+
+ loss_reg = torch.pow((neighbor_label_distance - directional_derivative),
+ 2).sum(0).mean() # sumed over network
+
+ self.gradient_model.train(loss, loss_reg, self.reg)
+ self.mse_loss.append(mse_loss.mean().item())
+ self.grad_loss.append(loss_reg.item())
+
+ self.mse_loss = sum(self.mse_loss) / len(self.mse_loss)
+ self.grad_loss = sum(self.grad_loss) / len(self.grad_loss)
+
+ with torch.no_grad():
+ holdout_mean, holdout_logvar = self.gradient_model(holdout_inputs, ret_log_var=True)
+ _, holdout_mse_loss = self.gradient_model.loss(holdout_mean, holdout_logvar, holdout_labels)
+ self.curr_holdout_mse_loss = holdout_mse_loss.mean().item()
+ break_train = self._save_best(epoch, holdout_mse_loss)
+ if break_train:
+ break
+
+ self._load_states()
+ with torch.no_grad():
+ holdout_mean, holdout_logvar = self.gradient_model(holdout_inputs, ret_log_var=True)
+ _, holdout_mse_loss = self.gradient_model.loss(holdout_mean, holdout_logvar, holdout_labels)
+ sorted_loss, sorted_loss_idx = holdout_mse_loss.sort()
+ sorted_loss = sorted_loss.detach().cpu().numpy().tolist()
+ sorted_loss_idx = sorted_loss_idx.detach().cpu().numpy().tolist()
+ self.elite_model_idxes_gradient_model = sorted_loss_idx[:self.elite_size]
+ self.top_holdout_mse_loss = sorted_loss[0]
+ self.middle_holdout_mse_loss = sorted_loss[self.ensemble_size // 2]
+ self.bottom_holdout_mse_loss = sorted_loss[-1]
+ self.best_holdout_mse_loss = holdout_mse_loss.mean().item()
+ return {
+ 'gradient_model/mse_loss': self.mse_loss,
+ 'gradient_model/grad_loss': self.grad_loss,
+ 'gradient_model/curr_holdout_mse_loss': self.curr_holdout_mse_loss,
+ 'gradient_model/best_holdout_mse_loss': self.best_holdout_mse_loss,
+ 'gradient_model/top_holdout_mse_loss': self.top_holdout_mse_loss,
+ 'gradient_model/middle_holdout_mse_loss': self.middle_holdout_mse_loss,
+ 'gradient_model/bottom_holdout_mse_loss': self.bottom_holdout_mse_loss,
+ }
+
+ def _save_states(self, ):
+ self._states = copy.deepcopy(self.state_dict())
+
+ def _save_state(self, id):
+ state_dict = self.state_dict()
+ for k, v in state_dict.items():
+ if 'weight' in k or 'bias' in k:
+ self._states[k].data[id] = copy.deepcopy(v.data[id])
+
+ def _load_states(self):
+ self.load_state_dict(self._states)
+
+ def _save_best(self, epoch, holdout_losses):
+ updated = False
+ for i in range(len(holdout_losses)):
+ current = holdout_losses[i]
+ _, best = self._snapshots[i]
+ improvement = (best - current) / best
+ if improvement > 0.01:
+ self._snapshots[i] = (epoch, current)
+ self._save_state(i)
+ # self._save_state(i)
+ updated = True
+ # improvement = (best - current) / best
+
+ if updated:
+ self._epochs_since_update = 0
+ else:
+ self._epochs_since_update += 1
+ return self._epochs_since_update > self.max_epochs_since_update
diff --git a/DI-engine/ding/world_model/dreamer.py b/DI-engine/ding/world_model/dreamer.py
new file mode 100644
index 0000000000000000000000000000000000000000..eafe257454449cd39fc8b8c5f3776cb7969be2f3
--- /dev/null
+++ b/DI-engine/ding/world_model/dreamer.py
@@ -0,0 +1,271 @@
+import numpy as np
+import copy
+import torch
+from torch import nn
+
+from ding.utils import WORLD_MODEL_REGISTRY, lists_to_dicts
+from ding.utils.data import default_collate
+from ding.model import ConvEncoder
+from ding.world_model.base_world_model import WorldModel
+from ding.world_model.model.networks import RSSM, ConvDecoder
+from ding.torch_utils import to_device
+from ding.torch_utils.network.dreamer import DenseHead
+
+
+@WORLD_MODEL_REGISTRY.register('dreamer')
+class DREAMERWorldModel(WorldModel, nn.Module):
+ config = dict(
+ pretrain=100,
+ train_freq=2,
+ model=dict(
+ state_size=None,
+ action_size=None,
+ model_lr=1e-4,
+ reward_size=1,
+ hidden_size=200,
+ batch_size=256,
+ max_epochs_since_update=5,
+ dyn_stoch=32,
+ dyn_deter=512,
+ dyn_hidden=512,
+ dyn_input_layers=1,
+ dyn_output_layers=1,
+ dyn_rec_depth=1,
+ dyn_shared=False,
+ dyn_discrete=32,
+ act='SiLU',
+ norm='LayerNorm',
+ grad_heads=['image', 'reward', 'discount'],
+ units=512,
+ reward_layers=2,
+ discount_layers=2,
+ value_layers=2,
+ actor_layers=2,
+ cnn_depth=32,
+ encoder_kernels=[4, 4, 4, 4],
+ decoder_kernels=[4, 4, 4, 4],
+ reward_head='twohot_symlog',
+ kl_lscale=0.1,
+ kl_rscale=0.5,
+ kl_free=1.0,
+ kl_forward=False,
+ pred_discount=True,
+ dyn_mean_act='none',
+ dyn_std_act='sigmoid2',
+ dyn_temp_post=True,
+ dyn_min_std=0.1,
+ dyn_cell='gru_layer_norm',
+ unimix_ratio=0.01,
+ device='cuda' if torch.cuda.is_available() else 'cpu',
+ ),
+ )
+
+ def __init__(self, cfg, env, tb_logger):
+ WorldModel.__init__(self, cfg, env, tb_logger)
+ nn.Module.__init__(self)
+
+ self.pretrain_flag = True
+ self._cfg = cfg.model
+ #self._cfg.act = getattr(torch.nn, self._cfg.act),
+ #self._cfg.norm = getattr(torch.nn, self._cfg.norm),
+ self._cfg.act = nn.modules.activation.SiLU # nn.SiLU
+ self._cfg.norm = nn.modules.normalization.LayerNorm # nn.LayerNorm
+ self.state_size = self._cfg.state_size
+ self.action_size = self._cfg.action_size
+ self.reward_size = self._cfg.reward_size
+ self.hidden_size = self._cfg.hidden_size
+ self.batch_size = self._cfg.batch_size
+
+ self.encoder = ConvEncoder(
+ self.state_size,
+ hidden_size_list=[32, 64, 128, 256, 4096], # to last layer 128?
+ activation=torch.nn.SiLU(),
+ kernel_size=self._cfg.encoder_kernels,
+ layer_norm=True
+ )
+ self.embed_size = (
+ (self.state_size[1] // 2 ** (len(self._cfg.encoder_kernels))) ** 2 * self._cfg.cnn_depth *
+ 2 ** (len(self._cfg.encoder_kernels) - 1)
+ )
+ self.dynamics = RSSM(
+ self._cfg.dyn_stoch,
+ self._cfg.dyn_deter,
+ self._cfg.dyn_hidden,
+ self._cfg.dyn_input_layers,
+ self._cfg.dyn_output_layers,
+ self._cfg.dyn_rec_depth,
+ self._cfg.dyn_shared,
+ self._cfg.dyn_discrete,
+ self._cfg.act,
+ self._cfg.norm,
+ self._cfg.dyn_mean_act,
+ self._cfg.dyn_std_act,
+ self._cfg.dyn_temp_post,
+ self._cfg.dyn_min_std,
+ self._cfg.dyn_cell,
+ self._cfg.unimix_ratio,
+ self._cfg.action_size,
+ self.embed_size,
+ self._cfg.device,
+ )
+ self.heads = nn.ModuleDict()
+ if self._cfg.dyn_discrete:
+ feat_size = self._cfg.dyn_stoch * self._cfg.dyn_discrete + self._cfg.dyn_deter
+ else:
+ feat_size = self._cfg.dyn_stoch + self._cfg.dyn_deter
+ self.heads["image"] = ConvDecoder(
+ feat_size, # pytorch version
+ self._cfg.cnn_depth,
+ self._cfg.act,
+ self._cfg.norm,
+ self.state_size,
+ self._cfg.decoder_kernels,
+ )
+ self.heads["reward"] = DenseHead(
+ feat_size, # dyn_stoch * dyn_discrete + dyn_deter
+ (255, ),
+ self._cfg.reward_layers,
+ self._cfg.units,
+ 'SiLU', # self._cfg.act
+ 'LN', # self._cfg.norm
+ dist=self._cfg.reward_head,
+ outscale=0.0,
+ device=self._cfg.device,
+ )
+ if self._cfg.pred_discount:
+ self.heads["discount"] = DenseHead(
+ feat_size, # pytorch version
+ [],
+ self._cfg.discount_layers,
+ self._cfg.units,
+ 'SiLU', # self._cfg.act
+ 'LN', # self._cfg.norm
+ dist="binary",
+ device=self._cfg.device,
+ )
+
+ if self._cuda:
+ self.cuda()
+ # to do
+ # grad_clip, weight_decay
+ self.optimizer = torch.optim.Adam(self.parameters(), lr=self._cfg.model_lr)
+
+ def step(self, obs, act):
+ pass
+
+ def eval(self, env_buffer, envstep, train_iter):
+ pass
+
+ def should_pretrain(self):
+ if self.pretrain_flag:
+ self.pretrain_flag = False
+ return True
+ return False
+
+ def train(self, env_buffer, envstep, train_iter, batch_size, batch_length):
+ self.last_train_step = envstep
+ data = env_buffer.sample(
+ batch_size, batch_length, train_iter
+ ) # [len=B, ele=[len=T, ele={dict_key: Tensor(any_dims)}]]
+ data = default_collate(data) # -> [len=T, ele={dict_key: Tensor(B, any_dims)}]
+ data = lists_to_dicts(data, recursive=True) # -> {some_key: T lists}, each list is [B, some_dim]
+ data = {k: torch.stack(data[k], dim=1) for k in data} # -> {dict_key: Tensor([B, T, any_dims])}
+
+ data['discount'] = data.get('discount', 1.0 - data['done'].float())
+ data['discount'] *= 0.997
+ data['weight'] = data.get('weight', None)
+ data['image'] = data['obs'] - 0.5
+ data = to_device(data, self._cfg.device)
+ if len(data['reward'].shape) == 2:
+ data['reward'] = data['reward'].unsqueeze(-1)
+ if len(data['action'].shape) == 2:
+ data['action'] = data['action'].unsqueeze(-1)
+ if len(data['discount'].shape) == 2:
+ data['discount'] = data['discount'].unsqueeze(-1)
+
+ self.requires_grad_(requires_grad=True)
+
+ image = data['image'].reshape([-1] + list(data['image'].shape[-3:]))
+ embed = self.encoder(image)
+ embed = embed.reshape(list(data['image'].shape[:-3]) + [embed.shape[-1]])
+
+ post, prior = self.dynamics.observe(embed, data["action"])
+ kl_loss, kl_value, loss_lhs, loss_rhs = self.dynamics.kl_loss(
+ post, prior, self._cfg.kl_forward, self._cfg.kl_free, self._cfg.kl_lscale, self._cfg.kl_rscale
+ )
+ losses = {}
+ likes = {}
+ for name, head in self.heads.items():
+ grad_head = name in self._cfg.grad_heads
+ feat = self.dynamics.get_feat(post)
+ feat = feat if grad_head else feat.detach()
+ pred = head(feat)
+ like = pred.log_prob(data[name])
+ likes[name] = like
+ losses[name] = -torch.mean(like)
+ model_loss = sum(losses.values()) + kl_loss
+
+ # ====================
+ # world model update
+ # ====================
+ self.optimizer.zero_grad()
+ model_loss.backward()
+ self.optimizer.step()
+
+ self.requires_grad_(requires_grad=False)
+ # log
+ if self.tb_logger is not None:
+ for name, loss in losses.items():
+ self.tb_logger.add_scalar(name + '_loss', loss.detach().cpu().numpy().item(), envstep)
+ self.tb_logger.add_scalar('kl_free', self._cfg.kl_free, envstep)
+ self.tb_logger.add_scalar('kl_lscale', self._cfg.kl_lscale, envstep)
+ self.tb_logger.add_scalar('kl_rscale', self._cfg.kl_rscale, envstep)
+ self.tb_logger.add_scalar('loss_lhs', loss_lhs.detach().cpu().numpy().item(), envstep)
+ self.tb_logger.add_scalar('loss_rhs', loss_rhs.detach().cpu().numpy().item(), envstep)
+ self.tb_logger.add_scalar('kl', torch.mean(kl_value).detach().cpu().numpy().item(), envstep)
+
+ prior_ent = torch.mean(self.dynamics.get_dist(prior).entropy()).detach().cpu().numpy()
+ post_ent = torch.mean(self.dynamics.get_dist(post).entropy()).detach().cpu().numpy()
+
+ self.tb_logger.add_scalar('prior_ent', prior_ent.item(), envstep)
+ self.tb_logger.add_scalar('post_ent', post_ent.item(), envstep)
+
+ context = dict(
+ embed=embed,
+ feat=self.dynamics.get_feat(post),
+ kl=kl_value,
+ postent=self.dynamics.get_dist(post).entropy(),
+ )
+ post = {k: v.detach() for k, v in post.items()}
+ return post, context
+
+ def _save_states(self, ):
+ self._states = copy.deepcopy(self.state_dict())
+
+ def _save_state(self, id):
+ state_dict = self.state_dict()
+ for k, v in state_dict.items():
+ if 'weight' in k or 'bias' in k:
+ self._states[k].data[id] = copy.deepcopy(v.data[id])
+
+ def _load_states(self):
+ self.load_state_dict(self._states)
+
+ def _save_best(self, epoch, holdout_losses):
+ updated = False
+ for i in range(len(holdout_losses)):
+ current = holdout_losses[i]
+ _, best = self._snapshots[i]
+ improvement = (best - current) / best
+ if improvement > 0.01:
+ self._snapshots[i] = (epoch, current)
+ self._save_state(i)
+ # self._save_state(i)
+ updated = True
+ # improvement = (best - current) / best
+
+ if updated:
+ self._epochs_since_update = 0
+ else:
+ self._epochs_since_update += 1
+ return self._epochs_since_update > self.max_epochs_since_update
diff --git a/DI-engine/ding/world_model/idm.py b/DI-engine/ding/world_model/idm.py
new file mode 100644
index 0000000000000000000000000000000000000000..308dd406f921b6971c63318b6d3e3885f2b0fdf6
--- /dev/null
+++ b/DI-engine/ding/world_model/idm.py
@@ -0,0 +1,138 @@
+import torch
+import torch.nn as nn
+from typing import Union, Optional, Dict
+import numpy as np
+
+from ding.model.common.head import DiscreteHead, RegressionHead, ReparameterizationHead
+from ding.utils import SequenceType, squeeze
+from ding.model.common.encoder import FCEncoder, ConvEncoder
+from torch.distributions import Independent, Normal
+
+
+class InverseDynamicsModel(nn.Module):
+ """
+ InverseDynamicsModel: infering missing action information from state transition.
+ input and output: given pair of observation, return action (s0,s1 --> a0 if n=2)
+ """
+
+ def __init__(
+ self,
+ obs_shape: Union[int, SequenceType],
+ action_shape: Union[int, SequenceType],
+ encoder_hidden_size_list: SequenceType = [60, 80, 100, 40],
+ action_space: str = "regression",
+ activation: Optional[nn.Module] = nn.LeakyReLU(),
+ norm_type: Optional[str] = None
+ ) -> None:
+ r"""
+ Overview:
+ Init the Inverse Dynamics (encoder + head) Model according to input arguments.
+ Arguments:
+ - obs_shape (:obj:`Union[int, SequenceType]`): Observation space shape, such as 8 or [4, 84, 84].
+ - action_shape (:obj:`Union[int, SequenceType]`): Action space shape, such as 6 or [2, 3, 3].
+ - encoder_hidden_size_list (:obj:`SequenceType`): Collection of ``hidden_size`` to pass to ``Encoder``, \
+ the last element must match ``head_hidden_size``.
+ - action_space (:obj:`String`): Action space, such as 'regression', 'reparameterization', 'discrete'.
+ - activation (:obj:`Optional[nn.Module]`): The type of activation function in networks \
+ if ``None`` then default set it to ``nn.LeakyReLU()`` refer to https://arxiv.org/abs/1805.01954
+ - norm_type (:obj:`Optional[str]`): The type of normalization in networks, see \
+ ``ding.torch_utils.fc_block`` for more details.
+ """
+ super(InverseDynamicsModel, self).__init__()
+ # For compatibility: 1, (1, ), [4, 32, 32]
+ obs_shape, action_shape = squeeze(obs_shape), squeeze(action_shape)
+ # FC encoder: obs and obs[next] ,so input shape is obs_shape*2
+ if isinstance(obs_shape, int) or len(obs_shape) == 1:
+ self.encoder = FCEncoder(
+ obs_shape * 2, encoder_hidden_size_list, activation=activation, norm_type=norm_type
+ )
+ elif len(obs_shape) == 3:
+ # FC encoder: obs and obs[next] ,so first channel need multiply 2
+ obs_shape = (obs_shape[0] * 2, *obs_shape[1:])
+ self.encoder = ConvEncoder(obs_shape, encoder_hidden_size_list, activation=activation, norm_type=norm_type)
+ else:
+ raise RuntimeError(
+ "not support obs_shape for pre-defined encoder: {}, please customize your own Model".format(obs_shape)
+ )
+ self.action_space = action_space
+ assert self.action_space in ['regression', 'reparameterization',
+ 'discrete'], "not supported action_space: {}".format(self.action_space)
+ if self.action_space == "regression":
+ self.header = RegressionHead(
+ encoder_hidden_size_list[-1],
+ action_shape,
+ final_tanh=False,
+ activation=activation,
+ norm_type=norm_type
+ )
+ elif self.action_space == "reparameterization":
+ self.header = ReparameterizationHead(
+ encoder_hidden_size_list[-1],
+ action_shape,
+ sigma_type='conditioned',
+ activation=activation,
+ norm_type=norm_type
+ )
+ elif self.action_space == "discrete":
+ self.header = DiscreteHead(
+ encoder_hidden_size_list[-1], action_shape, activation=activation, norm_type=norm_type
+ )
+
+ def forward(self, x: torch.Tensor) -> Dict:
+ if self.action_space == "regression":
+ x = self.encoder(x)
+ x = self.header(x)
+ return {'action': x['pred']}
+ elif self.action_space == "reparameterization":
+ x = self.encoder(x)
+ x = self.header(x)
+ mu, sigma = x['mu'], x['sigma']
+ dist = Independent(Normal(mu, sigma), 1)
+ pred = dist.rsample()
+ action = torch.tanh(pred)
+ return {'logit': [mu, sigma], 'action': action}
+ elif self.action_space == "discrete":
+ x = self.encoder(x)
+ x = self.header(x)
+ return x
+
+ def predict_action(self, x: torch.Tensor) -> Dict:
+ if self.action_space == "discrete":
+ res = nn.Softmax(dim=-1)
+ action = torch.argmax(res(self.forward(x)['logit']), -1)
+ return {'action': action}
+ else:
+ return self.forward(x)
+
+ def train(self, training_set: dict, n_epoch: int, learning_rate: float, weight_decay: float):
+ r"""
+ Overview:
+ Train idm model, given pair of states return action (s_t,s_t+1,a_t)
+
+ Arguments:
+ - training_set (:obj:`dict`):states transition
+ - n_epoch (:obj:`int`): number of epoches
+ - learning_rate (:obj:`float`): learning rate for optimizer
+ - weight_decay (:obj:`float`): weight decay for optimizer
+ """
+ if self.action_space == "discrete":
+ criterion = nn.CrossEntropyLoss()
+ else:
+ # criterion = nn.MSELoss()
+ criterion = nn.L1Loss()
+ optimizer = torch.optim.AdamW(self.parameters(), lr=learning_rate, weight_decay=weight_decay)
+ loss_list = []
+ for itr in range(n_epoch):
+ data = training_set['obs']
+ y = training_set['action']
+ if self.action_space == "discrete":
+ y_pred = self.forward(data)['logit']
+ else:
+ y_pred = self.forward(data)['action']
+ loss = criterion(y_pred, y)
+ optimizer.zero_grad()
+ loss.backward()
+ optimizer.step()
+ loss_list.append(loss.item())
+ loss = np.mean(loss_list)
+ return loss
diff --git a/DI-engine/ding/world_model/mbpo.py b/DI-engine/ding/world_model/mbpo.py
new file mode 100644
index 0000000000000000000000000000000000000000..e22f076aac586a14dfdb43b0c46dcefb16e5a11d
--- /dev/null
+++ b/DI-engine/ding/world_model/mbpo.py
@@ -0,0 +1,276 @@
+import itertools
+import numpy as np
+import copy
+import torch
+from torch import nn
+
+from ding.utils import WORLD_MODEL_REGISTRY
+from ding.utils.data import default_collate
+from ding.world_model.base_world_model import HybridWorldModel
+from ding.world_model.model.ensemble import EnsembleModel, StandardScaler
+from ding.torch_utils import fold_batch, unfold_batch, unsqueeze_repeat
+
+
+@WORLD_MODEL_REGISTRY.register('mbpo')
+class MBPOWorldModel(HybridWorldModel, nn.Module):
+ config = dict(
+ model=dict(
+ ensemble_size=7,
+ elite_size=5,
+ state_size=None,
+ action_size=None,
+ reward_size=1,
+ hidden_size=200,
+ use_decay=False,
+ batch_size=256,
+ holdout_ratio=0.2,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ )
+
+ def __init__(self, cfg, env, tb_logger):
+ HybridWorldModel.__init__(self, cfg, env, tb_logger)
+ nn.Module.__init__(self)
+
+ cfg = cfg.model
+ self.ensemble_size = cfg.ensemble_size
+ self.elite_size = cfg.elite_size
+ self.state_size = cfg.state_size
+ self.action_size = cfg.action_size
+ self.reward_size = cfg.reward_size
+ self.hidden_size = cfg.hidden_size
+ self.use_decay = cfg.use_decay
+ self.batch_size = cfg.batch_size
+ self.holdout_ratio = cfg.holdout_ratio
+ self.max_epochs_since_update = cfg.max_epochs_since_update
+ self.deterministic_rollout = cfg.deterministic_rollout
+
+ self.ensemble_model = EnsembleModel(
+ self.state_size,
+ self.action_size,
+ self.reward_size,
+ self.ensemble_size,
+ self.hidden_size,
+ use_decay=self.use_decay
+ )
+ self.scaler = StandardScaler(self.state_size + self.action_size)
+
+ if self._cuda:
+ self.cuda()
+
+ self.ensemble_mse_losses = []
+ self.model_variances = []
+ self.elite_model_idxes = []
+
+ def step(self, obs, act, batch_size=8192, keep_ensemble=False):
+ if len(act.shape) == 1:
+ act = act.unsqueeze(1)
+ if self._cuda:
+ obs = obs.cuda()
+ act = act.cuda()
+ inputs = torch.cat([obs, act], dim=-1)
+ if keep_ensemble:
+ inputs, dim = fold_batch(inputs, 1)
+ inputs = self.scaler.transform(inputs)
+ inputs = unfold_batch(inputs, dim)
+ else:
+ inputs = self.scaler.transform(inputs)
+ # predict
+ ensemble_mean, ensemble_var = [], []
+ batch_dim = 0 if len(inputs.shape) == 2 else 1
+ for i in range(0, inputs.shape[batch_dim], batch_size):
+ if keep_ensemble:
+ # inputs: [E, B, D]
+ input = inputs[:, i:i + batch_size]
+ else:
+ # input: [B, D]
+ input = unsqueeze_repeat(inputs[i:i + batch_size], self.ensemble_size)
+ b_mean, b_var = self.ensemble_model(input, ret_log_var=False)
+ ensemble_mean.append(b_mean)
+ ensemble_var.append(b_var)
+ ensemble_mean = torch.cat(ensemble_mean, 1)
+ ensemble_var = torch.cat(ensemble_var, 1)
+ if keep_ensemble:
+ ensemble_mean[:, :, 1:] += obs
+ else:
+ ensemble_mean[:, :, 1:] += obs.unsqueeze(0)
+ ensemble_std = ensemble_var.sqrt()
+ # sample from the predicted distribution
+ if self.deterministic_rollout:
+ ensemble_sample = ensemble_mean
+ else:
+ ensemble_sample = ensemble_mean + torch.randn_like(ensemble_mean).to(ensemble_mean) * ensemble_std
+ if keep_ensemble:
+ # [E, B, D]
+ rewards, next_obs = ensemble_sample[:, :, 0], ensemble_sample[:, :, 1:]
+ next_obs_flatten, dim = fold_batch(next_obs)
+ done = unfold_batch(self.env.termination_fn(next_obs_flatten), dim)
+ return rewards, next_obs, done
+ # sample from ensemble
+ model_idxes = torch.from_numpy(np.random.choice(self.elite_model_idxes, size=len(obs))).to(inputs.device)
+ batch_idxes = torch.arange(len(obs)).to(inputs.device)
+ sample = ensemble_sample[model_idxes, batch_idxes]
+ rewards, next_obs = sample[:, 0], sample[:, 1:]
+
+ return rewards, next_obs, self.env.termination_fn(next_obs)
+
+ def eval(self, env_buffer, envstep, train_iter):
+ data = env_buffer.sample(self.eval_freq, train_iter)
+ data = default_collate(data)
+ data['done'] = data['done'].float()
+ data['weight'] = data.get('weight', None)
+ obs = data['obs']
+ action = data['action']
+ reward = data['reward']
+ next_obs = data['next_obs']
+ if len(reward.shape) == 1:
+ reward = reward.unsqueeze(1)
+ if len(action.shape) == 1:
+ action = action.unsqueeze(1)
+
+ # build eval samples
+ inputs = torch.cat([obs, action], dim=1)
+ labels = torch.cat([reward, next_obs - obs], dim=1)
+ if self._cuda:
+ inputs = inputs.cuda()
+ labels = labels.cuda()
+
+ # normalize
+ inputs = self.scaler.transform(inputs)
+
+ # repeat for ensemble
+ inputs = unsqueeze_repeat(inputs, self.ensemble_size)
+ labels = unsqueeze_repeat(labels, self.ensemble_size)
+
+ # eval
+ with torch.no_grad():
+ mean, logvar = self.ensemble_model(inputs, ret_log_var=True)
+ loss, mse_loss = self.ensemble_model.loss(mean, logvar, labels)
+ ensemble_mse_loss = torch.pow(mean.mean(0) - labels[0], 2)
+ model_variance = mean.var(0)
+ self.tb_logger.add_scalar('env_model_step/eval_mse_loss', mse_loss.mean().item(), envstep)
+ self.tb_logger.add_scalar('env_model_step/eval_ensemble_mse_loss', ensemble_mse_loss.mean().item(), envstep)
+ self.tb_logger.add_scalar('env_model_step/eval_model_variances', model_variance.mean().item(), envstep)
+
+ self.last_eval_step = envstep
+
+ def train(self, env_buffer, envstep, train_iter):
+ data = env_buffer.sample(env_buffer.count(), train_iter)
+ data = default_collate(data)
+ data['done'] = data['done'].float()
+ data['weight'] = data.get('weight', None)
+ obs = data['obs']
+ action = data['action']
+ reward = data['reward']
+ next_obs = data['next_obs']
+ if len(reward.shape) == 1:
+ reward = reward.unsqueeze(1)
+ if len(action.shape) == 1:
+ action = action.unsqueeze(1)
+ # build train samples
+ inputs = torch.cat([obs, action], dim=1)
+ labels = torch.cat([reward, next_obs - obs], dim=1)
+ if self._cuda:
+ inputs = inputs.cuda()
+ labels = labels.cuda()
+ # train
+ logvar = self._train(inputs, labels)
+ self.last_train_step = envstep
+ # log
+ if self.tb_logger is not None:
+ for k, v in logvar.items():
+ self.tb_logger.add_scalar('env_model_step/' + k, v, envstep)
+
+ def _train(self, inputs, labels):
+ #split
+ num_holdout = int(inputs.shape[0] * self.holdout_ratio)
+ train_inputs, train_labels = inputs[num_holdout:], labels[num_holdout:]
+ holdout_inputs, holdout_labels = inputs[:num_holdout], labels[:num_holdout]
+
+ #normalize
+ self.scaler.fit(train_inputs)
+ train_inputs = self.scaler.transform(train_inputs)
+ holdout_inputs = self.scaler.transform(holdout_inputs)
+
+ #repeat for ensemble
+ holdout_inputs = unsqueeze_repeat(holdout_inputs, self.ensemble_size)
+ holdout_labels = unsqueeze_repeat(holdout_labels, self.ensemble_size)
+
+ self._epochs_since_update = 0
+ self._snapshots = {i: (-1, 1e10) for i in range(self.ensemble_size)}
+ self._save_states()
+ for epoch in itertools.count():
+
+ train_idx = torch.stack([torch.randperm(train_inputs.shape[0])
+ for _ in range(self.ensemble_size)]).to(train_inputs.device)
+ self.mse_loss = []
+ for start_pos in range(0, train_inputs.shape[0], self.batch_size):
+ idx = train_idx[:, start_pos:start_pos + self.batch_size]
+ train_input = train_inputs[idx]
+ train_label = train_labels[idx]
+ mean, logvar = self.ensemble_model(train_input, ret_log_var=True)
+ loss, mse_loss = self.ensemble_model.loss(mean, logvar, train_label)
+ self.ensemble_model.train(loss)
+ self.mse_loss.append(mse_loss.mean().item())
+ self.mse_loss = sum(self.mse_loss) / len(self.mse_loss)
+
+ with torch.no_grad():
+ holdout_mean, holdout_logvar = self.ensemble_model(holdout_inputs, ret_log_var=True)
+ _, holdout_mse_loss = self.ensemble_model.loss(holdout_mean, holdout_logvar, holdout_labels)
+ self.curr_holdout_mse_loss = holdout_mse_loss.mean().item()
+ break_train = self._save_best(epoch, holdout_mse_loss)
+ if break_train:
+ break
+
+ self._load_states()
+ with torch.no_grad():
+ holdout_mean, holdout_logvar = self.ensemble_model(holdout_inputs, ret_log_var=True)
+ _, holdout_mse_loss = self.ensemble_model.loss(holdout_mean, holdout_logvar, holdout_labels)
+ sorted_loss, sorted_loss_idx = holdout_mse_loss.sort()
+ sorted_loss = sorted_loss.detach().cpu().numpy().tolist()
+ sorted_loss_idx = sorted_loss_idx.detach().cpu().numpy().tolist()
+ self.elite_model_idxes = sorted_loss_idx[:self.elite_size]
+ self.top_holdout_mse_loss = sorted_loss[0]
+ self.middle_holdout_mse_loss = sorted_loss[self.ensemble_size // 2]
+ self.bottom_holdout_mse_loss = sorted_loss[-1]
+ self.best_holdout_mse_loss = holdout_mse_loss.mean().item()
+ return {
+ 'mse_loss': self.mse_loss,
+ 'curr_holdout_mse_loss': self.curr_holdout_mse_loss,
+ 'best_holdout_mse_loss': self.best_holdout_mse_loss,
+ 'top_holdout_mse_loss': self.top_holdout_mse_loss,
+ 'middle_holdout_mse_loss': self.middle_holdout_mse_loss,
+ 'bottom_holdout_mse_loss': self.bottom_holdout_mse_loss,
+ }
+
+ def _save_states(self, ):
+ self._states = copy.deepcopy(self.state_dict())
+
+ def _save_state(self, id):
+ state_dict = self.state_dict()
+ for k, v in state_dict.items():
+ if 'weight' in k or 'bias' in k:
+ self._states[k].data[id] = copy.deepcopy(v.data[id])
+
+ def _load_states(self):
+ self.load_state_dict(self._states)
+
+ def _save_best(self, epoch, holdout_losses):
+ updated = False
+ for i in range(len(holdout_losses)):
+ current = holdout_losses[i]
+ _, best = self._snapshots[i]
+ improvement = (best - current) / best
+ if improvement > 0.01:
+ self._snapshots[i] = (epoch, current)
+ self._save_state(i)
+ # self._save_state(i)
+ updated = True
+ # improvement = (best - current) / best
+
+ if updated:
+ self._epochs_since_update = 0
+ else:
+ self._epochs_since_update += 1
+ return self._epochs_since_update > self.max_epochs_since_update
diff --git a/DI-engine/ding/world_model/model/__init__.py b/DI-engine/ding/world_model/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/ding/world_model/model/ensemble.py b/DI-engine/ding/world_model/model/ensemble.py
new file mode 100644
index 0000000000000000000000000000000000000000..87433fd8c533e01a29130c94926c1ae6474f406b
--- /dev/null
+++ b/DI-engine/ding/world_model/model/ensemble.py
@@ -0,0 +1,150 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from ding.torch_utils import Swish
+
+
+class StandardScaler(nn.Module):
+
+ def __init__(self, input_size: int):
+ super(StandardScaler, self).__init__()
+ self.register_buffer('std', torch.ones(1, input_size))
+ self.register_buffer('mu', torch.zeros(1, input_size))
+
+ def fit(self, data: torch.Tensor):
+ std, mu = torch.std_mean(data, dim=0, keepdim=True)
+ std[std < 1e-12] = 1
+ self.std.data.mul_(0.0).add_(std)
+ self.mu.data.mul_(0.0).add_(mu)
+
+ def transform(self, data: torch.Tensor):
+ return (data - self.mu) / self.std
+
+ def inverse_transform(self, data: torch.Tensor):
+ return self.std * data + self.mu
+
+
+class EnsembleFC(nn.Module):
+ __constants__ = ['in_features', 'out_features']
+ in_features: int
+ out_features: int
+ ensemble_size: int
+ weight: torch.Tensor
+
+ def __init__(self, in_features: int, out_features: int, ensemble_size: int, weight_decay: float = 0.) -> None:
+ super(EnsembleFC, self).__init__()
+ self.in_features = in_features
+ self.out_features = out_features
+ self.ensemble_size = ensemble_size
+ self.weight = nn.Parameter(torch.zeros(ensemble_size, in_features, out_features))
+ self.weight_decay = weight_decay
+ self.bias = nn.Parameter(torch.zeros(ensemble_size, 1, out_features))
+
+ def forward(self, input: torch.Tensor) -> torch.Tensor:
+ assert input.shape[0] == self.ensemble_size and len(input.shape) == 3
+ return torch.bmm(input, self.weight) + self.bias # w times x + b
+
+ def extra_repr(self) -> str:
+ return 'in_features={}, out_features={}, ensemble_size={}, weight_decay={}'.format(
+ self.in_features, self.out_features, self.ensemble_size, self.weight_decay
+ )
+
+
+class EnsembleModel(nn.Module):
+
+ def __init__(
+ self,
+ state_size,
+ action_size,
+ reward_size,
+ ensemble_size,
+ hidden_size=200,
+ learning_rate=1e-3,
+ use_decay=False
+ ):
+ super(EnsembleModel, self).__init__()
+
+ self.use_decay = use_decay
+ self.hidden_size = hidden_size
+ self.output_dim = state_size + reward_size
+
+ self.nn1 = EnsembleFC(state_size + action_size, hidden_size, ensemble_size, weight_decay=0.000025)
+ self.nn2 = EnsembleFC(hidden_size, hidden_size, ensemble_size, weight_decay=0.00005)
+ self.nn3 = EnsembleFC(hidden_size, hidden_size, ensemble_size, weight_decay=0.000075)
+ self.nn4 = EnsembleFC(hidden_size, hidden_size, ensemble_size, weight_decay=0.000075)
+ self.nn5 = EnsembleFC(hidden_size, self.output_dim * 2, ensemble_size, weight_decay=0.0001)
+ self.max_logvar = nn.Parameter(torch.ones(1, self.output_dim).float() * 0.5, requires_grad=False)
+ self.min_logvar = nn.Parameter(torch.ones(1, self.output_dim).float() * -10, requires_grad=False)
+ self.swish = Swish()
+
+ def init_weights(m: nn.Module):
+
+ def truncated_normal_init(t, mean: float = 0.0, std: float = 0.01):
+ torch.nn.init.normal_(t, mean=mean, std=std)
+ while True:
+ cond = torch.logical_or(t < mean - 2 * std, t > mean + 2 * std)
+ if not torch.sum(cond):
+ break
+ t = torch.where(cond, torch.nn.init.normal_(torch.ones(t.shape), mean=mean, std=std), t)
+ return t
+
+ if isinstance(m, nn.Linear) or isinstance(m, EnsembleFC):
+ input_dim = m.in_features
+ truncated_normal_init(m.weight, std=1 / (2 * np.sqrt(input_dim)))
+ m.bias.data.fill_(0.0)
+
+ self.apply(init_weights)
+
+ self.optimizer = torch.optim.Adam(self.parameters(), lr=learning_rate)
+
+ def forward(self, x: torch.Tensor, ret_log_var: bool = False):
+ x = self.swish(self.nn1(x))
+ x = self.swish(self.nn2(x))
+ x = self.swish(self.nn3(x))
+ x = self.swish(self.nn4(x))
+ x = self.nn5(x)
+
+ mean, logvar = x.chunk(2, dim=2)
+ logvar = self.max_logvar - F.softplus(self.max_logvar - logvar)
+ logvar = self.min_logvar + F.softplus(logvar - self.min_logvar)
+
+ if ret_log_var:
+ return mean, logvar
+ else:
+ return mean, torch.exp(logvar)
+
+ def get_decay_loss(self):
+ decay_loss = 0.
+ for m in self.modules():
+ if isinstance(m, EnsembleFC):
+ decay_loss += m.weight_decay * torch.sum(torch.square(m.weight)) / 2.
+ return decay_loss
+
+ def loss(self, mean: torch.Tensor, logvar: torch.Tensor, labels: torch.Tensor):
+ """
+ mean, logvar: Ensemble_size x N x dim
+ labels: Ensemble_size x N x dim
+ """
+ assert len(mean.shape) == len(logvar.shape) == len(labels.shape) == 3
+ inv_var = torch.exp(-logvar)
+ # Average over batch and dim, sum over ensembles.
+ mse_loss_inv = (torch.pow(mean - labels, 2) * inv_var).mean(dim=(1, 2))
+ var_loss = logvar.mean(dim=(1, 2))
+ with torch.no_grad():
+ # Used only for logging.
+ mse_loss = torch.pow(mean - labels, 2).mean(dim=(1, 2))
+ total_loss = mse_loss_inv.sum() + var_loss.sum()
+ return total_loss, mse_loss
+
+ def train(self, loss: torch.Tensor):
+ self.optimizer.zero_grad()
+
+ loss += 0.01 * torch.sum(self.max_logvar) - 0.01 * torch.sum(self.min_logvar)
+ if self.use_decay:
+ loss += self.get_decay_loss()
+
+ loss.backward()
+
+ self.optimizer.step()
diff --git a/DI-engine/ding/world_model/model/networks.py b/DI-engine/ding/world_model/model/networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..091fa4f827805dfa45040d654102a8d7258c5602
--- /dev/null
+++ b/DI-engine/ding/world_model/model/networks.py
@@ -0,0 +1,397 @@
+import math
+import numpy as np
+
+import torch
+from torch import nn
+import torch.nn.functional as F
+from torch import distributions as torchd
+
+from ding.torch_utils.network.dreamer import weight_init, uniform_weight_init, static_scan, \
+ OneHotDist, ContDist, SymlogDist, DreamerLayerNorm
+
+
+class RSSM(nn.Module):
+
+ def __init__(
+ self,
+ stoch=30,
+ deter=200,
+ hidden=200,
+ layers_input=1,
+ layers_output=1,
+ rec_depth=1,
+ shared=False,
+ discrete=False,
+ act=nn.ELU,
+ norm=nn.LayerNorm,
+ mean_act="none",
+ std_act="softplus",
+ temp_post=True,
+ min_std=0.1,
+ cell="gru",
+ unimix_ratio=0.01,
+ num_actions=None,
+ embed=None,
+ device=None,
+ ):
+ super(RSSM, self).__init__()
+ self._stoch = stoch
+ self._deter = deter
+ self._hidden = hidden
+ self._min_std = min_std
+ self._layers_input = layers_input
+ self._layers_output = layers_output
+ self._rec_depth = rec_depth
+ self._shared = shared
+ self._discrete = discrete
+ self._act = act
+ self._norm = norm
+ self._mean_act = mean_act
+ self._std_act = std_act
+ self._temp_post = temp_post
+ self._unimix_ratio = unimix_ratio
+ self._embed = embed
+ self._device = device
+
+ inp_layers = []
+ if self._discrete:
+ inp_dim = self._stoch * self._discrete + num_actions
+ else:
+ inp_dim = self._stoch + num_actions
+ if self._shared:
+ inp_dim += self._embed
+ for i in range(self._layers_input):
+ inp_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
+ inp_layers.append(self._norm(self._hidden, eps=1e-03))
+ inp_layers.append(self._act())
+ if i == 0:
+ inp_dim = self._hidden
+ self._inp_layers = nn.Sequential(*inp_layers)
+ self._inp_layers.apply(weight_init)
+
+ if cell == "gru":
+ self._cell = GRUCell(self._hidden, self._deter)
+ self._cell.apply(weight_init)
+ elif cell == "gru_layer_norm":
+ self._cell = GRUCell(self._hidden, self._deter, norm=True)
+ self._cell.apply(weight_init)
+ else:
+ raise NotImplementedError(cell)
+
+ img_out_layers = []
+ inp_dim = self._deter
+ for i in range(self._layers_output):
+ img_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
+ img_out_layers.append(self._norm(self._hidden, eps=1e-03))
+ img_out_layers.append(self._act())
+ if i == 0:
+ inp_dim = self._hidden
+ self._img_out_layers = nn.Sequential(*img_out_layers)
+ self._img_out_layers.apply(weight_init)
+
+ obs_out_layers = []
+ if self._temp_post:
+ inp_dim = self._deter + self._embed
+ else:
+ inp_dim = self._embed
+ for i in range(self._layers_output):
+ obs_out_layers.append(nn.Linear(inp_dim, self._hidden, bias=False))
+ obs_out_layers.append(self._norm(self._hidden, eps=1e-03))
+ obs_out_layers.append(self._act())
+ if i == 0:
+ inp_dim = self._hidden
+ self._obs_out_layers = nn.Sequential(*obs_out_layers)
+ self._obs_out_layers.apply(weight_init)
+
+ if self._discrete:
+ self._ims_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
+ self._ims_stat_layer.apply(weight_init)
+ self._obs_stat_layer = nn.Linear(self._hidden, self._stoch * self._discrete)
+ self._obs_stat_layer.apply(weight_init)
+ else:
+ self._ims_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
+ self._ims_stat_layer.apply(weight_init)
+ self._obs_stat_layer = nn.Linear(self._hidden, 2 * self._stoch)
+ self._obs_stat_layer.apply(weight_init)
+
+ def initial(self, batch_size):
+ deter = torch.zeros(batch_size, self._deter).to(self._device)
+ if self._discrete:
+ state = dict(
+ logit=torch.zeros([batch_size, self._stoch, self._discrete]).to(self._device),
+ stoch=torch.zeros([batch_size, self._stoch, self._discrete]).to(self._device),
+ deter=deter,
+ )
+ else:
+ state = dict(
+ mean=torch.zeros([batch_size, self._stoch]).to(self._device),
+ std=torch.zeros([batch_size, self._stoch]).to(self._device),
+ stoch=torch.zeros([batch_size, self._stoch]).to(self._device),
+ deter=deter,
+ )
+ return state
+
+ def observe(self, embed, action, state=None):
+ swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape)))) # 交换前两维
+ if state is None:
+ state = self.initial(action.shape[0]) # {logit, stoch, deter}
+ # (batch, time, ch) -> (time, batch, ch)
+ embed, action = swap(embed), swap(action)
+ post, prior = static_scan(
+ lambda prev_state, prev_act, embed: self.obs_step(prev_state[0], prev_act, embed),
+ (action, embed),
+ (state, state),
+ )
+
+ # (time, batch, stoch, discrete_num) -> (batch, time, stoch, discrete_num)
+ post = {k: swap(v) for k, v in post.items()}
+ prior = {k: swap(v) for k, v in prior.items()}
+ return post, prior
+
+ def imagine(self, action, state=None):
+ swap = lambda x: x.permute([1, 0] + list(range(2, len(x.shape))))
+ if state is None:
+ state = self.initial(action.shape[0])
+ assert isinstance(state, dict), state
+ action = action
+ action = swap(action)
+ prior = static_scan(self.img_step, [action], state)
+ prior = prior[0]
+ prior = {k: swap(v) for k, v in prior.items()}
+ return prior
+
+ def get_feat(self, state):
+ stoch = state["stoch"]
+ if self._discrete:
+ shape = list(stoch.shape[:-2]) + [self._stoch * self._discrete]
+ stoch = stoch.reshape(shape)
+ return torch.cat([stoch, state["deter"]], -1)
+
+ def get_dist(self, state, dtype=None):
+ if self._discrete:
+ logit = state["logit"]
+ dist = torchd.independent.Independent(OneHotDist(logit, unimix_ratio=self._unimix_ratio), 1)
+ else:
+ mean, std = state["mean"], state["std"]
+ dist = ContDist(torchd.independent.Independent(torchd.normal.Normal(mean, std), 1))
+ return dist
+
+ def obs_step(self, prev_state, prev_action, embed, sample=True):
+ # if shared is True, prior and post both use same networks(inp_layers, _img_out_layers, _ims_stat_layer)
+ # otherwise, post use different network(_obs_out_layers) with prior[deter] and embed as inputs
+ prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
+ prior = self.img_step(prev_state, prev_action, None, sample)
+ if self._shared:
+ post = self.img_step(prev_state, prev_action, embed, sample)
+ else:
+ if self._temp_post:
+ x = torch.cat([prior["deter"], embed], -1)
+ else:
+ x = embed
+ # (batch_size, prior_deter + embed) -> (batch_size, hidden)
+ x = self._obs_out_layers(x)
+ # (batch_size, hidden) -> (batch_size, stoch, discrete_num)
+ stats = self._suff_stats_layer("obs", x)
+ if sample:
+ stoch = self.get_dist(stats).sample()
+ else:
+ stoch = self.get_dist(stats).mode()
+ post = {"stoch": stoch, "deter": prior["deter"], **stats}
+ return post, prior
+
+ # this is used for making future image
+ def img_step(self, prev_state, prev_action, embed=None, sample=True):
+ # (batch, stoch, discrete_num)
+ prev_action *= (1.0 / torch.clip(torch.abs(prev_action), min=1.0)).detach()
+ prev_stoch = prev_state["stoch"]
+ if self._discrete:
+ shape = list(prev_stoch.shape[:-2]) + [self._stoch * self._discrete]
+ # (batch, stoch, discrete_num) -> (batch, stoch * discrete_num)
+ prev_stoch = prev_stoch.reshape(shape)
+ if self._shared:
+ if embed is None:
+ shape = list(prev_action.shape[:-1]) + [self._embed]
+ embed = torch.zeros(shape)
+ # (batch, stoch * discrete_num) -> (batch, stoch * discrete_num + action, embed)
+ x = torch.cat([prev_stoch, prev_action, embed], -1)
+ else:
+ x = torch.cat([prev_stoch, prev_action], -1)
+ # (batch, stoch * discrete_num + action, embed) -> (batch, hidden)
+ x = self._inp_layers(x)
+ for _ in range(self._rec_depth): # rec depth is not correctly implemented
+ deter = prev_state["deter"]
+ # (batch, hidden), (batch, deter) -> (batch, deter), (batch, deter)
+ x, deter = self._cell(x, [deter])
+ deter = deter[0] # Keras wraps the state in a list.
+ # (batch, deter) -> (batch, hidden)
+ x = self._img_out_layers(x)
+ # (batch, hidden) -> (batch_size, stoch, discrete_num)
+ stats = self._suff_stats_layer("ims", x)
+ if sample:
+ stoch = self.get_dist(stats).sample()
+ else:
+ stoch = self.get_dist(stats).mode()
+ prior = {"stoch": stoch, "deter": deter, **stats} # {stoch, deter, logit}
+ return prior
+
+ def _suff_stats_layer(self, name, x):
+ if self._discrete:
+ if name == "ims":
+ x = self._ims_stat_layer(x)
+ elif name == "obs":
+ x = self._obs_stat_layer(x)
+ else:
+ raise NotImplementedError
+ logit = x.reshape(list(x.shape[:-1]) + [self._stoch, self._discrete])
+ return {"logit": logit}
+ else:
+ if name == "ims":
+ x = self._ims_stat_layer(x)
+ elif name == "obs":
+ x = self._obs_stat_layer(x)
+ else:
+ raise NotImplementedError
+ mean, std = torch.split(x, [self._stoch] * 2, -1)
+ mean = {
+ "none": lambda: mean,
+ "tanh5": lambda: 5.0 * torch.tanh(mean / 5.0),
+ }[self._mean_act]()
+ std = {
+ "softplus": lambda: torch.softplus(std),
+ "abs": lambda: torch.abs(std + 1),
+ "sigmoid": lambda: torch.sigmoid(std),
+ "sigmoid2": lambda: 2 * torch.sigmoid(std / 2),
+ }[self._std_act]()
+ std = std + self._min_std
+ return {"mean": mean, "std": std}
+
+ def kl_loss(self, post, prior, forward, free, lscale, rscale):
+ kld = torchd.kl.kl_divergence
+ dist = lambda x: self.get_dist(x)
+ sg = lambda x: {k: v.detach() for k, v in x.items()}
+ # forward == false -> (post, prior)
+ lhs, rhs = (prior, post) if forward else (post, prior)
+
+ # forward == false -> Lrep
+ value_lhs = value = kld(
+ dist(lhs) if self._discrete else dist(lhs)._dist,
+ dist(sg(rhs)) if self._discrete else dist(sg(rhs))._dist,
+ )
+ # forward == false -> Ldyn
+ value_rhs = kld(
+ dist(sg(lhs)) if self._discrete else dist(sg(lhs))._dist,
+ dist(rhs) if self._discrete else dist(rhs)._dist,
+ )
+ loss_lhs = torch.clip(torch.mean(value_lhs), min=free)
+ loss_rhs = torch.clip(torch.mean(value_rhs), min=free)
+ loss = lscale * loss_lhs + rscale * loss_rhs
+
+ return loss, value, loss_lhs, loss_rhs
+
+
+class ConvDecoder(nn.Module):
+
+ def __init__(
+ self,
+ inp_depth, # config.dyn_stoch * config.dyn_discrete + config.dyn_deter
+ depth=32,
+ act=nn.ELU,
+ norm=nn.LayerNorm,
+ shape=(3, 64, 64),
+ kernels=(3, 3, 3, 3),
+ outscale=1.0,
+ ):
+ super(ConvDecoder, self).__init__()
+ self._inp_depth = inp_depth
+ self._act = act
+ self._norm = norm
+ self._depth = depth
+ self._shape = shape
+ self._kernels = kernels
+ self._embed_size = ((64 // 2 ** (len(kernels))) ** 2 * depth * 2 ** (len(kernels) - 1))
+
+ self._linear_layer = nn.Linear(inp_depth, self._embed_size)
+ inp_dim = self._embed_size // 16 # 除以最后的4*4 feature map来得到channel数
+
+ layers = []
+ h, w = 4, 4
+ for i, kernel in enumerate(self._kernels):
+ depth = self._embed_size // 16 // (2 ** (i + 1))
+ act = self._act
+ bias = False
+ initializer = weight_init
+ if i == len(self._kernels) - 1:
+ depth = self._shape[0]
+ act = False
+ bias = True
+ norm = False
+ initializer = uniform_weight_init(outscale)
+
+ if i != 0:
+ inp_dim = 2 ** (len(self._kernels) - (i - 1) - 2) * self._depth
+ pad_h, outpad_h = self.calc_same_pad(k=kernel, s=2, d=1)
+ pad_w, outpad_w = self.calc_same_pad(k=kernel, s=2, d=1)
+ layers.append(
+ nn.ConvTranspose2d(
+ inp_dim,
+ depth,
+ kernel,
+ 2,
+ padding=(pad_h, pad_w),
+ output_padding=(outpad_h, outpad_w),
+ bias=bias,
+ )
+ )
+ if norm:
+ layers.append(DreamerLayerNorm(depth))
+ if act:
+ layers.append(act())
+ [m.apply(initializer) for m in layers[-3:]]
+ h, w = h * 2, w * 2
+
+ self.layers = nn.Sequential(*layers)
+
+ def calc_same_pad(self, k, s, d):
+ val = d * (k - 1) - s + 1
+ pad = math.ceil(val / 2)
+ outpad = pad * 2 - val
+ return pad, outpad
+
+ def __call__(self, features, dtype=None):
+ x = self._linear_layer(features) # feature:[batch, time, stoch*discrete + deter]
+ x = x.reshape([-1, 4, 4, self._embed_size // 16])
+ x = x.permute(0, 3, 1, 2)
+ x = self.layers(x)
+ mean = x.reshape(list(features.shape[:-1]) + self._shape)
+ #mean = mean.permute(0, 1, 3, 4, 2)
+ return SymlogDist(mean)
+
+
+class GRUCell(nn.Module):
+
+ def __init__(self, inp_size, size, norm=False, act=torch.tanh, update_bias=-1):
+ super(GRUCell, self).__init__()
+ self._inp_size = inp_size # hidden
+ self._size = size # deter
+ self._act = act
+ self._norm = norm
+ self._update_bias = update_bias
+ self._layer = nn.Linear(inp_size + size, 3 * size, bias=False)
+ if norm:
+ self._norm = nn.LayerNorm(3 * size, eps=1e-03)
+
+ @property
+ def state_size(self):
+ return self._size
+
+ def forward(self, inputs, state):
+ state = state[0] # Keras wraps the state in a list.
+ parts = self._layer(torch.cat([inputs, state], -1))
+ if self._norm:
+ parts = self._norm(parts)
+ reset, cand, update = torch.split(parts, [self._size] * 3, -1)
+ reset = torch.sigmoid(reset)
+ cand = self._act(reset * cand)
+ update = torch.sigmoid(update + self._update_bias)
+ output = update * cand + (1 - update) * state
+ return output, [output]
diff --git a/DI-engine/ding/world_model/model/tests/test_ensemble.py b/DI-engine/ding/world_model/model/tests/test_ensemble.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dcaf9395f862b4a08cc09c0852fd46108757f82
--- /dev/null
+++ b/DI-engine/ding/world_model/model/tests/test_ensemble.py
@@ -0,0 +1,29 @@
+import pytest
+import torch
+from itertools import product
+from ding.world_model.model.ensemble import EnsembleFC, EnsembleModel
+
+# arguments
+state_size = [16]
+action_size = [16, 1]
+reward_size = [1]
+args = list(product(*[state_size, action_size, reward_size]))
+
+
+@pytest.mark.unittest
+def test_EnsembleFC():
+ in_dim, out_dim, ensemble_size, B = 4, 8, 7, 64
+ fc = EnsembleFC(in_dim, out_dim, ensemble_size)
+ x = torch.randn(ensemble_size, B, in_dim)
+ y = fc(x)
+ assert y.shape == (ensemble_size, B, out_dim)
+
+
+@pytest.mark.parametrize('state_size, action_size, reward_size', args)
+def test_EnsembleModel(state_size, action_size, reward_size):
+ ensemble_size, B = 7, 64
+ model = EnsembleModel(state_size, action_size, reward_size, ensemble_size)
+ x = torch.randn(ensemble_size, B, state_size + action_size)
+ y = model(x)
+ assert len(y) == 2
+ assert y[0].shape == y[1].shape == (ensemble_size, B, state_size + reward_size)
diff --git a/DI-engine/ding/world_model/model/tests/test_networks.py b/DI-engine/ding/world_model/model/tests/test_networks.py
new file mode 100644
index 0000000000000000000000000000000000000000..c23c94cd3d6687974f67dd7c660dece3eac03384
--- /dev/null
+++ b/DI-engine/ding/world_model/model/tests/test_networks.py
@@ -0,0 +1,3 @@
+import pytest
+import torch
+from itertools import product
diff --git a/DI-engine/ding/world_model/tests/test_ddppo.py b/DI-engine/ding/world_model/tests/test_ddppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..138a2625cea0996b01b468dc4c58b1988346453d
--- /dev/null
+++ b/DI-engine/ding/world_model/tests/test_ddppo.py
@@ -0,0 +1,108 @@
+import pytest
+import torch
+from torch import nn
+
+from itertools import product
+from easydict import EasyDict
+from ding.world_model.ddppo import DDPPOWorldMode, get_batch_jacobian, get_neighbor_index
+from ding.utils import deep_merge_dicts
+
+# arguments
+state_size = [16]
+action_size = [16, 1]
+args = list(product(*[state_size, action_size]))
+
+
+@pytest.mark.unittest
+class TestDDPPO:
+
+ def get_world_model(self, state_size, action_size):
+ cfg = DDPPOWorldMode.default_config()
+ cfg.model.max_epochs_since_update = 0
+ cfg = deep_merge_dicts(
+ cfg, dict(cuda=False, model=dict(state_size=state_size, action_size=action_size, reward_size=1))
+ )
+ fake_env = EasyDict(termination_fn=lambda obs: torch.zeros_like(obs.sum(-1)).bool())
+ model = DDPPOWorldMode(cfg, fake_env, None)
+ model.serial_calc_nn = True
+ return model
+
+ def test_get_neighbor_index(self):
+ k = 2
+ data = torch.tensor([[0, 0, 0], [0, 0, 1], [0, 0, -1], [5, 0, 0], [5, 0, 1], [5, 0, -1]])
+ idx = get_neighbor_index(data, k, serial=True)
+ target_idx = torch.tensor([[2, 1], [0, 2], [0, 1], [5, 4], [3, 5], [3, 4]])
+ assert (idx - target_idx).sum() == 0
+
+ def test_get_batch_jacobian(self):
+ B, in_dim, out_dim = 64, 4, 8
+ net = nn.Linear(in_dim, out_dim)
+ x = torch.randn(B, in_dim)
+ jacobian = get_batch_jacobian(net, x, out_dim)
+ assert jacobian.shape == (B, out_dim, in_dim)
+
+ @pytest.mark.parametrize('state_size, action_size', args)
+ def test_get_jacobian(self, state_size, action_size):
+ B, ensemble_size = 64, 7
+ model = self.get_world_model(state_size, action_size)
+ train_input_reg = torch.randn(ensemble_size, B, state_size + action_size)
+ jacobian = model._get_jacobian(model.gradient_model, train_input_reg)
+ assert jacobian.shape == (ensemble_size, B, state_size + 1, state_size + action_size)
+ assert jacobian.requires_grad
+
+ @pytest.mark.parametrize('state_size, action_size', args)
+ def test_step(self, state_size, action_size):
+ states = torch.rand(128, state_size)
+ actions = torch.rand(128, action_size)
+ model = self.get_world_model(state_size, action_size)
+ model.elite_model_idxes = [0, 1]
+ rewards, next_obs, dones = model.step(states, actions)
+ assert rewards.shape == (128, )
+ assert next_obs.shape == (128, state_size)
+ assert dones.shape == (128, )
+
+ @pytest.mark.parametrize('state_size, action_size', args)
+ def test_train_rollout_model(self, state_size, action_size):
+ states = torch.rand(1280, state_size)
+ actions = torch.rand(1280, action_size)
+
+ next_states = states + actions.mean(1, keepdim=True)
+ rewards = next_states.mean(1, keepdim=True).repeat(1, 1)
+
+ inputs = torch.cat([states, actions], dim=1)
+ labels = torch.cat([rewards, next_states], dim=1)
+
+ model = self.get_world_model(state_size, action_size)
+ model._train_rollout_model(inputs[:64], labels[:64])
+
+ @pytest.mark.parametrize('state_size, action_size', args)
+ def test_train_graident_model(self, state_size, action_size):
+ states = torch.rand(1280, state_size)
+ actions = torch.rand(1280, action_size)
+
+ next_states = states + actions.mean(1, keepdim=True)
+ rewards = next_states.mean(1, keepdim=True)
+
+ inputs = torch.cat([states, actions], dim=1)
+ labels = torch.cat([rewards, next_states], dim=1)
+
+ model = self.get_world_model(state_size, action_size)
+ model._train_gradient_model(inputs[:64], labels[:64], inputs[:64], labels[:64])
+
+ @pytest.mark.parametrize('state_size, action_size', args[:1])
+ def test_others(self, state_size, action_size):
+ states = torch.rand(1280, state_size)
+ actions = torch.rand(1280, action_size)
+
+ next_states = states + actions.mean(1, keepdim=True)
+ rewards = next_states.mean(1, keepdim=True)
+
+ inputs = torch.cat([states, actions], dim=1)
+ labels = torch.cat([rewards, next_states], dim=1)
+
+ model = self.get_world_model(state_size, action_size)
+ model._train_rollout_model(inputs[:64], labels[:64])
+ model._train_gradient_model(inputs[:64], labels[:64], inputs[:64], labels[:64])
+ model._save_states()
+ model._load_states()
+ model._save_best(0, [1, 2, 3])
diff --git a/DI-engine/ding/world_model/tests/test_dreamerv3.py b/DI-engine/ding/world_model/tests/test_dreamerv3.py
new file mode 100644
index 0000000000000000000000000000000000000000..f93673d39bba2772a62720b974f18ee0f1a7a88b
--- /dev/null
+++ b/DI-engine/ding/world_model/tests/test_dreamerv3.py
@@ -0,0 +1,32 @@
+import pytest
+import torch
+
+from itertools import product
+from easydict import EasyDict
+from ding.world_model.dreamer import DREAMERWorldModel
+from ding.utils import deep_merge_dicts
+
+# arguments
+state_size = [[3, 64, 64]]
+action_size = [6, 1]
+args = list(product(*[state_size, action_size]))
+
+
+@pytest.mark.unittest
+class TestDREAMER:
+
+ def get_world_model(self, state_size, action_size):
+ cfg = DREAMERWorldModel.default_config()
+ cfg.model.max_epochs_since_update = 0
+ cfg = deep_merge_dicts(
+ cfg, dict(cuda=False, model=dict(state_size=state_size, action_size=action_size, reward_size=1))
+ )
+ fake_env = EasyDict(termination_fn=lambda obs: torch.zeros_like(obs.sum(-1)).bool())
+ return DREAMERWorldModel(cfg, fake_env, None)
+
+ @pytest.mark.parametrize('state_size, action_size', args)
+ def test_train(self, state_size, action_size):
+ states = torch.rand(1280, *state_size)
+ actions = torch.rand(1280, action_size)
+
+ model = self.get_world_model(state_size, action_size)
diff --git a/DI-engine/ding/world_model/tests/test_idm.py b/DI-engine/ding/world_model/tests/test_idm.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbc6b2d6880685cf1c8aa2465b6ab5b8eab43359
--- /dev/null
+++ b/DI-engine/ding/world_model/tests/test_idm.py
@@ -0,0 +1,102 @@
+import torch
+import pytest
+from itertools import product
+
+from ding.world_model.idm import InverseDynamicsModel
+from ding.torch_utils import is_differentiable
+from ding.utils import squeeze
+
+B = 4
+obs_shape_arg = [4, (8, ), (9, 64, 64)]
+encoder_hidden_size_list = [10, 20, 10]
+action_shape_arg = [6, (6, ), [6]]
+args = list(product(*[obs_shape_arg, action_shape_arg, ['regression', 'reparameterization']]))
+
+
+@pytest.mark.unittest
+class TestContinousIDM:
+
+ @pytest.mark.parametrize('obs_shape, action_shape, action_space', args)
+ def test_continuous_idm(self, obs_shape, action_shape, action_space):
+
+ model = InverseDynamicsModel(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ encoder_hidden_size_list=encoder_hidden_size_list,
+ action_space=action_space,
+ )
+ inputs = {}
+ if isinstance(obs_shape, int):
+ inputs['obs'] = torch.randn(B, obs_shape * 2)
+ else:
+ inputs['obs'] = torch.randn(B, *(obs_shape[0] * 2, *obs_shape[1:]))
+ if isinstance(action_shape, int):
+ inputs['action'] = torch.randn(B, action_shape)
+ else:
+ inputs['action'] = torch.randn(B, *action_shape)
+ if action_space == 'regression':
+ action = model.predict_action(inputs['obs'])['action']
+ if isinstance(action_shape, int):
+ assert action.shape == (B, action_shape)
+ else:
+ assert action.shape == (B, *action_shape)
+ assert action.eq(action.clamp(-1, 1)).all()
+ elif action_space == 'reparameterization':
+ (mu, sigma) = model.predict_action(inputs['obs'])['logit']
+ action = model.predict_action(inputs['obs'])['action']
+ if isinstance(action_shape, int):
+ assert mu.shape == (B, action_shape)
+ assert sigma.shape == (B, action_shape)
+ assert action.shape == (B, action_shape)
+ else:
+ assert mu.shape == (B, *action_shape)
+ assert sigma.shape == (B, *action_shape)
+ assert action.shape == (B, *action_shape)
+
+ loss = model.train(inputs, n_epoch=10, learning_rate=0.01, weight_decay=1e-4)
+ assert isinstance(loss, float)
+
+
+B = 4
+obs_shape = [4, (8, ), (4, 64, 64)]
+action_shape = [6, (6, ), [6]]
+encoder_hidden_size_list = [10, 20, 10]
+args = list(product(*[obs_shape, action_shape]))
+action_space = 'discrete'
+
+
+@pytest.mark.unittest
+class TestDiscreteIDM:
+
+ @pytest.mark.parametrize('obs_shape, action_shape', args)
+ def test_discrete_idm(self, obs_shape, action_shape):
+ model = InverseDynamicsModel(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ encoder_hidden_size_list=encoder_hidden_size_list,
+ action_space=action_space,
+ )
+ inputs = {}
+ if isinstance(obs_shape, int):
+ inputs['obs'] = torch.randn(B, obs_shape * 2)
+ else:
+ obs_shape = (obs_shape[0] * 2, *obs_shape[1:])
+ inputs['obs'] = torch.randn(B, *obs_shape)
+ # inputs['action'] = torch.randint(action_shape, B)
+ if isinstance(action_shape, int):
+ inputs['action'] = torch.randint(action_shape, (B, ))
+ else:
+ inputs['action'] = torch.randint(action_shape[0], (B, ))
+
+ outputs = model.forward(inputs['obs'])
+ assert isinstance(outputs, dict)
+ if isinstance(action_shape, int):
+ assert outputs['logit'].shape == (B, action_shape)
+ else:
+ assert outputs['logit'].shape == (B, *action_shape)
+ # self.test_train(model, inputs)
+ action = model.predict_action(inputs['obs'])['action']
+ assert action.shape == (B, )
+
+ loss = model.train(inputs, n_epoch=10, learning_rate=0.01, weight_decay=1e-4)
+ assert isinstance(loss, float)
diff --git a/DI-engine/ding/world_model/tests/test_mbpo.py b/DI-engine/ding/world_model/tests/test_mbpo.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd7c80a5fb9f6a2f7df1602958bbdce67eb34857
--- /dev/null
+++ b/DI-engine/ding/world_model/tests/test_mbpo.py
@@ -0,0 +1,67 @@
+import pytest
+import torch
+
+from itertools import product
+from easydict import EasyDict
+from ding.world_model.mbpo import MBPOWorldModel
+from ding.utils import deep_merge_dicts
+
+# arguments
+state_size = [16]
+action_size = [16, 1]
+args = list(product(*[state_size, action_size]))
+
+
+@pytest.mark.unittest
+class TestMBPO:
+
+ def get_world_model(self, state_size, action_size):
+ cfg = MBPOWorldModel.default_config()
+ cfg.model.max_epochs_since_update = 0
+ cfg = deep_merge_dicts(
+ cfg, dict(cuda=False, model=dict(state_size=state_size, action_size=action_size, reward_size=1))
+ )
+ fake_env = EasyDict(termination_fn=lambda obs: torch.zeros_like(obs.sum(-1)).bool())
+ return MBPOWorldModel(cfg, fake_env, None)
+
+ @pytest.mark.parametrize('state_size, action_size', args)
+ def test_step(self, state_size, action_size):
+ states = torch.rand(128, state_size)
+ actions = torch.rand(128, action_size)
+ model = self.get_world_model(state_size, action_size)
+ model.elite_model_idxes = [0, 1]
+ rewards, next_obs, dones = model.step(states, actions)
+ assert rewards.shape == (128, )
+ assert next_obs.shape == (128, state_size)
+ assert dones.shape == (128, )
+
+ @pytest.mark.parametrize('state_size, action_size', args)
+ def test_train(self, state_size, action_size):
+ states = torch.rand(1280, state_size)
+ actions = torch.rand(1280, action_size)
+
+ next_states = states + actions.mean(1, keepdim=True)
+ rewards = next_states.mean(1, keepdim=True)
+
+ inputs = torch.cat([states, actions], dim=1)
+ labels = torch.cat([rewards, next_states], dim=1)
+
+ model = self.get_world_model(state_size, action_size)
+ model._train(inputs[:64], labels[:64])
+
+ @pytest.mark.parametrize('state_size, action_size', args[:1])
+ def test_others(self, state_size, action_size):
+ states = torch.rand(1280, state_size)
+ actions = torch.rand(1280, action_size)
+
+ next_states = states + actions.mean(1, keepdim=True)
+ rewards = next_states.mean(1, keepdim=True)
+
+ inputs = torch.cat([states, actions], dim=1)
+ labels = torch.cat([rewards, next_states], dim=1)
+
+ model = self.get_world_model(state_size, action_size)
+ model._train(inputs[:64], labels[:64])
+ model._save_states()
+ model._load_states()
+ model._save_best(0, [1, 2, 3])
diff --git a/DI-engine/ding/world_model/tests/test_world_model.py b/DI-engine/ding/world_model/tests/test_world_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8dd620c595aa7a1023050640a956e40c5fd4333
--- /dev/null
+++ b/DI-engine/ding/world_model/tests/test_world_model.py
@@ -0,0 +1,123 @@
+import pytest
+import os
+import torch
+from easydict import EasyDict
+from ding.world_model.base_world_model import DreamWorldModel, DynaWorldModel
+from ding.worker.replay_buffer import NaiveReplayBuffer, EpisodeReplayBuffer
+
+
+@pytest.mark.unittest
+class TestDynaWorldModel:
+
+ @pytest.mark.parametrize('buffer_type', [NaiveReplayBuffer, EpisodeReplayBuffer])
+ def test_fill_img_buffer(self, buffer_type):
+ env_buffer = buffer_type(buffer_type.default_config(), None, 'dyna_exp_name', 'env_buffer_for_test')
+ img_buffer = buffer_type(buffer_type.default_config(), None, 'dyna_exp_name', 'img_buffer_for_test')
+ fake_config = EasyDict(
+ train_freq=250, # w.r.t environment step
+ eval_freq=250, # w.r.t environment step
+ cuda=False,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=20000,
+ rollout_end_step=150000,
+ rollout_length_min=1,
+ rollout_length_max=25,
+ ),
+ other=dict(
+ real_ratio=0.05,
+ rollout_retain=4,
+ rollout_batch_size=100000,
+ imagination_buffer=dict(
+ type='elastic',
+ replay_buffer_size=6000000,
+ deepcopy=False,
+ enable_track_used_data=False,
+ # set_buffer_size=set_buffer_size,
+ periodic_thruput_seconds=60,
+ ),
+ ),
+ )
+ T, B, O, A = 25, 20, 100, 30
+
+ class FakeModel(DynaWorldModel):
+
+ def train(self, env_buffer, envstep, train_iter):
+ pass
+
+ def eval(self, env_buffer, envstep, train_iter):
+ pass
+
+ def step(self, obs, action):
+ return (torch.zeros(B), torch.rand(B, O), obs.sum(-1) > 0)
+
+ from ding.policy import SACPolicy
+ from ding.model import ContinuousQAC
+
+ policy_config = SACPolicy.default_config()
+ policy_config.model.update(dict(obs_shape=2, action_shape=2))
+ model = ContinuousQAC(**policy_config.model)
+ policy = SACPolicy(policy_config, model=model).collect_mode
+
+ fake_model = FakeModel(fake_config, None, None)
+
+ env_buffer.push(
+ [
+ {
+ 'obs': torch.randn(2),
+ 'next_obs': torch.randn(2),
+ 'action': torch.randn(2),
+ 'reward': torch.randn(1),
+ 'done': False,
+ 'collect_iter': 0
+ }
+ ] * 20, 0
+ )
+
+ super(FakeModel, fake_model).fill_img_buffer(policy, env_buffer, img_buffer, 0, 0)
+ os.popen("rm -rf dyna_exp_name")
+
+
+@pytest.mark.unittest
+class TestDreamWorldModel:
+
+ def test_rollout(self):
+ fake_config = EasyDict(
+ train_freq=250, # w.r.t environment step
+ eval_freq=250, # w.r.t environment step
+ cuda=False,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=20000,
+ rollout_end_step=150000,
+ rollout_length_min=1,
+ rollout_length_max=25,
+ )
+ )
+ envstep = 150000
+ T, B, O, A = 25, 20, 100, 30
+
+ class FakeModel(DreamWorldModel):
+
+ def train(self, env_buffer, envstep, train_iter):
+ pass
+
+ def eval(self, env_buffer, envstep, train_iter):
+ pass
+
+ def step(self, obs, action):
+ return (torch.zeros(B), torch.rand(B, O), obs.sum(-1) > 0)
+
+ def fake_policy_fn(obs):
+ return torch.randn(B, A), torch.zeros(B)
+
+ fake_model = FakeModel(fake_config, None, None)
+
+ obs = torch.rand(B, O)
+ obss, actions, rewards, aug_rewards, dones = \
+ super(FakeModel, fake_model).rollout(obs, fake_policy_fn, envstep)
+ assert obss.shape == (T + 1, B, O)
+ assert actions.shape == (T + 1, B, A)
+ assert rewards.shape == (T, B)
+ assert aug_rewards.shape == (T + 1, B)
+ assert dones.shape == (T, B)
diff --git a/DI-engine/ding/world_model/tests/test_world_model_utils.py b/DI-engine/ding/world_model/tests/test_world_model_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ba5e7f5e1357aadc928cd4995021243ca1b2c4
--- /dev/null
+++ b/DI-engine/ding/world_model/tests/test_world_model_utils.py
@@ -0,0 +1,19 @@
+import pytest
+from easydict import EasyDict
+from ding.world_model.utils import get_rollout_length_scheduler
+
+
+@pytest.mark.unittest
+def test_get_rollout_length_scheduler():
+ fake_cfg = EasyDict(
+ type='linear',
+ rollout_start_step=20000,
+ rollout_end_step=150000,
+ rollout_length_min=1,
+ rollout_length_max=25,
+ )
+ scheduler = get_rollout_length_scheduler(fake_cfg)
+ assert scheduler(0) == 1
+ assert scheduler(19999) == 1
+ assert scheduler(150000) == 25
+ assert scheduler(1500000) == 25
diff --git a/DI-engine/ding/world_model/utils.py b/DI-engine/ding/world_model/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..15172699f98f227d2b954468ae07c0f9ac22ef42
--- /dev/null
+++ b/DI-engine/ding/world_model/utils.py
@@ -0,0 +1,25 @@
+from easydict import EasyDict
+from typing import Callable
+
+
+def get_rollout_length_scheduler(cfg: EasyDict) -> Callable[[int], int]:
+ """
+ Overview:
+ Get the rollout length scheduler that adapts rollout length based\
+ on the current environment steps.
+ Returns:
+ - scheduler (:obj:`Callble`): The function that takes envstep and\
+ return the current rollout length.
+ """
+ if cfg.type == 'linear':
+ x0 = cfg.rollout_start_step
+ x1 = cfg.rollout_end_step
+ y0 = cfg.rollout_length_min
+ y1 = cfg.rollout_length_max
+ w = (y1 - y0) / (x1 - x0)
+ b = y0
+ return lambda x: int(min(max(w * (x - x0) + b, y0), y1))
+ elif cfg.type == 'constant':
+ return lambda x: cfg.rollout_length
+ else:
+ raise KeyError("not implemented key: {}".format(cfg.type))
diff --git a/DI-engine/dizoo/__init__.py b/DI-engine/dizoo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/atari/__init__.py b/DI-engine/dizoo/atari/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/atari/config/__init__.py b/DI-engine/dizoo/atari/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/atari/config/serial/__init__.py b/DI-engine/dizoo/atari/config/serial/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ecd50235d2da7907fbd27e076f92721ad6d8518
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/__init__.py
@@ -0,0 +1,5 @@
+from dizoo.atari.config.serial.enduro import *
+from dizoo.atari.config.serial.pong import *
+from dizoo.atari.config.serial.qbert import *
+from dizoo.atari.config.serial.spaceinvaders import *
+from dizoo.atari.config.serial.asterix import *
\ No newline at end of file
diff --git a/DI-engine/dizoo/atari/config/serial/asterix/__init__.py b/DI-engine/dizoo/atari/config/serial/asterix/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d25b637d560192c29b44e763a8359814d3d47ded
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/asterix/__init__.py
@@ -0,0 +1 @@
+from .asterix_mdqn_config import asterix_mdqn_config, asterix_mdqn_create_config
diff --git a/DI-engine/dizoo/atari/config/serial/asterix/asterix_mdqn_config.py b/DI-engine/dizoo/atari/config/serial/asterix/asterix_mdqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..675ae325401a40f10b47cf6494de168177b6a3a3
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/asterix/asterix_mdqn_config.py
@@ -0,0 +1,63 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+asterix_mdqn_config = dict(
+ exp_name='asterix_mdqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20000,
+ env_id='AsterixNoFrameskip-v0',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=9,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ entropy_tau=0.03,
+ m_alpha=0.9,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000, ))
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+asterix_mdqn_config = EasyDict(asterix_mdqn_config)
+main_config = asterix_mdqn_config
+asterix_mdqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='mdqn'),
+)
+asterix_mdqn_create_config = EasyDict(asterix_mdqn_create_config)
+create_config = asterix_mdqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c asterix_mdqn_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7), dynamic_seed=False)
diff --git a/DI-engine/dizoo/atari/config/serial/enduro/__init__.py b/DI-engine/dizoo/atari/config/serial/enduro/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..73c5be00bd318e474660ed12454cfc22c1b74234
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/enduro/__init__.py
@@ -0,0 +1 @@
+from .enduro_dqn_config import enduro_dqn_config, enduro_dqn_create_config
diff --git a/DI-engine/dizoo/atari/config/serial/enduro/enduro_dqn_config.py b/DI-engine/dizoo/atari/config/serial/enduro/enduro_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb22c207379ace0adf9ee201d4bdc899ae51d677
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/enduro/enduro_dqn_config.py
@@ -0,0 +1,60 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+enduro_dqn_config = dict(
+ exp_name='enduro_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=int(1e10),
+ env_id='EnduroNoFrameskip-v4',
+ #'ALE/Enduro-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=9,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.5,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+enduro_dqn_config = EasyDict(enduro_dqn_config)
+main_config = enduro_dqn_config
+enduro_dqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+enduro_dqn_create_config = EasyDict(enduro_dqn_create_config)
+create_config = enduro_dqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c enduro_dqn_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/enduro/enduro_impala_config.py b/DI-engine/dizoo/atari/config/serial/enduro/enduro_impala_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f13770aadc6e0a4045d8f69c74fe37ef9ce98ea
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/enduro/enduro_impala_config.py
@@ -0,0 +1,83 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+enduro_impala_config = dict(
+ exp_name='enduro_impala_seed0',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='EnduroNoFrameskip-v4',
+ #'ALE/Enduro-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ # (int) the trajectory length to calculate v-trace target
+ unroll_len=64,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=9,
+ encoder_hidden_size_list=[128, 128, 512],
+ critic_head_hidden_size=512,
+ critic_head_layer_num=2,
+ actor_head_hidden_size=512,
+ actor_head_layer_num=2,
+ ),
+ learn=dict(
+ # (int) collect n_sample data, train model update_per_collect times
+ # here we follow ppo serial pipeline
+ update_per_collect=10,
+ # (int) the number of data for a train iteration
+ batch_size=128,
+ grad_clip_type='clip_norm',
+ clip_value=10.0,
+ learning_rate=0.0001,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=1.0,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.0000001,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ # (float) additional discounting parameter
+ lambda_=1.0,
+ # (float) clip ratio of importance weights
+ rho_clip_ratio=1.0,
+ # (float) clip ratio of importance weights
+ c_clip_ratio=1.0,
+ # (float) clip ratio of importance sampling
+ rho_pg_clip_ratio=1.0,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=16,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ other=dict(replay_buffer=dict(
+ type='naive',
+ replay_buffer_size=500000,
+ max_use=100,
+ ), ),
+ ),
+)
+main_config = EasyDict(enduro_impala_config)
+
+enduro_impala_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='impala'),
+)
+create_config = EasyDict(enduro_impala_create_config)
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c enduro_impala_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/enduro/enduro_mdqn_config.py b/DI-engine/dizoo/atari/config/serial/enduro/enduro_mdqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b08f4ace95d18ffb640bbf0f232f4a7cabcf2ba
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/enduro/enduro_mdqn_config.py
@@ -0,0 +1,63 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+enduro_mdqn_config = dict(
+ exp_name='enduro_mdqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=int(1e10),
+ env_id='EnduroNoFrameskip-v4',
+ #'ALE/Enduro-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=9,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ entropy_tau=0.03,
+ m_alpha=0.9,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000, ))
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.5,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+enduro_mdqn_config = EasyDict(enduro_mdqn_config)
+main_config = enduro_mdqn_config
+enduro_mdqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='mdqn'),
+)
+enduro_mdqn_create_config = EasyDict(enduro_mdqn_create_config)
+create_config = enduro_mdqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c enduro_mdqn_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7), dynamic_seed=False)
diff --git a/DI-engine/dizoo/atari/config/serial/enduro/enduro_onppo_config.py b/DI-engine/dizoo/atari/config/serial/enduro/enduro_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba2ffa21d7aaac77b3829eb9b3cf640e5ad33c13
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/enduro/enduro_onppo_config.py
@@ -0,0 +1,66 @@
+from easydict import EasyDict
+
+enduro_onppo_config = dict(
+ exp_name='enduro_onppo_seed0',
+ env=dict(
+ collector_env_num=64,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='EnduroNoFrameskip-v4',
+ #'ALE/Enduro-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=9,
+ encoder_hidden_size_list=[32, 64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ critic_head_layer_num=2,
+ ),
+ learn=dict(
+ update_per_collect=24,
+ batch_size=128,
+ # (bool) Whether to normalize advantage. Default to False.
+ adv_norm=False,
+ learning_rate=0.0001,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=1.0,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.001, # [0.1, 0.01 ,0.0]
+ clip_ratio=0.1
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=1024,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.95,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(replay_buffer=dict(
+ replay_buffer_size=10000,
+ max_use=3,
+ ), ),
+ ),
+)
+main_config = EasyDict(enduro_onppo_config)
+
+enduro_onppo_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(enduro_onppo_create_config)
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial_onpolicy -c enduro_onppo_config.py -s 0
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/enduro/enduro_qrdqn_config.py b/DI-engine/dizoo/atari/config/serial/enduro/enduro_qrdqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce0409fedbd3c8ee25e8e5c8eec6c90402f866d8
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/enduro/enduro_qrdqn_config.py
@@ -0,0 +1,60 @@
+from easydict import EasyDict
+
+enduro_qrdqn_config = dict(
+ exp_name='enduro_qrdqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='EnduroNoFrameskip-v4',
+ #'ALE/Enduro-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=64,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+enduro_qrdqn_config = EasyDict(enduro_qrdqn_config)
+main_config = enduro_qrdqn_config
+enduro_qrdqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qrdqn'),
+)
+enduro_qrdqn_create_config = EasyDict(enduro_qrdqn_create_config)
+create_config = enduro_qrdqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c enduro_qrdqn_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/enduro/enduro_rainbow_config.py b/DI-engine/dizoo/atari/config/serial/enduro/enduro_rainbow_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0735ceeabcd7bb1dd7de1727e7a0324e88a3fd2d
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/enduro/enduro_rainbow_config.py
@@ -0,0 +1,62 @@
+from easydict import EasyDict
+
+enduro_rainbow_config = dict(
+ exp_name='enduro_rainbow_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='EnduroNoFrameskip-v4',
+ #'ALE/Enduro-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=9,
+ encoder_hidden_size_list=[128, 128, 512],
+ v_min=-10,
+ v_max=10,
+ n_atom=51,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+enduro_rainbow_config = EasyDict(enduro_rainbow_config)
+main_config = enduro_rainbow_config
+enduro_rainbow_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='rainbow'),
+)
+enduro_rainbow_create_config = EasyDict(enduro_rainbow_create_config)
+create_config = enduro_rainbow_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c enduro_rainbow_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/montezuma/montezuma_ngu_config.py b/DI-engine/dizoo/atari/config/serial/montezuma/montezuma_ngu_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a3bc125942367d3b84b2a008864d5e7e6336cc7
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/montezuma/montezuma_ngu_config.py
@@ -0,0 +1,129 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+nstep = 5
+max_env_step = int(10e6)
+
+montezuma_ngu_config = dict(
+ exp_name='montezuma_ngu_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=8,
+ env_id='MontezumaRevengeNoFrameskip-v4',
+ #'ALE/MontezumaRevenge-v5' is available. But special setting is needed after gym make.
+ obs_plus_prev_action_reward=True, # use specific env wrapper for ngu policy
+ stop_value=int(1e5),
+ frame_stack=4,
+ ),
+ rnd_reward_model=dict(
+ intrinsic_reward_type='add',
+ learning_rate=0.001,
+ obs_shape=[4, 84, 84],
+ action_shape=18,
+ batch_size=320,
+ update_per_collect=10,
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='rnd-ngu',
+ ),
+ episodic_reward_model=dict(
+ # means if using rescale trick to the last non-zero reward
+ # when combing extrinsic and intrinsic reward.
+ # the rescale trick only used in:
+ # 1. sparse reward env minigrid, in which the last non-zero reward is a strong positive signal
+ # 2. the last reward of each episode directly reflects the agent's completion of the task, e.g. lunarlander
+ # Note that the ngu intrinsic reward is a positive value (max value is 5), in these envs,
+ # the last non-zero reward should not be overwhelmed by intrinsic rewards, so we need rescale the
+ # original last nonzero extrinsic reward.
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_rescale=False,
+ # means the rescale value for the last non-zero reward, only used when last_nonzero_reward_rescale is True
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_weight=1,
+ intrinsic_reward_type='add',
+ learning_rate=0.001,
+ obs_shape=[4, 84, 84],
+ action_shape=18,
+ batch_size=320,
+ update_per_collect=10, # 32*100/64=50
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='episodic',
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ priority=True,
+ priority_IS_weight=True,
+ discount_factor=0.997,
+ nstep=nstep,
+ burnin_step=20,
+ # (int) is the total length of [sequence sample] minus
+ # the length of burnin part in [sequence sample],
+ # i.e., = = +
+ learn_unroll_len=80, # set this key according to the episode length
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=18,
+ encoder_hidden_size_list=[128, 128, 512],
+ collector_env_num=collector_env_num,
+ ),
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=int(2e3),
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+montezuma_ngu_config = EasyDict(montezuma_ngu_config)
+main_config = montezuma_ngu_config
+montezuma_ngu_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ngu'),
+ rnd_reward_model=dict(type='rnd-ngu'),
+ episodic_reward_model=dict(type='episodic'),
+)
+montezuma_ngu_create_config = EasyDict(montezuma_ngu_create_config)
+create_config = montezuma_ngu_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_reward_model_ngu
+ serial_pipeline_reward_model_ngu([main_config, create_config], seed=0, max_env_step=max_env_step)
diff --git a/DI-engine/dizoo/atari/config/serial/phoenix/phoenix_fqf_config.py b/DI-engine/dizoo/atari/config/serial/phoenix/phoenix_fqf_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc0273ad568deca7f9dca3751ee4c7e68b185eac
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/phoenix/phoenix_fqf_config.py
@@ -0,0 +1,62 @@
+from easydict import EasyDict
+
+phoenix_fqf_config = dict(
+ exp_name='phoenix_fqf_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='PhoenixNoFrameskip-v4',
+ #'ALE/Phoenix-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=32,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate_fraction=2.5e-9,
+ learning_rate_quantile=0.00005,
+ target_update_freq=500,
+ ent_coef=0,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=1000000, ),
+ ),
+ ),
+)
+phoenix_fqf_config = EasyDict(phoenix_fqf_config)
+main_config = phoenix_fqf_config
+phoenix_fqf_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='fqf'),
+)
+phoenix_fqf_create_config = EasyDict(phoenix_fqf_create_config)
+create_config = phoenix_fqf_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c phoenix_fqf_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
\ No newline at end of file
diff --git a/DI-engine/dizoo/atari/config/serial/phoenix/phoenix_iqn_config.py b/DI-engine/dizoo/atari/config/serial/phoenix/phoenix_iqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e2a2036af56987ddff53693a092612283fa2a99
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/phoenix/phoenix_iqn_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+phoenix_iqn_config = dict(
+ exp_name='phoenix_iqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PhoenixNoFrameskip-v4',
+ #'ALE/Phoenix-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=32,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ kappa=1.0,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+phoenix_iqn_config = EasyDict(phoenix_iqn_config)
+main_config = phoenix_iqn_config
+phoenix_iqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='iqn'),
+)
+phoenix_iqn_create_config = EasyDict(phoenix_iqn_create_config)
+create_config = phoenix_iqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c phoenix_iqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pitfall/pitfall_ngu_config.py b/DI-engine/dizoo/atari/config/serial/pitfall/pitfall_ngu_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e22563adcb6b7a5bb13e8d0ff7a60e6ea381604
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pitfall/pitfall_ngu_config.py
@@ -0,0 +1,132 @@
+from easydict import EasyDict
+
+collector_env_num = 32
+evaluator_env_num = 5
+nstep = 5
+max_env_step = int(10e6)
+
+pitfall_ngu_config = dict(
+ # Note:
+ # 1. at least 1e10 timesteps, i.e., 10000 million, the reward may increase, please be patient.
+ # 2. the larger unroll_lenth and replay buffer size may have better results, but also require more memory.
+ exp_name='pitfall_ngu_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=5,
+ env_id='PitfallNoFrameskip-v4',
+ #'ALE/Pitfall-v5' is available. But special setting is needed after gym make.
+ obs_plus_prev_action_reward=True, # use specific env wrapper for ngu policy
+ stop_value=int(1e5),
+ frame_stack=4,
+ ),
+ rnd_reward_model=dict(
+ intrinsic_reward_type='add', # 'assign'
+ learning_rate=1e-4,
+ obs_shape=[4, 84, 84],
+ action_shape=18,
+ batch_size=320,
+ update_per_collect=10,
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='rnd-ngu',
+ ),
+ episodic_reward_model=dict(
+ # means if using rescale trick to the last non-zero reward
+ # when combing extrinsic and intrinsic reward.
+ # the rescale trick only used in:
+ # 1. sparse reward env minigrid, in which the last non-zero reward is a strong positive signal
+ # 2. the last reward of each episode directly reflects the agent's completion of the task, e.g. lunarlander
+ # Note that the ngu intrinsic reward is a positive value (max value is 5), in these envs,
+ # the last non-zero reward should not be overwhelmed by intrinsic rewards, so we need rescale the
+ # original last nonzero extrinsic reward.
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_rescale=False,
+ # means the rescale value for the last non-zero reward, only used when last_nonzero_reward_rescale is True
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_weight=1,
+ intrinsic_reward_type='add',
+ learning_rate=1e-4,
+ obs_shape=[4, 84, 84],
+ action_shape=18,
+ batch_size=320,
+ update_per_collect=10,
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='episodic',
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ priority=True,
+ priority_IS_weight=True,
+ discount_factor=0.997,
+ nstep=nstep,
+ burnin_step=20,
+ # (int) is the total length of [sequence sample] minus
+ # the length of burnin part in [sequence sample],
+ # i.e., = = +
+ learn_unroll_len=80, # set this key according to the episode length
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=18,
+ encoder_hidden_size_list=[128, 128, 512],
+ collector_env_num=collector_env_num,
+ ),
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=int(3e3),
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+pitfall_ngu_config = EasyDict(pitfall_ngu_config)
+main_config = pitfall_ngu_config
+pitfall_ngu_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ngu'),
+ rnd_reward_model=dict(type='rnd-ngu'),
+ episodic_reward_model=dict(type='episodic'),
+)
+pitfall_ngu_create_config = EasyDict(pitfall_ngu_create_config)
+create_config = pitfall_ngu_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_reward_model_ngu
+ serial_pipeline_reward_model_ngu([main_config, create_config], seed=0, max_env_step=max_env_step)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/__init__.py b/DI-engine/dizoo/atari/config/serial/pong/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ce3db9a5b9e9fa9bd3e2b9c29ef0bc2b04e28ad
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/__init__.py
@@ -0,0 +1,3 @@
+from .pong_dqn_config import pong_dqn_config, pong_dqn_create_config
+from .pong_dqn_envpool_config import pong_dqn_envpool_config, pong_dqn_envpool_create_config
+from .pong_dqfd_config import pong_dqfd_config, pong_dqfd_create_config
\ No newline at end of file
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_a2c_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_a2c_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..eb48f4248efcf1ee3c4ceaa86fbd6ae23f6c8cd3
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_a2c_config.py
@@ -0,0 +1,62 @@
+from easydict import EasyDict
+
+pong_a2c_config = dict(
+ exp_name='pong_a2c_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ batch_size=160,
+ # (bool) Whether to normalize advantage. Default to False.
+ adv_norm=False,
+ learning_rate=0.0001414,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ grad_norm=0.5,
+ betas=(0.0, 0.99),
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=160,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.99,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+)
+main_config = EasyDict(pong_a2c_config)
+
+pong_a2c_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='a2c'),
+ replay_buffer=dict(type='naive'),
+)
+pong_a2c_create_config = EasyDict(pong_a2c_create_config)
+create_config = pong_a2c_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_onpolicy -c pong_a2c_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_acer_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_acer_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..706bb83b3a90999ffa33d28f57826b3b962e762b
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_acer_config.py
@@ -0,0 +1,71 @@
+from easydict import EasyDict
+
+pong_acer_config = dict(
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ critic_head_hidden_size=512,
+ critic_head_layer_num=2,
+ actor_head_hidden_size=512,
+ actor_head_layer_num=2,
+ ),
+ unroll_len=64,
+ learn=dict(
+ # (int) collect n_sample data, train model update_per_collect times
+ # here we follow impala serial pipeline
+ update_per_collect=10,
+ # (int) the number of data for a train iteration
+ batch_size=64,
+ # grad_clip_type='clip_norm',
+ # clip_value=10,
+ learning_rate_actor=0.0001,
+ learning_rate_critic=0.0003,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.9,
+ # (float) additional discounting parameter
+ trust_region=True,
+ # (float) clip ratio of importance weights
+ c_clip_ratio=10,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=64,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.9,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=3000, ), ),
+ ),
+)
+main_config = EasyDict(pong_acer_config)
+
+pong_acer_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='acer'),
+)
+create_config = EasyDict(pong_acer_create_config)
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pong_acer_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_c51_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_c51_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d9f77fa7b69bf8982c59b2aaa5b653bcc394917
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_c51_config.py
@@ -0,0 +1,62 @@
+from easydict import EasyDict
+
+pong_c51_config = dict(
+ exp_name='pong_c51_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ v_min=-10,
+ v_max=10,
+ n_atom=51,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_c51_config = EasyDict(pong_c51_config)
+main_config = pong_c51_config
+pong_c51_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='c51'),
+)
+pong_c51_create_config = EasyDict(pong_c51_create_config)
+create_config = pong_c51_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pong_c51_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_cql_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_cql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7290863716896a24cb7a8616bf0d9c511e0fe350
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_cql_config.py
@@ -0,0 +1,68 @@
+from easydict import EasyDict
+
+pong_cql_config = dict(
+ exp_name='pong_cql_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=200,
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ learn=dict(
+ train_epoch=30000,
+ batch_size=32,
+ learning_rate=0.00005,
+ target_update_freq=2000,
+ min_q_weight=10.0,
+ ),
+ collect=dict(
+ n_sample=100,
+ data_type='hdf5',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ data_path='./default_experiment/expert.pkl',
+ ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_cql_config = EasyDict(pong_cql_config)
+main_config = pong_cql_config
+pong_cql_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='cql_discrete'),
+)
+pong_cql_create_config = EasyDict(pong_cql_create_config)
+create_config = pong_cql_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_offline -c pong_cql_config.py -s 0`
+ from ding.entry import serial_pipeline_offline
+ serial_pipeline_offline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_dqfd_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_dqfd_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..dde918bd26386faa925808c695f0a058dc6f152a
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_dqfd_config.py
@@ -0,0 +1,80 @@
+from easydict import EasyDict
+
+pong_dqfd_config = dict(
+ exp_name='pong_dqfd_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ lambda1=1.0,
+ lambda2=1.0,
+ lambda3=1e-5,
+ per_train_iter_k=10,
+ expert_replay_buffer_size=10000,
+ # justify the buffer size of the expert buffer
+ ),
+ collect=dict(
+ n_sample=64,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ # Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ # Users should add their own path here (path should lead to a well-trained model)
+ # Absolute path is recommended
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_dqfd_config = EasyDict(pong_dqfd_config)
+main_config = pong_dqfd_config
+pong_dqfd_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqfd'),
+)
+pong_dqfd_create_config = EasyDict(pong_dqfd_create_config)
+create_config = pong_dqfd_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_dqfd -c pong_dqfd_config.py -s 0`
+ # then input ``pong_dqfd_config.py`` upon the instructions.
+ # The reason we need to input the dqfd config is we have to borrow its ``_get_train_sample`` function
+ # in the collector part even though the expert model may be generated from other Q learning algos.
+ from ding.entry.serial_entry_dqfd import serial_pipeline_dqfd
+ from dizoo.atari.config.serial.pong import pong_dqfd_config, pong_dqfd_create_config
+ expert_main_config = pong_dqfd_config
+ expert_create_config = pong_dqfd_create_config
+ serial_pipeline_dqfd((main_config, create_config), (expert_main_config, expert_create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e5f4040f1033da6c76e51740a39c6af2831cc05
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_config.py
@@ -0,0 +1,59 @@
+from easydict import EasyDict
+
+pong_dqn_config = dict(
+ exp_name='pong_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=96, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_dqn_config = EasyDict(pong_dqn_config)
+main_config = pong_dqn_config
+pong_dqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+pong_dqn_create_config = EasyDict(pong_dqn_create_config)
+create_config = pong_dqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pong_dqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b80e41548bbd9977724560f376aa448c5d91405
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_envpool_config.py
@@ -0,0 +1,62 @@
+from easydict import EasyDict
+
+pong_dqn_envpool_config = dict(
+ exp_name='pong_dqn_envpool_seed0',
+ env=dict(
+ collector_env_num=8,
+ collector_batch_size=8,
+ evaluator_env_num=8,
+ evaluator_batch_size=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=96, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_dqn_envpool_config = EasyDict(pong_dqn_envpool_config)
+main_config = pong_dqn_envpool_config
+pong_dqn_envpool_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='env_pool'),
+ policy=dict(type='dqn'),
+ replay_buffer=dict(type='deque'),
+)
+pong_dqn_envpool_create_config = EasyDict(pong_dqn_envpool_create_config)
+create_config = pong_dqn_envpool_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pong_dqn_envpool_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_multi_gpu_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_multi_gpu_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..91b9de260a3d4a2813ffc652b8b2002869f8e032
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_multi_gpu_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+pong_dqn_config = dict(
+ exp_name='pong_dqn_multi_gpu_seed0',
+ env=dict(
+ collector_env_num=4,
+ evaluator_env_num=4,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ multi_gpu=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=96, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_dqn_config = EasyDict(pong_dqn_config)
+main_config = pong_dqn_config
+pong_dqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+pong_dqn_create_config = EasyDict(pong_dqn_create_config)
+create_config = pong_dqn_create_config
+
+if __name__ == '__main__':
+ from ding.utils import DistContext
+ from ding.entry import serial_pipeline
+ with DistContext():
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_render_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_render_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b12ed17331b7c60863132bf829a7ee179631afc1
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_render_config.py
@@ -0,0 +1,65 @@
+from easydict import EasyDict
+
+pong_dqn_config = dict(
+ exp_name='pong_dqn_render_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=96, ),
+ eval=dict(evaluator=dict(
+ eval_freq=4000,
+ render=dict(
+ render_freq=200000,
+ mode='train_iter',
+ ),
+ ), ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_dqn_config = EasyDict(pong_dqn_config)
+main_config = pong_dqn_config
+pong_dqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+pong_dqn_create_config = EasyDict(pong_dqn_create_config)
+create_config = pong_dqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pong_dqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_stdim_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_stdim_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d568ab3bb762e6d7bc2ec57648f8c27b667afbe
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_dqn_stdim_config.py
@@ -0,0 +1,67 @@
+from easydict import EasyDict
+
+pong_dqn_stdim_config = dict(
+ exp_name='pong_dqn_stdim_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ aux_model=dict(
+ encode_shape=64,
+ heads=[1, 1],
+ loss_type='infonce',
+ temperature=1.0,
+ ),
+ # the weight of the auxiliary loss to the TD loss
+ aux_loss_weight=0.003,
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=128,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=96, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_dqn_stdim_config = EasyDict(pong_dqn_stdim_config)
+main_config = pong_dqn_stdim_config
+pong_dqn_stdim_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn_stdim'),
+)
+pong_dqn_stdim_create_config = EasyDict(pong_dqn_stdim_create_config)
+create_config = pong_dqn_stdim_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pong_dqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_dt_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..60d795ec29db09300d04fcf6cf180e131c08072f
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_dt_config.py
@@ -0,0 +1,68 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+Pong_dt_config = dict(
+ exp_name='dt_log/atari/Pong/Pong_dt_seed0',
+ env=dict(
+ env_id='PongNoFrameskip-v4',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ frame_stack=4,
+ is_train=False,
+ episode_num=10000, # stop in breakout
+ ),
+ dataset=dict(
+ env_type='atari',
+ num_steps=500000,
+ # num_steps=50,
+ num_buffers=50,
+ rtg_scale=None,
+ context_len=30,
+ data_dir_prefix='/mnt/nfs/luyd/d4rl_atari/Pong',
+ trajectories_per_buffer=10,
+ ),
+ policy=dict(
+ cuda=True,
+ multi_gpu=True,
+ stop_value=20,
+ evaluator_env_num=8,
+ rtg_target=20, # max target return to go
+ max_eval_ep_len=10000, # max lenght of one episode
+ wt_decay=1e-4,
+ clip_grad_norm_p=1.0,
+ weight_decay=0.1,
+ warmup_steps=10000,
+ model=dict(
+ state_dim=(4, 84, 84),
+ act_dim=6,
+ n_blocks=6,
+ h_dim=128,
+ context_len=30,
+ n_heads=8,
+ drop_p=0.1,
+ continuous=False,
+ ),
+ batch_size=128,
+ learning_rate=6e-4,
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ collect=dict(
+ data_type='d4rl_trajectory',
+ unroll_len=1,
+ ),
+ ),
+)
+
+Pong_dt_config = EasyDict(Pong_dt_config)
+main_config = Pong_dt_config
+Pong_dt_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+Pong_dt_create_config = EasyDict(Pong_dt_create_config)
+create_config = Pong_dt_create_config
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_fqf_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_fqf_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..25a788aa0b3c2e88c972e75580f5cc967561bf94
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_fqf_config.py
@@ -0,0 +1,62 @@
+from easydict import EasyDict
+
+pong_fqf_config = dict(
+ exp_name='pong_fqf_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=32,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ kappa=1.0,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate_fraction=2.5e-9,
+ learning_rate_quantile=0.00005,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_fqf_config = EasyDict(pong_fqf_config)
+main_config = pong_fqf_config
+pong_fqf_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='fqf'),
+)
+pong_fqf_create_config = EasyDict(pong_fqf_create_config)
+create_config = pong_fqf_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pong_fqf_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
\ No newline at end of file
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_gail_dqn_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_gail_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..505b75b62623928c066f782a6e6a67360c062954
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_gail_dqn_config.py
@@ -0,0 +1,90 @@
+from easydict import EasyDict
+
+pong_dqn_gail_config = dict(
+ exp_name='pong_gail_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ reward_model=dict(
+ type='gail',
+ input_size=[4, 84, 84],
+ hidden_size=128,
+ batch_size=64,
+ learning_rate=1e-3,
+ update_per_collect=100,
+ collect_count=1000,
+ action_size=6,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder+/reward_model/ckpt/ckpt_best.pth.tar',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ # e.g. 'exp_name/expert_data.pkl'
+ data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=96, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_dqn_gail_config = EasyDict(pong_dqn_gail_config)
+main_config = pong_dqn_gail_config
+pong_dqn_gail_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+pong_dqn_gail_create_config = EasyDict(pong_dqn_gail_create_config)
+create_config = pong_dqn_gail_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_gail -c pong_gail_dqn_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. pong_dqn_config.py
+ from ding.entry import serial_pipeline_gail
+ from dizoo.atari.config.serial.pong import pong_dqn_config, pong_dqn_create_config
+ expert_main_config = pong_dqn_config
+ expert_create_config = pong_dqn_create_config
+ serial_pipeline_gail(
+ (main_config, create_config), (expert_main_config, expert_create_config),
+ max_env_step=1000000,
+ seed=0,
+ collect_data=True
+ )
\ No newline at end of file
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_impala_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_impala_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8a4136bf918c94a2c9413e5eb9ac6b5cd38e78c
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_impala_config.py
@@ -0,0 +1,78 @@
+from easydict import EasyDict
+
+pong_impala_config = dict(
+ exp_name='impala_log/pong_impala_seed0',
+ env=dict(
+ collector_env_num=12,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=21,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ # (int) the trajectory length to calculate v-trace target
+ unroll_len=64,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[64, 128, 256],
+ critic_head_hidden_size=256,
+ critic_head_layer_num=2,
+ actor_head_hidden_size=256,
+ actor_head_layer_num=2,
+ # impala_cnn_encoder=True,
+ ),
+ learn=dict(
+ # (int) collect n_sample data, train model update_per_collect times
+ # here we follow impala serial pipeline
+ update_per_collect=2,
+ # (int) the number of data for a train iteration
+ batch_size=128,
+ # optim_type='rmsprop',
+ grad_clip_type='clip_norm',
+ clip_value=0.5,
+ learning_rate=0.0006,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ # (float) additional discounting parameter
+ lambda_=0.95,
+ # (float) clip ratio of importance weights
+ rho_clip_ratio=1.0,
+ # (float) clip ratio of importance weights
+ c_clip_ratio=1.0,
+ # (float) clip ratio of importance sampling
+ rho_pg_clip_ratio=1.0,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=16,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=2000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=10000, sliced=False), ),
+ ),
+)
+main_config = EasyDict(pong_impala_config)
+
+pong_impala_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='impala'),
+)
+create_config = EasyDict(pong_impala_create_config)
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pong_impala_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_iqn_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_iqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5eed786a19c2609210d5c8a83cdcf7ba9e61e722
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_iqn_config.py
@@ -0,0 +1,62 @@
+from easydict import EasyDict
+
+pong_iqn_config = dict(
+ exp_name='pong_iqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=32,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ kappa=1.0,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_iqn_config = EasyDict(pong_iqn_config)
+main_config = pong_iqn_config
+
+pong_iqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='iqn'),
+)
+pong_iqn_create_config = EasyDict(pong_iqn_create_config)
+create_config = pong_iqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pong_iqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_ngu_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_ngu_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..215913d20d460477ac8253e2a221539b99c8ef43
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_ngu_config.py
@@ -0,0 +1,130 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+nstep = 5
+max_env_step = int(10e6)
+
+pong_ngu_config = dict(
+ exp_name='pong_ngu_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ obs_plus_prev_action_reward=True, # use specific env wrapper for ngu policy
+ stop_value=20,
+ frame_stack=4,
+ ),
+ rnd_reward_model=dict(
+ intrinsic_reward_type='add',
+ learning_rate=1e-4,
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ batch_size=320,
+ update_per_collect=10,
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='rnd-ngu',
+ ),
+ episodic_reward_model=dict(
+ # means if using rescale trick to the last non-zero reward
+ # when combing extrinsic and intrinsic reward.
+ # the rescale trick only used in:
+ # 1. sparse reward env minigrid, in which the last non-zero reward is a strong positive signal
+ # 2. the last reward of each episode directly reflects the agent's completion of the task, e.g. lunarlander
+ # Note that the ngu intrinsic reward is a positive value (max value is 5), in these envs,
+ # the last non-zero reward should not be overwhelmed by intrinsic rewards, so we need rescale the
+ # original last nonzero extrinsic reward.
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_rescale=False,
+ # means the rescale value for the last non-zero reward, only used when last_nonzero_reward_rescale is True
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_weight=1,
+ intrinsic_reward_type='add',
+ learning_rate=1e-4,
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ batch_size=320,
+ update_per_collect=10,
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='episodic',
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ priority=True,
+ priority_IS_weight=True,
+ discount_factor=0.997,
+ nstep=nstep,
+ burnin_step=20,
+ # (int) is the total length of [sequence sample] minus
+ # the length of burnin part in [sequence sample],
+ # i.e., = = +
+ learn_unroll_len=40, # set this key according to the episode length
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ collector_env_num=collector_env_num,
+ ),
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=int(2e4),
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+pong_ngu_config = EasyDict(pong_ngu_config)
+main_config = pong_ngu_config
+pong_ngu_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ngu'),
+ rnd_reward_model=dict(type='rnd-ngu'),
+ episodic_reward_model=dict(type='episodic'),
+)
+pong_ngu_create_config = EasyDict(pong_ngu_create_config)
+create_config = pong_ngu_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_ngu -c pong_ngu_config.py -s 0`
+ from ding.entry import serial_pipeline_ngu
+ serial_pipeline_ngu([main_config, create_config], seed=0, max_env_step=max_env_step)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_onppo_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b936d4e01f2ddd72a676da11db028b6207433d6c
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_onppo_config.py
@@ -0,0 +1,67 @@
+from easydict import EasyDict
+
+pong_onppo_config = dict(
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ update_per_collect=1,
+ batch_size=320,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ # for onppo, when we recompute adv, we need the key done in data to split traj, so we must
+ # use ignore_done=False here,
+ # but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
+ # for halfcheetah, the length=1000
+ ignore_done=False,
+ grad_clip_type='clip_norm',
+ grad_clip_value=0.5,
+ ),
+ collect=dict(
+ n_sample=3200,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ ),
+)
+main_config = EasyDict(pong_onppo_config)
+
+pong_onppo_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(pong_onppo_create_config)
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c pong_onppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_ppg_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_ppg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..13bc7ed448e9c1930e52c77ae0a838829f113157
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_ppg_config.py
@@ -0,0 +1,80 @@
+from easydict import EasyDict
+
+pong_ppg_config = dict(
+ exp_name='pong_ppg_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ ),
+ learn=dict(
+ update_per_collect=24,
+ batch_size=128,
+ # (bool) Whether to normalize advantage. Default to False.
+ adv_norm=False,
+ learning_rate=0.0001,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.03,
+ clip_ratio=0.1,
+ epochs_aux=6,
+ beta_weight=1,
+ aux_freq=100
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=1024,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.95,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(
+ replay_buffer=dict(
+ multi_buffer=True,
+ policy=dict(
+ replay_buffer_size=100000,
+ max_use=3,
+ ),
+ value=dict(
+ replay_buffer_size=100000,
+ max_use=5,
+ ),
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(pong_ppg_config)
+
+pong_ppg_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppg_offpolicy'),
+)
+create_config = EasyDict(pong_ppg_create_config)
+
+if __name__ == "__main__":
+ import os
+ import warnings
+ from dizoo.atari.entry.atari_ppg_main import main
+ from dizoo.atari.entry.atari_ppg_main import __file__ as _origin_py_file
+ origin_py_file_rel = os.path.relpath(_origin_py_file, os.path.abspath(os.path.curdir))
+ warnings.warn(UserWarning(f"This config file can be executed by {repr(origin_py_file_rel)}"))
+ main(main_config)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_qrdqn_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_qrdqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..df0ad1ab9db65af1dbac7350d164f9240eda4c92
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_qrdqn_config.py
@@ -0,0 +1,60 @@
+from easydict import EasyDict
+
+pong_qrdqn_config = dict(
+ exp_name='pong_qrdqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=64,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_qrdqn_config = EasyDict(pong_qrdqn_config)
+main_config = pong_qrdqn_config
+pong_qrdqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qrdqn'),
+)
+pong_qrdqn_create_config = EasyDict(pong_qrdqn_create_config)
+create_config = pong_qrdqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pong_qrdqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..32a0346a37d094b351daee38a9ae516898a42b54
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_qrdqn_generation_data_config.py
@@ -0,0 +1,68 @@
+from easydict import EasyDict
+
+pong_qrdqn_config = dict(
+ exp_name='pong_qrdqn_generation_data_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=64,
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ collect=dict(
+ collect_count=1000,
+ data_type='hdf5',
+ # pretrained RL model path, user can modify it as its own path
+ model_path='./pong_qrdqn_seed0/ckpt/ckpt_best.pth.tar',
+ # this prefix should be the same as exp_name
+ expert_data_path='./pong_qrdqn_generation_data_seed0/expert.pkl',
+ ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ collect=0.2,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_qrdqn_config = EasyDict(pong_qrdqn_config)
+main_config = pong_qrdqn_config
+pong_qrdqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qrdqn'),
+)
+pong_qrdqn_create_config = EasyDict(pong_qrdqn_create_config)
+create_config = pong_qrdqn_create_config
+
+if __name__ == "__main__":
+ from ding.entry import collect_demo_data
+ cfg = main_config.policy.collect
+ collect_demo_data(
+ (main_config, create_config),
+ seed=0,
+ collect_count=cfg.collect_count,
+ expert_data_path=cfg.expert_data_path,
+ state_dict_path=cfg.model_path
+ )
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_r2d2_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_r2d2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6160332317fbe236f28d47409be4c12cee96d386
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_r2d2_config.py
@@ -0,0 +1,87 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+pong_r2d2_config = dict(
+ exp_name='pong_r2d2_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ discount_factor=0.997,
+ nstep=5,
+ burnin_step=20,
+ # (int) the whole sequence length to unroll the RNN network minus
+ # the timesteps of burnin part,
+ # i.e., = = +
+ learn_unroll_len=80,
+ learn=dict(
+ # according to the R2D2 paper, actor parameter update interval is 400
+ # environment timesteps, and in per collect phase, we collect sequence
+ # samples, the length of each sequence sample is + ,
+ # e.g. if n_sample=32, is 100, thus 32*100/400=8,
+ # we will set update_per_collect=8 in most environments.
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In R2D2 policy, for each collect_env, we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=10000,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+pong_r2d2_config = EasyDict(pong_r2d2_config)
+main_config = pong_r2d2_config
+pong_r2d2_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2'),
+)
+pong_r2d2_create_config = EasyDict(pong_r2d2_create_config)
+create_config = pong_r2d2_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c pong_r2d2_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_r2d2_gtrxl_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_r2d2_gtrxl_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..860cce4e4cc1ef88fafd4d367542a8f57f3d0f41
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_r2d2_gtrxl_config.py
@@ -0,0 +1,89 @@
+from easydict import EasyDict
+
+collector_env_num = 4
+evaluator_env_num = 4
+pong_r2d2_gtrxl_config = dict(
+ exp_name='pong_r2d2_gtrxl_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=5,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ hidden_size=1024,
+ encoder_hidden_size_list=[128, 512, 1024],
+ gru_bias=2.,
+ memory_len=0,
+ dropout=0.1,
+ att_head_num=8,
+ att_layer_num=3,
+ att_head_dim=16,
+ ),
+ discount_factor=0.997,
+ burnin_step=0,
+ nstep=5,
+ unroll_len=25,
+ seq_len=20,
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ value_rescale=True,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In R2D2 policy, for each collect_env, we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=300, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=10000,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+pong_r2d2_gtrxl_config = EasyDict(pong_r2d2_gtrxl_config)
+main_config = pong_r2d2_gtrxl_config
+pong_r2d2_gtrxl_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2_gtrxl'),
+)
+pong_r2d2_gtrxl_create_config = EasyDict(pong_r2d2_gtrxl_create_config)
+create_config = pong_r2d2_gtrxl_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c pong_r2d2_gtrxl_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_r2d2_residual_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_r2d2_residual_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..db3411102f3f6431230b1cc30aec5820a5e6e4ff
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_r2d2_residual_config.py
@@ -0,0 +1,84 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+pong_r2d2_residual_config = dict(
+ exp_name='pong_r2d2_residual_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ res_link=True,
+ ),
+ discount_factor=0.997,
+ nstep=5,
+ burnin_step=2,
+ # (int) the whole sequence length to unroll the RNN network minus
+ # the timesteps of burnin part,
+ # i.e., = = +
+ learn_unroll_len=40,
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=10000,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+pong_r2d2_residual_config = EasyDict(pong_r2d2_residual_config)
+main_config = pong_r2d2_residual_config
+pong_r2d2_residual_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2'),
+)
+pong_r2d2_residual_create_config = EasyDict(pong_r2d2_residual_create_config)
+create_config = pong_r2d2_residual_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c pong_r2d2_residual_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_r2d3_offppoexpert_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_r2d3_offppoexpert_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a61221da87f3500cff990b5733904de08efbd5fd
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_r2d3_offppoexpert_config.py
@@ -0,0 +1,167 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+expert_replay_buffer_size = int(5e3)
+"""
+agent config
+"""
+pong_r2d3_config = dict(
+ exp_name='pong_r2d3_offppo-expert_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ discount_factor=0.997,
+ nstep=5,
+ burnin_step=2,
+ # (int) the whole sequence length to unroll the RNN network minus
+ # the timesteps of burnin part,
+ # i.e., = = +
+ learn_unroll_len=40,
+ learn=dict(
+ value_rescale=True,
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ # DQFD related parameters
+ lambda1=1.0, # n-step return
+ lambda2=1, # 1.0, # supervised loss
+ lambda3=1e-5, # 1e-5, # L2 it's very important to set Adam optimizer optim_type='adamw'.
+ lambda_one_step_td=1, # 1-step return
+ margin_function=0.8, # margin function in JE, here we implement this as a constant
+ per_train_iter_k=0, # TODO(pu)
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ # The hyperparameter pho, the demo ratio, control the propotion of data coming\
+ # from expert demonstrations versus from the agent's own experience.
+ pho=1 / 4, # TODO(pu)
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=20000,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+pong_r2d3_config = EasyDict(pong_r2d3_config)
+main_config = pong_r2d3_config
+pong_r2d3_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d3'),
+)
+pong_r2d3_create_config = EasyDict(pong_r2d3_create_config)
+create_config = pong_r2d3_create_config
+"""
+export config
+"""
+expert_pong_r2d3_config = dict(
+ exp_name='expert_pong_r2d3_offppo-expert_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[64, 64, 128], # ppo expert policy
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ discount_factor=0.997,
+ burnin_step=20,
+ nstep=5,
+ learn=dict(expert_replay_buffer_size=expert_replay_buffer_size, ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ # Users should add their own path here. path should lead to a well-trained model
+ # Absolute path is recommended.
+ model_path='./pong_offppo_seed0/ckpt/ckpt_best.pth.tar',
+ # Cut trajectories into pieces with length "unroll_len",
+ # which should set as self._sequence_len of r2d2
+ unroll_len=42, # NOTE: should equals self._sequence_len in r2d2 policy
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ replay_buffer=dict(
+ replay_buffer_size=expert_replay_buffer_size,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ ),
+ ),
+ ),
+)
+expert_pong_r2d3_config = EasyDict(expert_pong_r2d3_config)
+expert_main_config = expert_pong_r2d3_config
+expert_pong_r2d3_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='offppo_collect_traj'),
+)
+expert_pong_r2d3_create_config = EasyDict(expert_pong_r2d3_create_config)
+expert_create_config = expert_pong_r2d3_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_r2d3
+ serial_pipeline_r2d3((main_config, create_config), (expert_main_config, expert_create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_r2d3_r2d2expert_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_r2d3_r2d2expert_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..64d0bd06fc2ed4e5dc0ee4910f689d924acbf17e
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_r2d3_r2d2expert_config.py
@@ -0,0 +1,170 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+expert_replay_buffer_size = int(5e3) # TODO(pu)
+"""
+agent config
+"""
+pong_r2d3_config = dict(
+ exp_name='pong_r2d3_r2d2expert_k0_pho1-4_rbs2e4_ds5e3_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ discount_factor=0.997,
+ burnin_step=2,
+ nstep=5,
+ # (int) the whole sequence length to unroll the RNN network minus
+ # the timesteps of burnin part,
+ # i.e., = +
+ unroll_len=40,
+ learn=dict(
+ # according to the r2d3 paper, actor parameter update interval is 400
+ # environment timesteps, and in per collect phase, we collect 32 sequence
+ # samples, the length of each samlpe sequence is + ,
+ # which is 100 in our seeting, 32*100/400=8, so we set update_per_collect=8
+ # in most environments
+ value_rescale=True,
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ # DQFD related parameters
+ lambda1=1.0, # n-step return
+ lambda2=1.0, # supervised loss
+ lambda3=1e-5, # L2 it's very important to set Adam optimizer optim_type='adamw'.
+ lambda_one_step_td=1.0, # 1-step return
+ margin_function=0.8, # margin function in JE, here we implement this as a constant
+ per_train_iter_k=0, # TODO(pu)
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In R2D2 policy, for each collect_env, we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ # The hyperparameter pho, the demo ratio, control the propotion of data coming\
+ # from expert demonstrations versus from the agent's own experience.
+ pho=1 / 4, # TODO(pu)
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=int(2e4), # TODO(pu)
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+pong_r2d3_config = EasyDict(pong_r2d3_config)
+main_config = pong_r2d3_config
+pong_r2d3_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d3'),
+)
+pong_r2d3_create_config = EasyDict(pong_r2d3_create_config)
+create_config = pong_r2d3_create_config
+"""
+export config
+"""
+expert_pong_r2d3_config = dict(
+ exp_name='expert_pong_r2d3_r2d2expert_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512], # r2d2 expert policy
+ ),
+ discount_factor=0.997,
+ burnin_step=2,
+ nstep=5,
+ learn=dict(expert_replay_buffer_size=expert_replay_buffer_size, ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ # Users should add their own path here. path should lead to a well-trained model
+ # Absolute path is recommended.
+ model_path='./pong_r2d2_seed0/ckpt/ckpt_best.pth.tar',
+ # Cut trajectories into pieces with length "unroll_len",
+ # which should set as self._sequence_len of r2d2
+ unroll_len=42, # NOTE: should equals self._sequence_len in r2d2 policy
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ replay_buffer=dict(
+ replay_buffer_size=expert_replay_buffer_size,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ ),
+ ),
+ ),
+)
+expert_pong_r2d3_config = EasyDict(expert_pong_r2d3_config)
+expert_main_config = expert_pong_r2d3_config
+expert_pong_r2d3_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ # this policy is designed to collect r2d2 expert traj for r2d3
+ policy=dict(type='r2d2_collect_traj'),
+)
+expert_pong_r2d3_create_config = EasyDict(expert_pong_r2d3_create_config)
+expert_create_config = expert_pong_r2d3_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_r2d3
+ serial_pipeline_r2d3([main_config, create_config], [expert_main_config, expert_create_config], seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_rainbow_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_rainbow_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f7b403d57a006771d38f0be3dd4cf471bd4c19a0
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_rainbow_config.py
@@ -0,0 +1,62 @@
+from easydict import EasyDict
+
+pong_rainbow_config = dict(
+ env=dict(
+ exp_name='pong_rainbow_seed0',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ priority_IS_weight=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ v_min=-10,
+ v_max=10,
+ n_atom=51,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+main_config = EasyDict(pong_rainbow_config)
+
+pong_rainbow_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='rainbow'),
+)
+create_config = EasyDict(pong_rainbow_create_config)
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pong_rainbow_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_sqil_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_sqil_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d562795d52e2368ddf08f1e45375f9d054eaba4
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_sqil_config.py
@@ -0,0 +1,65 @@
+from easydict import EasyDict
+
+pong_sqil_config = dict(
+ exp_name='pong_sqil_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.97, # discount_factor: 0.97-0.99
+ learn=dict(update_per_collect=10, batch_size=32, learning_rate=0.0001, target_update_freq=500,
+ alpha=0.1), # alpha: 0.08-0.12
+ collect=dict(
+ n_sample=96,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_sqil_config = EasyDict(pong_sqil_config)
+main_config = pong_sqil_config
+pong_sqil_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sql'),
+)
+pong_sqil_create_config = EasyDict(pong_sqil_create_config)
+create_config = pong_sqil_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_sqil -c pong_sqil_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. pong_dqn_config.py
+ from ding.entry import serial_pipeline_sqil
+ from dizoo.atari.config.serial.pong import pong_dqn_config, pong_dqn_create_config
+ expert_main_config = pong_dqn_config
+ expert_create_config = pong_dqn_create_config
+ serial_pipeline_sqil((main_config, create_config), (expert_main_config, expert_create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_sql_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_sql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bc2ab2ee0946e247ee368b9286294a13f2e01a3
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_sql_config.py
@@ -0,0 +1,53 @@
+from easydict import EasyDict
+
+pong_sql_config = dict(
+ exp_name='pong_sql_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(update_per_collect=10, batch_size=32, learning_rate=0.0001, target_update_freq=500, alpha=0.12),
+ collect=dict(n_sample=96),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_sql_config = EasyDict(pong_sql_config)
+main_config = pong_sql_config
+pong_sql_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sql'),
+)
+pong_sql_create_config = EasyDict(pong_sql_create_config)
+create_config = pong_sql_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pong_sql_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_trex_offppo_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_trex_offppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3351931380cdba2677dfea51868bcbe04f2b07c0
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_trex_offppo_config.py
@@ -0,0 +1,102 @@
+from easydict import EasyDict
+
+pong_trex_ppo_config = dict(
+ exp_name='pong_trex_offppo_seed0',
+ env=dict(
+ collector_env_num=4,
+ evaluator_env_num=4,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ reward_model=dict(
+ type='trex',
+ min_snippet_length=50,
+ max_snippet_length=100,
+ checkpoint_min=0,
+ checkpoint_max=100,
+ checkpoint_step=100,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ # However, here in ``expert_model_path``, it is ``exp_name`` of the expert config.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder + /pong.params',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ # See ding/entry/application_entry_trex_collect_data.py to collect the data
+ data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=2048,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ critic_head_layer_num=1,
+ ),
+ learn=dict(
+ update_per_collect=24,
+ batch_size=128,
+ # (bool) Whether to normalize advantage. Default to False.
+ adv_norm=False,
+ learning_rate=0.0002,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ clip_ratio=0.1,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=1024,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.95,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(replay_buffer=dict(
+ replay_buffer_size=100000,
+ max_use=5,
+ ), ),
+ ),
+)
+pong_trex_ppo_config = EasyDict(pong_trex_ppo_config)
+main_config = pong_trex_ppo_config
+
+pong_trex_ppo_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo_offpolicy'),
+)
+pong_trex_ppo_create_config = EasyDict(pong_trex_ppo_create_config)
+create_config = pong_trex_ppo_create_config
+
+if __name__ == "__main__":
+ # Users should first run ``ppo_offppo_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_preference_based_irl
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_preference_based_irl((main_config, create_config))
diff --git a/DI-engine/dizoo/atari/config/serial/pong/pong_trex_sql_config.py b/DI-engine/dizoo/atari/config/serial/pong/pong_trex_sql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1d4991294dd974d5eebd81ef93bf57ee37f2b49
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/pong/pong_trex_sql_config.py
@@ -0,0 +1,87 @@
+from easydict import EasyDict
+
+pong_trex_sql_config = dict(
+ exp_name='pong_trex_sql_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='PongNoFrameskip-v4',
+ #'ALE/Pong-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ reward_model=dict(
+ type='trex',
+ min_snippet_length=50,
+ max_snippet_length=100,
+ checkpoint_min=10000,
+ checkpoint_max=50000,
+ checkpoint_step=10000,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ # However, here in ``expert_model_path``, it is ``exp_name`` of the expert config.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder + /pong.params',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ # See ding/entry/application_entry_trex_collect_data.py to collect the data
+ data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ learn=dict(update_per_collect=10, batch_size=32, learning_rate=0.0001, target_update_freq=500, alpha=0.12),
+ collect=dict(n_sample=96, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_trex_sql_config = EasyDict(pong_trex_sql_config)
+main_config = pong_trex_sql_config
+pong_trex_sql_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sql'),
+)
+pong_trex_sql_create_config = EasyDict(pong_trex_sql_create_config)
+create_config = pong_trex_sql_create_config
+
+if __name__ == '__main__':
+ # Users should first run ``ppo_sql_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_preference_based_irl
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_preference_based_irl((main_config, create_config))
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/__init__.py b/DI-engine/dizoo/atari/config/serial/qbert/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..5032c3a751bf25a60f09d6687bc7011505777fae
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/__init__.py
@@ -0,0 +1,2 @@
+from .qbert_dqn_config import qbert_dqn_config, qbert_dqn_create_config
+from .qbert_dqfd_config import qbert_dqfd_config, qbert_dqfd_create_config
\ No newline at end of file
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_a2c_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_a2c_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4b0cc7088b16812fc200b13a738aa55710e840d
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_a2c_config.py
@@ -0,0 +1,62 @@
+from easydict import EasyDict
+
+qbert_a2c_config = dict(
+ exp_name='qbert_a2c_seed0',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=1000000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[32, 64, 64, 256],
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ critic_head_layer_num=2,
+ ),
+ learn=dict(
+ batch_size=300,
+ # (bool) Whether to normalize advantage. Default to False.
+ adv_norm=False,
+ learning_rate=0.0001414,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ grad_norm=0.5,
+ betas=(0.0, 0.99),
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model 1 times
+ n_sample=160,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.99,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+)
+main_config = EasyDict(qbert_a2c_config)
+
+qbert_a2c_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='a2c'),
+ replay_buffer=dict(type='naive'),
+)
+create_config = EasyDict(qbert_a2c_create_config)
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial_onpolicy -c qbert_a2c_config.py -s 0
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_acer_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_acer_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..46bff881dd38ca37864e710390e5ed1bbaf6f300
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_acer_config.py
@@ -0,0 +1,72 @@
+from easydict import EasyDict
+
+qbert_acer_config = dict(
+ exp_name='qbert_acer_seed0',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=int(1e6),
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ critic_head_hidden_size=512,
+ critic_head_layer_num=2,
+ actor_head_hidden_size=512,
+ actor_head_layer_num=2
+ ),
+ unroll_len=64,
+ learn=dict(
+ # (int) collect n_sample data, train model update_per_collect times
+ # here we follow impala serial pipeline
+ update_per_collect=10,
+ # (int) the number of data for a train iteration
+ batch_size=64,
+ # grad_clip_type='clip_norm',
+ learning_rate_actor=0.0001,
+ learning_rate_critic=0.0003,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ # (float) additional discounting parameter
+ trust_region=True,
+ # (float) clip ratio of importance weights
+ c_clip_ratio=10,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=64,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=3000, ), ),
+ ),
+)
+main_config = EasyDict(qbert_acer_config)
+
+qbert_acer_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='acer'),
+)
+create_config = EasyDict(qbert_acer_create_config)
+
+if __name__ == "__main__":
+ # or you can enter ding -m serial -c qbert_acer_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_c51_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_c51_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ba0effe2c3e41456ae0e5275c9fd34b759fbc46e
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_c51_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+qbert_c51_config = dict(
+ exp_name='qbert_c51_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ v_min=-10,
+ v_max=10,
+ n_atom=51,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+main_config = EasyDict(qbert_c51_config)
+
+qbert_c51_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='c51'),
+)
+create_config = EasyDict(qbert_c51_create_config)
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c qbert_c51_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_cql_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_cql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..de6aef6e896379689357453e6456e1ecef3d2b5c
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_cql_config.py
@@ -0,0 +1,68 @@
+from easydict import EasyDict
+
+qbert_cql_config = dict(
+ exp_name='qbert_cql_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=200,
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ train_epoch=30000,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=2000,
+ min_q_weight=10.0,
+ ),
+ collect=dict(
+ n_sample=100,
+ data_type='naive',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ data_path='data_path_placeholder',
+ ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+main_config = EasyDict(qbert_cql_config)
+
+qbert_cql_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='cql_discrete'),
+)
+create_config = EasyDict(qbert_cql_create_config)
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial_offline -c qbert_cql_config.py -s 0
+ from ding.entry import serial_pipeline_offline
+ serial_pipeline_offline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_dqfd_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_dqfd_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3907aa3f098d239d3f084fd930c302e0c28c7d3e
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_dqfd_config.py
@@ -0,0 +1,76 @@
+from easydict import EasyDict
+
+qbert_dqfd_config = dict(
+ exp_name='qbert_dqfd_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ lambda1=1.0,
+ lambda2=1.0,
+ lambda3=1e-5,
+ per_train_iter_k=10,
+ expert_replay_buffer_size=10000, # justify the buffer size of the expert buffer
+ ),
+ collect=dict(
+ n_sample=100,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder'
+ ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+qbert_dqfd_config = EasyDict(qbert_dqfd_config)
+main_config = qbert_dqfd_config
+qbert_dqfd_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqfd'),
+)
+qbert_dqfd_create_config = EasyDict(qbert_dqfd_create_config)
+create_config = qbert_dqfd_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_dqfd -c spaceinvaders_dqfd_config.py -s 0`
+ # then input ``spaceinvaders_dqfd_config.py`` upon the instructions.
+ # The reason we need to input the dqfd config is we have to borrow its ``_get_train_sample`` function
+ # in the collector part even though the expert model may be generated from other Q learning algos.
+ from ding.entry.serial_entry_dqfd import serial_pipeline_dqfd
+ from dizoo.atari.config.serial.qbert import qbert_dqfd_config, qbert_dqfd_create_config
+ expert_main_config = qbert_dqfd_config
+ expert_create_config = qbert_dqfd_create_config
+ serial_pipeline_dqfd([main_config, create_config], [expert_main_config, expert_create_config], seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_dqn_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b1cd1906f4063f1732cae52d867232d5b7b1420
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_dqn_config.py
@@ -0,0 +1,59 @@
+from easydict import EasyDict
+
+qbert_dqn_config = dict(
+ exp_name='qbert_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+qbert_dqn_config = EasyDict(qbert_dqn_config)
+main_config = qbert_dqn_config
+qbert_dqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+qbert_dqn_create_config = EasyDict(qbert_dqn_create_config)
+create_config = qbert_dqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c qbert_dqn_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_fqf_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_fqf_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3241c924f3ed0f7c08012e6bc00341349c97f12d
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_fqf_config.py
@@ -0,0 +1,64 @@
+from easydict import EasyDict
+
+qbert_fqf_config = dict(
+ exp_name='qbert_fqf_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=32,
+ quantile_embedding_size=64,
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate_fraction=2.5e-9,
+ learning_rate_quantile=0.00005,
+ target_update_freq=500,
+ ent_coef=0,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.01,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=1000000, ),
+ ),
+ ),
+)
+qbert_fqf_config = EasyDict(qbert_fqf_config)
+main_config = qbert_fqf_config
+
+qbert_fqf_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='fqf'),
+)
+qbert_fqf_create_config = EasyDict(qbert_fqf_create_config)
+create_config = qbert_fqf_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c qbert_fqf_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
\ No newline at end of file
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_impala_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_impala_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..74fbdf4c0c9821468407a0295bd17af68b213011
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_impala_config.py
@@ -0,0 +1,76 @@
+from easydict import EasyDict
+
+qbert_impala_config = dict(
+ exp_name='qbert_impala_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ # (int) the trajectory length to calculate v-trace target
+ unroll_len=32,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 256, 512],
+ critic_head_hidden_size=512,
+ critic_head_layer_num=3,
+ actor_head_hidden_size=512,
+ actor_head_layer_num=3,
+ ),
+ learn=dict(
+ # (int) collect n_sample data, train model update_per_collect times
+ # here we follow impala serial pipeline
+ update_per_collect=10, # update_per_collect show be in [1, 10]
+ # (int) the number of data for a train iteration
+ batch_size=128,
+ grad_clip_type='clip_norm',
+ clip_value=5,
+ learning_rate=0.0003,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ # (float) additional discounting parameter
+ lambda_=0.95,
+ # (float) clip ratio of importance weights
+ rho_clip_ratio=1.0,
+ # (float) clip ratio of importance weights
+ c_clip_ratio=1.0,
+ # (float) clip ratio of importance sampling
+ rho_pg_clip_ratio=1.0,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=16,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
+ ),
+)
+main_config = EasyDict(qbert_impala_config)
+
+qbert_impala_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='impala'),
+ replay_buffer=dict(type='naive'),
+)
+create_config = EasyDict(qbert_impala_create_config)
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c qbert_impala_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_iqn_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_iqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..00f2f3218305830cbeef67c63b396f0ee8177a68
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_iqn_config.py
@@ -0,0 +1,60 @@
+from easydict import EasyDict
+
+qbert_iqn_config = dict(
+ exp_name='qbert_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=32,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ kappa=1.0,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+main_config = EasyDict(qbert_iqn_config)
+
+qbert_iqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='iqn'),
+)
+create_config = EasyDict(qbert_iqn_create_config)
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c qbert_iqn_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_ngu_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_ngu_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..00bc18f33ff29b55359c095e379ff555677efa6a
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_ngu_config.py
@@ -0,0 +1,129 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+nstep = 5
+max_env_step = int(10e6)
+
+qbert_ngu_config = dict(
+ exp_name='qbert_ngu_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ env_id='QbertNoFrameskip-v4',
+ obs_plus_prev_action_reward=True, # use specific env wrapper for ngu policy
+ stop_value=int(1e6),
+ frame_stack=4,
+ ),
+ rnd_reward_model=dict(
+ intrinsic_reward_type='add',
+ learning_rate=1e-4,
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ batch_size=320,
+ update_per_collect=10,
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='rnd-ngu',
+ ),
+ episodic_reward_model=dict(
+ # means if using rescale trick to the last non-zero reward
+ # when combing extrinsic and intrinsic reward.
+ # the rescale trick only used in:
+ # 1. sparse reward env minigrid, in which the last non-zero reward is a strong positive signal
+ # 2. the last reward of each episode directly reflects the agent's completion of the task, e.g. lunarlander
+ # Note that the ngu intrinsic reward is a positive value (max value is 5), in these envs,
+ # the last non-zero reward should not be overwhelmed by intrinsic rewards, so we need rescale the
+ # original last nonzero extrinsic reward.
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_rescale=False,
+ # means the rescale value for the last non-zero reward, only used when last_nonzero_reward_rescale is True
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_weight=1,
+ intrinsic_reward_type='add',
+ learning_rate=1e-4,
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ batch_size=320,
+ update_per_collect=10,
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='episodic',
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ priority=True,
+ priority_IS_weight=True,
+ discount_factor=0.997,
+ nstep=nstep,
+ burnin_step=20,
+ # (int) is the total length of [sequence sample] minus
+ # the length of burnin part in [sequence sample],
+ # i.e., = = +
+ learn_unroll_len=40, # set this key according to the episode length
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ collector_env_num=collector_env_num,
+ ),
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=int(2e4),
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+qbert_ngu_config = EasyDict(qbert_ngu_config)
+main_config = qbert_ngu_config
+qbert_ngu_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ngu'),
+ rnd_reward_model=dict(type='rnd-ngu'),
+ episodic_reward_model=dict(type='episodic'),
+)
+qbert_ngu_create_config = EasyDict(qbert_ngu_create_config)
+create_config = qbert_ngu_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_ngu -c qbert_ngu_config.py -s 0`
+ from ding.entry import serial_pipeline_ngu
+ serial_pipeline_ngu([main_config, create_config], seed=0, max_env_step=max_env_step)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_offppo_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_offppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d0de2b6b43f1a7cf2f959dd74c389c3486b2924b
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_offppo_config.py
@@ -0,0 +1,65 @@
+from easydict import EasyDict
+
+qbert_offppo_config = dict(
+ exp_name='qbert_offppo_seed0',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[32, 64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ critic_head_layer_num=2,
+ ),
+ learn=dict(
+ update_per_collect=24,
+ batch_size=128,
+ # (bool) Whether to normalize advantage. Default to False.
+ adv_norm=False,
+ learning_rate=0.0001,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=1.0,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.03,
+ clip_ratio=0.1,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=1024,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.95,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(replay_buffer=dict(
+ replay_buffer_size=100000,
+ max_use=3,
+ ), ),
+ ),
+)
+main_config = EasyDict(qbert_offppo_config)
+
+qbert_offppo_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo_offpolicy'),
+)
+create_config = EasyDict(qbert_offppo_create_config)
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c qbert_offppo_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_onppo_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..46d813ec4c04afc8f6fefa08b2455fe5e2fe670c
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_onppo_config.py
@@ -0,0 +1,68 @@
+from easydict import EasyDict
+
+qbert_onppo_config = dict(
+ exp_name='enduro_onppo_seed0',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=int(1e10),
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ update_per_collect=1,
+ batch_size=320,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ # for onppo, when we recompute adv, we need the key done in data to split traj, so we must
+ # use ignore_done=False here,
+ # but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
+ # for halfcheetah, the length=1000
+ ignore_done=False,
+ grad_clip_type='clip_norm',
+ grad_clip_value=0.5,
+ ),
+ collect=dict(
+ n_sample=3200,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ ),
+)
+main_config = EasyDict(qbert_onppo_config)
+
+qbert_onppo_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(qbert_onppo_create_config)
+
+if __name__ == "__main__":
+ # or you can enter ding -m serial_onpolicy -c qbert_onppo_config.py -s 0
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_ppg_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_ppg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..51a953315ae1172dda53cb7f818779f7be1f751a
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_ppg_config.py
@@ -0,0 +1,77 @@
+from easydict import EasyDict
+
+qbert_ppg_config = dict(
+ exp_name='qbert_ppg_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=1000000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[32, 64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ critic_head_layer_num=2,
+ ),
+ learn=dict(
+ update_per_collect=24,
+ batch_size=128,
+ # (bool) Whether to normalize advantage. Default to False.
+ adv_norm=False,
+ learning_rate=0.0001,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.03,
+ clip_ratio=0.1,
+ epochs_aux=6,
+ beta_weight=1,
+ aux_freq=100
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=1024,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.95,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(
+ replay_buffer=dict(
+ multi_buffer=True,
+ policy=dict(
+ replay_buffer_size=100000,
+ max_use=3,
+ ),
+ value=dict(
+ replay_buffer_size=100000,
+ max_use=10,
+ ),
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(qbert_ppg_config)
+
+qbert_ppg_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppg_offpolicy'),
+)
+create_config = EasyDict(qbert_ppg_create_config)
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c qbert_ppg_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_qrdqn_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_qrdqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5485e7f016f497632027662de55c1ef7744bce9
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_qrdqn_config.py
@@ -0,0 +1,59 @@
+from easydict import EasyDict
+
+qbert_qrdqn_config = dict(
+ exp_name='qbert_qrdqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=64,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+main_config = EasyDict(qbert_qrdqn_config)
+
+qbert_qrdqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qrdqn'),
+)
+create_config = EasyDict(qbert_qrdqn_create_config)
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c qbert_qrdqn_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..38eaaa7ca6cbc3e6c812e059b636a9d11972c354
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_qrdqn_generation_data_config.py
@@ -0,0 +1,71 @@
+from easydict import EasyDict
+
+qbert_qrdqn_config = dict(
+ exp_name='qbert_qrdqn_generation_data_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=64,
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ learner=dict(
+ load_path='./expert/ckpt/ckpt_best.pth.tar',
+ hook=dict(
+ load_ckpt_before_run='./expert/ckpt/ckpt_best.pth.tar',
+ save_ckpt_after_run=False,
+ )
+ ),
+ ),
+ collect=dict(
+ n_sample=100,
+ data_type='hdf5',
+ save_path='./expert/expert.pkl',
+ ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ collect=0.2,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+main_config = EasyDict(qbert_qrdqn_config)
+
+qbert_qrdqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qrdqn'),
+)
+create_config = EasyDict(qbert_qrdqn_create_config)
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c qbert_qrdqn_generation_data_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_r2d2_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_r2d2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1c597eb2baf380ccf23092d194e9675689f3bec
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_r2d2_config.py
@@ -0,0 +1,90 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+qbert_r2d2_config = dict(
+ exp_name='qbert_r2d2_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=int(1e6),
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ res_link=False,
+ ),
+ discount_factor=0.997,
+ nstep=5,
+ burnin_step=20,
+ # (int) the whole sequence length to unroll the RNN network minus
+ # the timesteps of burnin part,
+ # i.e., = = +
+ learn_unroll_len=80,
+ learn=dict(
+ # according to the R2D2 paper, actor parameter update interval is 400
+ # environment timesteps, and in per collect phase, we collect sequence
+ # samples, the length of each sequence sample is + ,
+ # e.g. if n_sample=32, is 100, thus 32*100/400=8,
+ # we will set update_per_collect=8 in most environments.
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=10000,
+ # (Float type) How much prioritization is used: 0 means no prioritization
+ # while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+qbert_r2d2_config = EasyDict(qbert_r2d2_config)
+main_config = qbert_r2d2_config
+qbert_r2d2_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2'),
+)
+qbert_r2d2_create_config = EasyDict(qbert_r2d2_create_config)
+create_config = qbert_r2d2_create_config
+
+if __name__ == "__main__":
+ # or you can enter ding -m serial -c qbert_r2d2_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_r2d2_gtrxl_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_r2d2_gtrxl_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..afc22c2eb002b864368b114789772027425e726b
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_r2d2_gtrxl_config.py
@@ -0,0 +1,83 @@
+from easydict import EasyDict
+
+qbert_r2d2_gtrxl_config = dict(
+ exp_name='qbert_r2d2_gtrxl_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 256, 1024],
+ hidden_size=1024,
+ gru_bias=1.,
+ memory_len=0,
+ ),
+ discount_factor=0.99,
+ burnin_step=0,
+ nstep=3,
+ unroll_len=13,
+ seq_len=10,
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=8,
+ ),
+ eval=dict(env_num=8, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=10000,
+ # (Float type) How much prioritization is used: 0 means no prioritization
+ # while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+qbert_r2d2_gtrxl_config = EasyDict(qbert_r2d2_gtrxl_config)
+main_config = qbert_r2d2_gtrxl_config
+qbert_r2d2_gtrxl_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2_gtrxl'),
+)
+qbert_r2d2_gtrxl_create_config = EasyDict(qbert_r2d2_gtrxl_create_config)
+create_config = qbert_r2d2_gtrxl_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c qbert_gtrxl_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_rainbow_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_rainbow_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..26f761fd933ec2af72935ba43fab72cac835b047
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_rainbow_config.py
@@ -0,0 +1,63 @@
+from easydict import EasyDict
+
+qbert_rainbow_config = dict(
+ exp_name='qbert_rainbow_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ priority_IS_weight=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ v_min=-10,
+ v_max=10,
+ n_atom=51,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ iqn=False,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.05,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+main_config = EasyDict(qbert_rainbow_config)
+
+qbert_rainbow_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='rainbow'),
+)
+create_config = EasyDict(qbert_rainbow_create_config)
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c qbert_rainbow_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_sqil_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_sqil_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b030513c7f24ca27020a25eda9687e1d9a4e7d21
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_sqil_config.py
@@ -0,0 +1,71 @@
+from easydict import EasyDict
+
+qbert_sqil_config = dict(
+ exp_name='qbert_sqil_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.97, # discount_factor: 0.97-0.99
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ alpha=0.1 # alpha: 0.08-0.12
+ ),
+ collect=dict(
+ n_sample=100,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder'
+ ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+qbert_sqil_config = EasyDict(qbert_sqil_config)
+main_config = qbert_sqil_config
+qbert_sqil_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+qbert_sqil_create_config = EasyDict(qbert_sqil_create_config)
+create_config = qbert_sqil_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_sqil -c qbert_sqil_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. qbert_dqn_config.py
+ from ding.entry import serial_pipeline_sqil
+ from dizoo.atari.config.serial.qbert import qbert_dqn_config, qbert_dqn_create_config
+ expert_main_config = qbert_dqn_config
+ expert_create_config = qbert_dqn_create_config
+ serial_pipeline_sqil([main_config, create_config], [expert_main_config, expert_create_config], seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_sql_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_sql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7019318d453b4d7ab880e5a7c54959a14f94947b
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_sql_config.py
@@ -0,0 +1,59 @@
+from easydict import EasyDict
+
+qbert_sql_config = dict(
+ exp_name='qbert_sql_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=500000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+qbert_sql_config = EasyDict(qbert_sql_config)
+main_config = qbert_sql_config
+qbert_sql_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sql'),
+)
+qbert_sql_create_config = EasyDict(qbert_sql_create_config)
+create_config = qbert_sql_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c qbert_sql_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_trex_dqn_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_trex_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..95d1d9716d5ce4a714db586caabc18d76119b757
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_trex_dqn_config.py
@@ -0,0 +1,86 @@
+from easydict import EasyDict
+
+qbert_trex_dqn_config = dict(
+ exp_name='qbert_trex_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=30000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ reward_model=dict(
+ type='trex',
+ min_snippet_length=30,
+ max_snippet_length=100,
+ checkpoint_min=0,
+ checkpoint_max=100,
+ checkpoint_step=100,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ expert_model_path='abs model path',
+ reward_model_path='abs data path + ./qbert.params',
+ offline_data_path='abs data path',
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+qbert_trex_dqn_config = EasyDict(qbert_trex_dqn_config)
+main_config = qbert_trex_dqn_config
+qbert_trex_dqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn'),
+)
+qbert_trex_dqn_create_config = EasyDict(qbert_trex_dqn_create_config)
+create_config = qbert_trex_dqn_create_config
+
+if __name__ == "__main__":
+ # Users should first run ``cartpole_dqn_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar,
+ # iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_reward_model_trex
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_reward_model_trex((main_config, create_config))
diff --git a/DI-engine/dizoo/atari/config/serial/qbert/qbert_trex_offppo_config.py b/DI-engine/dizoo/atari/config/serial/qbert/qbert_trex_offppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3621edc462e36e90377c87e6433581f1c53961c5
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/qbert/qbert_trex_offppo_config.py
@@ -0,0 +1,92 @@
+from easydict import EasyDict
+
+qbert_trex_ppo_config = dict(
+ exp_name='qbert_trex_offppo_seed0',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='QbertNoFrameskip-v4',
+ #'ALE/Qbert-v5' is available. But special setting is needed after gym make.
+ frame_stack=4
+ ),
+ reward_model=dict(
+ type='trex',
+ min_snippet_length=30,
+ max_snippet_length=100,
+ checkpoint_min=0,
+ checkpoint_max=100,
+ checkpoint_step=100,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ expert_model_path='abs model path',
+ reward_model_path='abs data path + ./qbert.params',
+ offline_data_path='abs data path',
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[32, 64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ critic_head_layer_num=2,
+ ),
+ learn=dict(
+ update_per_collect=24,
+ batch_size=128,
+ # (bool) Whether to normalize advantage. Default to False.
+ adv_norm=False,
+ learning_rate=0.0001,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=1.0,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.03,
+ clip_ratio=0.1,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=1024,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.95,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(replay_buffer=dict(
+ replay_buffer_size=100000,
+ max_use=3,
+ ), ),
+ ),
+)
+main_config = EasyDict(qbert_trex_ppo_config)
+
+qbert_trex_ppo_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo_offpolicy'),
+)
+create_config = EasyDict(qbert_trex_ppo_create_config)
+
+if __name__ == "__main__":
+ # Users should first run ``cartpole_offppo_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar,
+ # iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_reward_model_trex
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_reward_model_trex((main_config, create_config))
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/__init__.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4ff222aa613908609330e918c7b92e4bb8e5377
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/__init__.py
@@ -0,0 +1,2 @@
+from .spaceinvaders_dqn_config import spaceinvaders_dqn_config, spaceinvaders_dqn_create_config
+from .spaceinvaders_dqfd_config import spaceinvaders_dqfd_config, spaceinvaders_dqfd_create_config
\ No newline at end of file
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_a2c_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_a2c_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb7d4cbb63d22be54fde714c53a65e60de987706
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_a2c_config.py
@@ -0,0 +1,65 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_a2c_config = dict(
+ exp_name='spaceinvaders_a2c_seed0',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[32, 64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ critic_head_layer_num=2,
+ ),
+ learn=dict(
+ batch_size=80,
+ # (bool) Whether to normalize advantage. Default to False.
+ adv_norm=False,
+ learning_rate=0.0001,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ grad_norm=0.5,
+ betas=(0.3, 0.99),
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model 1 times
+ n_sample=80,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.99,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+)
+spaceinvaders_a2c_config = EasyDict(spaceinvaders_a2c_config)
+main_config = spaceinvaders_a2c_config
+
+spaceinvaders_a2c_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='a2c'),
+ replay_buffer=dict(type='naive'),
+)
+spaceinvaders_a2c_create_config = EasyDict(spaceinvaders_a2c_create_config)
+create_config = spaceinvaders_a2c_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial_onpolicy -c spaceinvaders_a2c_config.py -s 0
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_acer_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_acer_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a7afdc3690fd4b6a2dabad96498bdb3d4d1e9064
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_acer_config.py
@@ -0,0 +1,77 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_acer_config = dict(
+ exp_name='spaceinvaders_acer_seed0',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=int(1e6),
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ critic_head_hidden_size=512,
+ critic_head_layer_num=2,
+ actor_head_hidden_size=512,
+ actor_head_layer_num=2,
+ ),
+ unroll_len=64,
+ learn=dict(
+ # (int) collect n_sample data, train model update_per_collect times
+ # here we follow impala serial pipeline
+ update_per_collect=10,
+ # (int) the number of data for a train iteration
+ batch_size=64,
+ # grad_clip_type='clip_norm',
+ # clip_value=10,
+ learning_rate_actor=0.00005,
+ learning_rate_critic=0.0001,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.03,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ # (float) additional discounting parameter
+ trust_region=True,
+ # (float) clip ratio of importance weights
+ c_clip_ratio=10,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ # n_sample=16,
+ n_sample=64,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=3000, ), ),
+ ),
+)
+spaceinvaders_acer_config = EasyDict(spaceinvaders_acer_config)
+main_config = spaceinvaders_acer_config
+
+spaceinvaders_acer_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='acer'),
+)
+spaceinvaders_acer_create_config = EasyDict(spaceinvaders_acer_create_config)
+create_config = spaceinvaders_acer_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c spaceinvaders_acer_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_c51_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_c51_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a12b4846c2c67319e4ad2273d1745e1101deb17b
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_c51_config.py
@@ -0,0 +1,64 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_c51_config = dict(
+ exp_name='spaceinvaders_c51_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ v_min=-10,
+ v_max=10,
+ n_atom=51,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+spaceinvaders_c51_config = EasyDict(spaceinvaders_c51_config)
+main_config = spaceinvaders_c51_config
+spaceinvaders_c51_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='c51'),
+)
+spaceinvaders_c51_create_config = EasyDict(spaceinvaders_c51_create_config)
+create_config = spaceinvaders_c51_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial_onpolicy -c spaceinvaders_c51_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqfd_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqfd_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b9d5fde6945d798acf0f9595adc12cc07c22d73e
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqfd_config.py
@@ -0,0 +1,76 @@
+from easydict import EasyDict
+
+spaceinvaders_dqfd_config = dict(
+ exp_name='spaceinvaders_dqfd_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ lambda1=1.0, # n-step return
+ lambda2=1.0, # supervised loss
+ lambda3=1e-5, # L2 regularization
+ per_train_iter_k=10,
+ expert_replay_buffer_size=10000, # justify the buffer size of the expert buffer
+ ),
+ collect=dict(
+ n_sample=100,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+spaceinvaders_dqfd_config = EasyDict(spaceinvaders_dqfd_config)
+main_config = spaceinvaders_dqfd_config
+spaceinvaders_dqfd_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqfd'),
+)
+spaceinvaders_dqfd_create_config = EasyDict(spaceinvaders_dqfd_create_config)
+create_config = spaceinvaders_dqfd_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_dqfd -c spaceinvaders_dqfd_config.py -s 0`
+ # then input ``spaceinvaders_dqfd_config.py`` upon the instructions.
+ # The reason we need to input the dqfd config is we have to borrow its ``_get_train_sample`` function
+ # in the collector part even though the expert model may be generated from other Q learning algos.
+ from ding.entry.serial_entry_dqfd import serial_pipeline_dqfd
+ from dizoo.atari.config.serial.spaceinvaders import spaceinvaders_dqfd_config, spaceinvaders_dqfd_create_config
+ expert_main_config = spaceinvaders_dqfd_config
+ expert_create_config = spaceinvaders_dqfd_create_config
+ serial_pipeline_dqfd([main_config, create_config], [expert_main_config, expert_create_config], seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d0b810f4959d2ec71094d80fc082e6d691b90f8
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_config.py
@@ -0,0 +1,60 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_dqn_config = dict(
+ exp_name='spaceinvaders_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ random_collect_size=5000,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+spaceinvaders_dqn_config = EasyDict(spaceinvaders_dqn_config)
+main_config = spaceinvaders_dqn_config
+spaceinvaders_dqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+spaceinvaders_dqn_create_config = EasyDict(spaceinvaders_dqn_create_config)
+create_config = spaceinvaders_dqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c spaceinvaders_dqn_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_config_multi_gpu_ddp.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_config_multi_gpu_ddp.py
new file mode 100644
index 0000000000000000000000000000000000000000..093dd7d2ba5091378e96a16dec53043e6da19015
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_config_multi_gpu_ddp.py
@@ -0,0 +1,62 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_dqn_config = dict(
+ exp_name='spaceinvaders_dqn_multi_gpu_ddp_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ multi_gpu=True,
+ priority=False,
+ random_collect_size=5000,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+spaceinvaders_dqn_config = EasyDict(spaceinvaders_dqn_config)
+main_config = spaceinvaders_dqn_config
+spaceinvaders_dqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+spaceinvaders_dqn_create_config = EasyDict(spaceinvaders_dqn_create_config)
+create_config = spaceinvaders_dqn_create_config
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline
+ from ding.utils import DDPContext, to_ddp_config
+ with DDPContext():
+ main_config = to_ddp_config(main_config)
+ serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_config_multi_gpu_dp.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_config_multi_gpu_dp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dcbdc8465d5a86bd2b35473997ae22ebb2dfd02
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_dqn_config_multi_gpu_dp.py
@@ -0,0 +1,61 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_dqn_config = dict(
+ exp_name='spaceinvaders_dqn_multi_gpu_dp_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+spaceinvaders_dqn_config = EasyDict(spaceinvaders_dqn_config)
+main_config = spaceinvaders_dqn_config
+spaceinvaders_dqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+spaceinvaders_dqn_create_config = EasyDict(spaceinvaders_dqn_create_config)
+create_config = spaceinvaders_dqn_create_config
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline
+ from ding.model.template.q_learning import DQN
+ from ding.torch_utils import DataParallel
+ model = DataParallel(DQN(obs_shape=[4, 84, 84], action_shape=6))
+ serial_pipeline((main_config, create_config), seed=0, model=model, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_fqf_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_fqf_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..95df0d46571869f0c1246074f19af22d2fb608e4
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_fqf_config.py
@@ -0,0 +1,63 @@
+from easydict import EasyDict
+
+spaceinvaders_fqf_config = dict(
+ exp_name='spaceinvaders_fqf_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=32,
+ quantile_embedding_size=64,
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate_fraction=2.5e-9,
+ learning_rate_quantile=0.00005,
+ target_update_freq=500,
+ ent_coef=0,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.01,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=1000000, ),
+ ),
+ ),
+)
+spaceinvaders_fqf_config = EasyDict(spaceinvaders_fqf_config)
+main_config = spaceinvaders_fqf_config
+spaceinvaders_fqf_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='fqf'),
+)
+spaceinvaders_fqf_create_config = EasyDict(spaceinvaders_fqf_create_config)
+create_config = spaceinvaders_fqf_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c spaceinvaders_fqf_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
\ No newline at end of file
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_impala_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_impala_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..19be3fc1116e904e171bcde0c8ab87d9be05dc7f
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_impala_config.py
@@ -0,0 +1,80 @@
+from easydict import EasyDict
+
+spaceinvaders_impala_config = dict(
+ exp_name='impala_log/spaceinvaders_impala_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ # manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ # (int) the trajectory length to calculate v-trace target
+ unroll_len=32,
+ random_collect_size=500,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 256, 256],
+ critic_head_hidden_size=256,
+ critic_head_layer_num=3,
+ actor_head_hidden_size=256,
+ actor_head_layer_num=3,
+ ),
+ learn=dict(
+ # (int) collect n_sample data, train model update_per_collect times
+ # here we follow impala serial pipeline
+ update_per_collect=2, # update_per_collect show be in [1, 10]
+ # (int) the number of data for a train iteration
+ batch_size=128,
+ grad_clip_type='clip_norm',
+ clip_value=5,
+ learning_rate=0.0006,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ # (float) additional discounting parameter
+ lambda_=0.95,
+ # (float) clip ratio of importance weights
+ rho_clip_ratio=1.0,
+ # (float) clip ratio of importance weights
+ c_clip_ratio=1.0,
+ # (float) clip ratio of importance sampling
+ rho_pg_clip_ratio=1.0,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=16,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, sliced=True), ),
+ ),
+)
+spaceinvaders_impala_config = EasyDict(spaceinvaders_impala_config)
+main_config = spaceinvaders_impala_config
+
+spaceinvaders_impala_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='impala'),
+ replay_buffer=dict(type='naive'),
+)
+spaceinvaders_impala_create_config = EasyDict(spaceinvaders_impala_create_config)
+create_config = spaceinvaders_impala_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c spaceinvaders_impala_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_iqn_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_iqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1fe4b459036f65637fece67b48309fab25f24822
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_iqn_config.py
@@ -0,0 +1,63 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_iqn_config = dict(
+ exp_name='spaceinvaders_iqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=32,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ kappa=1.0,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+spaceinvaders_iqn_config = EasyDict(spaceinvaders_iqn_config)
+main_config = spaceinvaders_iqn_config
+spaceinvaders_iqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='iqn'),
+)
+spaceinvaders_iqn_create_config = EasyDict(spaceinvaders_iqn_create_config)
+create_config = spaceinvaders_iqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c spaceinvaders_iqn_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_mdqn_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_mdqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9abf5159d300ec0b8ae35c0c2fc7c307936fb4c6
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_mdqn_config.py
@@ -0,0 +1,63 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_mdqn_config = dict(
+ exp_name='spaceinvaders_mdqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000,
+ env_id='SpaceInvadersNoFrameskip-v0',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ entropy_tau=0.03,
+ m_alpha=0.9,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000, ))
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+spaceinvaders_mdqn_config = EasyDict(spaceinvaders_mdqn_config)
+main_config = spaceinvaders_mdqn_config
+spaceinvaders_mdqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='mdqn'),
+)
+spaceinvaders_mdqn_create_config = EasyDict(spaceinvaders_mdqn_create_config)
+create_config = spaceinvaders_mdqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c spaceinvaders_mdqn_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0, max_env_step=int(3e7), dynamic_seed=False)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_ngu_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_ngu_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f01caa4ac1f794d8770e5bf6e112e01dcf85abd1
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_ngu_config.py
@@ -0,0 +1,129 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+nstep = 5
+max_env_step = int(10e6)
+
+spaceinvaders_ngu_config = dict(
+ exp_name='spaceinvaders_ngu_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ obs_plus_prev_action_reward=True, # use specific env wrapper for ngu policy
+ stop_value=int(1e6),
+ frame_stack=4,
+ ),
+ rnd_reward_model=dict(
+ intrinsic_reward_type='add',
+ learning_rate=1e-4,
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ batch_size=320,
+ update_per_collect=10,
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='rnd-ngu',
+ ),
+ episodic_reward_model=dict(
+ # means if using rescale trick to the last non-zero reward
+ # when combing extrinsic and intrinsic reward.
+ # the rescale trick only used in:
+ # 1. sparse reward env minigrid, in which the last non-zero reward is a strong positive signal
+ # 2. the last reward of each episode directly reflects the agent's completion of the task, e.g. lunarlander
+ # Note that the ngu intrinsic reward is a positive value (max value is 5), in these envs,
+ # the last non-zero reward should not be overwhelmed by intrinsic rewards, so we need rescale the
+ # original last nonzero extrinsic reward.
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_rescale=False,
+ # means the rescale value for the last non-zero reward, only used when last_nonzero_reward_rescale is True
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_weight=1,
+ intrinsic_reward_type='add',
+ learning_rate=1e-4,
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ batch_size=320,
+ update_per_collect=10,
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='episodic',
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ priority=True,
+ priority_IS_weight=True,
+ discount_factor=0.997,
+ nstep=nstep,
+ burnin_step=20,
+ # (int) is the total length of [sequence sample] minus
+ # the length of burnin part in [sequence sample],
+ # i.e., = = +
+ learn_unroll_len=40, # set this key according to the episode length
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ collector_env_num=collector_env_num,
+ ),
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=int(2e4),
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+spaceinvaders_ngu_config = EasyDict(spaceinvaders_ngu_config)
+main_config = spaceinvaders_ngu_config
+spaceinvaders_ngu_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ngu'),
+ rnd_reward_model=dict(type='rnd-ngu'),
+ episodic_reward_model=dict(type='episodic'),
+)
+spaceinvaders_ngu_create_config = EasyDict(spaceinvaders_ngu_create_config)
+create_config = spaceinvaders_ngu_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_ngu -c spaceinvaders_ngu_config.py -s 0`
+ from ding.entry import serial_pipeline_ngu
+ serial_pipeline_ngu([main_config, create_config], seed=0, max_env_step=max_env_step)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_offppo_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_offppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2b284d9f5bf9a7194879a269596c27ec8de2a730
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_offppo_config.py
@@ -0,0 +1,69 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_ppo_config = dict(
+ exp_name='spaceinvaders_offppo_seed0',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[32, 64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ critic_head_layer_num=2,
+ ),
+ learn=dict(
+ update_per_collect=24,
+ batch_size=128,
+ # (bool) Whether to normalize advantage. Default to False.
+ adv_norm=False,
+ learning_rate=0.0001,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=1.0,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.03,
+ clip_ratio=0.1,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=1024,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.95,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(replay_buffer=dict(
+ replay_buffer_size=100000,
+ max_use=5,
+ ), ),
+ ),
+)
+spaceinvaders_ppo_config = EasyDict(spaceinvaders_ppo_config)
+main_config = spaceinvaders_ppo_config
+
+spaceinvaders_ppo_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo_offpolicy'),
+)
+spaceinvaders_ppo_create_config = EasyDict(spaceinvaders_ppo_create_config)
+create_config = spaceinvaders_ppo_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c spaceinvaders_offppo_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..498381e3a8b60c9f792d63a07b21c0b46abaf2cd
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_onppo_config.py
@@ -0,0 +1,70 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_ppo_config = dict(
+ exp_name='spaceinvaders_onppo_seed0',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=int(1e10),
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ update_per_collect=1,
+ batch_size=320,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ # for onppo, when we recompute adv, we need the key done in data to split traj, so we must
+ # use ignore_done=False here,
+ # but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
+ # for halfcheetah, the length=1000
+ ignore_done=False,
+ grad_clip_type='clip_norm',
+ grad_clip_value=0.5,
+ ),
+ collect=dict(
+ n_sample=3200,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ ),
+)
+main_config = EasyDict(spaceinvaders_ppo_config)
+
+spaceinvaders_ppo_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(spaceinvaders_ppo_create_config)
+
+if __name__ == "__main__":
+ # or you can enter ding -m serial_onpolicy -c spaceinvaders_onppo_config.py -s 0
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_ppg_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_ppg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..af111ccf36b7ee595c5cf09860cc6712fcb6878d
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_ppg_config.py
@@ -0,0 +1,80 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_ppg_config = dict(
+ exp_name='spaceinvaders_ppg_seed0',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[32, 64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ critic_head_layer_num=2,
+ ),
+ learn=dict(
+ update_per_collect=24,
+ batch_size=128,
+ # (bool) Whether to normalize advantage. Default to False.
+ adv_norm=False,
+ learning_rate=0.0001,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.03,
+ clip_ratio=0.1,
+ epochs_aux=6,
+ beta_weight=1,
+ aux_freq=100
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=1024,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.95,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(
+ replay_buffer=dict(
+ multi_buffer=True,
+ policy=dict(
+ replay_buffer_size=100000,
+ max_use=3,
+ ),
+ value=dict(
+ replay_buffer_size=100000,
+ max_use=10,
+ ),
+ ),
+ ),
+ ),
+)
+spaceinvaders_ppg_config = EasyDict(spaceinvaders_ppg_config)
+main_config = spaceinvaders_ppg_config
+spaceinvaders_ppg_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppg_offpolicy'),
+)
+spaceinvaders_ppg_create_config = EasyDict(spaceinvaders_ppg_create_config)
+create_config = EasyDict(spaceinvaders_ppg_create_config)
+
+if __name__ == '__main__':
+ from dizoo.atari.entry.atari_ppg_main import main
+ # PPG needs to use specific entry, you can run `dizoo/atari/entry/atari_ppg_main.py`
+ main(main_config)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_qrdqn_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_qrdqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c67c33ea43f2ed15aa71d7005c56f6582ae4c712
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_qrdqn_config.py
@@ -0,0 +1,62 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_qrdqn_config = dict(
+ exp_name='spaceinvaders_qrdqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ num_quantiles=64,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+spaceinvaders_qrdqn_config = EasyDict(spaceinvaders_qrdqn_config)
+main_config = spaceinvaders_qrdqn_config
+spaceinvaders_qrdqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qrdqn'),
+)
+spaceinvaders_qrdqn_create_config = EasyDict(spaceinvaders_qrdqn_create_config)
+create_config = spaceinvaders_qrdqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c spaceinvaders_qrdqn_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_r2d2_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_r2d2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d9145d741bf9224755fd299c70065e52817b96d
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_r2d2_config.py
@@ -0,0 +1,90 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 5
+spaceinvaders_r2d2_config = dict(
+ exp_name='spaceinvaders_r2d2_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=8,
+ stop_value=int(1e6),
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ res_link=False,
+ ),
+ discount_factor=0.997,
+ nstep=5,
+ burnin_step=20,
+ # (int) the whole sequence length to unroll the RNN network minus
+ # the timesteps of burnin part,
+ # i.e., = = +
+ learn_unroll_len=80,
+ learn=dict(
+ # according to the R2D2 paper, actor parameter update interval is 400
+ # environment timesteps, and in per collect phase, we collect sequence
+ # samples, the length of each sequence sample is + ,
+ # e.g. if n_sample=32, is 100, thus 32*100/400=8,
+ # we will set update_per_collect=8 in most environments.
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=10000,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+spaceinvaders_r2d2_config = EasyDict(spaceinvaders_r2d2_config)
+main_config = spaceinvaders_r2d2_config
+spaceinvaders_r2d2_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2'),
+)
+spaceinvaders_r2d2_create_config = EasyDict(spaceinvaders_r2d2_create_config)
+create_config = spaceinvaders_r2d2_create_config
+
+if __name__ == "__main__":
+ # or you can enter ding -m serial -c spaceinvaders_r2d2_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_r2d2_gtrxl_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_r2d2_gtrxl_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b176597821d741449c05e4c473271f6b28c7f239
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_r2d2_gtrxl_config.py
@@ -0,0 +1,88 @@
+from easydict import EasyDict
+
+spaceinvaders_r2d2_gtrxl_config = dict(
+ exp_name='spaceinvaders_r2d2_gtrxl_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False)
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ hidden_size=2048,
+ encoder_hidden_size_list=[128, 512, 2048],
+ gru_bias=1.0,
+ memory_len=0,
+ dropout=0.2,
+ att_layer_num=5,
+ att_head_dim=512,
+ ),
+ discount_factor=0.99,
+ nstep=3,
+ burnin_step=0,
+ unroll_len=13,
+ seq_len=10,
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ value_rescale=True,
+ init_memory='zero',
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=8,
+ ),
+ eval=dict(env_num=8, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=10000,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+spaceinvaders_r2d2_gtrxl_config = EasyDict(spaceinvaders_r2d2_gtrxl_config)
+main_config = spaceinvaders_r2d2_gtrxl_config
+spaceinvaders_r2d2_gtrxl_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2_gtrxl'),
+)
+spaceinvaders_r2d2_gtrxl_create_config = EasyDict(spaceinvaders_r2d2_gtrxl_create_config)
+create_config = spaceinvaders_r2d2_gtrxl_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c spaceinvaders_r2d2_gtrxl_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_r2d2_residual_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_r2d2_residual_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9341b9f0924bb409663e25a98477bd68b6084e44
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_r2d2_residual_config.py
@@ -0,0 +1,85 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 5
+spaceinvaders_r2d2_residual_config = dict(
+ exp_name='spaceinvaders_r2d2_residual_link_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ priority_IS_weight=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ res_link=True,
+ ),
+ discount_factor=0.997,
+ nstep=5,
+ burnin_step=2,
+ # (int) the whole sequence length to unroll the RNN network minus
+ # the timesteps of burnin part,
+ # i.e., = = +
+ learn_unroll_len=40,
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=10000,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+spaceinvaders_r2d2_residual_config = EasyDict(spaceinvaders_r2d2_residual_config)
+main_config = spaceinvaders_r2d2_residual_config
+spaceinvaders_r2d2_residual_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2'),
+)
+spaceinvaders_r2d2_residual_create_config = EasyDict(spaceinvaders_r2d2_residual_create_config)
+create_config = spaceinvaders_r2d2_residual_create_config
+
+if __name__ == "__main__":
+ # or you can enter ding -m serial -c spaceinvaders_r2d2_residual_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_rainbow_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_rainbow_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..00be159e71fbdbcd95da9f93000fbe2988c08418
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_rainbow_config.py
@@ -0,0 +1,66 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_rainbow_config = dict(
+ exp_name='spaceinvaders_rainbow_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ priority_IS_weight=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ v_min=-10,
+ v_max=10,
+ n_atom=51,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ iqn=False,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.05,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+spaceinvaders_rainbow_config = EasyDict(spaceinvaders_rainbow_config)
+main_config = spaceinvaders_rainbow_config
+spaceinvaders_rainbow_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='rainbow'),
+)
+spaceinvaders_rainbow_create_config = EasyDict(spaceinvaders_rainbow_create_config)
+create_config = spaceinvaders_rainbow_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c spaceinvaders_rainbow_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sqil_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sqil_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..84a83a63579e565117478145b03bfc1494eea113
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sqil_config.py
@@ -0,0 +1,66 @@
+from easydict import EasyDict
+
+spaceinvaders_sqil_config = dict(
+ exp_name='spaceinvaders_sqil_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.97, # discount_factor: 0.97-0.99
+ learn=dict(update_per_collect=10, batch_size=32, learning_rate=0.0001, target_update_freq=500,
+ alpha=0.1), # alpha: 0.08-0.12
+ collect=dict(
+ n_sample=100,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+spaceinvaders_sqil_config = EasyDict(spaceinvaders_sqil_config)
+main_config = spaceinvaders_sqil_config
+spaceinvaders_sqil_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sql'),
+)
+spaceinvaders_sqil_create_config = EasyDict(spaceinvaders_sqil_create_config)
+create_config = spaceinvaders_sqil_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_sqil -c spaceinvaders_sqil_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. spaceinvaders_dqn_config.py
+ from ding.entry import serial_pipeline_sqil
+ from dizoo.atari.config.serial.spaceinvaders import spaceinvaders_dqn_config, spaceinvaders_dqn_create_config
+ expert_main_config = spaceinvaders_dqn_config
+ expert_create_config = spaceinvaders_dqn_create_config
+ serial_pipeline_sqil([main_config, create_config], [expert_main_config, expert_create_config], seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sql_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..888eeb1bfe6cc49af41c337620fd119e331de957
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_sql_config.py
@@ -0,0 +1,56 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_sql_config = dict(
+ exp_name='spaceinvaders_sql_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, reset_inplace=True)
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(update_per_collect=10, batch_size=32, learning_rate=0.0001, target_update_freq=500, alpha=0.1),
+ collect=dict(n_sample=100),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=500000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+spaceinvaders_sql_config = EasyDict(spaceinvaders_sql_config)
+main_config = spaceinvaders_sql_config
+spaceinvaders_sql_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sql'),
+)
+spaceinvaders_sql_create_config = EasyDict(spaceinvaders_sql_create_config)
+create_config = spaceinvaders_sql_create_config
+
+if __name__ == '__main__':
+ # or you can enter ding -m serial -c spaceinvaders_sql_config.py -s 0
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_dqn_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a4491d68a92d2e95ef9ed38791c6671f9da0fc4
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_dqn_config.py
@@ -0,0 +1,100 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_trex_dqn_config = dict(
+ exp_name='spaceinvaders_trex_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ reward_model=dict(
+ type='trex',
+ min_snippet_length=50,
+ max_snippet_length=100,
+ checkpoint_min=10000,
+ checkpoint_max=90000,
+ checkpoint_step=10000,
+ num_snippets=100000,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ # path to expert models that generate demonstration data
+ # Users should add their own model path here. Model path should lead to an exp_name.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name``.
+ # For example, if you want to use dqn to generate demos, you can use ``spaceinvaders_dqn``
+ expert_model_path='model_path_placeholder',
+ # path to save reward model
+ # Users should add their own model path here.
+ # Absolute path is recommended.
+ # For example, if you use ``spaceinvaders_drex``, then the reward model will be saved in this directory.
+ reward_model_path='model_path_placeholder + ./spaceinvaders.params',
+ # path to save generated observations.
+ # Users should add their own model path here.
+ # Absolute path is recommended.
+ # For example, if you use ``spaceinvaders_drex``, then all the generated data will be saved in this directory.
+ offline_data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(replay_buffer_size=400000, ),
+ ),
+ ),
+)
+spaceinvaders_trex_dqn_config = EasyDict(spaceinvaders_trex_dqn_config)
+main_config = spaceinvaders_trex_dqn_config
+spaceinvaders_trex_dqn_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+spaceinvaders_trex_dqn_create_config = EasyDict(spaceinvaders_trex_dqn_create_config)
+create_config = spaceinvaders_trex_dqn_create_config
+
+if __name__ == '__main__':
+ # Users should first run ``spaceinvaders_dqn_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_trex
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_trex([main_config, create_config])
diff --git a/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_offppo_config.py b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_offppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7934a67a7ded63e4cb3cdbeef6f737e55e076c1c
--- /dev/null
+++ b/DI-engine/dizoo/atari/config/serial/spaceinvaders/spaceinvaders_trex_offppo_config.py
@@ -0,0 +1,107 @@
+from copy import deepcopy
+from easydict import EasyDict
+
+spaceinvaders_trex_ppo_config = dict(
+ exp_name='spaceinvaders_trex_offppo_seed0',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10000000000,
+ env_id='SpaceInvadersNoFrameskip-v4',
+ #'ALE/SpaceInvaders-v5' is available. But special setting is needed after gym make.
+ frame_stack=4,
+ manager=dict(shared_memory=False, )
+ ),
+ reward_model=dict(
+ type='trex',
+ min_snippet_length=30,
+ max_snippet_length=100,
+ checkpoint_min=0,
+ checkpoint_max=100,
+ checkpoint_step=100,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ # path to expert models that generate demonstration data
+ # Users should add their own model path here. Model path should lead to an exp_name.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name``.
+ # For example, if you want to use dqn to generate demos, you can use ``spaceinvaders_dqn``
+ expert_model_path='model_path_placeholder',
+ # path to save reward model
+ # Users should add their own model path here.
+ # Absolute path is recommended.
+ # For example, if you use ``spaceinvaders_drex``, then the reward model will be saved in this directory.
+ reward_model_path='model_path_placeholder + ./spaceinvaders.params',
+ # path to save generated observations.
+ # Users should add their own model path here.
+ # Absolute path is recommended.
+ # For example, if you use ``spaceinvaders_drex``, then all the generated data will be saved in this directory.
+ offline_data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=6,
+ encoder_hidden_size_list=[32, 64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ critic_head_layer_num=2,
+ ),
+ learn=dict(
+ update_per_collect=24,
+ batch_size=128,
+ # (bool) Whether to normalize advantage. Default to False.
+ adv_norm=False,
+ learning_rate=0.0001,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=1.0,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.03,
+ clip_ratio=0.1,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=1024,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.95,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(replay_buffer=dict(
+ replay_buffer_size=100000,
+ max_use=5,
+ ), ),
+ ),
+)
+spaceinvaders_trex_ppo_config = EasyDict(spaceinvaders_trex_ppo_config)
+main_config = spaceinvaders_trex_ppo_config
+
+spaceinvaders_trex_ppo_create_config = dict(
+ env=dict(
+ type='atari',
+ import_names=['dizoo.atari.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo_offpolicy'),
+)
+spaceinvaders_trex_ppo_create_config = EasyDict(spaceinvaders_trex_ppo_create_config)
+create_config = spaceinvaders_trex_ppo_create_config
+
+if __name__ == '__main__':
+ # Users should first run ``spaceinvaders_offppo_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_trex
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_trex([main_config, create_config])
diff --git a/DI-engine/dizoo/atari/entry/__init__.py b/DI-engine/dizoo/atari/entry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/atari/entry/atari_dqn_main.py b/DI-engine/dizoo/atari/entry/atari_dqn_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..930d0ac64d0ab1c0ace05a242fc2ba581da2dcbf
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/atari_dqn_main.py
@@ -0,0 +1,75 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+from copy import deepcopy
+from functools import partial
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import SyncSubprocessEnvManager
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from ding.utils import set_pkg_seed, deep_merge_dicts
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.atari.envs import AtariEnv
+from dizoo.atari.config.serial.pong.pong_dqn_config import pong_dqn_config
+
+
+def main(cfg, seed=0, max_iterations=int(1e10)):
+ cfg = compile_config(
+ cfg,
+ SyncSubprocessEnvManager,
+ DQNPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_cfg = AtariEnv.create_collector_env_cfg(cfg.env)
+ evaluator_env_cfg = AtariEnv.create_evaluator_env_cfg(cfg.env)
+ collector_env = SyncSubprocessEnvManager(
+ env_fn=[partial(AtariEnv, cfg=c) for c in collector_env_cfg], cfg=cfg.env.manager
+ )
+ evaluator_env = SyncSubprocessEnvManager(
+ env_fn=[partial(AtariEnv, cfg=c) for c in evaluator_env_cfg], cfg=cfg.env.manager
+ )
+
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(
+ cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name, instance_name='replay_buffer'
+ )
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ while True:
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ eps = epsilon_greedy(collector.envstep)
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ for i in range(cfg.policy.learn.update_per_collect):
+ batch_size = learner.policy.get_attribute('batch_size')
+ train_data = replay_buffer.sample(batch_size, learner.train_iter)
+ if train_data is not None:
+ learner.train(train_data, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(EasyDict(pong_dqn_config))
diff --git a/DI-engine/dizoo/atari/entry/atari_dt_main.py b/DI-engine/dizoo/atari/entry/atari_dt_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..b89bbaec7ee13530b6e87136d6b59607ac04d67f
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/atari_dt_main.py
@@ -0,0 +1,56 @@
+import torch.nn as nn
+import torch.distributed as dist
+from ditk import logging
+from ding.model import DecisionTransformer
+from ding.policy import DTPolicy
+from ding.envs import SubprocessEnvManagerV2
+from ding.envs import AllinObsWrapper
+from ding.data import create_dataset
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OfflineRLContext
+from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_logger, termination_checker, \
+ OfflineMemoryDataFetcher
+from ding.utils import set_pkg_seed, DDPContext, to_ddp_config
+from dizoo.atari.envs import AtariEnv
+from dizoo.atari.config.serial.pong.pong_dt_config import main_config, create_config
+
+
+def main():
+ # If you don't have offline data, you need to prepare if first and set the data_path in config
+ # For demostration, we also can train a RL policy (e.g. SAC) and collect some data
+ logging.getLogger().setLevel(logging.INFO)
+ with DDPContext():
+ cmain_config = to_ddp_config(main_config)
+ cfg = compile_config(cmain_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OfflineRLContext()):
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AllinObsWrapper(AtariEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ dataset = create_dataset(cfg)
+ cfg.policy.model.max_timestep = dataset.get_max_timestep()
+ state_encoder = nn.Sequential(
+ nn.Conv2d(4, 32, 8, stride=4, padding=0), nn.ReLU(), nn.Conv2d(32, 64, 4, stride=2, padding=0),
+ nn.ReLU(), nn.Conv2d(64, 64, 3, stride=1, padding=0), nn.ReLU(), nn.Flatten(),
+ nn.Linear(3136, cfg.policy.model.h_dim), nn.Tanh()
+ )
+
+ model = DecisionTransformer(**cfg.policy.model, state_encoder=state_encoder)
+ # model.parallelize()
+ policy = DTPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(OfflineMemoryDataFetcher(cfg, dataset))
+ task.use(trainer(cfg, policy.learn_mode))
+ task.use(termination_checker(max_train_iter=3e4))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(offline_logger())
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/atari/entry/atari_impala_main.py b/DI-engine/dizoo/atari/entry/atari_impala_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..69dce9f8a49d309dede64158c70c0e52e586f196
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/atari_impala_main.py
@@ -0,0 +1,51 @@
+import gym
+from ditk import logging
+from ding.model import VAC
+from ding.policy import IMPALAPolicy
+from ding.envs import SubprocessEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ CkptSaver, online_logger, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.atari.config.serial.pong.pong_impala_config import main_config, create_config
+from dizoo.atari.envs import AtariEnv
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env_cfg = AtariEnv.create_collector_env_cfg(cfg.env)
+ evaluator_env_cfg = AtariEnv.create_evaluator_env_cfg(cfg.env)
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(cfg=c) for c in collector_env_cfg], cfg=cfg.env.manager
+ )
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(cfg=c) for c in evaluator_env_cfg], cfg=cfg.env.manager
+ )
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ buffer_ = DequeBuffer(
+ size=cfg.policy.other.replay_buffer.replay_buffer_size, sliced=cfg.policy.other.replay_buffer.sliced
+ )
+ policy = IMPALAPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(
+ StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
+ )
+ task.use(data_pusher(cfg, buffer_, group_by_env=True))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(online_logger(train_show_freq=300))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=10000))
+ task.use(termination_checker(max_env_step=1e7))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/atari/entry/atari_ppg_main.py b/DI-engine/dizoo/atari/entry/atari_ppg_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..25236c3f670f281d10bfba814a3b8ab40cd03dc4
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/atari_ppg_main.py
@@ -0,0 +1,82 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+from copy import deepcopy
+from functools import partial
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import SyncSubprocessEnvManager
+from ding.policy import PPGPolicy
+from ding.model import PPG
+from ding.utils import set_pkg_seed, deep_merge_dicts
+from dizoo.atari.envs import AtariEnv
+from dizoo.atari.config.serial.spaceinvaders.spaceinvaders_ppg_config import spaceinvaders_ppg_config
+
+
+def main(cfg, seed=0, max_iterations=int(1e10)):
+ cfg.exp_name = 'spaceinvaders_ppg_seed0'
+ cfg = compile_config(
+ cfg,
+ SyncSubprocessEnvManager,
+ PPGPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator, {
+ 'policy': AdvancedReplayBuffer,
+ 'value': AdvancedReplayBuffer
+ },
+ save_cfg=True
+ )
+ collector_env_cfg = AtariEnv.create_collector_env_cfg(cfg.env)
+ evaluator_env_cfg = AtariEnv.create_evaluator_env_cfg(cfg.env)
+ collector_env = SyncSubprocessEnvManager(
+ env_fn=[partial(AtariEnv, cfg=c) for c in collector_env_cfg], cfg=cfg.env.manager
+ )
+ evaluator_env = SyncSubprocessEnvManager(
+ env_fn=[partial(AtariEnv, cfg=c) for c in evaluator_env_cfg], cfg=cfg.env.manager
+ )
+
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = PPG(**cfg.policy.model)
+ policy = PPGPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ policy_buffer = AdvancedReplayBuffer(
+ cfg.policy.other.replay_buffer.policy, tb_logger, exp_name=cfg.exp_name, instance_name='policy_buffer'
+ )
+ value_buffer = AdvancedReplayBuffer(
+ cfg.policy.other.replay_buffer.value, tb_logger, exp_name=cfg.exp_name, instance_name='value_buffer'
+ )
+
+ while True:
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ new_data = collector.collect(train_iter=learner.train_iter)
+ policy_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ value_buffer.push(deepcopy(new_data), cur_collector_envstep=collector.envstep)
+ for i in range(cfg.policy.learn.update_per_collect):
+ batch_size = learner.policy.get_attribute('batch_size')
+ policy_data = policy_buffer.sample(batch_size['policy'], learner.train_iter)
+ value_data = value_buffer.sample(batch_size['value'], learner.train_iter)
+ if policy_data is not None and value_data is not None:
+ train_data = {'policy': policy_data, 'value': value_data}
+ learner.train(train_data, collector.envstep)
+ policy_buffer.clear()
+ value_buffer.clear()
+
+
+if __name__ == "__main__":
+ main(EasyDict(spaceinvaders_ppg_config))
diff --git a/DI-engine/dizoo/atari/entry/phoenix_fqf_main.py b/DI-engine/dizoo/atari/entry/phoenix_fqf_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..7d8ab5a3841a44d42360d13cf95bea5ede8eb858
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/phoenix_fqf_main.py
@@ -0,0 +1,74 @@
+import os
+import torch
+from tensorboardX import SummaryWriter
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.policy import FQFPolicy
+from ding.model import FQF
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.atari.config.serial.phoenix.phoenix_fqf_config import phoenix_fqf_config, create_config
+from ding.utils import DistContext
+from functools import partial
+from ding.envs import get_vec_env_setting, create_env_manager
+
+
+def main(cfg, create_cfg, seed=0):
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = FQF(**cfg.policy.model)
+ policy = FQFPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Set up other modules, etc. epsilon greedy
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluating at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Sampling data from environments
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Training
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+
+ if collector.envstep >= int(1e7):
+ break
+
+
+if __name__ == "__main__":
+ # with DistContext():
+ main(phoenix_fqf_config, create_config)
diff --git a/DI-engine/dizoo/atari/entry/phoenix_iqn_main.py b/DI-engine/dizoo/atari/entry/phoenix_iqn_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..91f528505d5f21a97cbd4095cf1e414adcd118ab
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/phoenix_iqn_main.py
@@ -0,0 +1,74 @@
+import os
+import torch
+from tensorboardX import SummaryWriter
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.policy import IQNPolicy
+from ding.model import IQN
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.atari.config.serial.phoenix.phoenix_iqn_config import phoenix_iqn_config, create_config
+from ding.utils import DistContext
+from functools import partial
+from ding.envs import get_vec_env_setting, create_env_manager
+
+
+def main(cfg, create_cfg, seed=0):
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = IQN(**cfg.policy.model)
+ policy = IQNPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Set up other modules, etc. epsilon greedy
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluating at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Sampling data from environments
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Training
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+
+ if collector.envstep >= int(1e7):
+ break
+
+
+if __name__ == "__main__":
+ # with DistContext():
+ main(phoenix_iqn_config, create_config)
diff --git a/DI-engine/dizoo/atari/entry/pong_cql_main.py b/DI-engine/dizoo/atari/entry/pong_cql_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..fca48a1ff696438a73fdcf7dceba42167bef4e18
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/pong_cql_main.py
@@ -0,0 +1,55 @@
+import torch
+from copy import deepcopy
+
+from ding.entry import serial_pipeline_offline, collect_demo_data, eval, serial_pipeline
+
+
+def train_cql(args):
+ from dizoo.atari.config.serial.pong.pong_cql_config import main_config, create_config
+ main_config.exp_name = 'pong_cql'
+ main_config.policy.collect.data_path = './pong/expert_demos.hdf5'
+ main_config.policy.collect.data_type = 'hdf5'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+def eval_ckpt(args):
+ from dizoo.atari.config.serial.pong.pong_qrdqn_generation_data_config import main_config, create_config
+ main_config.exp_name = 'pong'
+ config = deepcopy([main_config, create_config])
+ eval(config, seed=args.seed, load_path='./pong/ckpt/ckpt_best.pth.tar')
+
+
+def generate(args):
+ from dizoo.atari.config.serial.pong.pong_qrdqn_generation_data_config import main_config, create_config
+ main_config.exp_name = 'pong'
+ main_config.policy.collect.save_path = './pong/expert.pkl'
+ config = deepcopy([main_config, create_config])
+ state_dict = torch.load('./pong/ckpt/ckpt_best.pth.tar', map_location='cpu')
+ collect_demo_data(
+ config,
+ collect_count=int(1e5),
+ seed=args.seed,
+ expert_data_path=main_config.policy.collect.save_path,
+ state_dict=state_dict
+ )
+
+
+def train_expert(args):
+ from dizoo.atari.config.serial.pong.pong_qrdqn_config import main_config, create_config
+ main_config.exp_name = 'pong'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline(config, seed=args.seed, max_iterations=1e6)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=10)
+ args = parser.parse_args()
+
+ train_expert(args)
+ eval_ckpt(args)
+ generate(args)
+ train_cql(args)
diff --git a/DI-engine/dizoo/atari/entry/pong_dqn_envpool_main.py b/DI-engine/dizoo/atari/entry/pong_dqn_envpool_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..769fe4f261447b8d5843f40658495acb4ebb18d4
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/pong_dqn_envpool_main.py
@@ -0,0 +1,91 @@
+import os
+from easydict import EasyDict
+from tensorboardX import SummaryWriter
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs.env_manager.envpool_env_manager import PoolEnvManager
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.atari.config.serial import pong_dqn_envpool_config
+
+
+def main(cfg, seed=0, max_iterations=int(1e10)):
+ cfg.exp_name = 'atari_dqn_envpool'
+ cfg = compile_config(
+ cfg,
+ PoolEnvManager,
+ DQNPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_cfg = EasyDict(
+ {
+ 'env_id': cfg.env.env_id,
+ 'env_num': cfg.env.collector_env_num,
+ 'batch_size': cfg.env.collector_batch_size,
+ # env wrappers
+ 'episodic_life': True, # collector: True
+ 'reward_clip': True, # collector: True
+ 'gray_scale': cfg.env.get('gray_scale', True),
+ 'stack_num': cfg.env.get('stack_num', 4),
+ 'frame_skip': cfg.env.get('frame_skip', 4),
+ }
+ )
+ collector_env = PoolEnvManager(collector_env_cfg)
+ evaluator_env_cfg = EasyDict(
+ {
+ 'env_id': cfg.env.env_id,
+ 'env_num': cfg.env.evaluator_env_num,
+ 'batch_size': cfg.env.evaluator_batch_size,
+ # env wrappers
+ 'episodic_life': False, # evaluator: False
+ 'reward_clip': False, # evaluator: False
+ 'gray_scale': cfg.env.get('gray_scale', True),
+ 'stack_num': cfg.env.get('stack_num', 4),
+ 'frame_skip': cfg.env.get('frame_skip', 4),
+ }
+ )
+ evaluator_env = PoolEnvManager(evaluator_env_cfg)
+ collector_env.seed(seed)
+ evaluator_env.seed(seed)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(
+ cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name, instance_name='replay_buffer'
+ )
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ while True:
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ eps = epsilon_greedy(collector.envstep)
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ for i in range(cfg.policy.learn.update_per_collect):
+ batch_size = learner.policy.get_attribute('batch_size')
+ train_data = replay_buffer.sample(batch_size, learner.train_iter)
+ if train_data is not None:
+ learner.train(train_data, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(EasyDict(pong_dqn_envpool_config))
diff --git a/DI-engine/dizoo/atari/entry/pong_fqf_main.py b/DI-engine/dizoo/atari/entry/pong_fqf_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..816ec566c4b566505a1ae22a3064e00c11c00aee
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/pong_fqf_main.py
@@ -0,0 +1,74 @@
+import os
+import torch
+from tensorboardX import SummaryWriter
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.policy import FQFPolicy
+from ding.model import FQF
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.atari.config.serial.pong.pong_fqf_config import pong_fqf_config, create_config
+from ding.utils import DistContext
+from functools import partial
+from ding.envs import get_vec_env_setting, create_env_manager
+
+
+def main(cfg, create_cfg, seed=0):
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = FQF(**cfg.policy.model)
+ policy = FQFPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Set up other modules, etc. epsilon greedy
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluating at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Sampling data from environments
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Training
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+
+ if collector.envstep >= 10000000:
+ break
+
+
+if __name__ == "__main__":
+ # with DistContext():
+ main(pong_fqf_config, create_config)
diff --git a/DI-engine/dizoo/atari/entry/qbert_cql_main.py b/DI-engine/dizoo/atari/entry/qbert_cql_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ce66ac226cdda8e6f04307f95c05718a392f3e3
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/qbert_cql_main.py
@@ -0,0 +1,57 @@
+import torch
+from copy import deepcopy
+
+from dizoo.atari.config.serial.qbert.qbert_qrdqn_generation_data_config import main_config, create_config
+from ding.entry import serial_pipeline_offline, collect_demo_data, eval, serial_pipeline
+
+
+def train_cql(args):
+ from dizoo.atari.config.serial.qbert.qbert_cql_config import main_config, create_config
+ main_config.exp_name = 'qbert_cql_num_200_weight_10'
+ main_config.policy.collect.data_path = './qbert/expert_demos.hdf5'
+ main_config.policy.collect.data_type = 'hdf5'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+def eval_ckpt(args):
+ main_config.exp_name = 'qbert'
+ main_config.policy.learn.learner.load_path = './qbert/ckpt/ckpt_best.pth.tar'
+ main_config.policy.learn.learner.hook.load_ckpt_before_run = './qbert/ckpt/ckpt_best.pth.tar'
+ config = deepcopy([main_config, create_config])
+ eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run)
+
+
+def generate(args):
+ main_config.exp_name = 'qbert'
+ main_config.policy.learn.learner.load_path = './qbert/ckpt/ckpt_best.pth.tar'
+ main_config.policy.collect.save_path = './qbert/expert.pkl'
+ config = deepcopy([main_config, create_config])
+ state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
+ collect_demo_data(
+ config,
+ collect_count=main_config.policy.other.replay_buffer.replay_buffer_size,
+ seed=args.seed,
+ expert_data_path=main_config.policy.collect.save_path,
+ state_dict=state_dict
+ )
+
+
+def train_expert(args):
+ from dizoo.atari.config.serial.qbert.qbert_qrdqn_config import main_config, create_config
+ main_config.exp_name = 'qbert'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline(config, seed=args.seed, max_iterations=2e6)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=10)
+ args = parser.parse_args()
+
+ train_expert(args)
+ eval_ckpt(args)
+ generate(args)
+ train_cql(args)
diff --git a/DI-engine/dizoo/atari/entry/qbert_fqf_main.py b/DI-engine/dizoo/atari/entry/qbert_fqf_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c87c549a89c20ac15fc2983ee3cb47bc969ae2f
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/qbert_fqf_main.py
@@ -0,0 +1,74 @@
+import os
+import torch
+from tensorboardX import SummaryWriter
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.policy import FQFPolicy
+from ding.model import FQF
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.atari.config.serial.qbert.qbert_fqf_config import qbert_fqf_config, create_config
+from ding.utils import DistContext
+from functools import partial
+from ding.envs import get_vec_env_setting, create_env_manager
+
+
+def main(cfg, create_cfg, seed=0):
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = FQF(**cfg.policy.model)
+ policy = FQFPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Set up other modules, etc. epsilon greedy
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluating at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Sampling data from environments
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Training
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+
+ if collector.envstep >= 10000000:
+ break
+
+
+if __name__ == "__main__":
+ # with DistContext():
+ main(qbert_fqf_config, create_config)
diff --git a/DI-engine/dizoo/atari/entry/spaceinvaders_dqn_eval.py b/DI-engine/dizoo/atari/entry/spaceinvaders_dqn_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..35e15a578cfeb9c1b169c1e4577baa0abda9e9c4
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/spaceinvaders_dqn_eval.py
@@ -0,0 +1,60 @@
+import os
+import gym
+import torch
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+from functools import partial
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.atari.config.serial.spaceinvaders.spaceinvaders_dqn_config import main_config, create_config
+
+
+def main(rl_cfg, seed=0):
+ main_cfg, create_cfg = rl_cfg
+ cfg = compile_config(
+ main_cfg,
+ BaseEnvManager,
+ DQNPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ create_cfg=create_cfg,
+ save_cfg=True
+ )
+
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
+
+ # Set random seed for all package and instance
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model)
+ policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
+
+ # evaluate
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator.eval()
+
+
+if __name__ == "__main__":
+ main(rl_cfg=(main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/atari/entry/spaceinvaders_dqn_main_multi_gpu_ddp.py b/DI-engine/dizoo/atari/entry/spaceinvaders_dqn_main_multi_gpu_ddp.py
new file mode 100644
index 0000000000000000000000000000000000000000..18b2269583de52feb137e57f29e352cdc5d7536a
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/spaceinvaders_dqn_main_multi_gpu_ddp.py
@@ -0,0 +1,71 @@
+import os
+import torch
+from tensorboardX import SummaryWriter
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.atari.config.serial.spaceinvaders.spaceinvaders_dqn_config_multi_gpu_ddp import spaceinvaders_dqn_config, create_config
+from ding.utils import DistContext
+from functools import partial
+from ding.envs import get_vec_env_setting, create_env_manager
+
+
+def main(cfg, create_cfg, seed=0):
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Set up other modules, etc. epsilon greedy
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluating at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Sampling data from environments
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Training
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+
+
+if __name__ == "__main__":
+ with DistContext():
+ main(spaceinvaders_dqn_config, create_config)
diff --git a/DI-engine/dizoo/atari/entry/spaceinvaders_dqn_main_multi_gpu_dp.py b/DI-engine/dizoo/atari/entry/spaceinvaders_dqn_main_multi_gpu_dp.py
new file mode 100644
index 0000000000000000000000000000000000000000..229e25c0e981bee1b64d05fe6b4fadfd7356eb11
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/spaceinvaders_dqn_main_multi_gpu_dp.py
@@ -0,0 +1,71 @@
+import os
+import torch
+from tensorboardX import SummaryWriter
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.atari.config.serial.spaceinvaders.spaceinvaders_dqn_config_multi_gpu_dp import spaceinvaders_dqn_config, create_config
+from ding.torch_utils import DataParallel
+from functools import partial
+from ding.envs import get_vec_env_setting, create_env_manager
+
+
+def main(cfg, create_cfg, seed=0):
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = DQN(**cfg.policy.model)
+ model = DataParallel(model)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Set up other modules, etc. epsilon greedy
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluating at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Sampling data from environments
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Training
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(spaceinvaders_dqn_config, create_config)
diff --git a/DI-engine/dizoo/atari/entry/spaceinvaders_fqf_main.py b/DI-engine/dizoo/atari/entry/spaceinvaders_fqf_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c46e87019ed240e2c827b5d28b266ce2f3586f3
--- /dev/null
+++ b/DI-engine/dizoo/atari/entry/spaceinvaders_fqf_main.py
@@ -0,0 +1,74 @@
+import os
+import torch
+from tensorboardX import SummaryWriter
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.policy import FQFPolicy
+from ding.model import FQF
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.atari.config.serial.spaceinvaders.spaceinvaders_fqf_config import spaceinvaders_fqf_config, create_config
+from ding.utils import DistContext
+from functools import partial
+from ding.envs import get_vec_env_setting, create_env_manager
+
+
+def main(cfg, create_cfg, seed=0):
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = FQF(**cfg.policy.model)
+ policy = FQFPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Set up other modules, etc. epsilon greedy
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluating at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Sampling data from environments
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Training
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+
+ if collector.envstep >= 10000000:
+ break
+
+
+if __name__ == "__main__":
+ # with DistContext():
+ main(spaceinvaders_fqf_config, create_config)
diff --git a/DI-engine/dizoo/atari/envs/__init__.py b/DI-engine/dizoo/atari/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f94505654d0dd14ce9a37e046c12302e36ae3111
--- /dev/null
+++ b/DI-engine/dizoo/atari/envs/__init__.py
@@ -0,0 +1 @@
+from .atari_env import AtariEnv, AtariEnvMR
diff --git a/DI-engine/dizoo/atari/envs/atari_env.py b/DI-engine/dizoo/atari/envs/atari_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8c3298225907edfef9e8448c8a8f7c10d485b35
--- /dev/null
+++ b/DI-engine/dizoo/atari/envs/atari_env.py
@@ -0,0 +1,145 @@
+from typing import Any, List, Union, Sequence, Optional
+import copy
+import numpy as np
+import gym
+
+from ding.envs import BaseEnv, BaseEnvTimestep, update_shape
+from ding.utils import ENV_REGISTRY
+from ding.torch_utils import to_tensor, to_ndarray, to_list
+from .atari_wrappers import wrap_deepmind, wrap_deepmind_mr
+from ding.envs import ObsPlusPrevActRewWrapper
+
+
+@ENV_REGISTRY.register("atari")
+class AtariEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ self._replay_path = None
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = self._make_env()
+ if self._replay_path is not None:
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+ if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward:
+ self._env = ObsPlusPrevActRewWrapper(self._env)
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ obs = self._env.reset()
+ obs = to_ndarray(obs)
+ self._eval_episode_return = 0.
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ assert isinstance(action, np.ndarray), type(action)
+ action = action.item()
+ obs, rew, done, info = self._env.step(action)
+ # self._env.render()
+ self._eval_episode_return += rew
+ obs = to_ndarray(obs)
+ rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transferred to a Tensor with shape (1,)
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def _make_env(self):
+ return wrap_deepmind(
+ self._cfg.env_id,
+ frame_stack=self._cfg.frame_stack,
+ episode_life=self._cfg.is_train,
+ clip_rewards=self._cfg.is_train
+ )
+
+ def __repr__(self) -> str:
+ return "DI-engine Atari Env({})".format(self._cfg.env_id)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_env_num = cfg.pop('collector_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = True
+ return [cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_env_num = cfg.pop('evaluator_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = False
+ return [cfg for _ in range(evaluator_env_num)]
+
+
+@ENV_REGISTRY.register('atari_mr')
+class AtariEnvMR(AtariEnv):
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = self._make_env()
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed'):
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ obs = self._env.reset()
+ obs = to_ndarray(obs)
+ self._eval_episode_return = 0.
+ return obs
+
+ def _make_env(self):
+ return wrap_deepmind_mr(
+ self._cfg.env_id,
+ frame_stack=self._cfg.frame_stack,
+ episode_life=self._cfg.is_train,
+ clip_rewards=self._cfg.is_train
+ )
diff --git a/DI-engine/dizoo/atari/envs/atari_wrappers.py b/DI-engine/dizoo/atari/envs/atari_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..1eb288d1dc18702fb48cc7310ef3d8014a807110
--- /dev/null
+++ b/DI-engine/dizoo/atari/envs/atari_wrappers.py
@@ -0,0 +1,185 @@
+# Borrow a lot from openai baselines:
+# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
+
+import gym
+from collections import deque
+from ding.envs import NoopResetWrapper, MaxAndSkipWrapper, EpisodicLifeWrapper, FireResetWrapper, WarpFrameWrapper, \
+ ScaledFloatFrameWrapper, \
+ ClipRewardWrapper, FrameStackWrapper
+import numpy as np
+from ding.utils.compression_helper import jpeg_data_compressor
+import cv2
+
+
+def wrap_deepmind(env_id, episode_life=True, clip_rewards=True, frame_stack=4, scale=True, warp_frame=True):
+ """Configure environment for DeepMind-style Atari. The observation is
+ channel-first: (c, h, w) instead of (h, w, c).
+
+ :param str env_id: the atari environment id.
+ :param bool episode_life: wrap the episode life wrapper.
+ :param bool clip_rewards: wrap the reward clipping wrapper.
+ :param int frame_stack: wrap the frame stacking wrapper.
+ :param bool scale: wrap the scaling observation wrapper.
+ :param bool warp_frame: wrap the grayscale + resize observation wrapper.
+ :return: the wrapped atari environment.
+ """
+ #assert 'NoFrameskip' in env_id
+ env = gym.make(env_id)
+ env = NoopResetWrapper(env, noop_max=30)
+ env = MaxAndSkipWrapper(env, skip=4)
+ if episode_life:
+ env = EpisodicLifeWrapper(env)
+ if 'FIRE' in env.unwrapped.get_action_meanings():
+ env = FireResetWrapper(env)
+ if warp_frame:
+ env = WarpFrameWrapper(env)
+ if scale:
+ env = ScaledFloatFrameWrapper(env)
+ if clip_rewards:
+ env = ClipRewardWrapper(env)
+ if frame_stack:
+ env = FrameStackWrapper(env, frame_stack)
+ return env
+
+
+def wrap_deepmind_mr(env_id, episode_life=True, clip_rewards=True, frame_stack=4, scale=True, warp_frame=True):
+ """Configure environment for DeepMind-style Atari. The observation is
+ channel-first: (c, h, w) instead of (h, w, c).
+
+ :param str env_id: the atari environment id.
+ :param bool episode_life: wrap the episode life wrapper.
+ :param bool clip_rewards: wrap the reward clipping wrapper.
+ :param int frame_stack: wrap the frame stacking wrapper.
+ :param bool scale: wrap the scaling observation wrapper.
+ :param bool warp_frame: wrap the grayscale + resize observation wrapper.
+ :return: the wrapped atari environment.
+ """
+ assert 'MontezumaRevenge' in env_id
+ env = gym.make(env_id)
+ env = NoopResetWrapper(env, noop_max=30)
+ env = MaxAndSkipWrapper(env, skip=4)
+ if episode_life:
+ env = EpisodicLifeWrapper(env)
+ if 'FIRE' in env.unwrapped.get_action_meanings():
+ env = FireResetWrapper(env)
+ if warp_frame:
+ env = WarpFrameWrapper(env)
+ if scale:
+ env = ScaledFloatFrameWrapper(env)
+ if clip_rewards:
+ env = ClipRewardWrapper(env)
+ if frame_stack:
+ env = FrameStackWrapper(env, frame_stack)
+ return env
+
+
+class TimeLimit(gym.Wrapper):
+
+ def __init__(self, env, max_episode_steps=None):
+ super(TimeLimit, self).__init__(env)
+ self._max_episode_steps = max_episode_steps
+ self._elapsed_steps = 0
+
+ def step(self, ac):
+ observation, reward, done, info = self.env.step(ac)
+ self._elapsed_steps += 1
+ if self._elapsed_steps >= self._max_episode_steps:
+ done = True
+ info['TimeLimit.truncated'] = True
+ return observation, reward, done, info
+
+ def reset(self, **kwargs):
+ self._elapsed_steps = 0
+ return self.env.reset(**kwargs)
+
+
+class WarpFrame(gym.ObservationWrapper):
+
+ def __init__(self, env, width=84, height=84, grayscale=True, dict_space_key=None):
+ """
+ Warp frames to 84x84 as done in the Nature paper and later work.
+ If the environment uses dictionary observations, `dict_space_key` can be specified which indicates which
+ observation should be warped.
+ """
+ super().__init__(env)
+ self._width = width
+ self._height = height
+ self._grayscale = grayscale
+ self._key = dict_space_key
+ if self._grayscale:
+ num_colors = 1
+ else:
+ num_colors = 3
+
+ new_space = gym.spaces.Box(
+ low=0,
+ high=255,
+ shape=(self._height, self._width, num_colors),
+ dtype=np.uint8,
+ )
+ if self._key is None:
+ original_space = self.observation_space
+ self.observation_space = new_space
+ else:
+ original_space = self.observation_space.spaces[self._key]
+ self.observation_space.spaces[self._key] = new_space
+ assert original_space.dtype == np.uint8 and len(original_space.shape) == 3
+
+ def observation(self, obs):
+ if self._key is None:
+ frame = obs
+ else:
+ frame = obs[self._key]
+
+ if self._grayscale:
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
+ frame = cv2.resize(frame, (self._width, self._height), interpolation=cv2.INTER_AREA)
+ if self._grayscale:
+ frame = np.expand_dims(frame, -1)
+
+ if self._key is None:
+ obs = frame
+ else:
+ obs = obs.copy()
+ obs[self._key] = frame
+ return obs
+
+
+class JpegWrapper(gym.Wrapper):
+
+ def __init__(self, env, cvt_string=True):
+ """
+ Overview: convert the observation into string to save memory
+ """
+ super().__init__(env)
+ self.cvt_string = cvt_string
+
+ def step(self, action):
+ observation, reward, done, info = self.env.step(action)
+ observation = observation.astype(np.uint8)
+
+ if self.cvt_string:
+ observation = jpeg_data_compressor(observation)
+
+ return observation, reward, done, info
+
+ def reset(self, **kwargs):
+ observation = self.env.reset(**kwargs)
+ observation = observation.astype(np.uint8)
+
+ if self.cvt_string:
+ observation = jpeg_data_compressor(observation)
+
+ return observation
+
+
+class GameWrapper(gym.Wrapper):
+
+ def __init__(self, env):
+ """
+ Overview: warp env to adapt the game interface
+ """
+ super().__init__(env)
+
+ def legal_actions(self):
+ return [_ for _ in range(self.env.action_space.n)]
diff --git a/DI-engine/dizoo/atari/envs/test_atari_env.py b/DI-engine/dizoo/atari/envs/test_atari_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..63c6018bcc3cbacbc5a90257f3db02787628789b
--- /dev/null
+++ b/DI-engine/dizoo/atari/envs/test_atari_env.py
@@ -0,0 +1,63 @@
+import pytest
+import numpy as np
+import gym
+from easydict import EasyDict
+import atari_py
+
+from dizoo.atari.envs import AtariEnv, AtariEnvMR
+
+
+@pytest.mark.envtest
+class TestAtariEnv:
+
+ def test_pong(self):
+ cfg = {'env_id': 'PongNoFrameskip-v4', 'frame_stack': 4, 'is_train': True}
+ cfg = EasyDict(cfg)
+ pong_env = AtariEnv(cfg)
+ pong_env.seed(0)
+ obs = pong_env.reset()
+ assert obs.shape == (cfg.frame_stack, 84, 84)
+ act_dim = pong_env.action_space.n
+ i = 0
+ while True:
+ # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
+ # can generate legal random action.
+ if i < 10:
+ random_action = np.random.choice(range(act_dim), size=(1, ))
+ i += 1
+ else:
+ random_action = pong_env.random_action()
+ timestep = pong_env.step(random_action)
+ assert timestep.obs.shape == (cfg.frame_stack, 84, 84)
+ assert timestep.reward.shape == (1, )
+ if timestep.done:
+ assert 'eval_episode_return' in timestep.info, timestep.info
+ break
+ print(pong_env.observation_space, pong_env.action_space, pong_env.reward_space)
+ print('eval_episode_return: {}'.format(timestep.info['eval_episode_return']))
+ pong_env.close()
+
+ def test_montezuma_revenge(self):
+ cfg = {'env_id': 'MontezumaRevengeDeterministic-v4', 'frame_stack': 4, 'is_train': True}
+ cfg = EasyDict(cfg)
+ mr_env = AtariEnvMR(cfg)
+ mr_env.seed(0)
+ obs = mr_env.reset()
+ assert obs.shape == (cfg.frame_stack, 84, 84)
+ act_dim = mr_env.action_space.n
+ i = 0
+ while True:
+ if i < 10:
+ random_action = np.random.choice(range(act_dim), size=(1, ))
+ i += 1
+ else:
+ random_action = mr_env.random_action()
+ timestep = mr_env.step(random_action)
+ assert timestep.obs.shape == (cfg.frame_stack, 84, 84)
+ assert timestep.reward.shape == (1, )
+ if timestep.done:
+ assert 'eval_episode_return' in timestep.info, timestep.info
+ break
+ print(mr_env.observation_space, mr_env.action_space, mr_env.reward_space)
+ print('eval_episode_return: {}'.format(timestep.info['eval_episode_return']))
+ mr_env.close()
diff --git a/DI-engine/dizoo/atari/example/atari_dqn.py b/DI-engine/dizoo/atari/example/atari_dqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..660f8576d9f09f73b15fc62adf4eaf9a7e1e4a4a
--- /dev/null
+++ b/DI-engine/dizoo/atari/example/atari_dqn.py
@@ -0,0 +1,50 @@
+from copy import deepcopy
+from ditk import logging
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, nstep_reward_enhancer, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.atari.envs.atari_env import AtariEnv
+from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_cfg = deepcopy(cfg.env)
+ collector_cfg.is_train = True
+ evaluator_cfg = deepcopy(cfg.env)
+ evaluator_cfg.is_train = False
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+ task.use(termination_checker(max_env_step=int(1e7)))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/atari/example/atari_dqn_ddp.py b/DI-engine/dizoo/atari/example/atari_dqn_ddp.py
new file mode 100644
index 0000000000000000000000000000000000000000..217c5a5e10a6a4a158abb040cce7d034d7f11f39
--- /dev/null
+++ b/DI-engine/dizoo/atari/example/atari_dqn_ddp.py
@@ -0,0 +1,59 @@
+from copy import deepcopy
+from ditk import logging
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.utils import DistContext, get_rank
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, nstep_reward_enhancer, online_logger, ddp_termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.atari.envs.atari_env import AtariEnv
+from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'pong_dqn_seed0_ddp'
+ main_config.policy.multi_gpu = True
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with DistContext():
+ rank = get_rank()
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_cfg = deepcopy(cfg.env)
+ collector_cfg.is_train = True
+ evaluator_cfg = deepcopy(cfg.env)
+ evaluator_cfg.is_train = False
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ if rank == 0:
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ if rank == 0:
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+ task.use(online_logger(record_train_iter=True))
+ task.use(ddp_termination_checker(max_env_step=int(1e7), rank=rank))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/atari/example/atari_dqn_dist.py b/DI-engine/dizoo/atari/example/atari_dqn_dist.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b9f01699a90a4b7aab4d7efd9d79000e3b316a3
--- /dev/null
+++ b/DI-engine/dizoo/atari/example/atari_dqn_dist.py
@@ -0,0 +1,85 @@
+from copy import deepcopy
+from ditk import logging
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger, termination_checker, nstep_reward_enhancer, \
+ online_logger
+from ding.utils import set_pkg_seed
+from dizoo.atari.envs.atari_env import AtariEnv
+from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'pong_dqn_seed0_ditask_dist'
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ assert task.router.is_active, "Please execute this script with ditask! See note in the header."
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ if 'learner' in task.router.labels:
+ logging.info("Learner running on node {}".format(task.router.node_id))
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ task.use(
+ context_exchanger(
+ send_keys=["train_iter"],
+ recv_keys=["trajectories", "episodes", "env_step", "env_episode"],
+ skip_n_iter=0
+ )
+ )
+ task.use(model_exchanger(model, is_learner=True))
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+
+ elif 'evaluator' in task.router.labels:
+ logging.info("Evaluator running on node {}".format(task.router.node_id))
+ evaluator_cfg = deepcopy(cfg.env)
+ evaluator_cfg.is_train = False
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+ task.use(context_exchanger(recv_keys=["train_iter", "env_step"], skip_n_iter=1))
+ task.use(model_exchanger(model, is_learner=False))
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(CkptSaver(policy, cfg.exp_name, save_finish=False))
+ task.use(online_logger(record_train_iter=True))
+
+ elif 'collector' in task.router.labels:
+ logging.info("Collector running on node {}".format(task.router.node_id))
+ collector_cfg = deepcopy(cfg.env)
+ collector_cfg.is_train = True
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ task.use(
+ context_exchanger(
+ send_keys=["trajectories", "episodes", "env_step", "env_episode"],
+ recv_keys=["train_iter"],
+ skip_n_iter=1
+ )
+ )
+ task.use(model_exchanger(model, is_learner=False))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(termination_checker(max_env_step=int(1e7)))
+ else:
+ raise KeyError("invalid router labels: {}".format(task.router.labels))
+
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/atari/example/atari_dqn_dist_ddp.py b/DI-engine/dizoo/atari/example/atari_dqn_dist_ddp.py
new file mode 100644
index 0000000000000000000000000000000000000000..5dbfc4e65c8d48900be1e157d16e4ebb6419e458
--- /dev/null
+++ b/DI-engine/dizoo/atari/example/atari_dqn_dist_ddp.py
@@ -0,0 +1,106 @@
+from copy import deepcopy
+from ditk import logging
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger, termination_checker, nstep_reward_enhancer, \
+ online_logger
+from ding.utils import set_pkg_seed
+from dizoo.atari.envs.atari_env import AtariEnv
+from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config
+
+logging.getLogger().setLevel(logging.INFO)
+main_config.exp_name = 'pong_dqn_seed0_ditask_dist_ddp'
+
+
+def learner():
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model, enable_field=['learn'])
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ assert task.router.is_active, "Please execute this script with ditask! See note in the header."
+ logging.info("Learner running on node {}".format(task.router.node_id))
+
+ from ding.utils import DistContext, get_rank
+ with DistContext():
+ rank = get_rank()
+ task.use(
+ context_exchanger(
+ send_keys=["train_iter"],
+ recv_keys=["trajectories", "episodes", "env_step", "env_episode"],
+ skip_n_iter=0
+ )
+ )
+ task.use(model_exchanger(model, is_learner=True))
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ if rank == 0:
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+ task.run()
+
+
+def collector():
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model, enable_field=['collect'])
+ collector_cfg = deepcopy(cfg.env)
+ collector_cfg.is_train = True
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ assert task.router.is_active, "Please execute this script with ditask! See note in the header."
+ logging.info("Collector running on node {}".format(task.router.node_id))
+
+ task.use(
+ context_exchanger(
+ send_keys=["trajectories", "episodes", "env_step", "env_episode"],
+ recv_keys=["train_iter"],
+ skip_n_iter=1
+ )
+ )
+ task.use(model_exchanger(model, is_learner=False))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(termination_checker(max_env_step=int(1e7)))
+ task.run()
+
+
+def evaluator():
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model, enable_field=['eval'])
+ evaluator_cfg = deepcopy(cfg.env)
+ evaluator_cfg.is_train = False
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ assert task.router.is_active, "Please execute this script with ditask! See note in the header."
+ logging.info("Evaluator running on node {}".format(task.router.node_id))
+
+ task.use(context_exchanger(recv_keys=["train_iter", "env_step"], skip_n_iter=1))
+ task.use(model_exchanger(model, is_learner=False))
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(CkptSaver(policy, cfg.exp_name, save_finish=False))
+ task.use(online_logger(record_train_iter=True))
+ task.run()
diff --git a/DI-engine/dizoo/atari/example/atari_dqn_dist_rdma.py b/DI-engine/dizoo/atari/example/atari_dqn_dist_rdma.py
new file mode 100644
index 0000000000000000000000000000000000000000..b364b37966a7a6b9b75b3bc5c691add1f08e74e4
--- /dev/null
+++ b/DI-engine/dizoo/atari/example/atari_dqn_dist_rdma.py
@@ -0,0 +1,72 @@
+from copy import deepcopy
+from ditk import logging
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, context_exchanger, model_exchanger, termination_checker, nstep_reward_enhancer, \
+ online_logger
+from ding.utils import set_pkg_seed
+from dizoo.atari.envs.atari_env import AtariEnv
+from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'pong_dqn_seed0_dist_rdma'
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ assert task.router.is_active, "Please execute this script with ditask! See note in the header."
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ if 'learner' in task.router.labels:
+ logging.info("Learner running on node {}".format(task.router.node_id))
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ task.use(
+ context_exchanger(
+ send_keys=["train_iter"],
+ recv_keys=["trajectories", "episodes", "env_step", "env_episode"],
+ skip_n_iter=0
+ )
+ )
+ task.use(model_exchanger(model, is_learner=True))
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+
+ elif 'collector' in task.router.labels:
+ logging.info("Collector running on node {}".format(task.router.node_id))
+ collector_cfg = deepcopy(cfg.env)
+ collector_cfg.is_train = True
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ task.use(
+ context_exchanger(
+ send_keys=["trajectories", "episodes", "env_step", "env_episode"],
+ recv_keys=["train_iter"],
+ skip_n_iter=1
+ )
+ )
+ task.use(model_exchanger(model, is_learner=False))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(termination_checker(max_env_step=int(1e7)))
+ else:
+ raise KeyError("invalid router labels: {}".format(task.router.labels))
+
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/atari/example/atari_dqn_dp.py b/DI-engine/dizoo/atari/example/atari_dqn_dp.py
new file mode 100644
index 0000000000000000000000000000000000000000..540fb4ab0773b329ca397ea07cb12b8616d2c370
--- /dev/null
+++ b/DI-engine/dizoo/atari/example/atari_dqn_dp.py
@@ -0,0 +1,53 @@
+from copy import deepcopy
+from ditk import logging
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.torch_utils import DataParallel
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, nstep_reward_enhancer, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.atari.envs.atari_env import AtariEnv
+from dizoo.atari.config.serial.pong.pong_dqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'pong_dqn_seed0_dp'
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_cfg = deepcopy(cfg.env)
+ collector_cfg.is_train = True
+ evaluator_cfg = deepcopy(cfg.env)
+ evaluator_cfg.is_train = False
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ model = DataParallel(model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+ task.use(termination_checker(max_env_step=int(1e7)))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/atari/example/atari_ppo.py b/DI-engine/dizoo/atari/example/atari_ppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..215c93a18f887dbf52204e9a7dc7a1d317299cbf
--- /dev/null
+++ b/DI-engine/dizoo/atari/example/atari_ppo.py
@@ -0,0 +1,47 @@
+from copy import deepcopy
+from ditk import logging
+from ding.model import VAC
+from ding.policy import PPOPolicy
+from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
+ gae_estimator, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.atari.envs.atari_env import AtariEnv
+from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_cfg = deepcopy(cfg.env)
+ collector_cfg.is_train = True
+ evaluator_cfg = deepcopy(cfg.env)
+ evaluator_cfg.is_train = False
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(gae_estimator(cfg, policy.collect_mode))
+ task.use(multistep_trainer(cfg, policy.learn_mode))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+ task.use(termination_checker(max_env_step=int(1e7)))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/atari/example/atari_ppo_ddp.py b/DI-engine/dizoo/atari/example/atari_ppo_ddp.py
new file mode 100644
index 0000000000000000000000000000000000000000..92267e846f4db557edc572100fa3e8256818e306
--- /dev/null
+++ b/DI-engine/dizoo/atari/example/atari_ppo_ddp.py
@@ -0,0 +1,56 @@
+from copy import deepcopy
+from ditk import logging
+from ding.model import VAC
+from ding.policy import PPOPolicy
+from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
+ gae_estimator, ddp_termination_checker, online_logger
+from ding.utils import set_pkg_seed, DistContext, get_rank, get_world_size
+from dizoo.atari.envs.atari_env import AtariEnv
+from dizoo.atari.config.serial.pong.pong_onppo_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ with DistContext():
+ rank, world_size = get_rank(), get_world_size()
+ main_config.example = 'pong_ppo_seed0_ddp_avgsplit'
+ main_config.policy.multi_gpu = True
+ main_config.policy.learn.batch_size = main_config.policy.learn.batch_size // world_size
+ main_config.policy.collect.n_sample = main_config.policy.collect.n_sample // world_size
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_cfg = deepcopy(cfg.env)
+ collector_cfg.is_train = True
+ evaluator_cfg = deepcopy(cfg.env)
+ evaluator_cfg.is_train = False
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(collector_cfg) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: AtariEnv(evaluator_cfg) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+
+ if rank == 0:
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(gae_estimator(cfg, policy.collect_mode))
+ task.use(multistep_trainer(cfg, policy.learn_mode))
+ if rank == 0:
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+ task.use(ddp_termination_checker(max_env_step=int(1e7), rank=rank))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/beergame/__init__.py b/DI-engine/dizoo/beergame/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/beergame/config/beergame_onppo_config.py b/DI-engine/dizoo/beergame/config/beergame_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f6bec16a87b8c853b94a6119fdc09356d1f3d430
--- /dev/null
+++ b/DI-engine/dizoo/beergame/config/beergame_onppo_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+
+beergame_ppo_config = dict(
+ exp_name='beergame_ppo_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=200,
+ role=0, # 0-3 : retailer, warehouse, distributor, manufacturer
+ agent_type='bs',
+ # type of co-player, 'bs'- base stock, 'Strm'- use Sterman formula to model typical human behavior
+ demandDistribution=0
+ # distribution of demand, default=0, '0=uniform, 1=normal distribution, 2=the sequence of 4,4,4,4,8,..., 3= basket data, 4= forecast data'
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=50, # statedim * multPerdInpt= 5 * 10
+ action_shape=5, # the quantity relative to the arriving order
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64, 128],
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=320,
+ learning_rate=3e-4,
+ entropy_weight=0.001,
+ adv_norm=True,
+ value_norm=True,
+ # for onppo, when we recompute adv, we need the key done in data to split traj, so we must
+ # use ignore_done=False here,
+ # but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
+ # for halfcheetah, the length=1000
+ ignore_done=True,
+ ),
+ collect=dict(
+ n_episode=8,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ collector=dict(
+ get_train_sample=True,
+ reward_shaping=True, # whether use total return to reshape reward
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+)
+beergame_ppo_config = EasyDict(beergame_ppo_config)
+main_config = beergame_ppo_config
+beergame_ppo_create_config = dict(
+ env=dict(
+ type='beergame',
+ import_names=['dizoo.beergame.envs.beergame_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo'),
+ collector=dict(type='episode', ),
+)
+beergame_ppo_create_config = EasyDict(beergame_ppo_create_config)
+create_config = beergame_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c beergame_onppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/beergame/entry/beergame_eval.py b/DI-engine/dizoo/beergame/entry/beergame_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..5299107e788a97ff3fe4c45e62c0518f7da7a021
--- /dev/null
+++ b/DI-engine/dizoo/beergame/entry/beergame_eval.py
@@ -0,0 +1,42 @@
+import os
+import torch
+from tensorboardX import SummaryWriter
+
+from ding.config import compile_config
+from ding.worker import InteractionSerialEvaluator
+from ding.envs import BaseEnvManager
+from ding.policy import PPOPolicy
+from ding.model import VAC
+from ding.utils import set_pkg_seed
+from dizoo.beergame.config.beergame_onppo_config import beergame_ppo_config, beergame_ppo_create_config
+from ding.envs import get_vec_env_setting
+from functools import partial
+
+
+def main(cfg, seed=0):
+ env_fn = None
+ cfg, create_cfg = beergame_ppo_config, beergame_ppo_create_config
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ cfg.env.manager.auto_reset = False
+ evaluator_env = BaseEnvManager(env_fn=[partial(env_fn, cfg=c) for c in evaluator_env_cfg], cfg=cfg.env.manager)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+ model = VAC(**cfg.policy.model)
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ policy = PPOPolicy(cfg.policy, model=model)
+ # set the path to save figure
+ cfg.policy.eval.evaluator.figure_path = './'
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ # load model
+ model.load_state_dict(torch.load('model path', map_location='cpu')["model"])
+ evaluator.eval(None, -1, -1)
+
+
+if __name__ == "__main__":
+ beergame_ppo_config.exp_name = 'beergame_evaluate'
+ main(beergame_ppo_config)
\ No newline at end of file
diff --git a/DI-engine/dizoo/beergame/envs/BGAgent.py b/DI-engine/dizoo/beergame/envs/BGAgent.py
new file mode 100644
index 0000000000000000000000000000000000000000..d06866afc2e59663245e86505016ff2a38dd3140
--- /dev/null
+++ b/DI-engine/dizoo/beergame/envs/BGAgent.py
@@ -0,0 +1,152 @@
+# Code Reference: https://github.com/OptMLGroup/DeepBeerInventory-RL.
+import argparse
+import numpy as np
+
+
+# Here we want to define the agent class for the BeerGame
+class Agent(object):
+ # initializes the agents with initial values for IL, OO and saves self.agentNum for recognizing the agents.
+ def __init__(
+ self, agentNum: int, IL: int, AO: int, AS: int, c_h: float, c_p: float, eta: int, compuType: str,
+ config: argparse.Namespace
+ ) -> None:
+ self.agentNum = agentNum
+ self.IL = IL # Inventory level of each agent - changes during the game
+ self.OO = 0 # Open order of each agent - changes during the game
+ self.ASInitial = AS # the initial arriving shipment.
+ self.ILInitial = IL # IL at which we start each game with this number
+ self.AOInitial = AO # OO at which we start each game with this number
+ self.config = config # an instance of config is stored inside the class
+ self.curState = [] # this function gets the current state of the game
+ self.nextState = []
+ self.curReward = 0 # the reward observed at the current step
+ self.cumReward = 0 # cumulative reward; reset at the beginning of each episode
+ self.totRew = 0 # it is reward of all players obtained for the current player.
+ self.c_h = c_h # holding cost
+ self.c_p = c_p # backorder cost
+ self.eta = eta # the total cost regulazer
+ self.AS = np.zeros((1, 1)) # arriced shipment
+ self.AO = np.zeros((1, 1)) # arrived order
+ self.action = 0 # the action at time t
+ self.compType = compuType
+ # self.compTypeTrain = compuType # rnd -> random / srdqn-> srdqn / Strm-> formula-Rong2008 / bs -> optimal policy if exists
+ # self.compTypeTest = compuType # rnd -> random / srdqn-> srdqn / Strm-> formula-Rong2008 / bs -> optimal policy if exists
+ self.alpha_b = self.config.alpha_b[self.agentNum] # parameters for the formula
+ self.betta_b = self.config.betta_b[self.agentNum] # parameters for the formula
+ if self.config.demandDistribution == 0:
+ self.a_b = np.mean((self.config.demandUp, self.config.demandLow)) # parameters for the formula
+ self.b_b = np.mean((self.config.demandUp, self.config.demandLow)) * (
+ np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) +
+ np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum]))
+ ) # parameters for the formula
+ elif self.config.demandDistribution == 1 or self.config.demandDistribution == 3 or self.config.demandDistribution == 4:
+ self.a_b = self.config.demandMu # parameters for the formula
+ self.b_b = self.config.demandMu * (
+ np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) +
+ np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum]))
+ ) # parameters for the formula
+ elif self.config.demandDistribution == 2:
+ self.a_b = 8 # parameters for the formula
+ self.b_b = (3 / 4.) * 8 * (
+ np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) +
+ np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum]))
+ ) # parameters for the formula
+ elif self.config.demandDistribution == 3:
+ self.a_b = 10 # parameters for the formula
+ self.b_b = 7 * (
+ np.mean((self.config.leadRecItemLow[self.agentNum], self.config.leadRecItemUp[self.agentNum])) +
+ np.mean((self.config.leadRecOrderLow[self.agentNum], self.config.leadRecOrderUp[self.agentNum]))
+ ) # parameters for the formula
+ else:
+ raise Exception('The demand distribution is not defined or it is not a valid type.!')
+
+ self.hist = [] # this is used for plotting - keeps the history for only one game
+ self.hist2 = [] # this is used for animation usage
+ self.srdqnBaseStock = [] # this holds the base stock levels that srdqn has came up with. added on Nov 8, 2017
+ self.T = 0
+ self.bsBaseStock = 0
+ self.init_bsBaseStock = 0
+ self.nextObservation = []
+
+ if self.compType == 'srdqn':
+ # sets the initial input of the network
+ self.currentState = np.stack(
+ [self.curState for _ in range(self.config.multPerdInpt)], axis=0
+ ) # multPerdInpt observations stacked. each row is an observation
+
+ # reset player information
+ def resetPlayer(self, T: int):
+ self.IL = self.ILInitial
+ self.OO = 0
+ self.AS = np.squeeze(
+ np.zeros((1, T + max(self.config.leadRecItemUp) + max(self.config.leadRecOrderUp) + 10))
+ ) # arriced shipment
+ self.AO = np.squeeze(
+ np.zeros((1, T + max(self.config.leadRecItemUp) + max(self.config.leadRecOrderUp) + 10))
+ ) # arrived order
+ if self.agentNum != 0:
+ for i in range(self.config.leadRecOrderUp_aux[self.agentNum - 1]):
+ self.AO[i] = self.AOInitial[self.agentNum - 1]
+ for i in range(self.config.leadRecItemUp[self.agentNum]):
+ self.AS[i] = self.ASInitial
+ self.curReward = 0 # the reward observed at the current step
+ self.cumReward = 0 # cumulative reward; reset at the begining of each episode
+ self.action = []
+ self.hist = []
+ self.hist2 = []
+ self.srdqnBaseStock = [] # this holds the base stock levels that srdqn has came up with. added on Nov 8, 2017
+ self.T = T
+ self.curObservation = self.getCurState(1) # this function gets the current state of the game
+ self.nextObservation = []
+ if self.compType == 'srdqn':
+ self.currentState = np.stack([self.curObservation for _ in range(self.config.multPerdInpt)], axis=0)
+
+ # updates the IL and OO at time t, after recieving "rec" number of items
+ def recieveItems(self, time: int) -> None:
+ self.IL = self.IL + self.AS[time] # inverntory level update
+ self.OO = self.OO - self.AS[time] # invertory in transient update
+
+ # find action Value associated with the action list
+ def actionValue(self, curTime: int) -> int:
+ if self.config.fixedAction:
+ a = self.config.actionList[np.argmax(self.action)]
+ else:
+ # "d + x" rule
+ if self.compType == 'srdqn':
+ a = max(0, self.config.actionList[np.argmax(self.action)] * self.config.action_step + self.AO[curTime])
+ elif self.compType == 'rnd':
+ a = max(0, self.config.actionList[np.argmax(self.action)] + self.AO[curTime])
+ else:
+ a = max(0, self.config.actionListOpt[np.argmax(self.action)])
+
+ return a
+
+ # getReward returns the reward at the current state
+ def getReward(self) -> None:
+ # cost (holding + backorder) for one time unit
+ self.curReward = (self.c_p * max(0, -self.IL) + self.c_h * max(0, self.IL)) / 200. # self.config.Ttest #
+ self.curReward = -self.curReward
+ # make reward negative, because it is the cost
+
+ # sum total reward of each agent
+ self.cumReward = self.config.gamma * self.cumReward + self.curReward
+
+ # This function returns a np.array of the current state of the agent
+ def getCurState(self, t: int) -> np.ndarray:
+ if self.config.ifUseASAO:
+ if self.config.if_use_AS_t_plus_1:
+ curState = np.array(
+ [-1 * (self.IL < 0) * self.IL, 1 * (self.IL > 0) * self.IL, self.OO, self.AS[t], self.AO[t]]
+ )
+ else:
+ curState = np.array(
+ [-1 * (self.IL < 0) * self.IL, 1 * (self.IL > 0) * self.IL, self.OO, self.AS[t - 1], self.AO[t]]
+ )
+ else:
+ curState = np.array([-1 * (self.IL < 0) * self.IL, 1 * (self.IL > 0) * self.IL, self.OO])
+
+ if self.config.ifUseActionInD:
+ a = self.config.actionList[np.argmax(self.action)]
+ curState = np.concatenate((curState, np.array([a])))
+
+ return curState
diff --git a/DI-engine/dizoo/beergame/envs/__init__.py b/DI-engine/dizoo/beergame/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ffbfd4521883e47fdd3c992ff2f27ddb331223
--- /dev/null
+++ b/DI-engine/dizoo/beergame/envs/__init__.py
@@ -0,0 +1,2 @@
+from .clBeergame import clBeerGame
+from .beergame_core import BeerGame
diff --git a/DI-engine/dizoo/beergame/envs/beergame_core.py b/DI-engine/dizoo/beergame/envs/beergame_core.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f0ac619100aa776dcf9181f7c2a664319e85cd8
--- /dev/null
+++ b/DI-engine/dizoo/beergame/envs/beergame_core.py
@@ -0,0 +1,112 @@
+from __future__ import print_function
+from dizoo.beergame.envs import clBeerGame
+from torch import Tensor
+import numpy as np
+import random
+from .utils import get_config, update_config
+import gym
+import os
+from typing import Optional
+
+
+class BeerGame():
+
+ def __init__(self, role: int, agent_type: str, demandDistribution: int) -> None:
+ self._cfg, unparsed = get_config()
+ self._role = role
+ # prepare loggers and directories
+ # prepare_dirs_and_logger(self._cfg)
+ self._cfg = update_config(self._cfg)
+
+ # set agent type
+ if agent_type == 'bs':
+ self._cfg.agentTypes = ["bs", "bs", "bs", "bs"]
+ elif agent_type == 'Strm':
+ self._cfg.agentTypes = ["Strm", "Strm", "Strm", "Strm"]
+ self._cfg.agentTypes[role] = "srdqn"
+
+ self._cfg.demandDistribution = demandDistribution
+
+ # load demands:0=uniform, 1=normal distribution, 2=the sequence of 4,4,4,4,8,..., 3= basket data, 4= forecast data
+ if self._cfg.observation_data:
+ adsr = 'data/demandTr-obs-'
+ elif self._cfg.demandDistribution == 3:
+ if self._cfg.scaled:
+ adsr = 'data/basket_data/scaled'
+ else:
+ adsr = 'data/basket_data'
+ direc = os.path.realpath(adsr + '/demandTr-' + str(self._cfg.data_id) + '.npy')
+ self._demandTr = np.load(direc)
+ print("loaded training set=", direc)
+ elif self._cfg.demandDistribution == 4:
+ if self._cfg.scaled:
+ adsr = 'data/forecast_data/scaled'
+ else:
+ adsr = 'data/forecast_data'
+ direc = os.path.realpath(adsr + '/demandTr-' + str(self._cfg.data_id) + '.npy')
+ self._demandTr = np.load(direc)
+ print("loaded training set=", direc)
+ else:
+ if self._cfg.demandDistribution == 0: # uniform
+ self._demandTr = np.random.randint(0, self._cfg.demandUp, size=[self._cfg.demandSize, self._cfg.TUp])
+ elif self._cfg.demandDistribution == 1: # normal distribution
+ self._demandTr = np.round(
+ np.random.normal(
+ self._cfg.demandMu, self._cfg.demandSigma, size=[self._cfg.demandSize, self._cfg.TUp]
+ )
+ ).astype(int)
+ elif self._cfg.demandDistribution == 2: # the sequence of 4,4,4,4,8,...
+ self._demandTr = np.concatenate(
+ (4 * np.ones((self._cfg.demandSize, 4)), 8 * np.ones((self._cfg.demandSize, 98))), axis=1
+ ).astype(int)
+
+ # initilize an instance of Beergame
+ self._env = clBeerGame(self._cfg)
+ self.observation_space = gym.spaces.Box(
+ low=float("-inf"),
+ high=float("inf"),
+ shape=(self._cfg.stateDim * self._cfg.multPerdInpt, ),
+ dtype=np.float32
+ ) # state_space = state_dim * m (considering the reward delay)
+ self.action_space = gym.spaces.Discrete(self._cfg.actionListLen) # length of action list
+ self.reward_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32)
+
+ # get the length of the demand.
+ self._demand_len = np.shape(self._demandTr)[0]
+
+ def reset(self):
+ self._env.resetGame(demand=self._demandTr[random.randint(0, self._demand_len - 1)])
+ obs = [i for item in self._env.players[self._role].currentState for i in item]
+ return obs
+
+ def seed(self, seed: int) -> None:
+ self._seed = seed
+ np.random.seed(self._seed)
+
+ def close(self) -> None:
+ pass
+
+ def step(self, action: np.ndarray):
+ self._env.handelAction(action)
+ self._env.next()
+ newstate = np.append(
+ self._env.players[self._role].currentState[1:, :], [self._env.players[self._role].nextObservation], axis=0
+ )
+ self._env.players[self._role].currentState = newstate
+ obs = [i for item in newstate for i in item]
+ rew = self._env.players[self._role].curReward
+ done = (self._env.curTime == self._env.T)
+ info = {}
+ return obs, rew, done, info
+
+ def reward_shaping(self, reward: Tensor) -> Tensor:
+ self._totRew, self._cumReward = self._env.distTotReward(self._role)
+ reward += (self._cfg.distCoeff / 3) * ((self._totRew - self._cumReward) / (self._env.T))
+ return reward
+
+ def enable_save_figure(self, figure_path: Optional[str] = None) -> None:
+ self._cfg.ifSaveFigure = True
+ if figure_path is None:
+ figure_path = './'
+ self._cfg.figure_dir = figure_path
+ self._env.doTestMid(self._demandTr[random.randint(0, self._demand_len - 1)])
diff --git a/DI-engine/dizoo/beergame/envs/beergame_env.py b/DI-engine/dizoo/beergame/envs/beergame_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..f48cdd93cfd0c8a021aeaf90799b390196377924
--- /dev/null
+++ b/DI-engine/dizoo/beergame/envs/beergame_env.py
@@ -0,0 +1,84 @@
+import numpy as np
+from dizoo.beergame.envs.beergame_core import BeerGame
+from typing import Union, List, Optional
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.utils import ENV_REGISTRY
+from ding.torch_utils import to_ndarray
+import copy
+
+
+@ENV_REGISTRY.register('beergame')
+class BeerGameEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = BeerGame(self._cfg.role, self._cfg.agent_type, self._cfg.demandDistribution)
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._reward_space = self._env.reward_space
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+ obs = to_ndarray(obs).astype(np.float32)
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
+ if isinstance(action, np.ndarray) and action.shape == (1, ):
+ action = action.squeeze() # 0-dim array
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ obs = to_ndarray(obs).astype(np.float32)
+ rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transfered to a array with shape (1,)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def reward_shaping(self, transitions: List[dict]) -> List[dict]:
+ new_transitions = copy.deepcopy(transitions)
+ for trans in new_transitions:
+ trans['reward'] = self._env.reward_shaping(trans['reward'])
+ return new_transitions
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ if isinstance(random_action, int):
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ def enable_save_figure(self, figure_path: Optional[str] = None) -> None:
+ self._env.enable_save_figure(figure_path)
+
+ @property
+ def observation_space(self) -> int:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> int:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> int:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine Beergame Env"
diff --git a/DI-engine/dizoo/beergame/envs/clBeergame.py b/DI-engine/dizoo/beergame/envs/clBeergame.py
new file mode 100644
index 0000000000000000000000000000000000000000..9a237fbb68060d39a62d0aecdd542d13cf60bfaf
--- /dev/null
+++ b/DI-engine/dizoo/beergame/envs/clBeergame.py
@@ -0,0 +1,439 @@
+# Code Reference: https://github.com/OptMLGroup/DeepBeerInventory-RL.
+import numpy as np
+from random import randint
+from .BGAgent import Agent
+from matplotlib import rc
+rc('text', usetex=True)
+from .plotting import plotting, savePlot
+import matplotlib.pyplot as plt
+import os
+import time
+from time import gmtime, strftime
+
+
+class clBeerGame(object):
+
+ def __init__(self, config):
+ self.config = config
+ self.curGame = 0 # The number associated with the current game (counter of the game)
+ self.curTime = 0
+ self.totIterPlayed = 0 # total iterations of the game, played so far in this and previous games
+ self.players = self.createAgent() # create the agents
+ self.T = 0
+ self.demand = []
+ self.ifOptimalSolExist = self.config.ifOptimalSolExist
+ self.getOptimalSol()
+ self.totRew = 0 # it is reward of all players obtained for the current player.
+ self.resultTest = []
+ self.runnerMidlResults = [] # stores the results to use in runner comparisons
+ self.runnerFinlResults = [] # stores the results to use in runner comparisons
+ self.middleTestResult = [
+ ] # stores the whole middle results of bs, Strm, and random to avoid doing same tests multiple of times.
+ self.runNumber = 0 # the runNumber which is used when use runner
+ self.strNum = 0 # the runNumber which is used when use runner
+
+ # createAgent : Create agent objects (agentNum,IL,OO,c_h,c_p,type,config)
+ def createAgent(self):
+ agentTypes = self.config.agentTypes
+ return [
+ Agent(
+ i, self.config.ILInit[i], self.config.AOInit, self.config.ASInit[i], self.config.c_h[i],
+ self.config.c_p[i], self.config.eta[i], agentTypes[i], self.config
+ ) for i in range(self.config.NoAgent)
+ ]
+
+ # planHorizon : Find a random planning horizon
+ def planHorizon(self):
+ # TLow: minimum number for the planning horizon # TUp: maximum number for the planning horizon
+ # output: The planning horizon which is chosen randomly.
+ return randint(self.config.TLow, self.config.TUp)
+
+ # this function resets the game for start of the new game
+ def resetGame(self, demand: np.ndarray):
+ self.demand = demand
+ self.curTime = 0
+ self.curGame += 1
+ self.totIterPlayed += self.T
+ self.T = self.planHorizon()
+ # reset the required information of player for each episode
+ for k in range(0, self.config.NoAgent):
+ self.players[k].resetPlayer(self.T)
+
+ # update OO when there are initial IL,AO,AS
+ self.update_OO()
+
+ # correction on cost at time T according to the cost of the other players
+ def getTotRew(self):
+ totRew = 0
+ for i in range(self.config.NoAgent):
+ # sum all rewards for the agents and make correction
+ totRew += self.players[i].cumReward
+
+ for i in range(self.config.NoAgent):
+ self.players[i].curReward += self.players[i].eta * (totRew - self.players[i].cumReward) # /(self.T)
+
+ # make correction to the rewards in the experience replay for all iterations of current game
+ def distTotReward(self, role: int):
+ totRew = 0
+ optRew = 0.1 # why?
+ for i in range(self.config.NoAgent):
+ # sum all rewards for the agents and make correction
+ totRew += self.players[i].cumReward
+ totRew += optRew
+
+ return totRew, self.players[role].cumReward
+
+ def getAction(self, k: int, action: np.ndarray, playType="train"):
+ if playType == "train":
+ if self.players[k].compType == "srdqn":
+ self.players[k].action = np.zeros(self.config.actionListLen)
+ self.players[k].action[action] = 1
+ elif self.players[k].compType == "Strm":
+ self.players[k].action = np.zeros(self.config.actionListLenOpt)
+ self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt)\
+ - max(0, round(self.players[k].AO[self.curTime] + \
+ self.players[k].alpha_b*(self.players[k].IL - self.players[k].a_b) + \
+ self.players[k].betta_b*(self.players[k].OO - self.players[k].b_b)))))] = 1
+ elif self.players[k].compType == "rnd":
+ self.players[k].action = np.zeros(self.config.actionListLen)
+ a = np.random.randint(self.config.actionListLen)
+ self.players[k].action[a] = 1
+ elif self.players[k].compType == "bs":
+ self.players[k].action = np.zeros(self.config.actionListLenOpt)
+ if self.config.demandDistribution == 2:
+ if self.curTime and self.config.use_initial_BS <= 4:
+ self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt) - \
+ max(0, (self.players[k].int_bslBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime])))))] = 1
+ else:
+ self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt) - \
+ max(0, (self.players[k].bsBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime])))))] = 1
+ else:
+ self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt) - \
+ max(0, (self.players[k].bsBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime])))))] = 1
+ elif playType == "test":
+ if self.players[k].compTypeTest == "srdqn":
+ self.players[k].action = np.zeros(self.config.actionListLen)
+ self.players[k].action = self.players[k].brain.getDNNAction(self.playType)
+ elif self.players[k].compTypeTest == "Strm":
+ self.players[k].action = np.zeros(self.config.actionListLenOpt)
+
+ self.players[k].action[np.argmin(np.abs(np.array(self.config.actionListOpt)-\
+ max(0,round(self.players[k].AO[self.curTime] +\
+ self.players[k].alpha_b*(self.players[k].IL - self.players[k].a_b) +\
+ self.players[k].betta_b*(self.players[k].OO - self.players[k].b_b)))))] = 1
+ elif self.players[k].compTypeTest == "rnd":
+ self.players[k].action = np.zeros(self.config.actionListLen)
+ a = np.random.randint(self.config.actionListLen)
+ self.players[k].action[a] = 1
+ elif self.players[k].compTypeTest == "bs":
+ self.players[k].action = np.zeros(self.config.actionListLenOpt)
+
+ if self.config.demandDistribution == 2:
+ if self.curTime and self.config.use_initial_BS <= 4:
+ self.players[k].action [np.argmin(np.abs(np.array(self.config.actionListOpt)-\
+ max(0,(self.players[k].int_bslBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime]))) ))] = 1
+ else:
+ self.players[k].action [np.argmin(np.abs(np.array(self.config.actionListOpt)-\
+ max(0,(self.players[k].bsBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime]))) ))] = 1
+ else:
+ self.players[k].action [np.argmin(np.abs(np.array(self.config.actionListOpt)-\
+ max(0,(self.players[k].bsBaseStock - (self.players[k].IL + self.players[k].OO - self.players[k].AO[self.curTime]))) ))] = 1
+ else:
+ # not a valid player is defined.
+ raise Exception('The player type is not defined or it is not a valid type.!')
+
+ def next(self):
+ # get a random leadtime
+ leadTimeIn = randint(
+ self.config.leadRecItemLow[self.config.NoAgent - 1], self.config.leadRecItemUp[self.config.NoAgent - 1]
+ )
+ # handle the most upstream recieved shipment
+ self.players[self.config.NoAgent - 1].AS[self.curTime +
+ leadTimeIn] += self.players[self.config.NoAgent -
+ 1].actionValue(self.curTime)
+
+ for k in range(self.config.NoAgent - 1, -1, -1): # [3,2,1,0]
+
+ # get current IL and Backorder
+ current_IL = max(0, self.players[k].IL)
+ current_backorder = max(0, -self.players[k].IL)
+
+ # TODO: We have get the AS and AO from the UI and update our AS and AO, so that code update the corresponding variables
+
+ # increase IL and decrease OO based on the action, for the next period
+ self.players[k].recieveItems(self.curTime)
+
+ # observe the reward
+ possible_shipment = min(
+ current_IL + self.players[k].AS[self.curTime], current_backorder + self.players[k].AO[self.curTime]
+ )
+
+ # plan arrivals of the items to the downstream agent
+ if self.players[k].agentNum > 0:
+ leadTimeIn = randint(self.config.leadRecItemLow[k - 1], self.config.leadRecItemUp[k - 1])
+ self.players[k - 1].AS[self.curTime + leadTimeIn] += possible_shipment
+
+ # update IL
+ self.players[k].IL -= self.players[k].AO[self.curTime]
+ # observe the reward
+ self.players[k].getReward()
+ self.players[k].hist[-1][-2] = self.players[k].curReward
+ self.players[k].hist2[-1][-2] = self.players[k].curReward
+
+ # update next observation
+ self.players[k].nextObservation = self.players[k].getCurState(self.curTime + 1)
+
+ if self.config.ifUseTotalReward:
+ # correction on cost at time T
+ if self.curTime == self.T:
+ self.getTotRew()
+
+ self.curTime += 1
+
+ def handelAction(self, action: np.ndarray, playType="train"):
+ # get random lead time
+ leadTime = randint(self.config.leadRecOrderLow[0], self.config.leadRecOrderUp[0])
+ # set AO
+ self.players[0].AO[self.curTime] += self.demand[self.curTime]
+ for k in range(0, self.config.NoAgent):
+ self.getAction(k, action, playType)
+
+ self.players[k].srdqnBaseStock += [self.players[k].actionValue( \
+ self.curTime) + self.players[k].IL + self.players[k].OO]
+
+ # update hist for the plots
+ self.players[k].hist += [[self.curTime, self.players[k].IL, self.players[k].OO,\
+ self.players[k].actionValue(self.curTime), self.players[k].curReward, self.players[k].srdqnBaseStock[-1]]]
+
+ if self.players[k].compType == "srdqn":
+ self.players[k].hist2 += [[self.curTime, self.players[k].IL, self.players[k].OO, self.players[k].AO[self.curTime], self.players[k].AS[self.curTime], \
+ self.players[k].actionValue(self.curTime), self.players[k].curReward, \
+ self.config.actionList[np.argmax(self.players[k].action)]]]
+
+ else:
+ self.players[k].hist2 += [[self.curTime, self.players[k].IL, self.players[k].OO, self.players[k].AO[self.curTime], self.players[k].AS[self.curTime], \
+ self.players[k].actionValue(self.curTime), self.players[k].curReward, 0]]
+
+ # updates OO and AO at time t+1
+ self.players[k].OO += self.players[k].actionValue(self.curTime) # open order level update
+ leadTime = randint(self.config.leadRecOrderLow[k], self.config.leadRecOrderUp[k])
+ if self.players[k].agentNum < self.config.NoAgent - 1:
+ self.players[k + 1].AO[self.curTime + leadTime] += self.players[k].actionValue(
+ self.curTime
+ ) # open order level update
+
+ # check the Shang and Song (2003) condition, and if it works, obtains the base stock policy values for each agent
+ def getOptimalSol(self):
+ # if self.config.NoAgent !=1:
+ if self.config.NoAgent != 1 and 1 == 2:
+ # check the Shang and Song (2003) condition.
+ for k in range(self.config.NoAgent - 1):
+ if not (self.players[k].c_h == self.players[k + 1].c_h and self.players[k + 1].c_p == 0):
+ self.ifOptimalSolExist = False
+
+ # if the Shang and Song (2003) condition satisfied, it runs the algorithm
+ if self.ifOptimalSolExist == True:
+ calculations = np.zeros((7, self.config.NoAgent))
+ for k in range(self.config.NoAgent):
+ # DL_high
+ calculations[0][k] = ((self.config.leadRecItemLow + self.config.leadRecItemUp + 2) / 2 \
+ + (self.config.leadRecOrderLow + self.config.leadRecOrderUp + 2) / 2) * \
+ (self.config.demandUp - self.config.demandLow - 1)
+ if k > 0:
+ calculations[0][k] += calculations[0][k - 1]
+ # probability_high
+ nominator_ch = 0
+ low_denominator_ch = 0
+ for j in range(k, self.config.NoAgent):
+ if j < self.config.NoAgent - 1:
+ nominator_ch += self.players[j + 1].c_h
+ low_denominator_ch += self.players[j].c_h
+ if k == 0:
+ high_denominator_ch = low_denominator_ch
+ calculations[2][k] = (self.players[0].c_p +
+ nominator_ch) / (self.players[0].c_p + low_denominator_ch + 0.0)
+ # probability_low
+ calculations[3][k] = (self.players[0].c_p +
+ nominator_ch) / (self.players[0].c_p + high_denominator_ch + 0.0)
+ # S_high
+ calculations[4] = np.round(np.multiply(calculations[0], calculations[2]))
+ # S_low
+ calculations[5] = np.round(np.multiply(calculations[0], calculations[3]))
+ # S_avg
+ calculations[6] = np.round(np.mean(calculations[4:6], axis=0))
+ # S', set the base stock values into each agent.
+ for k in range(self.config.NoAgent):
+ if k == 0:
+ self.players[k].bsBaseStock = calculations[6][k]
+
+ else:
+ self.players[k].bsBaseStock = calculations[6][k] - calculations[6][k - 1]
+ if self.players[k].bsBaseStock < 0:
+ self.players[k].bsBaseStock = 0
+ elif self.config.NoAgent == 1:
+ if self.config.demandDistribution == 0:
+ self.players[0].bsBaseStock = np.ceil(
+ self.config.c_h[0] / (self.config.c_h[0] + self.config.c_p[0] + 0.0)
+ ) * ((self.config.demandUp - self.config.demandLow - 1) / 2) * self.config.leadRecItemUp
+ elif 1 == 1:
+ f = self.config.f
+ f_init = self.config.f_init
+ for k in range(self.config.NoAgent):
+ self.players[k].bsBaseStock = f[k]
+ self.players[k].int_bslBaseStock = f_init[k]
+
+ def update_OO(self):
+ for k in range(0, self.config.NoAgent):
+ if k < self.config.NoAgent - 1:
+ self.players[k].OO = sum(self.players[k + 1].AO) + sum(self.players[k].AS)
+ else:
+ self.players[k].OO = sum(self.players[k].AS)
+
+ def doTestMid(self, demandTs):
+ self.resultTest = []
+ m = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
+ self.doTest(m, demandTs)
+ print("---------------------------------------------------------------------------------------")
+ resultSummary = np.array(self.resultTest).mean(axis=0).tolist()
+
+ result_srdqn = ', '.join(map("{:.2f}".format, resultSummary[0]))
+ result_rand = ', '.join(map("{:.2f}".format, resultSummary[1]))
+ result_strm = ', '.join(map("{:.2f}".format, resultSummary[2]))
+ if self.ifOptimalSolExist:
+ result_bs = ', '.join(map("{:.2f}".format, resultSummary[3]))
+ print(
+ 'SUMMARY; {0:s}; ITER= {1:d}; OURPOLICY= [{2:s}]; SUM = {3:2.4f}; Rand= [{4:s}]; SUM = {5:2.4f}; STRM= [{6:s}]; SUM = {7:2.4f}; BS= [{8:s}]; SUM = {9:2.4f}'
+ .format(
+ strftime("%Y-%m-%d %H:%M:%S", gmtime()), self.curGame, result_srdqn, sum(resultSummary[0]),
+ result_rand, sum(resultSummary[1]), result_strm, sum(resultSummary[2]), result_bs,
+ sum(resultSummary[3])
+ )
+ )
+
+ else:
+ print(
+ 'SUMMARY; {0:s}; ITER= {1:d}; OURPOLICY= [{2:s}]; SUM = {3:2.4f}; Rand= [{4:s}]; SUM = {5:2.4f}; STRM= [{6:s}]; SUM = {7:2.4f}'
+ .format(
+ strftime("%Y-%m-%d %H:%M:%S", gmtime()), self.curGame, result_srdqn, sum(resultSummary[0]),
+ result_rand, sum(resultSummary[1]), result_strm, sum(resultSummary[2])
+ )
+ )
+
+ print("=======================================================================================")
+
+ def doTest(self, m, demand):
+ import matplotlib.pyplot as plt
+ if self.config.ifSaveFigure:
+ plt.figure(self.curGame, figsize=(12, 8), dpi=80, facecolor='w', edgecolor='k')
+
+ # self.demand = demand
+ # use dnn to get output.
+ Rsltdnn, plt = self.tester(self.config.agentTypes, plt, 'b', 'OurPolicy', m)
+ baseStockdata = self.players[0].srdqnBaseStock
+ # # use random to get output.
+ RsltRnd, plt = self.tester(["rnd", "rnd", "rnd", "rnd"], plt, 'y', 'RAND', m)
+
+ # use formual to get output.
+ RsltStrm, plt = self.tester(["Strm", "Strm", "Strm", "Strm"], plt, 'g', 'Strm', m)
+
+ # use optimal strategy to get output, if it works.
+ if self.ifOptimalSolExist:
+ if self.config.agentTypes == ["srdqn", "Strm", "Strm", "Strm"]:
+ Rsltbs, plt = self.tester(["bs", "Strm", "Strm", "Strm"], plt, 'r', 'Strm-BS', m)
+ elif self.config.agentTypes == ["Strm", "srdqn", "Strm", "Strm"]:
+ Rsltbs, plt = self.tester(["Strm", "bs", "Strm", "Strm"], plt, 'r', 'Strm-BS', m)
+ elif self.config.agentTypes == ["Strm", "Strm", "srdqn", "Strm"]:
+ Rsltbs, plt = self.tester(["Strm", "Strm", "bs", "Strm"], plt, 'r', 'Strm-BS', m)
+ elif self.config.agentTypes == ["Strm", "Strm", "Strm", "srdqn"]:
+ Rsltbs, plt = self.tester(["Strm", "Strm", "Strm", "bs"], plt, 'r', 'Strm-BS', m)
+ elif self.config.agentTypes == ["srdqn", "rnd", "rnd", "rnd"]:
+ Rsltbs, plt = self.tester(["bs", "rnd", "rnd", "rnd"], plt, 'r', 'RND-BS', m)
+ elif self.config.agentTypes == ["rnd", "srdqn", "rnd", "rnd"]:
+ Rsltbs, plt = self.tester(["rnd", "bs", "rnd", "rnd"], plt, 'r', 'RND-BS', m)
+ elif self.config.agentTypes == ["rnd", "rnd", "srdqn", "rnd"]:
+ Rsltbs, plt = self.tester(["rnd", "rnd", "bs", "rnd"], plt, 'r', 'RND-BS', m)
+ elif self.config.agentTypes == ["rnd", "rnd", "rnd", "srdqn"]:
+ Rsltbs, plt = self.tester(["rnd", "rnd", "rnd", "bs"], plt, 'r', 'RND-BS', m)
+ else:
+ Rsltbs, plt = self.tester(["bs", "bs", "bs", "bs"], plt, 'r', 'BS', m)
+ # hold the results of the optimal solution
+ self.middleTestResult += [[RsltRnd, RsltStrm, Rsltbs]]
+ else:
+ self.middleTestResult += [[RsltRnd, RsltStrm]]
+
+ else:
+ # return the obtained results into their lists
+ RsltRnd = self.middleTestResult[m][0]
+ RsltStrm = self.middleTestResult[m][1]
+ if self.ifOptimalSolExist:
+ Rsltbs = self.middleTestResult[m][2]
+
+ # save the figure
+ if self.config.ifSaveFigure:
+ savePlot(self.players, self.curGame, Rsltdnn, RsltStrm, Rsltbs, RsltRnd, self.config, m)
+ plt.close()
+
+ result_srdqn = ', '.join(map("{:.2f}".format, Rsltdnn))
+ result_rand = ', '.join(map("{:.2f}".format, RsltRnd))
+ result_strm = ', '.join(map("{:.2f}".format, RsltStrm))
+ if self.ifOptimalSolExist:
+ result_bs = ', '.join(map("{:.2f}".format, Rsltbs))
+ print(
+ 'output; {0:s}; Iter= {1:s}; SRDQN= [{2:s}]; sum = {3:2.4f}; Rand= [{4:s}]; sum = {5:2.4f}; Strm= [{6:s}]; sum = {7:2.4f}; BS= [{8:s}]; sum = {9:2.4f}'
+ .format(
+ strftime("%Y-%m-%d %H:%M:%S", gmtime()), str(str(self.curGame) + "-" + str(m)), result_srdqn,
+ sum(Rsltdnn), result_rand, sum(RsltRnd), result_strm, sum(RsltStrm), result_bs, sum(Rsltbs)
+ )
+ )
+ self.resultTest += [[Rsltdnn, RsltRnd, RsltStrm, Rsltbs]]
+
+ else:
+ print(
+ 'output; {0:s}; Iter= {1:s}; SRDQN= [{2:s}]; sum = {3:2.4f}; Rand= [{4:s}]; sum = {5:2.4f}; Strm= [{6:s}]; sum = {7:2.4f}'
+ .format(
+ strftime("%Y-%m-%d %H:%M:%S", gmtime()), str(str(self.curGame) + "-" + str(m)), result_srdqn,
+ sum(Rsltdnn), result_rand, sum(RsltRnd), result_strm, sum(RsltStrm)
+ )
+ )
+
+ self.resultTest += [[Rsltdnn, RsltRnd, RsltStrm]]
+
+ return sum(Rsltdnn)
+
+ def tester(self, testType, plt, colori, labeli, m):
+
+ # set computation type for test
+ for k in range(0, self.config.NoAgent):
+ # self.players[k].compTypeTest = testType[k]
+ self.players[k].compType = testType[k]
+ # run the episode to get the results.
+ if labeli != 'OurPolicy':
+ result = self.playGame(self.demand)
+ else:
+ result = [-1 * self.players[i].cumReward for i in range(0, self.config.NoAgent)]
+ # add the results into the figure
+ if self.config.ifSaveFigure:
+ plt = plotting(plt, [np.array(self.players[i].hist) for i in range(0, self.config.NoAgent)], colori, labeli)
+ if self.config.ifsaveHistInterval and ((self.curGame == 0) or (self.curGame == 1) or (self.curGame == 2) or (self.curGame == 3) or ((self.curGame - 1) % self.config.saveHistInterval == 0)\
+ or ((self.curGame) % self.config.saveHistInterval == 0) or ((self.curGame) % self.config.saveHistInterval == 1) \
+ or ((self.curGame) % self.config.saveHistInterval == 2)) :
+ for k in range(0, self.config.NoAgent):
+ name = labeli + "-" + str(self.curGame) + "-" + "player" + "-" + str(k) + "-" + str(m)
+ np.save(os.path.join(self.config.model_dir, name), np.array(self.players[k].hist2))
+
+ # save the figure of base stocks
+ # if self.config.ifSaveFigure and (self.curGame in range(self.config.saveFigInt[0],self.config.saveFigInt[1])):
+ # for k in range(self.config.NoAgent):
+ # if self.players[k].compTypeTest == 'dnn':
+ # plotBaseStock(self.players[k].srdqnBaseStock, 'b', 'base stock of agent '+ str(self.players[k].agentNum), self.curGame, self.config, m)
+
+ return result, plt
+
+ def playGame(self, demand):
+ self.resetGame(demand)
+
+ # run the game
+ while self.curTime < self.T:
+ self.handelAction(np.array(0)) # action won't be used.
+ self.next()
+ return [-1 * self.players[i].cumReward for i in range(0, self.config.NoAgent)]
diff --git a/DI-engine/dizoo/beergame/envs/plotting.py b/DI-engine/dizoo/beergame/envs/plotting.py
new file mode 100644
index 0000000000000000000000000000000000000000..57776c9641d18fcac289957198a1102fa957f162
--- /dev/null
+++ b/DI-engine/dizoo/beergame/envs/plotting.py
@@ -0,0 +1,72 @@
+# Code Reference: https://github.com/OptMLGroup/DeepBeerInventory-RL.
+import os
+import numpy as np
+import matplotlib
+matplotlib.use('Agg')
+import matplotlib.pyplot as plt
+from pylab import *
+
+
+# plotting
+def plotting(plt, data, colori, pltLabel):
+ # plt.hold(True)
+
+ for i in range(np.shape(data)[0]):
+ plt.subplot(4, 5, 5 * i + 1)
+ plt.plot(np.transpose(data[i])[0, :], np.transpose(data[i])[1, :], colori, label=pltLabel)
+ plt.xlabel('Time')
+ plt.ylabel('IL')
+ plt.grid(True)
+
+ plt.subplot(4, 5, 5 * i + 2)
+ plt.plot(np.transpose(data[i])[0, :], np.transpose(data[i])[2, :], colori, label=pltLabel)
+ plt.xlabel('Time')
+ plt.ylabel('OO')
+ plt.grid(True)
+
+ plt.subplot(4, 5, 5 * i + 3)
+ plt.plot(np.transpose(data[i])[0, :], np.transpose(data[i])[3, :], colori, label=pltLabel)
+ plt.xlabel('Time')
+ plt.ylabel('a')
+ plt.grid(True)
+
+ plt.subplot(4, 5, 5 * i + 4)
+ plt.plot(np.transpose(data[i])[0, :], np.transpose(data[i])[5, :], colori, label=pltLabel)
+ plt.xlabel('Time')
+ plt.ylabel('OUTL')
+ plt.grid(True)
+
+ plt.subplot(4, 5, 5 * i + 5)
+ plt.plot(np.transpose(data[i])[0, :], -1 * np.transpose(data[i])[4, :], colori, label=pltLabel)
+ plt.xlabel('Time')
+ plt.ylabel('r')
+ plt.grid(True)
+
+ return plt
+
+
+def savePlot(players, curGame, Rsltdnn, RsltFrmu, RsltOptm, RsltRnd, config, m):
+ #add title to plot
+ if config.if_titled_figure:
+ plt.suptitle(
+ "sum OurPolicy=" + str(round(sum(Rsltdnn), 2)) + "; sum Strm=" + str(round(sum(RsltFrmu), 2)) +
+ "; sum BS=" + str(round(sum(RsltOptm), 2)) + "; sum Rnd=" + str(round(sum(RsltRnd), 2)) + "\n" +
+ "Ag OurPolicy=" + str([round(Rsltdnn[i], 2) for i in range(config.NoAgent)]) + "; Ag Strm=" +
+ str([round(RsltFrmu[i], 2) for i in range(config.NoAgent)]) + "; Ag BS=" +
+ str([round(RsltOptm[i], 2) for i in range(config.NoAgent)]) + "; Ag Rnd=" +
+ str([round(RsltRnd[i], 2) for i in range(config.NoAgent)]),
+ fontsize=12
+ )
+
+ #insert legend to the figure
+ legend = plt.legend(bbox_to_anchor=(-1.4, -.165, 1., -.102), shadow=True, ncol=4)
+
+ # configures spaces between subplots
+ plt.subplots_adjust(left=None, bottom=None, right=None, top=None, wspace=.5, hspace=.5)
+ # save the figure
+ path = os.path.join(config.figure_dir, 'saved_figures/')
+ if not os.path.exists(path):
+ os.mkdir(path)
+ plt.savefig(path + str(curGame) + '-' + str(m) + '.png', format='png')
+ print("figure" + str(curGame) + ".png saved in folder \"saved_figures\"")
+ plt.close(curGame)
diff --git a/DI-engine/dizoo/beergame/envs/utils.py b/DI-engine/dizoo/beergame/envs/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a6cf6f83da6ba442447dc5dce04a8c8418839de
--- /dev/null
+++ b/DI-engine/dizoo/beergame/envs/utils.py
@@ -0,0 +1,355 @@
+import argparse
+import os
+import numpy as np
+
+
+def str2bool(v):
+ return v.lower() in ('true', '1')
+
+
+arg_lists = []
+parser = argparse.ArgumentParser()
+
+
+def add_argument_group(name):
+ arg = parser.add_argument_group(name)
+ arg_lists.append(arg)
+ return arg
+
+
+# crm
+game_arg = add_argument_group('BeerGame')
+game_arg.add_argument('--task', type=str, default='bg')
+game_arg.add_argument(
+ '--fixedAction',
+ type=str2bool,
+ default='False',
+ help='if you want to have actions in [0,actionMax] set it to True. with False it will set it [actionLow, actionUp]'
+)
+game_arg.add_argument(
+ '--observation_data',
+ type=str2bool,
+ default=False,
+ help='if it is True, then it uses the data that is generated by based on few real world observation'
+)
+game_arg.add_argument('--data_id', type=int, default=22, help='the default item id for the basket dataset')
+game_arg.add_argument('--TLow', type=int, default=100, help='duration of one GAME (lower bound)')
+game_arg.add_argument('--TUp', type=int, default=100, help='duration of one GAME (upper bound)')
+game_arg.add_argument(
+ '--demandDistribution',
+ type=int,
+ default=0,
+ help='0=uniform, 1=normal distribution, 2=the sequence of 4,4,4,4,8,..., 3= basket data, 4= forecast data'
+)
+game_arg.add_argument(
+ '--scaled', type=str2bool, default=False, help='if true it uses the (if) existing scaled parameters'
+)
+game_arg.add_argument('--demandSize', type=int, default=6100, help='the size of demand dataset')
+game_arg.add_argument('--demandLow', type=int, default=0, help='the lower bound of random demand')
+game_arg.add_argument('--demandUp', type=int, default=3, help='the upper bound of random demand')
+game_arg.add_argument('--demandMu', type=float, default=10, help='the mu of the normal distribution for demand ')
+game_arg.add_argument('--demandSigma', type=float, default=2, help='the sigma of the normal distribution for demand ')
+game_arg.add_argument('--actionMax', type=int, default=2, help='it works when fixedAction is True')
+game_arg.add_argument(
+ '--actionUp', type=int, default=2, help='bounds on my decision (upper bound), it works when fixedAction is True'
+)
+game_arg.add_argument(
+ '--actionLow', type=int, default=-2, help='bounds on my decision (lower bound), it works when fixedAction is True'
+)
+game_arg.add_argument(
+ '--action_step', type=int, default=1, help='The obtained action value by dnn is multiplied by this value'
+)
+game_arg.add_argument('--actionList', type=list, default=[], help='The list of the available actions')
+game_arg.add_argument('--actionListLen', type=int, default=0, help='the length of the action list')
+game_arg.add_argument(
+ '--actionListOpt', type=int, default=0, help='the action list which is used in optimal and sterman'
+)
+game_arg.add_argument('--actionListLenOpt', type=int, default=0, help='the length of the actionlistopt')
+game_arg.add_argument('--agentTypes', type=list, default=['dnn', 'dnn', 'dnn', 'dnn'], help='the player types')
+game_arg.add_argument(
+ '--agent_type1', type=str, default='dnn', help='the player types for agent 1, it can be dnn, Strm, bs, rnd'
+)
+game_arg.add_argument(
+ '--agent_type2', type=str, default='dnn', help='the player types for agent 2, it can be dnn, Strm, bs, rnd'
+)
+game_arg.add_argument(
+ '--agent_type3', type=str, default='dnn', help='the player types for agent 3, it can be dnn, Strm, bs, rnd'
+)
+game_arg.add_argument(
+ '--agent_type4', type=str, default='dnn', help='the player types for agent 4, it can be dnn, Strm, bs, rnd'
+)
+game_arg.add_argument('--NoAgent', type=int, default=4, help='number of agents, currently it should be in {1,2,3,4}')
+game_arg.add_argument('--cp1', type=float, default=2.0, help='shortage cost of player 1')
+game_arg.add_argument('--cp2', type=float, default=0.0, help='shortage cost of player 2')
+game_arg.add_argument('--cp3', type=float, default=0.0, help='shortage cost of player 3')
+game_arg.add_argument('--cp4', type=float, default=0.0, help='shortage cost of player 4')
+game_arg.add_argument('--ch1', type=float, default=2.0, help='holding cost of player 1')
+game_arg.add_argument('--ch2', type=float, default=2.0, help='holding cost of player 2')
+game_arg.add_argument('--ch3', type=float, default=2.0, help='holding cost of player 3')
+game_arg.add_argument('--ch4', type=float, default=2.0, help='holding cost of player 4')
+game_arg.add_argument('--alpha_b1', type=float, default=-0.5, help='alpha of Sterman formula parameter for player 1')
+game_arg.add_argument('--alpha_b2', type=float, default=-0.5, help='alpha of Sterman formula parameter for player 2')
+game_arg.add_argument('--alpha_b3', type=float, default=-0.5, help='alpha of Sterman formula parameter for player 3')
+game_arg.add_argument('--alpha_b4', type=float, default=-0.5, help='alpha of Sterman formula parameter for player 4')
+game_arg.add_argument('--betta_b1', type=float, default=-0.2, help='beta of Sterman formula parameter for player 1')
+game_arg.add_argument('--betta_b2', type=float, default=-0.2, help='beta of Sterman formula parameter for player 2')
+game_arg.add_argument('--betta_b3', type=float, default=-0.2, help='beta of Sterman formula parameter for player 3')
+game_arg.add_argument('--betta_b4', type=float, default=-0.2, help='beta of Sterman formula parameter for player 4')
+game_arg.add_argument('--eta', type=list, default=[0, 4, 4, 4], help='the total cost regulazer')
+game_arg.add_argument('--distCoeff', type=int, default=20, help='the total cost regulazer')
+game_arg.add_argument(
+ '--ifUseTotalReward',
+ type=str2bool,
+ default='False',
+ help='if you want to have the total rewards in the experience replay, set it to true.'
+)
+game_arg.add_argument(
+ '--ifUsedistTotReward',
+ type=str2bool,
+ default='True',
+ help='If use correction to the rewards in the experience replay for all iterations of current game'
+)
+game_arg.add_argument(
+ '--ifUseASAO',
+ type=str2bool,
+ default='True',
+ help='if use AS and AO, i.e., received shipment and received orders in the input of DNN'
+)
+game_arg.add_argument('--ifUseActionInD', type=str2bool, default='False', help='if use action in the input of DNN')
+game_arg.add_argument(
+ '--stateDim', type=int, default=5, help='Number of elements in the state desciptor - Depends on ifUseASAO'
+)
+game_arg.add_argument('--iftl', type=str2bool, default=False, help='if apply transfer learning')
+game_arg.add_argument(
+ '--ifTransferFromSmallerActionSpace',
+ type=str2bool,
+ default=False,
+ help='if want to transfer knowledge from a network with different action space size.'
+)
+game_arg.add_argument(
+ '--baseActionSize',
+ type=int,
+ default=5,
+ help='if ifTransferFromSmallerActionSpace is true, this determines the size of action space of saved network'
+)
+game_arg.add_argument(
+ '--tlBaseBrain',
+ type=int,
+ default=3,
+ help='the gameConfig of the base network for re-training with transfer-learning'
+)
+game_arg.add_argument('--baseDemandDistribution', type=int, default=0, help='same as the demandDistribution')
+game_arg.add_argument(
+ '--MultiAgent', type=str2bool, default=False, help='if run multi-agent RL model, not fully operational'
+)
+game_arg.add_argument(
+ '--MultiAgentRun',
+ type=list,
+ default=[True, True, True, True],
+ help='In the multi-RL setting, it determines which agent should get training.'
+)
+game_arg.add_argument(
+ '--if_use_AS_t_plus_1', type=str2bool, default='False', help='if use AS[t+1], not AS[t] in the input of DNN'
+)
+game_arg.add_argument(
+ '--ifSinglePathExist',
+ type=str2bool,
+ default=False,
+ help='If true it uses the predefined path in pre_model_dir and does not merge it with demandDistribution.'
+)
+game_arg.add_argument('--gamma', type=float, default=.99, help='discount factor for reward')
+game_arg.add_argument(
+ '--multPerdInpt', type=int, default=10, help='Number of history records which we feed into network'
+)
+
+# parameters of the leadtimes
+leadtimes_arg = add_argument_group('leadtimes')
+leadtimes_arg.add_argument(
+ '--leadRecItemLow', type=list, default=[2, 2, 2, 4], help='the min lead time for receiving items'
+)
+leadtimes_arg.add_argument(
+ '--leadRecItemUp', type=list, default=[2, 2, 2, 4], help='the max lead time for receiving items'
+)
+leadtimes_arg.add_argument(
+ '--leadRecOrderLow', type=int, default=[2, 2, 2, 0], help='the min lead time for receiving orders'
+)
+leadtimes_arg.add_argument(
+ '--leadRecOrderUp', type=int, default=[2, 2, 2, 0], help='the max lead time for receiving orders'
+)
+leadtimes_arg.add_argument('--ILInit', type=list, default=[0, 0, 0, 0], help='')
+leadtimes_arg.add_argument('--AOInit', type=list, default=[0, 0, 0, 0], help='')
+leadtimes_arg.add_argument('--ASInit', type=list, default=[0, 0, 0, 0], help='the initial shipment of each agent')
+leadtimes_arg.add_argument('--leadRecItem1', type=int, default=2, help='the min lead time for receiving items')
+leadtimes_arg.add_argument('--leadRecItem2', type=int, default=2, help='the min lead time for receiving items')
+leadtimes_arg.add_argument('--leadRecItem3', type=int, default=2, help='the min lead time for receiving items')
+leadtimes_arg.add_argument('--leadRecItem4', type=int, default=2, help='the min lead time for receiving items')
+leadtimes_arg.add_argument('--leadRecOrder1', type=int, default=2, help='the min lead time for receiving order')
+leadtimes_arg.add_argument('--leadRecOrder2', type=int, default=2, help='the min lead time for receiving order')
+leadtimes_arg.add_argument('--leadRecOrder3', type=int, default=2, help='the min lead time for receiving order')
+leadtimes_arg.add_argument('--leadRecOrder4', type=int, default=2, help='the min lead time for receiving order')
+leadtimes_arg.add_argument('--ILInit1', type=int, default=0, help='the initial inventory level of the agent')
+leadtimes_arg.add_argument('--ILInit2', type=int, default=0, help='the initial inventory level of the agent')
+leadtimes_arg.add_argument('--ILInit3', type=int, default=0, help='the initial inventory level of the agent')
+leadtimes_arg.add_argument('--ILInit4', type=int, default=0, help='the initial inventory level of the agent')
+leadtimes_arg.add_argument('--AOInit1', type=int, default=0, help='the initial arriving order of the agent')
+leadtimes_arg.add_argument('--AOInit2', type=int, default=0, help='the initial arriving order of the agent')
+leadtimes_arg.add_argument('--AOInit3', type=int, default=0, help='the initial arriving order of the agent')
+leadtimes_arg.add_argument('--AOInit4', type=int, default=0, help='the initial arriving order of the agent')
+leadtimes_arg.add_argument('--ASInit1', type=int, default=0, help='the initial arriving shipment of the agent')
+leadtimes_arg.add_argument('--ASInit2', type=int, default=0, help='the initial arriving shipment of the agent')
+leadtimes_arg.add_argument('--ASInit3', type=int, default=0, help='the initial arriving shipment of the agent')
+leadtimes_arg.add_argument('--ASInit4', type=int, default=0, help='the initial arriving shipment of the agent')
+
+# test
+test_arg = add_argument_group('testing')
+test_arg.add_argument(
+ '--testRepeatMid',
+ type=int,
+ default=50,
+ help='it is number of episodes which is going to be used for testing in the middle of training'
+)
+test_arg.add_argument('--testInterval', type=int, default=100, help='every xx games compute "test error"')
+test_arg.add_argument(
+ '--ifSaveFigure', type=str2bool, default=True, help='if is it True, save the figures in each testing.'
+)
+test_arg.add_argument(
+ '--if_titled_figure',
+ type=str2bool,
+ default='True',
+ help='if is it True, save the figures with details in the title.'
+)
+test_arg.add_argument(
+ '--ifsaveHistInterval', type=str2bool, default=False, help='if every xx games save details of the episode'
+)
+test_arg.add_argument('--saveHistInterval', type=int, default=50000, help='every xx games save details of the play')
+test_arg.add_argument('--Ttest', type=int, default=100, help='it defines the number of periods in the test cases')
+test_arg.add_argument(
+ '--ifOptimalSolExist',
+ type=str2bool,
+ default=True,
+ help='if the instance has optimal base stock policy, set it to True, otherwise it should be False.'
+)
+test_arg.add_argument('--f1', type=float, default=8, help='base stock policy decision of player 1')
+test_arg.add_argument('--f2', type=float, default=8, help='base stock policy decision of player 2')
+test_arg.add_argument('--f3', type=float, default=0, help='base stock policy decision of player 3')
+test_arg.add_argument('--f4', type=float, default=0, help='base stock policy decision of player 4')
+test_arg.add_argument(
+ '--f_init',
+ type=list,
+ default=[32, 32, 32, 24],
+ help='base stock policy decision for 4 time-steps on the C(4,8) demand distribution'
+)
+test_arg.add_argument('--use_initial_BS', type=str2bool, default=False, help='If use f_init set it to True')
+
+# reporting
+reporting_arg = add_argument_group('reporting')
+reporting_arg.add_argument('--Rsltdnn', type=list, default=[], help='the result of dnn play tests will be saved here')
+reporting_arg.add_argument(
+ '--RsltRnd', type=list, default=[], help='the result of random play tests will be saved here'
+)
+reporting_arg.add_argument(
+ '--RsltStrm', type=list, default=[], help='the result of heuristic fomula play tests will be saved here'
+)
+reporting_arg.add_argument(
+ '--Rsltbs', type=list, default=[], help='the result of optimal play tests will be saved here'
+)
+reporting_arg.add_argument(
+ '--ifSaveHist',
+ type=str2bool,
+ default='False',
+ help=
+ 'if it is true, saves history, prediction, and the randBatch in each period, WARNING: just make it True in small runs, it saves huge amount of files.'
+)
+
+
+# buildActionList: actions for the beer game problem
+def buildActionList(config):
+ aDiv = 1 # difference in the action list
+ if config.fixedAction:
+ actions = list(
+ range(0, config.actionMax + 1, aDiv)
+ ) # If you put the second argument =11, creates an actionlist from 0..xx
+ else:
+ actions = list(range(config.actionLow, config.actionUp + 1, aDiv))
+ return actions
+
+
+# specify the dimension of the state of the game
+def getStateDim(config):
+ if config.ifUseASAO:
+ stateDim = 5
+ else:
+ stateDim = 3
+
+ if config.ifUseActionInD:
+ stateDim += 1
+
+ return stateDim
+
+
+def set_optimal(config):
+ if config.demandDistribution == 0:
+ if config.cp1 == 2 and config.ch1 == 2 and config.ch2 == 2 and config.ch3 == 2 and config.ch4 == 2:
+ config.f1 = 8.
+ config.f2 = 8.
+ config.f3 = 0.
+ config.f4 = 0.
+
+
+def get_config():
+ config, unparsed = parser.parse_known_args()
+ config = update_config(config)
+
+ return config, unparsed
+
+
+def fill_leadtime_initial_values(config):
+ config.leadRecItemLow = [config.leadRecItem1, config.leadRecItem2, config.leadRecItem3, config.leadRecItem4]
+ config.leadRecItemUp = [config.leadRecItem1, config.leadRecItem2, config.leadRecItem3, config.leadRecItem4]
+ config.leadRecOrderLow = [config.leadRecOrder1, config.leadRecOrder2, config.leadRecOrder3, config.leadRecOrder4]
+ config.leadRecOrderUp = [config.leadRecOrder1, config.leadRecOrder2, config.leadRecOrder3, config.leadRecOrder4]
+ config.ILInit = [config.ILInit1, config.ILInit2, config.ILInit3, config.ILInit4]
+ config.AOInit = [config.AOInit1, config.AOInit2, config.AOInit3, config.AOInit4]
+ config.ASInit = [config.ASInit1, config.ASInit2, config.ASInit3, config.ASInit4]
+
+
+def get_auxuliary_leadtime_initial_values(config):
+ config.leadRecOrderUp_aux = [config.leadRecOrder1, config.leadRecOrder2, config.leadRecOrder3, config.leadRecOrder4]
+ config.leadRecItemUp_aux = [config.leadRecItem1, config.leadRecItem2, config.leadRecItem3, config.leadRecItem4]
+
+
+def fix_lead_time_manufacturer(config):
+ if config.leadRecOrder4 > 0:
+ config.leadRecItem4 += config.leadRecOrder4
+ config.leadRecOrder4 = 0
+
+
+def set_sterman_parameters(config):
+ config.alpha_b = [config.alpha_b1, config.alpha_b2, config.alpha_b3, config.alpha_b4]
+ config.betta_b = [config.betta_b1, config.betta_b2, config.betta_b3, config.betta_b4]
+
+
+def update_config(config):
+ config.actionList = buildActionList(config) # The list of the available actions
+ config.actionListLen = len(config.actionList) # the length of the action list
+
+ set_optimal(config)
+ config.f = [config.f1, config.f2, config.f3, config.f4] # [6.4, 2.88, 2.08, 0.8]
+
+ config.actionListLen = len(config.actionList)
+ if config.demandDistribution == 0:
+ config.actionListOpt = list(range(0, int(max(config.actionUp * 30 + 1, 3 * sum(config.f))), 1))
+ else:
+ config.actionListOpt = list(range(0, int(max(config.actionUp * 30 + 1, 7 * sum(config.f))), 1))
+ config.actionListLenOpt = len(config.actionListOpt)
+
+ config.c_h = [config.ch1, config.ch2, config.ch3, config.ch4]
+ config.c_p = [config.cp1, config.cp2, config.cp3, config.cp4]
+
+ config.stateDim = getStateDim(config) # Number of elements in the state description - Depends on ifUseASAO
+ get_auxuliary_leadtime_initial_values(config)
+ fix_lead_time_manufacturer(config)
+ fill_leadtime_initial_values(config)
+ set_sterman_parameters(config)
+
+ return config
diff --git a/DI-engine/dizoo/bitflip/README.md b/DI-engine/dizoo/bitflip/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..bb524b65c47f7cb6bef633a3029ad58be88a15f6
--- /dev/null
+++ b/DI-engine/dizoo/bitflip/README.md
@@ -0,0 +1,15 @@
+## BitFlip Environment
+A simple environment to flip a 01 sequence into a specific state. With the bits number increasing, the task becomes harder.
+Well suited for testing Hindsight Experience Replay.
+
+## DI-engine's HER on BitFlip
+
+The table shows how many envsteps are needed at least to converge for PureDQN and HER-DQN implemented in DI-engine. '-' means no convergence in 20M envsteps.
+
+| n_bit | PureDQN | HER-DQN |
+| ------ | ------- | ------- |
+| 15 | - | 150K |
+| 20 | - | 1.5M |
+DI-engine's HER-DQN can converge
+
+You can refer to the RL algorithm doc for implementation and experiment details.
diff --git a/DI-engine/dizoo/bitflip/__init__.py b/DI-engine/dizoo/bitflip/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/bitflip/config/__init__.py b/DI-engine/dizoo/bitflip/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ad65e6ba4d791d3c9053377d6e7320c4a1c6136
--- /dev/null
+++ b/DI-engine/dizoo/bitflip/config/__init__.py
@@ -0,0 +1,2 @@
+from .bitflip_her_dqn_config import bitflip_her_dqn_config, bitflip_her_dqn_create_config
+from .bitflip_pure_dqn_config import bitflip_pure_dqn_config, bitflip_pure_dqn_create_config
diff --git a/DI-engine/dizoo/bitflip/config/bitflip_her_dqn_config.py b/DI-engine/dizoo/bitflip/config/bitflip_her_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2128cfa90192e8ff6fcc4b7c67c71d73515a5860
--- /dev/null
+++ b/DI-engine/dizoo/bitflip/config/bitflip_her_dqn_config.py
@@ -0,0 +1,76 @@
+from easydict import EasyDict
+
+n_bits = 5 # 15 or 20 n_bits can show differences between pure DQN and HER DQN, 5 n_bits for unittest
+bitflip_her_dqn_config = dict(
+ exp_name='bitflip_{}bit_herdqn_seed0'.format(n_bits),
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=16,
+ n_bits=n_bits,
+ n_evaluator_episode=16,
+ stop_value=0.9,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=2 * n_bits,
+ action_shape=n_bits,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ # == Different from most DQN algorithms ==
+ # If discount_factor(gamma) > 0.9, it would be very difficult to converge
+ discount_factor=0.8,
+ learn=dict(
+ update_per_collect=10,
+ # batch_size = episode_size * sample_per_episode
+ # You can refer to cfg.other.her to learn about `episode_size` and `sample_per_episode`
+ batch_size=128,
+ learning_rate=0.0005,
+ target_update_freq=500,
+ ),
+ collect=dict(
+ n_episode=8,
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000)),
+ other=dict(
+ # == Different from most DQN algorithms ==
+ # Fix epsilon to 0.2 leads to easier convergence, proposed in the paper.
+ eps=dict(
+ type='exp',
+ start=0.2, # 0.8
+ end=0.2, # original0.1, paper0.15~0.2
+ decay=100, # 10000
+ ),
+ replay_buffer=dict(replay_buffer_size=4000, ),
+ her=dict(
+ her_strategy='future',
+ # her_replay_k=2, # `her_replay_k` is not used in episodic HER
+ # Sample how many episodes in each train iteration.
+ episode_size=32,
+ # Generate how many samples from one episode.
+ sample_per_episode=4,
+ ),
+ ),
+ ),
+)
+bitflip_her_dqn_config = EasyDict(bitflip_her_dqn_config)
+main_config = bitflip_her_dqn_config
+
+bitflip_her_dqn_create_config = dict(
+ env=dict(
+ type='bitflip',
+ import_names=['dizoo.bitflip.envs.bitflip_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn'),
+ replay_buffer=dict(type='episode'),
+ collector=dict(type='episode'),
+)
+bitflip_her_dqn_create_config = EasyDict(bitflip_her_dqn_create_config)
+create_config = bitflip_her_dqn_create_config
+
+if __name__ == '__main__':
+ from dizoo.bitflip.entry.bitflip_dqn_main import main
+ main(main_config, seed=0)
diff --git a/DI-engine/dizoo/bitflip/config/bitflip_pure_dqn_config.py b/DI-engine/dizoo/bitflip/config/bitflip_pure_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e1eeb3e226ed67cae7aca48b1b81f7c6f5428ae
--- /dev/null
+++ b/DI-engine/dizoo/bitflip/config/bitflip_pure_dqn_config.py
@@ -0,0 +1,59 @@
+from easydict import EasyDict
+
+n_bits = 4
+bitflip_pure_dqn_config = dict(
+ exp_name='bitflip_{}bit_puredqn_seed0'.format(n_bits),
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_bits=n_bits,
+ n_evaluator_episode=8,
+ stop_value=0.9,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=2 * n_bits,
+ action_shape=n_bits,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ discount_factor=0.9,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=128,
+ learning_rate=0.0005,
+ target_update_freq=500,
+ ),
+ collect=dict(n_episode=8, unroll_len=1, collector=dict(get_train_sample=True, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=4000, ),
+ ),
+ ),
+)
+bitflip_pure_dqn_config = EasyDict(bitflip_pure_dqn_config)
+main_config = bitflip_pure_dqn_config
+
+bitflip_pure_dqn_create_config = dict(
+ env=dict(
+ type='bitflip',
+ import_names=['dizoo.bitflip.envs.bitflip_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn'),
+ replay_buffer=dict(type='episode'),
+ collector=dict(type='episode'),
+)
+bitflip_pure_dqn_create_config = EasyDict(bitflip_pure_dqn_create_config)
+create_config = bitflip_pure_dqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c bitflip_pure_dqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/bitflip/entry/__init__.py b/DI-engine/dizoo/bitflip/entry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/bitflip/entry/bitflip_dqn_main.py b/DI-engine/dizoo/bitflip/entry/bitflip_dqn_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..89f996986352fa306792f74f6c441c03cd0c6ebf
--- /dev/null
+++ b/DI-engine/dizoo/bitflip/entry/bitflip_dqn_main.py
@@ -0,0 +1,103 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+from functools import partial
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, EpisodeSerialCollector, InteractionSerialEvaluator, EpisodeReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from ding.reward_model import HerRewardModel
+from dizoo.bitflip.envs import BitFlipEnv
+from dizoo.bitflip.config import bitflip_pure_dqn_config, bitflip_her_dqn_config
+
+
+def main(cfg, seed=0, max_train_iter=int(1e8), max_env_step=int(1e8)):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ DQNPolicy,
+ BaseLearner,
+ EpisodeSerialCollector,
+ InteractionSerialEvaluator,
+ EpisodeReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(
+ env_fn=[partial(BitFlipEnv, cfg=cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManager(
+ env_fn=[partial(BitFlipEnv, cfg=cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = EpisodeSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = EpisodeReplayBuffer(
+ cfg.policy.other.replay_buffer, exp_name=cfg.exp_name, instance_name='episode_buffer'
+ )
+
+ # Set up other modules, etc. epsilon greedy, hindsight experience replay
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+ her_cfg = cfg.policy.other.get('her', None)
+ if her_cfg is not None:
+ her_model = HerRewardModel(her_cfg, cfg.policy.cuda)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluating at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Sampling data from environments
+ new_episode = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_episode, cur_collector_envstep=collector.envstep)
+ # Training
+ for i in range(cfg.policy.learn.update_per_collect):
+ if her_cfg and her_model.episode_size is not None:
+ sample_size = her_model.episode_size
+ else:
+ sample_size = learner.policy.get_attribute('batch_size')
+ train_episode = replay_buffer.sample(sample_size, learner.train_iter)
+ if train_episode is None:
+ break
+ train_data = []
+ if her_cfg is not None:
+ her_episodes = []
+ for e in train_episode:
+ her_episodes.extend(her_model.estimate(e))
+ # Only use samples modified by HER reward_model to train.
+ for e in her_episodes:
+ train_data.extend(policy.collect_mode.get_train_sample(e))
+ learner.train(train_data, collector.envstep)
+ if learner.train_iter >= max_train_iter or collector.envstep >= max_env_step:
+ break
+
+
+if __name__ == "__main__":
+ # main(bitflip_pure_dqn_config)
+ main(bitflip_her_dqn_config)
diff --git a/DI-engine/dizoo/bitflip/envs/__init__.py b/DI-engine/dizoo/bitflip/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..05c4dc2ab3b5368751865d4ae83733647e8fb107
--- /dev/null
+++ b/DI-engine/dizoo/bitflip/envs/__init__.py
@@ -0,0 +1 @@
+from .bitflip_env import BitFlipEnv
diff --git a/DI-engine/dizoo/bitflip/envs/bitflip_env.py b/DI-engine/dizoo/bitflip/envs/bitflip_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c74b174be885ceeb2d4fd0cbc4bf71f66b058c2
--- /dev/null
+++ b/DI-engine/dizoo/bitflip/envs/bitflip_env.py
@@ -0,0 +1,91 @@
+import copy
+import random
+import numpy as np
+import gym
+from typing import Any, Dict, Optional, Union, List
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.utils import ENV_REGISTRY
+from ding.torch_utils import to_ndarray
+
+
+@ENV_REGISTRY.register('bitflip')
+class BitFlipEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._n_bits = cfg.n_bits
+ self._state = np.zeros(self._n_bits)
+ self._goal = np.zeros(self._n_bits)
+ self._curr_step = 0
+ self._maxsize = self._n_bits
+ self._eval_episode_return = 0
+ self._observation_space = gym.spaces.Box(low=0, high=1, shape=(2 * self._n_bits, ), dtype=np.float32)
+ self._action_space = gym.spaces.Discrete(self._n_bits)
+ self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, ), dtype=np.float32)
+
+ def reset(self) -> np.ndarray:
+ self._curr_step = 0
+ self._eval_episode_return = 0
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ random_seed = 100 * random.randint(1, 1000)
+ np.random.seed(self._seed + random_seed)
+ elif hasattr(self, '_seed'):
+ np.random.seed(self._seed)
+ self._state = np.random.randint(0, 2, size=(self._n_bits, )).astype(np.float32)
+ self._goal = np.random.randint(0, 2, size=(self._n_bits, )).astype(np.float32)
+
+ while (self._state == self._goal).all():
+ self._goal = np.random.randint(0, 2, size=(self._n_bits, )).astype(np.float32)
+
+ obs = np.concatenate([self._state, self._goal], axis=0)
+ return obs
+
+ def close(self) -> None:
+ pass
+
+ def check_success(self, state: np.ndarray, goal: np.ndarray) -> bool:
+ return (self._state == self._goal).all()
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ self._state[action] = 1 - self._state[action]
+ if self.check_success(self._state, self._goal):
+ rew = np.array([1]).astype(np.float32)
+ done = True
+ else:
+ rew = np.array([0]).astype(np.float32)
+ done = False
+ self._eval_episode_return += float(rew)
+ if self._curr_step >= self._maxsize - 1:
+ done = True
+ info = {}
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ self._curr_step += 1
+ obs = np.concatenate([self._state, self._goal], axis=0)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine BitFlip Env({})".format('bitflip')
diff --git a/DI-engine/dizoo/bitflip/envs/test_bitfilp_env.py b/DI-engine/dizoo/bitflip/envs/test_bitfilp_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd5124e715ddba2a00cbb813fea4821bc053cef7
--- /dev/null
+++ b/DI-engine/dizoo/bitflip/envs/test_bitfilp_env.py
@@ -0,0 +1,24 @@
+import pytest
+from easydict import EasyDict
+import numpy as np
+from dizoo.bitflip.envs import BitFlipEnv
+
+
+@pytest.mark.envtest
+def test_bitfilp_env():
+ n_bits = 10
+ env = BitFlipEnv(EasyDict({'n_bits': n_bits}))
+ env.seed(314)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (2 * n_bits, )
+ for i in range(10):
+ # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
+ # can generate legal random action.
+ if i < 5:
+ action = np.random.randint(0, n_bits, size=(1, ))
+ else:
+ action = env.random_action()
+ timestep = env.step(action)
+ assert timestep.obs.shape == (2 * n_bits, )
+ assert timestep.reward.shape == (1, )
diff --git a/DI-engine/dizoo/box2d/__init__.py b/DI-engine/dizoo/box2d/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/__init__.py b/DI-engine/dizoo/box2d/bipedalwalker/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e8c511ad631fad96ee1a528829543c495898f3f
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/__init__.py
@@ -0,0 +1 @@
+from dizoo.box2d.bipedalwalker.config import *
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/config/__init__.py b/DI-engine/dizoo/box2d/bipedalwalker/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d12677bb06b2ed77e0ae5fa689074822f129023c
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/config/__init__.py
@@ -0,0 +1 @@
+from .bipedalwalker_sac_config import bipedalwalker_sac_config, bipedalwalker_sac_create_config
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_a2c_config.py b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_a2c_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c82542597f8bbc5d890b37abff46523e93f75247
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_a2c_config.py
@@ -0,0 +1,64 @@
+from easydict import EasyDict
+
+bipedalwalker_a2c_config = dict(
+ exp_name='bipedalwalker_a2c_seed0',
+ env=dict(
+ env_id='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=300,
+ rew_clip=True,
+ # The path to save the game replay
+ # replay_path='./bipedalwalker_a2c_seed0/video',
+ ),
+ policy=dict(
+ cuda=True,
+ # load_path="./bipedalwalker_a2c_seed0/ckpt/ckpt_best.pth.tar",
+ action_space='continuous',
+ model=dict(
+ action_space='continuous',
+ obs_shape=24,
+ action_shape=4,
+ ),
+ learn=dict(
+ # (int) the number of data for a train iteration
+ batch_size=256,
+ learning_rate=0.0003,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.001,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ adv_norm=True,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=512,
+ discount_factor=0.99,
+ collector=dict(collect_print_freq=100, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ ),
+)
+bipedalwalker_a2c_config = EasyDict(bipedalwalker_a2c_config)
+main_config = bipedalwalker_a2c_config
+bipedalwalker_a2c_create_config = dict(
+ env=dict(
+ type='bipedalwalker',
+ import_names=['dizoo.box2d.bipedalwalker.envs.bipedalwalker_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='a2c'),
+ replay_buffer=dict(type='naive'),
+)
+bipedalwalker_a2c_create_config = EasyDict(bipedalwalker_a2c_create_config)
+create_config = bipedalwalker_a2c_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c bipedalwalker_a2c_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_bco_config.py b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_bco_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c98e695dcc31f03ea30a4a8984931374f5accc27
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_bco_config.py
@@ -0,0 +1,94 @@
+from easydict import EasyDict
+
+bipedalwalker_bco_config = dict(
+ exp_name='bipedalwalker_bco_seed0',
+ env=dict(
+ env_id='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=300,
+ rew_clip=True,
+ # The path to save the game replay
+ replay_path=None,
+ ),
+ policy=dict(
+ # Whether to use cuda for network.
+ cuda=True,
+ continuous=True,
+ loss_type='l1_loss',
+ model=dict(
+ obs_shape=24,
+ action_shape=4,
+ action_space='regression',
+ actor_head_hidden_size=128,
+ ),
+ learn=dict(
+ train_epoch=30,
+ batch_size=128,
+ learning_rate=0.01,
+ weight_decay=1e-4,
+ decay_epoch=1000,
+ decay_rate=0.5,
+ warmup_lr=1e-4,
+ warmup_epoch=3,
+ optimizer='SGD',
+ lr_decay=True,
+ momentum=0.9,
+ ),
+ collect=dict(
+ n_episode=100,
+ # control the number (alpha*n_episode) of post-demonstration environment interactions at each iteration.
+ # Notice: alpha * n_episode > collector_env_num
+ model_path='abs model path', # expert model path
+ data_path='abs data path', # expert data path
+ noise=True,
+ noise_sigma=dict(
+ start=0.5,
+ end=0.1,
+ decay=1000000,
+ type='exp',
+ ),
+ noise_range=dict(
+ min=-1,
+ max=1,
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
+ ),
+ bco=dict(
+ learn=dict(idm_batch_size=128, idm_learning_rate=0.001, idm_weight_decay=0, idm_train_epoch=50),
+ model=dict(
+ action_space='regression',
+ idm_encoder_hidden_size_list=[60, 80, 100, 40],
+ ),
+ alpha=0.2,
+ )
+)
+
+bipedalwalker_bco_config = EasyDict(bipedalwalker_bco_config)
+main_config = bipedalwalker_bco_config
+
+bipedalwalker_bco_create_config = dict(
+ env=dict(
+ type='bipedalwalker',
+ import_names=['dizoo.box2d.bipedalwalker.envs.bipedalwalker_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='bc'),
+ collector=dict(type='episode'),
+)
+bipedalwalker_bco_create_config = EasyDict(bipedalwalker_bco_create_config)
+create_config = bipedalwalker_bco_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_bco
+ from dizoo.box2d.bipedalwalker.config import bipedalwalker_sac_config, bipedalwalker_sac_create_config
+ expert_main_config = bipedalwalker_sac_config
+ expert_create_config = bipedalwalker_sac_create_config
+ serial_pipeline_bco(
+ [main_config, create_config], [expert_main_config, expert_create_config], seed=0, max_env_step=2000000
+ )
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_ddpg_config.py b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_ddpg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..de70a09c86bbeb7541b4e5f670977e8f57ca3899
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_ddpg_config.py
@@ -0,0 +1,55 @@
+from easydict import EasyDict
+
+bipedalwalker_ddpg_config = dict(
+ exp_name='bipedalwalker_ddpg_seed0',
+ env=dict(
+ env_id='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ rew_clip=True,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=24,
+ action_shape=4,
+ twin_critic=False,
+ action_space='regression',
+ actor_head_hidden_size=400,
+ critic_head_hidden_size=400,
+ ),
+ learn=dict(
+ update_per_collect=64,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ target_theta=0.005,
+ discount_factor=0.99,
+ learner=dict(hook=dict(log_show_after_iter=1000, ))
+ ),
+ collect=dict(n_sample=64, ),
+ other=dict(replay_buffer=dict(replay_buffer_size=300000, ), ),
+ ),
+)
+bipedalwalker_ddpg_config = EasyDict(bipedalwalker_ddpg_config)
+main_config = bipedalwalker_ddpg_config
+
+bipedalwalker_ddpg_create_config = dict(
+ env=dict(
+ type='bipedalwalker',
+ import_names=['dizoo.box2d.bipedalwalker.envs.bipedalwalker_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ddpg'),
+)
+bipedalwalker_ddpg_create_config = EasyDict(bipedalwalker_ddpg_create_config)
+create_config = bipedalwalker_ddpg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c bipedalwalker_ddpg_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0, max_env_step=int(1e5))
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_dt_config.py b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..41055503fb8463ad386819d2c6a82785f18a3599
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_dt_config.py
@@ -0,0 +1,82 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+bipedalwalker_dt_config = dict(
+ exp_name='bipedalwalker_dt_1000eps_seed0',
+ env=dict(
+ env_name='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=300, # stop when return arrive 300
+ rew_clip=True, # reward clip
+ replay_path=None,
+ ),
+ policy=dict(
+ stop_value=300,
+ device='cuda',
+ env_name='BipedalWalker-v3',
+ rtg_target=300, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ num_eval_ep=10, # num of evaluation episode
+ batch_size=64,
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ num_updates_per_iter=100,
+ context_len=20,
+ n_blocks=3,
+ embed_dim=128,
+ n_heads=1,
+ dropout_p=0.1,
+ log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/box2d/bipedalwalker/dt_data/dt_log_1000eps',
+ model=dict(
+ state_dim=24,
+ act_dim=4,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ discount_factor=0.999,
+ nstep=3,
+ learn=dict(
+ dataset_path='/home/wangzilin/research/dt/sac_data_1000eps.pkl',
+ learning_rate=0.0001,
+ target_update_freq=100,
+ kappa=1.0,
+ min_q_weight=4.0,
+ ),
+ collect=dict(unroll_len=1, ),
+ eval=dict(evaluator=dict(evalu_freq=100, ), ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=1000, ),
+ ),
+ ),
+)
+
+bipedalwalker_dt_config = EasyDict(bipedalwalker_dt_config)
+main_config = bipedalwalker_dt_config
+bipedalwalker_dt_create_config = dict(
+ env=dict(
+ type='bipedalwalker',
+ import_names=['dizoo.box2d.bipedalwalker.envs.bipedalwalker_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+bipedalwalker_dt_create_config = EasyDict(bipedalwalker_dt_create_config)
+create_config = bipedalwalker_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_gail_sac_config.py b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_gail_sac_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..4ef3d1b0687b084b5a827f951ce60682034d7a4a
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_gail_sac_config.py
@@ -0,0 +1,96 @@
+from easydict import EasyDict
+
+obs_shape = 24
+act_shape = 4
+bipedalwalker_sac_gail_default_config = dict(
+ exp_name='bipedalwalker_sac_gail_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=300,
+ rew_clip=True,
+ # The path to save the game replay
+ replay_path=None,
+ ),
+ reward_model=dict(
+ type='gail',
+ input_size=obs_shape + act_shape,
+ hidden_size=64,
+ batch_size=64,
+ learning_rate=1e-3,
+ update_per_collect=100,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder+/reward_model/ckpt/ckpt_best.pth.tar',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ data_path='data_path_placeholder',
+ collect_count=100000,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ random_collect_size=1000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=act_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=128,
+ learning_rate_q=0.001,
+ learning_rate_policy=0.001,
+ learning_rate_alpha=0.0003,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ auto_alpha=True,
+ value_network=False,
+ ),
+ collect=dict(
+ n_sample=128,
+ unroll_len=1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
+ ),
+)
+bipedalwalker_sac_gail_default_config = EasyDict(bipedalwalker_sac_gail_default_config)
+main_config = bipedalwalker_sac_gail_default_config
+
+bipedalwalker_sac_gail_create_config = dict(
+ env=dict(
+ type='bipedalwalker',
+ import_names=['dizoo.box2d.bipedalwalker.envs.bipedalwalker_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+bipedalwalker_sac_gail_create_config = EasyDict(bipedalwalker_sac_gail_create_config)
+create_config = bipedalwalker_sac_gail_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_gail -c bipedalwalker_sac_gail_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. bipedalwalker_sac_config.py
+ from ding.entry import serial_pipeline_gail
+ from dizoo.box2d.bipedalwalker.config import bipedalwalker_sac_config, bipedalwalker_sac_create_config
+ expert_main_config = bipedalwalker_sac_config
+ expert_create_config = bipedalwalker_sac_create_config
+ serial_pipeline_gail(
+ [main_config, create_config], [expert_main_config, expert_create_config], seed=0, collect_data=True
+ )
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_impala_config.py b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_impala_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee58dbe8189e8180b93c421dda05529784e73b07
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_impala_config.py
@@ -0,0 +1,73 @@
+from easydict import EasyDict
+
+bipedalwalker_impala_config = dict(
+ exp_name='bipedalwalker_impala_seed0',
+ env=dict(
+ env_id='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=300,
+ rew_clip=True,
+ # The path to save the game replay
+ # replay_path='./bipedalwalker_impala_seed0/video',
+ ),
+ policy=dict(
+ cuda=True,
+ # (int) the trajectory length to calculate v-trace target
+ unroll_len=32,
+ random_collect_size=256,
+ # load_path="./bipedalwalker_impala_seed0/ckpt/ckpt_best.pth.tar",
+ action_space='continuous',
+ model=dict(
+ action_space='continuous',
+ obs_shape=24,
+ action_shape=4,
+ ),
+ learn=dict(
+ # (int) collect n_sample data, train model update_per_collect times
+ # here we follow impala serial pipeline
+ update_per_collect=3, # update_per_collect show be in [1, 10]
+ # (int) the number of data for a train iteration
+ batch_size=64,
+ grad_clip_type='clip_norm',
+ clip_value=5,
+ learning_rate=0.0003,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ # (float) additional discounting parameter
+ lambda_=0.99,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=32,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
+ ),
+)
+bipedalwalker_impala_config = EasyDict(bipedalwalker_impala_config)
+main_config = bipedalwalker_impala_config
+bipedalwalker_impala_create_config = dict(
+ env=dict(
+ type='bipedalwalker',
+ import_names=['dizoo.box2d.bipedalwalker.envs.bipedalwalker_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='impala'),
+ replay_buffer=dict(type='naive'),
+)
+bipedalwalker_impala_create_config = EasyDict(bipedalwalker_impala_create_config)
+create_config = bipedalwalker_impala_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c bipedalwalker_impala_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_pg_config.py b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_pg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..96aa08aee80b5ba0f1f25ec84777eaebbe1a2e0e
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_pg_config.py
@@ -0,0 +1,52 @@
+from easydict import EasyDict
+
+bipedalwalker_pg_config = dict(
+ exp_name='bipedalwalker_pg_seed0',
+ env=dict(
+ env_id='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=300,
+ rew_clip=True,
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='continuous',
+ model=dict(
+ action_space='continuous',
+ obs_shape=24,
+ action_shape=4,
+ ),
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ entropy_weight=0.001,
+ ),
+ collect=dict(
+ n_episode=8,
+ unroll_len=1,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=200, ))
+ ),
+)
+bipedalwalker_pg_config = EasyDict(bipedalwalker_pg_config)
+main_config = bipedalwalker_pg_config
+bipedalwalker_pg_create_config = dict(
+ env=dict(
+ type='bipedalwalker',
+ import_names=['dizoo.box2d.bipedalwalker.envs.bipedalwalker_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='pg'),
+ collector=dict(type='episode'),
+)
+bipedalwalker_pg_create_config = EasyDict(bipedalwalker_pg_create_config)
+create_config = bipedalwalker_pg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c bipedalwalker_pg_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_ppo_config.py b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..348c25483448892e386811214af72a210a0d9b27
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_ppo_config.py
@@ -0,0 +1,59 @@
+from easydict import EasyDict
+
+bipedalwalker_ppo_config = dict(
+ exp_name='bipedalwalker_ppo_seed0',
+ env=dict(
+ env_id='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=300,
+ rew_clip=True,
+ # The path to save the game replay
+ # replay_path='./bipedalwalker_ppo_seed0/video',
+ ),
+ policy=dict(
+ cuda=False,
+ load_path="./bipedalwalker_ppo_seed0/ckpt/ckpt_best.pth.tar",
+ action_space='continuous',
+ model=dict(
+ action_space='continuous',
+ obs_shape=24,
+ action_shape=4,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ adv_norm=True,
+ ),
+ collect=dict(
+ n_sample=2048,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+)
+bipedalwalker_ppo_config = EasyDict(bipedalwalker_ppo_config)
+main_config = bipedalwalker_ppo_config
+bipedalwalker_ppo_create_config = dict(
+ env=dict(
+ type='bipedalwalker',
+ import_names=['dizoo.box2d.bipedalwalker.envs.bipedalwalker_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+bipedalwalker_ppo_create_config = EasyDict(bipedalwalker_ppo_create_config)
+create_config = bipedalwalker_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c bipedalwalker_ppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_ppopg_config.py b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_ppopg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e2e7df403d711c769c51f07dc0641e2a53d4f6d
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_ppopg_config.py
@@ -0,0 +1,84 @@
+from easydict import EasyDict
+import torch
+import torch.nn as nn
+from ding.model.common import FCEncoder, ReparameterizationHead
+
+bipedalwalker_ppo_config = dict(
+ exp_name='bipedalwalker_ppopg',
+ env=dict(
+ env_id='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=500,
+ rew_clip=True,
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='continuous',
+ model=dict(
+ obs_shape=24,
+ action_shape=4,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ entropy_weight=0.0001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ ),
+ collect=dict(
+ n_episode=16,
+ discount_factor=0.99,
+ collector=dict(get_train_sample=True),
+ ),
+ ),
+)
+bipedalwalker_ppo_config = EasyDict(bipedalwalker_ppo_config)
+main_config = bipedalwalker_ppo_config
+bipedalwalker_ppo_create_config = dict(
+ env=dict(
+ type='bipedalwalker',
+ import_names=['dizoo.box2d.bipedalwalker.envs.bipedalwalker_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo_pg'),
+ collector=dict(type='episode'),
+)
+bipedalwalker_ppo_create_config = EasyDict(bipedalwalker_ppo_create_config)
+create_config = bipedalwalker_ppo_create_config
+
+
+class PPOPGContinuousModel(nn.Module):
+
+ def __init__(self, obs_shape, action_shape):
+ super(PPOPGContinuousModel, self).__init__()
+ self.encoder = nn.Sequential(nn.Linear(obs_shape, 64), nn.Tanh())
+ self.head = ReparameterizationHead(
+ hidden_size=64,
+ output_size=action_shape,
+ layer_num=2,
+ sigma_type='conditioned',
+ activation=nn.Tanh(),
+ )
+
+ def forward(self, inputs):
+ x = self.encoder(inputs)
+ x = self.head(x)
+ return {'logit': x}
+
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c bipedalwalker_ppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ from copy import deepcopy
+ for seed in [1, 2, 3]:
+ new_main_config = deepcopy(main_config)
+ new_main_config.exp_name += "_seed{}".format(seed)
+ model = PPOPGContinuousModel(new_main_config.policy.model.obs_shape, new_main_config.policy.model.action_shape)
+ serial_pipeline_onpolicy(
+ [new_main_config, deepcopy(create_config)], seed=seed, max_env_step=int(5e6), model=model
+ )
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_sac_config.py b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f905c4031b899b448dc9454d714a19928325df1d
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_sac_config.py
@@ -0,0 +1,57 @@
+from easydict import EasyDict
+
+bipedalwalker_sac_config = dict(
+ exp_name='bipedalwalker_sac_config0',
+ env=dict(
+ env_id='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ rew_clip=True,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=24,
+ action_shape=4,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ update_per_collect=64,
+ batch_size=256,
+ learning_rate_q=0.0003,
+ learning_rate_policy=0.0003,
+ learning_rate_alpha=0.0003,
+ target_theta=0.005,
+ discount_factor=0.99,
+ auto_alpha=True,
+ learner=dict(hook=dict(log_show_after_iter=1000, ))
+ ),
+ collect=dict(n_sample=64, ),
+ other=dict(replay_buffer=dict(replay_buffer_size=300000, ), ),
+ ),
+)
+bipedalwalker_sac_config = EasyDict(bipedalwalker_sac_config)
+main_config = bipedalwalker_sac_config
+bipedalwalker_sac_create_config = dict(
+ env=dict(
+ type='bipedalwalker',
+ import_names=['dizoo.box2d.bipedalwalker.envs.bipedalwalker_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sac', ),
+ replay_buffer=dict(type='naive', ),
+)
+bipedalwalker_sac_create_config = EasyDict(bipedalwalker_sac_create_config)
+create_config = bipedalwalker_sac_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c bipedalwalker_sac_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0, max_env_step=int(1e5))
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_td3_config.py b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_td3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..09cc3d1bf1a90026a313fff046aef74839da189a
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/config/bipedalwalker_td3_config.py
@@ -0,0 +1,62 @@
+from easydict import EasyDict
+
+bipedalwalker_td3_config = dict(
+ exp_name='bipedalwalker_td3_seed0',
+ env=dict(
+ env_id='BipedalWalker-v3',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ rew_clip=True,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=24,
+ action_shape=4,
+ twin_critic=True,
+ action_space='regression',
+ actor_head_hidden_size=400,
+ critic_head_hidden_size=400,
+ ),
+ learn=dict(
+ update_per_collect=64,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ learner=dict(hook=dict(log_show_after_iter=1000, ))
+ ),
+ collect=dict(n_sample=64, ),
+ other=dict(replay_buffer=dict(replay_buffer_size=300000, ), ),
+ ),
+)
+bipedalwalker_td3_config = EasyDict(bipedalwalker_td3_config)
+main_config = bipedalwalker_td3_config
+
+bipedalwalker_td3_create_config = dict(
+ env=dict(
+ type='bipedalwalker',
+ import_names=['dizoo.box2d.bipedalwalker.envs.bipedalwalker_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='td3'),
+)
+bipedalwalker_td3_create_config = EasyDict(bipedalwalker_td3_create_config)
+create_config = bipedalwalker_td3_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c bipedalwalker_td3_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0, max_env_step=int(1e5))
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/entry/__init__.py b/DI-engine/dizoo/box2d/bipedalwalker/entry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/entry/bipedalwalker_ppo_eval.py b/DI-engine/dizoo/box2d/bipedalwalker/entry/bipedalwalker_ppo_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..1423c8c27f66ec59223459e82ddfde4c1e4132e4
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/entry/bipedalwalker_ppo_eval.py
@@ -0,0 +1,60 @@
+import os
+import gym
+import torch
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+from functools import partial
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.policy import PPOPolicy
+from ding.model import VAC
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.box2d.bipedalwalker.config.bipedalwalker_ppo_config import main_config, create_config
+
+
+def main(rl_cfg, seed=0):
+ main_cfg, create_cfg = rl_cfg
+ cfg = compile_config(
+ main_cfg,
+ BaseEnvManager,
+ PPOPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ create_cfg=create_cfg,
+ save_cfg=True
+ )
+
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
+
+ # Set random seed for all package and instance
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
+
+ # evaluate
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator.eval()
+
+
+if __name__ == "__main__":
+ main(rl_cfg=(main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/envs/__init__.py b/DI-engine/dizoo/box2d/bipedalwalker/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a5350579ef5c05cec1981370b6801c97554aa4c
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/envs/__init__.py
@@ -0,0 +1 @@
+from .bipedalwalker_env import BipedalWalkerEnv
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/envs/bipedalwalker_env.py b/DI-engine/dizoo/box2d/bipedalwalker/envs/bipedalwalker_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae4d5260324e205544570fd118e3aa3f02e44e8b
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/envs/bipedalwalker_env.py
@@ -0,0 +1,108 @@
+from typing import Any, List, Union, Optional
+import time
+import gym
+import numpy as np
+from ding.envs import BaseEnv, BaseEnvTimestep, FrameStackWrapper
+from ding.torch_utils import to_ndarray, to_list
+from ding.envs.common.common_function import affine_transform
+from ding.utils import ENV_REGISTRY
+
+
+@ENV_REGISTRY.register('bipedalwalker')
+class BipedalWalkerEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ self._act_scale = cfg.act_scale
+ self._rew_clip = cfg.rew_clip
+ if "replay_path" in cfg:
+ self._replay_path = cfg.replay_path
+ else:
+ self._replay_path = None
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = gym.make('BipedalWalker-v3')
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ if self._replay_path is not None:
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+ obs = to_ndarray(obs).astype(np.float32)
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def render(self) -> None:
+ self._env.render()
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ assert isinstance(action, np.ndarray), type(action)
+ if action.shape == (1, ):
+ action = action.squeeze() # 0-dim array
+ if self._act_scale:
+ action = affine_transform(action, min_val=self.action_space.low, max_val=self.action_space.high)
+
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ if self._rew_clip:
+ rew = max(-10, rew)
+ rew = np.float32(rew)
+
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ obs = to_ndarray(obs).astype(np.float32)
+ rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ if isinstance(random_action, np.ndarray):
+ pass
+ elif isinstance(random_action, int):
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine BipedalWalker Env"
diff --git a/DI-engine/dizoo/box2d/bipedalwalker/envs/test_bipedalwalker.py b/DI-engine/dizoo/box2d/bipedalwalker/envs/test_bipedalwalker.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a5cce4755dc395c1cdf3d772846c52a75a14683
--- /dev/null
+++ b/DI-engine/dizoo/box2d/bipedalwalker/envs/test_bipedalwalker.py
@@ -0,0 +1,28 @@
+import pytest
+from easydict import EasyDict
+import numpy as np
+from dizoo.box2d.bipedalwalker.envs import BipedalWalkerEnv
+
+
+@pytest.mark.envtest
+class TestBipedalWalkerEnv:
+
+ def test_naive(self):
+ env = BipedalWalkerEnv(EasyDict({'act_scale': True, 'rew_clip': True, 'replay_path': None}))
+ env.seed(123)
+ assert env._seed == 123
+ obs = env.reset()
+ assert obs.shape == (24, )
+ for i in range(10):
+ random_action = env.random_action()
+ timestep = env.step(random_action)
+ print(timestep)
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (24, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.reward_space.low
+ assert timestep.reward <= env.reward_space.high
+ # assert isinstance(timestep, tuple)
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/box2d/carracing/__init__.py b/DI-engine/dizoo/box2d/carracing/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/box2d/carracing/config/__init__.py b/DI-engine/dizoo/box2d/carracing/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1571e58a64c573064668bf2fa120ba76a2e6fa1a
--- /dev/null
+++ b/DI-engine/dizoo/box2d/carracing/config/__init__.py
@@ -0,0 +1 @@
+from .carracing_dqn_config import carracing_dqn_config, carracing_dqn_create_config
diff --git a/DI-engine/dizoo/box2d/carracing/config/carracing_dqn_config.py b/DI-engine/dizoo/box2d/carracing/config/carracing_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1792056a836169218537f8644b08afafb1d67304
--- /dev/null
+++ b/DI-engine/dizoo/box2d/carracing/config/carracing_dqn_config.py
@@ -0,0 +1,60 @@
+from easydict import EasyDict
+
+nstep = 3
+carracing_dqn_config = dict(
+ exp_name='carracing_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='CarRacing-v2',
+ continuous=False,
+ n_evaluator_episode=8,
+ stop_value=900,
+ # replay_path='./carracing_dqn_seed0/video',
+ ),
+ policy=dict(
+ cuda=True,
+ # load_path='carracing_dqn_seed0/ckpt/ckpt_best.pth.tar',
+ model=dict(
+ obs_shape=[3, 96, 96],
+ action_shape=5,
+ encoder_hidden_size_list=[64, 64, 128],
+ dueling=True,
+ ),
+ discount_factor=0.99,
+ nstep=nstep,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=64,
+ learning_rate=0.0001,
+ target_update_freq=100,
+ ),
+ collect=dict(n_sample=64, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=50000,
+ ), replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+)
+carracing_dqn_config = EasyDict(carracing_dqn_config)
+main_config = carracing_dqn_config
+
+carracing_dqn_create_config = dict(
+ env=dict(
+ type='carracing',
+ import_names=['dizoo.box2d.carracing.envs.carracing_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+carracing_dqn_create_config = EasyDict(carracing_dqn_create_config)
+create_config = carracing_dqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c carracing_dqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/carracing/envs/__init__.py b/DI-engine/dizoo/box2d/carracing/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a36760ccf7f21077051dbde9f3275d711da37ac3
--- /dev/null
+++ b/DI-engine/dizoo/box2d/carracing/envs/__init__.py
@@ -0,0 +1 @@
+from .carracing_env import CarRacingEnv
diff --git a/DI-engine/dizoo/box2d/carracing/envs/carracing_env.py b/DI-engine/dizoo/box2d/carracing/envs/carracing_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..60ebaa97d14bb155c39c77b5e0beeb20d3fc5f21
--- /dev/null
+++ b/DI-engine/dizoo/box2d/carracing/envs/carracing_env.py
@@ -0,0 +1,160 @@
+from typing import Optional
+import copy
+import os
+
+import gym
+import numpy as np
+from easydict import EasyDict
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.envs import ObsPlusPrevActRewWrapper
+from ding.envs.common import affine_transform, save_frames_as_gif
+from ding.torch_utils import to_ndarray
+from ding.utils import ENV_REGISTRY
+
+
+@ENV_REGISTRY.register('carracing')
+class CarRacingEnv(BaseEnv):
+
+ config = dict(
+ replay_path=None,
+ save_replay_gif=False,
+ replay_path_gif=None,
+ action_clip=False,
+ )
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ # env_id:CarRacing-v2
+ self._env_id = cfg.env_id
+ self._replay_path = None
+ self._replay_path_gif = cfg.replay_path_gif
+ self._save_replay_gif = cfg.save_replay_gif
+ self._save_replay_count = 0
+ if cfg.continuous:
+ self._act_scale = cfg.act_scale # act_scale only works in continuous env
+ self._action_clip = cfg.action_clip
+ else:
+ self._act_scale = False
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = gym.make(self._cfg.env_id, continuous=self._cfg.continuous)
+ if self._replay_path is not None:
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+ self._observation_space = gym.spaces.Box(
+ low=np.min(self._env.observation_space.low.astype(np.float32) / 255),
+ high=np.max(self._env.observation_space.high.astype(np.float32) / 255),
+ shape=(
+ self._env.observation_space.shape[2], self._env.observation_space.shape[0],
+ self._env.observation_space.shape[1]
+ ),
+ dtype=np.float32
+ )
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+ obs = obs.astype(np.float32) / 255
+ obs = obs.transpose(2, 0, 1)
+ obs = to_ndarray(obs)
+ if self._save_replay_gif:
+ self._frames = []
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def render(self) -> None:
+ self._env.render()
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ assert isinstance(action, np.ndarray), type(action)
+ if action.shape == (1, ):
+ action = action.item() # 0-dim array
+ if self._act_scale:
+ action = affine_transform(action, action_clip=self._action_clip, min_val=-1, max_val=1)
+ if self._save_replay_gif:
+ self._frames.append(self._env.render(mode='rgb_array'))
+ obs, rew, done, info = self._env.step(action)
+ obs = obs.astype(np.float32) / 255
+ obs = obs.transpose(2, 0, 1)
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ if self._save_replay_gif:
+ if not os.path.exists(self._replay_path_gif):
+ os.makedirs(self._replay_path_gif)
+ path = os.path.join(
+ self._replay_path_gif, '{}_episode_{}.gif'.format(self._env_id, self._save_replay_count)
+ )
+ save_frames_as_gif(self._frames, path)
+ self._save_replay_count += 1
+
+ obs = to_ndarray(obs)
+ rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transferred to a array with shape (1,)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+ self._save_replay_gif = True
+ self._save_replay_count = 0
+ # this function can lead to the meaningless result
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ if isinstance(random_action, np.ndarray):
+ pass
+ elif isinstance(random_action, int):
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine CarRacing Env"
diff --git a/DI-engine/dizoo/box2d/carracing/envs/test_carracing_env.py b/DI-engine/dizoo/box2d/carracing/envs/test_carracing_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..47a5fa463809f84a4b3f6673c83c9e502f1d8785
--- /dev/null
+++ b/DI-engine/dizoo/box2d/carracing/envs/test_carracing_env.py
@@ -0,0 +1,28 @@
+import pytest
+import numpy as np
+from easydict import EasyDict
+from carracing_env import CarRacingEnv
+
+
+@pytest.mark.envtest
+@pytest.mark.parametrize('cfg', [EasyDict({'env_id': 'CarRacing-v2', 'continuous': False, 'act_scale': False})])
+class TestCarRacing:
+
+ def test_naive(self, cfg):
+ env = CarRacingEnv(cfg)
+ env.seed(314)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (3, 96, 96)
+ for i in range(10):
+ random_action = env.random_action()
+ timestep = env.step(random_action)
+ print(timestep)
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (3, 96, 96)
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.reward_space.low
+ assert timestep.reward <= env.reward_space.high
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/box2d/lunarlander/__init__.py b/DI-engine/dizoo/box2d/lunarlander/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/__init__.py b/DI-engine/dizoo/box2d/lunarlander/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..200970a2083f1b64df3b9dcccacbc747b5976a7c
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/__init__.py
@@ -0,0 +1,6 @@
+from .lunarlander_dqn_config import lunarlander_dqn_config, lunarlander_dqn_create_config
+from .lunarlander_gail_dqn_config import lunarlander_dqn_gail_create_config, lunarlander_dqn_gail_config
+from .lunarlander_dqfd_config import lunarlander_dqfd_config, lunarlander_dqfd_create_config
+from .lunarlander_qrdqn_config import lunarlander_qrdqn_config, lunarlander_qrdqn_create_config
+from .lunarlander_trex_dqn_config import lunarlander_trex_dqn_config, lunarlander_trex_dqn_create_config
+from .lunarlander_trex_offppo_config import lunarlander_trex_ppo_config, lunarlander_trex_ppo_create_config
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_a2c_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_a2c_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5469bd58457ed7219da2d59ff7608806c80b77ab
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_a2c_config.py
@@ -0,0 +1,48 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+lunarlander_a2c_config = dict(
+ exp_name='lunarlander_a2c_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ ),
+ learn=dict(
+ batch_size=160,
+ learning_rate=3e-4,
+ entropy_weight=0.001,
+ adv_norm=True,
+ ),
+ collect=dict(
+ n_sample=320,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+)
+lunarlander_a2c_config = EasyDict(lunarlander_a2c_config)
+main_config = lunarlander_a2c_config
+lunarlander_a2c_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='a2c'),
+)
+lunarlander_a2c_create_config = EasyDict(lunarlander_a2c_create_config)
+create_config = lunarlander_a2c_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_acer_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_acer_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1673b3abef0b453148b310c0b15a31777593a3e0
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_acer_config.py
@@ -0,0 +1,82 @@
+from easydict import EasyDict
+
+nstep = 3
+lunarlander_acer_config = dict(
+ exp_name='lunarlander_acer_seed0',
+ env=dict(
+ # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
+ # Env number respectively for collector and evaluator.
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ # Whether to use cuda for network.
+ cuda=False,
+ # Model config used for model creating. Remember to change this,
+ # especially "obs_shape" and "action_shape" according to specific env.
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[512, 64],
+ # Whether to use dueling head.
+ ),
+ # Reward's future discount facotr, aka. gamma.
+ discount_factor=0.99,
+ # How many steps in td error.
+ nstep=nstep,
+ unroll_len=32,
+ # learn_mode config
+ learn=dict(
+ # (int) collect n_sample data, train model update_per_collect times
+ # here we follow impala serial pipeline
+ update_per_collect=10,
+ # (int) the number of data for a train iteration
+ batch_size=32,
+ # grad_clip_type='clip_norm',
+ # clip_value=10,
+ learning_rate_actor=0.0001,
+ learning_rate_critic=0.0001,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.0,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ # (float) additional discounting parameter
+ # (int) the trajectory length to calculate v-trace target
+ # (float) clip ratio of importance weights
+ c_clip_ratio=10,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=16,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=50000, ), ),
+ ),
+)
+lunarlander_acer_config = EasyDict(lunarlander_acer_config)
+main_config = lunarlander_acer_config
+
+lunarlander_acer_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='acer'),
+ replay_buffer=dict(type='naive')
+)
+lunarlander_acer_create_config = EasyDict(lunarlander_acer_create_config)
+create_config = lunarlander_acer_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c lunarlander_acer_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_bco_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_bco_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f7d0bfba7533838f3ff5fa389fb1723eaa29b31
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_bco_config.py
@@ -0,0 +1,89 @@
+from easydict import EasyDict
+
+nstep = 3
+lunarlander_bco_config = dict(
+ exp_name='lunarlander_bco_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ # Whether to use cuda for network.
+ cuda=True,
+ continuous=False,
+ loss_type='l1_loss',
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[512, 64],
+ # Whether to use dueling head.
+ dueling=True,
+ ),
+ # learn_mode config
+ learn=dict(
+ update_per_collect=10,
+ train_epoch=20,
+ batch_size=64,
+ learning_rate=0.001,
+ weight_decay=1e-4,
+ decay_epoch=1000,
+ decay_rate=0.5,
+ warmup_lr=1e-4,
+ warmup_epoch=3,
+ optimizer='SGD',
+ lr_decay=True,
+ momentum=0.9,
+ ),
+ # collect_mode config
+ collect=dict(
+ n_episode=100,
+ model_path='abs model path', # expert model path
+ data_path='abs data path', # expert data path
+ ),
+ # eval_mode config
+ eval=dict(evaluator=dict(eval_freq=50, )),
+ # command_mode config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=50000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+ bco=dict(
+ learn=dict(idm_batch_size=256, idm_learning_rate=0.001, idm_weight_decay=1e-4, idm_train_epoch=10),
+ model=dict(idm_encoder_hidden_size_list=[60, 80, 100, 40], action_space='discrete'),
+ alpha=0.2,
+ )
+)
+lunarlander_bco_config = EasyDict(lunarlander_bco_config)
+main_config = lunarlander_bco_config
+
+lunarlander_bco_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='bc'),
+ collector=dict(type='episode'),
+)
+lunarlander_bco_create_config = EasyDict(lunarlander_bco_create_config)
+create_config = lunarlander_bco_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_bco
+ from dizoo.box2d.lunarlander.config import lunarlander_dqn_config, lunarlander_dqn_create_config
+ expert_main_config = lunarlander_dqn_config
+ expert_create_config = lunarlander_dqn_create_config
+ serial_pipeline_bco(
+ [main_config, create_config], [expert_main_config, expert_create_config], seed=0, max_env_step=2000000
+ )
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_c51_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_c51_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a843f838cf62b38d5fb8ea2a6cf853ba2864f6a
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_c51_config.py
@@ -0,0 +1,60 @@
+from easydict import EasyDict
+
+lunarlander_c51_config = dict(
+ exp_name='lunarlander_c51_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[512, 64],
+ v_min=-30,
+ v_max=30,
+ n_atom=51,
+ ),
+ discount_factor=0.99,
+ nstep=3,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=64,
+ learning_rate=0.001,
+ target_update_freq=100,
+ ),
+ collect=dict(
+ n_sample=64,
+ unroll_len=1,
+ ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=50000,
+ ), replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+)
+lunarlander_c51_config = EasyDict(lunarlander_c51_config)
+main_config = lunarlander_c51_config
+lunarlander_c51_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='c51'),
+)
+lunarlander_c51_create_config = EasyDict(lunarlander_c51_create_config)
+create_config = lunarlander_c51_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c lunarlander_c51_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_cont_ddpg_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_cont_ddpg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c03c041c9ff797c0a52a2541981f8de1435e4c74
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_cont_ddpg_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+
+lunarlander_ddpg_config = dict(
+ exp_name='lunarlander_cont_ddpgs_seed0',
+ env=dict(
+ env_id='LunarLanderContinuous-v2',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ random_collect_size=0,
+ model=dict(
+ obs_shape=8,
+ action_shape=2,
+ twin_critic=True,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=2,
+ batch_size=128,
+ learning_rate_actor=0.001,
+ learning_rate_critic=0.001,
+ ignore_done=False, # TODO(pu)
+ # (int) When critic network updates once, how many times will actor network update.
+ # Delayed Policy Updates in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default 1 for DDPG, 2 for TD3.
+ actor_update_freq=1,
+ # (bool) Whether to add noise on target network's action.
+ # Target Policy Smoothing Regularization in original TD3 paper(https://arxiv.org/pdf/1802.09477.pdf).
+ # Default True for TD3, False for DDPG.
+ noise=False,
+ noise_sigma=0.1,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=48,
+ noise_sigma=0.1,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ other=dict(replay_buffer=dict(replay_buffer_size=20000, ), ),
+ ),
+)
+lunarlander_ddpg_config = EasyDict(lunarlander_ddpg_config)
+main_config = lunarlander_ddpg_config
+
+lunarlander_ddpg_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ddpg'),
+)
+lunarlander_ddpg_create_config = EasyDict(lunarlander_ddpg_create_config)
+create_config = lunarlander_ddpg_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c lunarlander_cont_ddpg_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_cont_sac_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_cont_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f8a8ab47e761eea8f845fcfb75bfb74cc64979e4
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_cont_sac_config.py
@@ -0,0 +1,53 @@
+from easydict import EasyDict
+
+lunarlander_sac_config = dict(
+ exp_name='lunarlander_cont_sac_seed0',
+ env=dict(
+ env_id='LunarLanderContinuous-v2',
+ collector_env_num=4,
+ evaluator_env_num=8,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=False,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=8,
+ action_shape=2,
+ twin_critic=True,
+ action_space='reparameterization',
+ ),
+ learn=dict(
+ update_per_collect=256,
+ batch_size=128,
+ learning_rate_q=1e-3,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ auto_alpha=True,
+ ),
+ collect=dict(n_sample=256, ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ other=dict(replay_buffer=dict(replay_buffer_size=int(1e5), ), ),
+ ),
+)
+lunarlander_sac_config = EasyDict(lunarlander_sac_config)
+main_config = lunarlander_sac_config
+
+lunarlander_sac_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sac'),
+)
+lunarlander_sac_create_config = EasyDict(lunarlander_sac_create_config)
+create_config = lunarlander_sac_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c lunarlander_cont_sac_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_cont_td3_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_cont_td3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d95932f237598331ee2a93c2fda3e29ad62a8a7c
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_cont_td3_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+lunarlander_td3_config = dict(
+ exp_name='lunarlander_cont_td3_seed0',
+ env=dict(
+ env_id='LunarLanderContinuous-v2',
+ collector_env_num=4,
+ evaluator_env_num=8,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=False,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=8,
+ action_shape=2,
+ twin_critic=True,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=256,
+ batch_size=128,
+ learning_rate_actor=3e-4,
+ learning_rate_critic=1e-3,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.1,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=256,
+ noise_sigma=0.1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ other=dict(replay_buffer=dict(replay_buffer_size=int(1e5), ), ),
+ ),
+)
+lunarlander_td3_config = EasyDict(lunarlander_td3_config)
+main_config = lunarlander_td3_config
+
+lunarlander_td3_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='td3'),
+)
+lunarlander_td3_create_config = EasyDict(lunarlander_td3_create_config)
+create_config = lunarlander_td3_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c lunarlander_cont_td3_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_cont_td3_vae_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..70c727329c438affdff979578c7c2fe8e7865d1e
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_cont_td3_vae_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+
+lunarlander_td3vae_config = dict(
+ exp_name='lunarlander_cont_td3_vae_seed0',
+ env=dict(
+ env_id='LunarLanderContinuous-v2',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ random_collect_size=10000,
+ original_action_shape=2,
+ model=dict(
+ obs_shape=8,
+ action_shape=6, # latent_action_dim
+ twin_critic=True,
+ action_space='regression',
+ ),
+ learn=dict(
+ warm_up_update=int(1e4),
+ rl_vae_update_circle=1, # train rl 1 iter, vae 1 iter
+ update_per_collect_rl=256,
+ update_per_collect_vae=10,
+ batch_size=128,
+ learning_rate_actor=3e-4,
+ learning_rate_critic=3e-4,
+ learning_rate_vae=1e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.1,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=256,
+ unroll_len=1,
+ noise_sigma=0, # NOTE: add noise in original action in _forward_collect method of td3_vae policy
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ other=dict(replay_buffer=dict(replay_buffer_size=int(1e5), ), ),
+ ),
+)
+lunarlander_td3vae_config = EasyDict(lunarlander_td3vae_config)
+main_config = lunarlander_td3vae_config
+
+lunarlander_td3vae_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='td3_vae'),
+)
+lunarlander_td3vae_create_config = EasyDict(lunarlander_td3vae_create_config)
+create_config = lunarlander_td3vae_create_config
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline_td3_vae
+ serial_pipeline_td3_vae([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_discrete_sac_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_discrete_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b344e526de9d48bf6be69258180dbf0aee3d5c68
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_discrete_sac_config.py
@@ -0,0 +1,74 @@
+from easydict import EasyDict
+
+lunarlander_sac_config = dict(
+ exp_name='lunarlander_discrete_sac_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ random_collect_size=0,
+ multi_agent=False,
+ model=dict(
+ agent_obs_shape=8,
+ global_obs_shape=8,
+ action_shape=4,
+ twin_critic=True,
+ actor_head_hidden_size=64,
+ critic_head_hidden_size=64,
+ ),
+ learn=dict(
+ update_per_collect=2,
+ batch_size=64,
+ learning_rate_q=5e-3,
+ learning_rate_policy=5e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.01,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ env_num=8,
+ n_sample=256,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(
+ evaluator=dict(eval_freq=50, ),
+ env_num=5,
+ ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=50000,
+ ), replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+)
+
+lunarlander_sac_config = EasyDict(lunarlander_sac_config)
+main_config = lunarlander_sac_config
+
+lunarlander_sac_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sac_discrete', ),
+)
+lunarlander_sac_create_config = EasyDict(lunarlander_sac_create_config)
+create_config = lunarlander_sac_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c lunarlander_discrete_sac_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_dqfd_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_dqfd_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee2043de6057fd21dd071622772c80f13c47e74d
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_dqfd_config.py
@@ -0,0 +1,75 @@
+from easydict import EasyDict
+
+lunarlander_dqfd_config = dict(
+ exp_name='lunarlander_dqfd_seed0',
+ env=dict(
+ # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[512, 64],
+ dueling=True,
+ ),
+ nstep=3,
+ discount_factor=0.97,
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ lambda1=1.0,
+ lambda2=1.0,
+ lambda3=1e-5,
+ per_train_iter_k=10,
+ expert_replay_buffer_size=10000, # justify the buffer size of the expert buffer
+ ),
+ collect=dict(
+ n_sample=64,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ # Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=50, )), # note: this is the times after which you learns to evaluate
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+lunarlander_dqfd_config = EasyDict(lunarlander_dqfd_config)
+main_config = lunarlander_dqfd_config
+lunarlander_dqfd_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqfd'),
+)
+lunarlander_dqfd_create_config = EasyDict(lunarlander_dqfd_create_config)
+create_config = lunarlander_dqfd_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_dqfd -c lunarlander_dqfd_config.py -s 0`
+ # then input ``lunarlander_dqfd_config.py`` upon the instructions.
+ # The reason we need to input the dqfd config is we have to borrow its ``_get_train_sample`` function
+ # in the collector part even though the expert model may be generated from other Q learning algos.
+ from ding.entry.serial_entry_dqfd import serial_pipeline_dqfd
+ from dizoo.box2d.lunarlander.config import lunarlander_dqfd_config, lunarlander_dqfd_create_config
+ expert_main_config = lunarlander_dqfd_config
+ expert_create_config = lunarlander_dqfd_create_config
+ serial_pipeline_dqfd([main_config, create_config], [expert_main_config, expert_create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_dqn_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b4a67562dcbe24e3bc28921af591f3f594bd802f
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_dqn_config.py
@@ -0,0 +1,80 @@
+from easydict import EasyDict
+
+nstep = 3
+lunarlander_dqn_config = dict(
+ exp_name='lunarlander_dqn_seed0',
+ env=dict(
+ # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
+ # Env number respectively for collector and evaluator.
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ # The path to save the game replay
+ # replay_path='./lunarlander_dqn_seed0/video',
+ ),
+ policy=dict(
+ # Whether to use cuda for network.
+ cuda=True,
+ load_path="./lunarlander_dqn_seed0/ckpt/ckpt_best.pth.tar",
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[512, 64],
+ # Whether to use dueling head.
+ dueling=True,
+ ),
+ # Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # How many steps in td error.
+ nstep=nstep,
+ # learn_mode config
+ learn=dict(
+ update_per_collect=10,
+ batch_size=64,
+ learning_rate=0.001,
+ # Frequency of target network update.
+ target_update_freq=100,
+ ),
+ # collect_mode config
+ collect=dict(
+ # You can use either "n_sample" or "n_episode" in collector.collect.
+ # Get "n_sample" samples per collect.
+ n_sample=64,
+ # Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ # command_mode config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=50000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+)
+lunarlander_dqn_config = EasyDict(lunarlander_dqn_config)
+main_config = lunarlander_dqn_config
+
+lunarlander_dqn_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ # env_manager=dict(type='base'),
+ policy=dict(type='dqn'),
+)
+lunarlander_dqn_create_config = EasyDict(lunarlander_dqn_create_config)
+create_config = lunarlander_dqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c lunarlander_dqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_dqn_deque_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_dqn_deque_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab02cfcb64987c39b962256666f174b0bb6fab8c
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_dqn_deque_config.py
@@ -0,0 +1,79 @@
+from easydict import EasyDict
+
+nstep = 3
+lunarlander_dqn_config = dict(
+ exp_name='lunarlander_dqn_deque_seed0',
+ env=dict(
+ # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
+ # Env number respectively for collector and evaluator.
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ # Whether to use cuda for network.
+ cuda=False,
+ priority=True,
+ priority_IS_weight=False,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[512, 64],
+ # Whether to use dueling head.
+ dueling=True,
+ ),
+ # Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # How many steps in td error.
+ nstep=nstep,
+ # learn_mode config
+ learn=dict(
+ update_per_collect=10,
+ batch_size=64,
+ learning_rate=0.001,
+ # Frequency of target network update.
+ target_update_freq=100,
+ ),
+ # collect_mode config
+ collect=dict(
+ # You can use either "n_sample" or "n_episode" in collector.collect.
+ # Get "n_sample" samples per collect.
+ n_sample=64,
+ # Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ # command_mode config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=50000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+)
+lunarlander_dqn_config = EasyDict(lunarlander_dqn_config)
+main_config = lunarlander_dqn_config
+
+lunarlander_dqn_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+ replay_buffer=dict(type='deque'),
+)
+lunarlander_dqn_create_config = EasyDict(lunarlander_dqn_create_config)
+create_config = lunarlander_dqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c lunarlander_dqn_deque_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a3ff0a165b55e3784780cf944537ab2a3d71614
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_dt_config.py
@@ -0,0 +1,64 @@
+from easydict import EasyDict
+import torch
+from copy import deepcopy
+
+lunarlander_dt_config = dict(
+ exp_name='data_dt/lunarlander_dt_1000eps_rtgt300_meel1000_seed0_debug',
+ env=dict(
+ env_id='LunarLander-v2',
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ stop_value=200,
+ state_mean=None,
+ state_std=None,
+ device='cuda',
+ env_name='LunarLander-v2',
+ rtg_target=300, # max target reward_to_go
+ rtg_scale=150,
+ max_eval_ep_len=1000, # max len of one episode # TODO
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ context_len=20, # TODO
+ evaluator_env_num=8,
+ log_dir='DI-engine/dizoo/box2d/lunarlander/dt_log_1000eps',
+ model=dict(
+ state_dim=8,
+ act_dim=4,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=False, # TODO
+ ),
+ discount_factor=0.999,
+ nstep=3,
+ learn=dict(
+ dataset_path='DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl', # TODO
+ learning_rate=3e-4,
+ batch_size=64, # training batch size
+ target_update_freq=100,
+ ),
+ collect=dict(
+ data_type='d4rl_trajectory',
+ data_path='DI-engine/dizoo/box2d/lunarlander/offline_data/dt_data/dqn_data_1000eps.pkl',
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ ),
+)
+lunarlander_dt_config = EasyDict(lunarlander_dt_config)
+main_config = lunarlander_dt_config
+lunarlander_dt_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+lunarlander_dt_create_config = EasyDict(lunarlander_dt_create_config)
+create_config = lunarlander_dt_create_config
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_gail_dqn_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_gail_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..855f84598099586b232703e345aa91d864a060fe
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_gail_dqn_config.py
@@ -0,0 +1,108 @@
+from easydict import EasyDict
+
+nstep = 1
+lunarlander_dqn_gail_config = dict(
+ exp_name='lunarlander_dqn_gail_seed0',
+ env=dict(
+ # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
+ # Env number respectively for collector and evaluator.
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ reward_model=dict(
+ type='gail',
+ input_size=9,
+ hidden_size=64,
+ batch_size=64,
+ learning_rate=1e-3,
+ update_per_collect=100,
+ collect_count=100000,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder+/reward_model/ckpt/ckpt_best.pth.tar',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ # e.g. 'exp_name/expert_data.pkl'
+ data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ # Whether to use cuda for network.
+ cuda=False,
+ # Whether the RL algorithm is on-policy or off-policy.
+ on_policy=False,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[512, 64],
+ # Whether to use dueling head.
+ dueling=True,
+ ),
+ # Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # How many steps in td error.
+ nstep=nstep,
+ # learn_mode config
+ learn=dict(
+ update_per_collect=10,
+ batch_size=64,
+ learning_rate=0.001,
+ # Frequency of target network update.
+ target_update_freq=100,
+ ),
+ # collect_mode config
+ collect=dict(
+ # You can use either "n_sample" or "n_episode" in collector.collect.
+ # Get "n_sample" samples per collect.
+ n_sample=64,
+ # Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ # command_mode config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=50000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+)
+lunarlander_dqn_gail_config = EasyDict(lunarlander_dqn_gail_config)
+main_config = lunarlander_dqn_gail_config
+
+lunarlander_dqn_gail_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+lunarlander_dqn_gail_create_config = EasyDict(lunarlander_dqn_gail_create_config)
+create_config = lunarlander_dqn_gail_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_gail -c lunarlander_dqn_gail_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. lunarlander_dqn_config.py
+ from ding.entry import serial_pipeline_gail
+ from dizoo.box2d.lunarlander.config import lunarlander_dqn_config, lunarlander_dqn_create_config
+ expert_main_config = lunarlander_dqn_config
+ expert_create_config = lunarlander_dqn_create_config
+ serial_pipeline_gail(
+ [main_config, create_config], [expert_main_config, expert_create_config],
+ max_env_step=1000000,
+ seed=0,
+ collect_data=True
+ )
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_gcl_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_gcl_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..60065ae33bd9f247c46879c80fa4d1a0c86247a9
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_gcl_config.py
@@ -0,0 +1,69 @@
+from easydict import EasyDict
+
+lunarlander_ppo_config = dict(
+ exp_name='lunarlander_gcl_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ reward_model=dict(
+ learning_rate=0.001,
+ input_size=9,
+ batch_size=32,
+ continuous=False,
+ update_per_collect=20,
+ ),
+ policy=dict(
+ cuda=False,
+ action_space='discrete',
+ recompute_adv=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ action_space='discrete',
+ ),
+ learn=dict(
+ update_per_collect=8,
+ batch_size=800,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ adv_norm=True,
+ ),
+ collect=dict(
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ # If you need the data collected by the collector to contain logit key which reflect the probability of
+ # the action, you can change the key to be True.
+ # In Guided cost Learning, we need to use logit to train the reward model, we change the key to be True.
+ collector_logit=True,
+ n_sample=800,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+)
+lunarlander_ppo_config = EasyDict(lunarlander_ppo_config)
+main_config = lunarlander_ppo_config
+lunarlander_ppo_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+ reward_model=dict(type='guided_cost'),
+)
+lunarlander_ppo_create_config = EasyDict(lunarlander_ppo_create_config)
+create_config = lunarlander_ppo_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_guided_cost
+ serial_pipeline_guided_cost([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_impala_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_impala_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8725f5bd83bce909d676e71517c072e98f575430
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_impala_config.py
@@ -0,0 +1,74 @@
+from easydict import EasyDict
+
+lunarlander_impala_config = dict(
+ exp_name='impala_log/lunarlander_impala_seed0',
+ env=dict(
+ env_id='LunarLander-v2',
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=3000,
+ ),
+ policy=dict(
+ cuda=True,
+ # (int) the trajectory length to calculate v-trace target
+ unroll_len=32,
+ random_collect_size=256,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[64, 64],
+ ),
+ learn=dict(
+ # (int) collect n_sample data, train model update_per_collect times
+ # here we follow ppo serial pipeline
+ update_per_collect=10,
+ # (int) the number of data for a train iteration
+ batch_size=128,
+ grad_clip_type='clip_norm',
+ clip_value=5,
+ learning_rate=0.0003,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.0001,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.99,
+ # (float) additional discounting parameter
+ lambda_=0.95,
+ # (float) clip ratio of importance weights
+ rho_clip_ratio=1.0,
+ # (float) clip ratio of importance weights
+ c_clip_ratio=1.0,
+ # (float) clip ratio of importance sampling
+ rho_pg_clip_ratio=1.0,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model update_per_collect times
+ n_sample=32,
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000, sliced=True), ),
+ ),
+)
+
+lunarlander_impala_config = EasyDict(lunarlander_impala_config)
+main_config = lunarlander_impala_config
+
+lunarlander_impala_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='impala'),
+ replay_buffer=dict(type='naive'),
+)
+
+lunarlander_impala_create_config = EasyDict(lunarlander_impala_create_config)
+create_config = lunarlander_impala_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c lunarlander_impala_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_ngu_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_ngu_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4cb7cfe2630289d99fc9b35e5dbe521d9511fd2
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_ngu_config.py
@@ -0,0 +1,130 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+nstep = 5
+lunarlander_ngu_config = dict(
+ exp_name='lunarlander_ngu_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ env_id='LunarLander-v2',
+ obs_plus_prev_action_reward=True, # use specific env wrapper for ngu policy
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=195,
+ ),
+ rnd_reward_model=dict(
+ intrinsic_reward_type='add',
+ learning_rate=5e-4,
+ obs_shape=8,
+ action_shape=4,
+ batch_size=320, # transitions
+ update_per_collect=10,
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=5,
+ hidden_size_list=[128, 128, 64],
+ type='rnd-ngu',
+ ),
+ episodic_reward_model=dict(
+ # means if using rescale trick to the last non-zero reward
+ # when combing extrinsic and intrinsic reward.
+ # the rescale trick only used in:
+ # 1. sparse reward env minigrid, in which the last non-zero reward is a strong positive signal
+ # 2. the last reward of each episode directly reflects the agent's completion of the task, e.g. lunarlander
+ # Note that the ngu intrinsic reward is a positive value (max value is 5), in these envs,
+ # the last non-zero reward should not be overwhelmed by intrinsic rewards, so we need rescale the
+ # original last nonzero extrinsic reward.
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_rescale=True,
+ # means the rescale value for the last non-zero reward, only used when last_nonzero_reward_rescale is True
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_weight=100,
+ intrinsic_reward_type='add',
+ learning_rate=5e-4,
+ obs_shape=8,
+ action_shape=4,
+ batch_size=320, # transitions
+ update_per_collect=10,
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='episodic',
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ discount_factor=0.997,
+ nstep=nstep,
+ burnin_step=10,
+ # (int) is the total length of [sequence sample] minus
+ # the length of burnin part in [sequence sample],
+ # i.e., = = +
+ learn_unroll_len=20, # set this key according to the episode length
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[128, 128, 64],
+ collector_env_num=collector_env_num,
+ ),
+ learn=dict(
+ # according to the R2D2 paper, actor parameter update interval is 400
+ # environment timesteps, and in per collect phase, we collect sequence
+ # samples, the length of each sequence sample is + ,
+ # e.g. if n_sample=32, is 100, thus 32*100/400=8,
+ # we will set update_per_collect=8 in most environments.
+ update_per_collect=8,
+ batch_size=32,
+ learning_rate=1e-4,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=int(5e4),
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+lunarlander_ngu_config = EasyDict(lunarlander_ngu_config)
+main_config = lunarlander_ngu_config
+lunarlander_ngu_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ngu'),
+ rnd_reward_model=dict(type='rnd-ngu'),
+ episodic_reward_model=dict(type='episodic'),
+)
+lunarlander_ngu_create_config = EasyDict(lunarlander_ngu_create_config)
+create_config = lunarlander_ngu_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_ngu -c lunarlander_ngu_config.py -s 0`
+ from ding.entry import serial_pipeline_ngu
+ serial_pipeline_ngu([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_offppo_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_offppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2c933f5a1bb476923b72bb7e203c06b1d2abebd2
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_offppo_config.py
@@ -0,0 +1,53 @@
+from easydict import EasyDict
+
+lunarlander_ppo_config = dict(
+ exp_name='lunarlander_offppo_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ ),
+ learn=dict(
+ update_per_collect=4,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ nstep=1,
+ nstep_return=False,
+ adv_norm=True,
+ ),
+ collect=dict(
+ n_sample=128,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+)
+lunarlander_ppo_config = EasyDict(lunarlander_ppo_config)
+main_config = lunarlander_ppo_config
+lunarlander_ppo_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo_offpolicy'),
+)
+lunarlander_ppo_create_config = EasyDict(lunarlander_ppo_create_config)
+create_config = lunarlander_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c lunarlander_offppo_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_pg_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_pg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa66ef4ae93391c0eaa96e6d9ae6f86b0dd4facd
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_pg_config.py
@@ -0,0 +1,45 @@
+from easydict import EasyDict
+
+lunarlander_pg_config = dict(
+ exp_name='lunarlander_pg_seed0',
+ env=dict(
+ env_id='LunarLander-v2',
+ collector_env_num=4,
+ evaluator_env_num=4,
+ n_evaluator_episode=4,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ ),
+ learn=dict(
+ batch_size=320,
+ learning_rate=3e-4,
+ entropy_weight=0.001,
+ grad_norm=0.5,
+ ),
+ collect=dict(n_episode=8, discount_factor=0.99),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ ),
+)
+lunarlander_pg_config = EasyDict(lunarlander_pg_config)
+main_config = lunarlander_pg_config
+lunarlander_pg_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='pg'),
+ collector=dict(type='episode'),
+)
+lunarlander_pg_create_config = EasyDict(lunarlander_pg_create_config)
+create_config = lunarlander_pg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c lunarlander_pg_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad622c444dcda73d38e20bc3eaaa3af95952daed
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_ppo_config.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+
+lunarlander_ppo_config = dict(
+ exp_name='lunarlander_ppo_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=5,
+ stop_value=200,
+ ),
+ policy=dict(
+ recompute_adv=True,
+ cuda=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ action_space='discrete',
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ entropy_weight=0.01,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=512,
+ discount_factor=0.99,
+ ),
+ ),
+)
+lunarlander_ppo_config = EasyDict(lunarlander_ppo_config)
+main_config = lunarlander_ppo_config
+lunarlander_ppo_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+lunarlander_ppo_create_config = EasyDict(lunarlander_ppo_create_config)
+create_config = lunarlander_ppo_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_qrdqn_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_qrdqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b173f220a9834b710b7e64207174296bf2518380
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_qrdqn_config.py
@@ -0,0 +1,53 @@
+from easydict import EasyDict
+
+lunarlander_qrdqn_config = dict(
+ exp_name='lunarlander_qrdqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ ),
+ collect=dict(n_sample=128, ),
+ eval=dict(evaluator=dict(eval_freq=50, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+lunarlander_qrdqn_config = EasyDict(lunarlander_qrdqn_config)
+main_config = lunarlander_qrdqn_config
+lunarlander_qrdqn_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qrdqn'),
+)
+lunarlander_qrdqn_create_config = EasyDict(lunarlander_qrdqn_create_config)
+create_config = lunarlander_qrdqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c lunarlander_qrdqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_r2d2_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_r2d2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f67ad5830ba94b3acb6b51ca5c516c7310b36e00
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_r2d2_config.py
@@ -0,0 +1,87 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+lunarlander_r2d2_config = dict(
+ exp_name='lunarlander_r2d2_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ discount_factor=0.997,
+ burnin_step=2,
+ nstep=5,
+ # (int) the whole sequence length to unroll the RNN network minus
+ # the timesteps of burnin part,
+ # i.e., = = +
+ learn_unroll_len=40,
+ learn=dict(
+ # according to the R2D2 paper, actor parameter update interval is 400
+ # environment timesteps, and in per collect phase, we collect sequence
+ # samples, the length of each sequence sample is + ,
+ # e.g. if n_sample=32, is 100, thus 32*100/400=8,
+ # we will set update_per_collect=8 in most environments.
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In R2D2 policy, for each collect_env, we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ unroll_len=2 + 40,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=50000,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+lunarlander_r2d2_config = EasyDict(lunarlander_r2d2_config)
+main_config = lunarlander_r2d2_config
+lunarlander_r2d2_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2'),
+)
+lunarlander_r2d2_create_config = EasyDict(lunarlander_r2d2_create_config)
+create_config = lunarlander_r2d2_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c lunarlander_r2d2_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_r2d2_gtrxl_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_r2d2_gtrxl_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..82966d0a284b32a532e9dd5d26f8b42d341bb5ec
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_r2d2_gtrxl_config.py
@@ -0,0 +1,86 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+lunarlander_r2d2_gtrxl_config = dict(
+ exp_name='lunarlander_r2d2_gtrxl_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=8,
+ stop_value=200,
+ env_id='LunarLander-v2',
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ memory_len=0, # length of transformer memory (can be 0)
+ hidden_size=256,
+ gru_bias=1.,
+ att_layer_num=3,
+ dropout=0.1,
+ att_head_dim=64,
+ att_head_num=8,
+ ),
+ discount_factor=0.99,
+ nstep=5,
+ burnin_step=0, # how many steps use to initialize the memory (can be 0)
+ unroll_len=25, # trajectory len
+ seq_len=20, # transformer input segment
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ value_rescale=True,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=50000,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+lunarlander_r2d2_gtrxl_config = EasyDict(lunarlander_r2d2_gtrxl_config)
+main_config = lunarlander_r2d2_gtrxl_config
+lunarlander_r2d2_gtrxl_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2_gtrxl'),
+)
+lunarlander_r2d2_gtrxl_create_config = EasyDict(lunarlander_r2d2_gtrxl_create_config)
+create_config = lunarlander_r2d2_gtrxl_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c lunarlander_r2d2_gtrxl_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_r2d3_ppoexpert_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_r2d3_ppoexpert_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..28e5638652acaf5ae640dc3ffca63236a386068c
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_r2d3_ppoexpert_config.py
@@ -0,0 +1,162 @@
+import os
+from easydict import EasyDict
+
+module_path = os.path.dirname(__file__)
+
+collector_env_num = 8
+evaluator_env_num = 8
+expert_replay_buffer_size = int(5e3)
+"""agent config"""
+lunarlander_r2d3_config = dict(
+ exp_name='lunarlander_r2d3_ppoexpert_seed0',
+ env=dict(
+ # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ discount_factor=0.997,
+ nstep=5,
+ burnin_step=2,
+ # (int) the whole sequence length to unroll the RNN network minus
+ # the timesteps of burnin part,
+ # i.e., = = +
+ learn_unroll_len=40,
+ learn=dict(
+ # according to the r2d3 paper, actor parameter update interval is 400
+ # environment timesteps, and in per collect phase, we collect 32 sequence
+ # samples, the length of each samlpe sequence is + ,
+ # which is 100 in our seeting, 32*100/400=8, so we set update_per_collect=8
+ # in most environments
+ value_rescale=True,
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ # DQFD related parameters
+ lambda1=1.0, # n-step return
+ lambda2=1.0, # supervised loss
+ lambda3=1e-5, # L2 it's very important to set Adam optimizer optim_type='adamw'.
+ lambda_one_step_td=1, # 1-step return
+ margin_function=0.8, # margin function in JE, here we implement this as a constant
+ per_train_iter_k=0, # TODO(pu)
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ # The hyperparameter pho, the demo ratio, control the propotion of data coming
+ # from expert demonstrations versus from the agent's own experience.
+ pho=1 / 4., # TODO(pu)
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=int(1e5),
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6, # priority exponent default=0.6
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+lunarlander_r2d3_config = EasyDict(lunarlander_r2d3_config)
+main_config = lunarlander_r2d3_config
+lunarlander_r2d3_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d3'),
+)
+lunarlander_r2d3_create_config = EasyDict(lunarlander_r2d3_create_config)
+create_config = lunarlander_r2d3_create_config
+"""export config"""
+
+expert_lunarlander_r2d3_config = dict(
+ exp_name='expert_lunarlander_r2d3_ppoexpert_seed0',
+ env=dict(
+ # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
+ manager=dict(shared_memory=True, reset_inplace=True),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=5,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ priority=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[128, 128, 64], # ppo
+ ),
+ discount_factor=0.997,
+ burnin_step=2,
+ nstep=5,
+ learn=dict(expert_replay_buffer_size=expert_replay_buffer_size, ),
+ collect=dict(
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ # Cut trajectories into pieces with length "unroll_len",
+ # which should set as self._sequence_len of r2d2
+ unroll_len=42, # NOTE: should equals self._sequence_len in r2d2 policy
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ replay_buffer=dict(
+ replay_buffer_size=expert_replay_buffer_size,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.9, # priority exponent default=0.6
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+expert_lunarlander_r2d3_config = EasyDict(expert_lunarlander_r2d3_config)
+expert_main_config = expert_lunarlander_r2d3_config
+expert_lunarlander_r2d3_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='offppo_collect_traj'), # this policy is designed to collect off-ppo expert traj for r2d3
+)
+expert_lunarlander_r2d3_create_config = EasyDict(expert_lunarlander_r2d3_create_config)
+expert_create_config = expert_lunarlander_r2d3_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_r2d3
+ serial_pipeline_r2d3([main_config, create_config], [expert_main_config, expert_create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_r2d3_r2d2expert_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_r2d3_r2d2expert_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..281f8eb07d959ca1ffdbe70f08b36fad850bd2d2
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_r2d3_r2d2expert_config.py
@@ -0,0 +1,166 @@
+import os
+from easydict import EasyDict
+
+module_path = os.path.dirname(__file__)
+
+collector_env_num = 8
+evaluator_env_num = 8
+expert_replay_buffer_size = int(5e3)
+"""agent config"""
+lunarlander_r2d3_config = dict(
+ exp_name='lunarlander_r2d3_r2d2expert_seed0',
+ env=dict(
+ # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ discount_factor=0.997,
+ nstep=5,
+ burnin_step=2,
+ # (int) the whole sequence length to unroll the RNN network minus
+ # the timesteps of burnin part,
+ # i.e., = = +
+ learn_unroll_len=40,
+ learn=dict(
+ # according to the r2d3 paper, actor parameter update interval is 400
+ # environment timesteps, and in per collect phase, we collect 32 sequence
+ # samples, the length of each samlpe sequence is + ,
+ # which is 100 in our seeting, 32*100/400=8, so we set update_per_collect=8
+ # in most environments
+ value_rescale=True,
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ # DQFD related parameters
+ lambda1=1.0, # n-step return
+ lambda2=1.0, # supervised loss
+ lambda3=1e-5, # L2 it's very important to set Adam optimizer optim_type='adamw'.
+ lambda_one_step_td=1, # 1-step return
+ margin_function=0.8, # margin function in JE, here we implement this as a constant
+ per_train_iter_k=0, # TODO(pu)
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In R2D2 policy, for each collect_env, we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ # The hyperparameter pho, the demo ratio, control the propotion of data coming\
+ # from expert demonstrations versus from the agent's own experience.
+ pho=1 / 4, # TODO(pu)
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=100000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=int(1e4),
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6, # priority exponent default=0.6
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+lunarlander_r2d3_config = EasyDict(lunarlander_r2d3_config)
+main_config = lunarlander_r2d3_config
+lunarlander_r2d3_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d3'),
+)
+lunarlander_r2d3_create_config = EasyDict(lunarlander_r2d3_create_config)
+create_config = lunarlander_r2d3_create_config
+"""export config"""
+expert_lunarlander_r2d3_config = dict(
+ exp_name='expert_lunarlander_r2d3_r2d2expert_seed0',
+ env=dict(
+ # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=5,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ priority=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[128, 128, 512], # r2d2
+ ),
+ discount_factor=0.997,
+ burnin_step=2,
+ nstep=5,
+ learn=dict(expert_replay_buffer_size=expert_replay_buffer_size, ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In R2D2 policy, for each collect_env, we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ # Cut trajectories into pieces with length "unroll_len",
+ # which should set as self._sequence_len of r2d2
+ unroll_len=42, # NOTE: should equals self._sequence_len in r2d2 policy
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ replay_buffer=dict(
+ replay_buffer_size=expert_replay_buffer_size,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.9, # priority exponent default=0.6
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+expert_lunarlander_r2d3_config = EasyDict(expert_lunarlander_r2d3_config)
+expert_main_config = expert_lunarlander_r2d3_config
+expert_lunarlander_r2d3_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2_collect_traj'), # this policy is designed to collect r2d2 expert traj for r2d3
+)
+expert_lunarlander_r2d3_create_config = EasyDict(expert_lunarlander_r2d3_create_config)
+expert_create_config = expert_lunarlander_r2d3_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_r2d3
+ serial_pipeline_r2d3([main_config, create_config], [expert_main_config, expert_create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_rnd_onppo_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_rnd_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e141d711887bb2b9d87f2258ce6040eb54b0ba2
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_rnd_onppo_config.py
@@ -0,0 +1,78 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+lunarlander_ppo_rnd_config = dict(
+ exp_name='lunarlander_rnd_onppo_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=200,
+ ),
+ reward_model=dict(
+ intrinsic_reward_type='add',
+ # means the relative weight of RND intrinsic_reward.
+ # If intrinsic_reward_weight=None, we will automatically set it based on
+ # the absolute value of the difference between max and min extrinsic reward in the sampled mini-batch
+ # please refer to rnd_reward_model for details.
+ intrinsic_reward_weight=None,
+ # means the rescale value of RND intrinsic_reward only used when intrinsic_reward_weight is None
+ # please refer to rnd_reward_model for details.
+ intrinsic_reward_rescale=0.001,
+ learning_rate=5e-4,
+ obs_shape=8,
+ batch_size=320,
+ update_per_collect=4,
+ obs_norm=True,
+ obs_norm_clamp_min=-1,
+ obs_norm_clamp_max=1,
+ clear_buffer_per_iters=10,
+ ),
+ policy=dict(
+ recompute_adv=True,
+ cuda=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ action_space='discrete',
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ update_per_collect=1,
+ batch_size=64,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=512,
+ collector_env_num=collector_env_num,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+)
+lunarlander_ppo_rnd_config = EasyDict(lunarlander_ppo_rnd_config)
+main_config = lunarlander_ppo_rnd_config
+lunarlander_ppo_rnd_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+ reward_model=dict(type='rnd')
+)
+lunarlander_ppo_rnd_create_config = EasyDict(lunarlander_ppo_rnd_create_config)
+create_config = lunarlander_ppo_rnd_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_reward_model_onpolicy
+ serial_pipeline_reward_model_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_sqil_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_sqil_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..638c2d29814f72a254bec204f6db9d4125b450f4
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_sqil_config.py
@@ -0,0 +1,66 @@
+from easydict import EasyDict
+
+lunarlander_sqil_config = dict(
+ exp_name='lunarlander_sqil_seed0',
+ env=dict(
+ # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(batch_size=64, learning_rate=0.001, alpha=0.08),
+ collect=dict(
+ n_sample=64,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ # Cut trajectories into pieces with length "unrol_len".
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=50, )), # note: this is the times after which you learns to evaluate
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+lunarlander_sqil_config = EasyDict(lunarlander_sqil_config)
+main_config = lunarlander_sqil_config
+lunarlander_sqil_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sql'),
+)
+lunarlander_sqil_create_config = EasyDict(lunarlander_sqil_create_config)
+create_config = lunarlander_sqil_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_sqil -c lunarlander_sqil_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. spaceinvaders_dqn_config.py
+ from ding.entry import serial_pipeline_sqil
+ from dizoo.box2d.lunarlander.config import lunarlander_dqn_config, lunarlander_dqn_create_config
+ expert_main_config = lunarlander_dqn_config
+ expert_create_config = lunarlander_dqn_create_config
+ serial_pipeline_sqil([main_config, create_config], [expert_main_config, expert_create_config], seed=0)
\ No newline at end of file
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_sql_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_sql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..60b7857868d3c28a8885406d64e928077c70be14
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_sql_config.py
@@ -0,0 +1,52 @@
+from easydict import EasyDict
+
+lunarlander_sql_config = dict(
+ exp_name='lunarlander_sql_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(batch_size=64, learning_rate=0.001, alpha=0.08),
+ collect=dict(n_sample=64),
+ eval=dict(evaluator=dict(eval_freq=50, )), # note: this is the times after which you learns to evaluate
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+lunarlander_sql_config = EasyDict(lunarlander_sql_config)
+main_config = lunarlander_sql_config
+lunarlander_sql_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sql'),
+)
+lunarlander_sql_create_config = EasyDict(lunarlander_sql_create_config)
+create_config = lunarlander_sql_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c lunarlander_sql_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_trex_dqn_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_trex_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..790ca5c271a17e21c28e48926d9b152769c0b072
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_trex_dqn_config.py
@@ -0,0 +1,110 @@
+from easydict import EasyDict
+
+nstep = 1
+lunarlander_trex_dqn_config = dict(
+ exp_name='lunarlander_trex_dqn_seed0',
+ env=dict(
+ # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
+ # Env number respectively for collector and evaluator.
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ reward_model=dict(
+ type='trex',
+ min_snippet_length=30,
+ max_snippet_length=100,
+ checkpoint_min=1000,
+ checkpoint_max=9000,
+ checkpoint_step=1000,
+ num_snippets=60000,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder + /lunarlander.params',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ # e.g. 'exp_name/expert_data.pkl'
+ data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ # Whether to use cuda for network.
+ cuda=False,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[512, 64],
+ # Whether to use dueling head.
+ dueling=True,
+ ),
+ # Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # How many steps in td error.
+ nstep=nstep,
+ # learn_mode config
+ learn=dict(
+ update_per_collect=10,
+ batch_size=64,
+ learning_rate=0.001,
+ # Frequency of target network update.
+ target_update_freq=100,
+ ),
+ # collect_mode config
+ collect=dict(
+ # You can use either "n_sample" or "n_episode" in collector.collect.
+ # Get "n_sample" samples per collect.
+ n_sample=64,
+ # Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ # command_mode config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=50000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+)
+lunarlander_trex_dqn_config = EasyDict(lunarlander_trex_dqn_config)
+main_config = lunarlander_trex_dqn_config
+
+lunarlander_trex_dqn_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+lunarlander_trex_dqn_create_config = EasyDict(lunarlander_trex_dqn_create_config)
+create_config = lunarlander_trex_dqn_create_config
+
+if __name__ == '__main__':
+ # Users should first run ``lunarlander_dqn_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_trex
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_trex([main_config, create_config])
diff --git a/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_trex_offppo_config.py b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_trex_offppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..37e2c78fdd28e57e73c1d4536c77d7863e388a98
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/config/lunarlander_trex_offppo_config.py
@@ -0,0 +1,87 @@
+from easydict import EasyDict
+
+lunarlander_trex_ppo_config = dict(
+ exp_name='lunarlander_trex_offppo_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ reward_model=dict(
+ type='trex',
+ min_snippet_length=30,
+ max_snippet_length=100,
+ checkpoint_min=1000,
+ checkpoint_max=9000,
+ checkpoint_step=1000,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ # However, here in ``expert_model_path``, it is ``exp_name`` of the expert config.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder + /lunarlander.params',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ # See ding/entry/application_entry_trex_collect_data.py to collect the data
+ data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ ),
+ learn=dict(
+ update_per_collect=4,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ nstep=1,
+ nstep_return=False,
+ adv_norm=True,
+ ),
+ collect=dict(
+ n_sample=128,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+)
+lunarlander_trex_ppo_config = EasyDict(lunarlander_trex_ppo_config)
+main_config = lunarlander_trex_ppo_config
+lunarlander_trex_ppo_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo_offpolicy'),
+)
+lunarlander_trex_ppo_create_config = EasyDict(lunarlander_trex_ppo_create_config)
+create_config = lunarlander_trex_ppo_create_config
+
+if __name__ == '__main__':
+ # Users should first run ``lunarlander_offppo_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_trex
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_trex([main_config, create_config])
diff --git a/DI-engine/dizoo/box2d/lunarlander/entry/__init__.py b/DI-engine/dizoo/box2d/lunarlander/entry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/box2d/lunarlander/entry/lunarlander_dqn_eval.py b/DI-engine/dizoo/box2d/lunarlander/entry/lunarlander_dqn_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..a87a80f8e8237668a7de4aedecc4a7f660eec56a
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/entry/lunarlander_dqn_eval.py
@@ -0,0 +1,60 @@
+import os
+import gym
+import torch
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+from functools import partial
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.box2d.lunarlander.config.lunarlander_dqn_config import main_config, create_config
+
+
+def main(rl_cfg, seed=0):
+ main_cfg, create_cfg = rl_cfg
+ cfg = compile_config(
+ main_cfg,
+ BaseEnvManager,
+ DQNPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ create_cfg=create_cfg,
+ save_cfg=True
+ )
+
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
+
+ # Set random seed for all package and instance
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model)
+ policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
+
+ # evaluate
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator.eval()
+
+
+if __name__ == "__main__":
+ main(rl_cfg=(main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py b/DI-engine/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1c28ed975faa52734f7883cb2f7f2f0a8713267
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/entry/lunarlander_dqn_example.py
@@ -0,0 +1,68 @@
+import gym
+from ditk import logging
+from ding.data.model_loader import FileModelLoader
+from ding.data.storage_loader import FileStorageLoader
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, ContextExchanger, ModelExchanger, online_logger, termination_checker, \
+ nstep_reward_enhancer
+from ding.utils import set_pkg_seed
+from dizoo.box2d.lunarlander.config.lunarlander_dqn_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True, save_cfg=task.router.node_id == 0)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: DingEnvWrapper(gym.make("LunarLander-v2")) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ # Consider the case with multiple processes
+ if task.router.is_active:
+ # You can use labels to distinguish between workers with different roles,
+ # here we use node_id to distinguish.
+ if task.router.node_id == 0:
+ task.add_role(task.role.LEARNER)
+ elif task.router.node_id == 1:
+ task.add_role(task.role.EVALUATOR)
+ else:
+ task.add_role(task.role.COLLECTOR)
+
+ # Sync their context and model between each worker.
+ task.use(ContextExchanger(skip_n_iter=1))
+ task.use(ModelExchanger(model))
+
+ # Here is the part of single process pipeline.
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(online_logger(train_show_freq=50))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+ task.use(termination_checker(max_env_step=int(3e6)))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/box2d/lunarlander/envs/__init__.py b/DI-engine/dizoo/box2d/lunarlander/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1395cd6c6d864b0569bb1e621291c55f74c79902
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/envs/__init__.py
@@ -0,0 +1 @@
+from .lunarlander_env import LunarLanderEnv
diff --git a/DI-engine/dizoo/box2d/lunarlander/envs/lunarlander_env.py b/DI-engine/dizoo/box2d/lunarlander/envs/lunarlander_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..235d8881155cb515bfd86931ea5df019d1480073
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/envs/lunarlander_env.py
@@ -0,0 +1,150 @@
+import copy
+import os
+from typing import Optional
+
+import gym
+import numpy as np
+from easydict import EasyDict
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.envs import ObsPlusPrevActRewWrapper
+from ding.envs.common import affine_transform, save_frames_as_gif
+from ding.torch_utils import to_ndarray
+from ding.utils import ENV_REGISTRY
+
+
+@ENV_REGISTRY.register('lunarlander')
+class LunarLanderEnv(BaseEnv):
+
+ config = dict(
+ replay_path=None,
+ save_replay_gif=False,
+ replay_path_gif=None,
+ action_clip=False,
+ )
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ # env_id: LunarLander-v2, LunarLanderContinuous-v2
+ self._env_id = cfg.env_id
+ self._replay_path = None
+ self._replay_path_gif = cfg.replay_path_gif
+ self._save_replay_gif = cfg.save_replay_gif
+ self._save_replay_count = 0
+ if 'Continuous' in self._env_id:
+ self._act_scale = cfg.act_scale # act_scale only works in continuous env
+ self._action_clip = cfg.action_clip
+ else:
+ self._act_scale = False
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = gym.make(self._cfg.env_id)
+ if self._replay_path is not None:
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+ if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward:
+ self._env = ObsPlusPrevActRewWrapper(self._env)
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+ obs = to_ndarray(obs)
+ if self._save_replay_gif:
+ self._frames = []
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def render(self) -> None:
+ self._env.render()
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ assert isinstance(action, np.ndarray), type(action)
+ if action.shape == (1, ):
+ action = action.item() # 0-dim array
+ if self._act_scale:
+ action = affine_transform(action, action_clip=self._action_clip, min_val=-1, max_val=1)
+ if self._save_replay_gif:
+ self._frames.append(self._env.render(mode='rgb_array'))
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ if self._save_replay_gif:
+ if not os.path.exists(self._replay_path_gif):
+ os.makedirs(self._replay_path_gif)
+ path = os.path.join(
+ self._replay_path_gif, '{}_episode_{}.gif'.format(self._env_id, self._save_replay_count)
+ )
+ save_frames_as_gif(self._frames, path)
+ self._save_replay_count += 1
+
+ obs = to_ndarray(obs)
+ rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transferred to a array with shape (1,)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+ self._save_replay_gif = True
+ self._save_replay_count = 0
+ # this function can lead to the meaningless result
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ if isinstance(random_action, np.ndarray):
+ pass
+ elif isinstance(random_action, int):
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine LunarLander Env"
diff --git a/DI-engine/dizoo/box2d/lunarlander/envs/test_lunarlander_env.py b/DI-engine/dizoo/box2d/lunarlander/envs/test_lunarlander_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..b828ff16c9e7550087e5224ebd3956d34673fe3c
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/envs/test_lunarlander_env.py
@@ -0,0 +1,41 @@
+from time import time
+import pytest
+import numpy as np
+from easydict import EasyDict
+from dizoo.box2d.lunarlander.envs import LunarLanderEnv
+
+
+@pytest.mark.envtest
+@pytest.mark.parametrize(
+ 'cfg', [
+ EasyDict({
+ 'env_id': 'LunarLander-v2',
+ 'act_scale': False
+ }),
+ EasyDict({
+ 'env_id': 'LunarLanderContinuous-v2',
+ 'act_scale': True
+ })
+ ]
+)
+class TestLunarLanderEnvEnv:
+
+ def test_naive(self, cfg):
+ env = LunarLanderEnv(cfg)
+ env.seed(314)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (8, )
+ for i in range(10):
+ random_action = env.random_action()
+ timestep = env.step(random_action)
+ print(timestep)
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (8, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.reward_space.low
+ assert timestep.reward <= env.reward_space.high
+ # assert isinstance(timestep, tuple)
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py b/DI-engine/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7cc7b383dac486ebb63302bec37414a8dfc606c
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/offline_data/collect_dqn_data_config.py
@@ -0,0 +1,124 @@
+from easydict import EasyDict
+
+nstep = 3
+lunarlander_dqn_config = dict(
+ exp_name='lunarlander',
+ env=dict(
+ # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
+ # Env number respectively for collector and evaluator.
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='LunarLander-v2',
+ n_evaluator_episode=8,
+ stop_value=200,
+ ),
+ policy=dict(
+ # Whether to use cuda for network.
+ cuda=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=4,
+ encoder_hidden_size_list=[512, 64],
+ # Whether to use dueling head.
+ dueling=True,
+ ),
+ # Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # How many steps in td error.
+ nstep=nstep,
+ # learn_mode config
+ learn=dict(
+ # NOTE
+ learner=dict(
+ train_iterations=1000000000,
+ dataloader=dict(num_workers=0, ),
+ log_policy=True,
+ hook=dict(
+ load_ckpt_before_run=
+ './ckpt_best.pth.tar', # TODO: syspath modeified in other place, have to use abs path. May be fix in next version.
+ # load_ckpt_before_run='DI-engine/dizoo/box2d/lunarlander/dt_data/ckpt/ckpt_best.pth.tar',
+ log_show_after_iter=100,
+ save_ckpt_after_iter=10000,
+ save_ckpt_after_run=False,
+ ),
+ cfg_type='BaseLearnerDict',
+ load_path='./ckpt_best.pth.tar', # TODO: same like last path.
+ # load_path='DI-engine/dizoo/box2d/lunarlander/dt_data/ckpt/ckpt_best.pth.tar',
+ ),
+ update_per_collect=10,
+ batch_size=64,
+ learning_rate=0.001,
+ # Frequency of target network update.
+ target_update_freq=100,
+ ),
+ # collect_mode config
+ collect=dict(
+ # You can use either "n_sample" or "n_episode" in collector.collect.
+ # Get "n_sample" samples per collect.
+ n_sample=64,
+ # Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ # NOTE
+ # save
+ # data_type='hdf5',
+ data_type='naive',
+ save_path='./dt_data/dqn_data_1000eps.pkl', # TODO(pu)
+ # load
+ data_path='./dt_data/dqn_data_10eps.pkl', # TODO(pu)
+ ),
+ # command_mode config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=50000,
+ ),
+ # NOTE
+ replay_buffer=dict(
+ type='advanced',
+ # replay_buffer_size=100000,
+ replay_buffer_size=1000, # TODO(pu)
+ max_use=float('inf'),
+ max_staleness=float('inf'),
+ alpha=0.6,
+ beta=0.4,
+ anneal_step=100000,
+ enable_track_used_data=False,
+ deepcopy=False,
+ thruput_controller=dict(
+ push_sample_rate_limit=dict(
+ max=float('inf'),
+ min=0,
+ ),
+ window_seconds=30,
+ sample_min_limit_ratio=1,
+ ),
+ monitor=dict(
+ sampled_data_attr=dict(
+ average_range=5,
+ print_freq=200,
+ ),
+ periodic_thruput=dict(seconds=60, ),
+ ),
+ cfg_type='AdvancedReplayBufferDict',
+ ),
+ ),
+ ),
+)
+lunarlander_dqn_config = EasyDict(lunarlander_dqn_config)
+main_config = lunarlander_dqn_config
+
+lunarlander_dqn_create_config = dict(
+ env=dict(
+ type='lunarlander',
+ import_names=['dizoo.box2d.lunarlander.envs.lunarlander_env'],
+ ),
+ # env_manager=dict(type='subprocess'),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn'),
+)
+lunarlander_dqn_create_config = EasyDict(lunarlander_dqn_create_config)
+create_config = lunarlander_dqn_create_config
diff --git a/DI-engine/dizoo/box2d/lunarlander/offline_data/lunarlander_collect_data.py b/DI-engine/dizoo/box2d/lunarlander/offline_data/lunarlander_collect_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..58256b758ed438f590b1109324e1629a6db8902c
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/offline_data/lunarlander_collect_data.py
@@ -0,0 +1,33 @@
+from dizoo.box2d.lunarlander.offline_data.collect_dqn_data_config import main_config, create_config
+from ding.entry import collect_episodic_demo_data, eval
+import torch
+import copy
+
+
+def eval_ckpt(args):
+ config = copy.deepcopy([main_config, create_config])
+ # eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run, replay_path='./replay')
+ eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run)
+
+
+def generate(args):
+ config = copy.deepcopy([main_config, create_config])
+ state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
+ collect_episodic_demo_data(
+ config,
+ collect_count=main_config.policy.other.replay_buffer.replay_buffer_size,
+ seed=args.seed,
+ expert_data_path=main_config.policy.collect.save_path,
+ state_dict=state_dict
+ )
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ args = parser.parse_args()
+
+ eval_ckpt(args)
+ generate(args)
diff --git a/DI-engine/dizoo/box2d/lunarlander/offline_data/lunarlander_show_data.py b/DI-engine/dizoo/box2d/lunarlander/offline_data/lunarlander_show_data.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3a63b071d650f9aec9970a05835be0b8cd18d50
--- /dev/null
+++ b/DI-engine/dizoo/box2d/lunarlander/offline_data/lunarlander_show_data.py
@@ -0,0 +1,50 @@
+from dizoo.classic_control.cartpole.offline_data.collect_dqn_data_config import main_config, create_config
+
+from ding.entry import serial_pipeline_offline
+import os
+import torch
+from torch.utils.data import DataLoader
+from ding.config import read_config, compile_config
+from ding.utils.data import create_dataset
+
+
+def train(args):
+ config = [main_config, create_config]
+ input_cfg = config
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = input_cfg
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ cfg = compile_config(cfg, seed=args.seed, auto=True, create_cfg=create_cfg)
+
+ # Dataset
+ dataset = create_dataset(cfg)
+ print(dataset.__len__())
+
+ # print(dataset.__getitem__(0))
+ print(dataset.__getitem__(0)[0]['action'])
+
+ # episode_action = []
+ # for i in range(dataset.__getitem__(0).__len__()): # length of the firse collected episode
+ # episode_action.append(dataset.__getitem__(0)[i]['action'])
+
+ # stacked action of the first collected episode
+ episode_action = torch.stack(
+ [dataset.__getitem__(0)[i]['action'] for i in range(dataset.__getitem__(0).__len__())], axis=0
+ )
+
+ # dataloader = DataLoader(dataset, cfg.policy.learn.batch_size, shuffle=True, collate_fn=lambda x: x)
+ # for i, train_data in enumerate(dataloader):
+ # print(i, train_data)
+ # serial_pipeline_offline(config, seed=args.seed)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ args = parser.parse_args()
+
+ train(args)
diff --git a/DI-engine/dizoo/bsuite/__init__.py b/DI-engine/dizoo/bsuite/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/bsuite/config/__init__.py b/DI-engine/dizoo/bsuite/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b137891791fe96927ad78e64b0aad7bded08bdc
--- /dev/null
+++ b/DI-engine/dizoo/bsuite/config/__init__.py
@@ -0,0 +1 @@
+
diff --git a/DI-engine/dizoo/bsuite/config/serial/bandit_noise/bandit_noise_0_dqn_config.py b/DI-engine/dizoo/bsuite/config/serial/bandit_noise/bandit_noise_0_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bd43dd6f9041830e3d161611b48e92a07f53c61
--- /dev/null
+++ b/DI-engine/dizoo/bsuite/config/serial/bandit_noise/bandit_noise_0_dqn_config.py
@@ -0,0 +1,51 @@
+from easydict import EasyDict
+
+bandit_noise_0_dqn_config = dict(
+ exp_name='bandit_noise_0_dqn',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=1,
+ n_evaluator_episode=10,
+ env_id='bandit_noise/0',
+ stop_value=0.8,
+ ),
+ policy=dict(
+ load_path='',
+ cuda=True,
+ model=dict(
+ obs_shape=1,
+ action_shape=11,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ ),
+ collect=dict(n_sample=8),
+ eval=dict(evaluator=dict(eval_freq=20, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+bandit_noise_0_dqn_config = EasyDict(bandit_noise_0_dqn_config)
+main_config = bandit_noise_0_dqn_config
+bandit_noise_0_dqn_create_config = dict(
+ env=dict(
+ type='bsuite',
+ import_names=['dizoo.bsuite.envs.bsuite_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn'),
+)
+bandit_noise_0_dqn_create_config = EasyDict(bandit_noise_0_dqn_create_config)
+create_config = bandit_noise_0_dqn_create_config
diff --git a/DI-engine/dizoo/bsuite/config/serial/cartpole_swingup/cartpole_swingup_0_dqn_config.py b/DI-engine/dizoo/bsuite/config/serial/cartpole_swingup/cartpole_swingup_0_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..657d3515d97a6f3c07239bb794b31398c60148d7
--- /dev/null
+++ b/DI-engine/dizoo/bsuite/config/serial/cartpole_swingup/cartpole_swingup_0_dqn_config.py
@@ -0,0 +1,57 @@
+from easydict import EasyDict
+
+cartpole_swingup_dqn_config = dict(
+ exp_name='cartpole_swingup_0_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=1,
+ n_evaluator_episode=10,
+ env_id='cartpole_swingup/0',
+ stop_value=100,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=8,
+ action_shape=3,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97, # discount_factor: 0.97-0.99
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ ),
+ collect=dict(n_sample=8),
+ eval=dict(evaluator=dict(eval_freq=200, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+cartpole_swingup_dqn_config = EasyDict(cartpole_swingup_dqn_config)
+main_config = cartpole_swingup_dqn_config
+cartpole_swingup_dqn_create_config = dict(
+ env=dict(
+ type='bsuite',
+ import_names=['dizoo.bsuite.envs.bsuite_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+cartpole_swingup_dqn_create_config = EasyDict(cartpole_swingup_dqn_create_config)
+create_config = cartpole_swingup_dqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_swingup_0_dqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/bsuite/config/serial/memory_len/memory_len_0_a2c_config.py b/DI-engine/dizoo/bsuite/config/serial/memory_len/memory_len_0_a2c_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..77790a768b893da7a94f9e8b34a1e27cdc2a3d7f
--- /dev/null
+++ b/DI-engine/dizoo/bsuite/config/serial/memory_len/memory_len_0_a2c_config.py
@@ -0,0 +1,52 @@
+from easydict import EasyDict
+
+memory_len_a2c_config = dict(
+ exp_name='memory_len_0_a2c_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=1,
+ n_evaluator_episode=20,
+ env_id='memory_len/0', # this environment configuration is 1 'memory steps' long
+ stop_value=1.,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=False,
+ priority=True,
+ model=dict(
+ obs_shape=3,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ ),
+ learn=dict(
+ batch_size=64,
+ normalize_advantage=False,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ ),
+ collect=dict(
+ n_sample=80,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ ),
+)
+memory_len_a2c_config = EasyDict(memory_len_a2c_config)
+main_config = memory_len_a2c_config
+
+memory_len_a2c_create_config = dict(
+ env=dict(
+ type='bsuite',
+ import_names=['dizoo.bsuite.envs.bsuite_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='a2c'),
+)
+memory_len_a2c_create_config = EasyDict(memory_len_a2c_create_config)
+create_config = memory_len_a2c_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c memory_len_0_a2c_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/bsuite/config/serial/memory_len/memory_len_0_dqn_config.py b/DI-engine/dizoo/bsuite/config/serial/memory_len/memory_len_0_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e120bfe5b6d067170565620ee9ace01321c44760
--- /dev/null
+++ b/DI-engine/dizoo/bsuite/config/serial/memory_len/memory_len_0_dqn_config.py
@@ -0,0 +1,57 @@
+from easydict import EasyDict
+
+memory_len_dqn_config = dict(
+ exp_name='memory_len_0_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=1,
+ n_evaluator_episode=20,
+ env_id='memory_len/0', # this environment configuration is 1 'memory steps' long
+ stop_value=1.,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=3,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97, # discount_factor: 0.97-0.99
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ ),
+ collect=dict(n_sample=8),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+memory_len_dqn_config = EasyDict(memory_len_dqn_config)
+main_config = memory_len_dqn_config
+memory_len_dqn_create_config = dict(
+ env=dict(
+ type='bsuite',
+ import_names=['dizoo.bsuite.envs.bsuite_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+memory_len_dqn_create_config = EasyDict(memory_len_dqn_create_config)
+create_config = memory_len_dqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c memory_len_0_dqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/bsuite/config/serial/memory_len/memory_len_15_r2d2_config.py b/DI-engine/dizoo/bsuite/config/serial/memory_len/memory_len_15_r2d2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..72d7308082dade86016f964a8448b67207a78add
--- /dev/null
+++ b/DI-engine/dizoo/bsuite/config/serial/memory_len/memory_len_15_r2d2_config.py
@@ -0,0 +1,64 @@
+from easydict import EasyDict
+
+memory_len_r2d2_config = dict(
+ exp_name='memory_len_15_r2d2_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=1,
+ n_evaluator_episode=20,
+ env_id='memory_len/15', # this environment configuration is 30 'memory steps' long
+ stop_value=1.,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=3,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ ),
+ discount_factor=0.997, # discount_factor: 0.97-0.99
+ burnin_step=1, # fix to 1 since early steps are the most important
+ nstep=3,
+ unroll_len=40, # for better converge should be unroll_len > 'memory steps' = 30
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ each_iter_n_sample=32,
+ env_num=8,
+ ),
+ eval=dict(env_num=1, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(replay_buffer_size=50000, ),
+ ),
+ ),
+)
+memory_len_r2d2_config = EasyDict(memory_len_r2d2_config)
+main_config = memory_len_r2d2_config
+memory_len_r2d2_create_config = dict(
+ env=dict(
+ type='bsuite',
+ import_names=['dizoo.bsuite.envs.bsuite_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2'),
+)
+memory_len_r2d2_create_config = EasyDict(memory_len_r2d2_create_config)
+create_config = memory_len_r2d2_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c memory_len_15_r2d2_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/bsuite/config/serial/memory_len/memory_len_15_r2d2_gtrxl_config.py b/DI-engine/dizoo/bsuite/config/serial/memory_len/memory_len_15_r2d2_gtrxl_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..74d623522cd3fd0a26976a8be88f74b0c72aa7a3
--- /dev/null
+++ b/DI-engine/dizoo/bsuite/config/serial/memory_len/memory_len_15_r2d2_gtrxl_config.py
@@ -0,0 +1,68 @@
+from easydict import EasyDict
+
+memory_len_r2d2_gtrxl_config = dict(
+ exp_name='memory_len_15_r2d2_gtrxl_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=1,
+ n_evaluator_episode=20,
+ env_id='memory_len/15', # this environment configuration is 30 'memory steps' long
+ stop_value=1.,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=3,
+ action_shape=2,
+ memory_len=0,
+ hidden_size=64,
+ gru_bias=1. # gru_bias: 0. to 2.
+ ),
+ discount_factor=0.997, # discount_factor: 0.97-0.99
+ nstep=3,
+ burnin_step=0,
+ unroll_len=35, # unroll_len >= seq_len
+ seq_len=35, # for better converge should be seq_len > 'memory steps' = 30
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ init_memory='zero', # 'zero' or 'old', how to initialize the transformer memory
+ ),
+ collect=dict(
+ each_iter_n_sample=32,
+ env_num=8,
+ ),
+ eval=dict(env_num=1, evaluator=dict(eval_freq=10, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(replay_buffer_size=50000, ),
+ ),
+ ),
+)
+memory_len_r2d2_gtrxl_config = EasyDict(memory_len_r2d2_gtrxl_config)
+main_config = memory_len_r2d2_gtrxl_config
+memory_len_r2d2_gtrxl_create_config = dict(
+ env=dict(
+ type='bsuite',
+ import_names=['dizoo.bsuite.envs.bsuite_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2_gtrxl'),
+)
+memory_len_r2d2_gtrxl_create_config = EasyDict(memory_len_r2d2_gtrxl_create_config)
+create_config = memory_len_r2d2_gtrxl_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c memory_len_15_r2d2_gtrxl_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/bsuite/envs/__init__.py b/DI-engine/dizoo/bsuite/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..652abb7e7aff3626b9e9bfb9301e6f45b14c1d01
--- /dev/null
+++ b/DI-engine/dizoo/bsuite/envs/__init__.py
@@ -0,0 +1 @@
+from .bsuite_env import BSuiteEnv
diff --git a/DI-engine/dizoo/bsuite/envs/bsuite_env.py b/DI-engine/dizoo/bsuite/envs/bsuite_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..915411f57d7479495d9af45e413e7ad863d07a80
--- /dev/null
+++ b/DI-engine/dizoo/bsuite/envs/bsuite_env.py
@@ -0,0 +1,107 @@
+from typing import Any, List, Union, Optional
+import time
+import copy
+import gym
+import numpy as np
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.torch_utils import to_ndarray, to_list
+from ding.utils import ENV_REGISTRY
+
+import bsuite
+from bsuite.utils import gym_wrapper
+from bsuite import sweep
+
+
+@ENV_REGISTRY.register('bsuite')
+class BSuiteEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ self.env_id = cfg.env_id
+ self.env_name = self.env_id.split('/')[0]
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ raw_env = bsuite.load_from_id(bsuite_id=self.env_id)
+ self._env = gym_wrapper.GymFromDMEnv(raw_env)
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float64
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+ if obs.shape[0] == 1:
+ obs = obs[0]
+ obs = to_ndarray(obs).astype(np.float32)
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ assert isinstance(action, np.ndarray), type(action)
+ if action.shape[0] == 1:
+ action = action[0]
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ if obs.shape[0] == 1:
+ obs = obs[0]
+ obs = to_ndarray(obs)
+ rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def config_info(self) -> dict:
+ config_info = sweep.SETTINGS[self.env_id] # additional info that are specific to each env configuration
+ config_info['num_episodes'] = self._env.bsuite_num_episodes
+ return config_info
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine BSuite Env({})".format(self.env_id)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_env_num = cfg.pop('collector_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = True
+ return [cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_env_num = cfg.pop('evaluator_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = False
+ return [cfg for _ in range(evaluator_env_num)]
diff --git a/DI-engine/dizoo/bsuite/envs/test_bsuite_env.py b/DI-engine/dizoo/bsuite/envs/test_bsuite_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..93a330bfb8c998608094f20fd45f52309ab7f424
--- /dev/null
+++ b/DI-engine/dizoo/bsuite/envs/test_bsuite_env.py
@@ -0,0 +1,43 @@
+from time import time
+import pytest
+import numpy as np
+from easydict import EasyDict
+from dizoo.bsuite.envs import BSuiteEnv
+
+
+@pytest.mark.envtest
+class TestBSuiteEnv:
+
+ def test_memory_len(self):
+ cfg = {'env_id': 'memory_len/0'}
+ cfg = EasyDict(cfg)
+ memory_len_env = BSuiteEnv(cfg)
+ memory_len_env.seed(0)
+ obs = memory_len_env.reset()
+ assert obs.shape == (3, )
+ while True:
+ random_action = memory_len_env.random_action()
+ timestep = memory_len_env.step(random_action)
+ assert timestep.obs.shape == (3, )
+ assert timestep.reward.shape == (1, )
+ if timestep.done:
+ assert 'eval_episode_return' in timestep.info, timestep.info
+ break
+ memory_len_env.close()
+
+ def test_cartpole_swingup(self):
+ cfg = {'env_id': 'cartpole_swingup/0'}
+ cfg = EasyDict(cfg)
+ bandit_noise_env = BSuiteEnv(cfg)
+ bandit_noise_env.seed(0)
+ obs = bandit_noise_env.reset()
+ assert obs.shape == (8, )
+ while True:
+ random_action = bandit_noise_env.random_action()
+ timestep = bandit_noise_env.step(random_action)
+ assert timestep.obs.shape == (8, )
+ assert timestep.reward.shape == (1, )
+ if timestep.done:
+ assert 'eval_episode_return' in timestep.info, timestep.info
+ break
+ bandit_noise_env.close()
diff --git a/DI-engine/dizoo/classic_control/__init__.py b/DI-engine/dizoo/classic_control/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/classic_control/acrobot/__init__.py b/DI-engine/dizoo/classic_control/acrobot/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/classic_control/acrobot/config/__init__.py b/DI-engine/dizoo/classic_control/acrobot/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..036dbf6a93519311e76af36bd6ecbc8984a642d6
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/acrobot/config/__init__.py
@@ -0,0 +1 @@
+from .acrobot_dqn_config import acrobot_dqn_config, acrobot_dqn_create_config
diff --git a/DI-engine/dizoo/classic_control/acrobot/config/acrobot_dqn_config.py b/DI-engine/dizoo/classic_control/acrobot/config/acrobot_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4957db987f2e3e65ec3d745439f459dd63eae8d4
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/acrobot/config/acrobot_dqn_config.py
@@ -0,0 +1,55 @@
+from easydict import EasyDict
+
+acrobot_dqn_config = dict(
+ exp_name='acrobot_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=-60,
+ env_id='Acrobot-v1',
+ replay_path='acrobot_dqn_seed0/video',
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=6,
+ action_shape=3,
+ encoder_hidden_size_list=[256, 256],
+ dueling=True,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=128,
+ learning_rate=0.0001,
+ target_update_freq=250,
+ ),
+ collect=dict(n_sample=96, ),
+ eval=dict(evaluator=dict(eval_freq=2000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+acrobot_dqn_config = EasyDict(acrobot_dqn_config)
+main_config = acrobot_dqn_config
+acrobot_dqn_create_config = dict(
+ env=dict(type='acrobot', import_names=['dizoo.classic_control.acrobot.envs.acrobot_env']),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+ replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']),
+)
+acrobot_dqn_create_config = EasyDict(acrobot_dqn_create_config)
+create_config = acrobot_dqn_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/acrobot/envs/__init__.py b/DI-engine/dizoo/classic_control/acrobot/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..be6537f2c90a1eb092a6b70717529cf2e74847fb
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/acrobot/envs/__init__.py
@@ -0,0 +1 @@
+from .acrobot_env import AcroBotEnv
diff --git a/DI-engine/dizoo/classic_control/acrobot/envs/acrobot_env.py b/DI-engine/dizoo/classic_control/acrobot/envs/acrobot_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c2632331597580a695c5b2da02d758f46300250
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/acrobot/envs/acrobot_env.py
@@ -0,0 +1,98 @@
+from typing import Any, List, Union, Optional
+import time
+import gym
+import copy
+import numpy as np
+from easydict import EasyDict
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.torch_utils import to_ndarray, to_list
+from ding.utils import ENV_REGISTRY
+from ding.envs import ObsPlusPrevActRewWrapper
+
+
+@ENV_REGISTRY.register('acrobot')
+class AcroBotEnv(BaseEnv):
+
+ def __init__(self, cfg: dict = {}) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ self._replay_path = None
+ self._observation_space = gym.spaces.Box(
+ low=np.array([-1.0, -1.0, -1.0, -1.0, -12.57, -28.27]),
+ high=np.array([1.0, 1.0, 1.0, 1.0, 12.57, 28.27]),
+ shape=(6, ),
+ dtype=np.float32
+ )
+ self._action_space = gym.spaces.Discrete(3)
+ self._action_space.seed(0) # default seed
+ self._reward_space = gym.spaces.Box(low=-1.0, high=0.0, shape=(1, ), dtype=np.float32)
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = gym.make('Acrobot-v1')
+ if self._replay_path is not None:
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ self._action_space.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._action_space.seed(self._seed)
+ self._observation_space = self._env.observation_space
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+ obs = to_ndarray(obs)
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
+ if isinstance(action, np.ndarray) and action.shape == (1, ):
+ action = action.squeeze() # 0-dim array
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ obs = to_ndarray(obs)
+ rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transfered to a array with shape (1,)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine Acrobot Env"
diff --git a/DI-engine/dizoo/classic_control/acrobot/envs/test_acrobot_env.py b/DI-engine/dizoo/classic_control/acrobot/envs/test_acrobot_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..fba0914cfa0870986e25362649bb08dad1bf53cc
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/acrobot/envs/test_acrobot_env.py
@@ -0,0 +1,35 @@
+import pytest
+import numpy as np
+from dizoo.classic_control.acrobot.envs import AcroBotEnv
+
+
+@pytest.mark.envtest
+class TestAcrobotEnv:
+
+ def test_naive(self):
+ env = AcroBotEnv({})
+ env.seed(314, dynamic_seed=False)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (6, )
+ for _ in range(5):
+ env.reset()
+ np.random.seed(314)
+ print('=' * 60)
+ for i in range(10):
+ # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
+ # can generate legal random action.
+ if i < 5:
+ random_action = np.array([env.action_space.sample()])
+ else:
+ random_action = env.random_action()
+ timestep = env.step(random_action)
+ print(timestep)
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (6, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.reward_space.low
+ assert timestep.reward <= env.reward_space.high
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/classic_control/cartpole/__init__.py b/DI-engine/dizoo/classic_control/cartpole/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/__init__.py b/DI-engine/dizoo/classic_control/cartpole/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d6d124274610c604fa6a9bd5afc3387b0d46280
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/__init__.py
@@ -0,0 +1,23 @@
+from .cartpole_a2c_config import cartpole_a2c_config, cartpole_a2c_create_config
+from .cartpole_acer_config import cartpole_acer_config, cartpole_acer_create_config
+from .cartpole_c51_config import cartpole_c51_config, cartpole_c51_create_config
+from .cartpole_dqfd_config import cartpole_dqfd_config, cartpole_dqfd_create_config
+from .cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config
+from .cartpole_dqn_gail_config import cartpole_dqn_gail_config, cartpole_dqn_gail_create_config
+from .cartpole_gcl_config import cartpole_gcl_ppo_onpolicy_config, cartpole_gcl_ppo_onpolicy_create_config
+from .cartpole_impala_config import cartpole_impala_config, cartpole_impala_create_config
+from .cartpole_iqn_config import cartpole_iqn_config, cartpole_iqn_create_config
+from .cartpole_ppo_offpolicy_config import cartpole_ppo_offpolicy_config, cartpole_ppo_offpolicy_create_config
+from .cartpole_ppg_config import cartpole_ppg_config, cartpole_ppg_create_config
+from .cartpole_ppo_config import cartpole_ppo_config, cartpole_ppo_create_config
+from .cartpole_qrdqn_config import cartpole_qrdqn_config, cartpole_qrdqn_create_config
+from .cartpole_r2d2_config import cartpole_r2d2_config, cartpole_r2d2_create_config
+from .cartpole_rainbow_config import cartpole_rainbow_config, cartpole_rainbow_create_config
+from .cartpole_sqil_config import cartpole_sqil_config, cartpole_sqil_create_config
+from .cartpole_sql_config import cartpole_sql_config, cartpole_sql_create_config
+from .cartpole_sqn_config import cartpole_sqn_config, cartpole_sqn_create_config
+from .cartpole_trex_dqn_config import cartpole_trex_dqn_config, cartpole_trex_dqn_create_config
+from .cartpole_trex_offppo_config import cartpole_trex_offppo_config, cartpole_trex_offppo_create_config
+from .cartpole_trex_onppo_config import cartpole_trex_ppo_onpolicy_config, cartpole_trex_ppo_onpolicy_create_config
+from .cartpole_mdqn_config import cartpole_mdqn_config, cartpole_mdqn_create_config
+# from .cartpole_ppo_default_loader import cartpole_ppo_default_loader
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_a2c_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_a2c_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec6f93cd6ebf0f0e936a402e42364509046feda1
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_a2c_config.py
@@ -0,0 +1,51 @@
+from easydict import EasyDict
+
+cartpole_a2c_config = dict(
+ exp_name='cartpole_a2c_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ # (bool) whether use on-policy training pipeline(behaviour policy and training policy are the same)
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ ),
+ learn=dict(
+ batch_size=40,
+ learning_rate=0.001,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.01,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=80,
+ # (float) the trade-off factor lambda to balance 1step td and mc
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=50, )),
+ ),
+)
+cartpole_a2c_config = EasyDict(cartpole_a2c_config)
+main_config = cartpole_a2c_config
+
+cartpole_a2c_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='a2c'),
+)
+cartpole_a2c_create_config = EasyDict(cartpole_a2c_create_config)
+create_config = cartpole_a2c_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c cartpole_a2c_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_acer_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_acer_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3555e6aee8ac87ec9d08e3ef3c45addb72f42c4
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_acer_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+
+cartpole_acer_config = dict(
+ exp_name='cartpole_acer_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[64, 64],
+ ),
+ # (int) the trajectory length to calculate Q retrace target
+ unroll_len=32,
+ learn=dict(
+ # (int) collect n_sample data, train model update_per_collect times
+ # here we follow ppo serial pipeline
+ update_per_collect=4,
+ # (int) the number of data for a train iteration
+ batch_size=16,
+ learning_rate_actor=0.0005,
+ learning_rate_critic=0.0005,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ # entropy_weight=0.0001,
+ entropy_weight=0.0,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.9,
+ # (float) additional discounting parameter
+ # (int) the trajectory length to calculate v-trace target
+ # (float) clip ratio of importance weights
+ trust_region=True,
+ c_clip_ratio=10,
+ # (float) clip ratio of importance sampling
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=16,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.9,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=200, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
+ ),
+)
+
+cartpole_acer_config = EasyDict(cartpole_acer_config)
+main_config = cartpole_acer_config
+
+cartpole_acer_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='acer'),
+)
+
+cartpole_acer_create_config = EasyDict(cartpole_acer_create_config)
+create_config = cartpole_acer_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_acer_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_bc_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b1975718f32af46bb18567f5df5aa9391ae61d69
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_bc_config.py
@@ -0,0 +1,44 @@
+from easydict import EasyDict
+
+cartpole_bc_config = dict(
+ exp_name='cartpole_bc_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ continuous=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[64, 64, 128],
+ ),
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.01,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000)),
+ train_epoch=20,
+ ),
+ eval=dict(evaluator=dict(eval_freq=40, ))
+ ),
+)
+cartpole_bc_config = EasyDict(cartpole_bc_config)
+main_config = cartpole_bc_config
+cartpole_bc_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='bc'),
+)
+cartpole_bc_create_config = EasyDict(cartpole_bc_create_config)
+create_config = cartpole_bc_create_config
+
+if __name__ == "__main__":
+ # Note: Users need to generate expert data, and save the data to ``expert_data_path``
+ from ding.entry import serial_pipeline_bc
+ serial_pipeline_bc([main_config, create_config], seed=0, data_path=expert_data_path)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_bco_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_bco_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..79b5f0165f32962bdff8db57953448e5749f8692
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_bco_config.py
@@ -0,0 +1,77 @@
+from easydict import EasyDict
+
+cartpole_bco_config = dict(
+ exp_name='cartpole_bco_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ replay_path=None,
+ ),
+ policy=dict(
+ cuda=True,
+ continuous=False,
+ loss_type='l1_loss',
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ learn=dict(
+ train_epoch=20,
+ batch_size=128,
+ learning_rate=0.001,
+ weight_decay=1e-4,
+ momentum=0.9,
+ decay_epoch=30,
+ decay_rate=1,
+ warmup_lr=1e-4,
+ warmup_epoch=3,
+ optimizer='SGD',
+ lr_decay=True,
+ ),
+ collect=dict(
+ n_episode=10,
+ # control the number (alpha*n_episode) of post-demonstration environment interactions at each iteration.
+ # Notice: alpha * n_episode > collector_env_num
+ model_path='abs model path', # epxert model path
+ data_path='abs data path', # expert data path
+ ),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ), )
+ ),
+ bco=dict(
+ learn=dict(idm_batch_size=32, idm_learning_rate=0.001, idm_weight_decay=1e-4, idm_train_epoch=10),
+ model=dict(idm_encoder_hidden_size_list=[60, 80, 100, 40], action_space='discrete'),
+ alpha=0.8,
+ )
+)
+cartpole_bco_config = EasyDict(cartpole_bco_config)
+main_config = cartpole_bco_config
+cartpole_bco_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='bc'),
+ collector=dict(type='episode')
+)
+cartpole_bco_create_config = EasyDict(cartpole_bco_create_config)
+create_config = cartpole_bco_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_bco
+ from dizoo.classic_control.cartpole.config import cartpole_dqn_config, cartpole_dqn_create_config
+ expert_main_config = cartpole_dqn_config
+ expert_create_config = cartpole_dqn_create_config
+ serial_pipeline_bco(
+ [main_config, create_config], [cartpole_dqn_config, cartpole_dqn_create_config], seed=0, max_env_step=100000
+ )
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_c51_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_c51_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a9f85572921cc78f7b952c15e192d0b96b18402a
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_c51_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+cartpole_c51_config = dict(
+ exp_name='cartpole_c51_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=True,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ v_min=-10,
+ v_max=10,
+ n_atom=51,
+ ),
+ discount_factor=0.97,
+ nstep=3,
+ learn=dict(
+ update_per_collect=3,
+ batch_size=64,
+ learning_rate=0.001,
+ target_update_freq=100,
+ ),
+ collect=dict(
+ n_sample=80,
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ), replay_buffer=dict(replay_buffer_size=20000, )
+ ),
+ ),
+)
+cartpole_c51_config = EasyDict(cartpole_c51_config)
+main_config = cartpole_c51_config
+cartpole_c51_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='c51'),
+)
+cartpole_c51_create_config = EasyDict(cartpole_c51_create_config)
+create_config = cartpole_c51_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_c51_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_cql_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_cql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b1932e5ad910679b33da9dad300245fb627aca9
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_cql_config.py
@@ -0,0 +1,53 @@
+from easydict import EasyDict
+
+cartpole_discrete_cql_config = dict(
+ exp_name='cartpole_cql_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ num_quantiles=64,
+ ),
+ discount_factor=0.97,
+ nstep=3,
+ learn=dict(
+ train_epoch=3000,
+ batch_size=64,
+ learning_rate=0.001,
+ target_update_freq=100,
+ kappa=1.0,
+ min_q_weight=4.0,
+ ),
+ collect=dict(
+ data_type='hdf5',
+ # offline data path
+ data_path='./cartpole_qrdqn_generation_data_seed0/expert_demos.hdf5',
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ ),
+)
+cartpole_discrete_cql_config = EasyDict(cartpole_discrete_cql_config)
+main_config = cartpole_discrete_cql_config
+cartpole_discrete_cql_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='discrete_cql'),
+)
+cartpole_discrete_cql_create_config = EasyDict(cartpole_discrete_cql_create_config)
+create_config = cartpole_discrete_cql_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_offline -c cartpole_cql_config.py -s 0`
+ from ding.entry import serial_pipeline_offline
+ serial_pipeline_offline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_decision_transformer.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_decision_transformer.py
new file mode 100644
index 0000000000000000000000000000000000000000..87f0312bfd6a3f99ce29688e563171e1ba1465bc
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_decision_transformer.py
@@ -0,0 +1,84 @@
+from easydict import EasyDict
+import torch
+from copy import deepcopy
+
+cartpole_dt_config = dict(
+ exp_name='cartpole_dt',
+ env=dict(
+ env_name='CartPole-v0',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=195,
+ ),
+ policy=dict(
+ device='cuda',
+ stop_value=195,
+ env_name='CartPole-v0',
+ dataset='medium', # medium / medium-replay / medium-expert
+ rtg_scale=1000, # normalize returns to go
+ max_eval_ep_len=1000, # max len of one episode
+ num_eval_ep=10, # num of evaluation episodes
+ batch_size=64, # training batch size
+ # batch_size= 2, # debug
+ lr=1e-4,
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ num_updates_per_iter=100,
+ context_len=20,
+ n_blocks=3,
+ embed_dim=128,
+ n_heads=1,
+ dropout_p=0.1,
+ log_dir='/home/puyuan/DI-engine/dizoo/classic_control/cartpole/dt_log',
+ max_test_ep_len=200,
+ model=dict(
+ state_dim=4,
+ act_dim=2,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=False,
+ ),
+ discount_factor=0.999,
+ nstep=3,
+ learn=dict(
+ dataset_path='/home/puyuan/DI-engine/dizoo/classic_control/cartpole/dt_data/data/expert_data_1000eps.pkl',
+ learning_rate=0.001,
+ target_update_freq=100,
+ kappa=1.0,
+ min_q_weight=4.0,
+ ),
+ collect=dict(unroll_len=1, ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=int(1e4),
+ ),
+ replay_buffer=dict(replay_buffer_size=int(2e4), )
+ ),
+ ),
+)
+cartpole_dt_config = EasyDict(cartpole_dt_config)
+main_config = cartpole_dt_config
+cartpole_dt_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dt'),
+)
+cartpole_dt_create_config = EasyDict(cartpole_dt_create_config)
+create_config = cartpole_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt, collect_demo_data, eval, serial_pipeline
+ main_config.exp_name = 'cartpole_dt'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=200)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqfd_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqfd_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebc95c3fd111a84ccb0a6605de7a6887ac46ad16
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqfd_config.py
@@ -0,0 +1,75 @@
+from easydict import EasyDict
+
+cartpole_dqfd_config = dict(
+ exp_name='cartpole_dqfd_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=3,
+ discount_factor=0.97,
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ lambda1=1, # n-step return
+ lambda2=3.0, # supervised loss
+ # set this to be 0 (L2 loss = 0) with expert_replay_buffer_size = 0 and lambda1 = 0
+ # recover the one step pdd dqn
+ lambda3=0, # L2 regularization
+ per_train_iter_k=10,
+ expert_replay_buffer_size=10000, # justify the buffer size of the expert buffer
+ ),
+ collect=dict(
+ n_sample=8,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ ),
+ # note: this is the times after which you learns to evaluate
+ eval=dict(evaluator=dict(eval_freq=50, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+cartpole_dqfd_config = EasyDict(cartpole_dqfd_config)
+main_config = cartpole_dqfd_config
+cartpole_dqfd_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqfd'),
+)
+cartpole_dqfd_create_config = EasyDict(cartpole_dqfd_create_config)
+create_config = cartpole_dqfd_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_dqfd -c cartpole_dqfd_config.py -s 0`
+ # then input ``cartpole_dqfd_config.py`` upon the instructions.
+ # The reason we need to input the dqfd config is we have to borrow its ``_get_train_sample`` function
+ # in the collector part even though the expert model may be generated from other Q learning algos.
+ from ding.entry.serial_entry_dqfd import serial_pipeline_dqfd
+ from dizoo.classic_control.cartpole.config import cartpole_dqfd_config, cartpole_dqfd_create_config
+ expert_main_config = cartpole_dqfd_config
+ expert_create_config = cartpole_dqfd_create_config
+ serial_pipeline_dqfd((main_config, create_config), (expert_main_config, expert_create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqn_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e2cb85433838ff5c9999d861ae93310126dc26c
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqn_config.py
@@ -0,0 +1,59 @@
+from easydict import EasyDict
+
+cartpole_dqn_config = dict(
+ exp_name='cartpole_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ replay_path='cartpole_dqn_seed0/video',
+ ),
+ policy=dict(
+ cuda=False,
+ load_path='cartpole_dqn_seed0/ckpt/ckpt_best.pth.tar', # necessary for eval
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ # dropout=0.1,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ update_per_collect=5,
+ batch_size=64,
+ learning_rate=0.001,
+ ),
+ collect=dict(n_sample=8),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+cartpole_dqn_config = EasyDict(cartpole_dqn_config)
+main_config = cartpole_dqn_config
+cartpole_dqn_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn'),
+ replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']),
+)
+cartpole_dqn_create_config = EasyDict(cartpole_dqn_create_config)
+create_config = cartpole_dqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_dqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqn_gail_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqn_gail_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b438e648e3ff861c3322a8f447b2f0a934fe3c91
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqn_gail_config.py
@@ -0,0 +1,80 @@
+from easydict import EasyDict
+
+cartpole_dqn_gail_config = dict(
+ exp_name='cartpole_dqn_gail_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ reward_model=dict(
+ type='gail',
+ input_size=5,
+ hidden_size=64,
+ batch_size=64,
+ learning_rate=1e-3,
+ update_per_collect=100,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ # If collect_data is True, we will use this expert_model_path to collect expert data first, rather than we
+ # will load data directly from user-defined data_path
+ expert_model_path='model_path_placeholder',
+ collect_count=1000,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ update_per_collect=3,
+ ),
+ collect=dict(n_sample=64),
+ eval=dict(evaluator=dict(eval_freq=10, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+cartpole_dqn_gail_config = EasyDict(cartpole_dqn_gail_config)
+main_config = cartpole_dqn_gail_config
+cartpole_dqn_gail_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn'),
+)
+cartpole_dqn_gail_create_config = EasyDict(cartpole_dqn_gail_create_config)
+create_config = cartpole_dqn_gail_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_gail -c cartpole_dqn_gail_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. cartpole_dqn_config.py
+ from ding.entry import serial_pipeline_gail
+ from dizoo.classic_control.cartpole.config import cartpole_dqn_config, cartpole_dqn_create_config
+ expert_main_config = cartpole_dqn_config
+ expert_create_config = cartpole_dqn_create_config
+ serial_pipeline_gail(
+ (main_config, create_config), (expert_main_config, expert_create_config),
+ max_env_step=1000000,
+ seed=0,
+ collect_data=True
+ )
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqn_rnd_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqn_rnd_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..caf7a4f54791816825f6a3b5f07c618e40343102
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqn_rnd_config.py
@@ -0,0 +1,58 @@
+from easydict import EasyDict
+
+cartpole_dqn_config = dict(
+ exp_name='cartpole_dqn_rnd',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ reward_model=dict(
+ intrinsic_reward_type='add',
+ learning_rate=1e-3,
+ obs_shape=4,
+ batch_size=32,
+ update_per_collect=10,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ ),
+ collect=dict(n_sample=8),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+cartpole_dqn_config = EasyDict(cartpole_dqn_config)
+main_config = cartpole_dqn_config
+cartpole_dqn_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn'),
+ replay_buffer=dict(type='deque'),
+ reward_model=dict(type='rnd'),
+)
+cartpole_dqn_create_config = EasyDict(cartpole_dqn_create_config)
+create_config = cartpole_dqn_create_config
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqn_stdim_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqn_stdim_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3dfeff7a1712b38f0a6d63485439d3a8f128b00
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dqn_stdim_config.py
@@ -0,0 +1,65 @@
+from easydict import EasyDict
+
+cartpole_dqn_stdim_config = dict(
+ exp_name='cartpole_dqn_stdim_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ replay_path='cartpole_dqn_stdim_seed0/video',
+ ),
+ policy=dict(
+ cuda=False,
+ load_path='cartpole_dqn_stdim_seed0/ckpt/ckpt_best.pth.tar', # necessary for eval
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ aux_model=dict(
+ encode_shape=64,
+ heads=[1, 1],
+ loss_type='infonce',
+ temperature=1.0,
+ ),
+ # the weight of the auxiliary loss to the TD loss
+ aux_loss_weight=0.003,
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ ),
+ collect=dict(n_sample=8),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+cartpole_dqn_stdim_config = EasyDict(cartpole_dqn_stdim_config)
+main_config = cartpole_dqn_stdim_config
+cartpole_dqn_stdim_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn_stdim'),
+ replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']),
+)
+cartpole_dqn_stdim_create_config = EasyDict(cartpole_dqn_stdim_create_config)
+create_config = cartpole_dqn_stdim_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_dqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_drex_dqn_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_drex_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c898528a395d6dfead411a0dc215e015f6859123
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_drex_dqn_config.py
@@ -0,0 +1,85 @@
+from easydict import EasyDict
+
+cartpole_drex_dqn_config = dict(
+ exp_name='cartpole_drex_dqn_seed0',
+ env=dict(
+ manager=dict(shared_memory=True, reset_inplace=True),
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ reward_model=dict(
+ type='drex',
+ min_snippet_length=5,
+ max_snippet_length=100,
+ checkpoint_min=0,
+ checkpoint_max=1000,
+ checkpoint_step=1000,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ # path to expert models that generate demonstration data
+ # Users should add their own model path here. Model path should lead to an exp_name.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name``.
+ # For example, if you want to use dqn to generate demos, you can use ``spaceinvaders_dqn``
+ expert_model_path='expert_model_path_placeholder',
+ # path to save reward model
+ # Users should add their own model path here.
+ # Absolute path is recommended.
+ # For example, if you use ``spaceinvaders_drex``, then the reward model will be saved in this directory.
+ reward_model_path='reward_model_path_placeholder + ./spaceinvaders.params',
+ # path to save generated observations.
+ # Users should add their own model path here.
+ # Absolute path is recommended.
+ # For example, if you use ``spaceinvaders_drex``, then all the generated data will be saved in this directory.
+ offline_data_path='offline_data_path_placeholder',
+ # path to pretrained bc model. If omitted, bc will be trained instead.
+ # Users should add their own model path here. Model path should lead to a model ckpt.
+ # Absolute path is recommended.
+ bc_path='bc_path_placeholder',
+ # list of noises
+ eps_list=[0, 0.5, 1],
+ num_trajs_per_bin=20,
+ bc_iterations=6000,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ ),
+ collect=dict(n_sample=8, collector=dict(get_train_sample=False, )),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+cartpole_drex_dqn_config = EasyDict(cartpole_drex_dqn_config)
+main_config = cartpole_drex_dqn_config
+cartpole_drex_dqn_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+ collector=dict(type='episode'),
+)
+cartpole_drex_dqn_create_config = EasyDict(cartpole_drex_dqn_create_config)
+create_config = cartpole_drex_dqn_create_config
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dt_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7eebd77428e71d9b26a3535834428fe58d6b3612
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_dt_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+cartpole_discrete_dt_config = dict(
+ exp_name='cartpole_dt_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ dataset=dict(
+ data_dir_prefix='./cartpole_qrdqn_generation_data_seed0/expert_demos.hdf5',
+ rtg_scale=None,
+ context_len=20,
+ env_type='classic',
+ ),
+ policy=dict(
+ cuda=False,
+ rtg_target=10,
+ evaluator_env_num=5,
+ clip_grad_norm_p=1.0,
+ state_mean=1,
+ state_std=0,
+ model=dict(
+ state_dim=4,
+ act_dim=2,
+ n_blocks=6,
+ h_dim=128,
+ context_len=20,
+ n_heads=8,
+ drop_p=0.1,
+ continuous=False,
+ ),
+ max_timestep=1000,
+ discount_factor=0.97,
+ nstep=3,
+ batch_size=64,
+ learning_rate=0.001,
+ target_update_freq=100,
+ kappa=1.0,
+ min_q_weight=4.0,
+ collect=dict(
+ data_type='hdf5',
+ data_path='./cartpole_qrdqn_generation_data_seed0/expert_demos.hdf5',
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ ),
+)
+cartpole_discrete_dt_config = EasyDict(cartpole_discrete_dt_config)
+main_config = cartpole_discrete_dt_config
+cartpole_discrete_dt_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dt'),
+)
+cartpole_discrete_dt_create_config = EasyDict(cartpole_discrete_dt_create_config)
+create_config = cartpole_discrete_dt_create_config
+# You can run this config with the entry file like `ding/example/dt.py`
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_fqf_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_fqf_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca65670d815ed7191c5576ac8055d466ee5f3d6c
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_fqf_config.py
@@ -0,0 +1,63 @@
+from easydict import EasyDict
+
+cartpole_fqf_config = dict(
+ exp_name='cartpole_fqf_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ replay_path='cartpole_fqf_seed0/video',
+ ),
+ policy=dict(
+ cuda=False,
+ priority=True,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ num_quantiles=32,
+ quantile_embedding_size=64,
+ ),
+ discount_factor=0.97,
+ nstep=1,
+ learn=dict(
+ update_per_collect=3,
+ batch_size=64,
+ learning_rate_fraction=0.0001,
+ learning_rate_quantile=0.0001,
+ target_update_freq=100,
+ ent_coef=0,
+ ),
+ collect=dict(
+ n_sample=80,
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ), replay_buffer=dict(replay_buffer_size=20000, )
+ ),
+ ),
+)
+cartpole_fqf_config = EasyDict(cartpole_fqf_config)
+main_config = cartpole_fqf_config
+cartpole_fqf_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='fqf'),
+)
+cartpole_fqf_create_config = EasyDict(cartpole_fqf_create_config)
+create_config = cartpole_fqf_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c cartpole_fqf_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
\ No newline at end of file
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_gcl_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_gcl_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c4c8faf0834694a36007ad51e68cdbc22e08cbce
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_gcl_config.py
@@ -0,0 +1,68 @@
+from easydict import EasyDict
+
+cartpole_gcl_ppo_onpolicy_config = dict(
+ exp_name='cartpole_gcl_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ reward_model=dict(
+ learning_rate=0.001,
+ input_size=5,
+ batch_size=32,
+ continuous=False,
+ update_per_collect=10,
+ ),
+ policy=dict(
+ cuda=False,
+ recompute_adv=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ ),
+ learn=dict(
+ update_per_collect=2,
+ batch_size=64,
+ learning_rate=0.001,
+ entropy_weight=0.01,
+ ),
+ collect=dict(
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ # If you need the data collected by the collector to contain logit key which reflect the probability of
+ # the action, you can change the key to be True.
+ # In Guided cost Learning, we need to use logit to train the reward model, we change the key to be True.
+ collector_logit=True, # add logit into collected transition
+ n_sample=256,
+ discount_factor=0.9,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=50, ), ),
+ ),
+)
+cartpole_gcl_ppo_onpolicy_config = EasyDict(cartpole_gcl_ppo_onpolicy_config)
+main_config = cartpole_gcl_ppo_onpolicy_config
+cartpole_gcl_ppo_onpolicy_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo'),
+ reward_model=dict(type='guided_cost'),
+)
+cartpole_gcl_ppo_onpolicy_create_config = EasyDict(cartpole_gcl_ppo_onpolicy_create_config)
+create_config = cartpole_gcl_ppo_onpolicy_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_guided_cost
+ serial_pipeline_guided_cost((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_impala_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_impala_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c78d9392af10bff4191e41ec509849035c4ba16a
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_impala_config.py
@@ -0,0 +1,76 @@
+from easydict import EasyDict
+
+cartpole_impala_config = dict(
+ exp_name='cartpole_impala_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[64, 64],
+ ),
+ # (int) the trajectory length to calculate v-trace target
+ unroll_len=8,
+ learn=dict(
+ # (int) collect n_sample data, train model update_per_collect times
+ # here we follow ppo serial pipeline
+ update_per_collect=4,
+ # (int) the number of data for a train iteration
+ batch_size=16,
+ learning_rate=0.0005,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ value_weight=0.5,
+ # (float) loss weight of the entropy regularization, the weight of policy network is set to 1
+ entropy_weight=0.0001,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.9,
+ # (float) additional discounting parameter
+ lambda_=0.95,
+ # (float) clip ratio of importance weights
+ rho_clip_ratio=1.0,
+ # (float) clip ratio of importance weights
+ c_clip_ratio=1.0,
+ # (float) clip ratio of importance sampling
+ rho_pg_clip_ratio=1.0,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model n_iteration times
+ n_sample=16,
+ # (float) discount factor for future reward, defaults int [0, 1]
+ discount_factor=0.9,
+ gae_lambda=0.95,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=200, )),
+ other=dict(replay_buffer=dict(
+ replay_buffer_size=1000,
+ max_use=16,
+ ), ),
+ ),
+)
+
+cartpole_impala_config = EasyDict(cartpole_impala_config)
+main_config = cartpole_impala_config
+
+cartpole_impala_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='impala'),
+)
+
+cartpole_impala_create_config = EasyDict(cartpole_impala_create_config)
+create_config = cartpole_impala_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_impala_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_iqn_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_iqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d6fca73e9371d8202f4d5a22d84de033f25247f4
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_iqn_config.py
@@ -0,0 +1,59 @@
+from easydict import EasyDict
+
+cartpole_iqn_config = dict(
+ exp_name='cartpole_iqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ num_quantiles=32,
+ ),
+ discount_factor=0.97,
+ nstep=3,
+ learn=dict(
+ update_per_collect=3,
+ batch_size=64,
+ learning_rate=0.001,
+ target_update_freq=100,
+ kappa=1.0,
+ ),
+ collect=dict(
+ n_sample=80,
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ), replay_buffer=dict(replay_buffer_size=20000, )
+ ),
+ ),
+)
+cartpole_iqn_config = EasyDict(cartpole_iqn_config)
+main_config = cartpole_iqn_config
+cartpole_iqn_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='iqn'),
+)
+cartpole_iqn_create_config = EasyDict(cartpole_iqn_create_config)
+create_config = cartpole_iqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_iqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_mdqn_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_mdqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b72a6375a1905c5c6c56b577f9e0035596b6a25a
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_mdqn_config.py
@@ -0,0 +1,58 @@
+from easydict import EasyDict
+
+cartpole_mdqn_config = dict(
+ exp_name='cartpole_mdqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ entropy_tau=0.03,
+ m_alpha=0.9,
+ learn=dict(
+ update_per_collect=5,
+ batch_size=64,
+ learning_rate=0.001,
+ ),
+ collect=dict(n_sample=8),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+cartpole_mdqn_config = EasyDict(cartpole_mdqn_config)
+main_config = cartpole_mdqn_config
+cartpole_mdqn_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='mdqn'),
+ replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']),
+)
+cartpole_mdqn_create_config = EasyDict(cartpole_mdqn_create_config)
+create_config = cartpole_mdqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_mdqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0, dynamic_seed=False)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ngu_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ngu_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3aecbbb01b6704cd2625019d6e408a0b06f5e32b
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ngu_config.py
@@ -0,0 +1,124 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+nstep = 5
+cartpole_ngu_config = dict(
+ exp_name='cartpole_ngu_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ obs_plus_prev_action_reward=True, # use specific env wrapper for ngu policy
+ stop_value=195,
+ ),
+ rnd_reward_model=dict(
+ intrinsic_reward_type='add',
+ learning_rate=5e-4,
+ obs_shape=4,
+ action_shape=2,
+ batch_size=128, # transitions
+ update_per_collect=10,
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='rnd-ngu',
+ ),
+ episodic_reward_model=dict(
+ # means if using rescale trick to the last non-zero reward
+ # when combing extrinsic and intrinsic reward.
+ # the rescale trick only used in:
+ # 1. sparse reward env minigrid, in which the last non-zero reward is a strong positive signal
+ # 2. the last reward of each episode directly reflects the agent's completion of the task, e.g. lunarlander
+ # Note that the ngu intrinsic reward is a positive value (max value is 5), in these envs,
+ # the last non-zero reward should not be overwhelmed by intrinsic rewards, so we need rescale the
+ # original last nonzero extrinsic reward.
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_rescale=False,
+ # means the rescale value for the last non-zero reward, only used when last_nonzero_reward_rescale is True
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_weight=1,
+ intrinsic_reward_type='add',
+ learning_rate=5e-4,
+ obs_shape=4,
+ action_shape=2,
+ batch_size=128, # transitions
+ update_per_collect=10,
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='episodic',
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ discount_factor=0.997,
+ nstep=nstep,
+ burnin_step=2,
+ # (int) is the total length of [sequence sample] minus
+ # the length of burnin part in [sequence sample],
+ # i.e., = = +
+ learn_unroll_len=40, # set this key according to the episode length
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ collector_env_num=collector_env_num,
+ ),
+ learn=dict(
+ update_per_collect=8,
+ batch_size=32,
+ learning_rate=1e-4,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e4,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=int(1e4),
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+cartpole_ngu_config = EasyDict(cartpole_ngu_config)
+main_config = cartpole_ngu_config
+cartpole_ngu_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ngu'),
+ rnd_reward_model=dict(type='rnd-ngu'),
+ episodic_reward_model=dict(type='episodic'),
+)
+cartpole_ngu_create_config = EasyDict(cartpole_ngu_create_config)
+create_config = cartpole_ngu_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_ngu -c cartpole_ngu_config.py -s 0`
+ from ding.entry import serial_pipeline_ngu
+ serial_pipeline_ngu([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_pg_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_pg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..af3ee5ba044c16801a4d3851ff6b110e17a55a93
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_pg_config.py
@@ -0,0 +1,43 @@
+from easydict import EasyDict
+
+cartpole_pg_config = dict(
+ exp_name='cartpole_pg_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ ),
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ entropy_weight=0.001,
+ ),
+ collect=dict(n_episode=80, unroll_len=1, discount_factor=0.9),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+cartpole_pg_config = EasyDict(cartpole_pg_config)
+main_config = cartpole_pg_config
+cartpole_pg_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='pg'),
+ collector=dict(type='episode'),
+)
+cartpole_pg_create_config = EasyDict(cartpole_pg_create_config)
+create_config = cartpole_pg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c cartpole_pg_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppg_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d02a71ec79a5177170d2509367dddbae737ce97b
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppg_config.py
@@ -0,0 +1,75 @@
+from easydict import EasyDict
+
+cartpole_ppg_config = dict(
+ exp_name='cartpole_ppg_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ ),
+ learn=dict(
+ update_per_collect=5,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ ),
+ collect=dict(
+ n_sample=128,
+ unroll_len=1,
+ discount_factor=0.9,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(
+ replay_buffer=dict(
+ multi_buffer=True,
+ policy=dict(
+ replay_buffer_size=100,
+ max_use=10,
+ ),
+ value=dict(
+ replay_buffer_size=1000,
+ max_use=100,
+ ),
+ ),
+ ),
+ ),
+)
+cartpole_ppg_config = EasyDict(cartpole_ppg_config)
+main_config = cartpole_ppg_config
+cartpole_ppg_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppg_offpolicy'),
+ replay_buffer=dict(
+ policy=dict(type='advanced'),
+ value=dict(type='advanced'),
+ )
+)
+cartpole_ppg_create_config = EasyDict(cartpole_ppg_create_config)
+create_config = cartpole_ppg_create_config
+
+if __name__ == "__main__":
+ # This config file can be executed by `dizoo/classic_control/cartpole/entry/cartpole_ppg_main.py`
+ import os
+ import warnings
+ from dizoo.classic_control.cartpole.entry.cartpole_ppg_main import main
+ from dizoo.classic_control.cartpole.entry.cartpole_ppg_main import __file__ as _origin_py_file
+ origin_py_file_rel = os.path.relpath(_origin_py_file, os.path.abspath(os.path.curdir))
+ warnings.warn(UserWarning(f"This config file can be executed by {repr(origin_py_file_rel)}"))
+ main(cartpole_ppg_config)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8a5108721a86b340142633088d4be34f0dffab4
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppo_config.py
@@ -0,0 +1,56 @@
+from easydict import EasyDict
+
+cartpole_ppo_config = dict(
+ exp_name='cartpole_ppo_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ action_space='discrete',
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ ),
+ learn=dict(
+ epoch_per_collect=2,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ learner=dict(hook=dict(save_ckpt_after_iter=100)),
+ ),
+ collect=dict(
+ n_sample=256,
+ unroll_len=1,
+ discount_factor=0.9,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+cartpole_ppo_config = EasyDict(cartpole_ppo_config)
+main_config = cartpole_ppo_config
+cartpole_ppo_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo'),
+)
+cartpole_ppo_create_config = EasyDict(cartpole_ppo_create_config)
+create_config = cartpole_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c cartpole_ppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppo_icm_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppo_icm_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e1c937c08ff495e3097abf3270282185f59ba654
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppo_icm_config.py
@@ -0,0 +1,64 @@
+from easydict import EasyDict
+
+cartpole_ppo_icm_config = dict(
+ exp_name='cartpole_ppo_icm_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ reward_model=dict(
+ intrinsic_reward_type='add',
+ learning_rate=0.001,
+ obs_shape=4,
+ action_shape=2,
+ batch_size=32,
+ update_per_collect=10,
+ ),
+ policy=dict(
+ cuda=False,
+ action_space='discrete',
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ ),
+ learn=dict(
+ epoch_per_collect=2,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ ),
+ collect=dict(
+ n_sample=256,
+ unroll_len=1,
+ discount_factor=0.9,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+cartpole_ppo_icm_config = EasyDict(cartpole_ppo_icm_config)
+main_config = cartpole_ppo_icm_config
+cartpole_ppo_icm_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo_offpolicy'),
+ reward_model=dict(type='icm'),
+)
+cartpole_ppo_icm_create_config = EasyDict(cartpole_ppo_icm_create_config)
+create_config = cartpole_ppo_icm_create_config
+
+if __name__ == '__main__':
+ # TODO: confirm which mode to be used in CLI
+ from ding.entry import serial_pipeline_reward_model_offpolicy
+ serial_pipeline_reward_model_offpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppo_offpolicy_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppo_offpolicy_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f952ecadaa9f2fe745f2aadb5da5df41b7075936
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppo_offpolicy_config.py
@@ -0,0 +1,56 @@
+from easydict import EasyDict
+
+cartpole_ppo_offpolicy_config = dict(
+ exp_name='cartpole_ppo_offpolicy_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ action_space='discrete',
+ ),
+ learn=dict(
+ update_per_collect=6,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000)),
+ ),
+ collect=dict(
+ n_sample=128,
+ unroll_len=1,
+ discount_factor=0.9,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=5000))
+ ),
+)
+cartpole_ppo_offpolicy_config = EasyDict(cartpole_ppo_offpolicy_config)
+main_config = cartpole_ppo_offpolicy_config
+cartpole_ppo_offpolicy_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo_offpolicy'),
+)
+cartpole_ppo_offpolicy_create_config = EasyDict(cartpole_ppo_offpolicy_create_config)
+create_config = cartpole_ppo_offpolicy_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_ppo_offpolicy_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppo_stdim_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppo_stdim_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3f6060797cc667b437237db955fb8aa660921905
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppo_stdim_config.py
@@ -0,0 +1,66 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 5
+cartpole_ppo_stdim_config = dict(
+ exp_name='cartpole_onppo_stdim_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ ),
+ aux_model=dict(
+ encode_shape=64,
+ heads=[1, 1],
+ loss_type='infonce',
+ temperature=1.0,
+ ),
+ # the weight of the auxiliary loss to the TD loss
+ aux_loss_weight=0.003,
+ learn=dict(
+ epoch_per_collect=2,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ learner=dict(hook=dict(save_ckpt_after_iter=100)),
+ ),
+ collect=dict(
+ n_sample=256,
+ unroll_len=1,
+ discount_factor=0.9,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+cartpole_ppo_stdim_config = EasyDict(cartpole_ppo_stdim_config)
+main_config = cartpole_ppo_stdim_config
+cartpole_ppo_stdim_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo_stdim'),
+)
+cartpole_ppo_stdim_create_config = EasyDict(cartpole_ppo_stdim_create_config)
+create_config = cartpole_ppo_stdim_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c cartpole_ppo_stdim_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppopg_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppopg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..623a3b50480bcbe2533c11ad053b8427a036941d
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_ppopg_config.py
@@ -0,0 +1,44 @@
+from easydict import EasyDict
+
+cartpole_ppopg_config = dict(
+ exp_name='cartpole_ppopg_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ ),
+ learn=dict(
+ epoch_per_collect=1,
+ batch_size=64,
+ learning_rate=0.001,
+ entropy_weight=0.001,
+ ),
+ collect=dict(n_episode=80, unroll_len=1, discount_factor=0.9, collector=dict(get_train_sample=True)),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+cartpole_ppopg_config = EasyDict(cartpole_ppopg_config)
+main_config = cartpole_ppopg_config
+cartpole_ppopg_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo_pg'),
+ collector=dict(type='episode'),
+)
+cartpole_ppopg_create_config = EasyDict(cartpole_ppopg_create_config)
+create_config = cartpole_ppopg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c cartpole_ppopg_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_qrdqn_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_qrdqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd873d0b3a5f3d7e644c941411e59ea58a951a81
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_qrdqn_config.py
@@ -0,0 +1,59 @@
+from easydict import EasyDict
+
+cartpole_qrdqn_config = dict(
+ exp_name='cartpole_qrdqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ num_quantiles=64,
+ ),
+ discount_factor=0.97,
+ nstep=3,
+ learn=dict(
+ update_per_collect=3,
+ batch_size=64,
+ learning_rate=0.001,
+ target_update_freq=100,
+ kappa=1.0,
+ ),
+ collect=dict(
+ n_sample=80,
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ), replay_buffer=dict(replay_buffer_size=20000, )
+ ),
+ ),
+)
+cartpole_qrdqn_config = EasyDict(cartpole_qrdqn_config)
+main_config = cartpole_qrdqn_config
+cartpole_qrdqn_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='qrdqn'),
+)
+cartpole_qrdqn_create_config = EasyDict(cartpole_qrdqn_create_config)
+create_config = cartpole_qrdqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_qrdqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..73cfeef33249e63e49eb22ae0003a2c59b77aeec
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_qrdqn_generation_data_config.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+
+cartpole_qrdqn_generation_data_config = dict(
+ exp_name='cartpole_qrdqn_generation_data_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ num_quantiles=64,
+ ),
+ discount_factor=0.97,
+ nstep=3,
+ collect=dict(
+ collect_count=1000,
+ data_type='hdf5',
+ # pretrained RL model path, user can modify it as its own path
+ model_path='./cartpole_qrdqn_seed0/ckpt/ckpt_best.pth.tar',
+ # this prefix should be the same as exp_name
+ save_path='./cartpole_qrdqn_generation_data_seed0/expert.pkl',
+ ),
+ other=dict(eps=dict(collect=0.2, ), ),
+ ),
+)
+cartpole_qrdqn_generation_data_config = EasyDict(cartpole_qrdqn_generation_data_config)
+main_config = cartpole_qrdqn_generation_data_config
+cartpole_qrdqn_generation_data_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='qrdqn'),
+)
+cartpole_qrdqn_generation_data_create_config = EasyDict(cartpole_qrdqn_generation_data_create_config)
+create_config = cartpole_qrdqn_generation_data_create_config
+
+if __name__ == "__main__":
+ from ding.entry import collect_demo_data
+ cfg = main_config.policy.collect
+ collect_demo_data(
+ (main_config, create_config), seed=0, collect_count=cfg.collect_count, state_dict_path=cfg.model_path
+ )
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_r2d2_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_r2d2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..600c885cc967e57dc8c7e6366f76eddd96a8aa36
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_r2d2_config.py
@@ -0,0 +1,78 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+cartpole_r2d2_config = dict(
+ exp_name='cartpole_r2d2_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ priority_IS_weight=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ ),
+ discount_factor=0.995,
+ nstep=5,
+ burnin_step=2,
+ # (int) the whole sequence length to unroll the RNN network minus
+ # the timesteps of burnin part,
+ # i.e., = = +
+ learn_unroll_len=40,
+ learn=dict(
+ # according to the R2D2 paper, actor parameter update interval is 400
+ # environment timesteps, and in per collect phase, we collect 32 sequence
+ # samples, the length of each sample sequence is + ,
+ # which is 100 in our seeting, 32*100/400=8, so we set update_per_collect=8
+ # in most environments
+ update_per_collect=5,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In R2D2 policy, for each collect_env, we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ unroll_len=2 + 40,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=30)),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=10000,
+ ), replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+)
+cartpole_r2d2_config = EasyDict(cartpole_r2d2_config)
+main_config = cartpole_r2d2_config
+cartpole_r2d2_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='r2d2'),
+)
+cartpole_r2d2_create_config = EasyDict(cartpole_r2d2_create_config)
+create_config = cartpole_r2d2_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_r2d2_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_r2d2_gtrxl_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_r2d2_gtrxl_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a2cdac0202d533286c55a4e14a84d826f8f29d5
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_r2d2_gtrxl_config.py
@@ -0,0 +1,78 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 5
+cartpole_r2d2_gtrxl_config = dict(
+ exp_name='cartpole_r2d2_gtrxl_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ priority_IS_weight=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ memory_len=5, # length of transformer memory (can be 0)
+ hidden_size=256,
+ gru_bias=2.,
+ att_layer_num=3,
+ dropout=0.,
+ att_head_num=8,
+ ),
+ discount_factor=0.99,
+ nstep=3,
+ burnin_step=4, # how many steps use to initialize the memory (can be 0)
+ unroll_len=11, # trajectory len
+ seq_len=8, # transformer input segment
+ # training sequence: unroll_len - burnin_step - nstep
+ learn=dict(
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_freq=500,
+ value_rescale=True,
+ init_memory='old', # 'zero' or 'old', how to initialize the memory
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In R2D2 policy, for each collect_env, we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=20)),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=10000,
+ ), replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+)
+cartpole_r2d2_gtrxl_config = EasyDict(cartpole_r2d2_gtrxl_config)
+main_config = cartpole_r2d2_gtrxl_config
+cartpole_r2d2_gtrxl_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='r2d2_gtrxl'),
+)
+cartpole_r2d2_gtrxl_create_config = EasyDict(cartpole_r2d2_gtrxl_create_config)
+create_config = cartpole_r2d2_gtrxl_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_r2d2_gtrxl_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_r2d2_residual_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_r2d2_residual_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed2b72bab3e0eab2cec7cb947df97e4eac95f478
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_r2d2_residual_config.py
@@ -0,0 +1,80 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+cartpole_r2d2__residual_config = dict(
+ exp_name='cartpole_r2d2_residual_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ priority_IS_weight=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ res_link=True,
+ ),
+ discount_factor=0.997,
+ nstep=5,
+ burnin_step=10,
+ # (int) is the total length of [sequence sample] minus
+ # the length of burnin part in [sequence sample],
+ # i.e., = = +
+ learn_unroll_len=20, # set this key according to the episode length
+ learn=dict(
+ # according to the R2D2 paper, actor parameter update interval is 400
+ # environment timesteps, and in per collect phase, we collect sequence
+ # samples, the length of each sequence sample is + ,
+ # e.g. if n_sample=32, is 100, thus 32*100/400=8,
+ # we will set update_per_collect=8 in most environments.
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ # according to the R2D2 paper, the target network update interval is 2500
+ target_update_freq=2500,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=20)),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=10000,
+ ), replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+)
+cartpole_r2d2__residual_config = EasyDict(cartpole_r2d2__residual_config)
+main_config = cartpole_r2d2__residual_config
+cartpole_r2d2_residual_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='r2d2'),
+)
+cartpole_r2d2_residual_create_config = EasyDict(cartpole_r2d2_residual_create_config)
+create_config = cartpole_r2d2_residual_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_r2d2_residual_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_rainbow_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_rainbow_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a678022049c375d1b9a5613e66e251704e146534
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_rainbow_config.py
@@ -0,0 +1,58 @@
+from easydict import EasyDict
+
+cartpole_rainbow_config = dict(
+ exp_name='cartpole_rainbow_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=True,
+ discount_factor=0.97,
+ nstep=3,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ ),
+ learn=dict(
+ update_per_collect=3,
+ batch_size=64,
+ learning_rate=0.001,
+ target_update_freq=100,
+ ),
+ collect=dict(
+ n_sample=80,
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ), replay_buffer=dict(replay_buffer_size=20000, )
+ ),
+ ),
+)
+cartpole_rainbow_config = EasyDict(cartpole_rainbow_config)
+main_config = cartpole_rainbow_config
+cartpole_rainbow_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='rainbow'),
+)
+cartpole_rainbow_create_config = EasyDict(cartpole_rainbow_create_config)
+create_config = cartpole_rainbow_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_rainbow_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_rnd_onppo_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_rnd_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed4e89560ace3383f6665eb2564f805ec9c23fb1
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_rnd_onppo_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+
+cartpole_ppo_rnd_config = dict(
+ exp_name='cartpole_ppo_rnd_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ reward_model=dict(
+ intrinsic_reward_type='add',
+ intrinsic_reward_weight=0.001,
+ # means the rescale value of RND intrinsic_reward only used when intrinsic_reward_weight is None
+ # please refer to rnd_reward_model for details.
+ learning_rate=5e-4,
+ obs_shape=4,
+ batch_size=32,
+ update_per_collect=4,
+ obs_norm=True,
+ obs_norm_clamp_min=-1,
+ obs_norm_clamp_max=1,
+ clear_buffer_per_iters=10,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ ),
+ learn=dict(
+ update_per_collect=6,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ ),
+ collect=dict(
+ n_sample=128,
+ unroll_len=1,
+ discount_factor=0.9,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100))
+ ),
+)
+cartpole_ppo_rnd_config = EasyDict(cartpole_ppo_rnd_config)
+main_config = cartpole_ppo_rnd_config
+cartpole_ppo_rnd_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo_offpolicy'),
+ reward_model=dict(type='rnd'),
+)
+cartpole_ppo_rnd_create_config = EasyDict(cartpole_ppo_rnd_create_config)
+create_config = cartpole_ppo_rnd_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_reward_model_onpolicy
+ serial_pipeline_reward_model_onpolicy((main_config, create_config), seed=0)
+ # you can use the following pipeline to execute pure PPO
+ # from ding.entry import serial_pipeline_onpolicy
+ # serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_sac_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..74c528157703ad22b273799924bdc6ce1e010e63
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_sac_config.py
@@ -0,0 +1,72 @@
+from easydict import EasyDict
+
+cartpole_sac_config = dict(
+ exp_name='cartpole_sac_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ random_collect_size=0,
+ multi_agent=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ twin_critic=True,
+ actor_head_hidden_size=64,
+ critic_head_hidden_size=64,
+ ),
+ learn=dict(
+ update_per_collect=2,
+ batch_size=64,
+ learning_rate_q=5e-3,
+ learning_rate_policy=5e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.01,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ env_num=8,
+ n_sample=256,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(
+ evaluator=dict(eval_freq=50, ),
+ env_num=5,
+ ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=50000,
+ ), replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+)
+
+cartpole_sac_config = EasyDict(cartpole_sac_config)
+main_config = cartpole_sac_config
+
+cartpole_sac_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='discrete_sac'),
+)
+cartpole_sac_create_config = EasyDict(cartpole_sac_create_config)
+create_config = cartpole_sac_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_sac_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_sqil_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_sqil_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e4553f72952149a80536df9263e9d5d6c7a6543
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_sqil_config.py
@@ -0,0 +1,63 @@
+from easydict import EasyDict
+
+cartpole_sqil_config = dict(
+ exp_name='cartpole_sqil_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(batch_size=64, learning_rate=0.001, alpha=0.12),
+ collect=dict(
+ n_sample=8,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='cartpole_dqn_seed0/ckpt/eval.pth.tar'
+ ),
+ # note: this is the times after which you learns to evaluate
+ eval=dict(evaluator=dict(eval_freq=50, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+cartpole_sqil_config = EasyDict(cartpole_sqil_config)
+main_config = cartpole_sqil_config
+cartpole_sqil_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='sql'),
+)
+cartpole_sqil_create_config = EasyDict(cartpole_sqil_create_config)
+create_config = cartpole_sqil_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_sqil -c cartpole_sqil_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. spaceinvaders_dqn_config.py
+ from ding.entry import serial_pipeline_sqil
+ from dizoo.classic_control.cartpole.config import cartpole_dqn_config, cartpole_dqn_create_config
+ expert_main_config = cartpole_dqn_config
+ expert_create_config = cartpole_dqn_create_config
+ serial_pipeline_sqil((main_config, create_config), (expert_main_config, expert_create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_sql_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_sql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..bcdbee32ed1d02a69b29184a9dd8506f19cdc655
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_sql_config.py
@@ -0,0 +1,51 @@
+from easydict import EasyDict
+
+cartpole_sql_config = dict(
+ exp_name='cartpole_sql_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(batch_size=64, learning_rate=0.001, alpha=0.12),
+ collect=dict(n_sample=8),
+ eval=dict(evaluator=dict(eval_freq=50, )), # note: this is the times after which you learns to evaluate
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+cartpole_sql_config = EasyDict(cartpole_sql_config)
+main_config = cartpole_sql_config
+cartpole_sql_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='sql'),
+)
+cartpole_sql_create_config = EasyDict(cartpole_sql_create_config)
+create_config = cartpole_sql_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_sql_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_sqn_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_sqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa779ab785af0b072c83d0b649c31938561a6459
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_sqn_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+update_per_collect = 8
+cartpole_sqn_config = dict(
+ exp_name='cartpole_sqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[64, 64],
+ # Whether to use dueling head.
+ dueling=True,
+ ),
+ learn=dict(
+ multi_gpu=False,
+ update_per_collect=update_per_collect,
+ batch_size=64,
+ learning_rate_q=0.001,
+ learning_rate_alpha=0.001,
+ alpha=0.2,
+ target_entropy=0.2,
+ ),
+ collect=dict(
+ n_sample=update_per_collect * 2,
+ nstep=1,
+ ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.8,
+ decay=2000,
+ ), replay_buffer=dict(replay_buffer_size=10000, )
+ ),
+ )
+)
+cartpole_sqn_config = EasyDict(cartpole_sqn_config)
+main_config = cartpole_sqn_config
+
+cartpole_sqn_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='sqn'),
+)
+cartpole_sqn_create_config = EasyDict(cartpole_sqn_create_config)
+create_config = cartpole_sqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cartpole_sqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..306cadd6f2ca6fbaed286a21d229a66424543acf
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_trex_dqn_config.py
@@ -0,0 +1,80 @@
+from easydict import EasyDict
+
+cartpole_trex_dqn_config = dict(
+ exp_name='cartpole_trex_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ reward_model=dict(
+ type='trex',
+ min_snippet_length=5,
+ max_snippet_length=100,
+ checkpoint_min=0,
+ checkpoint_max=500,
+ checkpoint_step=100,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ num_trajs=6,
+ num_snippets=6000,
+ expert_model_path='cartpole_dqn_seed0', # expert model experiment directory path
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ update_per_collect=5,
+ batch_size=64,
+ learning_rate=0.001,
+ ),
+ collect=dict(n_sample=8),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+cartpole_trex_dqn_config = EasyDict(cartpole_trex_dqn_config)
+main_config = cartpole_trex_dqn_config
+cartpole_trex_dqn_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn'),
+)
+cartpole_trex_dqn_create_config = EasyDict(cartpole_trex_dqn_create_config)
+create_config = cartpole_trex_dqn_create_config
+
+if __name__ == "__main__":
+ # Users should first run ``cartpole_dqn_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_reward_model_trex
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_reward_model_trex((main_config, create_config))
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_trex_offppo_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_trex_offppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b58535f900ba02eddc6811cb65ad7ead898d706c
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_trex_offppo_config.py
@@ -0,0 +1,80 @@
+from easydict import EasyDict
+
+cartpole_trex_offppo_config = dict(
+ exp_name='cartpole_trex_offppo_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ reward_model=dict(
+ type='trex',
+ min_snippet_length=5,
+ max_snippet_length=100,
+ checkpoint_min=0,
+ checkpoint_max=100,
+ checkpoint_step=100,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ expert_model_path='abs model path',
+ reward_model_path='abs data path + ./cartpole.params',
+ data_path='abs data path',
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ critic_head_layer_num=1,
+ ),
+ learn=dict(
+ update_per_collect=6,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ ),
+ collect=dict(
+ n_sample=128,
+ unroll_len=1,
+ discount_factor=0.9,
+ gae_lambda=0.95,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=5000))
+ ),
+)
+cartpole_trex_offppo_config = EasyDict(cartpole_trex_offppo_config)
+main_config = cartpole_trex_offppo_config
+cartpole_trex_offppo_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo_offpolicy'),
+ reward_model=dict(type='trex'),
+)
+cartpole_trex_offppo_create_config = EasyDict(cartpole_trex_offppo_create_config)
+create_config = cartpole_trex_offppo_create_config
+
+if __name__ == "__main__":
+ # Users should first run ``cartpole_offppo_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_reward_model_trex
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_reward_model_trex((main_config, create_config))
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/cartpole_trex_onppo_config.py b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_trex_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..71b4d4a1361cff45d525717ec0ff0fb0784a8b75
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/cartpole_trex_onppo_config.py
@@ -0,0 +1,81 @@
+from easydict import EasyDict
+
+cartpole_trex_ppo_onpolicy_config = dict(
+ exp_name='cartpole_trex_onppo_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ reward_model=dict(
+ type='trex',
+ min_snippet_length=5,
+ max_snippet_length=100,
+ checkpoint_min=0,
+ checkpoint_max=100,
+ checkpoint_step=100,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ expert_model_path='abs model path',
+ reward_model_path='abs data path + ./cartpole.params',
+ data_path='abs data path',
+ ),
+ policy=dict(
+ cuda=False,
+ continuous=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ ),
+ learn=dict(
+ epoch_per_collect=2,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000)),
+ ),
+ collect=dict(
+ n_sample=256,
+ unroll_len=1,
+ discount_factor=0.9,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+cartpole_trex_ppo_onpolicy_config = EasyDict(cartpole_trex_ppo_onpolicy_config)
+main_config = cartpole_trex_ppo_onpolicy_config
+cartpole_trex_ppo_onpolicy_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo'),
+ reward_model=dict(type='trex'),
+)
+cartpole_trex_ppo_onpolicy_create_config = EasyDict(cartpole_trex_ppo_onpolicy_create_config)
+create_config = cartpole_trex_ppo_onpolicy_create_config
+
+if __name__ == "__main__":
+ # Users should first run ``cartpole_onppo_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_reward_model_trex_onpolicy
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_reward_model_trex_onpolicy((main_config, create_config))
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/parallel/__init__.py b/DI-engine/dizoo/classic_control/cartpole/config/parallel/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..848edf9e6bf9f5c3561612d5b489c198a56f35b1
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/parallel/__init__.py
@@ -0,0 +1 @@
+from .cartpole_dqn_config import cartpole_dqn_config, cartpole_dqn_create_config, cartpole_dqn_system_config
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/parallel/cartpole_dqn_config.py b/DI-engine/dizoo/classic_control/cartpole/config/parallel/cartpole_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a64b72931bbc293d0f524b1900fb06ffc5c3099
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/parallel/cartpole_dqn_config.py
@@ -0,0 +1,101 @@
+from easydict import EasyDict
+
+cartpole_dqn_config = dict(
+ exp_name='cartpole_dqn',
+ env=dict(
+ collector_env_num=8,
+ collector_episode_num=2,
+ evaluator_env_num=5,
+ evaluator_episode_num=1,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=3,
+ discount_factor=0.97,
+ learn=dict(
+ batch_size=32,
+ learning_rate=0.001,
+ learner=dict(
+ learner_num=1,
+ send_policy_freq=1,
+ ),
+ ),
+ collect=dict(
+ n_sample=16,
+ collector=dict(
+ collector_num=2,
+ update_policy_second=3,
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=50, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=100000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=100000,
+ enable_track_used_data=False,
+ ),
+ commander=dict(
+ collector_task_space=2,
+ learner_task_space=1,
+ eval_interval=5,
+ ),
+ ),
+ ),
+)
+cartpole_dqn_config = EasyDict(cartpole_dqn_config)
+main_config = cartpole_dqn_config
+
+cartpole_dqn_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn_command'),
+ learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
+ collector=dict(
+ type='zergling',
+ import_names=['ding.worker.collector.zergling_parallel_collector'],
+ ),
+ commander=dict(
+ type='solo',
+ import_names=['ding.worker.coordinator.solo_parallel_commander'],
+ ),
+ comm_learner=dict(
+ type='flask_fs',
+ import_names=['ding.worker.learner.comm.flask_fs_learner'],
+ ),
+ comm_collector=dict(
+ type='flask_fs',
+ import_names=['ding.worker.collector.comm.flask_fs_collector'],
+ ),
+)
+cartpole_dqn_create_config = EasyDict(cartpole_dqn_create_config)
+create_config = cartpole_dqn_create_config
+
+cartpole_dqn_system_config = dict(
+ coordinator=dict(),
+ path_data='./{}/data'.format(main_config.exp_name),
+ path_policy='./{}/policy'.format(main_config.exp_name),
+ communication_mode='auto',
+ learner_gpu_num=1,
+)
+cartpole_dqn_system_config = EasyDict(cartpole_dqn_system_config)
+system_config = cartpole_dqn_system_config
+
+if __name__ == '__main__':
+ from ding.entry.parallel_entry import parallel_pipeline
+
+ parallel_pipeline((main_config, create_config, system_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/parallel/cartpole_dqn_config_k8s.py b/DI-engine/dizoo/classic_control/cartpole/config/parallel/cartpole_dqn_config_k8s.py
new file mode 100644
index 0000000000000000000000000000000000000000..21e5c461abebfacffb382aa9460881847da74ea6
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/parallel/cartpole_dqn_config_k8s.py
@@ -0,0 +1,113 @@
+from easydict import EasyDict
+
+cartpole_dqn_config = dict(
+ exp_name='cartpole_dqn',
+ env=dict(
+ collector_env_num=8,
+ collector_episode_num=2,
+ evaluator_env_num=5,
+ evaluator_episode_num=1,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=4,
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=3,
+ discount_factor=0.97,
+ learn=dict(
+ batch_size=32,
+ learning_rate=0.001,
+ learner=dict(
+ learner_num=1,
+ send_policy_freq=1,
+ ),
+ ),
+ collect=dict(
+ n_sample=16,
+ collector=dict(
+ collector_num=2,
+ update_policy_second=3,
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=50, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=100000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=100000,
+ enable_track_used_data=False,
+ ),
+ commander=dict(
+ # increase collector task space when get rs from server
+ collector_task_space=0,
+ learner_task_space=1,
+ eval_interval=5,
+ ),
+ ),
+ ),
+)
+cartpole_dqn_config = EasyDict(cartpole_dqn_config)
+main_config = cartpole_dqn_config
+
+cartpole_dqn_create_config = dict(
+ env=dict(
+ type='cartpole',
+ import_names=['dizoo.classic_control.cartpole.envs.cartpole_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn_command'),
+ learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
+ collector=dict(
+ type='zergling',
+ import_names=['ding.worker.collector.zergling_parallel_collector'],
+ ),
+ commander=dict(
+ type='solo',
+ import_names=['ding.worker.coordinator.solo_parallel_commander'],
+ ),
+ comm_learner=dict(
+ type='flask_fs',
+ import_names=['ding.worker.learner.comm.flask_fs_learner'],
+ ),
+ comm_collector=dict(
+ type='flask_fs',
+ import_names=['ding.worker.collector.comm.flask_fs_collector'],
+ ),
+)
+cartpole_dqn_create_config = EasyDict(cartpole_dqn_create_config)
+create_config = cartpole_dqn_create_config
+
+cartpole_dqn_system_config = dict(
+ coordinator=dict(
+ operator_server=dict(
+ system_addr='di-server.di-system:8080',
+ api_version='/v1alpha1',
+ init_replicas_request=dict(
+ collectors={
+ "replicas": 2,
+ },
+ learners={
+ "gpus": "0",
+ "replicas": 1,
+ },
+ ),
+ collector_target_num=2,
+ learner_target_num=1,
+ ),
+ ),
+ path_data='./{}/data'.format(main_config.exp_name),
+ path_policy='./{}/policy'.format(main_config.exp_name),
+ communication_mode='auto',
+ learner_gpu_num=1,
+)
+cartpole_dqn_system_config = EasyDict(cartpole_dqn_system_config)
+system_config = cartpole_dqn_system_config
diff --git a/DI-engine/dizoo/classic_control/cartpole/config/parallel/cartpole_dqn_dist.sh b/DI-engine/dizoo/classic_control/cartpole/config/parallel/cartpole_dqn_dist.sh
new file mode 100755
index 0000000000000000000000000000000000000000..553b9161ed61b77ecebe0fd3c90efb990428782f
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/config/parallel/cartpole_dqn_dist.sh
@@ -0,0 +1,15 @@
+export PYTHONUNBUFFERED=1
+ding -m dist --module config -p slurm -c cartpole_dqn_config.py -s 0 -lh SH-IDC1-10-5-36-161 -clh SH-IDC1-10-5-36-140
+
+srun -p VI_SP_Y_V100_A -w SH-IDC1-10-5-36-161 --gres=gpu:1 ding -m dist --module learner --module-name learner0 -c cartpole_dqn_config.py.pkl -s 0 &
+srun -p VI_SP_Y_V100_A -w SH-IDC1-10-5-36-140 ding -m dist --module collector --module-name collector0 -c cartpole_dqn_config.py.pkl -s 0 &
+srun -p VI_SP_Y_V100_A -w SH-IDC1-10-5-36-140 ding -m dist --module collector --module-name collector1 -c cartpole_dqn_config.py.pkl -s 0 &
+
+ding -m dist --module coordinator -p slurm -c cartpole_dqn_config.py.pkl -s 0
+
+# the following command is for local test
+# ding -m dist --module config -p local -c cartpole_dqn_config.py -s 0
+# ding -m dist --module learner --module-name learner0 -c cartpole_dqn_config.py.pkl -s 0 &
+# ding -m dist --module collector --module-name collector0 -c cartpole_dqn_config.py.pkl -s 0 &
+# ding -m dist --module collector --module-name collector1 -c cartpole_dqn_config.py.pkl -s 0 &
+# ding -m dist --module coordinator -p local -c cartpole_dqn_config.py.pkl -s 0
diff --git a/DI-engine/dizoo/classic_control/cartpole/entry/__init__.py b/DI-engine/dizoo/classic_control/cartpole/entry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_c51_deploy.py b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_c51_deploy.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf1025dd0443e9fbfbe52d9d0f589c4f24c5241c
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_c51_deploy.py
@@ -0,0 +1,33 @@
+import gym
+import torch
+from easydict import EasyDict
+from ding.config import compile_config
+from ding.envs import DingEnvWrapper
+from ding.policy import C51Policy, single_env_forward_wrapper
+from ding.model import C51DQN
+from dizoo.classic_control.cartpole.config.cartpole_c51_config import cartpole_c51_config, cartpole_c51_create_config
+
+
+def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str):
+ main_config.exp_name = 'cartpole_c51_deploy'
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ env = DingEnvWrapper(gym.make('CartPole-v0'), EasyDict(env_wrapper='default'))
+ model = C51DQN(**cfg.policy.model)
+ state_dict = torch.load(ckpt_path, map_location='cpu')
+ model.load_state_dict(state_dict['model'])
+ policy = C51Policy(cfg.policy, model=model).eval_mode
+ forward_fn = single_env_forward_wrapper(policy.forward)
+
+ obs = env.reset()
+ returns = 0.
+ while True:
+ action = forward_fn(obs)
+ obs, rew, done, info = env.step(action)
+ returns += rew
+ if done:
+ break
+ print(f'Deploy is finished, final epsiode return is: {returns}')
+
+
+if __name__ == "__main__":
+ main(cartpole_c51_config, cartpole_c51_create_config, 'cartpole_c51_seed0/ckpt/ckpt_best.pth.tar')
diff --git a/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_c51_main.py b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_c51_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..29bc6f9ee485e2449cbe26bcf321b2624582e7c1
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_c51_main.py
@@ -0,0 +1,84 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import C51Policy
+from ding.model import C51DQN
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.classic_control.cartpole.config.cartpole_c51_config import cartpole_c51_config
+
+
+# Get DI-engine form env class
+def wrapped_cartpole_env():
+ return DingEnvWrapper(
+ gym.make('CartPole-v0'),
+ EasyDict(env_wrapper='default'),
+ )
+
+
+def main(cfg, seed=0):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ C51Policy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager)
+ evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = C51DQN(**cfg.policy.model)
+ policy = C51Policy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Set up other modules, etc. epsilon greedy
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluating at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Sampling data from environments
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Training
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(cartpole_c51_config)
diff --git a/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_cql_main.py b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_cql_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..311a2c7c11cff644f4ec30b5ab1a69c82e533e75
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_cql_main.py
@@ -0,0 +1,58 @@
+import torch
+from copy import deepcopy
+
+from ding.entry import serial_pipeline_offline, collect_demo_data, eval, serial_pipeline
+
+
+def train_cql(args):
+ from dizoo.classic_control.cartpole.config.cartpole_cql_config import main_config, create_config
+ main_config.exp_name = 'cartpole_cql'
+ main_config.policy.collect.data_path = './cartpole/expert_demos.hdf5'
+ main_config.policy.collect.data_type = 'hdf5'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+def eval_ckpt(args):
+ from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import main_config, create_config
+ main_config, create_config = deepcopy(main_config), deepcopy(create_config)
+ main_config.exp_name = 'cartpole'
+ config = deepcopy([main_config, create_config])
+ eval(config, seed=args.seed, load_path='./cartpole/ckpt/ckpt_best.pth.tar')
+
+
+def generate(args):
+ from dizoo.classic_control.cartpole.config.cartpole_qrdqn_generation_data_config import main_config, create_config
+ main_config.exp_name = 'cartpole'
+ main_config.policy.collect.save_path = './cartpole/expert.pkl'
+ main_config.policy.collect.data_type = 'hdf5'
+ config = deepcopy([main_config, create_config])
+ state_dict = torch.load('./cartpole/ckpt/ckpt_best.pth.tar', map_location='cpu')
+ collect_demo_data(
+ config,
+ collect_count=10000,
+ seed=args.seed,
+ expert_data_path=main_config.policy.collect.save_path,
+ state_dict=state_dict
+ )
+
+
+def train_expert(args):
+ from dizoo.classic_control.cartpole.config.cartpole_qrdqn_config import main_config, create_config
+ main_config, create_config = deepcopy(main_config), deepcopy(create_config)
+ main_config.exp_name = 'cartpole'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline(config, seed=args.seed)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=10)
+ args = parser.parse_args()
+
+ train_expert(args)
+ eval_ckpt(args)
+ generate(args)
+ train_cql(args)
diff --git a/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_dqn_buffer_main.py b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_dqn_buffer_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d225ddca077ac4f3dbd44afd66a5ee5de043408
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_dqn_buffer_main.py
@@ -0,0 +1,84 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, DequeBufferWrapper
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config
+
+
+# Get DI-engine form env class
+def wrapped_cartpole_env():
+ return DingEnvWrapper(
+ gym.make('CartPole-v0'),
+ EasyDict(env_wrapper='default'),
+ )
+
+
+def main(cfg, seed=0):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ DQNPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ DequeBufferWrapper,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager)
+ evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = DequeBufferWrapper(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Set up other modules, etc. epsilon greedy
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluating at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Sampling data from environments
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Training
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(cartpole_dqn_config)
diff --git a/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_dqn_eval.py b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_dqn_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..21031a3e345459deac95683833478c66b216f6b1
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_dqn_eval.py
@@ -0,0 +1,60 @@
+import os
+import gym
+import torch
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config
+
+
+# Get DI-engine form env class
+def wrapped_cartpole_env():
+ return DingEnvWrapper(
+ gym.make('CartPole-v0'),
+ EasyDict(env_wrapper='default'),
+ )
+ # from dizoo.classic_control.cartpole.envs.cartpole_env import CartPoleEnv
+ # return CartPoleEnv({})
+
+
+def main(cfg, seed=0):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ DQNPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ save_cfg=True
+ )
+ evaluator_env_num = cfg.env.evaluator_env_num
+ evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
+ evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
+
+ # Set random seed for all package and instance
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model)
+ policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
+
+ # evaluate
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator.eval()
+
+
+if __name__ == "__main__":
+ main(cartpole_dqn_config)
diff --git a/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_dqn_main.py b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_dqn_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..02c153efe0ff66bf67c0e4a29e095fd59a335d67
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_dqn_main.py
@@ -0,0 +1,91 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.classic_control.cartpole.config.cartpole_dqn_config import cartpole_dqn_config
+
+
+# Get DI-engine form env class
+def wrapped_cartpole_env():
+ return DingEnvWrapper(
+ gym.make('CartPole-v0'),
+ EasyDict(env_wrapper='default'),
+ )
+
+
+def main(cfg, seed=0):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ DQNPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager)
+ evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Set up other modules, etc. epsilon greedy
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluating at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Sampling data from environments
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Training
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+ # evaluate
+ evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
+ evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(cartpole_dqn_config)
diff --git a/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_fqf_main.py b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_fqf_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..1a6508641661da7e388c4ced9ad5ec22a74bb629
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_fqf_main.py
@@ -0,0 +1,92 @@
+import sys
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import FQFPolicy
+from ding.model import FQF
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.classic_control.cartpole.config.cartpole_fqf_config import cartpole_fqf_config
+
+
+# Get DI-engine form env class
+def wrapped_cartpole_env():
+ return DingEnvWrapper(
+ gym.make('CartPole-v0'),
+ EasyDict(env_wrapper='default'),
+ )
+
+
+def main(cfg, seed=0):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ FQFPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager)
+ evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
+ # evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = FQF(**cfg.policy.model)
+ policy = FQFPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Set up other modules, etc. epsilon greedy
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluating at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Sampling data from environments
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Training
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+ # evaluate
+ evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
+ evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(cartpole_fqf_config)
diff --git a/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_ppg_main.py b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_ppg_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..d845ed2ce95e4aa9939a5ae0572d0133e1cf7272
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_ppg_main.py
@@ -0,0 +1,83 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+from copy import deepcopy
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import PPGOffPolicy
+from ding.model import PPG
+from ding.utils import set_pkg_seed, deep_merge_dicts
+from dizoo.classic_control.cartpole.config.cartpole_ppg_config import cartpole_ppg_config
+
+
+def wrapped_cartpole_env():
+ return DingEnvWrapper(
+ gym.make('CartPole-v0'),
+ EasyDict(env_wrapper='default'),
+ )
+
+
+def main(cfg, seed=0, max_train_iter=int(1e8), max_env_step=int(1e8)):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ PPGOffPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator, {
+ 'policy': AdvancedReplayBuffer,
+ 'value': AdvancedReplayBuffer
+ },
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager)
+ evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
+
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = PPG(**cfg.policy.model)
+ policy = PPGOffPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ policy_buffer = AdvancedReplayBuffer(
+ cfg.policy.other.replay_buffer.policy, tb_logger, exp_name=cfg.exp_name, instance_name='policy_buffer'
+ )
+ value_buffer = AdvancedReplayBuffer(
+ cfg.policy.other.replay_buffer.value, tb_logger, exp_name=cfg.exp_name, instance_name='value_buffer'
+ )
+
+ while True:
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ new_data = collector.collect(train_iter=learner.train_iter)
+ policy_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ value_buffer.push(deepcopy(new_data), cur_collector_envstep=collector.envstep)
+ for i in range(cfg.policy.learn.update_per_collect):
+ batch_size = learner.policy.get_attribute('batch_size')
+ policy_data = policy_buffer.sample(batch_size['policy'], learner.train_iter)
+ value_data = value_buffer.sample(batch_size['value'], learner.train_iter)
+ if policy_data is not None and value_data is not None:
+ train_data = {'policy': policy_data, 'value': value_data}
+ learner.train(train_data, collector.envstep)
+ policy_buffer.clear()
+ value_buffer.clear()
+ if learner.train_iter >= max_train_iter or collector.envstep >= max_env_step:
+ break
+
+
+if __name__ == "__main__":
+ main(cartpole_ppg_config)
diff --git a/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_ppo_main.py b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_ppo_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd5a41f2612d2b93f6e961d634800d0cc52598e2
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_ppo_main.py
@@ -0,0 +1,55 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import PPOPolicy
+from ding.model import VAC
+from ding.utils import set_pkg_seed, deep_merge_dicts
+from dizoo.classic_control.cartpole.config.cartpole_ppo_config import cartpole_ppo_config
+
+
+def wrapped_cartpole_env():
+ return DingEnvWrapper(
+ gym.make('CartPole-v0'),
+ EasyDict(env_wrapper='default'),
+ )
+
+
+def main(cfg, seed=0, max_iterations=int(1e10)):
+ cfg = compile_config(
+ cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager)
+ evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
+
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+
+ for _ in range(max_iterations):
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ new_data = collector.collect(train_iter=learner.train_iter)
+ learner.train(new_data, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(cartpole_ppo_config)
diff --git a/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_ppo_offpolicy_main.py b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_ppo_offpolicy_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..a35c1d266d61a32999f5a5c1b21080c5617696cb
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/entry/cartpole_ppo_offpolicy_main.py
@@ -0,0 +1,67 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, NaiveReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import PPOOffPolicy
+from ding.model import VAC
+from ding.utils import set_pkg_seed, deep_merge_dicts
+from dizoo.classic_control.cartpole.config.cartpole_offppo_config import cartpole_offppo_config
+
+
+def wrapped_cartpole_env():
+ return DingEnvWrapper(
+ gym.make('CartPole-v0'),
+ EasyDict(env_wrapper='default'),
+ )
+
+
+def main(cfg, seed=0, max_iterations=int(1e10)):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ PPOOffPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ NaiveReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(collector_env_num)], cfg=cfg.env.manager)
+ evaluator_env = BaseEnvManager(env_fn=[wrapped_cartpole_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
+
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOOffPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = NaiveReplayBuffer(cfg.policy.other.replay_buffer, exp_name=cfg.exp_name)
+
+ for _ in range(max_iterations):
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ new_data = collector.collect(train_iter=learner.train_iter)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is not None:
+ learner.train(train_data, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(cartpole_offppo_config)
diff --git a/DI-engine/dizoo/classic_control/cartpole/envs/__init__.py b/DI-engine/dizoo/classic_control/cartpole/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6aa0dc9b7948455c4fd09886ed4dae70b53f658a
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/envs/__init__.py
@@ -0,0 +1 @@
+from .cartpole_env import CartPoleEnv
diff --git a/DI-engine/dizoo/classic_control/cartpole/envs/cartpole_env.py b/DI-engine/dizoo/classic_control/cartpole/envs/cartpole_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..7cd36702b01ac2ce2d311c4d01136294ffd23e5c
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/envs/cartpole_env.py
@@ -0,0 +1,100 @@
+from typing import Any, List, Union, Optional
+import time
+import gym
+import copy
+import numpy as np
+from easydict import EasyDict
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.torch_utils import to_ndarray, to_list
+from ding.utils import ENV_REGISTRY
+from ding.envs import ObsPlusPrevActRewWrapper
+
+
+@ENV_REGISTRY.register('cartpole')
+class CartPoleEnv(BaseEnv):
+
+ def __init__(self, cfg: dict = {}) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ self._replay_path = None
+ self._observation_space = gym.spaces.Box(
+ low=np.array([-4.8, float("-inf"), -0.42, float("-inf")]),
+ high=np.array([4.8, float("inf"), 0.42, float("inf")]),
+ shape=(4, ),
+ dtype=np.float32
+ )
+ self._action_space = gym.spaces.Discrete(2)
+ self._action_space.seed(0) # default seed
+ self._reward_space = gym.spaces.Box(low=0.0, high=1.0, shape=(1, ), dtype=np.float32)
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = gym.make('CartPole-v0')
+ if self._replay_path is not None:
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+ if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward:
+ self._env = ObsPlusPrevActRewWrapper(self._env)
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ self._action_space.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._action_space.seed(self._seed)
+ self._observation_space = self._env.observation_space
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+ obs = to_ndarray(obs)
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
+ if isinstance(action, np.ndarray) and action.shape == (1, ):
+ action = action.squeeze() # 0-dim array
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ obs = to_ndarray(obs)
+ rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transfered to a array with shape (1,)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine CartPole Env"
diff --git a/DI-engine/dizoo/classic_control/cartpole/envs/test_cartpole_env.py b/DI-engine/dizoo/classic_control/cartpole/envs/test_cartpole_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f9e3407b1a64ce00cfd28e586c923743290fa34
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/envs/test_cartpole_env.py
@@ -0,0 +1,35 @@
+import pytest
+import numpy as np
+from dizoo.classic_control.cartpole.envs import CartPoleEnv
+
+
+@pytest.mark.envtest
+class TestCartPoleEnv:
+
+ def test_naive(self):
+ env = CartPoleEnv({})
+ env.seed(314, dynamic_seed=False)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (4, )
+ for _ in range(5):
+ env.reset()
+ np.random.seed(314)
+ print('=' * 60)
+ for i in range(10):
+ # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
+ # can generate legal random action.
+ if i < 5:
+ random_action = np.array([env.action_space.sample()])
+ else:
+ random_action = env.random_action()
+ timestep = env.step(random_action)
+ print(timestep)
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (4, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.reward_space.low
+ assert timestep.reward <= env.reward_space.high
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/classic_control/cartpole/envs/test_cartpole_env_manager.py b/DI-engine/dizoo/classic_control/cartpole/envs/test_cartpole_env_manager.py
new file mode 100644
index 0000000000000000000000000000000000000000..986bbac7e99bc322de2ef03383df24840d6977ef
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/cartpole/envs/test_cartpole_env_manager.py
@@ -0,0 +1,34 @@
+import pytest
+import numpy as np
+from ding.envs import BaseEnvManager
+from dizoo.classic_control.cartpole.envs import CartPoleEnv
+
+
+@pytest.mark.envtest
+class TestCartPoleEnv:
+
+ def test_naive(self):
+ env_num = 8
+ env = BaseEnvManager([lambda: CartPoleEnv({}) for _ in range(env_num)], BaseEnvManager.default_config())
+ env.seed(314, dynamic_seed=False)
+ env.launch()
+ for _ in range(5):
+ env.reset()
+ np.random.seed(314)
+ for i in range(10):
+ obs = env.ready_obs
+ assert len(obs) == env_num
+ random_action = {i: np.array([env.action_space.sample()]) for i in range(env_num)}
+ timesteps = env.step(random_action)
+ # print(timesteps)
+ assert isinstance(timesteps, dict)
+ # test one of timesteps
+ timestep = timesteps[0]
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (4, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.reward_space.low
+ assert timestep.reward <= env.reward_space.high
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/classic_control/mountain_car/__init__.py b/DI-engine/dizoo/classic_control/mountain_car/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/classic_control/mountain_car/config/mtcar_rainbow_config.py b/DI-engine/dizoo/classic_control/mountain_car/config/mtcar_rainbow_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b293d444945a7af9f9ad2fd9b2fa9d035dcdc8da
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/mountain_car/config/mtcar_rainbow_config.py
@@ -0,0 +1,63 @@
+from easydict import EasyDict
+
+# DI-Engine uses EasyDict for configuration, by convention
+mtcar_rainbow_config = EasyDict(
+ dict(
+ exp_name='mtcar_rainbow_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=195,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=True,
+ discount_factor=0.97,
+ nstep=3,
+ model=dict(
+ obs_shape=2,
+ action_shape=3,
+ encoder_hidden_size_list=[128, 128, 64],
+ ),
+ learn=dict(
+ update_per_collect=3,
+ batch_size=64,
+ learning_rate=0.001,
+ target_update_freq=100,
+ ),
+ collect=dict(
+ n_sample=80,
+ unroll_len=1,
+ ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, )
+ ),
+ ),
+ )
+)
+
+main_config = mtcar_rainbow_config
+
+mtcar_rainbow_create_config = EasyDict(
+ dict(
+ env=dict(
+ type='mountain_car',
+ import_names=['dizoo.classic_control.mountain_car.envs.mtcar_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='rainbow'),
+ )
+)
+
+create_config = mtcar_rainbow_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/mountain_car/envs/__init__.py b/DI-engine/dizoo/classic_control/mountain_car/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9e8ca86d5f7e2d6aba7860428d9169e60244bf54
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/mountain_car/envs/__init__.py
@@ -0,0 +1 @@
+from .mtcar_env import MountainCarEnv
diff --git a/DI-engine/dizoo/classic_control/mountain_car/envs/mtcar_env.py b/DI-engine/dizoo/classic_control/mountain_car/envs/mtcar_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..515b4247656f792457c97f2002daf3546ef0b49d
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/mountain_car/envs/mtcar_env.py
@@ -0,0 +1,129 @@
+from typing import Any, List, Union, Optional
+import gym
+import numpy as np
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.torch_utils import to_ndarray, to_list
+from ding.utils import ENV_REGISTRY
+
+
+@ENV_REGISTRY.register('mountain_car')
+class MountainCarEnv(BaseEnv):
+ """
+ Implementation of DI-engine's version of the Mountain Car deterministic MDP.
+
+ Important references that contributed to the creation of this env:
+ > Source code of OpenAI's mountain car gym : https://is.gd/y1FkMT
+ > Gym documentation of mountain car : https://is.gd/29S0dt
+ > Based off DI-engine existing implementation of cartpole_env.py
+ > DI-engine's env creation conventions : https://is.gd/ZHLISj
+
+ Only __init__ , step, seed and reset are mandatory & impt.
+ The other methods are generally for convenience.
+ """
+
+ def __init__(self, cfg: EasyDict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ self._replay_path = None
+
+ # Following specifications from https://is.gd/29S0dt
+ self._observation_space = gym.spaces.Box(
+ low=np.array([-1.2, -0.07]), high=np.array([0.6, 0.07]), shape=(2, ), dtype=np.float32
+ )
+ self._action_space = gym.spaces.Discrete(3, start=0)
+ self._reward_space = gym.spaces.Box(low=-1, high=0.0, shape=(1, ), dtype=np.float32)
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def reset(self) -> np.ndarray:
+ # Instantiate environment if not already done so
+ if not self._init_flag:
+ self._env = gym.make('MountainCar-v0')
+ self._init_flag = True
+
+ # Check if we have a valid replay path and save replay video accordingly
+ if self._replay_path is not None:
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+
+ # Set the seeds for randomization.
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ self._action_space.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._action_space.seed(self._seed)
+
+ # Get first observation from original environment
+ obs = self._env.reset()
+
+ # Convert to numpy array as output
+ obs = to_ndarray(obs).astype(np.float32)
+
+ # Init final reward : cumulative sum of the real rewards obtained by a whole episode,
+ # used to evaluate the agent Performance on this environment, not used for training.
+ self._eval_episode_return = 0.
+ return obs
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+
+ # Making sure that input action is of numpy ndarray
+ assert isinstance(action, np.ndarray), type(action)
+
+ # Extract action as int, 0-dim array
+ action = action.squeeze()
+
+ # Take a step of faith into the unknown!
+ obs, rew, done, info = self._env.step(action)
+
+ # Cummulate reward
+ self._eval_episode_return += rew
+
+ # Save final cummulative reward when done.
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+
+ # Making sure we conform to di-engine conventions
+ obs = to_ndarray(obs)
+ rew = to_ndarray([rew]).astype(np.float32)
+
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def close(self) -> None:
+ # If init flag is False, then reset() was never run, no point closing.
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine Mountain Car Env"
diff --git a/DI-engine/dizoo/classic_control/mountain_car/envs/test_mtcar_env.py b/DI-engine/dizoo/classic_control/mountain_car/envs/test_mtcar_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc8c125a128060b47089238e8916d2101f150b6f
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/mountain_car/envs/test_mtcar_env.py
@@ -0,0 +1,36 @@
+import pytest
+import numpy as np
+from dizoo.classic_control.mountain_car.envs import MountainCarEnv
+
+
+@pytest.mark.envtest
+class TestMountainCarEnv:
+
+ def test_naive(self):
+ env = MountainCarEnv()
+ env.seed(314, dynamic_seed=False)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (2, )
+ for _ in range(5):
+ env.reset()
+ np.random.seed(314)
+ print('=' * 60)
+ for i in range(10):
+ # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
+ # can generate legal random action.
+ if i < 5:
+ random_action = np.array([env.action_space.sample()])
+ else:
+ random_action = env.random_action()
+ timestep = env.step(random_action)
+ print("Action taken : ", random_action)
+ print(timestep, "\n")
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (2, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.reward_space.low
+ assert timestep.reward <= env.reward_space.high
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/classic_control/pendulum/__init__.py b/DI-engine/dizoo/classic_control/pendulum/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/__init__.py b/DI-engine/dizoo/classic_control/pendulum/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7c2988f06013fe509c111c0dd3c3b207bd8ee3f
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/__init__.py
@@ -0,0 +1,7 @@
+from .pendulum_ddpg_config import pendulum_ddpg_config, pendulum_ddpg_create_config
+from .pendulum_td3_config import pendulum_td3_config, pendulum_td3_create_config
+from .pendulum_sac_config import pendulum_sac_config, pendulum_sac_create_config
+from .pendulum_d4pg_config import pendulum_d4pg_config, pendulum_d4pg_create_config
+from .pendulum_ppo_config import pendulum_ppo_config, pendulum_ppo_create_config
+from .pendulum_td3_bc_config import pendulum_td3_bc_config, pendulum_td3_bc_create_config
+from .pendulum_td3_data_generation_config import pendulum_td3_generation_config, pendulum_td3_generation_create_config
\ No newline at end of file
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_mbsac_ddppo_config.py b/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_mbsac_ddppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0df8695a9227b0b31d8b540cdd95a574d72afaf4
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_mbsac_ddppo_config.py
@@ -0,0 +1,116 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dream
+
+# environment hypo
+env_id = 'Pendulum-v1'
+obs_shape = 3
+action_shape = 1
+
+# gpu
+cuda = False
+
+main_config = dict(
+ exp_name='pendulum_mbsac_ddppo_seed0',
+ env=dict(
+ env_id=env_id, # only for backward compatibility
+ collector_env_num=10,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # backward compatibility: it is better to
+ # put random_collect_size in policy.other
+ random_collect_size=1000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ lambda_=0.8,
+ sample_state=False,
+ update_per_collect=1,
+ batch_size=128,
+ learning_rate_q=0.001,
+ learning_rate_policy=0.001,
+ learning_rate_alpha=0.0003,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ auto_alpha=False,
+ value_network=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=100, # w.r.t envstep
+ train_freq=100, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=2000,
+ rollout_end_step=15000,
+ rollout_length_min=3,
+ rollout_length_max=3,
+ ),
+ model=dict(
+ gradient_model=True,
+ k=3,
+ reg=50,
+ neighbor_pool_size=1000,
+ train_freq_gradient_model=500,
+ #
+ ensemble_size=5,
+ elite_size=3,
+ state_size=obs_shape,
+ action_size=action_shape,
+ reward_size=1,
+ hidden_size=100,
+ use_decay=True,
+ batch_size=64,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbpendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='mbsac',
+ import_names=['ding.policy.mbpolicy.mbsac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ world_model=dict(
+ type='ddppo',
+ import_names=['ding.world_model.ddppo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dream((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_mbsac_mbpo_config.py b/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_mbsac_mbpo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd24c415105ef84d6850c79841627f8e533efb2d
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_mbsac_mbpo_config.py
@@ -0,0 +1,110 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dream
+
+# environment hypo
+env_id = 'Pendulum-v1'
+obs_shape = 3
+action_shape = 1
+
+# gpu
+cuda = False
+
+main_config = dict(
+ exp_name='pendulum_mbsac_mbpo_seed0',
+ env=dict(
+ env_id=env_id, # only for backward compatibility
+ collector_env_num=10,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # backward compatibility: it is better to
+ # put random_collect_size in policy.other
+ random_collect_size=1000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ lambda_=0.8,
+ sample_state=False,
+ update_per_collect=1,
+ batch_size=128,
+ learning_rate_q=0.001,
+ learning_rate_policy=0.001,
+ learning_rate_alpha=0.0003,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ auto_alpha=False,
+ value_network=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=100, # w.r.t envstep
+ train_freq=100, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=2000,
+ rollout_end_step=15000,
+ rollout_length_min=3,
+ rollout_length_max=3,
+ ),
+ model=dict(
+ ensemble_size=5,
+ elite_size=3,
+ state_size=obs_shape,
+ action_size=action_shape,
+ reward_size=1,
+ hidden_size=100,
+ use_decay=True,
+ batch_size=64,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbpendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='mbsac',
+ import_names=['ding.policy.mbpolicy.mbsac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ world_model=dict(
+ type='mbpo',
+ import_names=['ding.world_model.mbpo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dream((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_sac_ddppo_config.py b/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_sac_ddppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..827aac4c7638c70d53514019797c0220c0b2c3ad
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_sac_ddppo_config.py
@@ -0,0 +1,121 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dyna
+
+# environment hypo
+env_id = 'Pendulum-v1'
+obs_shape = 3
+action_shape = 1
+
+# gpu
+cuda = False
+
+main_config = dict(
+ exp_name='pendulum_sac_ddppo_seed0',
+ env=dict(
+ env_id=env_id, # only for backward compatibility
+ collector_env_num=10,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # backward compatibility: it is better to
+ # put random_collect_size in policy.other
+ random_collect_size=1000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=128,
+ learning_rate_q=0.001,
+ learning_rate_policy=0.001,
+ learning_rate_alpha=0.0003,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ auto_alpha=False,
+ value_network=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=100, # w.r.t envstep
+ train_freq=100, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=2000,
+ rollout_end_step=15000,
+ rollout_length_min=1,
+ rollout_length_max=1,
+ ),
+ model=dict(
+ gradient_model=True,
+ k=3,
+ reg=50,
+ neighbor_pool_size=1000,
+ train_freq_gradient_model=500,
+ #
+ ensemble_size=5,
+ elite_size=3,
+ state_size=obs_shape,
+ action_size=action_shape,
+ reward_size=1,
+ hidden_size=100,
+ use_decay=True,
+ batch_size=64,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ other=dict(
+ rollout_batch_size=10000,
+ rollout_retain=4,
+ real_ratio=0.05,
+ imagination_buffer=dict(replay_buffer_size=600000, ),
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbpendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ imagination_buffer=dict(type='elastic', ),
+ world_model=dict(
+ type='ddppo',
+ import_names=['ding.world_model.ddppo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dyna((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_sac_mbpo_config.py b/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_sac_mbpo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..24fa887da0d5c6db146666f44a523df17014aeef
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_sac_mbpo_config.py
@@ -0,0 +1,115 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dyna
+
+# environment hypo
+env_id = 'Pendulum-v1'
+obs_shape = 3
+action_shape = 1
+
+# gpu
+cuda = False
+
+main_config = dict(
+ exp_name='pendulum_sac_mbpo_seed0',
+ env=dict(
+ env_id=env_id, # only for backward compatibility
+ collector_env_num=10,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # backward compatibility: it is better to
+ # put random_collect_size in policy.other
+ random_collect_size=1000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=128,
+ learning_rate_q=0.001,
+ learning_rate_policy=0.001,
+ learning_rate_alpha=0.0003,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ auto_alpha=False,
+ value_network=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=100, # w.r.t envstep
+ train_freq=100, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=2000,
+ rollout_end_step=15000,
+ rollout_length_min=1,
+ rollout_length_max=1,
+ ),
+ model=dict(
+ ensemble_size=5,
+ elite_size=3,
+ state_size=obs_shape,
+ action_size=action_shape,
+ reward_size=1,
+ hidden_size=100,
+ use_decay=True,
+ batch_size=64,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ other=dict(
+ rollout_batch_size=10000,
+ rollout_retain=4,
+ real_ratio=0.05,
+ imagination_buffer=dict(replay_buffer_size=600000, ),
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbpendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ imagination_buffer=dict(type='elastic', ),
+ world_model=dict(
+ type='mbpo',
+ import_names=['ding.world_model.mbpo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dyna((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_stevesac_mbpo_config.py b/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_stevesac_mbpo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e382f9ab916f9a4ea8b808bd004bb70b0244cc0
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/mbrl/pendulum_stevesac_mbpo_config.py
@@ -0,0 +1,109 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dream
+
+# environment hypo
+env_id = 'Pendulum-v1'
+obs_shape = 3
+action_shape = 1
+
+# gpu
+cuda = False
+
+main_config = dict(
+ exp_name='pendulum_stevesac_mbpo_seed0',
+ env=dict(
+ env_id=env_id, # only for backward compatibility
+ collector_env_num=10,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # backward compatibility: it is better to
+ # put random_collect_size in policy.other
+ random_collect_size=1000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ ensemble_size=5,
+ update_per_collect=1,
+ batch_size=128,
+ learning_rate_q=0.001,
+ learning_rate_policy=0.001,
+ learning_rate_alpha=0.0003,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ auto_alpha=False,
+ value_network=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=100, # w.r.t envstep
+ train_freq=100, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=2000,
+ rollout_end_step=15000,
+ rollout_length_min=3,
+ rollout_length_max=3,
+ ),
+ model=dict(
+ ensemble_size=5,
+ elite_size=3,
+ state_size=obs_shape,
+ action_size=action_shape,
+ reward_size=1,
+ hidden_size=100,
+ use_decay=True,
+ batch_size=64,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbpendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='stevesac',
+ import_names=['ding.policy.mbpolicy.mbsac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ world_model=dict(
+ type='mbpo',
+ import_names=['ding.world_model.mbpo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dream((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_a2c_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_a2c_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d66e8b1a17be011a002adea6e94290a59df32126
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_a2c_config.py
@@ -0,0 +1,51 @@
+from easydict import EasyDict
+
+pendulum_a2c_config = dict(
+ exp_name='pendulum_a2c_seed0',
+ env=dict(
+ collector_env_num=10,
+ evaluator_env_num=5,
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-200,
+ ),
+ policy=dict(
+ cuda=False,
+ action_space='continuous',
+ model=dict(
+ action_space='continuous',
+ obs_shape=3,
+ action_shape=1,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=32,
+ learning_rate=3e-5,
+ value_weight=0.5,
+ entropy_weight=0.0,
+ ),
+ collect=dict(
+ n_sample=200,
+ unroll_len=1,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=200, ))
+ ),
+)
+pendulum_a2c_config = EasyDict(pendulum_a2c_config)
+main_config = pendulum_a2c_config
+pendulum_a2c_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='a2c'),
+)
+pendulum_a2c_create_config = EasyDict(pendulum_a2c_create_config)
+create_config = pendulum_a2c_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c pendulum_a2c_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_bdq_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_bdq_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..59bffb05b9acaef0f24a1646907ba6cec7c17550
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_bdq_config.py
@@ -0,0 +1,62 @@
+from easydict import EasyDict
+import sys
+sys.path.insert(0, "/mnt/lustre/chenyun/bdq_implement1/DI-engine")
+pendulum_bdq_config = dict(
+ exp_name='pendulum_bdq_seed0',
+ env=dict(
+ collector_env_num=10,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ continuous=False,
+ # The path to save the game replay
+ # replay_path='./pendulum_bdq_seed0/video',
+ ),
+ policy=dict(
+ cuda=False,
+ load_path='pendulum_bdq_seed0/ckpt/ckpt_best.pth.tar', # necessary for eval
+ model=dict(
+ obs_shape=3,
+ num_branches=1,
+ action_bins_per_branch=11,
+ encoder_hidden_size_list=[128, 128, 64],
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ ),
+ collect=dict(n_sample=8),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+pendulum_bdq_config = EasyDict(pendulum_bdq_config)
+main_config = pendulum_bdq_config
+pendulum_bdq_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='bdq'),
+ replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']),
+)
+pendulum_bdq_create_config = EasyDict(pendulum_bdq_create_config)
+create_config = pendulum_bdq_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c pendulum_bdq_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_cql_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_cql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4d9efa7b1841d27e330c5051fed26389cb7c26d3
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_cql_config.py
@@ -0,0 +1,65 @@
+from easydict import EasyDict
+
+pendulum_cql_config = dict(
+ exp_name='pendulum_cql',
+ env=dict(
+ evaluator_env_num=5,
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=3,
+ action_shape=1,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=128,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=1e-3,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(
+ data_type='hdf5',
+ data_path='./pendulum_sac_data_generation/expert_demos.hdf5',
+ collector_logit=False,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ ),
+)
+
+pendulum_cql_config = EasyDict(pendulum_cql_config)
+main_config = pendulum_cql_config
+
+pendulum_cql_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+)
+pendulum_cql_create_config = EasyDict(pendulum_cql_create_config)
+create_config = pendulum_cql_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_offline -c pendulum_cql_config.py -s 0`
+ from ding.entry import serial_pipeline_offline
+ serial_pipeline_offline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_d4pg_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_d4pg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..82ef19e21cbfc185dae2e06749b747c430bb6441
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_d4pg_config.py
@@ -0,0 +1,65 @@
+from easydict import EasyDict
+
+pendulum_d4pg_config = dict(
+ exp_name='pendulum_d4pg_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=True,
+ nstep=3,
+ discount_factor=0.995,
+ random_collect_size=800,
+ model=dict(
+ obs_shape=3,
+ action_shape=1,
+ action_space='regression',
+ v_min=-100,
+ v_max=100,
+ n_atom=51,
+ ),
+ learn=dict(
+ update_per_collect=2,
+ batch_size=128,
+ learning_rate_actor=0.001,
+ learning_rate_critic=0.001,
+ ignore_done=True,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=48,
+ noise_sigma=0.1,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ other=dict(replay_buffer=dict(
+ replay_buffer_size=20000,
+ max_use=16,
+ ), ),
+ ),
+)
+pendulum_d4pg_config = EasyDict(pendulum_d4pg_config)
+main_config = pendulum_d4pg_config
+
+pendulum_d4pg_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='d4pg'),
+)
+pendulum_d4pg_create_config = EasyDict(pendulum_d4pg_create_config)
+create_config = pendulum_d4pg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c pendulum_d4pg_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_ddpg_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_ddpg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..883e9a5011ae88dc0553b992a8b904c3f6937650
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_ddpg_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+pendulum_ddpg_config = dict(
+ exp_name='pendulum_ddpg_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ random_collect_size=800,
+ model=dict(
+ obs_shape=3,
+ action_shape=1,
+ twin_critic=False,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=2,
+ batch_size=128,
+ learning_rate_actor=0.001,
+ learning_rate_critic=0.001,
+ ignore_done=True,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=48,
+ noise_sigma=0.1,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(replay_buffer=dict(
+ replay_buffer_size=20000,
+ max_use=16,
+ ), ),
+ ),
+)
+pendulum_ddpg_config = EasyDict(pendulum_ddpg_config)
+main_config = pendulum_ddpg_config
+
+pendulum_ddpg_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ddpg'),
+)
+pendulum_ddpg_create_config = EasyDict(pendulum_ddpg_create_config)
+create_config = pendulum_ddpg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c pendulum_ddpg_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_dqn_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4f5f3536c7fb7f97c70c929132e8d2cfd767bcd
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_dqn_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+pendulum_dqn_config = dict(
+ exp_name='pendulum_dqn_seed0',
+ env=dict(
+ collector_env_num=10,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ continuous=False,
+ # The path to save the game replay
+ # replay_path='./pendulum_dqn_seed0/video',
+ ),
+ policy=dict(
+ cuda=False,
+ load_path='pendulum_dqn_seed0/ckpt/ckpt_best.pth.tar', # necessary for eval
+ model=dict(
+ obs_shape=3,
+ action_shape=11, # mean the action shape is 11, 11 discrete actions
+ encoder_hidden_size_list=[128, 128, 64],
+ dueling=True,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ batch_size=64,
+ learning_rate=0.001,
+ ),
+ collect=dict(n_sample=8),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+pendulum_dqn_config = EasyDict(pendulum_dqn_config)
+main_config = pendulum_dqn_config
+pendulum_dqn_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn'),
+ replay_buffer=dict(type='deque', import_names=['ding.data.buffer.deque_buffer_wrapper']),
+)
+pendulum_dqn_create_config = EasyDict(pendulum_dqn_create_config)
+create_config = pendulum_dqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c pendulum_dqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_ibc_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_ibc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c56f283fe8f28fdd9f7fc1feb82708390310968
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_ibc_config.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+
+cuda = False
+multi_gpu = False
+
+main_config = dict(
+ exp_name='pendulum_ibc_seed0',
+ env=dict(
+ evaluator_env_num=5,
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=cuda,
+ model=dict(obs_shape=3, action_shape=1, stochastic_optim=dict(
+ type='mcmc',
+ cuda=cuda,
+ )),
+ learn=dict(
+ multi_gpu=multi_gpu,
+ train_epoch=15,
+ batch_size=256,
+ optim=dict(learning_rate=1e-5, ),
+ learner=dict(hook=dict(log_show_after_iter=1000)),
+ ),
+ collect=dict(
+ data_type='hdf5',
+ data_path='./pendulum_sac_data_generation/expert_demos.hdf5',
+ collector_logit=False,
+ ),
+ eval=dict(evaluator=dict(eval_freq=-1, )),
+ ),
+)
+pendulum_ibc_config = EasyDict(main_config)
+main_config = pendulum_ibc_config
+
+pendulum_ibc_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='ibc',
+ import_names=['ding.policy.ibc'],
+ ),
+)
+pendulum_ibc_create_config = EasyDict(pendulum_ibc_create_config)
+create_config = pendulum_ibc_create_config
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_pg_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_pg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b512548398eff4faaa851505964df5aca82bfcd3
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_pg_config.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+
+pendulum_pg_config = dict(
+ exp_name='pendulum_pg_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-200,
+ ),
+ policy=dict(
+ cuda=False,
+ action_space='continuous',
+ model=dict(
+ action_space='continuous',
+ obs_shape=3,
+ action_shape=1,
+ ),
+ learn=dict(
+ batch_size=400,
+ learning_rate=0.001,
+ entropy_weight=0.001,
+ ),
+ collect=dict(
+ n_episode=2,
+ unroll_len=1,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=200, ))
+ ),
+)
+pendulum_pg_config = EasyDict(pendulum_pg_config)
+main_config = pendulum_pg_config
+pendulum_pg_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='pg'),
+ collector=dict(type='episode'),
+)
+pendulum_pg_create_config = EasyDict(pendulum_pg_create_config)
+create_config = pendulum_pg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c pendulum_pg_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_ppo_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2431a5aa2bbc016ad44c4fedef32327e23a19642
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_ppo_config.py
@@ -0,0 +1,62 @@
+from easydict import EasyDict
+
+pendulum_ppo_config = dict(
+ exp_name='pendulum_ppo_seed0',
+ env=dict(
+ collector_env_num=10,
+ evaluator_env_num=5,
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=False,
+ action_space='continuous',
+ recompute_adv=True,
+ model=dict(
+ obs_shape=3,
+ action_shape=1,
+ encoder_hidden_size_list=[64, 64],
+ action_space='continuous',
+ actor_head_layer_num=0,
+ critic_head_layer_num=0,
+ sigma_type='conditioned',
+ bound_type='tanh',
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=32,
+ learning_rate=3e-5,
+ value_weight=0.5,
+ entropy_weight=0.0,
+ clip_ratio=0.2,
+ adv_norm=False,
+ value_norm=True,
+ ignore_done=True,
+ ),
+ collect=dict(
+ n_sample=200,
+ unroll_len=1,
+ discount_factor=0.9,
+ gae_lambda=1.,
+ ),
+ eval=dict(evaluator=dict(eval_freq=200, ))
+ ),
+)
+pendulum_ppo_config = EasyDict(pendulum_ppo_config)
+main_config = pendulum_ppo_config
+pendulum_ppo_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo'),
+)
+pendulum_ppo_create_config = EasyDict(pendulum_ppo_create_config)
+create_config = pendulum_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c pendulum_ppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_sac_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc95db33ff82aa66f76e5b01b5af710f390d3ced
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_sac_config.py
@@ -0,0 +1,58 @@
+from easydict import EasyDict
+
+pendulum_sac_config = dict(
+ exp_name='pendulum_sac_seed0',
+ env=dict(
+ collector_env_num=10,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ random_collect_size=1000,
+ model=dict(
+ obs_shape=3,
+ action_shape=1,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=128,
+ learning_rate_q=0.001,
+ learning_rate_policy=0.001,
+ learning_rate_alpha=0.0003,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ auto_alpha=True,
+ ),
+ collect=dict(n_sample=10, ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
+ ),
+)
+pendulum_sac_config = EasyDict(pendulum_sac_config)
+main_config = pendulum_sac_config
+
+pendulum_sac_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='sac'),
+)
+pendulum_sac_create_config = EasyDict(pendulum_sac_create_config)
+create_config = pendulum_sac_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c pendulum_sac_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ff5fdd242e488157e60aa6ec21a43b863536df3d
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_sac_data_generation_config.py
@@ -0,0 +1,44 @@
+from easydict import EasyDict
+
+pendulum_sac_data_genearation_config = dict(
+ exp_name='pendulum_sac_data_generation',
+ env=dict(
+ collector_env_num=10,
+ act_scale=True,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=3,
+ action_shape=1,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ collect=dict(
+ n_sample=1000,
+ save_path='./pendulum_sac_data_generation/expert.pkl',
+ data_type='hdf5',
+ state_dict_path='./pendulum_sac_seed0/ckpt/final.pth.tar',
+ ),
+ ),
+)
+
+pendulum_sac_data_genearation_config = EasyDict(pendulum_sac_data_genearation_config)
+main_config = pendulum_sac_data_genearation_config
+
+pendulum_sac_data_genearation_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+pendulum_sac_data_genearation_create_config = EasyDict(pendulum_sac_data_genearation_create_config)
+create_config = pendulum_sac_data_genearation_create_config
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_sqil_sac_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_sqil_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f80b85235f637697919c88f2cfcffcc41a4908f9
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_sqil_sac_config.py
@@ -0,0 +1,77 @@
+from easydict import EasyDict
+
+obs_shape = 3,
+action_shape = 1,
+pendulum_sqil_sac_config = dict(
+ exp_name='pendulum_sqil_sac_seed0',
+ env=dict(
+ collector_env_num=10,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=False,
+ random_collect_size=1000,
+ expert_random_collect_size=1000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ update_per_collect=1,
+ batch_size=128,
+ learning_rate_q=0.001,
+ learning_rate_policy=0.001,
+ learning_rate_alpha=0.0003,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ auto_alpha=True,
+ ),
+ collect=dict(
+ n_sample=10,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='pendulum_sac_seed0/ckpt/eval.pth.tar',
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
+ ),
+)
+pendulum_sqil_sac_config = EasyDict(pendulum_sqil_sac_config)
+main_config = pendulum_sqil_sac_config
+
+pendulum_sqil_sac_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='sac'),
+)
+pendulum_sqil_sac_create_config = EasyDict(pendulum_sqil_sac_create_config)
+create_config = pendulum_sqil_sac_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_sqil -c pendulum_sqil_sac_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. pendulum_sac_config.py
+ from ding.entry import serial_pipeline_sqil
+ from dizoo.classic_control.pendulum.config.pendulum_sac_config import pendulum_sac_config, pendulum_sac_create_config
+ expert_main_config = pendulum_sac_config
+ expert_create_config = pendulum_sac_create_config
+ serial_pipeline_sqil(
+ [main_config, create_config],
+ [expert_main_config, expert_create_config],
+ seed=0,
+ )
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8583fc6adaf55b0341eb2c5ea43252a6f191da90
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_td3_bc_config.py
@@ -0,0 +1,75 @@
+from easydict import EasyDict
+
+pendulum_td3_bc_config = dict(
+ exp_name='pendulum_td3_bc_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ random_collect_size=800,
+ model=dict(
+ obs_shape=3,
+ action_shape=1,
+ twin_critic=True,
+ action_space='regression',
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=128,
+ learning_rate_actor=1e-4,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ alpha=2.5,
+ ),
+ collect=dict(
+ noise_sigma=0.1,
+ data_type='hdf5',
+ data_path='./td3/expert_demos.hdf5',
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ other=dict(replay_buffer=dict(replay_buffer_size=20000, ), ),
+ ),
+)
+pendulum_td3_bc_config = EasyDict(pendulum_td3_bc_config)
+main_config = pendulum_td3_bc_config
+
+pendulum_td3_bc_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+)
+pendulum_td3_bc_create_config = EasyDict(pendulum_td3_bc_create_config)
+create_config = pendulum_td3_bc_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_offline -c pendulum_td3_bc_config.py -s 0`
+ from ding.entry import serial_pipeline_offline
+ serial_pipeline_offline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_td3_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_td3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ff06425d28332610d6f81d366d9c5999000f2c9
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_td3_config.py
@@ -0,0 +1,63 @@
+from easydict import EasyDict
+
+pendulum_td3_config = dict(
+ exp_name='pendulum_td3_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=5,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ random_collect_size=800,
+ model=dict(
+ obs_shape=3,
+ action_shape=1,
+ twin_critic=True,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=2,
+ batch_size=128,
+ learning_rate_actor=0.001,
+ learning_rate_critic=0.001,
+ ignore_done=True,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.1,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=48,
+ noise_sigma=0.1,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ other=dict(replay_buffer=dict(replay_buffer_size=20000, ), ),
+ ),
+)
+pendulum_td3_config = EasyDict(pendulum_td3_config)
+main_config = pendulum_td3_config
+
+pendulum_td3_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='td3'),
+)
+pendulum_td3_create_config = EasyDict(pendulum_td3_create_config)
+create_config = pendulum_td3_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c pendulum_td3_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/config/pendulum_td3_data_generation_config.py b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_td3_data_generation_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..19f46cb6e02b21b4b6c0067419503b0114bb22b5
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/config/pendulum_td3_data_generation_config.py
@@ -0,0 +1,71 @@
+from easydict import EasyDict
+
+pendulum_td3_generation_config = dict(
+ exp_name='pendulum_td3_generation_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=10,
+ # (bool) Scale output action into legal range.
+ act_scale=True,
+ n_evaluator_episode=10,
+ stop_value=-250,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ random_collect_size=800,
+ model=dict(
+ obs_shape=3,
+ action_shape=1,
+ twin_critic=True,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=2,
+ batch_size=128,
+ learning_rate_actor=0.001,
+ learning_rate_critic=0.001,
+ ignore_done=True,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ learner=dict(
+ load_path='./td3/ckpt/ckpt_best.pth.tar',
+ hook=dict(
+ load_ckpt_before_run='./td3/ckpt/ckpt_best.pth.tar',
+ save_ckpt_after_run=False,
+ )
+ ),
+ ),
+ collect=dict(
+ n_sample=10,
+ noise_sigma=0.1,
+ collector=dict(collect_print_freq=1000, ),
+ save_path='./td3/expert.pkl',
+ data_type='hdf5',
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ other=dict(replay_buffer=dict(replay_buffer_size=40000, ), ),
+ ),
+)
+pendulum_td3_generation_config = EasyDict(pendulum_td3_generation_config)
+main_config = pendulum_td3_generation_config
+
+pendulum_td3_generation_create_config = dict(
+ env=dict(
+ type='pendulum',
+ import_names=['dizoo.classic_control.pendulum.envs.pendulum_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ddpg'),
+)
+pendulum_td3_generation_create_config = EasyDict(pendulum_td3_generation_create_config)
+create_config = pendulum_td3_generation_create_config
+
+if __name__ == "__main__":
+ from ding.entry import collect_demo_data
+ collect_demo_data([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/entry/__init__.py b/DI-engine/dizoo/classic_control/pendulum/entry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_cql_ddpg_main.py b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_cql_ddpg_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac9576d1382eac5277cf56b3deccace67fc1f105
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_cql_ddpg_main.py
@@ -0,0 +1,58 @@
+import torch
+from copy import deepcopy
+
+from dizoo.classic_control.pendulum.config.pendulum_ddpg_data_generation_config import main_config, create_config
+from ding.entry import serial_pipeline_offline, collect_demo_data, eval, serial_pipeline
+
+
+def train_cql(args):
+ from dizoo.classic_control.pendulum.config.pendulum_cql_config import main_config, create_config
+ main_config.exp_name = 'cql_ddpg'
+ main_config.policy.learn.data_path = './ddpg/expert_demos.hdf5'
+ main_config.policy.learn.data_type = 'hdf5'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+def eval_ckpt(args):
+ main_config.exp_name = 'ddpg'
+ main_config.policy.learn.learner.load_path = './ddpg/ckpt/ckpt_best.pth.tar'
+ main_config.policy.learn.learner.hook.load_ckpt_before_run = './ddpg/ckpt/ckpt_best.pth.tar'
+ config = deepcopy([main_config, create_config])
+ eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run)
+
+
+def generate(args):
+ main_config.exp_name = 'ddpg'
+ main_config.policy.learn.learner.load_path = './ddpg/ckpt/ckpt_best.pth.tar'
+ main_config.policy.learn.save_path = './ddpg/expert.pkl'
+ # main_config.policy.learn.data_type = 'hdf5'
+ config = deepcopy([main_config, create_config])
+ state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
+ collect_demo_data(
+ config,
+ collect_count=main_config.policy.other.replay_buffer.replay_buffer_size,
+ seed=args.seed,
+ expert_data_path=main_config.policy.learn.save_path,
+ state_dict=state_dict
+ )
+
+
+def train_expert(args):
+ from dizoo.classic_control.pendulum.config.pendulum_ddpg_config import main_config, create_config
+ main_config.exp_name = 'ddpg'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline(config, seed=args.seed)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=10)
+ args = parser.parse_args()
+
+ # train_expert(args)
+ # eval_ckpt(args)
+ generate(args)
+ # train_cql(args)
diff --git a/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_cql_main.py b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_cql_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c89f31926ac0a944eca5f95389f0adfb9942bc4
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_cql_main.py
@@ -0,0 +1,59 @@
+import torch
+from copy import deepcopy
+
+from dizoo.classic_control.pendulum.config.pendulum_sac_data_generation_default_config import main_config, create_config
+from ding.entry import serial_pipeline_offline, collect_demo_data, eval, serial_pipeline
+
+
+def train_cql(args):
+ from dizoo.classic_control.pendulum.config.pendulum_cql_config import main_config, create_config
+ main_config.exp_name = 'cql_sac'
+ main_config.policy.collect.data_path = './sac/expert_demos.hdf5'
+ main_config.policy.collect.data_type = 'hdf5'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+def eval_ckpt(args):
+ main_config.exp_name = 'sac'
+ main_config.policy.learn.learner.load_path = './sac/ckpt/ckpt_best.pth.tar'
+ main_config.policy.learn.learner.hook.load_ckpt_before_run = './sac/ckpt/ckpt_best.pth.tar'
+ state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
+ config = deepcopy([main_config, create_config])
+ eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run)
+
+
+def generate(args):
+ main_config.exp_name = 'sac'
+ main_config.policy.learn.learner.load_path = './sac/ckpt/ckpt_best.pth.tar'
+ main_config.policy.collect.save_path = './sac/expert.pkl'
+ main_config.policy.collect.data_type = 'hdf5'
+ config = deepcopy([main_config, create_config])
+ state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
+ collect_demo_data(
+ config,
+ collect_count=main_config.policy.other.replay_buffer.replay_buffer_size,
+ seed=args.seed,
+ expert_data_path=main_config.policy.collect.save_path,
+ state_dict=state_dict
+ )
+
+
+def train_expert(args):
+ from dizoo.classic_control.pendulum.config.pendulum_sac_config import main_config, create_config
+ config = deepcopy([main_config, create_config])
+ config[0].exp_name = 'sac'
+ serial_pipeline(config, seed=args.seed)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=10)
+ args = parser.parse_args()
+
+ train_expert(args)
+ eval_ckpt(args)
+ generate(args)
+ train_cql(args)
diff --git a/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_d4pg_main.py b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_d4pg_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..e84eb4b10dd129d848f345ff2568abfba087dd0a
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_d4pg_main.py
@@ -0,0 +1,76 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager
+from ding.policy.d4pg import D4PGPolicy
+from ding.model.template import QACDIST
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.pendulum.envs import PendulumEnv
+from dizoo.classic_control.pendulum.config.pendulum_d4pg_config import pendulum_d4pg_config
+
+
+def main(cfg, seed=0):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ D4PGPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ save_cfg=True
+ )
+
+ # Set up envs for collection and evaluation
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManager(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = QACDIST(**cfg.policy.model)
+ policy = D4PGPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluate at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data from environments
+ new_data = collector.collect(train_iter=learner.train_iter)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Train
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+ replay_buffer.update(learner.priority_info)
+
+
+if __name__ == "__main__":
+ main(pendulum_d4pg_config, seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_ddpg_main.py b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_ddpg_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..0153071ed7eab1e11adb4ed7631423f34b2c9f8e
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_ddpg_main.py
@@ -0,0 +1,86 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import DDPGPolicy
+from ding.model import ContinuousQAC
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.pendulum.envs import PendulumEnv
+from dizoo.classic_control.pendulum.config.pendulum_ddpg_config import pendulum_ddpg_config
+
+
+def main(cfg, seed=0):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ DDPGPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ save_cfg=True
+ )
+
+ # Set up envs for collection and evaluation
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ # You can either use `PendulumEnv` or `DingEnvWrapper` to make a pendulum env and therefore an env manager.
+ # == Use `DingEnvWrapper`
+ collector_env = BaseEnvManager(
+ env_fn=[lambda: DingEnvWrapper(env=gym.make('Pendulum-v1'), cfg=cfg.env) for _ in range(collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManager(
+ env_fn=[lambda: DingEnvWrapper(env=gym.make('Pendulum-v1'), cfg=cfg.env) for _ in range(evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+ # == Use `PendulumEnv`
+ # collector_env = BaseEnvManager(
+ # env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
+ # )
+ # evaluator_env = BaseEnvManager(
+ # env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ # )
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = ContinuousQAC(**cfg.policy.model)
+ policy = DDPGPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluate at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data from environments
+ new_data = collector.collect(train_iter=learner.train_iter)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Train
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(pendulum_ddpg_config, seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_dqn_eval.py b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_dqn_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb80ad42ad550342bd16b4e5b2a19593096ecc9d
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_dqn_eval.py
@@ -0,0 +1,60 @@
+import os
+import gym
+import torch
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+from functools import partial
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.classic_control.pendulum.config.pendulum_dqn_config import main_config, create_config
+
+
+def main(rl_cfg, seed=0):
+ main_cfg, create_cfg = rl_cfg
+ cfg = compile_config(
+ main_cfg,
+ BaseEnvManager,
+ DQNPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ create_cfg=create_cfg,
+ save_cfg=True
+ )
+
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
+
+ # Set random seed for all package and instance
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model)
+ policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
+
+ # evaluate
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator.eval()
+
+
+if __name__ == "__main__":
+ main(rl_cfg=(main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_ppo_main.py b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_ppo_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..41ebece26295178b78f0c8a487cee802ec358ee8
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_ppo_main.py
@@ -0,0 +1,55 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, NaiveReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import PPOPolicy
+from ding.model import VAC
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.pendulum.envs import PendulumEnv
+from dizoo.classic_control.pendulum.config.pendulum_ppo_config import pendulum_ppo_config
+
+
+def main(cfg, seed=0, max_iterations=int(1e10)):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ PPOPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ NaiveReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(
+ env_fn=[lambda: PendulumEnv(cfg=cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManager(
+ env_fn=[lambda: PendulumEnv(cfg=cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger)
+ collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
+ evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
+
+ for _ in range(max_iterations):
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ new_data = collector.collect(train_iter=learner.train_iter)
+ learner.train(new_data, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(pendulum_ppo_config)
diff --git a/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_td3_bc_main.py b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_td3_bc_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ed1b1f7dc83d3863134e2abfde1583bd7203f88
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_td3_bc_main.py
@@ -0,0 +1,60 @@
+import torch
+from copy import deepcopy
+
+from dizoo.classic_control.pendulum.config.pendulum_td3_data_generation_config import main_config, create_config
+from ding.entry import serial_pipeline_offline, collect_demo_data, eval, serial_pipeline
+
+
+def train_td3_bc(args):
+ from dizoo.classic_control.pendulum.config.pendulum_td3_bc_config import main_config, create_config
+ main_config.exp_name = 'td3_bc'
+ main_config.policy.collect.data_path = './td3/expert_demos.hdf5'
+ main_config.policy.collect.data_type = 'hdf5'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+def eval_ckpt(args):
+ main_config.exp_name = 'td3'
+ main_config.policy.learn.learner.load_path = './td3/ckpt/ckpt_best.pth.tar'
+ main_config.policy.learn.learner.hook.load_ckpt_before_run = './td3/ckpt/ckpt_best.pth.tar'
+ state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
+ config = deepcopy([main_config, create_config])
+ # eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run)
+ eval(config, seed=args.seed, state_dict=state_dict)
+
+
+def generate(args):
+ main_config.exp_name = 'td3'
+ main_config.policy.learn.learner.load_path = './td3/ckpt/ckpt_best.pth.tar'
+ main_config.policy.collect.save_path = './td3/expert.pkl'
+ main_config.policy.collect.data_type = 'hdf5'
+ config = deepcopy([main_config, create_config])
+ state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
+ collect_demo_data(
+ config,
+ collect_count=main_config.policy.other.replay_buffer.replay_buffer_size,
+ seed=args.seed,
+ expert_data_path=main_config.policy.collect.save_path,
+ state_dict=state_dict
+ )
+
+
+def train_expert(args):
+ from dizoo.classic_control.pendulum.config.pendulum_td3_config import main_config, create_config
+ main_config.exp_name = 'td3'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline(config, seed=args.seed)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ args = parser.parse_args()
+
+ train_expert(args)
+ eval_ckpt(args)
+ generate(args)
+ train_td3_bc(args)
diff --git a/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_td3_main.py b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_td3_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..36bfce14f3eb5958fd7979c695c704e6e69ff68b
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/entry/pendulum_td3_main.py
@@ -0,0 +1,82 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from torch.optim.lr_scheduler import LambdaLR
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import DDPGPolicy
+from ding.model import ContinuousQAC
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.pendulum.envs import PendulumEnv
+from dizoo.classic_control.pendulum.config.pendulum_td3_config import pendulum_td3_config
+
+
+def main(cfg, seed=0):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ DDPGPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ save_cfg=True
+ )
+
+ # Set up envs for collection and evaluation
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManager(
+ env_fn=[lambda: PendulumEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = ContinuousQAC(**cfg.policy.model)
+ policy = DDPGPolicy(cfg.policy, model=model)
+ # lr_scheduler demo
+ lr_scheduler = LambdaLR(
+ policy.learn_mode.get_attribute('optimizer_actor'), lr_lambda=lambda iters: min(1.0, 0.5 + 0.5 * iters / 1000)
+ )
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluate at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data from environments
+ new_data = collector.collect(train_iter=learner.train_iter)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Train
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+ lr_scheduler.step()
+ tb_logger.add_scalar('other_iter/scheduled_lr', lr_scheduler.get_last_lr()[0], learner.train_iter)
+
+
+if __name__ == "__main__":
+ main(pendulum_td3_config, seed=0)
diff --git a/DI-engine/dizoo/classic_control/pendulum/envs/__init__.py b/DI-engine/dizoo/classic_control/pendulum/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..85455798230de6c6df45592a1c1fe527dcdb5049
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/envs/__init__.py
@@ -0,0 +1 @@
+from .pendulum_env import PendulumEnv
diff --git a/DI-engine/dizoo/classic_control/pendulum/envs/pendulum_env.py b/DI-engine/dizoo/classic_control/pendulum/envs/pendulum_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3265cbaae142bfb1757138dc45872415f03174e
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/envs/pendulum_env.py
@@ -0,0 +1,129 @@
+from typing import Any, Union, Optional
+import gym
+import torch
+import numpy as np
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.envs.common.common_function import affine_transform
+from ding.utils import ENV_REGISTRY
+from ding.torch_utils import to_ndarray, to_list
+
+
+@ENV_REGISTRY.register('pendulum')
+class PendulumEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._act_scale = cfg.act_scale
+ self._env = gym.make('Pendulum-v1')
+ self._init_flag = False
+ self._replay_path = None
+ if 'continuous' in cfg.keys():
+ self._continuous = cfg.continuous
+ else:
+ self._continuous = True
+ self._observation_space = gym.spaces.Box(
+ low=np.array([-1.0, -1.0, -8.0]), high=np.array([1.0, 1.0, 8.0]), shape=(3, ), dtype=np.float32
+ )
+ if self._continuous:
+ self._action_space = gym.spaces.Box(low=-2.0, high=2.0, shape=(1, ), dtype=np.float32)
+ else:
+ self._discrete_action_num = 11
+ self._action_space = gym.spaces.Discrete(self._discrete_action_num)
+ self._action_space.seed(0) # default seed
+ self._reward_space = gym.spaces.Box(
+ low=-1 * (3.14 * 3.14 + 0.1 * 8 * 8 + 0.001 * 2 * 2), high=0.0, shape=(1, ), dtype=np.float32
+ )
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = gym.make('Pendulum-v1')
+ if self._replay_path is not None:
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ self._action_space.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._action_space.seed(self._seed)
+ obs = self._env.reset()
+ obs = to_ndarray(obs).astype(np.float32)
+ self._eval_episode_return = 0.
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ assert isinstance(action, np.ndarray), type(action)
+ # if require discrete env, convert actions to [-1 ~ 1] float actions
+ if not self._continuous:
+ action = (action / (self._discrete_action_num - 1)) * 2 - 1
+ # scale into [-2, 2]
+ if self._act_scale:
+ action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high)
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ obs = to_ndarray(obs).astype(np.float32)
+ # wrapped to be transfered to a array with shape (1,)
+ rew = to_ndarray([rew]).astype(np.float32)
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ def random_action(self) -> np.ndarray:
+ # consider discrete
+ if self._continuous:
+ random_action = self.action_space.sample().astype(np.float32)
+ else:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine Pendulum Env({})".format(self._cfg.env_id)
+
+
+@ENV_REGISTRY.register('mbpendulum')
+class MBPendulumEnv(PendulumEnv):
+
+ def termination_fn(self, next_obs: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ This function determines whether each state is a terminated state
+ .. note::
+ Done is always false for pendulum, according to\
+ .
+ """
+ done = torch.zeros_like(next_obs.sum(-1)).bool()
+ return done
diff --git a/DI-engine/dizoo/classic_control/pendulum/envs/test_pendulum_env.py b/DI-engine/dizoo/classic_control/pendulum/envs/test_pendulum_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7ee142b30cd17f0aa4c5c19c8f67abe8784dd84
--- /dev/null
+++ b/DI-engine/dizoo/classic_control/pendulum/envs/test_pendulum_env.py
@@ -0,0 +1,55 @@
+import pytest
+import numpy as np
+from easydict import EasyDict
+from torch import rand
+from dizoo.classic_control.pendulum.envs import PendulumEnv
+
+
+@pytest.mark.envtest
+class TestPendulumEnv:
+
+ def test_naive(self):
+ env = PendulumEnv(EasyDict({'act_scale': True}))
+ env.seed(314)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (3, )
+ for i in range(10):
+ # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
+ # can generate legal random action.
+ if i < 5:
+ random_action = np.tanh(np.random.random(1))
+ else:
+ random_action = env.random_action()
+ timestep = env.step(random_action)
+ assert timestep.obs.shape == (3, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.reward_space.low
+ assert timestep.reward <= env.reward_space.high
+ # assert isinstance(timestep, tuple)
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
+
+ def test_discrete(self):
+ env = PendulumEnv(EasyDict({'act_scale': True, 'continuous': False}))
+ env.seed(314)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (3, )
+ for i in range(10):
+ # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
+ # can generate legal random action.
+ if i < 5:
+ random_action = np.array([env.action_space.sample()])
+ else:
+ random_action = env.random_action()
+ timestep = env.step(random_action)
+ print(env.observation_space, env.action_space, env.reward_space)
+ print(timestep.reward, timestep.obs, timestep.reward)
+ assert timestep.reward.shape == (1, )
+ assert timestep.obs.shape == (3, )
+ assert timestep.reward >= env.reward_space.low
+ assert timestep.reward <= env.reward_space.high
+ # assert isinstance(timestep, tuple)
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/cliffwalking/__init__.py b/DI-engine/dizoo/cliffwalking/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/cliffwalking/config/cliffwalking_dqn_config.py b/DI-engine/dizoo/cliffwalking/config/cliffwalking_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c852858ab70167750f2897d16373a2551a27aeca
--- /dev/null
+++ b/DI-engine/dizoo/cliffwalking/config/cliffwalking_dqn_config.py
@@ -0,0 +1,60 @@
+from easydict import EasyDict
+
+cliffwalking_dqn_config = dict(
+ exp_name='cliffwalking_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=-13, # the optimal value of cliffwalking env
+ max_episode_steps=300,
+ ),
+ policy=dict(
+ cuda=True,
+ load_path="./cliffwalking_dqn_seed0/ckpt/ckpt_best.pth.tar",
+ model=dict(
+ obs_shape=48,
+ action_shape=4,
+ encoder_hidden_size_list=[512, 64],
+ ),
+ discount_factor=0.99,
+ nstep=1,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=128,
+ learning_rate=0.001,
+ target_update_freq=100,
+ ),
+ collect=dict(
+ n_sample=64,
+ unroll_len=1,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=0.95,
+ end=0.25,
+ decay=50000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+)
+cliffwalking_dqn_config = EasyDict(cliffwalking_dqn_config)
+main_config = cliffwalking_dqn_config
+
+cliffwalking_dqn_create_config = dict(
+ env=dict(
+ type='cliffwalking',
+ import_names=['dizoo.cliffwalking.envs.cliffwalking_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+cliffwalking_dqn_create_config = EasyDict(cliffwalking_dqn_create_config)
+create_config = cliffwalking_dqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c cliffwalking_dqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/cliffwalking/entry/cliffwalking_dqn_deploy.py b/DI-engine/dizoo/cliffwalking/entry/cliffwalking_dqn_deploy.py
new file mode 100644
index 0000000000000000000000000000000000000000..02fe49a0a768b2069f3d5a5fbada42fa809b3406
--- /dev/null
+++ b/DI-engine/dizoo/cliffwalking/entry/cliffwalking_dqn_deploy.py
@@ -0,0 +1,39 @@
+import gym
+import torch
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.envs import DingEnvWrapper
+from ding.model import DQN
+from ding.policy import DQNPolicy, single_env_forward_wrapper
+from dizoo.cliffwalking.config.cliffwalking_dqn_config import create_config, main_config
+from dizoo.cliffwalking.envs.cliffwalking_env import CliffWalkingEnv
+
+
+def main(main_config: EasyDict, create_config: EasyDict, ckpt_path: str):
+ main_config.exp_name = f'cliffwalking_dqn_seed0_deploy'
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ env = CliffWalkingEnv(cfg.env)
+ env.enable_save_replay(replay_path=f'./{main_config.exp_name}/video')
+ model = DQN(**cfg.policy.model)
+ state_dict = torch.load(ckpt_path, map_location='cpu')
+ model.load_state_dict(state_dict['model'])
+ policy = DQNPolicy(cfg.policy, model=model).eval_mode
+ forward_fn = single_env_forward_wrapper(policy.forward)
+ obs = env.reset()
+ returns = 0.
+ while True:
+ action = forward_fn(obs)
+ obs, rew, done, info = env.step(action)
+ returns += rew
+ if done:
+ break
+ print(f'Deploy is finished, final epsiode return is: {returns}')
+
+
+if __name__ == "__main__":
+ main(
+ main_config=main_config,
+ create_config=create_config,
+ ckpt_path=f'./cliffwalking_dqn_seed0/ckpt/ckpt_best.pth.tar'
+ )
diff --git a/DI-engine/dizoo/cliffwalking/entry/cliffwalking_dqn_main.py b/DI-engine/dizoo/cliffwalking/entry/cliffwalking_dqn_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..8a8c082155ca00b624b33e67b29e817358621dce
--- /dev/null
+++ b/DI-engine/dizoo/cliffwalking/entry/cliffwalking_dqn_main.py
@@ -0,0 +1,50 @@
+import gym
+from ditk import logging
+
+from ding.config import compile_config
+from ding.data import DequeBuffer
+from ding.envs import BaseEnvManagerV2, DingEnvWrapper
+from ding.framework import ding_init, task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import CkptSaver, OffPolicyLearner, StepCollector, data_pusher, eps_greedy_handler, \
+ interaction_evaluator, online_logger
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.utils import set_pkg_seed
+from dizoo.cliffwalking.config.cliffwalking_dqn_config import create_config, main_config
+from dizoo.cliffwalking.envs.cliffwalking_env import CliffWalkingEnv
+
+
+def main():
+ filename = '{}/log.txt'.format(main_config.exp_name)
+ logging.getLogger(with_files=[filename]).setLevel(logging.INFO)
+
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: CliffWalkingEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: CliffWalkingEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ buffer = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(data_pusher(cfg, buffer))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer))
+ task.use(online_logger(train_show_freq=10))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+ task.run()
+
+
+if __name__ == '__main__':
+ main()
diff --git a/DI-engine/dizoo/cliffwalking/envs/__init__.py b/DI-engine/dizoo/cliffwalking/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d90c30675ff46506214738f7c16c1dc9b03301bc
--- /dev/null
+++ b/DI-engine/dizoo/cliffwalking/envs/__init__.py
@@ -0,0 +1 @@
+from .cliffwalking_env import CliffWalkingEnv
diff --git a/DI-engine/dizoo/cliffwalking/envs/cliffwalking_env.py b/DI-engine/dizoo/cliffwalking/envs/cliffwalking_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..79d53ba64c24035c660e5072187b926f7b8fcf71
--- /dev/null
+++ b/DI-engine/dizoo/cliffwalking/envs/cliffwalking_env.py
@@ -0,0 +1,111 @@
+import copy
+from typing import List, Union, Optional
+
+import gym
+import numpy as np
+from easydict import EasyDict
+
+from ding.envs.env.base_env import BaseEnv, BaseEnvTimestep
+from ding.torch_utils import to_ndarray
+from ding.utils import ENV_REGISTRY
+
+
+@ENV_REGISTRY.register('cliffwalking')
+class CliffWalkingEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = EasyDict(
+ env_id='CliffWalking',
+ render_mode='rgb_array',
+ max_episode_steps=300, # default max trajectory length to truncate possible infinite attempts
+ )
+ self._cfg.update(cfg)
+ self._init_flag = False
+ self._replay_path = None
+ self._observation_space = gym.spaces.Box(low=0, high=1, shape=(48, ), dtype=np.float32)
+ self._env = gym.make(
+ "CliffWalking", render_mode=self._cfg.render_mode, max_episode_steps=self._cfg.max_episode_steps
+ )
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = gym.make(
+ "CliffWalking", render_mode=self._cfg.render_mode, max_episode_steps=self._cfg.max_episode_steps
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ dy_seed = self._seed + 100 * np.random.randint(1, 1000)
+ self._env.seed(dy_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ if self._replay_path is not None:
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='cliffwalking-{}'.format(id(self))
+ )
+ obs = self._env.reset()
+ obs_encode = self._encode_obs(obs)
+ self._eval_episode_return = 0.
+ return obs_encode
+
+ def close(self) -> None:
+ try:
+ self._env.close()
+ del self._env
+ except:
+ pass
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(seed)
+
+ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
+ if isinstance(action, np.ndarray):
+ if action.shape == (1, ):
+ action = action.squeeze() # 0-dim array
+ action = action.item()
+ obs, reward, done, info = self._env.step(action)
+ obs_encode = self._encode_obs(obs)
+ self._eval_episode_return += reward
+ reward = to_ndarray([reward], dtype=np.float32)
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ return BaseEnvTimestep(obs_encode, reward, done, info)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ if isinstance(random_action, int):
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ def _encode_obs(self, obs) -> np.ndarray:
+ onehot = np.zeros(48, dtype=np.float32)
+ onehot[int(obs)] = 1
+ return onehot
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine CliffWalking Env"
diff --git a/DI-engine/dizoo/cliffwalking/envs/test_cliffwalking_env.py b/DI-engine/dizoo/cliffwalking/envs/test_cliffwalking_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..b378d1a1a85dfe51f1ffccfd558d6e25894775e2
--- /dev/null
+++ b/DI-engine/dizoo/cliffwalking/envs/test_cliffwalking_env.py
@@ -0,0 +1,35 @@
+import numpy as np
+import pytest
+from dizoo.cliffwalking.envs import CliffWalkingEnv
+
+
+@pytest.mark.envtest
+class TestCliffWalkingEnv:
+
+ def test_naive(self):
+ env = CliffWalkingEnv({})
+ env.seed(314, dynamic_seed=False)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (48, )
+ for _ in range(5):
+ env.reset()
+ np.random.seed(314)
+ print('=' * 60)
+ for i in range(10):
+ # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
+ # can generate legal random action.
+ if i < 5:
+ random_action = np.array([env.action_space.sample()])
+ else:
+ random_action = env.random_action()
+ timestep = env.step(random_action)
+ print(timestep)
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (48, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.reward_space.low
+ assert timestep.reward <= env.reward_space.high
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/common/__init__.py b/DI-engine/dizoo/common/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/common/policy/__init__.py b/DI-engine/dizoo/common/policy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/common/policy/md_dqn.py b/DI-engine/dizoo/common/policy/md_dqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..01cc9e13e9a6f4a9cd2bc50ce6449425132f94d2
--- /dev/null
+++ b/DI-engine/dizoo/common/policy/md_dqn.py
@@ -0,0 +1,103 @@
+from typing import Dict, Any
+import torch
+from ding.rl_utils import q_nstep_td_data, q_nstep_td_error
+from ding.policy import DQNPolicy
+from ding.utils import POLICY_REGISTRY
+from ding.policy.common_utils import default_preprocess_learn
+from ding.torch_utils import to_device
+
+
+@POLICY_REGISTRY.register('md_dqn')
+class MultiDiscreteDQNPolicy(DQNPolicy):
+ r"""
+ Overview:
+ Policy class of Multi-discrete action space DQN algorithm.
+ """
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ """
+ Overview:
+ Forward computation of learn mode(updating policy). It supports both single and multi-discrete action \
+ space. It depends on whether the ``q_value`` is a list.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, a batch of data for training, values are torch.Tensor or \
+ np.ndarray or dict/list combinations.
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Dict type data, a info dict indicated training result, which will be \
+ recorded in text log and tensorboard, values are python scalar or a list of scalars.
+ ArgumentsKeys:
+ - necessary: ``obs``, ``action``, ``reward``, ``next_obs``, ``done``
+ - optional: ``value_gamma``, ``IS``
+ ReturnsKeys:
+ - necessary: ``cur_lr``, ``total_loss``, ``priority``
+ - optional: ``action_distribution``
+ """
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # Q-learning forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # Current q value (main model)
+ q_value = self._learn_model.forward(data['obs'])['logit']
+ # Target q value
+ with torch.no_grad():
+ target_q_value = self._target_model.forward(data['next_obs'])['logit']
+ # Max q value action (main model)
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+
+ value_gamma = data.get('value_gamma')
+ if isinstance(q_value, list):
+ act_num = len(q_value)
+ loss, td_error_per_sample = [], []
+ q_value_list = []
+ for i in range(act_num):
+ td_data = q_nstep_td_data(
+ q_value[i], target_q_value[i], data['action'][i], target_q_action[i], data['reward'], data['done'],
+ data['weight']
+ )
+ loss_, td_error_per_sample_ = q_nstep_td_error(
+ td_data, self._gamma, nstep=self._nstep, value_gamma=value_gamma
+ )
+ loss.append(loss_)
+ td_error_per_sample.append(td_error_per_sample_.abs())
+ q_value_list.append(q_value[i].mean().item())
+ loss = sum(loss) / (len(loss) + 1e-8)
+ td_error_per_sample = sum(td_error_per_sample) / (len(td_error_per_sample) + 1e-8)
+ q_value_mean = sum(q_value_list) / act_num
+ else:
+ data_n = q_nstep_td_data(
+ q_value, target_q_value, data['action'], target_q_action, data['reward'], data['done'], data['weight']
+ )
+ loss, td_error_per_sample = q_nstep_td_error(
+ data_n, self._gamma, nstep=self._nstep, value_gamma=value_gamma
+ )
+ q_value_mean = q_value.mean().item()
+
+ # ====================
+ # Q-learning update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ self._optimizer.step()
+
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'q_value_mean': q_value_mean,
+ 'priority': td_error_per_sample.abs().tolist(),
+ }
diff --git a/DI-engine/dizoo/common/policy/md_ppo.py b/DI-engine/dizoo/common/policy/md_ppo.py
new file mode 100644
index 0000000000000000000000000000000000000000..537744e955684845e988c269a3fa5d0539edd009
--- /dev/null
+++ b/DI-engine/dizoo/common/policy/md_ppo.py
@@ -0,0 +1,188 @@
+from typing import List, Dict, Any, Tuple, Union
+import torch
+
+from ding.policy import PPOPolicy, PPOOffPolicy
+from ding.rl_utils import ppo_data, ppo_error, gae, gae_data
+from ding.utils import POLICY_REGISTRY, split_data_generator
+from ding.torch_utils import to_device
+from ding.policy.common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('md_ppo')
+class MultiDiscretePPOPolicy(PPOPolicy):
+ r"""
+ Overview:
+ Policy class of Multi-discrete action space PPO algorithm.
+ """
+
+ def _forward_learn(self, data: Dict[str, Any]) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`):
+ Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \
+ adv_max, adv_mean, value_max, value_mean, approx_kl, clipfrac
+ """
+ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=False)
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # PPO forward
+ # ====================
+ return_infos = []
+ self._learn_model.train()
+
+ for epoch in range(self._cfg.learn.epoch_per_collect):
+ if self._recompute_adv:
+ with torch.no_grad():
+ value = self._learn_model.forward(data['obs'], mode='compute_critic')['value']
+ next_value = self._learn_model.forward(data['next_obs'], mode='compute_critic')['value']
+ if self._value_norm:
+ value *= self._running_mean_std.std
+ next_value *= self._running_mean_std.std
+
+ compute_adv_data = gae_data(value, next_value, data['reward'], data['done'], data['traj_flag'])
+ # GAE need (T, B) shape input and return (T, B) output
+ data['adv'] = gae(compute_adv_data, self._gamma, self._gae_lambda)
+ # value = value[:-1]
+ unnormalized_returns = value + data['adv']
+
+ if self._value_norm:
+ data['value'] = value / self._running_mean_std.std
+ data['return'] = unnormalized_returns / self._running_mean_std.std
+ self._running_mean_std.update(unnormalized_returns.cpu().numpy())
+ else:
+ data['value'] = value
+ data['return'] = unnormalized_returns
+
+ else: # don't recompute adv
+ if self._value_norm:
+ unnormalized_return = data['adv'] + data['value'] * self._running_mean_std.std
+ data['return'] = unnormalized_return / self._running_mean_std.std
+ self._running_mean_std.update(unnormalized_return.cpu().numpy())
+ else:
+ data['return'] = data['adv'] + data['value']
+
+ for batch in split_data_generator(data, self._cfg.learn.batch_size, shuffle=True):
+ output = self._learn_model.forward(batch['obs'], mode='compute_actor_critic')
+ adv = batch['adv']
+ if self._adv_norm:
+ # Normalize advantage in a train_batch
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
+
+ # Calculate ppo error
+ loss_list = []
+ info_list = []
+ action_num = len(batch['action'])
+ for i in range(action_num):
+ ppo_batch = ppo_data(
+ output['logit'][i], batch['logit'][i], batch['action'][i], output['value'], batch['value'], adv,
+ batch['return'], batch['weight']
+ )
+ ppo_loss, ppo_info = ppo_error(ppo_batch, self._clip_ratio)
+ loss_list.append(ppo_loss)
+ info_list.append(ppo_info)
+ avg_policy_loss = sum([item.policy_loss for item in loss_list]) / action_num
+ avg_value_loss = sum([item.value_loss for item in loss_list]) / action_num
+ avg_entropy_loss = sum([item.entropy_loss for item in loss_list]) / action_num
+ avg_approx_kl = sum([item.approx_kl for item in info_list]) / action_num
+ avg_clipfrac = sum([item.clipfrac for item in info_list]) / action_num
+
+ wv, we = self._value_weight, self._entropy_weight
+ total_loss = avg_policy_loss + wv * avg_value_loss - we * avg_entropy_loss
+
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ self._optimizer.step()
+
+ return_info = {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': avg_policy_loss.item(),
+ 'value_loss': avg_value_loss.item(),
+ 'entropy_loss': avg_entropy_loss.item(),
+ 'adv_max': adv.max().item(),
+ 'adv_mean': adv.mean().item(),
+ 'value_mean': output['value'].mean().item(),
+ 'value_max': output['value'].max().item(),
+ 'approx_kl': avg_approx_kl,
+ 'clipfrac': avg_clipfrac,
+ }
+ return_infos.append(return_info)
+ return return_infos
+
+
+@POLICY_REGISTRY.register('md_ppo_offpolicy')
+class MultiDiscretePPOOffPolicy(PPOOffPolicy):
+ r"""
+ Overview:
+ Policy class of Multi-discrete action space off-policy PPO algorithm.
+ """
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`):
+ Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \
+ adv_abs_max, approx_kl, clipfrac
+ """
+ assert not self._nstep_return
+ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=self._nstep_return)
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # PPO forward
+ # ====================
+
+ self._learn_model.train()
+ # normal ppo
+ output = self._learn_model.forward(data['obs'], mode='compute_actor_critic')
+ adv = data['adv']
+ return_ = data['value'] + adv
+ if self._adv_norm:
+ # Normalize advantage in a total train_batch
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
+ # Calculate ppo error
+ loss_list = []
+ info_list = []
+ action_num = len(data['action'])
+ for i in range(action_num):
+ ppodata = ppo_data(
+ output['logit'][i], data['logit'][i], data['action'][i], output['value'], data['value'], adv, return_,
+ data['weight']
+ )
+ ppo_loss, ppo_info = ppo_error(ppodata, self._clip_ratio)
+ loss_list.append(ppo_loss)
+ info_list.append(ppo_info)
+ avg_policy_loss = sum([item.policy_loss for item in loss_list]) / action_num
+ avg_value_loss = sum([item.value_loss for item in loss_list]) / action_num
+ avg_entropy_loss = sum([item.entropy_loss for item in loss_list]) / action_num
+ avg_approx_kl = sum([item.approx_kl for item in info_list]) / action_num
+ avg_clipfrac = sum([item.clipfrac for item in info_list]) / action_num
+
+ wv, we = self._value_weight, self._entropy_weight
+ total_loss = avg_policy_loss + wv * avg_value_loss - we * avg_entropy_loss
+
+ # ====================
+ # PPO update
+ # ====================
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ self._optimizer.step()
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': avg_policy_loss,
+ 'value_loss': avg_value_loss,
+ 'entropy_loss': avg_entropy_loss,
+ 'adv_abs_max': adv.abs().max().item(),
+ 'approx_kl': avg_approx_kl,
+ 'clipfrac': avg_clipfrac,
+ }
diff --git a/DI-engine/dizoo/common/policy/md_rainbow_dqn.py b/DI-engine/dizoo/common/policy/md_rainbow_dqn.py
new file mode 100644
index 0000000000000000000000000000000000000000..8f0e83b7ba96d7f923745d33f3aeaec7d43619a8
--- /dev/null
+++ b/DI-engine/dizoo/common/policy/md_rainbow_dqn.py
@@ -0,0 +1,105 @@
+from typing import Dict, Any
+import torch
+from ding.torch_utils import to_device
+from ding.rl_utils import dist_nstep_td_data, dist_nstep_td_error, dist_1step_td_data, dist_1step_td_error
+from ding.policy import RainbowDQNPolicy
+from ding.utils import POLICY_REGISTRY
+from ding.policy.common_utils import default_preprocess_learn
+
+
+@POLICY_REGISTRY.register('md_rainbow_dqn')
+class MultiDiscreteRainbowDQNPolicy(RainbowDQNPolicy):
+ r"""
+ Overview:
+ Multi-discrete action space Rainbow DQN algorithms.
+ """
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ """
+ Overview:
+ Forward and backward function of learn mode, acquire the data and calculate the loss and \
+ optimize learner model
+
+ Arguments:
+ - data (:obj:`dict`): Dict type data, including at least ['obs', 'next_obs', 'reward', 'action']
+
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`): Including cur_lr, total_loss and priority
+ - cur_lr (:obj:`float`): current learning rate
+ - total_loss (:obj:`float`): the calculated loss
+ - priority (:obj:`list`): the priority of samples
+ """
+ data = default_preprocess_learn(
+ data,
+ use_priority=self._priority,
+ use_priority_IS_weight=self._cfg.priority_IS_weight,
+ ignore_done=self._cfg.learn.ignore_done,
+ use_nstep=True
+ )
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # Rainbow forward
+ # ====================
+ self._learn_model.train()
+ self._target_model.train()
+ # reset noise of noisenet for both main model and target model
+ self._reset_noise(self._learn_model)
+ self._reset_noise(self._target_model)
+ q_dist = self._learn_model.forward(data['obs'])['distribution']
+ with torch.no_grad():
+ target_q_dist = self._target_model.forward(data['next_obs'])['distribution']
+ self._reset_noise(self._learn_model)
+ target_q_action = self._learn_model.forward(data['next_obs'])['action']
+
+ value_gamma = data.get('value_gamma', None)
+ if isinstance(q_dist, torch.Tensor):
+ td_data = dist_nstep_td_data(
+ q_dist, target_q_dist, data['action'], target_q_action, data['reward'], data['done'], data['weight']
+ )
+ loss, td_error_per_sample = dist_nstep_td_error(
+ td_data,
+ self._gamma,
+ self._v_min,
+ self._v_max,
+ self._n_atom,
+ nstep=self._nstep,
+ value_gamma=value_gamma
+ )
+ else:
+ act_num = len(q_dist)
+ losses = []
+ td_error_per_samples = []
+ for i in range(act_num):
+ td_data = dist_nstep_td_data(
+ q_dist[i], target_q_dist[i], data['action'][i], target_q_action[i], data['reward'], data['done'],
+ data['weight']
+ )
+ td_loss, td_error_per_sample = dist_nstep_td_error(
+ td_data,
+ self._gamma,
+ self._v_min,
+ self._v_max,
+ self._n_atom,
+ nstep=self._nstep,
+ value_gamma=value_gamma
+ )
+ losses.append(td_loss)
+ td_error_per_samples.append(td_error_per_sample)
+ loss = sum(losses) / (len(losses) + 1e-8)
+ td_error_per_sample_mean = sum(td_error_per_samples) / (len(td_error_per_samples) + 1e-8)
+ # ====================
+ # Rainbow update
+ # ====================
+ self._optimizer.zero_grad()
+ loss.backward()
+ self._optimizer.step()
+ # =============
+ # after update
+ # =============
+ self._target_model.update(self._learn_model.state_dict())
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': loss.item(),
+ 'priority': td_error_per_sample_mean.abs().tolist(),
+ }
diff --git a/DI-engine/dizoo/competitive_rl/README.md b/DI-engine/dizoo/competitive_rl/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..1d87730953d190fc05cfd5f4526f68628056c3d1
--- /dev/null
+++ b/DI-engine/dizoo/competitive_rl/README.md
@@ -0,0 +1,2 @@
+Environment "Competitive RL"'s original repo is https://github.com/cuhkrlcourse/competitive-rl.
+You can refer to it for guide on installation and usage.
\ No newline at end of file
diff --git a/DI-engine/dizoo/competitive_rl/__init__.py b/DI-engine/dizoo/competitive_rl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/competitive_rl/config/cpong_dqn_config.py b/DI-engine/dizoo/competitive_rl/config/cpong_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..facf75237987e30eef0f2b67928cc846a5d508a1
--- /dev/null
+++ b/DI-engine/dizoo/competitive_rl/config/cpong_dqn_config.py
@@ -0,0 +1,98 @@
+from easydict import EasyDict
+from ding.config import parallel_transform
+
+cpong_dqn_config = dict(
+ env=dict(
+ collector_env_num=16,
+ collector_episode_num=2,
+ evaluator_env_num=8,
+ evaluator_episode_num=2,
+ stop_value=20,
+ opponent_type="builtin", # opponent_type is only used in evaluator
+ env_id='cPongDouble-v0',
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=3,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=1,
+ discount_factor=0.99,
+ learn=dict(
+ batch_size=16,
+ learning_rate=0.001,
+ learner=dict(
+ learner_num=1,
+ send_policy_freq=1,
+ ),
+ ),
+ collect=dict(
+ n_sample=16,
+ collector=dict(
+ collector_num=2,
+ update_policy_second=3,
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=50, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=100000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=100000,
+ enable_track_used_data=False,
+ ),
+ commander=dict(
+ collector_task_space=2,
+ learner_task_space=1,
+ eval_interval=5,
+ league=dict(),
+ ),
+ ),
+ )
+)
+cpong_dqn_config = EasyDict(cpong_dqn_config)
+main_config = cpong_dqn_config
+
+cpong_dqn_create_config = dict(
+ env=dict(
+ import_names=['dizoo.competitive_rl.envs.competitive_rl_env'],
+ type='competitive_rl',
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn_command'),
+ learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
+ collector=dict(
+ type='marine',
+ import_names=['ding.worker.collector.marine_parallel_collector'],
+ ),
+ commander=dict(
+ type='one_vs_one',
+ import_names=['ding.worker.coordinator.one_vs_one_parallel_commander'],
+ ),
+ comm_learner=dict(
+ type='flask_fs',
+ import_names=['ding.worker.learner.comm.flask_fs_learner'],
+ ),
+ comm_collector=dict(
+ type='flask_fs',
+ import_names=['ding.worker.collector.comm.flask_fs_collector'],
+ ),
+)
+cpong_dqn_create_config = EasyDict(cpong_dqn_create_config)
+create_config = cpong_dqn_create_config
+
+cpong_dqn_system_config = dict(
+ coordinator=dict(),
+ path_data='./data',
+ path_policy='./policy',
+ communication_mode='auto',
+ learner_gpu_num=0,
+)
+cpong_dqn_system_config = EasyDict(cpong_dqn_system_config)
+system_config = cpong_dqn_system_config
diff --git a/DI-engine/dizoo/competitive_rl/envs/__init__.py b/DI-engine/dizoo/competitive_rl/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b24801e62fec9f2c280705e49a51c64510d3e364
--- /dev/null
+++ b/DI-engine/dizoo/competitive_rl/envs/__init__.py
@@ -0,0 +1 @@
+from .competitive_rl_env import CompetitiveRlEnv
diff --git a/DI-engine/dizoo/competitive_rl/envs/competitive_rl_env.py b/DI-engine/dizoo/competitive_rl/envs/competitive_rl_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..4db8964c9d07acc2258016debb2d146c94034741
--- /dev/null
+++ b/DI-engine/dizoo/competitive_rl/envs/competitive_rl_env.py
@@ -0,0 +1,181 @@
+from typing import Any, Union, List
+import copy
+import numpy as np
+import gym
+import competitive_rl
+
+from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo, update_shape
+from ding.envs.common.env_element import EnvElement, EnvElementInfo
+from ding.envs.common.common_function import affine_transform
+from ding.torch_utils import to_ndarray, to_list
+from .competitive_rl_env_wrapper import BuiltinOpponentWrapper, wrap_env
+from ding.utils import ENV_REGISTRY
+
+competitive_rl.register_competitive_envs()
+"""
+The observation spaces:
+cPong-v0: Box(210, 160, 3)
+cPongDouble-v0: Tuple(Box(210, 160, 3), Box(210, 160, 3))
+cCarRacing-v0: Box(96, 96, 1)
+cCarRacingDouble-v0: Box(96, 96, 1)
+
+The action spaces:
+cPong-v0: Discrete(3)
+cPongDouble-v0: Tuple(Discrete(3), Discrete(3))
+cCarRacing-v0: Box(2,)
+cCarRacingDouble-v0: Dict(0:Box(2,), 1:Box(2,))
+
+cPongTournament-v0
+"""
+
+COMPETITIVERL_INFO_DICT = {
+ 'cPongDouble-v0': BaseEnvInfo(
+ agent_num=1,
+ obs_space=EnvElementInfo(
+ shape=(210, 160, 3),
+ # shape=(4, 84, 84),
+ value={
+ 'min': 0,
+ 'max': 255,
+ 'dtype': np.float32
+ },
+ ),
+ act_space=EnvElementInfo(
+ shape=(1, ), # different with https://github.com/cuhkrlcourse/competitive-rl#usage
+ value={
+ 'min': 0,
+ 'max': 3,
+ 'dtype': np.float32
+ },
+ ),
+ rew_space=EnvElementInfo(
+ shape=(1, ),
+ value={
+ 'min': np.float32("-inf"),
+ 'max': np.float32("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ use_wrappers=None,
+ ),
+}
+
+
+@ENV_REGISTRY.register('competitive_rl')
+class CompetitiveRlEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._env_id = self._cfg.env_id
+
+ # opponent_type is used to control builtin opponent agent, which is useful in evaluator.
+ is_evaluator = self._cfg.get("is_evaluator", False)
+ opponent_type = None
+ if is_evaluator:
+ opponent_type = self._cfg.get("opponent_type", None)
+ self._builtin_wrap = self._env_id == "cPongDouble-v0" and is_evaluator and opponent_type == "builtin"
+ self._opponent = self._cfg.get('eval_opponent', 'RULE_BASED')
+
+ self._init_flag = False
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = self._make_env(only_info=False)
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ obs = self._env.reset()
+ obs = to_ndarray(obs)
+ obs = self.process_obs(obs) # process
+
+ if self._builtin_wrap:
+ self._eval_episode_return = np.array([0.])
+ else:
+ self._eval_episode_return = np.array([0., 0.])
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
+ action = to_ndarray(action)
+ action = self.process_action(action) # process
+
+ obs, rew, done, info = self._env.step(action)
+
+ if not isinstance(rew, tuple):
+ rew = [rew]
+ rew = np.array(rew)
+ self._eval_episode_return += rew
+
+ obs = to_ndarray(obs)
+ obs = self.process_obs(obs) # process
+
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def info(self) -> BaseEnvInfo:
+ if self._env_id in COMPETITIVERL_INFO_DICT:
+ info = copy.deepcopy(COMPETITIVERL_INFO_DICT[self._env_id])
+ info.use_wrappers = self._make_env(only_info=True)
+ obs_shape, act_shape, rew_shape = update_shape(
+ info.obs_space.shape, info.act_space.shape, info.rew_space.shape, info.use_wrappers.split('\n')
+ )
+ info.obs_space.shape = obs_shape
+ info.act_space.shape = act_shape
+ info.rew_space.shape = rew_shape
+ if not self._builtin_wrap:
+ info.obs_space.shape = (2, ) + info.obs_space.shape
+ info.act_space.shape = (2, )
+ info.rew_space.shape = (2, )
+ return info
+ else:
+ raise NotImplementedError('{} not found in COMPETITIVERL_INFO_DICT [{}]'\
+ .format(self._env_id, COMPETITIVERL_INFO_DICT.keys()))
+
+ def _make_env(self, only_info=False):
+ return wrap_env(self._env_id, self._builtin_wrap, self._opponent, only_info=only_info)
+
+ def __repr__(self) -> str:
+ return "DI-engine Competitve RL Env({})".format(self._cfg.env_id)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_cfg = copy.deepcopy(cfg)
+ collector_env_num = collector_cfg.pop('collector_env_num', 1)
+ collector_cfg.is_evaluator = False
+ return [collector_cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_cfg = copy.deepcopy(cfg)
+ evaluator_env_num = evaluator_cfg.pop('evaluator_env_num', 1)
+ evaluator_cfg.is_evaluator = True
+ return [evaluator_cfg for _ in range(evaluator_env_num)]
+
+ def process_action(self, action: np.ndarray) -> Union[tuple, dict, np.ndarray]:
+ # If in double agent env, transfrom action passed in from outside to tuple or dict type.
+ if self._env_id == "cPongDouble-v0" and not self._builtin_wrap:
+ return (action[0].squeeze(), action[1].squeeze())
+ elif self._env_id == "cCarRacingDouble-v0":
+ return {0: action[0].squeeze(), 1: action[1].squeeze()}
+ else:
+ return action.squeeze()
+
+ def process_obs(self, obs: Union[tuple, np.ndarray]) -> Union[tuple, np.ndarray]:
+ # Copy observation for car racing double agent env, in case to be in alignment with pong double agent env.
+ if self._env_id == "cCarRacingDouble-v0":
+ obs = np.stack([obs, copy.deepcopy(obs)])
+ return obs
diff --git a/DI-engine/dizoo/competitive_rl/envs/competitive_rl_env_wrapper.py b/DI-engine/dizoo/competitive_rl/envs/competitive_rl_env_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..db43ed97ffa6eaa44d0dcf3aeefcedae73deedc1
--- /dev/null
+++ b/DI-engine/dizoo/competitive_rl/envs/competitive_rl_env_wrapper.py
@@ -0,0 +1,231 @@
+import cv2
+import gym
+import os.path as osp
+import numpy as np
+from typing import Union, Optional
+from collections import deque
+from competitive_rl.pong.builtin_policies import get_builtin_agent_names, single_obs_space, single_act_space, get_random_policy, get_rule_based_policy
+from competitive_rl.utils.policy_serving import Policy
+
+
+def get_compute_action_function_ours(agent_name, num_envs=1):
+ resource_dir = osp.join(osp.dirname(__file__), "resources", "pong")
+ if agent_name == "STRONG":
+ return Policy(
+ single_obs_space,
+ single_act_space,
+ num_envs,
+ osp.join(resource_dir, "checkpoint-strong.pkl"),
+ use_light_model=False
+ )
+ if agent_name == "MEDIUM":
+ return Policy(
+ single_obs_space,
+ single_act_space,
+ num_envs,
+ osp.join(resource_dir, "checkpoint-medium.pkl"),
+ use_light_model=True
+ )
+ if agent_name == "ALPHA_PONG":
+ return Policy(
+ single_obs_space,
+ single_act_space,
+ num_envs,
+ osp.join(resource_dir, "checkpoint-alphapong.pkl"),
+ use_light_model=False
+ )
+ if agent_name == "WEAK":
+ return Policy(
+ single_obs_space,
+ single_act_space,
+ num_envs,
+ osp.join(resource_dir, "checkpoint-weak.pkl"),
+ use_light_model=True
+ )
+ if agent_name == "RANDOM":
+ return get_random_policy(num_envs)
+ if agent_name == "RULE_BASED":
+ return get_rule_based_policy(num_envs)
+ raise ValueError("Unknown agent name: {}".format(agent_name))
+
+
+class BuiltinOpponentWrapper(gym.Wrapper):
+
+ def __init__(self, env: 'gym.Env', num_envs: int = 1) -> None: # noqa
+ super().__init__(env)
+ self.agents = {
+ agent_name: get_compute_action_function_ours(agent_name, num_envs)
+ for agent_name in get_builtin_agent_names()
+ }
+ self.agent_names = list(self.agents)
+ self.prev_opponent_obs = None
+ self.current_opponent_name = "RULE_BASED"
+ self.current_opponent = self.agents[self.current_opponent_name]
+ self.observation_space = env.observation_space[0]
+ self.action_space = env.action_space[0]
+ self.num_envs = num_envs
+
+ def reset_opponent(self, agent_name: str) -> None:
+ assert agent_name in self.agent_names, (agent_name, self.agent_names)
+ self.current_opponent_name = agent_name
+ self.current_opponent = self.agents[self.current_opponent_name]
+
+ def step(self, action):
+ tuple_action = (action.item(), self.current_opponent(self.prev_opponent_obs))
+ obs, rew, done, info = self.env.step(tuple_action)
+ self.prev_opponent_obs = obs[1]
+ # if done.ndim == 2:
+ # done = done[:, 0]
+ # return obs[0], rew[:, 0].reshape(-1, 1), done.reshape(-1, 1), info
+ return obs[0], rew[0], done, info
+
+ def reset(self):
+ obs = self.env.reset()
+ self.prev_opponent_obs = obs[1]
+ return obs[0]
+
+ def seed(self, s):
+ self.env.seed(s)
+
+
+def wrap_env(env_id, builtin_wrap, opponent, frame_stack=4, warp_frame=True, only_info=False):
+ """Configure environment for DeepMind-style Atari. The observation is
+ channel-first: (c, h, w) instead of (h, w, c).
+
+ :param str env_id: the atari environment id.
+ :param bool episode_life: wrap the episode life wrapper.
+ :param bool clip_rewards: wrap the reward clipping wrapper.
+ :param int frame_stack: wrap the frame stacking wrapper.
+ :param bool scale: wrap the scaling observation wrapper.
+ :param bool warp_frame: wrap the grayscale + resize observation wrapper.
+ :return: the wrapped atari environment.
+ """
+ if not only_info:
+ env = gym.make(env_id)
+ if builtin_wrap:
+ env = BuiltinOpponentWrapper(env)
+ env.reset_opponent(opponent)
+
+ if warp_frame:
+ env = WarpFrameWrapperCompetitveRl(env, builtin_wrap)
+ if frame_stack:
+ env = FrameStackWrapperCompetitiveRl(env, frame_stack, builtin_wrap)
+ return env
+ else:
+ wrapper_info = ''
+ if builtin_wrap:
+ wrapper_info += BuiltinOpponentWrapper.__name__ + '\n'
+ if warp_frame:
+ wrapper_info = WarpFrameWrapperCompetitveRl.__name__ + '\n'
+ if frame_stack:
+ wrapper_info = FrameStackWrapperCompetitiveRl.__name__ + '\n'
+ return wrapper_info
+
+
+class WarpFrameWrapperCompetitveRl(gym.ObservationWrapper):
+ """Warp frames to 84x84 as done in the Nature paper and later work.
+
+ :param gym.Env env: the environment to wrap.
+ """
+
+ def __init__(self, env, builtin_wrap):
+ super().__init__(env)
+ self.size = 84
+ obs_space = env.observation_space
+ self.builtin_wrap = builtin_wrap
+ if builtin_wrap:
+ # single player
+ self.observation_space = gym.spaces.Box(
+ low=np.min(obs_space.low),
+ high=np.max(obs_space.high),
+ shape=(self.size, self.size),
+ dtype=obs_space.dtype
+ )
+ else:
+ # double player
+ self.observation_space = gym.spaces.tuple.Tuple(
+ [
+ gym.spaces.Box(
+ low=np.min(obs_space[0].low),
+ high=np.max(obs_space[0].high),
+ shape=(self.size, self.size),
+ dtype=obs_space[0].dtype
+ ) for _ in range(len(obs_space))
+ ]
+ )
+
+ def observation(self, frame):
+ """returns the current observation from a frame"""
+ if self.builtin_wrap:
+ frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
+ return cv2.resize(frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
+ else:
+ frames = []
+ for one_frame in frame:
+ one_frame = cv2.cvtColor(one_frame, cv2.COLOR_RGB2GRAY)
+ one_frame = cv2.resize(one_frame, (self.size, self.size), interpolation=cv2.INTER_AREA)
+ frames.append(one_frame)
+ return frames
+
+
+class FrameStackWrapperCompetitiveRl(gym.Wrapper):
+ """Stack n_frames last frames.
+
+ :param gym.Env env: the environment to wrap.
+ :param int n_frames: the number of frames to stack.
+ """
+
+ def __init__(self, env, n_frames, builtin_wrap):
+ super().__init__(env)
+ self.n_frames = n_frames
+
+ self.builtin_wrap = builtin_wrap
+ obs_space = env.observation_space
+ if self.builtin_wrap:
+ self.frames = deque([], maxlen=n_frames)
+ shape = (n_frames, ) + obs_space.shape
+ self.observation_space = gym.spaces.Box(
+ low=np.min(obs_space.low), high=np.max(obs_space.high), shape=shape, dtype=obs_space.dtype
+ )
+ else:
+ self.frames = [deque([], maxlen=n_frames) for _ in range(len(obs_space))]
+ shape = (n_frames, ) + obs_space[0].shape
+ self.observation_space = gym.spaces.tuple.Tuple(
+ [
+ gym.spaces.Box(
+ low=np.min(obs_space[0].low),
+ high=np.max(obs_space[0].high),
+ shape=shape,
+ dtype=obs_space[0].dtype
+ ) for _ in range(len(obs_space))
+ ]
+ )
+
+ def reset(self):
+ if self.builtin_wrap:
+ obs = self.env.reset()
+ for _ in range(self.n_frames):
+ self.frames.append(obs)
+ return self._get_ob(self.frames)
+ else:
+ obs = self.env.reset()
+ for i, one_obs in enumerate(obs):
+ for _ in range(self.n_frames):
+ self.frames[i].append(one_obs)
+ return np.stack([self._get_ob(self.frames[i]) for i in range(len(obs))])
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+ if self.builtin_wrap:
+ self.frames.append(obs)
+ return self._get_ob(self.frames), reward, done, info
+ else:
+ for i, one_obs in enumerate(obs):
+ self.frames[i].append(one_obs)
+ return np.stack([self._get_ob(self.frames[i]) for i in range(len(obs))], axis=0), reward, done, info
+
+ @staticmethod
+ def _get_ob(frames):
+ # the original wrapper use `LazyFrames` but since we use np buffer,
+ # it has no effect
+ return np.stack(frames, axis=0)
diff --git a/DI-engine/dizoo/competitive_rl/envs/test_competitive_rl.py b/DI-engine/dizoo/competitive_rl/envs/test_competitive_rl.py
new file mode 100644
index 0000000000000000000000000000000000000000..4608debd8e87f2b450972f682eed9d74fea67722
--- /dev/null
+++ b/DI-engine/dizoo/competitive_rl/envs/test_competitive_rl.py
@@ -0,0 +1,70 @@
+import competitive_rl
+import pytest
+import numpy as np
+from easydict import EasyDict
+from dizoo.competitive_rl.envs.competitive_rl_env import CompetitiveRlEnv
+
+
+@pytest.mark.envtest
+class TestCompetitiveRlEnv:
+
+ def test_pong_single(self):
+ cfg = dict(
+ opponent_type="builtin",
+ is_evaluator=True,
+ env_id='cPongDouble-v0',
+ )
+ cfg = EasyDict(cfg)
+ env = CompetitiveRlEnv(cfg)
+ env.seed(314)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == env.info().obs_space.shape
+ # act_shape = env.info().act_space.shape
+ act_val = env.info().act_space.value
+ min_val, max_val = act_val['min'], act_val['max']
+ np.random.seed(314)
+ i = 0
+ while True:
+ random_action = np.random.randint(min_val, max_val, size=(1, ))
+ timestep = env.step(random_action)
+ if timestep.done:
+ print(timestep)
+ print('Env episode has {} steps'.format(i))
+ break
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == env.info().obs_space.shape
+ assert timestep.reward.shape == env.info().rew_space.shape
+ assert timestep.reward >= env.info().rew_space.value['min']
+ assert timestep.reward <= env.info().rew_space.value['max']
+ i += 1
+ print(env.info())
+ env.close()
+
+ def test_pong_double(self):
+ cfg = dict(env_id='cPongDouble-v0', )
+ cfg = EasyDict(cfg)
+ env = CompetitiveRlEnv(cfg)
+ env.seed(314)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == env.info().obs_space.shape
+ act_val = env.info().act_space.value
+ min_val, max_val = act_val['min'], act_val['max']
+ np.random.seed(314)
+ i = 0
+ while True:
+ random_action = [np.random.randint(min_val, max_val, size=(1, )) for _ in range(2)]
+ timestep = env.step(random_action)
+ if timestep.done:
+ print(timestep)
+ print('Env episode has {} steps'.format(i))
+ break
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == env.info().obs_space.shape
+ assert timestep.reward.shape == env.info().rew_space.shape
+ i += 1
+ print(env.info())
+ env.close()
diff --git a/DI-engine/dizoo/d4rl/__init__.py b/DI-engine/dizoo/d4rl/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/d4rl/config/__init__.py b/DI-engine/dizoo/d4rl/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..92bca79cc2eb0328c88558e4075006c2c55558ad
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/__init__.py
@@ -0,0 +1,3 @@
+# from .hopper_cql_config import hopper_cql_config
+# from .hopper_expert_cql_config import hopper_expert_cql_config
+# from .hopper_medium_cql_config import hopper_medium_cql_config
diff --git a/DI-engine/dizoo/d4rl/config/antmaze_umaze_pd_config.py b/DI-engine/dizoo/d4rl/config/antmaze_umaze_pd_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..96ca022545b1ade73661305a1a1552c5db2646e6
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/antmaze_umaze_pd_config.py
@@ -0,0 +1,101 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="antmaze_umaze_pd_seed0",
+ env=dict(
+ env_id='antmaze-umaze-v0',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ returns_scale=1.0,
+ termination_penalty=-100,
+ max_path_length=1000,
+ use_padding=True,
+ include_returns=True,
+ normed=False,
+ stop_value=8000,
+ horizon=256,
+ obs_dim=29,
+ action_dim=8,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ diffuser_model='GaussianDiffusion',
+ diffuser_model_cfg=dict(
+ model='DiffusionUNet1d',
+ model_cfg=dict(
+ transition_dim=37,
+ dim=32,
+ dim_mults=[1, 2, 4, 8],
+ returns_condition=False,
+ kernel_size=5,
+ attention=False,
+ ),
+ horizon=256,
+ obs_dim=29,
+ action_dim=8,
+ n_timesteps=20,
+ predict_epsilon=False,
+ loss_discount=1,
+ action_weight=10,
+ ),
+ value_model='ValueDiffusion',
+ value_model_cfg=dict(
+ model='TemporalValue',
+ model_cfg=dict(
+ horizon = 256,
+ transition_dim=37,
+ dim=32,
+ dim_mults=[1, 2, 4, 8],
+ kernel_size=5,
+ ),
+ horizon=256,
+ obs_dim=29,
+ action_dim=8,
+ n_timesteps=20,
+ predict_epsilon=True,
+ loss_discount=1,
+ ),
+ n_guide_steps=2,
+ scale=0.1,
+ t_stopgrad=2,
+ scale_grad_by_std=True,
+ ),
+ normalizer='GaussianNormalizer',
+ learn=dict(
+ data_path=None,
+ train_epoch=60000,
+ gradient_accumulate_every=2,
+ batch_size=32,
+ learning_rate=2e-4,
+ discount_factor=0.99,
+ plan_batch_size=64,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='diffuser_traj', ),
+ eval=dict(
+ evaluator=dict(eval_freq=500, ),
+ test_ret=0.9,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='pd',
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
\ No newline at end of file
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_expert_cql_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_expert_cql_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..58ead98dc57868cd0e7f4be78128f49b1a4805ba
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_expert_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="halfcheetah_expert_cql_seed0",
+ env=dict(
+ env_id='halfcheetah-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_expert_dt_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_expert_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..617d17bc73a0f2830eb2f4bb257567912198b431
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_expert_dt_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+halfcheetah_dt_config = dict(
+ exp_name='dt_log/d4rl/halfcheetah/halfcheetah_expert_dt_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ dataset=dict(
+ env_type='mujoco',
+ rtg_scale=1000,
+ context_len=30,
+ data_dir_prefix='d4rl/halfcheetah_expert-v2.pkl',
+ ),
+ policy=dict(
+ cuda=True,
+ stop_value=6000,
+ state_mean=None,
+ state_std=None,
+ evaluator_env_num=8,
+ env_name='HalfCheetah-v3',
+ rtg_target=6000, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ context_len=20,
+ weight_decay=0.1,
+ clip_grad_norm_p=0.25,
+ model=dict(
+ state_dim=11,
+ act_dim=3,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ batch_size=64,
+ learning_rate=1e-4,
+ collect=dict(
+ data_type='d4rl_trajectory',
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+
+halfcheetah_dt_config = EasyDict(halfcheetah_dt_config)
+main_config = halfcheetah_dt_config
+halfcheetah_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+halfcheetah_dt_create_config = EasyDict(halfcheetah_dt_create_config)
+create_config = halfcheetah_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_expert_td3bc_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_expert_td3bc_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..3cae080e92224aafe0c2201e56c684583ed8e6bf
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_expert_td3bc_config.py
@@ -0,0 +1,65 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='halfcheetah_expert_td3-bc_seed0',
+ env=dict(
+ env_id='halfcheetah-expert-v2',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_bcq_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_bcq_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..c0199dcb0905cf2b2ab7445ac564bf78de41df98
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_bcq_config.py
@@ -0,0 +1,55 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="halfcheetah_medium_bcq_seed0",
+ env=dict(
+ env_id='halfcheetah-medium-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=7000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ actor_head_hidden_size=[400, 300],
+ critic_head_hidden_size=[400, 300],
+ phi=0.05,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=100,
+ learning_rate_q=3e-3,
+ learning_rate_policy=3e-3,
+ learning_rate_alpha=3e-3,
+ lmbda=0.75,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+ seed=123,
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='bcq',
+ import_names=['ding.policy.bcq'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_cql_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_cql_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..84a504cadb2d67cc7fd3d6ee7009fa2339f11107
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="halfcheetah_medium_cql_seed0",
+ env=dict(
+ env_id='halfcheetah-medium-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_dt_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7521af6dd5e80ab367fa4bf607554b79a5c82c7c
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_dt_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+halfcheetah_dt_config = dict(
+ exp_name='dt_log/d4rl/halfcheetah/halfcheetah_medium_dt_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ dataset=dict(
+ env_type='mujoco',
+ rtg_scale=1000,
+ context_len=30,
+ data_dir_prefix='d4rl/halfcheetah_medium-v2.pkl',
+ ),
+ policy=dict(
+ cuda=True,
+ stop_value=6000,
+ state_mean=None,
+ state_std=None,
+ evaluator_env_num=8,
+ env_name='HalfCheetah-v3',
+ rtg_target=6000, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ context_len=20,
+ weight_decay=0.1,
+ clip_grad_norm_p=0.25,
+ model=dict(
+ state_dim=11,
+ act_dim=3,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ batch_size=64,
+ learning_rate=1e-4,
+ collect=dict(
+ data_type='d4rl_trajectory',
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+
+halfcheetah_dt_config = EasyDict(halfcheetah_dt_config)
+main_config = halfcheetah_dt_config
+halfcheetah_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+halfcheetah_dt_create_config = EasyDict(halfcheetah_dt_create_config)
+create_config = halfcheetah_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_edac_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_edac_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..66ea8039dc6f648c2482f50ed65a133a697ef7cb
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_edac_config.py
@@ -0,0 +1,58 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="halfcheetah_medium_edac_seed0",
+ env=dict(
+ env_id='halfcheetah-medium-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=7600,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ensemble_num=10,
+ actor_head_hidden_size=256,
+ actor_head_layer_num=3,
+ critic_head_hidden_size=256,
+ critic_head_layer_num=3,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ alpha=1,
+ auto_alpha=True,
+ eta=1.0,
+ with_q_entropy=False,
+ learner=dict(hook=dict(save_ckpt_after_iter=100000, )),
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+ seed=0,
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='edac',
+ import_names=['ding.policy.edac'],
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_bcq_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_bcq_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..6c3ac39c18198b1fbe6259d6798b82ee07ad33c4
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_bcq_config.py
@@ -0,0 +1,55 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="halfcheetah_medium_expert_bcq_seed0",
+ env=dict(
+ env_id='halfcheetah-medium-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=12000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ actor_head_hidden_size=[400, 300],
+ critic_head_hidden_size=[400, 300],
+ phi=0.05,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=100,
+ learning_rate_q=3e-3,
+ learning_rate_policy=3e-3,
+ learning_rate_alpha=3e-3,
+ lmbda=0.75,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+ seed=123,
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='bcq',
+ import_names=['ding.policy.bcq'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_cql_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_cql_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..05aa2d175214c18d5fc7adb5cc6931c2d34df775
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="halfcheetah_medium_expert_cql_seed0",
+ env=dict(
+ env_id='halfcheetah-medium-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_dt_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1f9c636d20ec8b99e42bfefd8c3ab404c60f575a
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_dt_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+halfcheetah_dt_config = dict(
+ exp_name='dt_log/d4rl/halfcheetah/halfcheetah_medium_expert_dt_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ dataset=dict(
+ env_type='mujoco',
+ rtg_scale=1000,
+ context_len=30,
+ data_dir_prefix='d4rl/halfcheetah_medium_expert-v2.pkl',
+ ),
+ policy=dict(
+ cuda=True,
+ stop_value=6000,
+ state_mean=None,
+ state_std=None,
+ evaluator_env_num=8,
+ env_name='HalfCheetah-v3',
+ rtg_target=6000, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ context_len=20,
+ weight_decay=0.1,
+ clip_grad_norm_p=0.25,
+ model=dict(
+ state_dim=11,
+ act_dim=3,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ batch_size=64,
+ learning_rate=1e-4,
+ collect=dict(
+ data_type='d4rl_trajectory',
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+
+halfcheetah_dt_config = EasyDict(halfcheetah_dt_config)
+main_config = halfcheetah_dt_config
+halfcheetah_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+halfcheetah_dt_create_config = EasyDict(halfcheetah_dt_create_config)
+create_config = halfcheetah_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_edac_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_edac_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..17e897f048c8ba690a8ef4bbdc42f57e8096c0e3
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_edac_config.py
@@ -0,0 +1,58 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="halfcheetah_medium_expert_edac_seed123",
+ env=dict(
+ env_id='halfcheetah-medium-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=13000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ensemble_num=10,
+ actor_head_hidden_size=256,
+ actor_head_layer_num=3,
+ critic_head_hidden_size=256,
+ critic_head_layer_num=3,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ alpha=1,
+ auto_alpha=True,
+ eta=5.0,
+ with_q_entropy=False,
+ learner=dict(hook=dict(save_ckpt_after_iter=100000, )),
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+ seed=123,
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='edac',
+ import_names=['ding.policy.edac'],
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..66c8ba8d91d236f6342049d77e8eb1f8572d2766
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_pd_config.py
@@ -0,0 +1,101 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="halfcheetah_medium_expert_pd_seed0",
+ env=dict(
+ env_id='halfcheetah-medium-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ returns_scale=1.0,
+ termination_penalty=-100,
+ max_path_length=1000,
+ use_padding=True,
+ include_returns=True,
+ normed=False,
+ stop_value=12000,
+ horizon=4,
+ obs_dim=17,
+ action_dim=6,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ diffuser_model='GaussianDiffusion',
+ diffuser_model_cfg=dict(
+ model='DiffusionUNet1d',
+ model_cfg=dict(
+ transition_dim=23,
+ dim=32,
+ dim_mults=[1, 4, 8],
+ returns_condition=False,
+ kernel_size=5,
+ attention=True,
+ ),
+ horizon=4,
+ obs_dim=17,
+ action_dim=6,
+ n_timesteps=20,
+ predict_epsilon=False,
+ loss_discount=1,
+ action_weight=10,
+ ),
+ value_model='ValueDiffusion',
+ value_model_cfg=dict(
+ model='TemporalValue',
+ model_cfg=dict(
+ horizon = 4,
+ transition_dim=23,
+ dim=32,
+ dim_mults=[1, 4, 8],
+ kernel_size=5,
+ ),
+ horizon=4,
+ obs_dim=17,
+ action_dim=6,
+ n_timesteps=20,
+ predict_epsilon=True,
+ loss_discount=1,
+ ),
+ n_guide_steps=2,
+ scale=0.001,
+ t_stopgrad=4,
+ scale_grad_by_std=True,
+ ),
+ normalizer='GaussianNormalizer',
+ learn=dict(
+ data_path=None,
+ train_epoch=60000,
+ gradient_accumulate_every=2,
+ batch_size=32,
+ learning_rate=2e-4,
+ discount_factor=0.99,
+ plan_batch_size=64,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='diffuser_traj', ),
+ eval=dict(
+ evaluator=dict(eval_freq=500, ),
+ test_ret=0.9,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='pd',
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_td3bc_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_td3bc_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..ed99a2d3f08c14c9407f9ff6aed0c37c2dff4b81
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_expert_td3bc_config.py
@@ -0,0 +1,65 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='halfcheetah_medium_expert_td3-bc_seed0',
+ env=dict(
+ env_id='halfcheetah-medium-expert-v2',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=13000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_pd_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_pd_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..674395a4e16de200c70667539afbf1b85941d4b7
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_pd_config.py
@@ -0,0 +1,101 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="halfcheetah_medium_pd_seed0",
+ env=dict(
+ env_id='halfcheetah-medium-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ returns_scale=1.0,
+ termination_penalty=-100,
+ max_path_length=1000,
+ use_padding=True,
+ include_returns=True,
+ normed=False,
+ stop_value=8000,
+ horizon=4,
+ obs_dim=17,
+ action_dim=6,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ diffuser_model='GaussianDiffusion',
+ diffuser_model_cfg=dict(
+ model='DiffusionUNet1d',
+ model_cfg=dict(
+ transition_dim=23,
+ dim=32,
+ dim_mults=[1, 4, 8],
+ returns_condition=False,
+ kernel_size=5,
+ attention=True,
+ ),
+ horizon=4,
+ obs_dim=17,
+ action_dim=6,
+ n_timesteps=20,
+ predict_epsilon=False,
+ loss_discount=1,
+ action_weight=10,
+ ),
+ value_model='ValueDiffusion',
+ value_model_cfg=dict(
+ model='TemporalValue',
+ model_cfg=dict(
+ horizon = 4,
+ transition_dim=23,
+ dim=32,
+ dim_mults=[1, 4, 8],
+ kernel_size=5,
+ ),
+ horizon=4,
+ obs_dim=17,
+ action_dim=6,
+ n_timesteps=20,
+ predict_epsilon=True,
+ loss_discount=1,
+ ),
+ n_guide_steps=2,
+ scale=0.001,
+ t_stopgrad=4,
+ scale_grad_by_std=True,
+ ),
+ normalizer='GaussianNormalizer',
+ learn=dict(
+ data_path=None,
+ train_epoch=60000,
+ gradient_accumulate_every=2,
+ batch_size=32,
+ learning_rate=2e-4,
+ discount_factor=0.99,
+ plan_batch_size=64,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='diffuser_traj', ),
+ eval=dict(
+ evaluator=dict(eval_freq=500, ),
+ test_ret=0.9,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='pd',
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_replay_cql_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_replay_cql_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..823e08d37073312691bc54181c00db997cb314e7
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_replay_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="halfcheetah_medium_replay_cql_seed0",
+ env=dict(
+ env_id='halfcheetah-medium-replay-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_replay_dt_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_replay_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..aa07e22280e4ac40a0dc46a56dd037face2a3b8d
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_replay_dt_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+halfcheetah_dt_config = dict(
+ exp_name='dt_log/d4rl/halfcheetah/halfcheetah_medium_replay_dt_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ dataset=dict(
+ env_type='mujoco',
+ rtg_scale=1000,
+ context_len=30,
+ data_dir_prefix='d4rl/halfcheetah_medium_replay-v2.pkl',
+ ),
+ policy=dict(
+ cuda=True,
+ stop_value=6000,
+ state_mean=None,
+ state_std=None,
+ evaluator_env_num=8,
+ env_name='HalfCheetah-v3',
+ rtg_target=6000, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ context_len=20,
+ weight_decay=0.1,
+ clip_grad_norm_p=0.25,
+ model=dict(
+ state_dim=11,
+ act_dim=3,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ batch_size=64,
+ learning_rate=1e-4,
+ collect=dict(
+ data_type='d4rl_trajectory',
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+
+halfcheetah_dt_config = EasyDict(halfcheetah_dt_config)
+main_config = halfcheetah_dt_config
+halfcheetah_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+halfcheetah_dt_create_config = EasyDict(halfcheetah_dt_create_config)
+create_config = halfcheetah_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_replay_td3bc_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_replay_td3bc_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..22cf7ff544bccd99388ebc11776062cac5df520d
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_replay_td3bc_config.py
@@ -0,0 +1,65 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='halfcheetah_medium_replay_td3-bc_seed0',
+ env=dict(
+ env_id='halfcheetah-medium-replay-v2',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_medium_td3bc_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_td3bc_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..541588e2619ae7e9d34b80933d2bffdba5728b64
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_medium_td3bc_config.py
@@ -0,0 +1,65 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='halfcheetah_medium_td3-bc_seed0',
+ env=dict(
+ env_id='halfcheetah-medium-v2',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=7600,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_random_cql_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_random_cql_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..58ead98dc57868cd0e7f4be78128f49b1a4805ba
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_random_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="halfcheetah_expert_cql_seed0",
+ env=dict(
+ env_id='halfcheetah-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_random_dt_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_random_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc912c5e5528d074d27892ad2c110001889896b8
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_random_dt_config.py
@@ -0,0 +1,82 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+halfcheetah_dt_config = dict(
+ exp_name='halfcheetah_random_dt_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ stop_value=6000,
+ cuda=True,
+ env_name='HalfCheetah-v3',
+ rtg_target=6000, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ num_eval_ep=10, # num of evaluation episode
+ batch_size=64,
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ num_updates_per_iter=100,
+ context_len=20,
+ n_blocks=3,
+ embed_dim=128,
+ n_heads=1,
+ dropout_p=0.1,
+ log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/halfcheetah_random_dt_log',
+ model=dict(
+ state_dim=17,
+ act_dim=6,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ discount_factor=0.999,
+ nstep=3,
+ learn=dict(
+ dataset_path='/mnt/lustre/wangzilin/d4rl_data/halfcheetah-random-v2.pkl',
+ learning_rate=0.0001,
+ target_update_freq=100,
+ kappa=1.0,
+ min_q_weight=4.0
+ ),
+ collect=dict(unroll_len=1, ),
+ eval=dict(evaluator=dict(evalu_freq=100, ), ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=1000, ),
+ ),
+ ),
+)
+
+halfcheetah_dt_config = EasyDict(halfcheetah_dt_config)
+main_config = halfcheetah_dt_config
+halfcheetah_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+halfcheetah_dt_create_config = EasyDict(halfcheetah_dt_create_config)
+create_config = halfcheetah_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/halfcheetah_random_td3bc_config.py b/DI-engine/dizoo/d4rl/config/halfcheetah_random_td3bc_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..85c03478bc2c6d971ae77b2f5a7f06a661a80016
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/halfcheetah_random_td3bc_config.py
@@ -0,0 +1,65 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='halfcheetah_random_td3-bc_seed0',
+ env=dict(
+ env_id='halfcheetah-random-v2',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_expert_cql_config.py b/DI-engine/dizoo/d4rl/config/hopper_expert_cql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b0d175c73712647f7c12fbfb97f0728e0b7b6a6
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_expert_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="hopper_expert_cql_seed0",
+ env=dict(
+ env_id='hopper-expert-v0',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_expert_dt_config.py b/DI-engine/dizoo/d4rl/config/hopper_expert_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..26387afde53e83885801e1f4a5f52524c0eaeca0
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_expert_dt_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+hopper_dt_config = dict(
+ exp_name='dt_log/d4rl/hopper/hopper_expert_dt_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=3600,
+ ),
+ dataset=dict(
+ env_type='mujoco',
+ rtg_scale=1000,
+ context_len=20,
+ data_dir_prefix='d4rl/hopper_expert-v2.pkl',
+ ),
+ policy=dict(
+ cuda=True,
+ stop_value=3600,
+ state_mean=None,
+ state_std=None,
+ evaluator_env_num=8,
+ env_name='Hopper-v3',
+ rtg_target=3600, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ context_len=20,
+ weight_decay=0.1,
+ clip_grad_norm_p=0.25,
+ model=dict(
+ state_dim=11,
+ act_dim=3,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ batch_size=64,
+ learning_rate=1e-4,
+ collect=dict(
+ data_type='d4rl_trajectory',
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+
+hopper_dt_config = EasyDict(hopper_dt_config)
+main_config = hopper_dt_config
+hopper_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+hopper_dt_create_config = EasyDict(hopper_dt_create_config)
+create_config = hopper_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/hopper_expert_td3bc_config.py b/DI-engine/dizoo/d4rl/config/hopper_expert_td3bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0e35474c48d332f2a7fff9e0095106cb25fd2e12
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_expert_td3bc_config.py
@@ -0,0 +1,61 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='hopper_expert_td3-bc_seed0',
+ env=dict(
+ env_id='hopper-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_bcq_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_bcq_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..06282d16808338b26757e964d4dd229b554a31b1
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_bcq_config.py
@@ -0,0 +1,55 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="hopper_medium_bcq_seed0_43_v0",
+ env=dict(
+ env_id='hopper-medium-v0',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=3500,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ actor_head_hidden_size=[400, 300],
+ critic_head_hidden_size=[400, 300],
+ phi=0.05,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=100,
+ learning_rate_q=3e-3,
+ learning_rate_policy=3e-3,
+ learning_rate_alpha=3e-3,
+ lmbda=0.75,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+ seed=123,
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='bcq',
+ import_names=['ding.policy.bcq'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_cql_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_cql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..09db9ea287c2366e6a78b0997a1fd35330abcd6e
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="hopper_medium_cql_seed0",
+ env=dict(
+ env_id='hopper-medium-v0',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_dt_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a5c389e67eb486faac8beb40363d96cdf43fa3ab
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_dt_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+hopper_dt_config = dict(
+ exp_name='dt_log/d4rl/hopper/hopper_medium_dt_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=3600,
+ ),
+ dataset=dict(
+ env_type='mujoco',
+ rtg_scale=1000,
+ context_len=20,
+ data_dir_prefix='d4rl/hopper_medium_expert-v2.pkl',
+ ),
+ policy=dict(
+ cuda=True,
+ stop_value=3600,
+ state_mean=None,
+ state_std=None,
+ evaluator_env_num=8,
+ env_name='Hopper-v3',
+ rtg_target=3600, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ context_len=20,
+ weight_decay=0.1,
+ clip_grad_norm_p=0.25,
+ model=dict(
+ state_dim=11,
+ act_dim=3,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ batch_size=64,
+ learning_rate=1e-4,
+ collect=dict(
+ data_type='d4rl_trajectory',
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ ),
+)
+
+hopper_dt_config = EasyDict(hopper_dt_config)
+main_config = hopper_dt_config
+hopper_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+hopper_dt_create_config = EasyDict(hopper_dt_create_config)
+create_config = hopper_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_edac_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_edac_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..f14fad350f428f65537fbbc1d1c14b75f290d39e
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_edac_config.py
@@ -0,0 +1,58 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="hopper_medium_edac_seed0",
+ env=dict(
+ env_id='hopper-medium-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=3700,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ ensemble_num=50,
+ actor_head_hidden_size=256,
+ actor_head_layer_num=3,
+ critic_head_hidden_size=256,
+ critic_head_layer_num=3,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ alpha=1,
+ auto_alpha=True,
+ eta=1.0,
+ with_q_entropy=False,
+ learner=dict(hook=dict(save_ckpt_after_iter=100000, )),
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+ seed=0,
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='edac',
+ import_names=['ding.policy.edac'],
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_expert_bc_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..348361dd2d51b4f81b7b4954d82846d7e0f69bc3
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_bc_config.py
@@ -0,0 +1,58 @@
+from easydict import EasyDict
+
+cuda = True
+multi_gpu = False
+
+main_config = dict(
+ exp_name='hopper_medium_expert_bc_seed0',
+ env=dict(
+ env_id='hopper-medium-expert-v0',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ use_act_scale=True,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=cuda,
+ continuous=True,
+ loss_type='mse_loss',
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ action_space='regression',
+ actor_head_hidden_size=512,
+ actor_head_layer_num=4,
+ ),
+ learn=dict(
+ multi_gpu=multi_gpu,
+ train_epoch=15,
+ batch_size=256,
+ learning_rate=1e-5,
+ learner=dict(hook=dict(log_show_after_iter=1000)),
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=-1, )),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base', ),
+ policy=dict(
+ type='bc',
+ import_names=['ding.policy.bc'],
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_expert_bcq_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_bcq_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..ac48ee4847b3c84dffe60c357deb2fb51f850288
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_bcq_config.py
@@ -0,0 +1,55 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="hopper_medium_expert_bcq_seed0",
+ env=dict(
+ env_id='hopper-medium-expert-v0',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=3800,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ actor_head_hidden_size=[400, 300],
+ critic_head_hidden_size=[400, 300],
+ phi=0.05,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=100,
+ learning_rate_q=3e-3,
+ learning_rate_policy=3e-3,
+ learning_rate_alpha=3e-3,
+ lmbda=0.75,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+ seed=123,
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='bcq',
+ import_names=['ding.policy.bcq'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_expert_cql_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_cql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c089d4bd33426a252fec9a5a2ecb7ffe9d0ae2c
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="hopper_medium_expert_cql_seed0",
+ env=dict(
+ env_id='hopper-medium-expert-v0',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_expert_dt_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5934590bf1c1fa94ac793c94e07cfa634262a077
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_dt_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+hopper_dt_config = dict(
+ exp_name='dt_log/d4rl/hopper/hopper_medium_expert_dt',
+ env=dict(
+ env_id='Hopper-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=3600,
+ ),
+ dataset=dict(
+ env_type='mujoco',
+ rtg_scale=1000,
+ context_len=20,
+ data_dir_prefix='d4rl/hopper_medium_expert.pkl',
+ ),
+ policy=dict(
+ cuda=True,
+ stop_value=3600,
+ state_mean=None,
+ state_std=None,
+ evaluator_env_num=8,
+ env_name='Hopper-v3',
+ rtg_target=3600, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ context_len=20,
+ weight_decay=0.1,
+ clip_grad_norm_p=0.25,
+ model=dict(
+ state_dim=11,
+ act_dim=3,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ batch_size=64,
+ learning_rate=1e-4,
+ collect=dict(
+ data_type='d4rl_trajectory',
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+
+hopper_dt_config = EasyDict(hopper_dt_config)
+main_config = hopper_dt_config
+hopper_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+hopper_dt_create_config = EasyDict(hopper_dt_create_config)
+create_config = hopper_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_expert_edac_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_edac_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..5bbc5b375dbcf6faf078d0f8a076c883f2b6c947
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_edac_config.py
@@ -0,0 +1,58 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="hopper_medium_expert_edac_seed0",
+ env=dict(
+ env_id='hopper-medium-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=5000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ ensemble_num=50,
+ actor_head_hidden_size=256,
+ actor_head_layer_num=3,
+ critic_head_hidden_size=256,
+ critic_head_layer_num=3,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ alpha=1,
+ auto_alpha=False,
+ eta=1.0,
+ with_q_entropy=False,
+ learner=dict(hook=dict(save_ckpt_after_iter=100000, )),
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+ seed=0,
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='edac',
+ import_names=['ding.policy.edac'],
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_expert_ibc_ar_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_ibc_ar_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5d1090dc772873057c8ea11ae50f5ffc0f73fc9f
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_ibc_ar_config.py
@@ -0,0 +1,54 @@
+from easydict import EasyDict
+
+cuda = True
+multi_gpu = False
+
+main_config = dict(
+ exp_name='hopper_medium_expert_ibc_ar_seed0',
+ env=dict(
+ env_id='hopper-medium-expert-v0',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ use_act_scale=True,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=cuda,
+ model=dict(obs_shape=11, action_shape=3, stochastic_optim=dict(type='ardfo', )),
+ learn=dict(
+ multi_gpu=multi_gpu,
+ train_epoch=15,
+ batch_size=256,
+ optim=dict(learning_rate=1e-5, ),
+ learner=dict(hook=dict(log_show_after_iter=1000)),
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=-1, )),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base', ),
+ policy=dict(
+ type='ibc',
+ import_names=['ding.policy.ibc'],
+ model=dict(
+ type='arebm',
+ import_names=['ding.model.template.ebm'],
+ ),
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_expert_ibc_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_ibc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0f040970e60a44bf224372b9957cf7b83a52b410
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_ibc_config.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+
+cuda = True
+multi_gpu = False
+
+main_config = dict(
+ exp_name='hopper_medium_expert_ibc_seed0',
+ env=dict(
+ env_id='hopper-medium-expert-v0',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ use_act_scale=True,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=cuda,
+ model=dict(obs_shape=11, action_shape=3, stochastic_optim=dict(type='dfo', )),
+ learn=dict(
+ multi_gpu=multi_gpu,
+ train_epoch=15,
+ batch_size=256,
+ optim=dict(learning_rate=1e-5, ),
+ learner=dict(hook=dict(log_show_after_iter=1000)),
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=-1, )),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base', ),
+ policy=dict(
+ type='ibc',
+ import_names=['ding.policy.ibc'],
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_expert_ibc_mcmc_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_ibc_mcmc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..478e0c5d44b1fb04c01cfc58e22815a04707da1a
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_ibc_mcmc_config.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+
+cuda = True
+multi_gpu = False
+
+main_config = dict(
+ exp_name='hopper_medium_expert_ibc_mcmc_seed0',
+ env=dict(
+ env_id='hopper-medium-expert-v0',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ use_act_scale=True,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=cuda,
+ model=dict(obs_shape=11, action_shape=3, stochastic_optim=dict(type='mcmc', )),
+ learn=dict(
+ multi_gpu=multi_gpu,
+ train_epoch=15,
+ batch_size=256,
+ optim=dict(learning_rate=1e-5, ),
+ learner=dict(hook=dict(log_show_after_iter=1000)),
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=-1, )),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base', ),
+ policy=dict(
+ type='ibc',
+ import_names=['ding.policy.ibc'],
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_expert_pd_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_pd_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..3df47f8d1b741ad85c3871f8d88b6856e70444aa
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_pd_config.py
@@ -0,0 +1,101 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="hopper_medium_expert_pd_seed0",
+ env=dict(
+ env_id='hopper-medium-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ returns_scale=1.0,
+ termination_penalty=-100,
+ max_path_length=1000,
+ use_padding=True,
+ include_returns=True,
+ normed=False,
+ stop_value=8000,
+ horizon=32,
+ obs_dim=11,
+ action_dim=3,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ diffuser_model='GaussianDiffusion',
+ diffuser_model_cfg=dict(
+ model='DiffusionUNet1d',
+ model_cfg=dict(
+ transition_dim=14,
+ dim=32,
+ dim_mults=[1, 2, 4, 8],
+ returns_condition=False,
+ kernel_size=5,
+ attention=False,
+ ),
+ horizon=32,
+ obs_dim=11,
+ action_dim=3,
+ n_timesteps=20,
+ predict_epsilon=False,
+ loss_discount=1,
+ action_weight=10,
+ ),
+ value_model='ValueDiffusion',
+ value_model_cfg=dict(
+ model='TemporalValue',
+ model_cfg=dict(
+ horizon = 32,
+ transition_dim=14,
+ dim=32,
+ dim_mults=[1, 2, 4, 8],
+ kernel_size=5,
+ ),
+ horizon=32,
+ obs_dim=11,
+ action_dim=3,
+ n_timesteps=20,
+ predict_epsilon=True,
+ loss_discount=1,
+ ),
+ n_guide_steps=2,
+ scale=0.0001,
+ t_stopgrad=4,
+ scale_grad_by_std=True,
+ ),
+ normalizer='GaussianNormalizer',
+ learn=dict(
+ data_path=None,
+ train_epoch=60000,
+ gradient_accumulate_every=2,
+ batch_size=32,
+ learning_rate=2e-4,
+ discount_factor=0.99,
+ plan_batch_size=64,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='diffuser_traj', ),
+ eval=dict(
+ evaluator=dict(eval_freq=500, ),
+ test_ret=0.9,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='pd',
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
\ No newline at end of file
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_expert_td3bc_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_td3bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b51ed523fa17dd03a3a671fe1cbc73a2a6169da5
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_expert_td3bc_config.py
@@ -0,0 +1,61 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='hopper_medium_expert_td3-bc_seed0',
+ env=dict(
+ env_id='hopper-medium-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_pd_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_pd_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..8dfee5d824bf1665e015b145bd8a59bc02a22889
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_pd_config.py
@@ -0,0 +1,101 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="hopper_medium_pd_seed0",
+ env=dict(
+ env_id='hopper-medium-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ returns_scale=1.0,
+ termination_penalty=-100,
+ max_path_length=1000,
+ use_padding=True,
+ include_returns=True,
+ normed=False,
+ stop_value=8000,
+ horizon=32,
+ obs_dim=11,
+ action_dim=3,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ diffuser_model='GaussianDiffusion',
+ diffuser_model_cfg=dict(
+ model='DiffusionUNet1d',
+ model_cfg=dict(
+ transition_dim=14,
+ dim=32,
+ dim_mults=[1, 2, 4, 8],
+ returns_condition=False,
+ kernel_size=5,
+ attention=False,
+ ),
+ horizon=32,
+ obs_dim=11,
+ action_dim=3,
+ n_timesteps=20,
+ predict_epsilon=False,
+ loss_discount=1,
+ action_weight=10,
+ ),
+ value_model='ValueDiffusion',
+ value_model_cfg=dict(
+ model='TemporalValue',
+ model_cfg=dict(
+ horizon = 32,
+ transition_dim=14,
+ dim=32,
+ dim_mults=[1, 2, 4, 8],
+ kernel_size=5,
+ ),
+ horizon=32,
+ obs_dim=11,
+ action_dim=3,
+ n_timesteps=20,
+ predict_epsilon=True,
+ loss_discount=1,
+ ),
+ n_guide_steps=2,
+ scale=0.1,
+ t_stopgrad=2,
+ scale_grad_by_std=True,
+ ),
+ normalizer='GaussianNormalizer',
+ learn=dict(
+ data_path=None,
+ train_epoch=60000,
+ gradient_accumulate_every=2,
+ batch_size=32,
+ learning_rate=2e-4,
+ discount_factor=0.99,
+ plan_batch_size=64,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='diffuser_traj', ),
+ eval=dict(
+ evaluator=dict(eval_freq=500, ),
+ test_ret=0.9,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='pd',
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
\ No newline at end of file
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_replay_cql_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_replay_cql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d995bfd5643b1446b4735a031c7bd89d416a40a4
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_replay_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="hopper_medium_replay_cql_seed0",
+ env=dict(
+ env_id='hopper-medium-replay-v0',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_replay_dt_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_replay_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2615ba1b9b0727a879ba377aaa8ad7171b08157
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_replay_dt_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+hopper_dt_config = dict(
+ exp_name='dt_log/d4rl/hopper/hopper_medium_replay_dt_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=3600,
+ ),
+ dataset=dict(
+ env_type='mujoco',
+ rtg_scale=1000,
+ context_len=30,
+ data_dir_prefix='d4rl/hopper_medium_replay-v2.pkl',
+ ),
+ policy=dict(
+ cuda=True,
+ stop_value=3600,
+ state_mean=None,
+ state_std=None,
+ evaluator_env_num=8,
+ env_name='Hopper-v3',
+ rtg_target=3600, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ context_len=20,
+ weight_decay=0.1,
+ clip_grad_norm_p=0.25,
+ model=dict(
+ state_dim=11,
+ act_dim=3,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ batch_size=64,
+ learning_rate=1e-4,
+ collect=dict(
+ data_type='d4rl_trajectory',
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+
+hopper_dt_config = EasyDict(hopper_dt_config)
+main_config = hopper_dt_config
+hopper_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+hopper_dt_create_config = EasyDict(hopper_dt_create_config)
+create_config = hopper_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_replay_td3bc_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_replay_td3bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ca9cdef0632393cbac3f843b512b661b1518714
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_replay_td3bc_config.py
@@ -0,0 +1,61 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='hopper_medium_replay_td3-bc_seed0',
+ env=dict(
+ env_id='hopper-medium-replay-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_medium_td3bc_config.py b/DI-engine/dizoo/d4rl/config/hopper_medium_td3bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd318545e6de3974703d30be0be3c79abc998274
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_medium_td3bc_config.py
@@ -0,0 +1,61 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='hopper_medium_td3-bc_seed0',
+ env=dict(
+ env_id='hopper-medium-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_random_cql_config.py b/DI-engine/dizoo/d4rl/config/hopper_random_cql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b0d175c73712647f7c12fbfb97f0728e0b7b6a6
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_random_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="hopper_expert_cql_seed0",
+ env=dict(
+ env_id='hopper-expert-v0',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/hopper_random_dt_config.py b/DI-engine/dizoo/d4rl/config/hopper_random_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a009058eba2f07ab7c03d1e4c0a01c5a64a68897
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_random_dt_config.py
@@ -0,0 +1,82 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+hopper_dt_config = dict(
+ exp_name='hopper_random_dt_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ stop_value=6000,
+ cuda=True,
+ env_name='Hopper-v3',
+ rtg_target=6000, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ num_eval_ep=10, # num of evaluation episode
+ batch_size=64,
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ num_updates_per_iter=100,
+ context_len=20,
+ n_blocks=3,
+ embed_dim=128,
+ n_heads=1,
+ dropout_p=0.1,
+ log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/hopper_random_dt_log',
+ model=dict(
+ state_dim=11,
+ act_dim=3,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ discount_factor=0.999,
+ nstep=3,
+ learn=dict(
+ dataset_path='/mnt/lustre/wangzilin/d4rl_data/hopper-random-v2.pkl',
+ learning_rate=0.0001,
+ target_update_freq=100,
+ kappa=1.0,
+ min_q_weight=4.0
+ ),
+ collect=dict(unroll_len=1, ),
+ eval=dict(evaluator=dict(evalu_freq=100, ), ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=1000, ),
+ ),
+ ),
+)
+
+hopper_dt_config = EasyDict(hopper_dt_config)
+main_config = hopper_dt_config
+hopper_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+hopper_dt_create_config = EasyDict(hopper_dt_create_config)
+create_config = hopper_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/hopper_random_td3bc_config.py b/DI-engine/dizoo/d4rl/config/hopper_random_td3bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..abc1f3ab60825edb647711fa1b0368d14fea9162
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/hopper_random_td3bc_config.py
@@ -0,0 +1,65 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='hopper_random_td3-bc_seed0',
+ env=dict(
+ env_id='hopper-random-v0',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/kitchen_complete_bc_config.py b/DI-engine/dizoo/d4rl/config/kitchen_complete_bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..413696993d9ee3472566077aefd1b9e7e354b0e6
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/kitchen_complete_bc_config.py
@@ -0,0 +1,58 @@
+from easydict import EasyDict
+
+cuda = True
+multi_gpu = False
+
+main_config = dict(
+ exp_name='kitchen_complete_bc_seed0',
+ env=dict(
+ env_id='kitchen-complete-v0',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ use_act_scale=True,
+ stop_value=1e10,
+ ),
+ policy=dict(
+ cuda=cuda,
+ continuous=True,
+ loss_type='mse_loss',
+ model=dict(
+ obs_shape=60,
+ action_shape=9,
+ action_space='regression',
+ actor_head_hidden_size=512,
+ actor_head_layer_num=4,
+ ),
+ learn=dict(
+ multi_gpu=multi_gpu,
+ train_epoch=1000,
+ batch_size=256,
+ learning_rate=1e-5,
+ learner=dict(hook=dict(log_show_after_iter=100)),
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base', ),
+ policy=dict(
+ type='bc',
+ import_names=['ding.policy.bc'],
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/kitchen_complete_ibc_ar_config.py b/DI-engine/dizoo/d4rl/config/kitchen_complete_ibc_ar_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbb7198af0485a4b86811a69c9c83dadae46d7fa
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/kitchen_complete_ibc_ar_config.py
@@ -0,0 +1,54 @@
+from easydict import EasyDict
+
+cuda = True
+multi_gpu = False
+
+main_config = dict(
+ exp_name='kitchen_complete_ibc_ar_seed0',
+ env=dict(
+ env_id='kitchen-complete-v0',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ use_act_scale=True,
+ stop_value=1e10,
+ ),
+ policy=dict(
+ cuda=cuda,
+ model=dict(obs_shape=60, action_shape=9, stochastic_optim=dict(type='ardfo', )),
+ learn=dict(
+ multi_gpu=multi_gpu,
+ train_epoch=1000,
+ batch_size=256,
+ optim=dict(learning_rate=1e-5, ),
+ learner=dict(hook=dict(log_show_after_iter=100)),
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base', ),
+ policy=dict(
+ type='ibc',
+ import_names=['ding.policy.ibc'],
+ model=dict(
+ type='arebm',
+ import_names=['ding.model.template.ebm'],
+ ),
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/kitchen_complete_ibc_config.py b/DI-engine/dizoo/d4rl/config/kitchen_complete_ibc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1606cb7792d05f6cd9967e474a75294999d1c3cc
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/kitchen_complete_ibc_config.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+
+cuda = True
+multi_gpu = False
+
+main_config = dict(
+ exp_name='kitchen_complete_ibc_seed0',
+ env=dict(
+ env_id='kitchen-complete-v0',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ use_act_scale=True,
+ stop_value=1e10,
+ ),
+ policy=dict(
+ cuda=cuda,
+ model=dict(obs_shape=60, action_shape=9, stochastic_optim=dict(type='dfo', )),
+ learn=dict(
+ multi_gpu=multi_gpu,
+ train_epoch=1000,
+ batch_size=256,
+ optim=dict(learning_rate=1e-5, ),
+ learner=dict(hook=dict(log_show_after_iter=100)),
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base', ),
+ policy=dict(
+ type='ibc',
+ import_names=['ding.policy.ibc'],
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/kitchen_complete_ibc_mcmc_config.py b/DI-engine/dizoo/d4rl/config/kitchen_complete_ibc_mcmc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..14924d525741019b7d26f388aaa8b515303f85ab
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/kitchen_complete_ibc_mcmc_config.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+
+cuda = True
+multi_gpu = False
+
+main_config = dict(
+ exp_name='kitchen_complete_ibc_mcmc_seed0',
+ env=dict(
+ env_id='kitchen-complete-v0',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ use_act_scale=True,
+ stop_value=1e10,
+ ),
+ policy=dict(
+ cuda=cuda,
+ model=dict(obs_shape=60, action_shape=9, stochastic_optim=dict(type='mcmc', )),
+ learn=dict(
+ multi_gpu=multi_gpu,
+ train_epoch=1000,
+ batch_size=256,
+ optim=dict(learning_rate=1e-5, ),
+ learner=dict(hook=dict(log_show_after_iter=100)),
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base', ),
+ policy=dict(
+ type='ibc',
+ import_names=['ding.policy.ibc'],
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/maze2d_large_pd_config.py b/DI-engine/dizoo/d4rl/config/maze2d_large_pd_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..a68838213a6775b271cf91a18fbbbadf7f52fb7d
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/maze2d_large_pd_config.py
@@ -0,0 +1,84 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="maze2d_large_pd_seed0",
+ env=dict(
+ env_id='maze2d-large-v1',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ returns_scale=1.0,
+ termination_penalty=None,
+ max_path_length=40000,
+ use_padding=False,
+ include_returns=False,
+ normed=False,
+ stop_value=500,
+ horizon=384,
+ obs_dim=4,
+ action_dim=2,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ diffuser_model='GaussianDiffusion',
+ diffuser_model_cfg=dict(
+ model='DiffusionUNet1d',
+ model_cfg=dict(
+ transition_dim=6,
+ dim=32,
+ dim_mults=[1, 4, 8],
+ returns_condition=False,
+ kernel_size=5,
+ attention=False,
+ ),
+ horizon=384,
+ obs_dim=4,
+ action_dim=2,
+ n_timesteps=256,
+ predict_epsilon=False,
+ loss_discount=1,
+ clip_denoised=True,
+ action_weight=1,
+ ),
+ value_model=None,
+ value_model_cfg=None,
+ ),
+ normalizer='LimitsNormalizer',
+ learn=dict(
+ data_path=None,
+ train_epoch=60000,
+ gradient_accumulate_every=2,
+ batch_size=32,
+ learning_rate=2e-4,
+ discount_factor=0.99,
+ plan_batch_size=1,
+ include_returns=False,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='diffuser_traj', ),
+ eval=dict(
+ evaluator=dict(eval_freq=500, ),
+ test_ret=0.9,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='pd',
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
\ No newline at end of file
diff --git a/DI-engine/dizoo/d4rl/config/maze2d_medium_pd_config.py b/DI-engine/dizoo/d4rl/config/maze2d_medium_pd_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..a14cac7480f35185af2f64e0eb43e94ba8470c85
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/maze2d_medium_pd_config.py
@@ -0,0 +1,84 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="maze2d_medium_pd_seed0",
+ env=dict(
+ env_id='maze2d-medium-v1',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ returns_scale=1.0,
+ termination_penalty=None,
+ max_path_length=40000,
+ use_padding=False,
+ include_returns=False,
+ normed=False,
+ stop_value=357,
+ horizon=256,
+ obs_dim=4,
+ action_dim=2,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ diffuser_model='GaussianDiffusion',
+ diffuser_model_cfg=dict(
+ model='DiffusionUNet1d',
+ model_cfg=dict(
+ transition_dim=6,
+ dim=32,
+ dim_mults=[1, 4, 8],
+ returns_condition=False,
+ kernel_size=5,
+ attention=False,
+ ),
+ horizon=256,
+ obs_dim=4,
+ action_dim=2,
+ n_timesteps=256,
+ predict_epsilon=False,
+ loss_discount=1,
+ clip_denoised=True,
+ action_weight=1,
+ ),
+ value_model=None,
+ value_model_cfg=None,
+ ),
+ normalizer='LimitsNormalizer',
+ learn=dict(
+ data_path=None,
+ train_epoch=60000,
+ gradient_accumulate_every=2,
+ batch_size=32,
+ learning_rate=2e-4,
+ discount_factor=0.99,
+ plan_batch_size=1,
+ include_returns=False,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='diffuser_traj', ),
+ eval=dict(
+ evaluator=dict(eval_freq=500, ),
+ test_ret=0.9,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='pd',
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
\ No newline at end of file
diff --git a/DI-engine/dizoo/d4rl/config/maze2d_umaze_pd_config.py b/DI-engine/dizoo/d4rl/config/maze2d_umaze_pd_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..462d10651e8aed8638bd47fbb69a359aae25112a
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/maze2d_umaze_pd_config.py
@@ -0,0 +1,84 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="maze2d_umaze_pd_seed0",
+ env=dict(
+ env_id='maze2d-umaze-v1',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ returns_scale=1.0,
+ termination_penalty=None,
+ max_path_length=40000,
+ use_padding=False,
+ include_returns=False,
+ normed=False,
+ stop_value=190,
+ horizon=128,
+ obs_dim=4,
+ action_dim=2,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ diffuser_model='GaussianDiffusion',
+ diffuser_model_cfg=dict(
+ model='DiffusionUNet1d',
+ model_cfg=dict(
+ transition_dim=6,
+ dim=32,
+ dim_mults=[1, 4, 8],
+ returns_condition=False,
+ kernel_size=5,
+ attention=False,
+ ),
+ horizon=128,
+ obs_dim=4,
+ action_dim=2,
+ n_timesteps=64,
+ predict_epsilon=False,
+ loss_discount=1,
+ clip_denoised=True,
+ action_weight=1,
+ ),
+ value_model=None,
+ value_model_cfg=None,
+ ),
+ normalizer='LimitsNormalizer',
+ learn=dict(
+ data_path=None,
+ train_epoch=60000,
+ gradient_accumulate_every=2,
+ batch_size=32,
+ learning_rate=2e-4,
+ discount_factor=0.99,
+ plan_batch_size=1,
+ include_returns=False,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='diffuser_traj', ),
+ eval=dict(
+ evaluator=dict(eval_freq=500, ),
+ test_ret=0.9,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='pd',
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
\ No newline at end of file
diff --git a/DI-engine/dizoo/d4rl/config/pen_human_bc_config.py b/DI-engine/dizoo/d4rl/config/pen_human_bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..215b706ffc3a3fe92b7b41d542e3cc01343a641d
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/pen_human_bc_config.py
@@ -0,0 +1,58 @@
+from easydict import EasyDict
+
+cuda = True
+multi_gpu = False
+
+main_config = dict(
+ exp_name='pen_human_bc_seed0',
+ env=dict(
+ env_id='pen-human-v0',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ use_act_scale=True,
+ stop_value=1e10,
+ ),
+ policy=dict(
+ cuda=cuda,
+ continuous=True,
+ loss_type='mse_loss',
+ model=dict(
+ obs_shape=45,
+ action_shape=24,
+ action_space='regression',
+ actor_head_hidden_size=512,
+ actor_head_layer_num=4,
+ ),
+ learn=dict(
+ multi_gpu=multi_gpu,
+ train_epoch=1000,
+ batch_size=256,
+ learning_rate=1e-5,
+ learner=dict(hook=dict(log_show_after_iter=100)),
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base', ),
+ policy=dict(
+ type='bc',
+ import_names=['ding.policy.bc'],
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/pen_human_ibc_ar_config.py b/DI-engine/dizoo/d4rl/config/pen_human_ibc_ar_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4f59733fd56eda6a9501c2795f9f7de38a95925b
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/pen_human_ibc_ar_config.py
@@ -0,0 +1,56 @@
+from easydict import EasyDict
+
+cuda = True
+multi_gpu = False
+
+main_config = dict(
+ exp_name='pen_human_ibc_ar_seed0',
+ env=dict(
+ env_id='pen-human-v0',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ use_act_scale=True,
+ stop_value=1e10,
+ ),
+ policy=dict(
+ cuda=cuda,
+ model=dict(
+ obs_shape=45, action_shape=24, hidden_size=128, hidden_layer_num=4, stochastic_optim=dict(type='ardfo', )
+ ),
+ learn=dict(
+ multi_gpu=multi_gpu,
+ train_epoch=1000,
+ batch_size=256,
+ optim=dict(learning_rate=1e-5, ),
+ learner=dict(hook=dict(log_show_after_iter=100)),
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base', ),
+ policy=dict(
+ type='ibc',
+ import_names=['ding.policy.ibc'],
+ model=dict(
+ type='arebm',
+ import_names=['ding.model.template.ebm'],
+ ),
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/pen_human_ibc_config.py b/DI-engine/dizoo/d4rl/config/pen_human_ibc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ed4f6d17bc2fc613f942b35e3d93d8b5a49781f
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/pen_human_ibc_config.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+
+cuda = True
+multi_gpu = False
+
+main_config = dict(
+ exp_name='pen_human_ibc_seed0',
+ env=dict(
+ env_id='pen-human-v0',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ use_act_scale=True,
+ stop_value=1e10,
+ ),
+ policy=dict(
+ cuda=cuda,
+ model=dict(obs_shape=45, action_shape=24, stochastic_optim=dict(type='dfo', )),
+ learn=dict(
+ multi_gpu=multi_gpu,
+ train_epoch=1000,
+ batch_size=256,
+ optim=dict(learning_rate=1e-5, ),
+ learner=dict(hook=dict(log_show_after_iter=100)),
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base', ),
+ policy=dict(
+ type='ibc',
+ import_names=['ding.policy.ibc'],
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/pen_human_ibc_mcmc_config.py b/DI-engine/dizoo/d4rl/config/pen_human_ibc_mcmc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4dd6b37f909f22d9f11ca6d52bb083938661dcf1
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/pen_human_ibc_mcmc_config.py
@@ -0,0 +1,50 @@
+from easydict import EasyDict
+
+cuda = True
+multi_gpu = False
+
+main_config = dict(
+ exp_name='pen_human_ibc_mcmc_seed0',
+ env=dict(
+ env_id='pen-human-v0',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ use_act_scale=True,
+ stop_value=1e10,
+ ),
+ policy=dict(
+ cuda=cuda,
+ model=dict(obs_shape=45, action_shape=24, stochastic_optim=dict(type='mcmc', )),
+ learn=dict(
+ multi_gpu=multi_gpu,
+ train_epoch=1000,
+ batch_size=256,
+ optim=dict(learning_rate=1e-5, ),
+ learner=dict(hook=dict(log_show_after_iter=100)),
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base', ),
+ policy=dict(
+ type='ibc',
+ import_names=['ding.policy.ibc'],
+ ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_expert_cql_config.py b/DI-engine/dizoo/d4rl/config/walker2d_expert_cql_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..346dd1a1f227532c8bf4aa657dba8a88e78310b9
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_expert_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="walker2d_expert_cql_seed0",
+ env=dict(
+ env_id='walker2d-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_expert_dt_config.py b/DI-engine/dizoo/d4rl/config/walker2d_expert_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3658f8ce030466b9044b1eae62fd58957a19238b
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_expert_dt_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+walk2d_dt_config = dict(
+ exp_name='dt_log/d4rl/walk2d/walk2d_expert_dt_seed0',
+ env=dict(
+ env_id='Walk2d-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=5000,
+ ),
+ dataset=dict(
+ env_type='mujoco',
+ rtg_scale=1000,
+ context_len=30,
+ data_dir_prefix='d4rl/walk2d_expert-v2.pkl',
+ ),
+ policy=dict(
+ cuda=True,
+ stop_value=5000,
+ state_mean=None,
+ state_std=None,
+ evaluator_env_num=8,
+ env_name='Walk2d-v3',
+ rtg_target=5000, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ context_len=20,
+ weight_decay=0.1,
+ clip_grad_norm_p=0.25,
+ model=dict(
+ state_dim=11,
+ act_dim=3,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ batch_size=64,
+ learning_rate=1e-4,
+ collect=dict(
+ data_type='d4rl_trajectory',
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+
+walk2d_dt_config = EasyDict(walk2d_dt_config)
+main_config = walk2d_dt_config
+walk2d_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+walk2d_dt_create_config = EasyDict(walk2d_dt_create_config)
+create_config = walk2d_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_expert_td3bc_config.py b/DI-engine/dizoo/d4rl/config/walker2d_expert_td3bc_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..3cae080e92224aafe0c2201e56c684583ed8e6bf
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_expert_td3bc_config.py
@@ -0,0 +1,65 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='halfcheetah_expert_td3-bc_seed0',
+ env=dict(
+ env_id='halfcheetah-expert-v2',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_medium_cql_config.py b/DI-engine/dizoo/d4rl/config/walker2d_medium_cql_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..afacebae0ba7c36ed1e586fbbc7cf875f14199ce
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_medium_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="walker2d_medium_cql_seed0",
+ env=dict(
+ env_id='walker2d-medium-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_medium_dt_config.py b/DI-engine/dizoo/d4rl/config/walker2d_medium_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8d88699ce9e242a0da7a3d35dc9f9f47340c869
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_medium_dt_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+walk2d_dt_config = dict(
+ exp_name='dt_log/d4rl/walk2d/walk2d_medium_dt',
+ env=dict(
+ env_id='Walker2d-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=5000,
+ ),
+ dataset=dict(
+ env_type='mujoco',
+ rtg_scale=1000,
+ context_len=20,
+ data_dir_prefix='d4rl/walker2d_medium-v2.pkl',
+ ),
+ policy=dict(
+ cuda=True,
+ stop_value=5000,
+ state_mean=None,
+ state_std=None,
+ evaluator_env_num=8,
+ env_name='Walker2d-v3',
+ rtg_target=5000, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ context_len=20,
+ weight_decay=0.1,
+ clip_grad_norm_p=0.25,
+ model=dict(
+ state_dim=17,
+ act_dim=6,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ batch_size=64,
+ learning_rate=1e-4,
+ collect=dict(
+ data_type='d4rl_trajectory',
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+
+walk2d_dt_config = EasyDict(walk2d_dt_config)
+main_config = walk2d_dt_config
+walk2d_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+walk2d_dt_create_config = EasyDict(walk2d_dt_create_config)
+create_config = walk2d_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_medium_expert_cql_config.py b/DI-engine/dizoo/d4rl/config/walker2d_medium_expert_cql_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..f05d15c346d88e33aec1e38442aef161a2ab9a2e
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_medium_expert_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="walker2d_medium_expert_cql_seed0",
+ env=dict(
+ env_id='walker2d-medium-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_medium_expert_dt_config.py b/DI-engine/dizoo/d4rl/config/walker2d_medium_expert_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..225d00c2e3d3d091faf94050df76e0c23bf46d5e
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_medium_expert_dt_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+walk2d_dt_config = dict(
+ exp_name='dt_log/d4rl/walk2d/walk2d_medium_expert_dt_seed0',
+ env=dict(
+ env_id='Walk2d-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=5000,
+ ),
+ dataset=dict(
+ env_type='mujoco',
+ rtg_scale=1000,
+ context_len=30,
+ data_dir_prefix='d4rl/walk2d_medium_expert-v2.pkl',
+ ),
+ policy=dict(
+ cuda=True,
+ stop_value=5000,
+ state_mean=None,
+ state_std=None,
+ evaluator_env_num=8,
+ env_name='Walk2d-v3',
+ rtg_target=5000, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ context_len=20,
+ weight_decay=0.1,
+ clip_grad_norm_p=0.25,
+ model=dict(
+ state_dim=11,
+ act_dim=3,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ batch_size=64,
+ learning_rate=1e-4,
+ collect=dict(
+ data_type='d4rl_trajectory',
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+
+walk2d_dt_config = EasyDict(walk2d_dt_config)
+main_config = walk2d_dt_config
+walk2d_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+walk2d_dt_create_config = EasyDict(walk2d_dt_create_config)
+create_config = walk2d_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py b/DI-engine/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..3d4c060e83c404b725fc21814fbc2456e6878f70
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_medium_expert_pd_config.py
@@ -0,0 +1,101 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="walker2d_medium_expert_pd_seed0",
+ env=dict(
+ env_id='walker2d-medium-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ returns_scale=1.0,
+ termination_penalty=-100,
+ max_path_length=1000,
+ use_padding=True,
+ include_returns=True,
+ normed=False,
+ stop_value=8000,
+ horizon=32,
+ obs_dim=17,
+ action_dim=6,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ diffuser_model='GaussianDiffusion',
+ diffuser_model_cfg=dict(
+ model='DiffusionUNet1d',
+ model_cfg=dict(
+ transition_dim=23,
+ dim=32,
+ dim_mults=[1, 2, 4, 8],
+ returns_condition=False,
+ kernel_size=5,
+ attention=False,
+ ),
+ horizon=32,
+ obs_dim=17,
+ action_dim=6,
+ n_timesteps=20,
+ predict_epsilon=False,
+ loss_discount=1,
+ action_weight=10,
+ ),
+ value_model='ValueDiffusion',
+ value_model_cfg=dict(
+ model='TemporalValue',
+ model_cfg=dict(
+ horizon = 32,
+ transition_dim=23,
+ dim=32,
+ dim_mults=[1, 2, 4, 8],
+ kernel_size=5,
+ ),
+ horizon=32,
+ obs_dim=17,
+ action_dim=6,
+ n_timesteps=20,
+ predict_epsilon=True,
+ loss_discount=1,
+ ),
+ n_guide_steps=2,
+ scale=0.1,
+ t_stopgrad=2,
+ scale_grad_by_std=True,
+ ),
+ normalizer='GaussianNormalizer',
+ learn=dict(
+ data_path=None,
+ train_epoch=60000,
+ gradient_accumulate_every=2,
+ batch_size=32,
+ learning_rate=2e-4,
+ discount_factor=0.99,
+ plan_batch_size=64,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='diffuser_traj', ),
+ eval=dict(
+ evaluator=dict(eval_freq=500, ),
+ test_ret=0.9,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='pd',
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
\ No newline at end of file
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_medium_expert_td3bc_config.py b/DI-engine/dizoo/d4rl/config/walker2d_medium_expert_td3bc_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..2473191de5658aeab4d0f36b3dc451499151d28a
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_medium_expert_td3bc_config.py
@@ -0,0 +1,65 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='walker2d_medium_expert_td3-bc_seed0',
+ env=dict(
+ env_id='walker2d-medium-expert-v2',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_medium_pd_config.py b/DI-engine/dizoo/d4rl/config/walker2d_medium_pd_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..29fce259c8eafb5261f50acb0c53323d6fb55a43
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_medium_pd_config.py
@@ -0,0 +1,101 @@
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="walker2d_medium_pd_seed0",
+ env=dict(
+ env_id='walker2d-medium-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ returns_scale=1.0,
+ termination_penalty=-100,
+ max_path_length=1000,
+ use_padding=True,
+ include_returns=True,
+ normed=False,
+ stop_value=8000,
+ horizon=32,
+ obs_dim=17,
+ action_dim=6,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ diffuser_model='GaussianDiffusion',
+ diffuser_model_cfg=dict(
+ model='DiffusionUNet1d',
+ model_cfg=dict(
+ transition_dim=23,
+ dim=32,
+ dim_mults=[1, 2, 4, 8],
+ returns_condition=False,
+ kernel_size=5,
+ attention=False,
+ ),
+ horizon=32,
+ obs_dim=17,
+ action_dim=6,
+ n_timesteps=20,
+ predict_epsilon=False,
+ loss_discount=1,
+ action_weight=10,
+ ),
+ value_model='ValueDiffusion',
+ value_model_cfg=dict(
+ model='TemporalValue',
+ model_cfg=dict(
+ horizon = 32,
+ transition_dim=23,
+ dim=32,
+ dim_mults=[1, 2, 4, 8],
+ kernel_size=5,
+ ),
+ horizon=32,
+ obs_dim=17,
+ action_dim=6,
+ n_timesteps=20,
+ predict_epsilon=True,
+ loss_discount=1,
+ ),
+ n_guide_steps=2,
+ scale=0.1,
+ t_stopgrad=2,
+ scale_grad_by_std=True,
+ ),
+ normalizer='GaussianNormalizer',
+ learn=dict(
+ data_path=None,
+ train_epoch=60000,
+ gradient_accumulate_every=2,
+ batch_size=32,
+ learning_rate=2e-4,
+ discount_factor=0.99,
+ plan_batch_size=64,
+ learner=dict(hook=dict(save_ckpt_after_iter=1000000000, )),
+ ),
+ collect=dict(data_type='diffuser_traj', ),
+ eval=dict(
+ evaluator=dict(eval_freq=500, ),
+ test_ret=0.9,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='pd',
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
\ No newline at end of file
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_medium_replay_cql_config.py b/DI-engine/dizoo/d4rl/config/walker2d_medium_replay_cql_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..23437423b6f27037869d7e9206716aaf39aaa067
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_medium_replay_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="walker2d_medium_replay_cql_seed0",
+ env=dict(
+ env_id='walker2d-medium-replay-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_medium_replay_dt_config.py b/DI-engine/dizoo/d4rl/config/walker2d_medium_replay_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b96375b242edda2bc1dd9165a4b010f55f5bad57
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_medium_replay_dt_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+walk2d_dt_config = dict(
+ exp_name='dt_log/d4rl/walk2d/walk2d_medium_replay_dt_seed0',
+ env=dict(
+ env_id='Walk2d-v3',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=5000,
+ ),
+ dataset=dict(
+ env_type='mujoco',
+ rtg_scale=1000,
+ context_len=30,
+ data_dir_prefix='d4rl/walk2d_medium_replay-v2.pkl',
+ ),
+ policy=dict(
+ cuda=True,
+ stop_value=5000,
+ state_mean=None,
+ state_std=None,
+ evaluator_env_num=8,
+ env_name='Walk2d-v3',
+ rtg_target=5000, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ context_len=20,
+ weight_decay=0.1,
+ clip_grad_norm_p=0.25,
+ model=dict(
+ state_dim=11,
+ act_dim=3,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ batch_size=64,
+ learning_rate=1e-4,
+ collect=dict(
+ data_type='d4rl_trajectory',
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, ), ),
+ ),
+)
+
+walk2d_dt_config = EasyDict(walk2d_dt_config)
+main_config = walk2d_dt_config
+walk2d_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+walk2d_dt_create_config = EasyDict(walk2d_dt_create_config)
+create_config = walk2d_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_medium_replay_td3bc_config.py b/DI-engine/dizoo/d4rl/config/walker2d_medium_replay_td3bc_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..0d13a8d7e2158a75190991cac4ac754a2c10a505
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_medium_replay_td3bc_config.py
@@ -0,0 +1,65 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='walker2d_medium_replay_td3-bc_seed0',
+ env=dict(
+ env_id='walker2d-medium-replay-v2',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_medium_td3bc_config.py b/DI-engine/dizoo/d4rl/config/walker2d_medium_td3bc_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..d496087c96ef2d9b4df35d0a6186fdf7e3741e7b
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_medium_td3bc_config.py
@@ -0,0 +1,65 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='walker2d_medium_td3-bc_seed0',
+ env=dict(
+ env_id='walker2d-medium-v2',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_random_cql_config.py b/DI-engine/dizoo/d4rl/config/walker2d_random_cql_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..346dd1a1f227532c8bf4aa657dba8a88e78310b9
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_random_cql_config.py
@@ -0,0 +1,55 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_cql_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name="walker2d_expert_cql_seed0",
+ env=dict(
+ env_id='walker2d-expert-v2',
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ data_path=None,
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ alpha=0.2,
+ auto_alpha=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(data_type='d4rl', ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+main_config = main_config
+
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_random_dt_config.py b/DI-engine/dizoo/d4rl/config/walker2d_random_dt_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a397f0efb2bf91d38e5d1ae24098ac298b0eb86
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_random_dt_config.py
@@ -0,0 +1,82 @@
+from easydict import EasyDict
+from copy import deepcopy
+
+walker2d_dt_config = dict(
+ exp_name='walker2d_random_dt_seed0',
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ stop_value=6000,
+ cuda=True,
+ env_name='Walker2d-v3',
+ rtg_target=6000, # max target return to go
+ max_eval_ep_len=1000, # max lenght of one episode
+ num_eval_ep=10, # num of evaluation episode
+ batch_size=64,
+ wt_decay=1e-4,
+ warmup_steps=10000,
+ num_updates_per_iter=100,
+ context_len=20,
+ n_blocks=3,
+ embed_dim=128,
+ n_heads=1,
+ dropout_p=0.1,
+ log_dir='/home/wangzilin/research/dt/DI-engine/dizoo/d4rl/dt_data/walker2d_random_dt_log',
+ model=dict(
+ state_dim=17,
+ act_dim=6,
+ n_blocks=3,
+ h_dim=128,
+ context_len=20,
+ n_heads=1,
+ drop_p=0.1,
+ continuous=True,
+ ),
+ discount_factor=0.999,
+ nstep=3,
+ learn=dict(
+ dataset_path='/mnt/lustre/wangzilin/d4rl_data/walker2d-random-v2.pkl',
+ learning_rate=0.0001,
+ target_update_freq=100,
+ kappa=1.0,
+ min_q_weight=4.0
+ ),
+ collect=dict(unroll_len=1, ),
+ eval=dict(evaluator=dict(evalu_freq=100, ), ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=1000, ),
+ ),
+ ),
+)
+
+walker2d_dt_config = EasyDict(walker2d_dt_config)
+main_config = walker2d_dt_config
+walker2d_dt_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dt'),
+)
+walker2d_dt_create_config = EasyDict(walker2d_dt_create_config)
+create_config = walker2d_dt_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_dt
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_dt(config, seed=0, max_train_iter=1000)
diff --git a/DI-engine/dizoo/d4rl/config/walker2d_random_td3bc_config.py b/DI-engine/dizoo/d4rl/config/walker2d_random_td3bc_config.py
new file mode 100755
index 0000000000000000000000000000000000000000..a38e1e8662fc854bb8f53d6b5e7ca6c22c8a2b64
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/config/walker2d_random_td3bc_config.py
@@ -0,0 +1,65 @@
+# You can conduct Experiments on D4RL with this config file through the following command:
+# cd ../entry && python d4rl_td3_bc_main.py
+from easydict import EasyDict
+
+main_config = dict(
+ exp_name='walker2d_random_td3-bc_seed0',
+ env=dict(
+ env_id='walker2d-random-v2',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=0.0003,
+ learning_rate_critic=0.0003,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range={
+ 'min': -0.5,
+ 'max': 0.5
+ },
+ alpha=2.5,
+ ),
+ collect=dict(
+ data_type='d4rl',
+ data_path=None,
+ ),
+ eval=dict(evaluator=dict(eval_freq=10000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+main_config = main_config
+create_config = dict(
+ env=dict(
+ type='d4rl',
+ import_names=['dizoo.d4rl.envs.d4rl_env'],
+ ),
+ env_manager=dict(
+ cfg_type='BaseEnvManagerDict',
+ type='base',
+ ),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+create_config = EasyDict(create_config)
+create_config = create_config
diff --git a/DI-engine/dizoo/d4rl/entry/__init__.py b/DI-engine/dizoo/d4rl/entry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/d4rl/entry/d4rl_bcq_main.py b/DI-engine/dizoo/d4rl/entry/d4rl_bcq_main.py
new file mode 100755
index 0000000000000000000000000000000000000000..099f6e025b3d44768a68662dcef8fc133ad78462
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/entry/d4rl_bcq_main.py
@@ -0,0 +1,21 @@
+from ding.entry import serial_pipeline_offline
+from ding.config import read_config
+from pathlib import Path
+
+
+def train(args):
+ # launch from anywhere
+ config = Path(__file__).absolute().parent.parent / 'config' / args.config
+ config = read_config(str(config))
+ config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_bcq_config.py')
+ args = parser.parse_args()
+ train(args)
diff --git a/DI-engine/dizoo/d4rl/entry/d4rl_cql_main.py b/DI-engine/dizoo/d4rl/entry/d4rl_cql_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a8934a90a38c13d63f61e9b5475f09428308ec0
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/entry/d4rl_cql_main.py
@@ -0,0 +1,21 @@
+from ding.entry import serial_pipeline_offline
+from ding.config import read_config
+from pathlib import Path
+
+
+def train(args):
+ # launch from anywhere
+ config = Path(__file__).absolute().parent.parent / 'config' / args.config
+ config = read_config(str(config))
+ config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=10)
+ parser.add_argument('--config', '-c', type=str, default='hopper_expert_cql_config.py')
+ args = parser.parse_args()
+ train(args)
diff --git a/DI-engine/dizoo/d4rl/entry/d4rl_dt_mujoco.py b/DI-engine/dizoo/d4rl/entry/d4rl_dt_mujoco.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6bf93e2c521ef8b728057d353eadd1c4c8ed3bf
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/entry/d4rl_dt_mujoco.py
@@ -0,0 +1,48 @@
+import gym
+import torch
+import numpy as np
+from ditk import logging
+from ding.model.template.decision_transformer import DecisionTransformer
+from ding.policy import DTPolicy
+from ding.envs import BaseEnvManagerV2
+from ding.envs.env_wrappers.env_wrappers import AllinObsWrapper
+from ding.data import create_dataset
+from ding.config import compile_config
+from ding.framework import task, ding_init
+from ding.framework.context import OfflineRLContext
+from ding.framework.middleware import interaction_evaluator, trainer, CkptSaver, offline_data_fetcher_from_mem, offline_logger, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.d4rl.envs import D4RLEnv
+from dizoo.d4rl.config.hopper_medium_dt_config import main_config, create_config
+
+
+def main():
+ # If you don't have offline data, you need to prepare if first and set the data_path in config
+ # For demostration, we also can train a RL policy (e.g. SAC) and collect some data
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ ding_init(cfg)
+ with task.start(async_mode=False, ctx=OfflineRLContext()):
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: AllinObsWrapper(D4RLEnv(cfg.env)) for _ in range(cfg.env.evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ dataset = create_dataset(cfg)
+ # env_data_stats = dataset.get_d4rl_dataset_stats(cfg.policy.dataset_name)
+ cfg.policy.state_mean, cfg.policy.state_std = dataset.get_state_stats()
+ model = DecisionTransformer(**cfg.policy.model)
+ policy = DTPolicy(cfg.policy, model=model)
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(offline_data_fetcher_from_mem(cfg, dataset))
+ task.use(trainer(cfg, policy.learn_mode))
+ task.use(termination_checker(max_train_iter=5e4))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+ task.use(offline_logger())
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/d4rl/entry/d4rl_edac_main.py b/DI-engine/dizoo/d4rl/entry/d4rl_edac_main.py
new file mode 100755
index 0000000000000000000000000000000000000000..b6710836cbdaa40073b0be01cb5f6560888141bf
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/entry/d4rl_edac_main.py
@@ -0,0 +1,21 @@
+from ding.entry import serial_pipeline_offline
+from ding.config import read_config
+from pathlib import Path
+
+
+def train(args):
+ # launch from anywhere
+ config = Path(__file__).absolute().parent.parent / 'config' / args.config
+ config = read_config(str(config))
+ config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_edac_config.py')
+ args = parser.parse_args()
+ train(args)
diff --git a/DI-engine/dizoo/d4rl/entry/d4rl_ibc_main.py b/DI-engine/dizoo/d4rl/entry/d4rl_ibc_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..a112916f6c3659f0ea800f01b8e967aa2f57f3ca
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/entry/d4rl_ibc_main.py
@@ -0,0 +1,35 @@
+import os
+from ding.entry import serial_pipeline_offline
+from ding.config import read_config
+from ding.utils import dist_init
+from pathlib import Path
+import torch
+import torch.multiprocessing as mp
+
+
+def offline_worker(rank, config, args):
+ dist_init(rank=rank, world_size=torch.cuda.device_count())
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+def train(args):
+ # launch from anywhere
+ config = Path(__file__).absolute().parent.parent / 'config' / args.config
+ config = read_config(str(config))
+ config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
+ if not config[0].policy.multi_gpu:
+ serial_pipeline_offline(config, seed=args.seed)
+ else:
+ os.environ["MASTER_ADDR"] = "localhost"
+ os.environ["MASTER_PORT"] = "29600"
+ mp.spawn(offline_worker, nprocs=torch.cuda.device_count(), args=(config, args))
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ parser.add_argument('--config', '-c', type=str, default='hopper_medium_expert_ibc_config.py')
+ args = parser.parse_args()
+ train(args)
diff --git a/DI-engine/dizoo/d4rl/entry/d4rl_pd_main.py b/DI-engine/dizoo/d4rl/entry/d4rl_pd_main.py
new file mode 100755
index 0000000000000000000000000000000000000000..1ca3c5b2995edcc88ddf5be4e012cd1ea781870b
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/entry/d4rl_pd_main.py
@@ -0,0 +1,21 @@
+from ding.entry import serial_pipeline_offline
+from ding.config import read_config
+from pathlib import Path
+
+
+def train(args):
+ # launch from anywhere
+ config = Path(__file__).absolute().parent.parent / 'config' / args.config
+ config = read_config(str(config))
+ config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=10)
+ parser.add_argument('--config', '-c', type=str, default='halfcheetah_medium_pd_config.py')
+ args = parser.parse_args()
+ train(args)
\ No newline at end of file
diff --git a/DI-engine/dizoo/d4rl/entry/d4rl_td3_bc_main.py b/DI-engine/dizoo/d4rl/entry/d4rl_td3_bc_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..b25bf904a5ce54576dcd962a8126778494344c7d
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/entry/d4rl_td3_bc_main.py
@@ -0,0 +1,21 @@
+from ding.entry import serial_pipeline_offline
+from ding.config import read_config
+from pathlib import Path
+
+
+def train(args):
+ # launch from anywhere
+ config = Path(__file__).absolute().parent.parent / 'config' / args.config
+ config = read_config(str(config))
+ config[0].exp_name = config[0].exp_name.replace('0', str(args.seed))
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=10)
+ parser.add_argument('--config', '-c', type=str, default='hopper_medium_expert_td3bc_config.py')
+ args = parser.parse_args()
+ train(args)
diff --git a/DI-engine/dizoo/d4rl/envs/__init__.py b/DI-engine/dizoo/d4rl/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..530ec5aab4eaef44fd325a47c28a084e477d9da8
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/envs/__init__.py
@@ -0,0 +1 @@
+from .d4rl_env import D4RLEnv
diff --git a/DI-engine/dizoo/d4rl/envs/d4rl_env.py b/DI-engine/dizoo/d4rl/envs/d4rl_env.py
new file mode 100755
index 0000000000000000000000000000000000000000..db770fd0992213a753ea14d484da063cc4252b00
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/envs/d4rl_env.py
@@ -0,0 +1,204 @@
+from typing import Any, Union, List
+import copy
+import numpy as np
+import gym
+import matplotlib.pyplot as plt
+import einops
+import imageio
+from easydict import EasyDict
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.envs.common.env_element import EnvElement, EnvElementInfo
+from ding.envs.common.common_function import affine_transform
+from ding.torch_utils import to_ndarray, to_list
+from .d4rl_wrappers import wrap_d4rl
+from ding.utils import ENV_REGISTRY
+
+MAZE_BOUNDS = {
+ 'maze2d-umaze-v1': (0, 5, 0, 5),
+ 'maze2d-medium-v1': (0, 8, 0, 8),
+ 'maze2d-large-v1': (0, 9, 0, 12)
+}
+
+def plot2img(fig, remove_margins=True):
+ # https://stackoverflow.com/a/35362787/2912349
+ # https://stackoverflow.com/a/54334430/2912349
+
+ from matplotlib.backends.backend_agg import FigureCanvasAgg
+
+ if remove_margins:
+ fig.subplots_adjust(left=0, bottom=0, right=1, top=1, wspace=0, hspace=0)
+
+ canvas = FigureCanvasAgg(fig)
+ canvas.draw()
+ img_as_string, (width, height) = canvas.print_to_buffer()
+ return np.fromstring(img_as_string, dtype='uint8').reshape((height, width, 4))
+
+def zipsafe(*args):
+ length = len(args[0])
+ assert all([len(a) == length for a in args])
+ return zip(*args)
+
+def zipkw(*args, **kwargs):
+ nargs = len(args)
+ keys = kwargs.keys()
+ vals = [kwargs[k] for k in keys]
+ zipped = zipsafe(*args, *vals)
+ for items in zipped:
+ zipped_args = items[:nargs]
+ zipped_kwargs = {k: v for k, v in zipsafe(keys, items[nargs:])}
+ yield zipped_args, zipped_kwargs
+
+@ENV_REGISTRY.register('d4rl')
+class D4RLEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._use_act_scale = cfg.use_act_scale
+ self._init_flag = False
+ if 'maze' in self._cfg.env_id:
+ self.observations = []
+ self._extent = (0, 1, 1, 0)
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = self._make_env(only_info=False)
+ self._env.observation_space.dtype = np.float32 # To unify the format of envs in DI-engine
+ self._observation_space = self._env.observation_space
+ if 'maze' in self._cfg.env_id:
+ new_low = np.tile(self._observation_space.low, 2)
+ new_high = np.tile(self._observation_space.high, 2)
+ self._observation_space = gym.spaces.Box(low=new_low, high=new_high)
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ if 'maze' in self._cfg.env_id:
+ target = self._env.get_target()
+ self.target_obs = np.array([*target, 0, 0])
+ obs = self._env.reset()
+ if 'maze' in self._cfg.env_id:
+ self.observations.append(obs)
+ obs = np.hstack((obs, self.target_obs))
+ obs = to_ndarray(obs).astype('float32')
+ self._eval_episode_return = 0.
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
+ action = to_ndarray(action)
+ if self._use_act_scale:
+ action_range = {'min': self.action_space.low[0], 'max': self.action_space.high[0], 'dtype': np.float32}
+ action = affine_transform(action, min_val=action_range['min'], max_val=action_range['max'])
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ if 'maze' in self._cfg.env_id:
+ self.observations.append(obs)
+ obs = np.hstack([obs, self.target_obs])
+ obs = to_ndarray(obs).astype('float32')
+ rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ # self.composite('/mnt/PD/render/rollout.png',self.observations,ncol=1)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def renders(self, observations, conditions=None, title=None):
+ bounds = MAZE_BOUNDS[self._cfg.env_id]
+
+ observations = observations + .5
+ if len(bounds) == 2:
+ _, scale = bounds
+ observations /= scale
+ elif len(bounds) == 4:
+ _, iscale, _, jscale = bounds
+ observations[:, 0] /= iscale
+ observations[:, 1] /= jscale
+ else:
+ raise RuntimeError(f'Unrecognized bounds for {self._cfg.env_id}: {bounds}')
+
+ if conditions is not None:
+ conditions /= scale
+
+ plt.clf()
+ fig = plt.gcf()
+ fig.set_size_inches(5, 5)
+ plt.imshow(self._background * .5,
+ extent=self._extent, cmap=plt.cm.binary, vmin=0, vmax=1)
+
+ path_length = len(observations)
+ colors = plt.cm.jet(np.linspace(0,1,path_length))
+ plt.plot(observations[:,1], observations[:,0], c='black', zorder=10)
+ plt.scatter(observations[:,1], observations[:,0], c=colors, zorder=20)
+ plt.axis('off')
+ plt.title(title)
+ img = plot2img(fig, remove_margins=self._remove_margins)
+ return img
+
+ def composite(self, savepath, paths, ncol=5, **kwargs):
+ assert len(paths) % ncol == 0, 'Number of paths must be divisible by number of columns'
+
+ images = []
+ for path, kw in zipkw(paths, **kwargs):
+ img = self.renders(*path, **kw)
+ images.append(img)
+ images = np.stack(images, axis=0)
+
+ nrow = len(images) // ncol
+ images = einops.rearrange(images,
+ '(nrow ncol) H W C -> (nrow H) (ncol W) C', nrow=nrow, ncol=ncol)
+ imageio.imsave(savepath, images)
+ print(f'Saved {len(paths)} samples to: {savepath}')
+
+ def _make_env(self, only_info=False):
+ return wrap_d4rl(
+ self._cfg.env_id,
+ norm_obs=self._cfg.get(
+ 'norm_obs',
+ EasyDict(use_norm=False, offline_stats=dict(use_offline_stats=False, )),
+ ),
+ norm_reward=self._cfg.get('norm_reward', EasyDict(use_norm=False, )),
+ only_info=only_info
+ )
+
+ def __repr__(self) -> str:
+ return "DI-engine D4RL Env({})".format(self._cfg.env_id)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_cfg = copy.deepcopy(cfg)
+ collector_env_num = collector_cfg.pop('collector_env_num', 1)
+ return [collector_cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_cfg = copy.deepcopy(cfg)
+ evaluator_env_num = evaluator_cfg.pop('evaluator_env_num', 1)
+ evaluator_cfg.get('norm_reward', EasyDict(use_norm=False, )).use_norm = False
+ return [evaluator_cfg for _ in range(evaluator_env_num)]
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
diff --git a/DI-engine/dizoo/d4rl/envs/d4rl_wrappers.py b/DI-engine/dizoo/d4rl/envs/d4rl_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..e9296657ad7a23b557521824a071d37b55f90e07
--- /dev/null
+++ b/DI-engine/dizoo/d4rl/envs/d4rl_wrappers.py
@@ -0,0 +1,51 @@
+from typing import Dict
+import gym
+import numpy as np
+from ditk import logging
+from ding.envs import ObsNormWrapper, StaticObsNormWrapper, RewardNormWrapper
+
+try:
+ import d4rl # register d4rl enviroments with open ai gym
+except ImportError:
+ logging.warning("not found d4rl env, please install it, refer to https://github.com/rail-berkeley/d4rl")
+
+
+def wrap_d4rl(
+ env_id,
+ norm_obs: Dict = dict(use_norm=False, offline_stats=dict(use_offline_stats=False, )),
+ norm_reward: Dict = dict(use_norm=False, ),
+ only_info=False
+) -> gym.Env:
+ r"""
+ Overview:
+ Wrap Mujoco Env to preprocess env step's return info, e.g. observation normalization, reward normalization, etc.
+ Arguments:
+ - env_id (:obj:`str`): Mujoco environment id, for example "HalfCheetah-v3"
+ - norm_obs (:obj:`EasyDict`): Whether to normalize observation or not
+ - norm_reward (:obj:`EasyDict`): Whether to normalize reward or not. For evaluator, environment's reward \
+ should not be normalized: Either ``norm_reward`` is None or ``norm_reward.use_norm`` is False can do this.
+ Returns:
+ - wrapped_env (:obj:`gym.Env`): The wrapped mujoco environment
+ """
+ if not only_info:
+ env = gym.make(env_id)
+ if norm_obs is not None and norm_obs.use_norm:
+ offline_stats = norm_obs.get('offline_stats', dict(use_offline_stats=False))
+ if offline_stats.use_offline_stats:
+ env = StaticObsNormWrapper(env, offline_stats.mean, offline_stats.std)
+ else:
+ env = ObsNormWrapper(env)
+ if norm_reward is not None and norm_reward.use_norm:
+ env = RewardNormWrapper(env, norm_reward.reward_discount)
+ return env
+ else:
+ wrapper_info = ''
+ if norm_obs is not None and norm_obs.use_norm:
+ offline_stats = norm_obs.get('offline_stats', dict(use_offline_stats=False))
+ if offline_stats.use_offline_stats:
+ wrapper_info = StaticObsNormWrapper.__name__ + '\n'
+ else:
+ wrapper_info = ObsNormWrapper.__name__ + '\n'
+ if norm_reward is not None and norm_reward.use_norm:
+ wrapper_info += RewardNormWrapper.__name__ + '\n'
+ return wrapper_info
diff --git a/DI-engine/dizoo/dmc2gym/__init__.py b/DI-engine/dizoo/dmc2gym/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py b/DI-engine/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..66f7c7e2a43754c950568184230d4532167eaedd
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/config/cartpole_balance/cartpole_balance_dreamer_config.py
@@ -0,0 +1,93 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dreamer
+
+cuda = False
+
+cartpole_balance_dreamer_config = dict(
+ exp_name='dmc2gym_cartpole_balance_dreamer',
+ env=dict(
+ env_id='dmc2gym_cartpole_balance',
+ domain_name='cartpole',
+ task_name='balance',
+ frame_skip=1,
+ warp_frame=True,
+ scale=True,
+ clip_rewards=False,
+ action_repeat=2,
+ frame_stack=1,
+ from_pixels=True,
+ resize=64,
+ collector_env_num=1,
+ evaluator_env_num=1,
+ n_evaluator_episode=1,
+ stop_value=1000, # 1000
+ ),
+ policy=dict(
+ cuda=cuda,
+ # it is better to put random_collect_size in policy.other
+ random_collect_size=2500,
+ model=dict(
+ obs_shape=(3, 64, 64),
+ action_shape=1,
+ actor_dist='normal',
+ ),
+ learn=dict(
+ lambda_=0.95,
+ learning_rate=3e-5,
+ batch_size=16,
+ batch_length=64,
+ imag_sample=True,
+ discount=0.997,
+ reward_EMA=True,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ action_size=1, # has to be specified
+ collect_dyn_sample=True,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=500000, periodic_thruput_seconds=60),
+ ),
+ ),
+ world_model=dict(
+ pretrain=100,
+ train_freq=2,
+ cuda=cuda,
+ model=dict(
+ state_size=(3, 64, 64), # has to be specified
+ action_size=1, # has to be specified
+ reward_size=1,
+ batch_size=16,
+ ),
+ ),
+)
+
+cartpole_balance_dreamer_config = EasyDict(cartpole_balance_dreamer_config)
+
+cartpole_balance_create_config = dict(
+ env=dict(
+ type='dmc2gym',
+ import_names=['dizoo.dmc2gym.envs.dmc2gym_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='dreamer',
+ import_names=['ding.policy.mbpolicy.dreamer'],
+ ),
+ replay_buffer=dict(type='sequence', ),
+ world_model=dict(
+ type='dreamer',
+ import_names=['ding.world_model.dreamer'],
+ ),
+)
+cartpole_balance_create_config = EasyDict(cartpole_balance_create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dreamer(
+ (cartpole_balance_dreamer_config, cartpole_balance_create_config), seed=0, max_env_step=500000
+ )
diff --git a/DI-engine/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py b/DI-engine/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..32a43463e7b923ff4ddff56f6ad8fdbce4bcceb5
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/config/cheetah_run/cheetah_run_dreamer_config.py
@@ -0,0 +1,91 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dreamer
+
+cuda = False
+
+cheetah_run_dreamer_config = dict(
+ exp_name='dmc2gym_cheetah_run_dreamer',
+ env=dict(
+ env_id='dmc2gym_cheetah_run',
+ domain_name='cheetah',
+ task_name='run',
+ frame_skip=1,
+ warp_frame=True,
+ scale=True,
+ clip_rewards=False,
+ action_repeat=2,
+ frame_stack=1,
+ from_pixels=True,
+ resize=64,
+ collector_env_num=1,
+ evaluator_env_num=1,
+ n_evaluator_episode=1,
+ stop_value=1000, # 1000
+ ),
+ policy=dict(
+ cuda=cuda,
+ # it is better to put random_collect_size in policy.other
+ random_collect_size=2500,
+ model=dict(
+ obs_shape=(3, 64, 64),
+ action_shape=6,
+ actor_dist='normal',
+ ),
+ learn=dict(
+ lambda_=0.95,
+ learning_rate=3e-5,
+ batch_size=16,
+ batch_length=64,
+ imag_sample=True,
+ discount=0.997,
+ reward_EMA=True,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ action_size=6, # has to be specified
+ collect_dyn_sample=True,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=500000, periodic_thruput_seconds=60),
+ ),
+ ),
+ world_model=dict(
+ pretrain=100,
+ train_freq=2,
+ cuda=cuda,
+ model=dict(
+ state_size=(3, 64, 64), # has to be specified
+ action_size=6, # has to be specified
+ reward_size=1,
+ batch_size=16,
+ ),
+ ),
+)
+
+cheetah_run_dreamer_config = EasyDict(cheetah_run_dreamer_config)
+
+cheetah_run_create_config = dict(
+ env=dict(
+ type='dmc2gym',
+ import_names=['dizoo.dmc2gym.envs.dmc2gym_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='dreamer',
+ import_names=['ding.policy.mbpolicy.dreamer'],
+ ),
+ replay_buffer=dict(type='sequence', ),
+ world_model=dict(
+ type='dreamer',
+ import_names=['ding.world_model.dreamer'],
+ ),
+)
+cheetah_run_create_config = EasyDict(cheetah_run_create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dreamer((cheetah_run_dreamer_config, cheetah_run_create_config), seed=0, max_env_step=500000)
diff --git a/DI-engine/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py b/DI-engine/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..de8e09e3d873fc8020ab9ad66e3a47b363d5d003
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/config/dmc2gym_dreamer_config.py
@@ -0,0 +1,93 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dreamer
+
+cuda = False
+
+cartpole_balance_dreamer_config = dict(
+ exp_name='dmc2gym_cartpole_balance_dreamer',
+ env=dict(
+ env_id='dmc2gym_cartpole_balance',
+ domain_name='cartpole',
+ task_name='balance',
+ frame_skip=1,
+ warp_frame=True,
+ scale=True,
+ clip_rewards=False,
+ action_repeat=2,
+ frame_stack=1,
+ from_pixels=True,
+ resize=64,
+ collector_env_num=1,
+ evaluator_env_num=1,
+ n_evaluator_episode=1,
+ stop_value=1000, # 1000
+ ),
+ policy=dict(
+ cuda=cuda,
+ # it is better to put random_collect_size in policy.other
+ random_collect_size=2500,
+ model=dict(
+ obs_shape=(3, 64, 64),
+ action_shape=1,
+ actor_dist='normal',
+ ),
+ learn=dict(
+ lambda_=0.95,
+ learning_rate=3e-5,
+ batch_size=16,
+ batch_length=64,
+ imag_sample=True,
+ discount=0.997,
+ reward_EMA=True,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ action_size=1, # has to be specified
+ collect_dyn_sample=True,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=500000, periodic_thruput_seconds=60),
+ ),
+ ),
+ world_model=dict(
+ pretrain=100,
+ train_freq=2,
+ cuda=cuda,
+ model=dict(
+ state_size=(3, 64, 64), # has to be specified
+ action_size=1, # has to be specified
+ reward_size=1,
+ batch_size=16,
+ ),
+ ),
+)
+
+cartpole_balance_dreamer_config = EasyDict(cartpole_balance_dreamer_config)
+
+cartpole_balance_create_config = dict(
+ env=dict(
+ type='dmc2gym',
+ import_names=['dizoo.dmc2gym.envs.dmc2gym_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='dreamer',
+ import_names=['ding.policy.mbpolicy.dreamer'],
+ ),
+ replay_buffer=dict(type='sequence', ),
+ world_model=dict(
+ type='dreamer',
+ import_names=['ding.world_model.dreamer'],
+ ),
+)
+cartpole_balance_create_config = EasyDict(cartpole_balance_create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dreamer(
+ (cartpole_balance_dreamer_config, cartpole_balance_create_config), seed=0, max_env_step=1000000
+ )
diff --git a/DI-engine/dizoo/dmc2gym/config/dmc2gym_ppo_config.py b/DI-engine/dizoo/dmc2gym/config/dmc2gym_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..207b398e63765a18b13a1c53e1b4284c28562f6d
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/config/dmc2gym_ppo_config.py
@@ -0,0 +1,63 @@
+from easydict import EasyDict
+
+cartpole_balance_ppo_config = dict(
+ exp_name='dmc2gym_cartpole_balance_ppo',
+ env=dict(
+ env_id='dmc2gym_cartpole_balance',
+ domain_name='cartpole',
+ task_name='balance',
+ from_pixels=False,
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=1000,
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=5,
+ action_shape=1,
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ ),
+ learn=dict(
+ epoch_per_collect=2,
+ batch_size=64,
+ learning_rate=0.001,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ learner=dict(hook=dict(save_ckpt_after_iter=100)),
+ ),
+ collect=dict(
+ n_sample=256,
+ unroll_len=1,
+ discount_factor=0.9,
+ gae_lambda=0.95,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
+ )
+)
+cartpole_balance_ppo_config = EasyDict(cartpole_balance_ppo_config)
+main_config = cartpole_balance_ppo_config
+
+cartpole_balance_create_config = dict(
+ env=dict(
+ type='dmc2gym',
+ import_names=['dizoo.dmc2gym.envs.dmc2gym_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo'),
+ replay_buffer=dict(type='naive', ),
+)
+cartpole_balance_create_config = EasyDict(cartpole_balance_create_config)
+create_config = cartpole_balance_create_config
+
+# To use this config, you can enter dizoo/dmc2gym/entry to call dmc2gym_onppo_main.py
diff --git a/DI-engine/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py b/DI-engine/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0155b1ebd2cd26ccdd2e4b534a9bbe3481f4035
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/config/dmc2gym_sac_pixel_config.py
@@ -0,0 +1,79 @@
+from easydict import EasyDict
+import os
+# os.environ['MUJOCO_GL']="egl"
+dmc2gym_sac_config = dict(
+ exp_name='dmc2gym_sac_pixel_seed0',
+ env=dict(
+ env_id='dmc2gym-v0',
+ domain_name="cartpole",
+ task_name="swingup",
+ frame_skip=4,
+ warp_frame=True,
+ scale=True,
+ clip_rewards=False,
+ frame_stack=3,
+ from_pixels=True, # pixel obs
+ channels_first=False, # obs shape (height, width, 3)
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=1e6,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model_type='pixel',
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=(3, 84, 84),
+ action_shape=1,
+ twin_critic=True,
+ encoder_hidden_size_list=[32, 32, 32],
+ actor_head_hidden_size=1024,
+ critic_head_hidden_size=1024,
+ share_encoder=True,
+ ),
+ learn=dict(
+ ignore_done=True,
+ update_per_collect=1,
+ batch_size=128,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=True,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
+ ),
+)
+
+dmc2gym_sac_config = EasyDict(dmc2gym_sac_config)
+main_config = dmc2gym_sac_config
+
+dmc2gym_sac_create_config = dict(
+ env=dict(
+ type='dmc2gym',
+ import_names=['dizoo.dmc2gym.envs.dmc2gym_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+dmc2gym_sac_create_config = EasyDict(dmc2gym_sac_create_config)
+create_config = dmc2gym_sac_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c ant_sac_config.py -s 0 --env-step 1e7`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
\ No newline at end of file
diff --git a/DI-engine/dizoo/dmc2gym/config/dmc2gym_sac_state_config.py b/DI-engine/dizoo/dmc2gym/config/dmc2gym_sac_state_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..840629423b9e54156bb977d2e94dfc5a47d98751
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/config/dmc2gym_sac_state_config.py
@@ -0,0 +1,69 @@
+from easydict import EasyDict
+
+dmc2gym_sac_config = dict(
+ exp_name='dmc2gym_sac_state_seed0',
+ env=dict(
+ env_id='dmc2gym-v0',
+ domain_name="cartpole",
+ task_name="swingup",
+ frame_skip=8,
+ frame_stack=1,
+ from_pixels=False, # state obs
+ channels_first=False, # obs shape (height, width, 3)
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=1e6,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model_type='state',
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=5,
+ action_shape=1,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ ignore_done=True,
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=True,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+dmc2gym_sac_config = EasyDict(dmc2gym_sac_config)
+main_config = dmc2gym_sac_config
+
+dmc2gym_sac_create_config = dict(
+ env=dict(
+ type='dmc2gym',
+ import_names=['dizoo.dmc2gym.envs.dmc2gym_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+dmc2gym_sac_create_config = EasyDict(dmc2gym_sac_create_config)
+create_config = dmc2gym_sac_create_config
diff --git a/DI-engine/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py b/DI-engine/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..16e76eac391cbc08508d836fc74cc348ac30d3e9
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/config/walker_walk/walker_walk_dreamer_config.py
@@ -0,0 +1,91 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dreamer
+
+cuda = False
+
+walker_walk_dreamer_config = dict(
+ exp_name='dmc2gym_walker_walk_dreamer',
+ env=dict(
+ env_id='dmc2gym_walker_walk',
+ domain_name='walker',
+ task_name='walk',
+ frame_skip=1,
+ warp_frame=True,
+ scale=True,
+ clip_rewards=False,
+ action_repeat=2,
+ frame_stack=1,
+ from_pixels=True,
+ resize=64,
+ collector_env_num=1,
+ evaluator_env_num=1,
+ n_evaluator_episode=1,
+ stop_value=1000, # 1000
+ ),
+ policy=dict(
+ cuda=cuda,
+ # it is better to put random_collect_size in policy.other
+ random_collect_size=2500,
+ model=dict(
+ obs_shape=(3, 64, 64),
+ action_shape=6,
+ actor_dist='normal',
+ ),
+ learn=dict(
+ lambda_=0.95,
+ learning_rate=3e-5,
+ batch_size=16,
+ batch_length=64,
+ imag_sample=True,
+ discount=0.997,
+ reward_EMA=True,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ action_size=6, # has to be specified
+ collect_dyn_sample=True,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=500000, periodic_thruput_seconds=60),
+ ),
+ ),
+ world_model=dict(
+ pretrain=100,
+ train_freq=2,
+ cuda=cuda,
+ model=dict(
+ state_size=(3, 64, 64), # has to be specified
+ action_size=6, # has to be specified
+ reward_size=1,
+ batch_size=16,
+ ),
+ ),
+)
+
+walker_walk_dreamer_config = EasyDict(walker_walk_dreamer_config)
+
+walker_walk_create_config = dict(
+ env=dict(
+ type='dmc2gym',
+ import_names=['dizoo.dmc2gym.envs.dmc2gym_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='dreamer',
+ import_names=['ding.policy.mbpolicy.dreamer'],
+ ),
+ replay_buffer=dict(type='sequence', ),
+ world_model=dict(
+ type='dreamer',
+ import_names=['ding.world_model.dreamer'],
+ ),
+)
+walker_walk_create_config = EasyDict(walker_walk_create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dreamer((walker_walk_dreamer_config, walker_walk_create_config), seed=0, max_env_step=500000)
diff --git a/DI-engine/dizoo/dmc2gym/entry/dmc2gym_onppo_main.py b/DI-engine/dizoo/dmc2gym/entry/dmc2gym_onppo_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..412a46577a68ddbb074c073ead66a4de0dbce704
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/entry/dmc2gym_onppo_main.py
@@ -0,0 +1,124 @@
+import os
+from easydict import EasyDict
+from functools import partial
+from tensorboardX import SummaryWriter
+import dmc2gym
+
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
+from ding.model import VAC
+from ding.policy import PPOPolicy
+from ding.envs import DingEnvWrapper, EvalEpisodeReturnWrapper, BaseEnvManager
+from ding.config import compile_config
+from ding.utils import set_pkg_seed
+from dizoo.dmc2gym.config.dmc2gym_ppo_config import cartpole_balance_ppo_config
+from dizoo.dmc2gym.envs.dmc2gym_env import *
+
+
+class Dmc2GymWrapper(gym.Wrapper):
+
+ def __init__(self, env, cfg):
+ super().__init__(env)
+ cfg = EasyDict(cfg)
+ self._cfg = cfg
+
+ env_info = dmc2gym_env_info[cfg.domain_name][cfg.task_name]
+
+ self._observation_space = env_info["observation_space"](
+ from_pixels=self._cfg["from_pixels"],
+ height=self._cfg["height"],
+ width=self._cfg["width"],
+ channels_first=self._cfg["channels_first"]
+ )
+ self._action_space = env_info["action_space"]
+ self._reward_space = env_info["reward_space"](self._cfg["frame_skip"])
+
+ def _process_obs(self, obs):
+ if self._cfg["from_pixels"]:
+ obs = to_ndarray(obs).astype(np.uint8)
+ else:
+ obs = to_ndarray(obs).astype(np.float32)
+ return obs
+
+ def step(self, action):
+ action = np.array([action]).astype('float32')
+ obs, reward, done, info = self.env.step(action)
+ return self._process_obs(obs), reward, done, info
+
+ def reset(self):
+ obs = self.env.reset()
+ return self._process_obs(obs)
+
+
+def wrapped_dmc2gym_env(cfg):
+ default_cfg = {
+ "frame_skip": 3,
+ "from_pixels": True,
+ "visualize_reward": False,
+ "height": 100,
+ "width": 100,
+ "channels_first": True,
+ }
+ default_cfg.update(cfg)
+
+ return DingEnvWrapper(
+ dmc2gym.make(
+ domain_name=default_cfg["domain_name"],
+ task_name=default_cfg["task_name"],
+ seed=1,
+ visualize_reward=default_cfg["visualize_reward"],
+ from_pixels=default_cfg["from_pixels"],
+ height=default_cfg["height"],
+ width=default_cfg["width"],
+ frame_skip=default_cfg["frame_skip"]
+ ),
+ cfg={
+ 'env_wrapper': [
+ lambda env: Dmc2GymWrapper(env, default_cfg),
+ lambda env: EvalEpisodeReturnWrapper(env),
+ ]
+ }
+ )
+
+
+def main(cfg, seed=0, max_env_step=int(1e10), max_train_iter=int(1e10)):
+ cfg = compile_config(
+ cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(
+ env_fn=[partial(wrapped_dmc2gym_env, cfg=cartpole_balance_ppo_config.env) for _ in range(collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManager(
+ env_fn=[partial(wrapped_dmc2gym_env, cfg=cartpole_balance_ppo_config.env) for _ in range(evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+
+ while True:
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ new_data = collector.collect(train_iter=learner.train_iter)
+ learner.train(new_data, collector.envstep)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+
+if __name__ == '__main__':
+ main(cartpole_balance_ppo_config)
diff --git a/DI-engine/dizoo/dmc2gym/entry/dmc2gym_sac_pixel_main.py b/DI-engine/dizoo/dmc2gym/entry/dmc2gym_sac_pixel_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..e03fcc2f0586a0f9ee6d82ea7d1aab4de2a01a3a
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/entry/dmc2gym_sac_pixel_main.py
@@ -0,0 +1,89 @@
+from tensorboardX import SummaryWriter
+from ditk import logging
+import os
+import numpy as np
+from ding.model.template.qac import ContinuousQAC
+from ding.policy import SACPolicy
+from ding.envs import BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import data_pusher, StepCollector, interaction_evaluator, \
+ CkptSaver, OffPolicyLearner, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.dmc2gym.envs.dmc2gym_env import DMC2GymEnv
+from dizoo.dmc2gym.config.dmc2gym_sac_pixel_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'dmc2gym_sac_pixel_seed0'
+ main_config.policy.cuda = True
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+
+ num_seed = 1
+ for seed_i in range(num_seed):
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed' + str(seed_i)))
+
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DMC2GymEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DMC2GymEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = ContinuousQAC(**cfg.policy.model)
+ logging.info(model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = SACPolicy(cfg.policy, model=model)
+
+ def _add_scalar(ctx):
+ if ctx.eval_value != -np.inf:
+ tb_logger.add_scalar('evaluator_step/reward', ctx.eval_value, global_step=ctx.env_step)
+ collector_rewards = [ctx.trajectories[i]['reward'] for i in range(len(ctx.trajectories))]
+ collector_mean_reward = sum(collector_rewards) / len(ctx.trajectories)
+ # collector_max_reward = max(collector_rewards)
+ # collector_min_reward = min(collector_rewards)
+ tb_logger.add_scalar('collecter_step/mean_reward', collector_mean_reward, global_step=ctx.env_step)
+ # tb_logger.add_scalar('collecter_step/max_reward', collector_max_reward, global_step= ctx.env_step)
+ # tb_logger.add_scalar('collecter_step/min_reward', collector_min_reward, global_step= ctx.env_step)
+ tb_logger.add_scalar(
+ 'collecter_step/avg_env_step_per_episode',
+ ctx.env_step / ctx.env_episode,
+ global_step=ctx.env_step
+ )
+
+ def _add_train_scalar(ctx):
+ len_train = len(ctx.train_output)
+ cur_lr_q_avg = sum([ctx.train_output[i]['cur_lr_q'] for i in range(len_train)]) / len_train
+ cur_lr_p_avg = sum([ctx.train_output[i]['cur_lr_p'] for i in range(len_train)]) / len_train
+ critic_loss_avg = sum([ctx.train_output[i]['critic_loss'] for i in range(len_train)]) / len_train
+ policy_loss_avg = sum([ctx.train_output[i]['policy_loss'] for i in range(len_train)]) / len_train
+ total_loss_avg = sum([ctx.train_output[i]['total_loss'] for i in range(len_train)]) / len_train
+ tb_logger.add_scalar('learner_step/cur_lr_q_avg', cur_lr_q_avg, global_step=ctx.env_step)
+ tb_logger.add_scalar('learner_step/cur_lr_p_avg', cur_lr_p_avg, global_step=ctx.env_step)
+ tb_logger.add_scalar('learner_step/critic_loss_avg', critic_loss_avg, global_step=ctx.env_step)
+ tb_logger.add_scalar('learner_step/policy_loss_avg', policy_loss_avg, global_step=ctx.env_step)
+ tb_logger.add_scalar('learner_step/total_loss_avg', total_loss_avg, global_step=ctx.env_step)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(
+ StepCollector(
+ cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size
+ )
+ )
+ task.use(_add_scalar)
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(_add_train_scalar)
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=int(1e5)))
+ task.use(termination_checker(max_env_step=int(5e6)))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/dmc2gym/entry/dmc2gym_sac_state_main.py b/DI-engine/dizoo/dmc2gym/entry/dmc2gym_sac_state_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..6f2393071c9a187388916ca26e4a4b66f8ec3c7e
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/entry/dmc2gym_sac_state_main.py
@@ -0,0 +1,88 @@
+from ditk import logging
+from ding.model import ContinuousQAC
+from ding.policy import SACPolicy
+from ding.envs import BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import data_pusher, StepCollector, interaction_evaluator, \
+ CkptSaver, OffPolicyLearner, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.dmc2gym.envs.dmc2gym_env import DMC2GymEnv
+from dizoo.dmc2gym.config.dmc2gym_sac_state_config import main_config, create_config
+import numpy as np
+from tensorboardX import SummaryWriter
+import os
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'dmc2gym_sac_state_nseed_5M'
+ main_config.policy.cuda = True
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+
+ num_seed = 4
+ for seed_i in range(num_seed):
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed' + str(seed_i)))
+
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: DMC2GymEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: DMC2GymEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = ContinuousQAC(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = SACPolicy(cfg.policy, model=model)
+
+ def _add_scalar(ctx):
+ if ctx.eval_value != -np.inf:
+ tb_logger.add_scalar('evaluator_step/reward', ctx.eval_value, global_step=ctx.env_step)
+ collector_rewards = [ctx.trajectories[i]['reward'] for i in range(len(ctx.trajectories))]
+ collector_mean_reward = sum(collector_rewards) / len(ctx.trajectories)
+ # collector_max_reward = max(collector_rewards)
+ # collector_min_reward = min(collector_rewards)
+ tb_logger.add_scalar('collecter_step/mean_reward', collector_mean_reward, global_step=ctx.env_step)
+ # tb_logger.add_scalar('collecter_step/max_reward', collector_max_reward, global_step= ctx.env_step)
+ # tb_logger.add_scalar('collecter_step/min_reward', collector_min_reward, global_step= ctx.env_step)
+ tb_logger.add_scalar(
+ 'collecter_step/avg_env_step_per_episode',
+ ctx.env_step / ctx.env_episode,
+ global_step=ctx.env_step
+ )
+
+ def _add_train_scalar(ctx):
+ len_train = len(ctx.train_output)
+ cur_lr_q_avg = sum([ctx.train_output[i]['cur_lr_q'] for i in range(len_train)]) / len_train
+ cur_lr_p_avg = sum([ctx.train_output[i]['cur_lr_p'] for i in range(len_train)]) / len_train
+ critic_loss_avg = sum([ctx.train_output[i]['critic_loss'] for i in range(len_train)]) / len_train
+ policy_loss_avg = sum([ctx.train_output[i]['policy_loss'] for i in range(len_train)]) / len_train
+ total_loss_avg = sum([ctx.train_output[i]['total_loss'] for i in range(len_train)]) / len_train
+ tb_logger.add_scalar('learner_step/cur_lr_q_avg', cur_lr_q_avg, global_step=ctx.env_step)
+ tb_logger.add_scalar('learner_step/cur_lr_p_avg', cur_lr_p_avg, global_step=ctx.env_step)
+ tb_logger.add_scalar('learner_step/critic_loss_avg', critic_loss_avg, global_step=ctx.env_step)
+ tb_logger.add_scalar('learner_step/policy_loss_avg', policy_loss_avg, global_step=ctx.env_step)
+ tb_logger.add_scalar('learner_step/total_loss_avg', total_loss_avg, global_step=ctx.env_step)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(
+ StepCollector(
+ cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size
+ )
+ )
+ task.use(_add_scalar)
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(_add_train_scalar)
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=int(1e5)))
+ task.use(termination_checker(max_env_step=int(5e6)))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/dmc2gym/entry/dmc2gym_save_replay_example.py b/DI-engine/dizoo/dmc2gym/entry/dmc2gym_save_replay_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9930dc193a6982d8ffa4f747552db3beece00b3
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/entry/dmc2gym_save_replay_example.py
@@ -0,0 +1,120 @@
+import os
+import gym
+import torch
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+from functools import partial
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.policy import DDPGPolicy
+from ding.utils import set_pkg_seed
+
+cartpole_balance_ddpg_config = dict(
+ exp_name='dmc2gym_cartpole_balance_ddpg_eval',
+ env=dict(
+ env_id='dmc2gym_cartpole_balance',
+ domain_name='cartpole',
+ task_name='balance',
+ from_pixels=False,
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ replay_path='./dmc2gym_cartpole_balance_ddpg_eval/video',
+ stop_value=1000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=2560,
+ load_path="./dmc2gym_cartpole_balance_ddpg/ckpt/iteration_10000.pth.tar",
+ model=dict(
+ obs_shape=5,
+ action_shape=1,
+ twin_critic=False,
+ actor_head_hidden_size=128,
+ critic_head_hidden_size=128,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=128,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=10000, ), ),
+ )
+)
+cartpole_balance_ddpg_config = EasyDict(cartpole_balance_ddpg_config)
+main_config = cartpole_balance_ddpg_config
+
+cartpole_balance_create_config = dict(
+ env=dict(
+ type='dmc2gym',
+ import_names=['dizoo.dmc2gym.envs.dmc2gym_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='ddpg',
+ import_names=['ding.policy.ddpg'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+cartpole_balance_create_config = EasyDict(cartpole_balance_create_config)
+create_config = cartpole_balance_create_config
+
+
+def main(cfg, create_cfg, seed=0):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ DDPGPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ create_cfg=create_cfg,
+ save_cfg=True
+ )
+
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ evaluator_env.enable_save_replay(cfg.env.replay_path)
+
+ # Set random seed for all package and instance
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ policy = DDPGPolicy(cfg.policy)
+ policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
+
+ # evaluate
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator.eval()
+
+
+if __name__ == "__main__":
+ main(main_config, create_config, seed=0)
diff --git a/DI-engine/dizoo/dmc2gym/envs/__init__.py b/DI-engine/dizoo/dmc2gym/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..63d00d0acf92f9f07d7c895843be6635e1676913
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/envs/__init__.py
@@ -0,0 +1 @@
+from .dmc2gym_env import DMC2GymEnv
diff --git a/DI-engine/dizoo/dmc2gym/envs/dmc2gym_env.py b/DI-engine/dizoo/dmc2gym/envs/dmc2gym_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b2ffc519c72453201d03acde2c46a90bf8ca616
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/envs/dmc2gym_env.py
@@ -0,0 +1,249 @@
+from typing import Optional, Callable
+import gym
+from gym.spaces import Box
+import numpy as np
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.envs.common.common_function import affine_transform
+from ding.torch_utils import to_ndarray
+from ding.utils import ENV_REGISTRY
+import dmc2gym
+from ding.envs import WarpFrameWrapper, ScaledFloatFrameWrapper, ClipRewardWrapper, ActionRepeatWrapper, FrameStackWrapper
+
+
+def dmc2gym_observation_space(dim, minimum=-np.inf, maximum=np.inf, dtype=np.float32) -> Callable:
+
+ def observation_space(from_pixels=True, height=84, width=84, channels_first=True) -> Box:
+ if from_pixels:
+ shape = [3, height, width] if channels_first else [height, width, 3]
+ return Box(low=0, high=255, shape=shape, dtype=np.uint8)
+ else:
+ return Box(np.repeat(minimum, dim).astype(dtype), np.repeat(maximum, dim).astype(dtype), dtype=dtype)
+
+ return observation_space
+
+
+def dmc2gym_state_space(dim, minimum=-np.inf, maximum=np.inf, dtype=np.float32) -> Box:
+ return Box(np.repeat(minimum, dim).astype(dtype), np.repeat(maximum, dim).astype(dtype), dtype=dtype)
+
+
+def dmc2gym_action_space(dim, minimum=-1, maximum=1, dtype=np.float32) -> Box:
+ return Box(np.repeat(minimum, dim).astype(dtype), np.repeat(maximum, dim).astype(dtype), dtype=dtype)
+
+
+def dmc2gym_reward_space(minimum=0, maximum=1, dtype=np.float32) -> Callable:
+
+ def reward_space(frame_skip=1) -> Box:
+ return Box(
+ np.repeat(minimum * frame_skip, 1).astype(dtype),
+ np.repeat(maximum * frame_skip, 1).astype(dtype),
+ dtype=dtype
+ )
+
+ return reward_space
+
+
+"""
+default observation, state, action, reward space for dmc2gym env
+"""
+dmc2gym_env_info = {
+ "ball_in_cup": {
+ "catch": {
+ "observation_space": dmc2gym_observation_space(8),
+ "state_space": dmc2gym_state_space(8),
+ "action_space": dmc2gym_action_space(2),
+ "reward_space": dmc2gym_reward_space()
+ }
+ },
+ "cartpole": {
+ "balance": {
+ "observation_space": dmc2gym_observation_space(5),
+ "state_space": dmc2gym_state_space(5),
+ "action_space": dmc2gym_action_space(1),
+ "reward_space": dmc2gym_reward_space()
+ },
+ "swingup": {
+ "observation_space": dmc2gym_observation_space(5),
+ "state_space": dmc2gym_state_space(5),
+ "action_space": dmc2gym_action_space(1),
+ "reward_space": dmc2gym_reward_space()
+ }
+ },
+ "cheetah": {
+ "run": {
+ "observation_space": dmc2gym_observation_space(17),
+ "state_space": dmc2gym_state_space(17),
+ "action_space": dmc2gym_action_space(6),
+ "reward_space": dmc2gym_reward_space()
+ }
+ },
+ "finger": {
+ "spin": {
+ "observation_space": dmc2gym_observation_space(9),
+ "state_space": dmc2gym_state_space(9),
+ "action_space": dmc2gym_action_space(1),
+ "reward_space": dmc2gym_reward_space()
+ }
+ },
+ "reacher": {
+ "easy": {
+ "observation_space": dmc2gym_observation_space(6),
+ "state_space": dmc2gym_state_space(6),
+ "action_space": dmc2gym_action_space(2),
+ "reward_space": dmc2gym_reward_space()
+ }
+ },
+ "walker": {
+ "walk": {
+ "observation_space": dmc2gym_observation_space(24),
+ "state_space": dmc2gym_state_space(24),
+ "action_space": dmc2gym_action_space(6),
+ "reward_space": dmc2gym_reward_space()
+ }
+ }
+}
+
+
+@ENV_REGISTRY.register('dmc2gym')
+class DMC2GymEnv(BaseEnv):
+
+ def __init__(self, cfg: dict = {}) -> None:
+ assert cfg.domain_name in dmc2gym_env_info, '{}/{}'.format(cfg.domain_name, dmc2gym_env_info.keys())
+ assert cfg.task_name in dmc2gym_env_info[
+ cfg.domain_name], '{}/{}'.format(cfg.task_name, dmc2gym_env_info[cfg.domain_name].keys())
+
+ # default config for dmc2gym env
+ self._cfg = {
+ "frame_skip": 4,
+ 'warp_frame': False,
+ 'scale': False,
+ 'clip_rewards': False,
+ 'action_repeat': 1,
+ "frame_stack": 3,
+ "from_pixels": True,
+ "visualize_reward": False,
+ "height": 84,
+ "width": 84,
+ "channels_first": True,
+ "resize": 84,
+ }
+
+ self._cfg.update(cfg)
+
+ self._init_flag = False
+
+ self._replay_path = None
+
+ self._observation_space = dmc2gym_env_info[cfg.domain_name][cfg.task_name]["observation_space"](
+ from_pixels=self._cfg["from_pixels"],
+ height=self._cfg["height"],
+ width=self._cfg["width"],
+ channels_first=self._cfg["channels_first"]
+ )
+ self._action_space = dmc2gym_env_info[cfg.domain_name][cfg.task_name]["action_space"]
+ self._reward_space = dmc2gym_env_info[cfg.domain_name][cfg.task_name]["reward_space"](self._cfg["frame_skip"])
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+
+ self._env = dmc2gym.make(
+ domain_name=self._cfg["domain_name"],
+ task_name=self._cfg["task_name"],
+ seed=1,
+ visualize_reward=self._cfg["visualize_reward"],
+ from_pixels=self._cfg["from_pixels"],
+ height=self._cfg["height"],
+ width=self._cfg["width"],
+ frame_skip=self._cfg["frame_skip"],
+ channels_first=self._cfg["channels_first"],
+ )
+
+ # optional env wrapper
+ if self._cfg['warp_frame']:
+ self._env = WarpFrameWrapper(self._env, size=self._cfg['resize'])
+ if self._cfg['scale']:
+ self._env = ScaledFloatFrameWrapper(self._env)
+ if self._cfg['clip_rewards']:
+ self._env = ClipRewardWrapper(self._env)
+ if self._cfg['action_repeat']:
+ self._env = ActionRepeatWrapper(self._env, self._cfg['action_repeat'])
+ if self._cfg['frame_stack'] > 1:
+ self._env = FrameStackWrapper(self._env, self._cfg['frame_stack'])
+
+ # set the obs, action space of wrapped env
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+
+ if self._replay_path is not None:
+ if gym.version.VERSION > '0.22.0':
+ self._env.metadata.update({'render_modes': ["rgb_array"]})
+ else:
+ self._env.metadata.update({'render.modes': ["rgb_array"]})
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+ self._env.start_video_recorder()
+
+ self._init_flag = True
+
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+
+ obs = to_ndarray(obs).astype(np.float32)
+
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ action = action.astype('float32')
+ action = affine_transform(action, min_val=self._env.action_space.low, max_val=self._env.action_space.high)
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+
+ obs = to_ndarray(obs).astype(np.float32)
+ rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transferred to a array with shape (1,)
+
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample().astype(np.float32)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine DeepMind Control Suite to gym Env: " + self._cfg["domain_name"] + ":" + self._cfg["task_name"]
diff --git a/DI-engine/dizoo/dmc2gym/envs/test_dmc2gym_env.py b/DI-engine/dizoo/dmc2gym/envs/test_dmc2gym_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..5245a7a86ab49f56642c3407d0f7a13c2278a631
--- /dev/null
+++ b/DI-engine/dizoo/dmc2gym/envs/test_dmc2gym_env.py
@@ -0,0 +1,49 @@
+import pytest
+import numpy as np
+from easydict import EasyDict
+from dizoo.dmc2gym.envs import DMC2GymEnv
+from torch import float32
+
+
+@pytest.mark.envtest
+class TestDMC2GymEnv:
+
+ def test_naive(self):
+ env = DMC2GymEnv(EasyDict({
+ "domain_name": "cartpole",
+ "task_name": "balance",
+ "frame_skip": 2,
+ }))
+ env.seed(314, dynamic_seed=False)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (
+ 3,
+ 100,
+ 100,
+ )
+ for _ in range(5):
+ env.reset()
+ np.random.seed(314)
+ print('=' * 60)
+ for i in range(10):
+ # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
+ # can generate legal random action.
+ if i < 5:
+ random_action = np.array(env.action_space.sample(), dtype=np.float32)
+ else:
+ random_action = env.random_action()
+ timestep = env.step(random_action)
+ print(timestep)
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (
+ 3,
+ 100,
+ 100,
+ )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.reward_space.low
+ assert timestep.reward <= env.reward_space.high
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/evogym/__init__.py b/DI-engine/dizoo/evogym/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/evogym/config/bridgewalker_ddpg_config.py b/DI-engine/dizoo/evogym/config/bridgewalker_ddpg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b712a0020ed80ea055ccb858e6e881a7ebb14f30
--- /dev/null
+++ b/DI-engine/dizoo/evogym/config/bridgewalker_ddpg_config.py
@@ -0,0 +1,69 @@
+from easydict import EasyDict
+
+bridgewalker_ddpg_config = dict(
+ exp_name='evogym_bridgewalker_ddpg_seed0',
+ env=dict(
+ env_id='BridgeWalker-v0',
+ robot='speed_bot',
+ robot_dir='../envs',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10,
+ manager=dict(shared_memory=True, ),
+ # The path to save the game replay
+ # replay_path='./evogym_walker_ddpg_seed0/video',
+ ),
+ policy=dict(
+ cuda=True,
+ # load_path="./evogym_walker_ddpg_seed0/ckpt/ckpt_best.pth.tar",
+ random_collect_size=1000,
+ model=dict(
+ obs_shape=59,
+ action_shape=10,
+ twin_critic=False,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99, # discount_factor: 0.97-0.99
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+bridgewalker_ddpg_config = EasyDict(bridgewalker_ddpg_config)
+main_config = bridgewalker_ddpg_config
+
+bridgewalker_ddpg_create_config = dict(
+ env=dict(
+ type='evogym',
+ import_names=['dizoo.evogym.envs.evogym_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='ddpg',
+ import_names=['ding.policy.ddpg'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+bridgewalker_ddpg_create_config = EasyDict(bridgewalker_ddpg_create_config)
+create_config = bridgewalker_ddpg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c evogym_bridgewalker_ddpg_config.py -s 0 --env-step 1e7`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/evogym/config/carrier_ppo_config.py b/DI-engine/dizoo/evogym/config/carrier_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..16a39193922b05f712259d2ecc02ee85363e2125
--- /dev/null
+++ b/DI-engine/dizoo/evogym/config/carrier_ppo_config.py
@@ -0,0 +1,65 @@
+from easydict import EasyDict
+
+carry_ppo_config = dict(
+ exp_name='evogym_carrier_ppo_seed1',
+ env=dict(
+ env_id='Carrier-v0',
+ robot='carry_bot',
+ robot_dir='./dizoo/evogym/envs',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10,
+ manager=dict(shared_memory=True, ),
+ # The path to save the game replay
+ # replay_path='./evogym_carry_ppo_seed0/video',
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ # load_path="./evogym_carry_ppo_seed0/ckpt/ckpt_best.pth.tar",
+ model=dict(
+ obs_shape=70,
+ action_shape=12,
+ action_space='continuous',
+ ),
+ action_space='continuous',
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=256,
+ learning_rate=3e-3,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=2048,
+ gae_lambda=0.97,
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ )
+)
+carry_ppo_config = EasyDict(carry_ppo_config)
+main_config = carry_ppo_config
+
+carry_ppo_create_config = dict(
+ env=dict(
+ type='evogym',
+ import_names=['dizoo.evogym.envs.evogym_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='ppo',
+ import_names=['ding.policy.ppo'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+carry_ppo_create_config = EasyDict(carry_ppo_create_config)
+create_config = carry_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c evogym_carry_ppo_config.py -s 0 --env-step 1e7`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/evogym/config/walker_ddpg_config.py b/DI-engine/dizoo/evogym/config/walker_ddpg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..58d1c65048e5ec90294949a21104b278e99b1812
--- /dev/null
+++ b/DI-engine/dizoo/evogym/config/walker_ddpg_config.py
@@ -0,0 +1,69 @@
+from easydict import EasyDict
+
+walker_ddpg_config = dict(
+ exp_name='evogym_walker_ddpg_seed0',
+ env=dict(
+ env_id='Walker-v0',
+ robot='speed_bot',
+ robot_dir='./dizoo/evogym/envs',
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=10,
+ manager=dict(shared_memory=True, ),
+ # The path to save the game replay
+ # replay_path='./evogym_walker_ddpg_seed0/video',
+ ),
+ policy=dict(
+ cuda=True,
+ # load_path="./evogym_walker_ddpg_seed0/ckpt/ckpt_best.pth.tar",
+ random_collect_size=1000,
+ model=dict(
+ obs_shape=58,
+ action_shape=10,
+ twin_critic=False,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99, # discount_factor: 0.97-0.99
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+walker_ddpg_config = EasyDict(walker_ddpg_config)
+main_config = walker_ddpg_config
+
+walker_ddpg_create_config = dict(
+ env=dict(
+ type='evogym',
+ import_names=['dizoo.evogym.envs.evogym_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='ddpg',
+ import_names=['ding.policy.ddpg'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+walker_ddpg_create_config = EasyDict(walker_ddpg_create_config)
+create_config = walker_ddpg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c evogym_walker_ddpg_config.py -s 0 --env-step 1e7`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/evogym/config/walker_ppo_config.py b/DI-engine/dizoo/evogym/config/walker_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2f1939a3d00490526ad8161586510c246fe6d9c3
--- /dev/null
+++ b/DI-engine/dizoo/evogym/config/walker_ppo_config.py
@@ -0,0 +1,65 @@
+from easydict import EasyDict
+
+walker_ppo_config = dict(
+ exp_name='evogym_walker_ppo_seed0',
+ env=dict(
+ env_id='Walker-v0',
+ robot='speed_bot',
+ robot_dir='./dizoo/evogym/envs',
+ collector_env_num=1,
+ evaluator_env_num=1,
+ n_evaluator_episode=1,
+ stop_value=10,
+ manager=dict(shared_memory=True, ),
+ # The path to save the game replay
+ # replay_path='./evogym_walker_ppo_seed0/video',
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ # load_path="./evogym_walker_ppo_seed0/ckpt/ckpt_best.pth.tar",
+ model=dict(
+ obs_shape=58,
+ action_shape=10,
+ action_space='continuous',
+ ),
+ action_space='continuous',
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=256,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.0,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=2048,
+ gae_lambda=0.97,
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ )
+)
+walker_ppo_config = EasyDict(walker_ppo_config)
+main_config = walker_ppo_config
+
+walker_ppo_create_config = dict(
+ env=dict(
+ type='evogym',
+ import_names=['dizoo.evogym.envs.evogym_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='ppo',
+ import_names=['ding.policy.ppo'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+walker_ppo_create_config = EasyDict(walker_ppo_create_config)
+create_config = walker_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c evogym_walker_ppo_config.py -s 0 --env-step 1e7`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/evogym/entry/walker_ppo_eval.py b/DI-engine/dizoo/evogym/entry/walker_ppo_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c1e7060925b9c7d629176e47aef10afaccc6de4
--- /dev/null
+++ b/DI-engine/dizoo/evogym/entry/walker_ppo_eval.py
@@ -0,0 +1,57 @@
+import os
+import gym
+import torch
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+from functools import partial
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.policy import PPOPolicy
+from ding.utils import set_pkg_seed
+
+from dizoo.evogym.config.walker_ppo_config import main_config, create_config
+
+
+def main(cfg, create_cfg, seed=0):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ PPOPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ create_cfg=create_cfg,
+ save_cfg=True
+ )
+
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ evaluator_env.enable_save_replay(cfg.env.replay_path)
+
+ # Set random seed for all package and instance
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ policy = PPOPolicy(cfg.policy)
+ policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
+
+ # evaluate
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator.eval()
+
+
+if __name__ == "__main__":
+ main(main_config, create_config, seed=0)
diff --git a/DI-engine/dizoo/evogym/envs/__init__.py b/DI-engine/dizoo/evogym/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ee6f41343b98dee1870b3023598dc9bc58caf0b
--- /dev/null
+++ b/DI-engine/dizoo/evogym/envs/__init__.py
@@ -0,0 +1 @@
+from .evogym_env import EvoGymEnv
diff --git a/DI-engine/dizoo/evogym/envs/evogym_env.py b/DI-engine/dizoo/evogym/envs/evogym_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..e0d8d59cdfc230472fb2636987321d2aba9131da
--- /dev/null
+++ b/DI-engine/dizoo/evogym/envs/evogym_env.py
@@ -0,0 +1,178 @@
+from typing import Any, Union, List, Optional
+import os
+import time
+import copy
+import numpy as np
+import gym
+from easydict import EasyDict
+
+from ding.envs import BaseEnv, BaseEnvTimestep, EvalEpisodeReturnWrapper
+from ding.envs.common.common_function import affine_transform
+from ding.torch_utils import to_ndarray, to_list
+from ding.utils import ENV_REGISTRY
+
+import evogym.envs
+from evogym import WorldObject, sample_robot
+from evogym.sim import EvoSim
+
+
+@ENV_REGISTRY.register('evogym')
+class EvoGymEnv(BaseEnv):
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ config = dict(
+ env_id='Walker-v0',
+ robot='speed_bot', # refer to 'world data' for more robots configurations
+ robot_h=5, # only used for random robots
+ robot_w=5, # only used for random robots
+ robot_pd=None, # only used for random robots, probability distributions of randomly generated components)
+ robot_dir="" # only used for defined robots, path to the robot config, env/world_data/my_bot.json
+ )
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ self._replay_path = None
+ if 'robot_dir' not in self._cfg.keys():
+ self._cfg = '../'
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = self._make_env()
+ self._env.observation_space.dtype = np.float32 # To unify the format of envs in DI-engine
+ self._observation_space = self._env.observation_space
+ self.num_actuators = self._env.get_actuator_indices('robot').size
+ # by default actions space is double (float64), create a new space with type of type float (float32)
+ self._action_space = gym.spaces.Box(low=0.6, high=1.6, shape=(self.num_actuators, ), dtype=np.float32)
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ if self._replay_path is not None:
+ gym.logger.set_level(gym.logger.DEBUG)
+ # make render mode compatible with gym
+ if gym.version.VERSION > '0.22.0':
+ self._env.metadata.update({'render_modes': ["rgb_array"]})
+ else:
+ self._env.metadata.update({'render.modes': ["rgb_array"]})
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}-{}'.format(id(self), time.time())
+ )
+ obs = self._env.reset()
+ obs = to_ndarray(obs).astype('float32')
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
+ action = to_ndarray(action).astype(np.float32)
+ obs, rew, done, info = self._env.step(action)
+ obs = to_ndarray(obs).astype(np.float32)
+ rew = to_ndarray([rew]).astype(np.float32)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def _make_env(self):
+ # robot configuration can be read from file or created randomly
+ if self._cfg.robot in [None, 'random']:
+ h, w = 5, 5
+ pd = None
+ if 'robot_h' in self._cfg.keys():
+ assert self._cfg.robot_h > 0
+ h = self._cfg.robot_h
+ if 'robot_w' in self._cfg.keys():
+ assert self._cfg.robot_w > 0
+ w = self._cfg.robot_w
+ if 'robot_pd' in self._cfg.keys():
+ assert isinstance(self._cfg.robot_pd, np.ndarray)
+ assert self._cfg.robot_w > 0
+ pd = self._cfg.robot_pd
+ structure = sample_robot((h, w), pd)
+ else:
+ structure = self.read_robot_from_file(self._cfg.robot, self._cfg.robot_dir)
+ env = gym.make(self._cfg.env_id, body=structure[0])
+ env = EvalEpisodeReturnWrapper(env)
+ return env
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ def random_action(self) -> np.ndarray:
+ return self.action_space.sample()
+
+ def __repr__(self) -> str:
+ return "DI-engine EvoGym Env({})".format(self._cfg.env_id)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_cfg = copy.deepcopy(cfg)
+ collector_env_num = collector_cfg.pop('collector_env_num', 1)
+ return [collector_cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_cfg = copy.deepcopy(cfg)
+ evaluator_env_num = evaluator_cfg.pop('evaluator_env_num', 1)
+ return [evaluator_cfg for _ in range(evaluator_env_num)]
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ @staticmethod
+ def read_robot_from_file(file_name, root_dir='../'):
+ possible_paths = [
+ os.path.join(file_name),
+ os.path.join(f'{file_name}.npz'),
+ os.path.join(f'{file_name}.json'),
+ os.path.join(root_dir, 'world_data', file_name),
+ os.path.join(root_dir, 'world_data', f'{file_name}.npz'),
+ os.path.join(root_dir, 'world_data', f'{file_name}.json'),
+ ]
+
+ best_path = None
+ for path in possible_paths:
+ if os.path.exists(path):
+ best_path = path
+ break
+
+ if best_path.endswith('json'):
+ robot_object = WorldObject.from_json(best_path)
+ return (robot_object.get_structure(), robot_object.get_connections())
+ if best_path.endswith('npz'):
+ structure_data = np.load(best_path)
+ structure = []
+ for key, value in structure_data.items():
+ structure.append(value)
+ return tuple(structure)
+ return None
diff --git a/DI-engine/dizoo/evogym/envs/test/test_evogym_env.py b/DI-engine/dizoo/evogym/envs/test/test_evogym_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b0f17410c9849b3e5003636742836f8eda75090
--- /dev/null
+++ b/DI-engine/dizoo/evogym/envs/test/test_evogym_env.py
@@ -0,0 +1,37 @@
+import pytest
+import numpy as np
+from easydict import EasyDict
+
+from ding.utils import set_pkg_seed
+from dizoo.evogym.envs import EvoGymEnv
+
+
+@pytest.mark.envtest
+@pytest.mark.parametrize('robot', ['speed_bot', 'random'])
+def test_evogym_env_eval_episode_return(robot):
+ set_pkg_seed(1234, use_cuda=False)
+ env = EvoGymEnv(EasyDict({'env_id': 'Walker-v0', 'robot': robot, 'robot_dir': '../'}))
+ env.seed(1234)
+ env.reset()
+ action_dim = env.action_space.shape
+ eval_episode_return = np.array([0.], dtype=np.float32)
+ if robot == 'speed_bot':
+ assert env.observation_space.shape == (58, )
+ assert action_dim == (10, )
+
+ while True:
+ action = np.random.random(size=action_dim)
+ timestep = env.step(action)
+ eval_episode_return += timestep.reward
+ print("{}(dtype: {})".format(timestep.reward, timestep.reward.dtype))
+ if timestep.done:
+ print(
+ "{}({}), {}({})".format(
+ timestep.info['eval_episode_return'], type(timestep.info['eval_episode_return']),
+ eval_episode_return, type(eval_episode_return)
+ )
+ )
+ # timestep.reward and the cumulative reward in wrapper EvalEpisodeReturn are not the same.
+ assert abs(timestep.info['eval_episode_return'].item() - eval_episode_return.item()) / \
+ abs(timestep.info['eval_episode_return'].item()) < 1e-5
+ break
diff --git a/DI-engine/dizoo/evogym/envs/test/visualize_simple_env.py b/DI-engine/dizoo/evogym/envs/test/visualize_simple_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..2203209fbe32176aa0436a8ca3d4bdf8dea6272f
--- /dev/null
+++ b/DI-engine/dizoo/evogym/envs/test/visualize_simple_env.py
@@ -0,0 +1,33 @@
+import gym
+from evogym import sample_robot
+from gym.wrappers import Monitor
+
+# import envs from the envs folder and register them
+import evogym.envs
+from dizoo.evogym.envs.viewer import DingEvoViewer
+from evogym.sim import EvoSim
+
+if __name__ == '__main__':
+ gym.logger.set_level(gym.logger.DEBUG)
+ # create a random robot
+ body, connections = sample_robot((5, 5))
+
+ # make the SimpleWalkingEnv using gym.make and with the robot information
+ #env = EvoGymEnv(EasyDict({'env_id': 'Walker-v0', 'robot': 'speed_bot', 'robot_dir': '../'}))
+ #env.enable_save_replay('video')
+
+ env = gym.make('Walker-v0', body=body)
+ env.default_viewer = DingEvoViewer(EvoSim(env.world))
+ env = Monitor(env, './video', force=True)
+ env.__class__.render = env.default_viewer.render
+ env.metadata['render.modes'] = 'rgb_array'
+
+ env.reset()
+ # step the environment for 200 iterations
+ for i in range(100):
+ action = env.action_space.sample()
+ ob, reward, done, info = env.step(action)
+ x = env.render()
+ if done:
+ env.reset()
+ env.close()
diff --git a/DI-engine/dizoo/evogym/envs/world_data/carry_bot.json b/DI-engine/dizoo/evogym/envs/world_data/carry_bot.json
new file mode 100644
index 0000000000000000000000000000000000000000..bd9af791d9b7051c62eeea5d9ae896f867b1005f
--- /dev/null
+++ b/DI-engine/dizoo/evogym/envs/world_data/carry_bot.json
@@ -0,0 +1,142 @@
+{
+ "grid_width": 5,
+ "grid_height": 5,
+ "objects": {
+ "new_object_1": {
+ "indices": [
+ 20,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 5,
+ 6,
+ 8,
+ 9,
+ 0,
+ 1,
+ 3,
+ 4
+ ],
+ "types": [
+ 1,
+ 3,
+ 3,
+ 3,
+ 3,
+ 3,
+ 1,
+ 3,
+ 3,
+ 3,
+ 1,
+ 2,
+ 4,
+ 4,
+ 2,
+ 2,
+ 4,
+ 4,
+ 2
+ ],
+ "neighbors": {
+ "20": [
+ 15
+ ],
+ "15": [
+ 10,
+ 16,
+ 20
+ ],
+ "16": [
+ 15,
+ 11,
+ 17
+ ],
+ "17": [
+ 16,
+ 12,
+ 18
+ ],
+ "18": [
+ 17,
+ 13,
+ 19
+ ],
+ "19": [
+ 18,
+ 14
+ ],
+ "10": [
+ 5,
+ 11,
+ 15
+ ],
+ "11": [
+ 10,
+ 6,
+ 12,
+ 16
+ ],
+ "12": [
+ 11,
+ 13,
+ 17
+ ],
+ "13": [
+ 12,
+ 14,
+ 8,
+ 18
+ ],
+ "14": [
+ 9,
+ 13,
+ 19
+ ],
+ "5": [
+ 10,
+ 0,
+ 6
+ ],
+ "6": [
+ 5,
+ 1,
+ 11
+ ],
+ "8": [
+ 9,
+ 3,
+ 13
+ ],
+ "9": [
+ 14,
+ 4,
+ 8
+ ],
+ "0": [
+ 5,
+ 1
+ ],
+ "1": [
+ 0,
+ 6
+ ],
+ "3": [
+ 4,
+ 8
+ ],
+ "4": [
+ 9,
+ 3
+ ]
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/DI-engine/dizoo/evogym/envs/world_data/simple_evironment.json b/DI-engine/dizoo/evogym/envs/world_data/simple_evironment.json
new file mode 100644
index 0000000000000000000000000000000000000000..b820344789236b0228ee5c212c923af47003b56a
--- /dev/null
+++ b/DI-engine/dizoo/evogym/envs/world_data/simple_evironment.json
@@ -0,0 +1,178 @@
+{
+ "grid_width": 20,
+ "grid_height": 10,
+ "objects": {
+ "box": {
+ "indices": [
+ 164,
+ 165,
+ 166,
+ 144,
+ 145,
+ 146
+ ],
+ "types": [
+ 2,
+ 2,
+ 2,
+ 2,
+ 2,
+ 2
+ ],
+ "neighbors": {
+ "164": [
+ 144,
+ 165
+ ],
+ "165": [
+ 164,
+ 166,
+ 145
+ ],
+ "166": [
+ 146,
+ 165
+ ],
+ "144": [
+ 164,
+ 145
+ ],
+ "145": [
+ 144,
+ 146,
+ 165
+ ],
+ "146": [
+ 145,
+ 166
+ ]
+ }
+ },
+ "ground": {
+ "indices": [
+ 0,
+ 1,
+ 2,
+ 3,
+ 4,
+ 5,
+ 6,
+ 7,
+ 8,
+ 9,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 15,
+ 16,
+ 17,
+ 18,
+ 19
+ ],
+ "types": [
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5,
+ 5
+ ],
+ "neighbors": {
+ "0": [
+ 1
+ ],
+ "1": [
+ 0,
+ 2
+ ],
+ "2": [
+ 1,
+ 3
+ ],
+ "3": [
+ 2,
+ 4
+ ],
+ "4": [
+ 3,
+ 5
+ ],
+ "5": [
+ 4,
+ 6
+ ],
+ "6": [
+ 5,
+ 7
+ ],
+ "7": [
+ 6,
+ 8
+ ],
+ "8": [
+ 7,
+ 9
+ ],
+ "9": [
+ 8,
+ 10
+ ],
+ "10": [
+ 9,
+ 11
+ ],
+ "11": [
+ 10,
+ 12
+ ],
+ "12": [
+ 11,
+ 13
+ ],
+ "13": [
+ 12,
+ 14
+ ],
+ "14": [
+ 13,
+ 15
+ ],
+ "15": [
+ 14,
+ 16
+ ],
+ "16": [
+ 15,
+ 17
+ ],
+ "17": [
+ 16,
+ 18
+ ],
+ "18": [
+ 17,
+ 19
+ ],
+ "19": [
+ 18
+ ]
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/DI-engine/dizoo/evogym/envs/world_data/speed_bot.json b/DI-engine/dizoo/evogym/envs/world_data/speed_bot.json
new file mode 100644
index 0000000000000000000000000000000000000000..9f7694774a245760cc5c96104ed01792f7a5d31b
--- /dev/null
+++ b/DI-engine/dizoo/evogym/envs/world_data/speed_bot.json
@@ -0,0 +1,120 @@
+{
+ "grid_width": 5,
+ "grid_height": 5,
+ "objects": {
+ "new_object_1": {
+ "indices": [
+ 15,
+ 16,
+ 17,
+ 18,
+ 19,
+ 10,
+ 11,
+ 12,
+ 13,
+ 14,
+ 5,
+ 6,
+ 8,
+ 9,
+ 0,
+ 4
+ ],
+ "types": [
+ 2,
+ 3,
+ 3,
+ 3,
+ 1,
+ 1,
+ 3,
+ 3,
+ 3,
+ 1,
+ 3,
+ 1,
+ 1,
+ 3,
+ 3,
+ 3
+ ],
+ "neighbors": {
+ "15": [
+ 10,
+ 16
+ ],
+ "16": [
+ 15,
+ 11,
+ 17
+ ],
+ "17": [
+ 16,
+ 12,
+ 18
+ ],
+ "18": [
+ 17,
+ 13,
+ 19
+ ],
+ "19": [
+ 18,
+ 14
+ ],
+ "10": [
+ 5,
+ 11,
+ 15
+ ],
+ "11": [
+ 10,
+ 6,
+ 12,
+ 16
+ ],
+ "12": [
+ 11,
+ 13,
+ 17
+ ],
+ "13": [
+ 12,
+ 14,
+ 8,
+ 18
+ ],
+ "14": [
+ 9,
+ 13,
+ 19
+ ],
+ "5": [
+ 10,
+ 0,
+ 6
+ ],
+ "6": [
+ 5,
+ 11
+ ],
+ "8": [
+ 9,
+ 13
+ ],
+ "9": [
+ 14,
+ 8,
+ 4
+ ],
+ "0": [
+ 5
+ ],
+ "4": [
+ 9
+ ]
+ }
+ }
+ }
+}
\ No newline at end of file
diff --git a/DI-engine/dizoo/gfootball/README.md b/DI-engine/dizoo/gfootball/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..5b66855f20eb78bc017218ecb0f4a5629f9cabc1
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/README.md
@@ -0,0 +1,129 @@
+# DI-engine Gfootball
+
+
+## Structure
+
+``dizoo/gfootball``目录的文件结构大致如下:
+
+```
+├── README.md
+├── __init__.py
+├── config
+│ ├── gfootball_counter_mappo_config.py
+│ ├── gfootball_counter_masac_config.py
+│ ├── gfootball_keeper_mappo_config.py
+│ └── gfootball_keeper_masac_config.py
+├── entry
+│ ├── __init__.py
+│ ├── gfootball_bc_config.py
+│ ├── gfootball_bc_kaggle5th_main.py
+│ ├── gfootball_bc_rule_lt0_main.py
+│ ├── gfootball_bc_rule_main.py
+│ ├── gfootball_dqn_config.py
+│ └── parallel
+│ ├── show_dataset.py
+│ ├── test_accuracy.py
+├── envs
+│ ├── __init__.py
+│ ├── action
+│ ├── fake_dataset.py
+│ ├── gfootball_academy_env.py
+│ ├── gfootball_env.py
+│ ├── gfootballsp_env.py
+│ ├── obs
+│ ├── reward
+│ └── tests
+├── gfootball.gif
+├── model
+│ ├── __init__.py
+│ ├── bots
+│ ├── conv1d
+│ └── q_network
+├── policy
+│ ├── __init__.py
+│ └── ppo_lstm.py
+└── replay.py
+```
+
+
+其中:
+
+- config: 存放``gfootball_academy_env``环境对应的多智能体算法配置
+
+- entry:存放``gfootball_env``环境对应的模仿学习和强化学习算法配置和相关工具函数
+
+- envs:存放gfootball环境: ``gfootball_academy_env``, ``gfootball_env``, ``gfootballsp_env`` 以及 ``obs``, ``action``, ``reward``处理函数
+
+- model:存放gfootball模型:
+
+ - q_network:用于进行模仿学习和强化学习的神经网络模型及其默认设置
+
+ - conv1d:用于进行``ppo self play training``的神经网络模型
+
+ - bots:gfootball环境上已有的基于规则或学习好的专家模型
+
+
+
+## Environment
+
+Gfootball 环境即 Google Research Football 环境,其开源代码和安装方式参见: https://github.com/google-research/football.
+
+DI-engine 对 Google Research Football 环境进行了封装,使之符合 DI-engine 环境对应接口,方便使用。具体使用方式参考 ``dizoo/gfootball/envs/tests/test_env_gfootball.py``
+
+目前 DI-engine 的 Gfootball 环境支持与内置 AI 进行对战,后续会设计接口支持双方对战。
+
+目前 DI-engine 的 Gfootball 环境支持保存 replay,环境 config 中设置 ``save_replay=True`` 后会自动保存 replay,包括一个.avi视频文件和一个.dump文件,保存在当前文件目录的 ``./tmp/football`` 文件夹下。.avi形式的视频默认为2d表示。
+
+
+
+如果需要立体表示(真实游戏画面),可以找到对应 episode 的 .dump文件,然后使用 ``replay.py`` 渲染视频,示例如下:
+
+```python
+python replay.py --trace_file=\tmp\football\episode_done_20210331-132800614938.dump
+```
+
+
+
+## Model
+
+Model分为bot部分和模型部分。
+
+### bots
+
+bots目前包括:
+
+*注:所有bot均来源于Google Research Football with Manchester City F.C. 的kaggle比赛社区。*
+
+- 基于规则的`rule_based_bot_model`。Hard code 机器人来源于 kaggle 比赛的社区,这一机器人为社区RL bot提供了众多用于模仿学习的素材。在DI-engine中此bot的代码修改自 https://www.kaggle.com/eugenkeil/simple-baseline-bot。
+
+- Kaggle比赛第五名的RL模型 ``kaggle_5th_place_model.py``,在 DI-engine 中用于提供模仿学习素材。我们的代码修改自 https://github.com/YuriCat/TamakEriFever ,ikki407 & yuricat关于这份优秀工作的介绍详见 https://www.kaggle.com/c/google-football/discussion/203412 。
+
+### q_network
+
+``q_network``路径下存放模仿学习和强化学习的模型及其默认设置。
+
+### conv1d
+
+对同队队友采用 ``conv1d`` 进行特征提取的模型,并使用 LSTM。在此模型上使用 selfplay 训练100k episode后对战 built-in hard AI 可以得到80%以上的胜率。最终训练得到的模型参见:https://drive.google.com/file/d/1O1I3Mcjnh9mwAVDyqhp5coksTDPqMZmG/view?usp=sharing
+
+我们同时提供了使用此模型训练得到的足球AI与游戏内置的AI对战一局的录像,左侧队伍是由我们训练得到的模型控制,以4-0战胜了内置AI (右侧队伍)。该录像的连接如下:
+https://drive.google.com/file/d/1n-_bF63IQ49b-p0nEZt_NPTL-dmNkoKs/view?usp=sharing
+
+## 入口文件
+
+### Imitation Leaning (Behaviour Cloning)
+
+目前编写了模仿学习相关入口,以``q_network``路径下的``FootballNaiveQ``作为Q网络/策略网络,以基于规则的模型``rule_based_bot_model`` 和 Kaggle比赛第五名的RL模型 ``kaggle_5th_place_model.py`` 为标签进行监督学习,具体请见`dizoo/gfootball/entry`下相关入口文件:
+
+- `gfootball_bc_rule_main.py`
+- `gfootball_bc_rule_lt0_main.py`
+- `gfootball_bc_kaggle5th_main.py`
+
+### Reinforcement learning
+
+目前使用DQN算法,具体请参见`dizoo/gfootball/entry`下相关入口文件:
+- `gfootball_dqn_config.py`
+
+### Self Play PPO (work in progress)
+
+使用self-play的PPO算法进行训练的入口,使用DI-engine提供的league模块和PPO算法。具体请见`dizoo/gfootball/entry/parallel/gfootball_ppo_parallel_config.py`入口。
diff --git a/DI-engine/dizoo/gfootball/__init__.py b/DI-engine/dizoo/gfootball/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/gfootball/config/gfootball_counter_mappo_config.py b/DI-engine/dizoo/gfootball/config/gfootball_counter_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..59c74bf5bd8c4a2a0e58ef674989b30fe54c3aa9
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/config/gfootball_counter_mappo_config.py
@@ -0,0 +1,76 @@
+from easydict import EasyDict
+
+agent_num = 4
+obs_dim = 34
+collector_env_num = 8
+evaluator_env_num = 32
+
+main_config = dict(
+ exp_name='gfootball_counter_mappo_seed0',
+ env=dict(
+ env_name='academy_counterattack_hard',
+ agent_num=agent_num,
+ obs_dim=obs_dim,
+ n_evaluator_episode=32,
+ stop_value=1,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ # share_weight=True,
+ multi_agent=True,
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ agent_num=agent_num,
+ # (int) obs_shape: The shape of observation of each agent.
+ # (int) global_obs_shape: The shape of global observation.
+ agent_obs_shape=obs_dim,
+ global_obs_shape=int(obs_dim * 2),
+ # (int) action_shape: The number of action which each agent can take.
+ action_shape=19,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=3200,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.05,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=50, )),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='gfootball-academy',
+ import_names=['dizoo.gfootball.envs.gfootball_academy_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c gfootball_counter_mappo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/gfootball/config/gfootball_counter_masac_config.py b/DI-engine/dizoo/gfootball/config/gfootball_counter_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..3abff5671337a475e5d1ac8b506387631ff720eb
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/config/gfootball_counter_masac_config.py
@@ -0,0 +1,87 @@
+from easydict import EasyDict
+
+agent_num = 4
+obs_dim = 34
+collector_env_num = 8
+evaluator_env_num = 32
+
+gfootball_keeper_masac_default_config = dict(
+ exp_name='gfootball_counter_masac_seed0',
+ env=dict(
+ env_name='academy_counterattack_hard',
+ agent_num=agent_num,
+ obs_dim=obs_dim,
+ n_evaluator_episode=32,
+ stop_value=1,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ # share_weight=True,
+ random_collect_size=int(1e4),
+ model=dict(
+ agent_num=agent_num,
+ agent_obs_shape=34,
+ global_obs_shape=68,
+ action_shape=19,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-5,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ log_space=True,
+ ),
+ collect=dict(
+ env_num=collector_env_num,
+ n_sample=1600,
+ unroll_len=1,
+ ),
+ eval=dict(
+ evaluator=dict(eval_freq=50, ),
+ env_num=evaluator_env_num,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=int(5e4),
+ ),
+ replay_buffer=dict(replay_buffer_size=int(1e6), ),
+ ),
+ ),
+)
+
+gfootball_keeper_masac_default_config = EasyDict(gfootball_keeper_masac_default_config)
+main_config = gfootball_keeper_masac_default_config
+
+gfootball_keeper_masac_default_create_config = dict(
+ env=dict(
+ type='gfootball-academy',
+ import_names=['dizoo.gfootball.envs.gfootball_academy_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sac_discrete', ),
+)
+gfootball_keeper_masac_default_create_config = EasyDict(gfootball_keeper_masac_default_create_config)
+create_config = gfootball_keeper_masac_default_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c gfootball_counter_masac_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/gfootball/config/gfootball_keeper_mappo_config.py b/DI-engine/dizoo/gfootball/config/gfootball_keeper_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5623d6c801a79e0dbf5405a34cbb42c614d3b2d2
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/config/gfootball_keeper_mappo_config.py
@@ -0,0 +1,73 @@
+from easydict import EasyDict
+
+agent_num = 3
+obs_dim = 26
+collector_env_num = 8
+evaluator_env_num = 32
+
+main_config = dict(
+ exp_name='gfootball_keeper_mappo_seed0',
+ env=dict(
+ env_name='academy_3_vs_1_with_keeper',
+ agent_num=agent_num,
+ obs_dim=obs_dim,
+ n_evaluator_episode=32,
+ stop_value=1,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ model=dict(
+ # (int) obs_shape: The shapeof observation of each agent.
+ # (int) global_obs_shape: The shape of global observation.
+ agent_obs_shape=obs_dim,
+ global_obs_shape=int(obs_dim * 2),
+ # (int) action_shape: The number of action which each agent can take.
+ action_shape=19,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=3200,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.05,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=50, )),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='gfootball-academy',
+ import_names=['dizoo.gfootball.envs.gfootball_academy_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c gfootball_keeper_mappo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/gfootball/config/gfootball_keeper_masac_config.py b/DI-engine/dizoo/gfootball/config/gfootball_keeper_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0cd01538667dd0844404fde2454b34d628be9da
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/config/gfootball_keeper_masac_config.py
@@ -0,0 +1,85 @@
+from easydict import EasyDict
+
+agent_num = 3
+obs_dim = 26
+collector_env_num = 8
+evaluator_env_num = 32
+
+gfootball_keeper_masac_default_config = dict(
+ exp_name='gfootball_keeper_masac_seed0',
+ env=dict(
+ env_name='academy_3_vs_1_with_keeper',
+ agent_num=agent_num,
+ obs_dim=obs_dim,
+ n_evaluator_episode=32,
+ stop_value=1,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=int(1e4),
+ model=dict(
+ agent_obs_shape=obs_dim,
+ global_obs_shape=int(obs_dim * 2),
+ action_shape=19,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-5,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ log_space=True,
+ ),
+ collect=dict(
+ env_num=collector_env_num,
+ n_sample=1600,
+ unroll_len=1,
+ ),
+ eval=dict(
+ evaluator=dict(eval_freq=50, ),
+ env_num=evaluator_env_num,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=int(5e4),
+ ),
+ replay_buffer=dict(replay_buffer_size=int(1e6), ),
+ ),
+ ),
+)
+
+gfootball_keeper_masac_default_config = EasyDict(gfootball_keeper_masac_default_config)
+main_config = gfootball_keeper_masac_default_config
+
+gfootball_keeper_masac_default_create_config = dict(
+ env=dict(
+ type='gfootball-academy',
+ import_names=['dizoo.gfootball.envs.gfootball_academy_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sac_discrete', ),
+)
+gfootball_keeper_masac_default_create_config = EasyDict(gfootball_keeper_masac_default_create_config)
+create_config = gfootball_keeper_masac_default_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c gfootball_keeper_masac_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/gfootball/entry/__init__.py b/DI-engine/dizoo/gfootball/entry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/gfootball/entry/gfootball_bc_config.py b/DI-engine/dizoo/gfootball/entry/gfootball_bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c24d70786379c4ca1fa3486f3f8fa6b9104538a
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/entry/gfootball_bc_config.py
@@ -0,0 +1,72 @@
+"""
+Overview:
+ Here is the behaviour cloning (BC) default config for gfootball.
+ For main entry, please refer to the gfootball_bc_rule_main.py,
+ gfootball_bc_rule_lt0_main.py, gfootball_bc_kaggle5th_main.py in the same directory.
+"""
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 5
+
+gfootball_bc_config = dict(
+ exp_name='gfootball_bc_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=999, # Don't stop until training epochs
+ env_name="11_vs_11_easy_stochastic",
+ # env_name="11_vs_11_stochastic", # default: medium
+ # env_name="11_vs_11_hard_stochastic",
+ save_replay_gif=False,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ env_name='gfootball',
+ continuous=False,
+ # action_shape is effective only when continuous=False
+ action_shape=19,
+ show_train_test_accuracy=False,
+ # Note, only if show_train_test_accuracy=True, we will test accuracy in train dataset and validation dataset
+ # use the pre-trained BC model in the path .
+ # Users should add their own BC model path here. Model path should lead to a model.
+ # Absolute path is recommended. In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ bc_model_path='bc_model_path_placeholder',
+ cuda=True,
+ model=dict(),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=512,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ weight_decay=None,
+ ce_label_smooth=False,
+ show_accuracy=False,
+ ),
+ collect=dict(n_sample=4096, ),
+ eval=dict(evaluator=dict(eval_freq=1000)),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=int(1e6), ),
+ ),
+ ),
+)
+gfootball_bc_config = EasyDict(gfootball_bc_config)
+main_config = gfootball_bc_config
+
+gfootball_bc_create_config = dict(
+ env=dict(
+ type='gfootball',
+ import_names=['dizoo.gfootball.envs.gfootball_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='bc'),
+)
+gfootball_bc_create_config = EasyDict(gfootball_bc_create_config)
+create_config = gfootball_bc_create_config
diff --git a/DI-engine/dizoo/gfootball/entry/gfootball_bc_kaggle5th_main.py b/DI-engine/dizoo/gfootball/entry/gfootball_bc_kaggle5th_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..711302709fc98a8228b0ad1555a16cb07021b3ad
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/entry/gfootball_bc_kaggle5th_main.py
@@ -0,0 +1,62 @@
+"""
+Overview:
+ Here is the behaviour cloning (BC) main entry for gfootball.
+ We first collect demo data using ``FootballKaggle5thPlaceModel``, then train the BC model
+ using the collected demo data,
+ and (optional) test accuracy in train dataset and test dataset of the trained BC model
+"""
+from copy import deepcopy
+import os
+from ding.entry import serial_pipeline_bc, collect_demo_data
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from dizoo.gfootball.entry.gfootball_bc_config import gfootball_bc_config, gfootball_bc_create_config
+from dizoo.gfootball.model.q_network.football_q_network import FootballNaiveQ
+from dizoo.gfootball.model.bots.kaggle_5th_place_model import FootballKaggle5thPlaceModel
+
+path = os.path.abspath(__file__)
+dir_path = os.path.dirname(path)
+
+# in gfootball env: 3000 transitions = one episode
+# 3e5 transitions = 100 episode, The memory needs about 180G
+seed = 0
+gfootball_bc_config.exp_name = 'gfootball_bc_kaggle5th_seed0'
+demo_transitions = int(3e5) # key hyper-parameter
+data_path_transitions = dir_path + f'/gfootball_kaggle5th_{demo_transitions}-demo-transitions.pkl'
+"""
+phase 1: collect demo data utilizing ``FootballKaggle5thPlaceModel``
+"""
+train_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)]
+input_cfg = train_config
+if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+else:
+ cfg, create_cfg = input_cfg
+create_cfg.policy.type = create_cfg.policy.type + '_command'
+cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+football_kaggle_5th_place_model = FootballKaggle5thPlaceModel()
+expert_policy = create_policy(
+ cfg.policy, model=football_kaggle_5th_place_model, enable_field=['learn', 'collect', 'eval', 'command']
+)
+
+# collect expert demo data
+state_dict = expert_policy.collect_mode.state_dict()
+collect_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)]
+collect_demo_data(
+ collect_config,
+ seed=seed,
+ expert_data_path=data_path_transitions,
+ collect_count=demo_transitions,
+ model=football_kaggle_5th_place_model,
+ state_dict=state_dict,
+)
+"""
+phase 2: BC training
+"""
+bc_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)]
+bc_config[0].policy.learn.train_epoch = 1000 # key hyper-parameter
+football_naive_q = FootballNaiveQ()
+_, converge_stop_flag = serial_pipeline_bc(
+ bc_config, seed=seed, data_path=data_path_transitions, model=football_naive_q
+)
diff --git a/DI-engine/dizoo/gfootball/entry/gfootball_bc_rule_lt0_main.py b/DI-engine/dizoo/gfootball/entry/gfootball_bc_rule_lt0_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a385cb8bdf88570883b816acc0d2f42676235ae
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/entry/gfootball_bc_rule_lt0_main.py
@@ -0,0 +1,105 @@
+"""
+Overview:
+ Here is the behaviour cloning (BC) main entry for gfootball.
+ We first collect demo data using rule model, then train the BC model
+ using the demo data whose return is larger than 0,
+ and (optional) test accuracy in train dataset and test dataset of the trained BC model
+"""
+from copy import deepcopy
+import os
+import torch
+import logging
+import test_accuracy
+from ding.entry import serial_pipeline_bc, collect_episodic_demo_data, episode_to_transitions_filter, eval
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from dizoo.gfootball.entry.gfootball_bc_config import gfootball_bc_config, gfootball_bc_create_config
+from dizoo.gfootball.model.q_network.football_q_network import FootballNaiveQ
+from dizoo.gfootball.model.bots.rule_based_bot_model import FootballRuleBaseModel
+
+path = os.path.abspath(__file__)
+dir_path = os.path.dirname(path)
+logging.basicConfig(level=logging.INFO)
+
+# Note: in gfootball env, 3000 transitions = one episode,
+# 3e5 transitions = 200 episode, the memory needs about 350G.
+seed = 0
+gfootball_bc_config.exp_name = 'gfootball_bc_rule_200ep_lt0_seed0'
+demo_episodes = 200 # key hyper-parameter
+data_path_episode = dir_path + f'/gfootball_rule_{demo_episodes}eps.pkl'
+data_path_transitions_lt0 = dir_path + f'/gfootball_rule_{demo_episodes}eps_transitions_lt0.pkl'
+"""
+phase 1: collect demo data utilizing rule model
+"""
+input_cfg = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)]
+if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+else:
+ cfg, create_cfg = input_cfg
+cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg)
+
+football_rule_base_model = FootballRuleBaseModel()
+expert_policy = create_policy(cfg.policy, model=football_rule_base_model, enable_field=['learn', 'collect', 'eval'])
+
+# collect rule/expert demo data
+state_dict = expert_policy.collect_mode.state_dict()
+collect_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)]
+eval_config = deepcopy(collect_config)
+
+# eval demo model
+# if save replay
+# eval(eval_config, seed=seed, model=football_rule_base_model, replay_path=dir_path + f'/gfootball_rule_replay/')
+# if not save replay
+# eval(eval_config, seed=seed, model=football_rule_base_model, state_dict=state_dict)
+
+# collect demo data
+collect_episodic_demo_data(
+ collect_config,
+ seed=seed,
+ expert_data_path=data_path_episode,
+ collect_count=demo_episodes,
+ model=football_rule_base_model,
+ state_dict=state_dict
+)
+# Note: only use the episode whose return is larger than 0 as demo data
+episode_to_transitions_filter(
+ data_path=data_path_episode, expert_data_path=data_path_transitions_lt0, nstep=1, min_episode_return=1
+)
+"""
+phase 2: BC training
+"""
+bc_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)]
+bc_config[0].policy.learn.train_epoch = 1000 # key hyper-parameter
+football_naive_q = FootballNaiveQ()
+
+_, converge_stop_flag = serial_pipeline_bc(
+ bc_config, seed=seed, data_path=data_path_transitions_lt0, model=football_naive_q
+)
+
+if bc_config[0].policy.show_train_test_accuracy:
+ """
+ phase 3: test accuracy in train dataset and test dataset
+ """
+ bc_model_path = bc_config[0].policy.bc_model_path
+
+ # load trained model
+ bc_config[0].policy.learn.batch_size = int(3000) # the total dataset
+ state_dict = torch.load(bc_model_path)
+ football_naive_q.load_state_dict(state_dict['model'])
+ policy = create_policy(cfg.policy, model=football_naive_q, enable_field=['eval'])
+
+ # calculate accuracy in train dataset
+ print('==' * 10)
+ print('calculate accuracy in train dataset')
+ print('==' * 10)
+ # Users should add their own bc train_data_path here. Absolute path is recommended.
+ train_data_path = dir_path + f'/gfootball_rule_100eps_transitions_lt0_train.pkl'
+ test_accuracy.test_accuracy_in_dataset(train_data_path, cfg.policy.learn.batch_size, policy)
+
+ # calculate accuracy in test dataset
+ print('==' * 10)
+ print('calculate accuracy in test dataset')
+ print('==' * 10)
+ # Users should add their own bc test_data_path here. Absolute path is recommended.
+ test_data_path = dir_path + f'/gfootball_rule_50eps_transitions_lt0_test.pkl'
+ test_accuracy.test_accuracy_in_dataset(test_data_path, cfg.policy.learn.batch_size, policy)
diff --git a/DI-engine/dizoo/gfootball/entry/gfootball_bc_rule_main.py b/DI-engine/dizoo/gfootball/entry/gfootball_bc_rule_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f824788e62c6ef00c7e2aea4e8490e5f17b6fbc
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/entry/gfootball_bc_rule_main.py
@@ -0,0 +1,100 @@
+"""
+Overview:
+ Here is the behaviour cloning (BC) main entry for gfootball.
+ We first collect demo data using rule model, then train the bc model using the demo data,
+ and (optional) test accuracy in train dataset and test dataset of the trained bc model
+"""
+from copy import deepcopy
+import os
+import torch
+import logging
+import test_accuracy
+from ding.entry import serial_pipeline_bc, collect_demo_data
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from dizoo.gfootball.entry.gfootball_bc_config import gfootball_bc_config, gfootball_bc_create_config
+from dizoo.gfootball.model.q_network.football_q_network import FootballNaiveQ
+from dizoo.gfootball.model.bots.rule_based_bot_model import FootballRuleBaseModel
+
+path = os.path.abspath(__file__)
+dir_path = os.path.dirname(path)
+logging.basicConfig(level=logging.INFO)
+
+# Note: in gfootball env, 3000 transitions = one episode
+# 3e5 transitions = 100 episode, the memory needs about 180G
+seed = 0
+gfootball_bc_config.exp_name = 'gfootball_bc_rule_seed0_100eps_epc1000_bs512'
+demo_transitions = int(3e5) # key hyper-parameter
+
+data_path_transitions = dir_path + f'/gfootball_rule_{demo_transitions}-demo-transitions.pkl'
+"""
+phase 1: collect demo data utilizing rule model
+"""
+input_cfg = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)]
+if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+else:
+ cfg, create_cfg = input_cfg
+cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg)
+
+football_rule_base_model = FootballRuleBaseModel()
+expert_policy = create_policy(cfg.policy, model=football_rule_base_model, enable_field=['learn', 'collect', 'eval'])
+
+# collect rule/expert demo data
+state_dict = expert_policy.collect_mode.state_dict()
+collect_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)]
+
+# eval demo model
+# eval_config = deepcopy(collect_config)
+# # if save replay
+# eval(eval_config, seed=seed, model=football_rule_base_model, replay_path=dir_path + f'/gfootball_rule_replay/')
+# # if not save replay
+# eval(eval_config, seed=seed, model=football_rule_base_model, state_dict=state_dict)
+
+# collect demo data
+collect_demo_data(
+ collect_config,
+ seed=seed,
+ expert_data_path=data_path_transitions,
+ collect_count=demo_transitions,
+ model=football_rule_base_model,
+ state_dict=state_dict,
+)
+"""
+phase 2: BC training
+"""
+bc_config = [deepcopy(gfootball_bc_config), deepcopy(gfootball_bc_create_config)]
+bc_config[0].policy.learn.train_epoch = 1000 # key hyper-parameter
+football_naive_q = FootballNaiveQ()
+
+_, converge_stop_flag = serial_pipeline_bc(
+ bc_config, seed=seed, data_path=data_path_transitions, model=football_naive_q
+)
+
+if bc_config[0].policy.show_train_test_accuracy:
+ """
+ phase 3: test accuracy in train dataset and test dataset
+ """
+ bc_model_path = bc_config[0].policy.bc_model_path
+
+ # load trained bc model
+ bc_config[0].policy.learn.batch_size = int(3000)
+ state_dict = torch.load(bc_model_path)
+ football_naive_q.load_state_dict(state_dict['model'])
+ policy = create_policy(cfg.policy, model=football_naive_q, enable_field=['eval'])
+
+ # calculate accuracy in train dataset
+ print('==' * 10)
+ print('calculate accuracy in train dataset')
+ print('==' * 10)
+ # Users should add their own bc train_data_path here. Absolute path is recommended.
+ train_data_path = dir_path + f'/gfootball_rule_300000-demo-transitions_train.pkl'
+ test_accuracy.test_accuracy_in_dataset(train_data_path, cfg.policy.learn.batch_size, policy)
+
+ # calculate accuracy in test dataset
+ print('==' * 10)
+ print('calculate accuracy in test dataset')
+ print('==' * 10)
+ # Users should add their own bc test_data_path here. Absolute path is recommended.
+ test_data_path = dir_path + f'/gfootball_rule_150000-demo-transitions_test.pkl'
+ test_accuracy.test_accuracy_in_dataset(test_data_path, cfg.policy.learn.batch_size, policy)
diff --git a/DI-engine/dizoo/gfootball/entry/gfootball_dqn_config.py b/DI-engine/dizoo/gfootball/entry/gfootball_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a05edaa3a3b968dfd91ad26701e5145fe26bc6e
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/entry/gfootball_dqn_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 5
+
+gfootball_dqn_main_config = dict(
+ exp_name='gfootball_dqn_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=999,
+ env_name="11_vs_11_easy_stochastic",
+ # env_name="11_vs_11_stochastic", # default: medium
+ # env_name="11_vs_11_hard_stochastic",
+ save_replay_gif=False,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ cuda=True,
+ nstep=5,
+ discount_factor=0.997,
+ model=dict(),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=512,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=256),
+ eval=dict(evaluator=dict(eval_freq=5000)),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1,
+ end=0.05,
+ decay=int(2e5),
+ ),
+ replay_buffer=dict(replay_buffer_size=int(5e5), ),
+ ),
+ ),
+)
+gfootball_dqn_main_config = EasyDict(gfootball_dqn_main_config)
+main_config = gfootball_dqn_main_config
+
+gfootball_dqn_create_config = dict(
+ env=dict(
+ type='gfootball',
+ import_names=['dizoo.gfootball.envs.gfootball_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+gfootball_dqn_create_config = EasyDict(gfootball_dqn_create_config)
+create_config = gfootball_dqn_create_config
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline
+ from dizoo.gfootball.model.q_network.football_q_network import FootballNaiveQ
+ football_naive_q = FootballNaiveQ()
+ serial_pipeline((main_config, create_config), model=football_naive_q, seed=0, max_env_step=int(5e6))
diff --git a/DI-engine/dizoo/gfootball/entry/parallel/gfootball_il_parallel_config.py b/DI-engine/dizoo/gfootball/entry/parallel/gfootball_il_parallel_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..96f0ef21864510fda2c2cf673951064d2a2095ff
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/entry/parallel/gfootball_il_parallel_config.py
@@ -0,0 +1,123 @@
+from easydict import EasyDict
+from ding.config import parallel_transform
+
+__policy_default_config = dict(
+ use_cuda=False,
+ policy_type='IL',
+ model=dict(),
+ learn=dict(
+ train_iteration=20,
+ batch_size=64,
+ learning_rate=0.0002,
+ algo=dict(discount_factor=0.99, ),
+ ),
+ collect=dict(),
+ command=dict(),
+)
+
+__base_learner_default_config = dict(
+ load_path='',
+ use_cuda=False,
+ dataloader=dict(
+ batch_size=64,
+ chunk_size=64,
+ num_workers=0,
+ ),
+ hook=dict(
+ load_ckpt=dict(
+ name='load_ckpt',
+ type='load_ckpt',
+ priority=20,
+ position='before_run',
+ ),
+ log_show=dict(
+ name='log_show',
+ type='log_show',
+ priority=20,
+ position='after_iter',
+ ext_args=dict(freq=50),
+ ),
+ save_ckpt_after_run=dict(
+ name='save_ckpt_after_run',
+ type='save_ckpt',
+ priority=20,
+ position='after_run',
+ )
+ ),
+)
+
+__zergling_collector_default_config = dict(
+ collector_type='zergling',
+ import_names=['ding.worker.collector.zergling_parallel_collector'],
+ print_freq=10,
+ compressor='lz4',
+ policy_update_freq=3,
+ env_kwargs=dict(
+ import_names=['dizoo.gfootball.envs.gfootball_env'],
+ env_type='gfootball',
+ collector_env_num=2,
+ collector_episode_num=2,
+ evaluator_env_num=2,
+ evaluator_episode_num=2,
+ eval_stop_val=3,
+ manager=dict(shared_memory=False, ),
+ ),
+)
+
+__coordinator_default_config = dict(
+ collector_task_timeout=30,
+ learner_task_timeout=600,
+ interaction=dict(
+ host='auto',
+ port='auto',
+ ),
+ commander=dict(
+ parallel_commander_type='solo',
+ import_names=['ding.worker.coordinator.solo_parallel_commander'],
+ collector_task_space=2,
+ learner_task_space=1,
+ learner_cfg=__base_learner_default_config,
+ collector_cfg=__zergling_collector_default_config,
+ replay_buffer_cfg=dict(buffer_name=['agent'], agent=dict(
+ meta_maxlen=100000,
+ max_reuse=10,
+ )),
+ policy=__policy_default_config,
+ max_iterations=int(1e9),
+ eval_interval=500,
+ ),
+)
+__coordinator_default_config = EasyDict(__coordinator_default_config)
+
+main_config = dict(
+ coordinator=__coordinator_default_config,
+ learner0=dict(
+ import_names=['ding.worker.learner.comm.flask_fs_learner'],
+ comm_learner_type='flask_fs',
+ host='auto',
+ port='auto',
+ path_data='./data',
+ path_policy='.',
+ send_policy_freq=1,
+ use_distributed=False,
+ ),
+ collector0=dict(
+ import_names=['ding.worker.collector.comm.flask_fs_collector'],
+ comm_collector_type='flask_fs',
+ host='auto',
+ port='auto',
+ path_data='./data',
+ path_policy='.',
+ queue_maxsize=8,
+ ),
+ collector1=dict(
+ import_names=['ding.worker.collector.comm.flask_fs_collector'],
+ comm_collector_type='flask_fs',
+ host='auto',
+ port='auto',
+ path_data='./data',
+ path_policy='.',
+ queue_maxsize=8,
+ ),
+)
+main_config = parallel_transform(main_config)
diff --git a/DI-engine/dizoo/gfootball/entry/parallel/gfootball_ppo_parallel_config.py b/DI-engine/dizoo/gfootball/entry/parallel/gfootball_ppo_parallel_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a4d74cc648a5d945bf7002988de45faa2540585
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/entry/parallel/gfootball_ppo_parallel_config.py
@@ -0,0 +1,103 @@
+from easydict import EasyDict
+from ding.config import parallel_transform
+from copy import deepcopy
+
+gfootball_ppo_config = dict(
+ env=dict(
+ collector_env_num=1,
+ collector_episode_num=1,
+ evaluator_env_num=1,
+ evaluator_episode_num=1,
+ stop_value=5,
+ save_replay=False,
+ render=False,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(type='conv1d', import_names=['dizoo.gfootball.model.conv1d.conv1d']),
+ nstep=1,
+ discount_factor=0.995,
+ learn=dict(
+ batch_size=32,
+ learning_rate=0.001,
+ learner=dict(
+ learner_num=1,
+ send_policy_freq=1,
+ ),
+ ),
+ collect=dict(
+ n_sample=20,
+ env_num=1,
+ collector=dict(
+ collector_num=1,
+ update_policy_second=3,
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=50), env_num=1),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=100000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=100000,
+ enable_track_used_data=True,
+ ),
+ commander=dict(
+ collector_task_space=2,
+ learner_task_space=1,
+ eval_interval=5,
+ league=dict(),
+ ),
+ ),
+ )
+)
+gfootball_ppo_config = EasyDict(gfootball_ppo_config)
+main_config = gfootball_ppo_config
+
+gfootball_ppo_create_config = dict(
+ env=dict(
+ import_names=['dizoo.gfootball.envs.gfootballsp_env'],
+ type='gfootball_sp',
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo_lstm_command', import_names=['dizoo.gfootball.policy.ppo_lstm']),
+ learner=dict(type='base', import_names=['ding.worker.learner.base_learner']),
+ collector=dict(
+ type='marine',
+ import_names=['ding.worker.collector.marine_parallel_collector'],
+ ),
+ commander=dict(
+ type='one_vs_one',
+ import_names=['ding.worker.coordinator.one_vs_one_parallel_commander'],
+ ),
+ comm_learner=dict(
+ type='flask_fs',
+ import_names=['ding.worker.learner.comm.flask_fs_learner'],
+ ),
+ comm_collector=dict(
+ type='flask_fs',
+ import_names=['ding.worker.collector.comm.flask_fs_collector'],
+ ),
+)
+gfootball_ppo_create_config = EasyDict(gfootball_ppo_create_config)
+create_config = gfootball_ppo_create_config
+
+gfootball_ppo_system_config = dict(
+ path_data='./data',
+ path_policy='./policy',
+ communication_mode='auto',
+ learner_multi_gpu=False,
+ learner_gpu_num=1,
+ coordinator=dict()
+)
+gfootball_ppo_system_config = EasyDict(gfootball_ppo_system_config)
+system_config = gfootball_ppo_system_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c gfootball_ppo_parallel_config.py -s 0`
+ from ding.entry import parallel_pipeline
+ config = tuple([deepcopy(main_config), deepcopy(create_config), deepcopy(system_config)])
+ parallel_pipeline(config, seed=0)
diff --git a/DI-engine/dizoo/gfootball/entry/show_dataset.py b/DI-engine/dizoo/gfootball/entry/show_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c5d8d3d338c66e632fb1138bf46d62dc9698262
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/entry/show_dataset.py
@@ -0,0 +1,56 @@
+"""
+Overview:
+ The following is to show some statistics of the dataset in gfootball env.
+"""
+import torch
+import numpy as np
+import os
+from ding.config import read_config, compile_config
+from ding.utils.data import create_dataset
+from dizoo.gfootball.entry.gfootball_bc_config import main_config, create_config
+
+path = os.path.abspath(__file__)
+dir_path = os.path.dirname(path)
+
+if __name__ == "__main__":
+ config = [main_config, create_config]
+ input_cfg = config
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = input_cfg
+ cfg = compile_config(cfg, seed=0, auto=True, create_cfg=create_cfg)
+ cfg.policy.collect.data_type = 'naive'
+ """episode data"""
+ # Users should add their own BC data path here.
+ cfg.policy.collect.data_path = dir_path + '/gfootball_rule_100eps.pkl'
+ dataset = create_dataset(cfg)
+
+ print('num_episodes', dataset.__len__())
+ print('episode 0, transition 0', dataset.__getitem__(0)[0])
+ episodes_len = np.array([len(dataset.__getitem__(i)) for i in range(dataset.__len__())])
+ print('episodes_len', episodes_len)
+ return_of_episode = torch.stack(
+ [
+ torch.stack(
+ [dataset.__getitem__(episode)[i]['reward'] for i in range(dataset.__getitem__(episode).__len__())],
+ axis=0
+ ).sum(0) for episode in range(dataset.__len__())
+ ],
+ axis=0
+ )
+ print('return_of_episode', return_of_episode)
+ print(return_of_episode.mean(), return_of_episode.max(), return_of_episode.min())
+ """transition data"""
+ # Users should add their own BC data path here.
+ cfg.policy.collect.data_path = dir_path + '/gfootball_rule_100eps_transitions_lt0.pkl'
+ dataset = create_dataset(cfg)
+
+ print('num_transitions', dataset.__len__())
+ print('transition 0: ', dataset.__getitem__(0))
+
+ reward_of_transitions = torch.stack(
+ [dataset.__getitem__(transition)['reward'] for transition in range(dataset.__len__())], axis=0
+ )
+ print('reward_of_transitions', reward_of_transitions)
+ print(reward_of_transitions.mean(), reward_of_transitions.max(), reward_of_transitions.min())
diff --git a/DI-engine/dizoo/gfootball/entry/test_accuracy.py b/DI-engine/dizoo/gfootball/entry/test_accuracy.py
new file mode 100644
index 0000000000000000000000000000000000000000..07af2a2145c95b523af206341ad613672f2b4105
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/entry/test_accuracy.py
@@ -0,0 +1,41 @@
+import torch
+import logging
+import math
+from ding.torch_utils import to_list
+from ding.utils.data import NaiveRLDataset
+from torch.utils.data import DataLoader
+
+logging.basicConfig(level=logging.INFO)
+
+
+def test_accuracy_in_dataset(data_path, batch_size, policy):
+ """
+ Overview:
+ Evaluate total accuracy and accuracy of each action in dataset from
+ ``datapath`` using the ``policy`` for gfootball env.
+ """
+ dataset = NaiveRLDataset(data_path)
+ dataloader = DataLoader(dataset, batch_size)
+
+ total_accuracy_in_dataset = []
+ action_accuracy_in_dataset = {k: [] for k in range(19)}
+ for _, minibatch in enumerate(dataloader):
+ policy_output = policy._forward_eval(minibatch['obs'])
+ pred_action = policy_output['action']
+ total_accuracy = (pred_action == minibatch['action'].view(-1)).float().mean()
+ total_accuracy_in_dataset.append(total_accuracy)
+
+ for action_unique in to_list(torch.unique(minibatch['action'])):
+ # find the index where action is `action_unique` in `pred_action`
+ action_index = (pred_action == action_unique).nonzero(as_tuple=True)[0]
+ action_accuracy = (pred_action[action_index] == minibatch['action'].view(-1)[action_index]).float().mean()
+ if math.isnan(action_accuracy):
+ action_accuracy = 0.0
+ action_accuracy_in_dataset[action_unique].append(action_accuracy)
+ # logging.info(f'the accuracy of action {action_unique} in current train mini-batch is: {action_accuracy}')
+
+ logging.info(f'total accuracy in dataset is: {torch.tensor(total_accuracy_in_dataset).mean().item()}')
+ logging.info(
+ f'accuracy of each action in dataset is (nan means the action does not appear in the dataset): '
+ f'{ {k: torch.tensor(action_accuracy_in_dataset[k]).mean().item() for k in range(19)} }'
+ )
diff --git a/DI-engine/dizoo/gfootball/envs/__init__.py b/DI-engine/dizoo/gfootball/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d543ae1a3a61ce60356a635e36e0b685a833096a
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/__init__.py
@@ -0,0 +1,7 @@
+import warnings
+
+try:
+ from .gfootball_env import GfootballEnv
+except ImportError:
+ warnings.warn("not found gfootball env, please install it")
+ GfootballEnv = None
diff --git a/DI-engine/dizoo/gfootball/envs/action/gfootball_action.py b/DI-engine/dizoo/gfootball/envs/action/gfootball_action.py
new file mode 100644
index 0000000000000000000000000000000000000000..34f633886e633283eb5275bebd3ddad1ff53833e
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/action/gfootball_action.py
@@ -0,0 +1,93 @@
+from collections import namedtuple
+
+import numpy as np
+
+from ding.envs.common import EnvElement
+
+
+class GfootballSpAction(EnvElement):
+ _name = "gfootballSpAction"
+ _action_keys = ['action_type']
+ Action = namedtuple('Action', _action_keys)
+
+ def _init(self, cfg):
+ self.default_val = None
+ self.template = {
+ 'action_type': {
+ 'name': 'action_type',
+ 'shape': (17, ),
+ 'value': {
+ 'min': 0,
+ 'max': 16,
+ 'dtype': int,
+ 'dinfo': 'int value',
+ },
+ 'env_value': 'type of action, refer to AtariEnv._action_set',
+ 'to_agent_processor': lambda x: x,
+ 'from_agent_processor': lambda x: x,
+ 'necessary': True,
+ }
+ }
+ self._shape = (17, )
+ self._value = {
+ 'min': 0,
+ 'max': 16,
+ 'dtype': int,
+ 'dinfo': 'int value, action_meanings: []',
+ }
+
+ def _to_agent_processor(self, action):
+ return action
+
+ def _from_agent_processor(self, action):
+ return action
+
+ # override
+ def _details(self):
+ return '\t'.join(self._action_keys)
+
+
+class GfootballRawAction(EnvElement):
+ '''
+ For raw action set please reference
+ .
+ '''
+ _name = "gfootballRawAction"
+ _action_keys = ['action_type']
+ Action = namedtuple('Action', _action_keys)
+
+ def _init(self, cfg):
+ self._default_val = None
+ self.template = {
+ 'action_type': {
+ 'name': 'action_type',
+ 'shape': (19, ),
+ 'value': {
+ 'min': 0,
+ 'max': 18,
+ 'dtype': int,
+ 'dinfo': 'int value',
+ },
+ 'env_value': 'type of action, refer to AtariEnv._action_set',
+ 'to_agent_processor': lambda x: x,
+ 'from_agent_processor': lambda x: x,
+ 'necessary': True,
+ }
+ }
+ self._shape = (19, )
+ self._value = {
+ 'min': 0,
+ 'max': 18,
+ 'dtype': int,
+ 'dinfo': 'int value, action_meanings: []',
+ }
+
+ def _to_agent_processor(self, action):
+ return action
+
+ def _from_agent_processor(self, action):
+ return action
+
+ # override
+ def _details(self):
+ return '\t'.join(self._action_keys)
diff --git a/DI-engine/dizoo/gfootball/envs/action/gfootball_action_runner.py b/DI-engine/dizoo/gfootball/envs/action/gfootball_action_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..436661958c1336c746ae0278dd0f5775bcab3f72
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/action/gfootball_action_runner.py
@@ -0,0 +1,21 @@
+import copy
+
+import numpy as np
+
+from ding.envs.common import EnvElementRunner
+from ding.envs.env.base_env import BaseEnv
+from .gfootball_action import GfootballRawAction
+
+
+class GfootballRawActionRunner(EnvElementRunner):
+
+ def _init(self, cfg, *args, **kwargs) -> None:
+ # set self._core and other state variable
+ self._core = GfootballRawAction(cfg)
+
+ def get(self, engine: BaseEnv) -> np.array:
+ agent_action = copy.deepcopy(engine.agent_action)
+ return agent_action
+
+ def reset(self) -> None:
+ pass
diff --git a/DI-engine/dizoo/gfootball/envs/fake_dataset.py b/DI-engine/dizoo/gfootball/envs/fake_dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..3189a658e0f00d60a05783b23396bf69b36cef5b
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/fake_dataset.py
@@ -0,0 +1,83 @@
+import random
+import numpy as np
+
+from dizoo.gfootball.envs.obs.gfootball_obs import PlayerObs, MatchObs
+from ding.utils.data import default_collate
+
+
+def generate_data(player_obs: dict) -> np.array:
+ dim = player_obs['dim']
+ min = player_obs['value']['min']
+ max = player_obs['value']['max']
+ dinfo = player_obs['value']['dinfo']
+ if dinfo in ['one-hot', 'boolean vector']:
+ data = np.zeros((dim, ), dtype=np.float32)
+ data[random.randint(0, dim - 1)] = 1
+ return data
+ elif dinfo == 'float':
+ data = np.random.rand(dim)
+ for dim_idx in range(dim):
+ data[dim_idx] = min[dim_idx] + (max[dim_idx] - min[dim_idx]) * data[dim_idx]
+ return data
+
+
+class FakeGfootballDataset:
+
+ def __init__(self):
+ match_obs = MatchObs({})
+ player_obs = PlayerObs({})
+ self.match_obs_info = match_obs.template
+ self.player_obs_info = player_obs.template
+ self.action_dim = 19
+ self.batch_size = 4
+ del match_obs, player_obs
+
+ def __len__(self) -> int:
+ return self.batch_size
+
+ def get_random_action(self) -> np.array:
+ return np.random.randint(0, self.action_dim - 1, size=(1, ))
+
+ def get_random_obs(self) -> dict:
+ inputs = {}
+ for match_obs in self.match_obs_info:
+ key = match_obs['ret_key']
+ data = generate_data(match_obs)
+ inputs[key] = data
+ players_list = []
+ for _ in range(22):
+ one_player = {}
+ for player_obs in self.player_obs_info:
+ key = player_obs['ret_key']
+ data = generate_data(player_obs)
+ one_player[key] = data
+ players_list.append(one_player)
+ inputs['players'] = players_list
+ return inputs
+
+ def get_batched_obs(self, bs: int) -> dict:
+ batch = []
+ for _ in range(bs):
+ batch.append(self.get_random_obs())
+ return default_collate(batch)
+
+ def get_random_reward(self) -> np.array:
+ return np.array([random.random() - 0.5])
+
+ def get_random_terminals(self) -> int:
+ sample = random.random()
+ if sample > 0.99:
+ return 1
+ return 0
+
+ def get_batch_sample(self, bs: int) -> list:
+ batch = []
+ for _ in range(bs):
+ step = {}
+ step['obs'] = self.get_random_obs()
+ step['next_obs'] = self.get_random_obs()
+ step['action'] = self.get_random_action()
+ step['done'] = self.get_random_terminals()
+ step['reward'] = self.get_random_reward()
+ batch.append(step)
+ return batch
diff --git a/DI-engine/dizoo/gfootball/envs/gfootball_academy_env.py b/DI-engine/dizoo/gfootball/envs/gfootball_academy_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..183ac41d25711e00cee684013b64cd256c59e544
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/gfootball_academy_env.py
@@ -0,0 +1,367 @@
+"""
+The code below is adapted from https://github.com/lich14/CDS/tree/main/CDS_GRF/envs/grf,
+which is from the codebase of the CDS paper "Celebrating Diversity in Shared Multi-Agent Reinforcement Learning"
+"""
+
+import gfootball.env as football_env
+from gfootball.env import observation_preprocessing
+import gym
+import numpy as np
+from ding.utils import ENV_REGISTRY
+from typing import Any, List, Union, Optional
+import copy
+import torch
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.torch_utils import to_ndarray, to_list
+import os
+from matplotlib import animation
+import matplotlib.pyplot as plt
+
+
+@ENV_REGISTRY.register('gfootball-academy')
+class GfootballAcademyEnv(BaseEnv):
+
+ def __init__(
+ self,
+ cfg: dict,
+ dense_reward=False,
+ write_full_episode_dumps=False,
+ write_goal_dumps=False,
+ dump_freq=1000,
+ render=False,
+ time_limit=150,
+ time_step=0,
+ stacked=False,
+ representation="simple115",
+ rewards='scoring',
+ logdir='football_dumps',
+ write_video=True,
+ number_of_right_players_agent_controls=0,
+ ):
+ """
+ 'academy_3_vs_1_with_keeper'
+ n_agents=3,
+ obs_dim=26,
+ 'academy_counterattack_hard'
+ n_agents=4,
+ obs_dim=34,
+ """
+ self._cfg = cfg
+ self._save_replay = False
+ self._save_replay_count = 0
+ self._replay_path = None
+ self.dense_reward = dense_reward
+ self.write_full_episode_dumps = write_full_episode_dumps
+ self.write_goal_dumps = write_goal_dumps
+ self.dump_freq = dump_freq
+ self.render = render
+ self.env_name = self._cfg.env_name # TODO
+ self.n_agents = self._cfg.agent_num
+ self.obs_dim = self._cfg.obs_dim
+
+ self.episode_limit = time_limit
+ self.time_step = time_step
+ self.stacked = stacked
+ self.representation = representation
+ self.rewards = rewards
+ self.logdir = logdir
+ self.write_video = write_video
+ self.number_of_right_players_agent_controls = number_of_right_players_agent_controls
+
+ self._env = football_env.create_environment(
+ write_full_episode_dumps=self.write_full_episode_dumps,
+ write_goal_dumps=self.write_goal_dumps,
+ env_name=self.env_name,
+ stacked=self.stacked,
+ representation=self.representation,
+ rewards=self.rewards,
+ logdir=self.logdir,
+ render=self.render,
+ write_video=self.write_video,
+ dump_frequency=self.dump_freq,
+ number_of_left_players_agent_controls=self.n_agents,
+ number_of_right_players_agent_controls=self.number_of_right_players_agent_controls,
+ channel_dimensions=(observation_preprocessing.SMM_WIDTH, observation_preprocessing.SMM_HEIGHT)
+ )
+
+ obs_space_low = self._env.observation_space.low[0][:self.obs_dim]
+ obs_space_high = self._env.observation_space.high[0][:self.obs_dim]
+
+ self._action_space = gym.spaces.Dict(
+ {agent_i: gym.spaces.Discrete(self._env.action_space.nvec[1])
+ for agent_i in range(self.n_agents)}
+ )
+ self._observation_space = gym.spaces.Dict(
+ {
+ agent_i:
+ gym.spaces.Box(low=obs_space_low, high=obs_space_high, dtype=self._env.observation_space.dtype)
+ for agent_i in range(self.n_agents)
+ }
+ )
+ self._reward_space = gym.spaces.Box(low=0, high=100, shape=(1, ), dtype=np.float32) # TODO(pu)
+
+ self.n_actions = self.action_space[0].n
+
+ def get_simple_obs(self, index=-1):
+ full_obs = self._env.unwrapped.observation()[0]
+ simple_obs = []
+
+ if self.env_name == 'academy_3_vs_1_with_keeper':
+ if index == -1:
+ # global state, absolute position
+ simple_obs.append(full_obs['left_team'][-self.n_agents:].reshape(-1))
+ simple_obs.append(full_obs['left_team_direction'][-self.n_agents:].reshape(-1))
+
+ simple_obs.append(full_obs['right_team'].reshape(-1))
+ simple_obs.append(full_obs['right_team_direction'].reshape(-1))
+
+ simple_obs.append(full_obs['ball'])
+ simple_obs.append(full_obs['ball_direction'])
+ else:
+ # local state, relative position
+ ego_position = full_obs['left_team'][-self.n_agents + index].reshape(-1)
+ simple_obs.append(ego_position)
+ simple_obs.append(
+ (np.delete(full_obs['left_team'][-self.n_agents:], index, axis=0) - ego_position).reshape(-1)
+ )
+
+ simple_obs.append(full_obs['left_team_direction'][-self.n_agents + index].reshape(-1))
+ simple_obs.append(
+ np.delete(full_obs['left_team_direction'][-self.n_agents:], index, axis=0).reshape(-1)
+ )
+
+ simple_obs.append((full_obs['right_team'] - ego_position).reshape(-1))
+ simple_obs.append(full_obs['right_team_direction'].reshape(-1))
+
+ simple_obs.append(full_obs['ball'][:2] - ego_position)
+ simple_obs.append(full_obs['ball'][-1].reshape(-1))
+ simple_obs.append(full_obs['ball_direction'])
+
+ elif self.env_name == 'academy_counterattack_hard':
+ if index == -1:
+ # global state, absolute position
+ simple_obs.append(full_obs['left_team'][-self.n_agents:].reshape(-1))
+ simple_obs.append(full_obs['left_team_direction'][-self.n_agents:].reshape(-1))
+
+ simple_obs.append(full_obs['right_team'][0])
+ simple_obs.append(full_obs['right_team'][1])
+ simple_obs.append(full_obs['right_team'][2])
+ simple_obs.append(full_obs['right_team_direction'][0])
+ simple_obs.append(full_obs['right_team_direction'][1])
+ simple_obs.append(full_obs['right_team_direction'][2])
+
+ simple_obs.append(full_obs['ball'])
+ simple_obs.append(full_obs['ball_direction'])
+
+ else:
+ # local state, relative position
+ ego_position = full_obs['left_team'][-self.n_agents + index].reshape(-1)
+ simple_obs.append(ego_position)
+ simple_obs.append(
+ (np.delete(full_obs['left_team'][-self.n_agents:], index, axis=0) - ego_position).reshape(-1)
+ )
+
+ simple_obs.append(full_obs['left_team_direction'][-self.n_agents + index].reshape(-1))
+ simple_obs.append(
+ np.delete(full_obs['left_team_direction'][-self.n_agents:], index, axis=0).reshape(-1)
+ )
+
+ simple_obs.append(full_obs['right_team'][0] - ego_position)
+ simple_obs.append(full_obs['right_team'][1] - ego_position)
+ simple_obs.append(full_obs['right_team'][2] - ego_position)
+ simple_obs.append(full_obs['right_team_direction'][0])
+ simple_obs.append(full_obs['right_team_direction'][1])
+ simple_obs.append(full_obs['right_team_direction'][2])
+
+ simple_obs.append(full_obs['ball'][:2] - ego_position)
+ simple_obs.append(full_obs['ball'][-1].reshape(-1))
+ simple_obs.append(full_obs['ball_direction'])
+
+ simple_obs = np.concatenate(simple_obs)
+ return simple_obs
+
+ def get_global_state(self):
+ return self.get_simple_obs(-1)
+
+ def get_global_special_state(self):
+ return [np.concatenate([self.get_global_state(), self.get_obs_agent(i)]) for i in range(self.n_agents)]
+
+ def check_if_done(self):
+ cur_obs = self._env.unwrapped.observation()[0]
+ ball_loc = cur_obs['ball']
+ ours_loc = cur_obs['left_team'][-self.n_agents:]
+
+ if ball_loc[0] < 0 or any(ours_loc[:, 0] < 0):
+ """
+ This is based on the CDS paper:
+ 'We make a small and reasonable change to the half-court offensive scenarios: our players will lose if
+ they or the ball returns to our half-court.'
+ """
+ return True
+
+ return False
+
+ def reset(self):
+ """Returns initial observations and states."""
+ if self._save_replay:
+ self._frames = []
+ self.time_step = 0
+ self._env.reset()
+ obs = {
+ 'agent_state': np.stack(self.get_obs(), axis=0).astype(np.float32),
+ # Note: here 'global_state' is the agent_specific_global_state,
+ # we simply concatenate the global_state and agent_state
+ 'global_state': np.stack(
+ self.get_global_special_state(),
+ axis=0,
+ ).astype(np.float32),
+ 'action_mask': np.stack(self.get_avail_actions(), axis=0).astype(np.float32),
+ }
+
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._eval_episode_return = 0
+
+ return obs
+
+ def step(self, actions):
+ """Returns reward, terminated, info."""
+ assert isinstance(actions, np.ndarray) or isinstance(actions, list), type(actions)
+ self.time_step += 1
+ if isinstance(actions, np.ndarray):
+ actions = actions.tolist()
+
+ if self._save_replay:
+ self._frames.append(self._env.render(mode='rgb_array'))
+ _, original_rewards, done, infos = self._env.step(actions)
+ obs = {
+ 'agent_state': np.stack(self.get_obs(), axis=0).astype(np.float32),
+ # Note: here 'global_state' is the agent_specific_global_state,
+ # we simply concatenate the global_state and agent_state
+ 'global_state': np.stack(
+ self.get_global_special_state(),
+ axis=0,
+ ).astype(np.float32),
+ 'action_mask': np.stack(self.get_avail_actions(), axis=0).astype(np.float32),
+ }
+ rewards = list(original_rewards)
+
+ if self.time_step >= self.episode_limit:
+ done = True
+
+ if self.check_if_done():
+ done = True
+
+ if done:
+ if self._save_replay:
+ path = os.path.join(
+ self._replay_path, '{}_episode_{}.gif'.format(self.env_name, self._save_replay_count)
+ )
+ self.display_frames_as_gif(self._frames, path)
+ self._save_replay_count += 1
+
+ if sum(rewards) <= 0:
+ """
+ This is based on the CDS paper:
+ "Environmental reward only occurs at the end of the game.
+ They will get +100 if they win, else get -1."
+ If done=False, the reward is -1,
+ If done=True and sum(rewards)<=0 the reward is 1.
+ If done=True and sum(rewards)>0 the reward is 100.
+ """
+ infos['eval_episode_return'] = infos['score_reward'] # TODO(pu)
+ return BaseEnvTimestep(obs, np.array(-int(done)).astype(np.float32), done, infos)
+ else:
+ infos['eval_episode_return'] = infos['score_reward']
+ return BaseEnvTimestep(obs, np.array(100).astype(np.float32), done, infos)
+
+ def get_obs(self):
+ """Returns all agent observations in a list."""
+ obs = [self.get_simple_obs(i) for i in range(self.n_agents)]
+ return obs
+
+ def get_obs_agent(self, agent_id):
+ """Returns observation for agent_id."""
+ return self.get_simple_obs(agent_id)
+
+ def get_obs_size(self):
+ """Returns the size of the observation."""
+ return self.obs_dim
+
+ def get_state(self):
+ """Returns the global state."""
+ return self.get_global_state()
+
+ def get_state_size(self):
+ """Returns the size of the global state."""
+ return self.obs_dim
+
+ def get_avail_actions(self):
+ """Returns the available actions of all agents in a list."""
+ return [[1 for _ in range(self.n_actions)] for agent_id in range(self.n_agents)]
+
+ def get_avail_agent_actions(self, agent_id):
+ """Returns the available actions for agent_id."""
+ return self.get_avail_actions()[agent_id]
+
+ def render(self):
+ pass
+
+ def close(self):
+ self._env.close()
+
+ def save_replay(self):
+ """Save a replay."""
+ pass
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return f'GfootballEnv Academy Env {self.env_name}'
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ """
+ Overview:
+ Save replay file in the given path
+ Arguments:
+ - replay_path(:obj:`str`): Storage path.
+ """
+ if replay_path is None:
+ replay_path = './video'
+ self._save_replay = True
+ self._replay_path = replay_path
+ self._save_replay_count = 0
+
+ @staticmethod
+ def display_frames_as_gif(frames: list, path: str) -> None:
+ patch = plt.imshow(frames[0])
+ plt.axis('off')
+
+ def animate(i):
+ patch.set_data(frames[i])
+
+ anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=5)
+ anim.save(path, writer='imagemagick', fps=20)
diff --git a/DI-engine/dizoo/gfootball/envs/gfootball_env.py b/DI-engine/dizoo/gfootball/envs/gfootball_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..49209d21eb48091bb06ee2cdd1b694dd149f03e1
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/gfootball_env.py
@@ -0,0 +1,229 @@
+import gfootball
+import gfootball.env as football_env
+
+import copy
+from collections import namedtuple
+from typing import List, Any, Optional
+
+import numpy as np
+from ding.envs import BaseEnv
+from ding.utils import ENV_REGISTRY
+from .action.gfootball_action_runner import GfootballRawActionRunner
+from .obs.gfootball_obs_runner import GfootballObsRunner
+from .reward.gfootball_reward_runner import GfootballRewardRunner
+import gym
+from ding.torch_utils import to_ndarray, to_list
+import os
+from matplotlib import animation
+import matplotlib.pyplot as plt
+from ding.envs import ObsPlusPrevActRewWrapper
+
+
+@ENV_REGISTRY.register('gfootball')
+class GfootballEnv(BaseEnv):
+ timestep = namedtuple('GfootballTimestep', ['obs', 'reward', 'done', 'info'])
+
+ info_template = namedtuple('GFootballEnvInfo', ['obs_space', 'act_space', 'rew_space'])
+
+ def __init__(self, cfg):
+ self._cfg = cfg
+ self._action_helper = GfootballRawActionRunner(cfg)
+ self._reward_helper = GfootballRewardRunner(cfg)
+ self._obs_helper = GfootballObsRunner(cfg)
+ self.save_replay = cfg.get("save_replay", False)
+ self._launch_env_flag = False
+ self._launch_env()
+ self.env_name = self._cfg.env_name
+ self._save_replay_gif = self._cfg.save_replay_gif
+
+ def _launch_env(self, gui=False):
+
+ self._env = football_env.create_environment(
+ # default env_name="11_vs_11_stochastic",
+ env_name=self._cfg.env_name,
+ representation='raw',
+ stacked=False,
+ logdir='./tmp/football',
+ write_goal_dumps=False,
+ write_full_episode_dumps=self.save_replay,
+ write_video=self.save_replay,
+ render=False
+ )
+ self._launch_env_flag = True
+
+ def reset(self) -> dict:
+ if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward:
+ # for NGU
+ self.prev_action = -1 # null action
+ self.prev_reward_extrinsic = 0 # null reward
+
+ if self._save_replay_gif:
+ self._frames = []
+ if not self._launch_env_flag:
+ self._launch_env()
+ self._football_obs = self._env.reset()[0]
+ self._reward_helper.reset()
+ self._obs_helper.reset()
+ self._action_helper.reset()
+ self._observation_space = gym.spaces.Dict(
+ {
+ 'match': gym.spaces.Dict(
+ {
+ k: gym.spaces.Discrete(v['max']) if v['dinfo'] == 'one-hot' else
+ gym.spaces.Box(low=np.array(v['min']), high=np.array(v['max']), dtype=np.float32)
+ for k, v in self._obs_helper.info['match'].value.items()
+ }
+ ),
+ 'player': gym.spaces.Dict(
+ {
+ k: gym.spaces.Discrete(v['max']) if v['dinfo'] == 'one-hot' else
+ gym.spaces.Box(low=np.array(v['min']), high=np.array(v['max']), dtype=np.float32)
+ for k, v in self._obs_helper.info['player'].value['players'].items()
+ }
+ )
+ }
+ )
+ self._action_space = gym.spaces.Discrete(self._action_helper.info.shape[0])
+ self._reward_space = gym.spaces.Box(
+ low=self._reward_helper.info.value['min'],
+ high=self._reward_helper.info.value['max'],
+ shape=self._reward_helper.info.shape,
+ dtype=np.float32
+ )
+
+ self.obs = self._obs_helper.get(self)
+
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward:
+ # for NGU
+ return {
+ 'obs': {
+ 'processed_obs': self.obs,
+ 'raw_obs': self._football_obs
+ },
+ 'prev_action': self.prev_action,
+ 'prev_reward_extrinsic': self.prev_reward_extrinsic
+ }
+ else:
+ return {'processed_obs': self.obs, 'raw_obs': self._football_obs}
+
+ def step(self, action: np.array) -> 'GfootballEnv.timestep':
+ assert self._launch_env_flag
+ self.agent_action = action
+ action = action.item()
+ # env step
+ if self._save_replay_gif:
+ self._frames.append(self._env.render(mode='rgb_array'))
+ self._football_obs, self._reward_of_action, self._is_done, self._info = self._env.step(action)
+ self._football_obs = self._football_obs[0]
+ self.action = self._action_helper.get(self)
+ self.reward = self._reward_helper.get(self)
+ self.obs = self._obs_helper.get(self)
+
+ info = {'cum_reward': self._reward_helper.cum_reward}
+ if self._is_done:
+ info['eval_episode_return'] = to_ndarray(self._reward_helper.cum_reward)
+ if self._save_replay_gif:
+ path = os.path.join(
+ self._replay_path, '{}_episode_{}.gif'.format(self.env_name, self._save_replay_gif_count)
+ )
+ self.display_frames_as_gif(self._frames, path)
+ self._save_replay_gif_count += 1
+ print(f'save one episode replay_gif in {path}')
+ # TODO(pu)
+ self.reward = to_ndarray(self.reward)
+
+ if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward:
+ # for NGU
+ self.prev_action = action
+ self.prev_reward_extrinsic = self.reward
+ obs = {
+ 'obs': {
+ 'processed_obs': self.obs,
+ 'raw_obs': self._football_obs
+ },
+ 'prev_action': self.prev_action,
+ 'prev_reward_extrinsic': self.prev_reward_extrinsic
+ }
+ else:
+ obs = {'processed_obs': self.obs, 'raw_obs': self._football_obs}
+
+ return GfootballEnv.timestep(obs, reward=self.reward, done=self._is_done, info=info)
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def close(self) -> None:
+ self._env.close()
+
+ def __repr__(self) -> str:
+ return 'GfootballEnv:\n\
+ \tobservation[{}]\n\
+ \taction[{}]\n\
+ \treward[{}]\n'.format(repr(self._obs_helper), repr(self._action_helper), repr(self._reward_helper))
+
+ def info(self) -> 'GfootballEnv.info':
+ info_data = {
+ 'obs_space': self._obs_helper.info,
+ 'act_space': self._action_helper.info,
+ 'rew_space': self._reward_helper.info,
+ }
+ return GfootballEnv.info_template(**info_data)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_env_num = cfg.pop('collector_env_num', 1)
+ cfg = copy.deepcopy(cfg)
+ cfg.save_replay = False
+ return [cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_env_num = cfg.pop('evaluator_env_num', 1)
+ cfg = copy.deepcopy(cfg)
+ cfg.save_replay = True
+ return [cfg for _ in range(evaluator_env_num)]
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._save_replay_gif = True
+ self._replay_path = replay_path
+ self._save_replay_gif_count = 0
+
+ @staticmethod
+ def display_frames_as_gif(frames: list, path: str) -> None:
+ patch = plt.imshow(frames[0])
+ plt.axis('off')
+
+ def animate(i):
+ patch.set_data(frames[i])
+
+ anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=5)
+ anim.save(path, writer='imagemagick', fps=20)
+
+
+GfootballTimestep = GfootballEnv.timestep
diff --git a/DI-engine/dizoo/gfootball/envs/gfootballsp_env.py b/DI-engine/dizoo/gfootball/envs/gfootballsp_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c3131b9db791c315720a357113da94e3103435c
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/gfootballsp_env.py
@@ -0,0 +1,185 @@
+import copy
+from collections import namedtuple
+from typing import Any, List, Union
+
+import gfootball
+import gfootball.env as football_env
+import numpy as np
+from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo
+from ding.envs.common.env_element import EnvElement, EnvElementInfo
+from ding.torch_utils import to_ndarray
+from ding.utils import ENV_REGISTRY
+from dizoo.gfootball.envs.obs.encoder import FeatureEncoder
+from dizoo.gfootball.envs.obs.gfootball_obs import FullObs
+from dizoo.gfootball.envs.action.gfootball_action import GfootballSpAction
+
+
+@ENV_REGISTRY.register('gfootball_sp')
+class GfootballEnv(BaseEnv):
+
+ timestep = namedtuple('GfootballTimestep', ['obs', 'reward', 'done', 'info'])
+ info_template = namedtuple('GFootballEnvInfo', ['obs_space', 'act_space', 'rew_space'])
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self.save_replay = self._cfg.save_replay
+ # self.env_name = cfg.get("env_name", "11_vs_11_kaggle")
+ self.gui = self._cfg.render
+ self._obs_helper = FullObs(cfg)
+ self._action_helper = GfootballSpAction(cfg)
+ self._launch_env_flag = False
+ self._encoder = FeatureEncoder()
+ self.is_evaluator = self._cfg.get("is_evaluator", False)
+ if self.is_evaluator:
+ self.env_name = "11_vs_11_hard_stochastic"
+ self.right_role_num = 0
+ else:
+ self.env_name = "11_vs_11_kaggle"
+ self.right_role_num = 1
+
+ def _make_env(self):
+ self._env = football_env.create_environment(
+ env_name=self.env_name,
+ representation='raw',
+ stacked=False,
+ logdir='/tmp/football',
+ write_goal_dumps=False,
+ write_full_episode_dumps=self.save_replay,
+ write_video=self.save_replay,
+ render=self.gui,
+ number_of_right_players_agent_controls=self.right_role_num
+ )
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._launch_env_flag = True
+ if self.is_evaluator:
+ self._eval_episode_return = [0, 0]
+ else:
+ self._eval_episode_return = [0, 0]
+
+ def reset(self) -> np.ndarray:
+ if not self._launch_env_flag:
+ self._make_env()
+ self._init_flag = True
+ self._env.reset()
+ obs = self._env.observation()
+ if self.is_evaluator:
+ self._prev_obs = obs[0]
+ obs = self._encoder.encode(obs[0])
+ return [obs, obs]
+ else:
+ self._prev_obs, self.prev_obs_opponent = obs
+ obs_ = self._encoder.encode(obs[0])
+ obs_opponent = self._encoder.encode(obs[1])
+ return [obs_, obs_opponent]
+
+ def close(self) -> None:
+ if self._launch_env_flag:
+ self._env.close()
+ self._launch_env_flag = False
+
+ def seed(self, seed: int, dynamic_seed: int = None) -> None:
+ self._seed = seed
+ if dynamic_seed:
+ self._dynamic_seed = dynamic_seed
+
+ def step(self, action) -> 'GfootballEnv.timestep':
+ action = to_ndarray(action)
+ # action = self.process_action(action) # process
+ raw_obs, raw_rew, done, info = self._env.step(action)
+ if self.is_evaluator:
+ raw_obs = raw_obs[0]
+ rew = GfootballEnv.calc_reward(raw_rew, self._prev_obs, raw_obs)
+ obs = to_ndarray(self._encoder.encode(raw_obs))
+ rew = [rew, rew]
+ obs = [obs, obs]
+ self._eval_episode_return[0] += raw_rew
+ self._eval_episode_return[1] += raw_rew
+ else:
+ rew = GfootballEnv.calc_reward(raw_rew[0], self._prev_obs, raw_obs[0])
+ rew_oppo = GfootballEnv.calc_reward(raw_rew[1], self._prev_obs, raw_obs[1])
+ rew = [rew, rew_oppo]
+ obs = [to_ndarray(self._encoder.encode(raw_obs[0])), to_ndarray(self._encoder.encode(raw_obs[1]))]
+ self._eval_episode_return[0] += raw_rew[0]
+ self._eval_episode_return[1] += raw_rew[1]
+
+ if done:
+ if self.is_evaluator:
+ info['eval_episode_return'] = self._eval_episode_return
+ else:
+ info[0]['eval_episode_return'] = self._eval_episode_return[0]
+ info[1]['eval_episode_return'] = self._eval_episode_return[1]
+
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def info(self) -> BaseEnvInfo:
+ info_data = {
+ 'obs_space': self._obs_helper.info,
+ 'act_space': self._action_helper.info,
+ 'rew_space': EnvElementInfo(
+ shape=1,
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ }
+ return GfootballEnv.info_template(**info_data)
+
+ def __repr__(self) -> str:
+ return "DI-engine Gfootball Env({})".format(self.env_name)
+
+ @staticmethod
+ def calc_reward(rew, prev_obs, obs):
+ """
+ Reward disign referred to [football-pairs](https://github.com/seungeunrho/football-paris/blob/main/rewarders/rewarder_basic.py)
+ """
+ ball_x, ball_y, ball_z = obs['ball']
+ MIDDLE_X, PENALTY_X, END_X = 0.2, 0.64, 1.0
+ PENALTY_Y, END_Y = 0.27, 0.42
+
+ ball_position_r = 0.0
+ if (-END_X <= ball_x and ball_x < -PENALTY_X) and (-PENALTY_Y < ball_y and ball_y < PENALTY_Y):
+ ball_position_r = -2.0
+ elif (-END_X <= ball_x and ball_x < -MIDDLE_X) and (-END_Y < ball_y and ball_y < END_Y):
+ ball_position_r = -1.0
+ elif (-MIDDLE_X <= ball_x and ball_x <= MIDDLE_X) and (-END_Y < ball_y and ball_y < END_Y):
+ ball_position_r = 0.0
+ elif (PENALTY_X < ball_x and ball_x <= END_X) and (-PENALTY_Y < ball_y and ball_y < PENALTY_Y):
+ ball_position_r = 2.0
+ elif (MIDDLE_X < ball_x and ball_x <= END_X) and (-END_Y < ball_y and ball_y < END_Y):
+ ball_position_r = 1.0
+ else:
+ ball_position_r = 0.0
+
+ left_yellow = np.sum(obs["left_team_yellow_card"]) - np.sum(prev_obs["left_team_yellow_card"])
+ right_yellow = np.sum(obs["right_team_yellow_card"]) - np.sum(prev_obs["right_team_yellow_card"])
+ yellow_r = right_yellow - left_yellow
+
+ win_reward = 0.0
+ if obs['steps_left'] == 0:
+ [my_score, opponent_score] = obs['score']
+ if my_score > opponent_score:
+ win_reward = 1.0
+
+ reward = 5.0 * win_reward + 5.0 * rew + 0.003 * ball_position_r + yellow_r
+
+ return reward
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_cfg = copy.deepcopy(cfg)
+ collector_env_num = collector_cfg.pop('collector_env_num', 1)
+ collector_cfg.is_evaluator = False
+ return [collector_cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_cfg = copy.deepcopy(cfg)
+ evaluator_env_num = evaluator_cfg.pop('evaluator_env_num', 1)
+ evaluator_cfg.is_evaluator = True
+ return [evaluator_cfg for _ in range(evaluator_env_num)]
diff --git a/DI-engine/dizoo/gfootball/envs/obs/encoder.py b/DI-engine/dizoo/gfootball/envs/obs/encoder.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e2e5fc47b871055c464b0f48d867c7574663a69
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/obs/encoder.py
@@ -0,0 +1,178 @@
+import numpy as np
+
+
+class FeatureEncoder:
+ """
+ Feature encoder referred to [football-pairs](https://github.com/seungeunrho/football-paris/blob/main/encoders/encoder_basic.py)
+ """
+
+ def __init__(self):
+ self.active = -1
+ self.player_pos_x, self.player_pos_y = 0, 0
+ self.n_player = 10
+
+ def get_feature_dims(self):
+ dims = {
+ 'player': 36,
+ 'ball': 18,
+ 'left_team': 7,
+ 'left_team_closest': 7,
+ 'right_team': 7,
+ 'right_team_closest': 7,
+ }
+ return dims
+
+ def encode(self, obs):
+ player_num = obs['active']
+
+ player_pos_x, player_pos_y = obs['left_team'][player_num]
+ player_direction = np.array(obs['left_team_direction'][player_num])
+ player_speed = np.linalg.norm(player_direction)
+ player_role = obs['left_team_roles'][player_num]
+ player_role_onehot = np.eye(self.n_player)[player_role]
+ player_tired = obs['left_team_tired_factor'][player_num]
+ is_dribbling = obs['sticky_actions'][9]
+ is_sprinting = obs['sticky_actions'][8]
+
+ ball_x, ball_y, ball_z = obs['ball']
+ ball_x_relative = ball_x - player_pos_x
+ ball_y_relative = ball_y - player_pos_y
+ ball_x_speed, ball_y_speed, _ = obs['ball_direction']
+ ball_distance = np.linalg.norm([ball_x_relative, ball_y_relative])
+ ball_speed = np.linalg.norm([ball_x_speed, ball_y_speed])
+ ball_owned = 0.0
+ if obs['ball_owned_team'] == -1:
+ ball_owned = 0.0
+ else:
+ ball_owned = 1.0
+ ball_owned_by_us = 0.0
+ if obs['ball_owned_team'] == 0:
+ ball_owned_by_us = 1.0
+ elif obs['ball_owned_team'] == 1:
+ ball_owned_by_us = 0.0
+ else:
+ ball_owned_by_us = 0.0
+
+ ball_which_zone = self._encode_ball_which_zone(ball_x, ball_y)
+
+ if ball_distance > 0.03:
+ ball_far = 1.0
+ else:
+ ball_far = 0.0
+
+ avail = self._get_avail(obs, ball_distance)
+ player_state = np.concatenate(
+ (
+ avail[2:], obs['left_team'][player_num], player_direction * 100, [player_speed * 100],
+ player_role_onehot, [ball_far, player_tired, is_dribbling, is_sprinting]
+ )
+ )
+
+ ball_state = np.concatenate(
+ (
+ np.array(obs['ball']), np.array(ball_which_zone), np.array([ball_x_relative, ball_y_relative]),
+ np.array(obs['ball_direction']) * 20,
+ np.array([ball_speed * 20, ball_distance, ball_owned, ball_owned_by_us])
+ )
+ )
+
+ obs_left_team = np.delete(obs['left_team'], player_num, axis=0)
+ obs_left_team_direction = np.delete(obs['left_team_direction'], player_num, axis=0)
+ left_team_relative = obs_left_team
+ left_team_distance = np.linalg.norm(left_team_relative - obs['left_team'][player_num], axis=1, keepdims=True)
+ left_team_speed = np.linalg.norm(obs_left_team_direction, axis=1, keepdims=True)
+ left_team_tired = np.delete(obs['left_team_tired_factor'], player_num, axis=0).reshape(-1, 1)
+ left_team_state = np.concatenate((left_team_relative*2, obs_left_team_direction*100, left_team_speed*100, \
+ left_team_distance*2, left_team_tired), axis=1)
+ left_closest_idx = np.argmin(left_team_distance)
+ left_closest_state = left_team_state[left_closest_idx]
+
+ obs_right_team = np.array(obs['right_team'])
+ obs_right_team_direction = np.array(obs['right_team_direction'])
+ right_team_distance = np.linalg.norm(obs_right_team - obs['left_team'][player_num], axis=1, keepdims=True)
+ right_team_speed = np.linalg.norm(obs_right_team_direction, axis=1, keepdims=True)
+ right_team_tired = np.array(obs['right_team_tired_factor']).reshape(-1, 1)
+ right_team_state = np.concatenate((obs_right_team*2, obs_right_team_direction*100, right_team_speed*100, \
+ right_team_distance*2, right_team_tired), axis=1)
+ right_closest_idx = np.argmin(right_team_distance)
+ right_closest_state = right_team_state[right_closest_idx]
+
+ state_dict = {
+ "player": player_state,
+ "ball": ball_state,
+ "left_team": left_team_state,
+ "left_closest": left_closest_state,
+ "right_team": right_team_state,
+ "right_closest": right_closest_state,
+ "avail": avail
+ }
+
+ return state_dict
+
+ def _get_avail(self, obs, ball_distance):
+ avail = [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]
+ NO_OP, MOVE, LONG_PASS, HIGH_PASS, SHORT_PASS, SHOT, SPRINT, RELEASE_MOVE, \
+ RELEASE_SPRINT, SLIDE, DRIBBLE, RELEASE_DRIBBLE = 0, 1, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18
+
+ if obs['ball_owned_team'] == 1: # opponents owning ball
+ avail[LONG_PASS], avail[HIGH_PASS], avail[SHORT_PASS], avail[SHOT], avail[DRIBBLE] = 0, 0, 0, 0, 0
+ elif obs['ball_owned_team'] == -1 and ball_distance > 0.03 and obs['game_mode'
+ ] == 0: # Ground ball and far from me
+ avail[LONG_PASS], avail[HIGH_PASS], avail[SHORT_PASS], avail[SHOT], avail[DRIBBLE] = 0, 0, 0, 0, 0
+ else: # my team owning ball
+ avail[SLIDE] = 0
+
+ # Dealing with sticky actions
+ sticky_actions = obs['sticky_actions']
+ if sticky_actions[8] == 0: # sprinting
+ avail[RELEASE_SPRINT] = 0
+
+ if sticky_actions[9] == 1: # dribbling
+ avail[SLIDE] = 0
+ else:
+ avail[RELEASE_DRIBBLE] = 0
+
+ if np.sum(sticky_actions[:8]) == 0:
+ avail[RELEASE_MOVE] = 0
+
+ # if too far, no shot
+ ball_x, ball_y, _ = obs['ball']
+ if ball_x < 0.64 or ball_y < -0.27 or 0.27 < ball_y:
+ avail[SHOT] = 0
+ elif (0.64 <= ball_x and ball_x <= 1.0) and (-0.27 <= ball_y and ball_y <= 0.27):
+ avail[HIGH_PASS], avail[LONG_PASS] = 0, 0
+
+ if obs['game_mode'] == 2 and ball_x < -0.7: # Our GoalKick
+ avail = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ avail[LONG_PASS], avail[HIGH_PASS], avail[SHORT_PASS] = 1, 1, 1
+ return np.array(avail)
+
+ elif obs['game_mode'] == 4 and ball_x > 0.9: # Our CornerKick
+ avail = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ avail[LONG_PASS], avail[HIGH_PASS], avail[SHORT_PASS] = 1, 1, 1
+ return np.array(avail)
+
+ elif obs['game_mode'] == 6 and ball_x > 0.6: # Our PenaltyKick
+ avail = [1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ avail[SHOT] = 1
+ return np.array(avail)
+
+ return np.array(avail)
+
+ def _encode_ball_which_zone(self, ball_x, ball_y):
+ MIDDLE_X, PENALTY_X, END_X = 0.2, 0.64, 1.0
+ LEFT_PENALTY, LEFT_HALF, HALF, RIGHT_PENALTY, RIGHT_HALF, OTHERS = 0, 1, 2, 3, 4, 5
+ PENALTY_Y, END_Y = 0.27, 0.42
+ res = np.eye(6)
+ if (-END_X <= ball_x and ball_x < -PENALTY_X) and (-PENALTY_Y < ball_y and ball_y < PENALTY_Y):
+ return res[LEFT_PENALTY]
+ elif (-END_X <= ball_x and ball_x < -MIDDLE_X) and (-END_Y < ball_y and ball_y < END_Y):
+ return res[LEFT_HALF]
+ elif (-MIDDLE_X <= ball_x and ball_x <= MIDDLE_X) and (-END_Y < ball_y and ball_y < END_Y):
+ return res[HALF]
+ elif (PENALTY_X < ball_x and ball_x <= END_X) and (-PENALTY_Y < ball_y and ball_y < PENALTY_Y):
+ return res[RIGHT_PENALTY]
+ elif (MIDDLE_X < ball_x and ball_x <= END_X) and (-END_Y < ball_y and ball_y < END_Y):
+ return res[RIGHT_HALF]
+ else:
+ return res[OTHERS]
diff --git a/DI-engine/dizoo/gfootball/envs/obs/gfootball_obs.py b/DI-engine/dizoo/gfootball/envs/obs/gfootball_obs.py
new file mode 100644
index 0000000000000000000000000000000000000000..0014e0dcf5c7d31ae3e72fa8fbffb1e1c7f2f2e9
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/obs/gfootball_obs.py
@@ -0,0 +1,454 @@
+import numpy as np
+import torch
+import math
+
+from ding.envs.common import EnvElement
+from functools import partial
+from ding.torch_utils import one_hot
+from ding.envs.common import div_func, div_one_hot
+
+N_PLAYER = 11
+
+
+def score_preprocess(scores):
+ ret = []
+ for score in scores:
+ clip_score = torch.clamp_max(score.unsqueeze(0), 10) # 0-9: 0-9; 10: >=10
+ ret.append(one_hot(clip_score, num=11).squeeze(0))
+ return torch.cat(ret, dim=0)
+
+
+class MatchObs(EnvElement):
+ _name = "GFootballMatchObs"
+
+ def _init(self, cfg):
+ self._default_val = None
+ self.template = [
+ # ------Ball information
+ {
+ 'key': 'ball',
+ 'ret_key': 'ball_position',
+ 'dim': 3,
+ 'op': lambda x: x,
+ 'value': {
+ 'min': (-1, -0.42, 0),
+ 'max': (1, 0.42, 100),
+ 'dtype': float,
+ 'dinfo': 'float'
+ },
+ 'other': 'float (x, y, z)'
+ },
+ {
+ 'key': 'ball_direction',
+ 'ret_key': 'ball_direction',
+ 'dim': 3,
+ 'op': lambda x: x,
+ 'value': {
+ 'min': (-1, -0.42, 0),
+ 'max': (1, 0.42, 100),
+ 'dtype': float,
+ 'dinfo': 'float'
+ },
+ 'other': 'float (x, y, z)'
+ },
+ {
+ 'key': 'ball_rotation',
+ 'ret_key': 'ball_rotation',
+ 'dim': 3,
+ 'op': lambda x: x,
+ 'value': {
+ 'min': (-math.pi, -math.pi, -math.pi),
+ 'max': (math.pi, math.pi, math.pi),
+ 'dtype': float,
+ 'dinfo': 'float'
+ },
+ 'other': 'float (x, y, z)'
+ },
+ {
+ 'key': 'ball_owned_team',
+ 'ret_key': 'ball_owned_team',
+ 'dim': 3,
+ 'op': lambda x: partial(one_hot, num=3)(x + 1),
+ 'value': {
+ 'min': 0,
+ 'max': 2,
+ 'dtype': float,
+ 'dinfo': 'one-hot'
+ },
+ 'other': 'one hot 3 value',
+ 'meaning': ['NotOwned', 'LeftTeam', 'RightTeam']
+ },
+ {
+ 'key': 'ball_owned_player',
+ 'ret_key': 'ball_owned_player',
+ 'dim': N_PLAYER + 1, # 0...N_1: player_idx, N: nobody
+ 'op': lambda x: partial(one_hot, num=N_PLAYER + 1)(x + N_PLAYER + 1 if x == -1 else x),
+ 'value': {
+ 'min': 0,
+ 'max': 2,
+ 'dtype': float,
+ 'dinfo': 'one-hot'
+ },
+ 'other': 'one hot 12 value',
+ 'meaning': 'index of player'
+ },
+ # ------Controlled player information
+ {
+ 'key': 'active',
+ 'ret_key': 'active_player',
+ 'dim': N_PLAYER,
+ 'op': partial(one_hot, num=N_PLAYER),
+ 'value': {
+ 'min': 0,
+ 'max': 2,
+ 'dtype': float,
+ 'dinfo': 'one-hot'
+ },
+ 'other': 'one hot 11 value',
+ 'meaning': 'index of controlled player'
+ },
+ {
+ 'key': 'designated', # In non-multiagent mode it is always equal to `active`
+ 'ret_key': 'designated_player',
+ 'dim': N_PLAYER,
+ 'op': partial(one_hot, num=N_PLAYER),
+ 'value': {
+ 'min': 0,
+ 'max': 2,
+ 'dtype': float,
+ 'dinfo': 'one-hot'
+ },
+ 'other': 'one hot 11 value',
+ 'meaning': 'index of player'
+ },
+ {
+ 'key': 'sticky_actions',
+ 'ret_key': 'active_player_sticky_actions',
+ 'dim': 10,
+ 'op': lambda x: x,
+ 'value': {
+ 'min': 0,
+ 'max': 2,
+ 'dtype': float,
+ 'dinfo': 'boolean vector'
+ },
+ 'other': 'boolean vector with 10 value',
+ 'meaning': [
+ 'Left', 'TopLeft', 'Top', 'TopRight', 'Right', 'BottomRight', 'Bottom', 'BottomLeft', 'Sprint',
+ 'Dribble'
+ ] # 8 directions are one-hot
+ },
+ # ------Match state
+ {
+ 'key': 'score',
+ 'ret_key': 'score',
+ 'dim': 22,
+ 'op': score_preprocess,
+ 'value': {
+ 'min': 0,
+ 'max': 2,
+ 'dtype': float,
+ 'dinfo': 'one-hot'
+ },
+ 'other': 'each score one hot 11 values(10 for 0-9, 1 for over 10), concat two scores',
+ },
+ {
+ 'key': 'steps_left',
+ 'ret_key': 'steps_left',
+ 'dim': 30,
+ 'op': partial(div_one_hot, max_val=2999, ratio=100),
+ 'value': {
+ 'min': 0,
+ 'max': 2,
+ 'dtype': float,
+ 'dinfo': 'one-hot'
+ },
+ 'other': 'div(50), one hot 30 values',
+ },
+ {
+ 'key': 'game_mode',
+ 'ret_key': 'game_mode',
+ 'dim': 7,
+ 'op': partial(one_hot, num=7),
+ 'value': {
+ 'min': 0,
+ 'max': 2,
+ 'dtype': float,
+ 'dinfo': 'one-hot'
+ },
+ 'other': 'one-hot 7 values',
+ 'meaning': ['Normal', 'KickOff', 'GoalKick', 'FreeKick', 'Corner', 'ThrowIn', 'Penalty']
+ },
+ ]
+ self.cfg = cfg
+ self._shape = {t['key']: t['dim'] for t in self.template}
+ self._value = {t['key']: t['value'] for t in self.template}
+ self._to_agent_processor = self.parse
+ self._from_agent_processor = None
+
+ def parse(self, obs: dict) -> dict:
+ '''
+ Overview: find corresponding setting in cfg, parse the feature
+ Arguments:
+ - feature (:obj:`ndarray`): the feature to parse
+ - idx_dict (:obj:`dict`): feature index dict
+ Returns:
+ - ret (:obj:`list`): parse result tensor list
+ '''
+ ret = {}
+ for item in self.template:
+ key = item['key']
+ ret_key = item['ret_key']
+ data = obs[key]
+ if not isinstance(data, list):
+ data = [data]
+ data = torch.Tensor(data) if item['value']['dinfo'] != 'one-hot' else torch.LongTensor(data)
+ try:
+ data = item['op'](data)
+ except RuntimeError:
+ print(item, data)
+ raise RuntimeError
+ if len(data.shape) == 2:
+ data = data.squeeze(0)
+ ret[ret_key] = data.numpy()
+ return ret
+
+ def _details(self):
+ return 'Match Global Obs: Ball, Controlled Player and Match State'
+
+
+class PlayerObs(EnvElement):
+ _name = "GFootballPlayerObs"
+
+ def _init(self, cfg):
+ self._default_val = None
+ self.template = [
+ {
+ 'key': 'team',
+ 'ret_key': 'team',
+ 'dim': 2,
+ 'op': partial(one_hot, num=2), # 0 for left, 1 for right
+ 'value': {
+ 'min': 0,
+ 'max': 2,
+ 'dtype': float,
+ 'dinfo': 'one-hot'
+ },
+ 'other': 'one-hot 2 values for which team'
+ },
+ {
+ 'key': 'index',
+ 'ret_key': 'index',
+ 'dim': N_PLAYER,
+ 'op': partial(one_hot, num=N_PLAYER),
+ 'value': {
+ 'min': 0,
+ 'max': N_PLAYER,
+ 'dtype': float,
+ 'dinfo': 'one-hot'
+ },
+ 'other': 'one-hot N_PLAYER values for index in one team'
+ },
+ {
+ 'key': 'position',
+ 'ret_key': 'position',
+ 'dim': 2,
+ 'op': lambda x: x,
+ 'value': {
+ 'min': (-1, -0.42),
+ 'max': (1, 0.42),
+ 'dtype': float,
+ 'dinfo': 'float'
+ },
+ 'other': 'float (x, y)'
+ },
+ {
+ 'key': 'direction',
+ 'ret_key': 'direction',
+ 'dim': 2,
+ 'op': lambda x: x,
+ 'value': {
+ 'min': (-1, -0.42),
+ 'max': (1, 0.42),
+ 'dtype': float,
+ 'dinfo': 'float'
+ },
+ 'other': 'float'
+ },
+ {
+ 'key': 'tired_factor',
+ 'ret_key': 'tired_factor',
+ 'dim': 1,
+ 'op': lambda x: x,
+ 'value': {
+ 'min': (0, ),
+ 'max': (1, ),
+ 'dtype': float,
+ 'dinfo': 'float'
+ },
+ 'other': 'float'
+ },
+ {
+ 'key': 'yellow_card',
+ 'ret_key': 'yellow_card',
+ 'dim': 2,
+ 'op': partial(one_hot, num=2),
+ 'value': {
+ 'min': 0,
+ 'max': 2,
+ 'dtype': float,
+ 'dinfo': 'one-hot'
+ },
+ 'other': 'one hot 2 values'
+ },
+ {
+ 'key': 'active', # 0(False) means got a red card
+ 'ret_key': 'active',
+ 'dim': 2,
+ 'op': partial(one_hot, num=2),
+ 'value': {
+ 'min': 0,
+ 'max': 2,
+ 'dtype': float,
+ 'dinfo': 'one-hot'
+ },
+ 'other': 'float'
+ },
+ {
+ 'key': 'roles',
+ 'ret_key': 'role',
+ 'dim': 10,
+ 'op': partial(one_hot, num=10),
+ 'value': {
+ 'min': 0,
+ 'max': 2,
+ 'dtype': float,
+ 'dinfo': 'one-hot'
+ },
+ 'other': 'one-hot 10 values',
+ 'meaning': [
+ 'GoalKeeper', 'CentreBack', 'LeftBack', 'RightBack', 'DefenceMidfield', 'CentralMidfield',
+ 'LeftMidfield', 'RightMidfield', 'AttackMidfield', 'CentralFront'
+ ]
+ },
+ ]
+ self.cfg = cfg
+ self._shape = {'players': {t['key']: t['dim'] for t in self.template}}
+ self._value = {'players': {t['key']: t['value'] for t in self.template}}
+ self._to_agent_processor = self.parse
+ self._from_agent_processor = None
+
+ def parse(self, obs: dict) -> dict:
+ players = []
+ for player_idx in range(N_PLAYER):
+ players.append(self._parse(obs, 'left_team', player_idx))
+ for player_idx in range(N_PLAYER):
+ players.append(self._parse(obs, 'right_team', player_idx))
+ return {'players': players}
+
+ def _parse(self, obs: dict, left_right: str, player_idx) -> dict:
+ player_dict = {
+ 'team': 0 if left_right == 'left_team' else 1,
+ 'index': player_idx,
+ }
+ for item in self.template:
+ key = item['key']
+ ret_key = item['ret_key']
+ if key in ['team', 'index']:
+ data = player_dict[key]
+ elif key == 'position':
+ player_stat = left_right
+ data = obs[player_stat][player_idx]
+ else:
+ player_stat = left_right + '_' + key
+ data = obs[player_stat][player_idx]
+ if not isinstance(data, np.ndarray):
+ data = [data]
+ data = torch.Tensor(data) if item['value']['dinfo'] != 'one-hot' else torch.LongTensor(data)
+ try:
+ data = item['op'](data)
+ except RuntimeError:
+ print(item, data)
+ raise RuntimeError
+ if len(data.shape) == 2:
+ data = data.squeeze(0)
+ player_dict[ret_key] = data.numpy()
+ return player_dict
+
+ def _details(self):
+ return 'Single Player Obs'
+
+
+class FullObs(EnvElement):
+ _name = "GFootballFullObs"
+
+ def _init(self, cfg):
+ self._default_val = None
+ self.template = [
+ {
+ 'key': 'player',
+ 'ret_key': 'player',
+ 'dim': 36,
+ 'op': lambda x: x,
+ 'value': {
+ 'min': (
+ 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, -0.42, -1, -0.42, 0, 0, 0, 0, 0, 0, 0, 0,
+ 0, 0, 0, 0, 0, 0, 0, 0
+ ),
+ 'max': (
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0.42, 1, 0.42, float(np.inf), 1, 1, 1, 1,
+ 1, 1, 1, 1, 1, 1, 1, 1, 1, 1
+ ),
+ 'dtype': float,
+ 'dinfo': 'mix'
+ },
+ 'other': 'mixed active player info'
+ },
+ {
+ 'key': 'ball',
+ 'ret_key': 'ball',
+ 'dim': 18,
+ 'op': lambda x: x,
+ 'value': {
+ 'min': (-1, -0.42, 0, 0, 0, 0, 0, 0, 0, -2, -0.84, -20, -8.4, 0, 0, 0, 0, 0),
+ 'max': (1, 0.42, 100, 1, 1, 1, 1, 1, 1, 2, 0.84, 20, 8.4, np.inf, np.inf, 2.5, 1, 1),
+ 'dtype': float,
+ 'dinfo': 'mix'
+ },
+ 'other': 'mixed ball info, relative to active player'
+ },
+ {
+ 'key': 'LeftTeam',
+ 'ret_key': 'LeftTeam',
+ 'dim': 7,
+ 'op': lambda x: x,
+ 'value': {
+ 'min': (-1, -0.42, -1, -0.42, 0, 0, 0),
+ 'max': (1, 0.42, 1, 0.42, 100, 2.5, 1),
+ 'dtype': float,
+ 'dinfo': 'mix'
+ },
+ 'other': 'mixed player info, relative to active player,\
+ will have 10+1 infos(all left team member and closest member )'
+ },
+ {
+ 'key': 'RightTeam',
+ 'ret_key': 'RightTeam',
+ 'dim': 7,
+ 'op': lambda x: x,
+ 'value': {
+ 'min': (-1, -0.42, -1, -0.42, 0, 0, 0),
+ 'max': (1, 0.42, 1, 0.42, 100, 2.5, 1),
+ 'dtype': float,
+ 'dinfo': 'mix'
+ },
+ 'other': 'mixed player info, relative to active player,\
+ will have 10+1 infos(all right team member and closest member )'
+ },
+ ]
+ self.cfg = cfg
+ self._shape = {t['key']: t['dim'] for t in self.template}
+ self._value = {t['key']: t['value'] for t in self.template}
+
+ def _details(self):
+ return 'Full Obs for Gfootball Self Play'
diff --git a/DI-engine/dizoo/gfootball/envs/obs/gfootball_obs_runner.py b/DI-engine/dizoo/gfootball/envs/obs/gfootball_obs_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9d7cee93ea76c046a96c9e870d83fa8d14af2e9
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/obs/gfootball_obs_runner.py
@@ -0,0 +1,33 @@
+import copy
+
+import numpy as np
+
+from ding.envs.common import EnvElementRunner, EnvElement
+from ding.envs.env.base_env import BaseEnv
+from .gfootball_obs import PlayerObs, MatchObs
+from ding.utils import deep_merge_dicts
+
+
+class GfootballObsRunner(EnvElementRunner):
+
+ def _init(self, cfg, *args, **kwargs) -> None:
+ # set self._core and other state variable
+ self._obs_match = MatchObs(cfg)
+ self._obs_player = PlayerObs(cfg)
+ self._core = self._obs_player # placeholder
+
+ def get(self, engine: BaseEnv) -> dict:
+ ret = copy.deepcopy(engine._football_obs)
+ # print(ret, type(ret))
+ assert isinstance(ret, dict)
+ match_obs = self._obs_match._to_agent_processor(ret)
+ players_obs = self._obs_player._to_agent_processor(ret)
+ return deep_merge_dicts(match_obs, players_obs)
+
+ def reset(self) -> None:
+ pass
+
+ # override
+ @property
+ def info(self):
+ return {'match': self._obs_match.info, 'player': self._obs_player.info}
diff --git a/DI-engine/dizoo/gfootball/envs/reward/gfootball_reward.py b/DI-engine/dizoo/gfootball/envs/reward/gfootball_reward.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bdb9cf06f8e5a4b5dc1c0d997de69ebd441dcd9
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/reward/gfootball_reward.py
@@ -0,0 +1,49 @@
+from collections import namedtuple
+
+import numpy as np
+
+from ding.envs.common import EnvElement
+
+
+class GfootballReward(EnvElement):
+ _name = "gfootballReward"
+ _reward_keys = ['reward_value']
+ Reward = namedtuple('Action', _reward_keys)
+
+ MinReward = -1.0
+ MaxReward = 1.0
+
+ def _init(self, cfg) -> None:
+ self._default_val = 0.0
+ self.template = {
+ 'reward_value': {
+ 'name': 'reward_value',
+ 'shape': (1, ),
+ 'value': {
+ 'min': -1.0,
+ 'max': 1.0,
+ 'dtype': float,
+ 'dinfo': 'float value',
+ },
+ 'env_value': 'reward of action',
+ 'to_agent_processor': lambda x: x,
+ 'from_agent_processor': lambda x: x,
+ 'necessary': True,
+ }
+ }
+ self._shape = (1, )
+ self._value = {
+ 'min': -1.0,
+ 'max': 1.0,
+ 'dtype': float,
+ 'dinfo': 'float value',
+ }
+
+ def _to_agent_processor(self, reward: float) -> np.array:
+ return np.array([reward], dtype=float)
+
+ def _from_agent_processor(self, reward: float) -> float:
+ return reward
+
+ def _details(self):
+ return '\t'.join(self._reward_keys)
diff --git a/DI-engine/dizoo/gfootball/envs/reward/gfootball_reward_runner.py b/DI-engine/dizoo/gfootball/envs/reward/gfootball_reward_runner.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e62c6075e1a538d302950bb74ea101fc2dbb005
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/reward/gfootball_reward_runner.py
@@ -0,0 +1,27 @@
+import copy
+
+import torch
+
+from ding.envs.common import EnvElementRunner
+from ding.envs.env.base_env import BaseEnv
+from .gfootball_reward import GfootballReward
+
+
+class GfootballRewardRunner(EnvElementRunner):
+
+ def _init(self, cfg, *args, **kwargs) -> None:
+ # set self._core and other state variable
+ self._core = GfootballReward(cfg)
+ self._cum_reward = 0.0
+
+ def get(self, engine: BaseEnv) -> torch.tensor:
+ ret = copy.deepcopy(engine._reward_of_action)
+ self._cum_reward += ret
+ return self._core._to_agent_processor(ret)
+
+ def reset(self) -> None:
+ self._cum_reward = 0.0
+
+ @property
+ def cum_reward(self) -> torch.tensor:
+ return torch.FloatTensor([self._cum_reward])
diff --git a/DI-engine/dizoo/gfootball/envs/tests/test_env_gfootball.py b/DI-engine/dizoo/gfootball/envs/tests/test_env_gfootball.py
new file mode 100644
index 0000000000000000000000000000000000000000..86b704a087a983a953c79955f4fc6fb9a68011bd
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/tests/test_env_gfootball.py
@@ -0,0 +1,41 @@
+import pytest
+import numpy as np
+import pprint
+
+try:
+ from dizoo.gfootball.envs.gfootball_env import GfootballEnv
+except ModuleNotFoundError:
+ print("[WARNING] no gfootball env, if you want to use gfootball, please install it, otherwise, ignore it.")
+
+
+@pytest.mark.envtest
+class TestGfootballEnv:
+
+ def get_random_action(self, min_value, max_value):
+ action = np.random.randint(min_value, max_value + 1, (1, ))
+ return action
+
+ def test_naive(self):
+ env = GfootballEnv({})
+ print(env.info())
+ reset_obs = env.reset()
+ print('after reset:', reset_obs)
+ pp = pprint.PrettyPrinter(indent=2)
+ for i in range(3000):
+ action = self.get_random_action(env.info().act_space.value['min'], env.info().act_space.value['max'])
+ timestep = env.step(action)
+ reward = timestep.reward
+ print('reward:', reward)
+ # assert reward.shape == 1
+ obs = timestep.obs
+ print("raw_obs = ", obs['raw_obs'])
+ obs = obs['processed_obs']
+ assert obs['ball_owned_team'].shape[0] == 3
+ assert obs['ball_owned_player'].shape[0] == 12
+ assert obs['active_player'].shape[0] == 11
+ assert obs['score'].shape[0] == 22
+ assert obs['steps_left'].shape[0] == 30
+ print('observation: ')
+ pp.pprint(obs)
+ print('--step {} with action {}'.format(i, action))
+ print('end')
diff --git a/DI-engine/dizoo/gfootball/envs/tests/test_env_gfootball_academy.py b/DI-engine/dizoo/gfootball/envs/tests/test_env_gfootball_academy.py
new file mode 100644
index 0000000000000000000000000000000000000000..c81421cfbe7c69964b90303ff9b06bc404a486db
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/envs/tests/test_env_gfootball_academy.py
@@ -0,0 +1,88 @@
+import pytest
+import numpy as np
+import pprint
+from easydict import EasyDict
+
+try:
+ from dizoo.gfootball.envs.gfootball_academy_env import GfootballAcademyEnv
+except ModuleNotFoundError:
+ print("[WARNING] no gfootball env, if you want to use gfootball, please install it, otherwise, ignore it.")
+
+cfg_keeper = EasyDict(dict(
+ env_name='academy_3_vs_1_with_keeper',
+ agent_num=3,
+ obs_dim=26,
+))
+
+cfg_counter = EasyDict(dict(
+ env_name='academy_counterattack_hard',
+ agent_num=4,
+ obs_dim=34,
+))
+
+
+@pytest.mark.envtest
+class TestGfootballAcademyEnv:
+
+ def get_random_action(self, min_value, max_value):
+ action = np.random.randint(min_value, max_value + 1, (1, ))
+ return action
+
+ def test_academy_3_vs_1_with_keeper(self):
+ cfg = cfg_keeper
+ env = GfootballAcademyEnv(cfg)
+ print(env.observation_space, env._action_space, env.reward_space)
+ pp = pprint.PrettyPrinter(indent=2)
+ for i in range(2):
+ eps_len = 0
+ # env.enable_save_replay(replay_path='./video')
+ reset_obs = env.reset()
+ while True:
+ eps_len += 1
+ action = env.random_action()[0]
+ action = [int(action_agent) for k, action_agent in action.items()]
+ timestep = env.step(action)
+ obs = timestep.obs
+ reward = timestep.reward
+ done = timestep.done
+ # print('observation: ')
+ # pp.pprint(obs)
+ assert obs['agent_state'].shape == (cfg.agent_num, cfg.obs_dim)
+ assert obs['global_state'].shape == (cfg.agent_num, cfg.obs_dim * 2)
+ assert obs['action_mask'].shape == (cfg.agent_num, 19)
+
+ print('step {}, action: {}, reward: {}'.format(eps_len, action, reward))
+ if done:
+ break
+ assert reward == -1 or reward == 100
+ print(f'Episode {i} done! The episode length is {eps_len}. The last reward is {reward}.')
+ print('End')
+
+ def test_academy_counterattack_hard(self):
+ cfg = cfg_counter
+ env = GfootballAcademyEnv(cfg)
+ print(env.observation_space, env._action_space, env.reward_space)
+ pp = pprint.PrettyPrinter(indent=2)
+ for i in range(2):
+ eps_len = 0
+ reset_obs = env.reset()
+ while True:
+ eps_len += 1
+ action = env.random_action()[0]
+ action = [int(action_agent) for k, action_agent in action.items()]
+ timestep = env.step(action)
+ obs = timestep.obs
+ reward = timestep.reward
+ done = timestep.done
+ # print('observation: ')
+ # pp.pprint(obs)
+ assert obs['agent_state'].shape == (cfg.agent_num, cfg.obs_dim)
+ assert obs['global_state'].shape == (cfg.agent_num, cfg.obs_dim * 2)
+ assert obs['action_mask'].shape == (cfg.agent_num, 19)
+
+ print('step {}, action: {}, reward: {}'.format(eps_len, action, reward))
+ if done:
+ break
+ assert reward == -1 or reward == 100
+ print(f'Episode {i} done! The episode length is {eps_len}. The last reward is {reward}.')
+ print('End')
diff --git a/DI-engine/dizoo/gfootball/model/__init__.py b/DI-engine/dizoo/gfootball/model/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/config.yaml b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/config.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..e883bc70218c31148eb36dbb828134bcc41e4ceb
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/config.yaml
@@ -0,0 +1,42 @@
+
+env_args:
+ env: 'Football'
+ source: 'football_ikki'
+ frames_per_sec: 10 # we cannot change
+
+ frame_skip: 0
+ limit_steps: 3002
+
+train_args:
+ gamma_per_sec: 0.97
+ lambda_per_sec: 0.4
+ forward_steps: 64
+ compress_steps: 16
+ entropy_regularization: 1.3e-3
+ monte_carlo_rate: 1.0
+ update_episodes: 400
+ batch_size: 192
+ minimum_episodes: 3000
+ maximum_episodes: 30000
+ num_batchers: 23
+ eval_rate: 0.1
+ replay_rate: 0 # 0.1
+ supervised_weight: 0 # 0.1
+ record_dir: "records/"
+ randomized_start_rate: 0.3
+ randomized_start_max_steps: 400
+ reward_reset: True
+ worker:
+ num_gather: 2
+ num_process: 6
+ seed: 1800
+ restart_epoch: 1679
+
+entry_args:
+ remote_host: ''
+ num_gather: 2
+ num_process: 6
+
+eval_args:
+ remote_host: ''
+
diff --git a/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/football/util.py b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/football/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c1aa2fe0e4d7ac0a81d79f1407ac89172a5ed03
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/football/util.py
@@ -0,0 +1,91 @@
+# https://github.com/Kaggle/kaggle-environments/blob/master/kaggle_environments/envs/football/helpers.py
+
+import enum
+from functools import wraps
+from typing import *
+
+
+class Action(enum.IntEnum):
+ Idle = 0
+ Left = 1
+ TopLeft = 2
+ Top = 3
+ TopRight = 4
+ Right = 5
+ BottomRight = 6
+ Bottom = 7
+ BottomLeft = 8
+ LongPass = 9
+ HighPass = 10
+ ShortPass = 11
+ Shot = 12
+ Sprint = 13
+ ReleaseDirection = 14
+ ReleaseSprint = 15
+ Slide = 16
+ Dribble = 17
+ ReleaseDribble = 18
+
+
+sticky_index_to_action = [
+ Action.Left, Action.TopLeft, Action.Top, Action.TopRight, Action.Right, Action.BottomRight, Action.Bottom,
+ Action.BottomLeft, Action.Sprint, Action.Dribble
+]
+
+action_to_sticky_index = {a: index for index, a in enumerate(sticky_index_to_action)}
+
+
+class PlayerRole(enum.IntEnum):
+ GoalKeeper = 0
+ CenterBack = 1
+ LeftBack = 2
+ RightBack = 3
+ DefenceMidfield = 4
+ CentralMidfield = 5
+ LeftMidfield = 6
+ RIghtMidfield = 7
+ AttackMidfield = 8
+ CentralFront = 9
+
+
+class GameMode(enum.IntEnum):
+ Normal = 0
+ KickOff = 1
+ GoalKick = 2
+ FreeKick = 3
+ Corner = 4
+ ThrowIn = 5
+ Penalty = 6
+
+
+def human_readable_agent(agent: Callable[[Dict], Action]):
+ """
+ Decorator allowing for more human-friendly implementation of the agent function.
+ @human_readable_agent
+ def my_agent(obs):
+ ...
+ return football_action_set.action_right
+ """
+
+ @wraps(agent)
+ def agent_wrapper(obs) -> List[int]:
+ # Extract observations for the first (and only) player we control.
+ obs = obs['players_raw'][0]
+ # Turn 'sticky_actions' into a set of active actions (strongly typed).
+ obs['sticky_actions'] = {
+ sticky_index_to_action[nr]
+ for nr, action in enumerate(obs['sticky_actions']) if action
+ }
+ # Turn 'game_mode' into an enum.
+ obs['game_mode'] = GameMode(obs['game_mode'])
+ # In case of single agent mode, 'designated' is always equal to 'active'.
+ if 'designated' in obs:
+ del obs['designated']
+ # Conver players' roles to enum.
+ obs['left_team_roles'] = [PlayerRole(role) for role in obs['left_team_roles']]
+ obs['right_team_roles'] = [PlayerRole(role) for role in obs['right_team_roles']]
+
+ action = agent(obs)
+ return [action.value]
+
+ return agent_wrapper
diff --git a/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/football_ikki.py b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/football_ikki.py
new file mode 100644
index 0000000000000000000000000000000000000000..c45865158d3e2467a9f7000f8afb6868f584dbb0
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/football_ikki.py
@@ -0,0 +1,1329 @@
+import os
+import sys
+import random
+import json
+import copy
+import enum
+from functools import partial
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from dizoo.gfootball.model.bots.TamakEriFever.handyrl_core.model import BaseModel, Dense
+from dizoo.gfootball.model.bots.TamakEriFever.football.util import *
+
+# import dizoo.gfootball.model.TamakEriFever.football.rulebaseA as rulebaseA
+# import dizoo.gfootball.model.TamakEriFever.football.rulebaseB as rulebaseB
+# import dizoo.gfootball.model.TamakEriFever.football.rulebaseC as rulebaseC
+# #import football.rulebaseD as rulebaseD
+# import dizoo.gfootball.model.TamakEriFever.football.rulebaseE as rulebaseE
+# import dizoo.gfootball.model.TamakEriFever.football.rulebaseF as rulebaseF
+
+
+class MultiHeadAttention(nn.Module):
+ # multi head attention for sets
+ # https://github.com/akurniawan/pytorch-transformer/blob/master/modules/attention.py
+ def __init__(self, in_dim, out_dim, out_heads, relation_dim=0, residual=False, projection=True, layer_norm=True):
+ super().__init__()
+ self.in_dim = in_dim
+ self.out_dim = out_dim
+ self.out_heads = out_heads
+ self.relation_dim = relation_dim
+ assert self.out_dim % self.out_heads == 0
+ self.query_layer = nn.Linear(self.in_dim + self.relation_dim, self.out_dim, bias=False)
+ self.key_layer = nn.Linear(self.in_dim + self.relation_dim, self.out_dim, bias=False)
+ self.value_layer = nn.Linear(self.in_dim, self.out_dim, bias=False)
+ self.residual = residual
+ self.projection = projection
+ if self.projection:
+ self.proj_layer = nn.Linear(self.out_dim, self.out_dim)
+ self.layer_norm = layer_norm
+ if self.layer_norm:
+ self.ln = nn.LayerNorm(self.out_dim)
+
+ self.reset_parameters()
+
+ def reset_parameters(self):
+ nn.init.uniform_(self.query_layer.weight, -0.1, 0.1)
+ nn.init.uniform_(self.key_layer.weight, -0.1, 0.1)
+ nn.init.uniform_(self.value_layer.weight, -0.1, 0.1)
+ if self.projection:
+ nn.init.uniform_(self.proj_layer.weight, -0.1, 0.1)
+
+ def forward(self, query, key, relation=None, mask=None, key_mask=None, distance=None):
+ """
+ Args:
+ query (torch.Tensor): [batch, query_len, in_dim]
+ key (torch.Tensor): [batch, key_len, in_dim]
+ relation (torch.Tensor): [batch, query_len, key_len, relation_dim]
+ mask (torch.Tensor): [batch, query_len]
+ key_mask (torch.Tensor): [batch, key_len]
+ Returns:
+ torch.Tensor: [batch, query_len, out_dim]
+ """
+
+ query_len = query.size(-2)
+ key_len = key.size(-2)
+ head_dim = self.out_dim // self.out_heads
+
+ if key_mask is None:
+ if torch.equal(query, key):
+ key_mask = mask
+
+ if relation is not None:
+ relation = relation.view(-1, query_len, key_len, self.relation_dim)
+
+ query_ = query.view(-1, query_len, 1, self.in_dim).repeat(1, 1, key_len, 1)
+ query_ = torch.cat([query_, relation], dim=-1)
+
+ key_ = key.view(-1, 1, key_len, self.in_dim).repeat(1, query_len, 1, 1)
+ key_ = torch.cat([key_, relation], dim=-1)
+
+ Q = self.query_layer(query_).view(-1, query_len * key_len, self.out_heads, head_dim)
+ K = self.key_layer(key_).view(-1, query_len * key_len, self.out_heads, head_dim)
+
+ Q = Q.transpose(1, 2).contiguous().view(-1, query_len, key_len, head_dim)
+ K = K.transpose(1, 2).contiguous().view(-1, query_len, key_len, head_dim)
+
+ attention = (Q * K).sum(dim=-1)
+ else:
+ Q = self.query_layer(query).view(-1, query_len, self.out_heads, head_dim)
+ K = self.key_layer(key).view(-1, key_len, self.out_heads, head_dim)
+
+ Q = Q.transpose(1, 2).contiguous().view(-1, query_len, head_dim)
+ K = K.transpose(1, 2).contiguous().view(-1, key_len, head_dim)
+
+ attention = torch.bmm(Q, K.transpose(1, 2))
+
+ if distance is not None:
+ attention = attention - torch.log1p(distance.repeat(self.out_heads, 1, 1))
+ attention = attention * (float(head_dim) ** -0.5)
+
+ if key_mask is not None:
+ attention = attention.view(-1, self.out_heads, query_len, key_len)
+ attention = attention + ((1 - key_mask) * -1e32).view(-1, 1, 1, key_len)
+ attention = F.softmax(attention, dim=-1)
+ if mask is not None:
+ attention = attention * mask.view(-1, 1, query_len, 1)
+ attention = attention.contiguous().view(-1, query_len, key_len)
+
+ V = self.value_layer(key).view(-1, key_len, self.out_heads, head_dim)
+ V = V.transpose(1, 2).contiguous().view(-1, key_len, head_dim)
+
+ output = torch.bmm(attention, V).view(-1, self.out_heads, query_len, head_dim)
+ output = output.transpose(1, 2).contiguous().view(*query.size()[:-2], query_len, self.out_dim)
+
+ if self.projection:
+ output = self.proj_layer(output)
+
+ if self.residual:
+ output = output + query
+
+ if self.layer_norm:
+ output = self.ln(output)
+
+ if mask is not None:
+ output = output * mask.unsqueeze(-1)
+ attention = attention.view(*query.size()[:-2], self.out_heads, query_len, key_len).detach()
+
+ return output, attention
+
+
+class ResidualBlock(nn.Module):
+
+ def __init__(self, in_channels, out_channels, activation='relu'):
+ super().__init__()
+ self.in_channels, self.out_channels, self.activation = in_channels, out_channels, activation
+ self.blocks = nn.Identity()
+ self.activate = nn.ReLU() # activation_func(activation)
+ self.shortcut = nn.Identity()
+
+ def forward(self, x):
+ residual = x
+ if self.should_apply_shortcut:
+ residual = self.shortcut(x)
+ x = self.blocks(x)
+ x += residual
+ x = self.activate(x)
+ return x
+
+ @property
+ def should_apply_shortcut(self):
+ return self.in_channels != self.out_channels
+
+
+class Conv2dAuto(nn.Conv2d):
+
+ def __init__(self, *args, **kwargs):
+ super().__init__(*args, **kwargs)
+ self.padding = (
+ self.kernel_size[0] // 2, self.kernel_size[1] // 2
+ ) # dynamic add padding based on the kernel_size
+
+
+class ResNetResidualBlock(ResidualBlock):
+
+ def __init__(self, in_channels, out_channels, expansion=1, downsampling=1, *args, **kwargs):
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.expansion, self.downsampling, self.conv = expansion, downsampling, partial(
+ Conv2dAuto, kernel_size=3, bias=False
+ )
+ self.shortcut = nn.Sequential(
+ nn.Conv2d(self.in_channels, self.expanded_channels, kernel_size=1, stride=self.downsampling, bias=False),
+ nn.BatchNorm2d(self.expanded_channels)
+ ) if self.should_apply_shortcut else None
+
+ @property
+ def expanded_channels(self):
+ return self.out_channels * self.expansion
+
+ @property
+ def should_apply_shortcut(self):
+ return self.in_channels != self.expanded_channels
+
+
+def activation_func(activation):
+ return nn.ModuleDict(
+ [
+ ['relu', nn.ReLU(inplace=True)], ['leaky_relu',
+ nn.LeakyReLU(negative_slope=0.01, inplace=True)],
+ ['selu', nn.SELU(inplace=True)], ['none', nn.Identity()]
+ ]
+ )[activation]
+
+
+def conv_bn(in_channels, out_channels, conv, *args, **kwargs):
+ conv3x3 = partial(Conv2dAuto, kernel_size=3, bias=False)
+ return nn.Sequential(conv3x3(in_channels, out_channels, *args, **kwargs), nn.BatchNorm2d(out_channels))
+
+
+class ResNetBasicBlock(ResNetResidualBlock):
+ """
+ Basic ResNet block composed by two layers of 3x3conv/batchnorm/activation
+ """
+ expansion = 1
+
+ def __init__(self, in_channels, out_channels, *args, **kwargs):
+ super().__init__(in_channels, out_channels, *args, **kwargs)
+ self.blocks = nn.Sequential(
+ conv_bn(self.in_channels, self.out_channels, conv=self.conv, bias=False, stride=self.downsampling),
+ activation_func(self.activation),
+ conv_bn(self.out_channels, self.expanded_channels, conv=self.conv, bias=False),
+ )
+
+
+class FootballNet(BaseModel):
+
+ class FootballEncoder(nn.Module):
+
+ def __init__(self, filters):
+ super().__init__()
+ self.player_embedding = nn.Embedding(32, 5, padding_idx=0)
+ self.mode_embedding = nn.Embedding(8, 3, padding_idx=0)
+ self.fc_teammate = nn.Linear(23, filters)
+ self.fc_opponent = nn.Linear(23, filters)
+ self.fc = nn.Linear(filters + 41, filters)
+
+ def forward(self, x):
+ bs = x['mode_index'].size(0)
+ # scalar features
+ m_emb = self.mode_embedding(x['mode_index']).view(bs, -1)
+ ball = x['ball']
+ s = torch.cat([ball, x['match'], x['distance']['b2o'].view(bs, -1), m_emb], dim=1)
+
+ # player features
+ p_emb_self = self.player_embedding(x['player_index']['self'])
+ ball_concat_self = ball.view(bs, 1, -1).repeat(1, x['player']['self'].size(1), 1)
+ p_self = torch.cat([x['player']['self'], p_emb_self, ball_concat_self], dim=2)
+
+ p_emb_opp = self.player_embedding(x['player_index']['opp'])
+ ball_concat_opp = ball.view(bs, 1, -1).repeat(1, x['player']['opp'].size(1), 1)
+ p_opp = torch.cat([x['player']['opp'], p_emb_opp, ball_concat_opp], dim=2)
+
+ # encoding linear layer
+ p_self = self.fc_teammate(p_self)
+ p_opp = self.fc_opponent(p_opp)
+
+ p = F.relu(torch.cat([p_self, p_opp], dim=1))
+ s_concat = s.view(bs, 1, -1).repeat(1, p.size(1), 1)
+ """
+ TODO(pu): How to deal with dimension mismatch better?
+ original code is:
+ p = torch.cat([p, x['distance']['p2bo'].view(bs, p.size(1), -1), s_concat], dim=2)
+ """
+ p = torch.cat([p, x['distance']['p2bo'].repeat(1, 2, 1).view(bs, p.size(1), -1), s_concat], dim=2)
+ h = F.relu(self.fc(p))
+
+ # relation
+ rel = None # x['distance']['p2p']
+ distance = None # x['distance']['p2p']
+
+ return h, rel, distance
+
+ class FootballBlock(nn.Module):
+
+ def __init__(self, filters, heads):
+ super().__init__()
+ self.attention = MultiHeadAttention(filters, filters, heads, relation_dim=0, residual=True, projection=True)
+
+ def forward(self, x, rel, distance=None):
+ h, _ = self.attention(x, x, relation=rel, distance=distance)
+ return h
+
+ class FootballControll(nn.Module):
+
+ def __init__(self, filters, final_filters):
+ super().__init__()
+ self.filters = filters
+ self.attention = MultiHeadAttention(filters, filters, 1, residual=False, projection=True)
+ # self.fc_control = Dense(filters * 3, final_filters, bnunits=final_filters)
+ self.fc_control = Dense(filters * 3, final_filters, bnunits=final_filters)
+
+ def forward(self, x, e, control_flag):
+ x_controled = (x * control_flag).sum(dim=1, keepdim=True)
+ e_controled = (e * control_flag).sum(dim=1, keepdim=True)
+
+ h, _ = self.attention(x_controled, x)
+
+ h = torch.cat([x_controled, e_controled, h], dim=2).view(x.size(0), -1)
+ # h = torch.cat([h, cnn_h.view(cnn_h.size(0), -1)], dim=1)
+ h = self.fc_control(h)
+ return h
+
+ class FootballHead(nn.Module):
+
+ def __init__(self, filters):
+ super().__init__()
+ self.head_p = nn.Linear(filters, 19, bias=False)
+ self.head_p_special = nn.Linear(filters, 1 + 8 * 4, bias=False)
+ self.head_v = nn.Linear(filters, 1, bias=True)
+ self.head_r = nn.Linear(filters, 1, bias=False)
+
+ def forward(self, x):
+ p = self.head_p(x)
+ p2 = self.head_p_special(x)
+ v = self.head_v(x)
+ r = self.head_r(x)
+ return torch.cat([p, p2], -1), v, r
+
+ class CNNModel(nn.Module):
+
+ def __init__(self, final_filters):
+ super().__init__()
+ self.conv1 = nn.Sequential(
+ nn.Conv2d(53, 128, kernel_size=1, stride=1, bias=False), nn.ReLU(inplace=True),
+ nn.Conv2d(128, 160, kernel_size=1, stride=1, bias=False), nn.ReLU(inplace=True),
+ nn.Conv2d(160, 128, kernel_size=1, stride=1, bias=False), nn.ReLU(inplace=True)
+ )
+ self.pool1 = nn.AdaptiveAvgPool2d((1, 11))
+ self.conv2 = nn.Sequential(
+ nn.BatchNorm2d(128),
+ nn.Conv2d(128, 160, kernel_size=(1, 1), stride=1, bias=False),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(160),
+ nn.Conv2d(160, 96, kernel_size=(1, 1), stride=1, bias=False),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(96),
+ nn.Conv2d(96, final_filters, kernel_size=(1, 1), stride=1, bias=False),
+ nn.ReLU(inplace=True),
+ nn.BatchNorm2d(final_filters),
+ )
+ self.pool2 = nn.AdaptiveAvgPool2d((1, 1))
+ self.flatten = nn.Flatten()
+
+ def forward(self, x):
+ x = x['cnn_feature']
+ x = self.conv1(x)
+ x = self.pool1(x)
+ x = self.conv2(x)
+ x = self.pool2(x)
+ x = self.flatten(x)
+ return x
+
+ class SMMEncoder(nn.Module):
+
+ class SMMBlock(nn.Module):
+
+ def __init__(self, in_filters, out_filters, residuals=2):
+ super().__init__()
+ self.conv1 = nn.Conv2d(in_filters, out_filters, kernel_size=3, stride=1, bias=False)
+ self.pool1 = nn.MaxPool2d(3, stride=2)
+ self.blocks = nn.ModuleList([ResNetBasicBlock(out_filters, out_filters) for _ in range(residuals)])
+
+ def forward(self, x):
+ h = self.conv1(x)
+ h = self.pool1(h)
+ for block in self.blocks:
+ h = block(h)
+ return h
+
+ def __init__(self, filters):
+ super().__init__()
+ # 4, 72, 96 => filters, 1, 3
+ self.blocks = nn.ModuleList(
+ [
+ self.SMMBlock(4, filters),
+ self.SMMBlock(filters, filters),
+ self.SMMBlock(filters, filters),
+ self.SMMBlock(filters, filters),
+ ]
+ )
+
+ def forward(self, x):
+ x = x['smm']
+ h = x
+ for block in self.blocks:
+ h = block(h)
+ h = F.relu(h)
+ return h
+
+ class ActionHistoryEncoder(nn.Module):
+
+ def __init__(self, input_size=19, hidden_size=64, num_layers=2, bidirectional=True):
+ super().__init__()
+ self.action_emd = nn.Embedding(19, 8)
+ self.rnn = nn.GRU(8, hidden_size, num_layers, batch_first=True, bidirectional=bidirectional)
+
+ def forward(self, x):
+ h = self.action_emd(x['action_history'])
+ h = h.squeeze(dim=2)
+ self.rnn.flatten_parameters()
+ h, _ = self.rnn(h)
+ return h
+
+ def __init__(self, env, args={}, action_length=None):
+ super().__init__(env, args, action_length)
+ blocks = 5
+ filters = 96
+ final_filters = 128
+ smm_filters = 32
+ self.encoder = self.FootballEncoder(filters)
+ self.blocks = nn.ModuleList([self.FootballBlock(filters, 8) for _ in range(blocks)])
+ self.control = self.FootballControll(filters, final_filters) # to head
+
+ self.cnn = self.CNNModel(final_filters) # to control
+ # self.smm = self.SMMEncoder(smm_filters) # to control
+ rnn_hidden = 64
+ self.rnn = self.ActionHistoryEncoder(19, rnn_hidden, 2)
+
+ self.head = self.FootballHead(final_filters + final_filters + rnn_hidden * 2)
+ # self.head = self.FootballHead(19, final_filters)
+
+ def init_hidden(self, batch_size=None):
+ return None
+
+ def forward(self, x, hidden):
+ e, rel, distance = self.encoder(x)
+ h = e
+ for block in self.blocks:
+ h = block(h, rel, distance)
+ cnn_h = self.cnn(x)
+ # smm_h = self.smm(x)
+ # h = self.control(h, e, x['control_flag'], cnn_h, smm_h)
+ h = self.control(h, e, x['control_flag'])
+ rnn_h = self.rnn(x)
+
+ # p, v, r = self.head(torch.cat([h,
+ # cnn_h.view(cnn_h.size(0), -1),
+ # smm_h.view(smm_h.size(0), -1)], axis=-1))
+
+ rnn_h_head_tail = rnn_h[:, 0, :] + rnn_h[:, -1, :]
+ rnn_h_plus_stick = torch.cat([rnn_h_head_tail[:, :-4], x['control']], dim=1)
+ p, v, r = self.head(torch.cat([
+ h,
+ cnn_h.view(cnn_h.size(0), -1),
+ rnn_h_plus_stick,
+ ], axis=-1))
+ # p, v, r = self.head(h)
+
+ return p, torch.tanh(v), torch.tanh(r), hidden
+
+
+OBS_TEMPLATE = {
+ "controlled_players": 1,
+ "players_raw": [
+ {
+ "right_team_active": [True, True, True, True, True, True, True, True, True, True, True],
+ "right_team_yellow_card": [False, False, False, False, False, False, False, False, False, False, False],
+ "left_team_tired_factor": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ "right_team_roles": [0, 2, 1, 1, 3, 5, 5, 5, 6, 9, 7],
+ "left_team": [
+ [-1.0110293626785278, -0.0], [-0.4266543984413147, -0.19894461333751678],
+ [-0.5055146813392639, -0.06459399312734604], [-0.5055146813392639, 0.06459297984838486],
+ [-0.4266543984413147, 0.19894461333751678], [-0.18624374270439148, -0.10739918798208237],
+ [-0.270525187253952, -0.0], [-0.18624374270439148, 0.10739918798208237],
+ [-0.010110294446349144, -0.21961550414562225], [-0.05055147036910057, -0.0],
+ [-0.010110294446349144, 0.21961753070354462]
+ ],
+ "ball": [0.0, -0.0, 0.11061639338731766],
+ "ball_owned_team": -1,
+ "right_team_direction": [
+ [-0.0, 0.0], [-0.0, 0.0], [-0.0, 0.0], [-0.0, 0.0], [-0.0, 0.0], [-0.0, 0.0], [-0.0, 0.0], [-0.0, 0.0],
+ [-0.0, 0.0], [-0.0, 0.0], [-0.0, 0.0]
+ ],
+ "left_team_direction": [
+ [0.0, -0.0], [0.0, -0.0], [0.0, -0.0], [0.0, -0.0], [0.0, -0.0], [0.0, -0.0], [0.0, -0.0], [0.0, -0.0],
+ [0.0, -0.0], [0.0, -0.0], [0.0, -0.0]
+ ],
+ "left_team_roles": [0, 2, 1, 1, 3, 5, 5, 5, 6, 9, 7],
+ "score": [0, 0],
+ "left_team_active": [True, True, True, True, True, True, True, True, True, True, True],
+ "game_mode": 0,
+ "steps_left": 3001,
+ "ball_direction": [-0.0, 0.0, 0.006163952872157097],
+ "ball_owned_player": -1,
+ "right_team": [
+ [1.0110293626785278, 0.0], [0.4266543984413147, 0.19894461333751678],
+ [0.5055146813392639, 0.06459399312734604], [0.5055146813392639, -0.06459297984838486],
+ [0.4266543984413147, -0.19894461333751678], [0.18624374270439148, 0.10739918798208237],
+ [0.270525187253952, 0.0], [0.18624374270439148, -0.10739918798208237],
+ [0.010110294446349144, 0.21961550414562225], [-0.0, -0.02032535709440708], [-0.0, 0.02032535709440708]
+ ],
+ "left_team_yellow_card": [False, False, False, False, False, False, False, False, False, False, False],
+ "ball_rotation": [0.0, -0.0, 0.0],
+ "right_team_tired_factor": [0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0],
+ "designated": 6,
+ "active": 6,
+ "sticky_actions": [0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ }
+ ]
+}
+
+INFO_TEMPLATE = {'half_step': 1500}
+
+
+# feature
+def feature_from_states(states, info, player):
+ # observation list to input tensor
+
+ HISTORY_LENGTH = 8
+
+ obs_history_ = [s[player]['observation']['players_raw'][0] for s in reversed(states[-HISTORY_LENGTH:])]
+ obs_history = obs_history_ + [obs_history_[-1]] * (HISTORY_LENGTH - len(obs_history_))
+ obs = obs_history[0]
+
+ action_history_ = [s[player]['action'][0] for s in reversed(states[-HISTORY_LENGTH:])]
+ action_history = action_history_ + [0] * (HISTORY_LENGTH - len(action_history_))
+ """
+ ・left players (x)
+ ・left players (y)
+ ・right players (x)
+ ・right players (y)
+ ・ball (x)
+ ・ball (y)
+ ・left goal (x)
+ ・left goal (y)
+ ・right goal (x)
+ ・right goal (y)
+ ・active (x)
+ ・active (y)
+
+ ・left players (x) - right players (x)
+ ・left players (y) - right players (y)
+ ・left players (x) - ball (x)
+ ・left players (y) - ball (y)
+ ・left players (x) - goal (x)
+ ・left players (y) - goal (y)
+ ・left players (x) - active (x)
+ ・left players (y) - active (y)
+
+ ・left players direction (x)
+ ・left players direction (y)
+ ・right players direction (x)
+ ・right players direction (y)
+ ・left players direction (x) - right players direction (x)
+ ・left players direction (y) - right players direction (y)
+ """
+
+ # left players
+ obs_left_team = np.array(obs['left_team'])
+ left_player_x = np.repeat(obs_left_team[:, 0][..., None], 11, axis=1)
+ left_player_y = np.repeat(obs_left_team[:, 1][..., None], 11, axis=1)
+
+ # right players
+ obs_right_team = np.array(obs['right_team'])
+ right_player_x = np.repeat(obs_right_team[:, 0][..., None], 11, axis=1).transpose(1, 0)
+ right_player_y = np.repeat(obs_right_team[:, 1][..., None], 11, axis=1).transpose(1, 0)
+
+ # ball
+ obs_ball = np.array(obs['ball'])
+ ball_x = np.ones((11, 11)) * obs_ball[0]
+ ball_y = np.ones((11, 11)) * obs_ball[1]
+ ball_z = np.ones((11, 11)) * obs_ball[2]
+
+ # goal
+ left_goal, right_goal = [-1, 0], [1, 0]
+ left_goal_x = np.ones((11, 11)) * left_goal[0]
+ left_goal_y = np.ones((11, 11)) * left_goal[1]
+ right_goal_x = np.ones((11, 11)) * right_goal[0]
+ right_goal_y = np.ones((11, 11)) * right_goal[1]
+
+ # side line
+ side_line_y = [-.42, .42]
+ side_line_y_top = np.ones((11, 11)) * side_line_y[0]
+ side_line_y_bottom = np.ones((11, 11)) * side_line_y[1]
+
+ # active
+ active = np.array(obs['active'])
+ active_player_x = np.repeat(obs_left_team[active][0][..., None, None], 11, axis=1).repeat(11, axis=0)
+ active_player_y = np.repeat(obs_left_team[active][1][..., None, None], 11, axis=1).repeat(11, axis=0)
+
+ # left players - right players
+ left_minus_right_player_x = obs_left_team[:, 0][..., None] - obs_right_team[:, 0]
+ left_minus_right_player_y = obs_left_team[:, 1][..., None] - obs_right_team[:, 1]
+
+ # left players - ball
+ left_minus_ball_x = (obs_left_team[:, 0][..., None] - obs_ball[0]).repeat(11, axis=1)
+ left_minus_ball_y = (obs_left_team[:, 1][..., None] - obs_ball[1]).repeat(11, axis=1)
+
+ # left players - right goal
+ left_minus_right_goal_x = (obs_left_team[:, 0][..., None] - right_goal[0]).repeat(11, axis=1)
+ left_minus_right_goal_y = (obs_left_team[:, 1][..., None] - right_goal[1]).repeat(11, axis=1)
+
+ # left players - left goal
+ left_minus_left_goal_x = (obs_left_team[:, 0][..., None] - left_goal[0]).repeat(11, axis=1)
+ left_minus_left_goal_y = (obs_left_team[:, 1][..., None] - left_goal[1]).repeat(11, axis=1)
+
+ # right players - right goal
+ right_minus_right_goal_x = (obs_right_team[:, 0][..., None] - right_goal[0]).repeat(11, axis=1).transpose(1, 0)
+ right_minus_right_goal_y = (obs_right_team[:, 1][..., None] - right_goal[1]).repeat(11, axis=1).transpose(1, 0)
+
+ # right players - left goal
+ right_minus_left_goal_x = (obs_right_team[:, 0][..., None] - left_goal[0]).repeat(11, axis=1).transpose(1, 0)
+ right_minus_left_goal_y = (obs_right_team[:, 1][..., None] - left_goal[1]).repeat(11, axis=1).transpose(1, 0)
+
+ # left players (x) - active
+ left_minus_active_x = (obs_left_team[:, 0][..., None] - obs_left_team[active][0]).repeat(11, axis=1)
+ left_minus_active_y = (obs_left_team[:, 1][..., None] - obs_left_team[active][1]).repeat(11, axis=1)
+
+ # right player - ball
+ right_minus_ball_x = (obs_right_team[:, 0][..., None] - obs_ball[0]).repeat(11, axis=1).transpose(1, 0)
+ right_minus_ball_y = (obs_right_team[:, 1][..., None] - obs_ball[1]).repeat(11, axis=1).transpose(1, 0)
+
+ # right player - active
+ right_minus_active_x = (obs_right_team[:, 0][..., None] - obs_left_team[active][0]).repeat(
+ 11, axis=1
+ ).transpose(1, 0)
+ right_minus_active_y = (obs_right_team[:, 1][..., None] - obs_left_team[active][1]).repeat(
+ 11, axis=1
+ ).transpose(1, 0)
+
+ # left player - side line
+ left_minus_side_top = np.abs(obs_left_team[:, 1][..., None] - side_line_y[0]).repeat(11, axis=1)
+ left_minus_side_bottom = np.abs(obs_left_team[:, 1][..., None] - side_line_y[1]).repeat(11, axis=1)
+
+ # right player - side line
+ right_minus_side_top = np.abs(obs_right_team[:, 1][..., None] - side_line_y[0]).repeat(11, axis=1).transpose(1, 0)
+ right_minus_side_bottom = np.abs(obs_right_team[:, 1][..., None] - side_line_y[1]).repeat(
+ 11, axis=1
+ ).transpose(1, 0)
+
+ # left players direction
+ obs_left_team_direction = np.array(obs['left_team_direction'])
+ left_player_direction_x = np.repeat(obs_left_team_direction[:, 0][..., None], 11, axis=1)
+ left_player_direction_y = np.repeat(obs_left_team_direction[:, 1][..., None], 11, axis=1)
+
+ # right players direction
+ obs_right_team_direction = np.array(obs['right_team_direction'])
+ right_player_direction_x = np.repeat(obs_right_team_direction[:, 0][..., None], 11, axis=1).transpose(1, 0)
+ right_player_direction_y = np.repeat(obs_right_team_direction[:, 1][..., None], 11, axis=1).transpose(1, 0)
+
+ # ball direction
+ obs_ball_direction = np.array(obs['ball_direction'])
+ ball_direction_x = np.ones((11, 11)) * obs_ball_direction[0]
+ ball_direction_y = np.ones((11, 11)) * obs_ball_direction[1]
+ ball_direction_z = np.ones((11, 11)) * obs_ball_direction[2]
+
+ # left players direction - right players direction
+ left_minus_right_player_direction_x = obs_left_team_direction[:, 0][..., None] - obs_right_team_direction[:, 0]
+ left_minus_right_player_direction_y = obs_left_team_direction[:, 1][..., None] - obs_right_team_direction[:, 1]
+
+ # left players direction - ball direction
+ left_minus_ball_direction_x = (obs_left_team_direction[:, 0][..., None] - obs_ball_direction[0]).repeat(11, axis=1)
+ left_minus_ball_direction_y = (obs_left_team_direction[:, 1][..., None] - obs_ball_direction[1]).repeat(11, axis=1)
+
+ # right players direction - ball direction
+ right_minus_ball_direction_x = (obs_right_team_direction[:, 0][..., None] - obs_ball_direction[0]).repeat(
+ 11, axis=1
+ ).transpose(1, 0)
+ right_minus_ball_direction_y = (obs_right_team_direction[:, 1][..., None] - obs_ball_direction[1]).repeat(
+ 11, axis=1
+ ).transpose(1, 0)
+
+ # ball rotation
+ obs_ball_rotation = np.array(obs['ball_rotation'])
+ ball_rotation_x = np.ones((11, 11)) * obs_ball_rotation[0]
+ ball_rotation_y = np.ones((11, 11)) * obs_ball_rotation[1]
+ ball_rotation_z = np.ones((11, 11)) * obs_ball_rotation[2]
+
+ cnn_feature = np.stack(
+ [
+ left_player_x,
+ left_player_y,
+ right_player_x,
+ right_player_y,
+ ball_x,
+ ball_y,
+ ball_z,
+ left_goal_x,
+ left_goal_y,
+ right_goal_x,
+ right_goal_y,
+ side_line_y_top,
+ side_line_y_bottom,
+ active_player_x,
+ active_player_y,
+ left_minus_right_player_x,
+ left_minus_right_player_y,
+ left_minus_right_goal_x,
+ left_minus_right_goal_y,
+ left_minus_left_goal_x,
+ left_minus_left_goal_y,
+ right_minus_right_goal_x,
+ right_minus_right_goal_y,
+ right_minus_left_goal_x,
+ right_minus_left_goal_y,
+ left_minus_side_top,
+ left_minus_side_bottom,
+ right_minus_side_top,
+ right_minus_side_bottom,
+ right_minus_ball_x,
+ right_minus_ball_y,
+ right_minus_active_x,
+ right_minus_active_y,
+ left_minus_ball_x,
+ left_minus_ball_y,
+ left_minus_active_x,
+ left_minus_active_y,
+ ball_direction_x,
+ ball_direction_y,
+ ball_direction_z,
+ left_minus_ball_direction_x,
+ left_minus_ball_direction_y,
+ right_minus_ball_direction_x,
+ right_minus_ball_direction_y,
+ left_player_direction_x,
+ left_player_direction_y,
+ right_player_direction_x,
+ right_player_direction_y,
+ left_minus_right_player_direction_x,
+ left_minus_right_player_direction_y,
+ ball_rotation_x,
+ ball_rotation_y,
+ ball_rotation_z,
+ ],
+ axis=0
+ )
+
+ # ball
+ BALL_OWEND_1HOT = {-1: [0, 0], 0: [1, 0], 1: [0, 1]}
+ ball_owned_team_ = obs['ball_owned_team']
+ ball_owned_team = BALL_OWEND_1HOT[ball_owned_team_] # {-1, 0, 1} None, self, opponent
+ PLAYER_1HOT = np.concatenate([np.eye(11), np.zeros((1, 11))])
+ ball_owned_player_ = PLAYER_1HOT[obs['ball_owned_player']] # {-1, N-1}
+ if ball_owned_team_ == -1:
+ my_ball_owned_player = PLAYER_1HOT[-1]
+ op_ball_owned_player = PLAYER_1HOT[-1]
+ elif ball_owned_team_ == 0:
+ my_ball_owned_player = ball_owned_player_
+ op_ball_owned_player = PLAYER_1HOT[-1]
+ else:
+ my_ball_owned_player = PLAYER_1HOT[-1]
+ op_ball_owned_player = ball_owned_player_
+
+ ball_features = np.concatenate([obs['ball'], obs['ball_direction'], obs['ball_rotation']]).astype(np.float32)
+
+ # self team
+ left_team_features = np.concatenate(
+ [
+ [[1] for _ in obs['left_team']], # left team flag
+ obs['left_team'], # position
+ obs['left_team_direction'],
+ [[v] for v in obs['left_team_tired_factor']],
+ [[v] for v in obs['left_team_yellow_card']],
+ [[v] for v in obs['left_team_active']],
+ my_ball_owned_player[..., np.newaxis]
+ ],
+ axis=1
+ ).astype(np.float32)
+
+ left_team_indice = np.arange(0, 11, dtype=np.int32)
+
+ # opponent team
+ right_team_features = np.concatenate(
+ [
+ [[0] for _ in obs['right_team']], # right team flag
+ obs['right_team'], # position
+ obs['right_team_direction'],
+ [[v] for v in obs['right_team_tired_factor']],
+ [[v] for v in obs['right_team_yellow_card']],
+ [[v] for v in obs['right_team_active']],
+ op_ball_owned_player[..., np.newaxis]
+ ],
+ axis=1
+ ).astype(np.float32)
+
+ right_team_indice = np.arange(0, 11, dtype=np.int32)
+
+ # distance information
+ def get_distance(xy1, xy2):
+ return (((xy1 - xy2) ** 2).sum(axis=-1)) ** 0.5
+
+ def get_line_distance(x1, x2):
+ return np.abs(x1 - x2)
+
+ def multi_scale(x, scale):
+ return 2 / (1 + np.exp(-np.array(x)[..., np.newaxis] / np.array(scale)))
+
+ both_team = np.array(obs['left_team'] + obs['right_team'], dtype=np.float32)
+ ball = np.array([obs['ball'][:2]], dtype=np.float32)
+ goal = np.array([[-1, 0], [1, 0]], dtype=np.float32)
+ goal_line_x = np.array([-1, 1], dtype=np.float32)
+ side_line_y = np.array([-.42, .42], dtype=np.float32)
+
+ # ball <-> goal, goal line, side line distance
+ b2g_distance = get_distance(ball, goal)
+ b2gl_distance = get_line_distance(ball[0][0], goal_line_x)
+ b2sl_distance = get_line_distance(ball[0][1], side_line_y)
+ b2o_distance = np.concatenate([b2g_distance, b2gl_distance, b2sl_distance], axis=-1)
+
+ # player <-> ball, goal, back line, side line distance
+ p2b_distance = get_distance(both_team[:, np.newaxis, :], ball[np.newaxis, :, :])
+ p2g_distance = get_distance(both_team[:, np.newaxis, :], goal[np.newaxis, :, :])
+ p2gl_distance = get_line_distance(both_team[:, :1], goal_line_x[np.newaxis, :])
+ p2sl_distance = get_line_distance(both_team[:, 1:], side_line_y[np.newaxis, :])
+ p2bo_distance = np.concatenate([p2b_distance, p2g_distance, p2gl_distance, p2sl_distance], axis=-1)
+
+ # player <-> player distance
+ p2p_distance = get_distance(both_team[:, np.newaxis, :], both_team[np.newaxis, :, :])
+
+ # apply Multiscale to distances
+ # def concat_multiscale(x, scale):
+ # return np.concatenate([x[...,np.newaxis], 1 - multi_scale(x, scale)], axis=-1)
+
+ # distance_scales = [.01, .05, .25, 1.25]
+ # b2o_distance = 1 - multi_scale(b2o_distance, distance_scales).reshape(-1)
+ # p2bo_distance = 1 - multi_scale(p2bo_distance, distance_scales).reshape(len(both_team), -1)
+ # p2p_distance = 1 - multi_scale(p2p_distance, distance_scales).reshape(len(both_team), len(both_team), -1)
+
+ # controlled player information
+ control_flag_ = np.array(PLAYER_1HOT[obs['active']], dtype=np.float32)
+ control_flag = np.concatenate([control_flag_, np.zeros(len(obs['right_team']))])[..., np.newaxis]
+
+ # controlled status information
+ DIR = [
+ [-1, 0],
+ [-.707, -.707],
+ [0, 1],
+ [.707, -.707], # L, TL, T, TR
+ [1, 0],
+ [.707, .707],
+ [0, -1],
+ [-.707, .707] # R, BR, B, BL
+ ]
+ sticky_direction = DIR[obs['sticky_actions'][:8].index(1)] if 1 in obs['sticky_actions'][:8] else [0, 0]
+ sticky_flags = obs['sticky_actions'][8:]
+
+ control_features = np.concatenate([
+ sticky_direction,
+ sticky_flags,
+ ]).astype(np.float32)
+
+ # Match state
+ if obs['steps_left'] > info['half_step']:
+ steps_left_half = obs['steps_left'] - info['half_step']
+ else:
+ steps_left_half = obs['steps_left']
+ match_features = np.concatenate(
+ [
+ multi_scale(obs['score'], [1, 3]).ravel(),
+ multi_scale(obs['score'][0] - obs['score'][1], [1, 3]),
+ multi_scale(obs['steps_left'], [10, 100, 1000, 10000]),
+ multi_scale(steps_left_half, [10, 100, 1000, 10000]),
+ ball_owned_team,
+ ]
+ ).astype(np.float32)
+
+ mode_index = np.array([obs['game_mode']], dtype=np.int32)
+
+ # Super Mini Map
+ # SMM_WIDTH = 96 #// 3
+ # SMM_HEIGHT = 72 #// 3
+ # SMM_LAYERS = ['left_team', 'right_team', 'ball', 'active']
+
+ # # Normalized minimap coordinates
+ # MINIMAP_NORM_X_MIN = -1.0
+ # MINIMAP_NORM_X_MAX = 1.0
+ # MINIMAP_NORM_Y_MIN = -1.0 / 2.25
+ # MINIMAP_NORM_Y_MAX = 1.0 / 2.25
+
+ # _MARKER_VALUE = 1 # 255
+
+ # def get_smm_layers(config):
+ # return SMM_LAYERS
+
+ # def mark_points(frame, points):
+ # """Draw dots corresponding to 'points'.
+ # Args:
+ # frame: 2-d matrix representing one SMM channel ([y, x])
+ # points: a list of (x, y) coordinates to be marked
+ # """
+ # for p in range(len(points) // 2):
+ # x = int((points[p * 2] - MINIMAP_NORM_X_MIN) /
+ # (MINIMAP_NORM_X_MAX - MINIMAP_NORM_X_MIN) * frame.shape[1])
+ # y = int((points[p * 2 + 1] - MINIMAP_NORM_Y_MIN) /
+ # (MINIMAP_NORM_Y_MAX - MINIMAP_NORM_Y_MIN) * frame.shape[0])
+ # x = max(0, min(frame.shape[1] - 1, x))
+ # y = max(0, min(frame.shape[0] - 1, y))
+ # frame[y, x] = _MARKER_VALUE
+
+ # def generate_smm(observation, config=None,
+ # channel_dimensions=(SMM_WIDTH, SMM_HEIGHT)):
+ # """Returns a list of minimap observations given the raw features for each
+ # active player.
+ # Args:
+ # observation: raw features from the environment
+ # config: environment config
+ # channel_dimensions: resolution of SMM to generate
+ # Returns:
+ # (N, H, W, C) - shaped np array representing SMM. N stands for the number of
+ # players we are controlling.
+ # """
+ # frame = np.zeros((len(observation), channel_dimensions[1],
+ # channel_dimensions[0], len(get_smm_layers(config))),
+ # dtype=np.uint8)
+
+ # for o_i, o in enumerate(observation):
+ # for index, layer in enumerate(get_smm_layers(config)):
+ # assert layer in o
+ # if layer == 'active':
+ # if o[layer] == -1:
+ # continue
+ # mark_points(frame[o_i, :, :, index],
+ # np.array(o['left_team'][o[layer]]).reshape(-1))
+ # else:
+ # mark_points(frame[o_i, :, :, index], np.array(o[layer]).reshape(-1))
+ # return frame
+
+ # smm = generate_smm([obs]).transpose(3, 1, 2, 0).squeeze(3).astype(np.float32)
+
+ # ACTION_1HOT = np.eye(19)
+ # action_history = np.stack([ACTION_1HOT[a] for a in action_history]).astype(np.float32)
+ action_history = np.array(action_history, dtype=np.int32)[..., None]
+
+ return {
+ # features
+ 'ball': ball_features,
+ 'match': match_features,
+ 'player': {
+ 'self': left_team_features,
+ 'opp': right_team_features
+ },
+ 'control': control_features,
+ 'player_index': {
+ 'self': left_team_indice,
+ 'opp': right_team_indice
+ },
+ 'mode_index': mode_index,
+ 'control_flag': control_flag,
+ # distances
+ 'distance': {
+ 'p2p': p2p_distance,
+ 'p2bo': p2bo_distance,
+ 'b2o': b2o_distance
+ },
+ # CNN
+ 'cnn_feature': cnn_feature,
+ # SuperMiniMap
+ # 'smm': smm,
+ 'action_history': action_history
+ }
+
+
+KICK_ACTIONS = {
+ Action.LongPass: 20,
+ Action.HighPass: 28,
+ Action.ShortPass: 36,
+ Action.Shot: 44,
+}
+
+
+class Environment:
+ ACTION_LEN = 19 + 4 * 8
+ ACTION_IDX = list(range(ACTION_LEN))
+
+ def __init__(self, args={}):
+ self.env_map = {}
+ self.env = None
+ self.limit_steps = args.get('limit_steps', 100000)
+ self.frame_skip = args.get('frame_skip', 0)
+ self.reset_common()
+
+ def reset_common(self):
+ self.finished = False
+ self.prev_score = [0, 0]
+ self.reset_flag = False
+ self.checkpoint = [
+ [0.95, 0.85, 0.75, 0.65, 0.55, 0.45, 0.35, 0.25, 0.15, 0.05],
+ [0.95, 0.85, 0.75, 0.65, 0.55, 0.45, 0.35, 0.25, 0.15, 0.05]
+ ]
+ self.states = []
+ self.half_step = 1500
+ self.reserved_action = [None, None]
+
+ def reset(self, args={}):
+ if len(self.env_map) == 0:
+ from gfootball.env import football_action_set
+ from gfootball.env.wrappers import Simple115StateWrapper
+ from kaggle_environments import make
+
+ self.ACTION_STR = football_action_set.action_set_v1
+ self.ACTION2STR = {i: j for i, j in enumerate(football_action_set.action_set_v1)}
+ self.STR2ACTION = {j: i for i, j in self.ACTION2STR.items()}
+
+ # self.env_map[3000] = make("football", configuration={"scenario_name": "11_vs_11_kaggle"})
+ # self.env_map[1000] = make("football", configuration={"scenario_name": "11_vs_11_kaggle_1000_500"})
+ # self.env_map[500] = make("football", configuration={"scenario_name": "11_vs_11_kaggle_500_250"})
+ # self.env_map[9999] = make("football", configuration={"scenario_name": "11_vs_11_kaggle_random"})
+ # self.env_map[99999] = make("football", configuration={"scenario_name": "11_vs_11_kaggle_random_long"})
+
+ self.env_map["real"] = make("football", configuration={"scenario_name": "11_vs_11_kaggle"})
+ self.env_map["eval"] = make("football", configuration={"scenario_name": "11_vs_11_kaggle_1000_500"})
+ self.env_map["train"] = make("football", configuration={"scenario_name": "11_vs_11_kaggle_train"})
+
+ # decide limit steps
+
+ # if args.get('role', {}) == 'e':
+ # self.env = self.env_map[1000]
+ # else:
+ # limit_rate = args.get('limit_rate', 1.0)
+ # if limit_rate > 0.9:
+ # self.env = self.env_map[3000]
+ # elif limit_rate >= 0:
+ # self.env = self.env_map[99999]
+
+ role = args.get('role', '')
+ limit_rate = args.get('limit_rate', 1)
+ if role == 'g':
+ self.env = self.env_map['train' if limit_rate < 0.95 else 'real']
+ elif role == 'e':
+ self.env = self.env_map['eval']
+ else:
+ self.env = self.env_map['real']
+
+ state = self.env.reset()
+ self.resets_info(state)
+
+ def resets_info(self, state):
+ self.reset_common()
+ state = copy.deepcopy(state)
+ state = [self._preprocess_state(s) for s in state]
+ self.states.append(state)
+ self.half_step = state[0]['observation']['players_raw'][0]['steps_left'] // 2
+
+ def reset_info(self, state):
+ self.resets_info(state)
+
+ def chance(self):
+ pass
+
+ def action2str(self, a: int):
+ # return self.ACTION2STR[a]
+ return str(a)
+
+ def str2action(self, s: str):
+ # return self.STR2ACTION[s]
+ return int(s)
+
+ def plays(self, actions):
+ self._plays(actions)
+
+ def _plays(self, actions):
+ # state transition function
+ # action is integer (0 ~ 18)
+ actions = copy.deepcopy(actions)
+ for i, res_action in enumerate(self.reserved_action):
+ if res_action is not None:
+ actions[i] = res_action
+
+ # augmented action to atomic action
+ for i, action in enumerate(actions):
+ atomic_a, reserved_a = self.special_to_actions(action)
+ actions[i] = atomic_a
+ self.reserved_action[i] = reserved_a
+
+ # step environment
+ state = self.env.step([[actions[0]], [actions[1]]])
+ state = copy.deepcopy(state)
+ state = [self._preprocess_state(s) for s in state]
+ self.states.append(state)
+
+ # update status
+ if state[0]['status'] == 'DONE' or len(self.states) > self.limit_steps:
+ self.finished = True
+
+ def plays_info(self, state):
+ # state stansition function as an agent
+ state = copy.deepcopy(state)
+ state = [self._preprocess_state(s) for s in state]
+ self.states.append(state)
+
+ def play_info(self, state):
+ self.plays_info(state)
+
+ def diff_info(self):
+ return self.states[-1]
+
+ def turns(self):
+ return self.players()
+
+ def players(self):
+ return [0, 1]
+
+ def terminal(self):
+ # check whether the state is terminal
+ return self.finished
+
+ def reward(self):
+ prev_score = self.prev_score
+ score = self.score()
+
+ rs = []
+ scored_player = None
+ for p in self.players():
+ r = 1.0 * (score[p] - prev_score[p]) - 1.0 * (score[1 - p] - prev_score[1 - p])
+ rs.append(r)
+ if r != 0:
+ self.reset_flag = True
+ scored_player = p
+
+ self.prev_score = self.score()
+ return rs
+
+ def get_goal_distance(xy1):
+ return (((xy1 - np.array([1, 0])) ** 2).sum(axis=-1)) ** 0.5
+
+ # checkpoint reward (https://arxiv.org/pdf/1907.11180.pdf)
+ checkpoint_reward = []
+ for p in self.players():
+ obs = self.raw_observation(p)['players_raw'][0]
+ ball_owned_team = obs['ball_owned_team']
+ if ball_owned_team == p and len(self.checkpoint[p]) != 0:
+ ball = obs['ball'][:2]
+ goal_distance = get_goal_distance(ball)
+ if goal_distance < self.checkpoint[p][0]:
+ cr = 0
+ for idx, c in enumerate(self.checkpoint[p]):
+ if goal_distance < c:
+ cr += 0.1
+ else:
+ break
+ self.checkpoint[p] = self.checkpoint[p][idx:]
+ checkpoint_reward.append(cr)
+ else:
+ checkpoint_reward.append(0)
+ else:
+ checkpoint_reward.append(0)
+
+ if scored_player is not None:
+ checkpoint_reward[scored_player] += len(
+ self.checkpoint[scored_player]
+ ) * 0.1 # add remain reward when scoring (0.05 per checkpoint)
+ self.checkpoint[scored_player] = []
+
+ return [rs[p] + checkpoint_reward[p] for p in self.players()]
+
+ def is_reset_state(self):
+ if self.reset_flag:
+ self.reset_flag = False
+ return True
+ return False
+
+ def score(self):
+ if len(self.states) == 0:
+ return [0, 0]
+ obs = self.states[-1]
+ return [
+ obs[0]['observation']['players_raw'][0]['score'][0], obs[1]['observation']['players_raw'][0]['score'][0]
+ ]
+
+ def outcome(self):
+ if len(self.states) == 0:
+ return [0, 0]
+ scores = self.score()
+ if scores[0] > scores[1]:
+ score_diff = scores[0] - scores[1]
+ outcome_tanh = np.tanh(score_diff ** 0.8)
+ return [outcome_tanh, -outcome_tanh]
+ elif scores[0] < scores[1]:
+ score_diff = scores[1] - scores[0]
+ outcome_tanh = np.tanh(score_diff ** 0.8)
+ return [-outcome_tanh, outcome_tanh]
+ return [0, 0]
+
+ def legal_actions(self, player):
+ # legal action list
+ all_actions = [i for i in copy.copy(self.ACTION_IDX) if i != 19]
+
+ if len(self.states) == 0:
+ return all_actions
+
+ # obs from view of the player
+ obs = self.raw_observation(player)['players_raw'][0]
+ # Illegal actions
+ illegal_actions = set()
+ # You have a ball?
+ ball_owned_team = obs['ball_owned_team']
+ if ball_owned_team != 0: # not owned or free
+ illegal_actions.add(int(Action.LongPass))
+ illegal_actions.add(int(Action.HighPass))
+ illegal_actions.add(int(Action.ShortPass))
+ illegal_actions.add(int(Action.Shot))
+ illegal_actions.add(int(Action.Dribble))
+ for d in range(8):
+ illegal_actions.add(KICK_ACTIONS[Action.LongPass] + d)
+ illegal_actions.add(KICK_ACTIONS[Action.HighPass] + d)
+ illegal_actions.add(KICK_ACTIONS[Action.ShortPass] + d)
+ illegal_actions.add(KICK_ACTIONS[Action.Shot] + d)
+ else: # owned
+ illegal_actions.add(int(Action.Slide))
+
+ # Already sticky action?
+ sticky_actions = obs['sticky_actions']
+ if type(sticky_actions) == set:
+ sticky_actions = [0] * 10
+
+ if sticky_actions[action_to_sticky_index[Action.Sprint]] == 0: # not action_sprint
+ illegal_actions.add(int(Action.ReleaseSprint))
+
+ if sticky_actions[action_to_sticky_index[Action.Dribble]] == 0: # not action_dribble
+ illegal_actions.add(int(Action.ReleaseDribble))
+
+ if 1 not in sticky_actions[:8]:
+ illegal_actions.add(int(Action.ReleaseDirection))
+
+ return [a for a in all_actions if a not in illegal_actions]
+
+ def action_length(self):
+ # maximum size of policy (it determines output size of policy function)
+ return self.ACTION_LEN
+
+ def raw_observation(self, player):
+ if len(self.states) > 0:
+ return self.states[-1][player]['observation']
+ else:
+ return OBS_TEMPLATE
+
+ def observation(self, player):
+ # input feature for neural nets
+ info = {'half_step': self.half_step}
+ return feature_from_states(self.states, info, player)
+
+ def _preprocess_state(self, player_state):
+ if player_state is None:
+ return player_state
+
+ # in ball-dead state, set ball owned player and team
+ o = player_state['observation']['players_raw'][0]
+ mode = o['game_mode']
+ if mode == GameMode.FreeKick or \
+ mode == GameMode.Corner or \
+ mode == GameMode.Penalty or \
+ mode == GameMode.GoalKick:
+ # find nearest player and team
+ def dist(xy1, xy2):
+ return ((xy1[0] - xy2[0]) ** 2 + (xy1[1] - xy2[1]) ** 2) ** 0.5
+
+ team_player_position = [(0, i, p) for i, p in enumerate(o['left_team'])] + \
+ [(1, i, p) for i, p in enumerate(o['right_team'])]
+ distances = [(t[0], t[1], dist(t[2], o['ball'][:2])) for t in team_player_position]
+ distances = sorted(distances, key=lambda x: x[2])
+ # print(mode, [t[2] for t in distances])
+ # print(o['ball_owned_team'], o['ball_owned_player'], '->', distances[0][0], distances[0][1])
+ # input()
+ o['ball_owned_team'] = distances[0][0]
+ o['ball_owned_player'] = distances[0][1]
+
+ # in the beginning, fill actions with 0
+ if len(player_state['action']) == 0:
+ player_state['action'].append(0)
+
+ return player_state
+
+ def special_to_actions(self, saction):
+ if not 0 <= saction < 52:
+ return [0, None]
+ for a, index in KICK_ACTIONS.items():
+ if index <= saction < index + 8:
+ return [a, Action(saction - index + 1)]
+ return [saction, None]
+
+ '''def action_to_specials(self, action):
+ p = np.zeros(self.action_length())
+ p[action] = 1
+
+ sticky_direction =
+
+
+ if action == Action.LongPass:
+ return
+
+ return p / p.sum()'''
+
+ def funcname(self, parameter_list):
+ """
+ docstring
+ """
+ pass
+
+ def net(self):
+ return FootballNet
+
+ def rule_based_action(self, player):
+ return 19
+
+ # def rule_based_action_A(self, player):
+ # return rulebaseA._agent(self.states[-1][player]['observation'])
+
+ # def rule_based_action_B(self, player):
+ # return rulebaseB._agent(self.states[-1][player]['observation'])
+
+ # def rule_based_action_C(self, player):
+ # return rulebaseC._agent(self.states[-1][player]['observation'])
+
+ # #def rule_based_action_D(self, player):
+ # # return rulebaseD._agent(self.states[-1][player]['observation'])
+
+ # def rule_based_action_E(self, player):
+ # return rulebaseE._agent(self.states[-1][player]['observation'])
+
+ # def rule_based_action_F(self, player):
+ # return rulebaseF._agent(self.states[-1][player]['observation'])
+
+
+if __name__ == '__main__':
+ e = Environment()
+ net = e.net()(e)
+ net.eval()
+ for _ in range(1):
+ e.reset()
+ o = e.observation(0)
+ net.inference(o, None)
+ while not e.terminal():
+ # print(e)
+ _ = e.observation(0)
+ _ = e.observation(1)
+ print(e.env.configuration.episodeSteps)
+ print(e.raw_observation(0)['players_raw'][0]['steps_left'])
+ action_list = [0, 0]
+ action_list[0] = random.choice(e.legal_actions(0))
+ action_list[1] = e.rule_based_action_C(1)
+ print(len(e.states), action_list)
+ e.plays(action_list)
+ print(e.checkpoint)
+ print(e.reward())
+ print(e)
+ print(e.score())
+ print(e.outcome())
diff --git a/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/handyrl_core/model.py b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/handyrl_core/model.py
new file mode 100644
index 0000000000000000000000000000000000000000..3da4622b52a29b24fde15d96f6ed741f9e521bfd
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/handyrl_core/model.py
@@ -0,0 +1,295 @@
+# Copyright (c) 2020 DeNA Co., Ltd.
+# Licensed under The MIT License [see LICENSE for details]
+
+# neural nets
+
+import numpy as np
+import torch
+torch.set_num_threads(1)
+
+import torch.nn as nn
+import torch.nn.functional as F
+
+from .util import map_r
+
+
+def load_model(model, model_path):
+ loaded_dict_ = torch.load(model_path)
+ model_dict = model.state_dict()
+ loaded_dict = {k: v for k, v in loaded_dict_.items() if k in model_dict}
+ model_dict.update(loaded_dict)
+ model.load_state_dict(model_dict)
+ return model
+
+
+def to_torch(x, transpose=False, unsqueeze=None):
+ if x is None:
+ return None
+ elif isinstance(x, (list, tuple, set)):
+ return type(x)(to_torch(xx, transpose, unsqueeze) for xx in x)
+ elif isinstance(x, dict):
+ return type(x)((key, to_torch(xx, transpose, unsqueeze)) for key, xx in x.items())
+
+ a = np.array(x)
+ if transpose:
+ a = np.swapaxes(a, 0, 1)
+ if unsqueeze is not None:
+ a = np.expand_dims(a, unsqueeze)
+
+ if a.dtype == np.int32 or a.dtype == np.int64:
+ t = torch.LongTensor(a)
+ else:
+ t = torch.FloatTensor(a)
+
+ return t.contiguous()
+
+
+def to_numpy(x):
+ return map_r(x, lambda x: x.detach().numpy() if x is not None else None)
+
+
+def to_gpu(data):
+ return map_r(data, lambda x: x.cuda() if x is not None else None)
+
+
+def to_gpu_or_not(data, gpu):
+ return to_gpu(data) if gpu else data
+
+
+def softmax(x):
+ x = np.exp(x - np.max(x, axis=-1))
+ return x / x.sum(axis=-1)
+
+
+class Conv(nn.Module):
+
+ def __init__(self, filters0, filters1, kernel_size, bn, bias=True):
+ super().__init__()
+ if bn:
+ bias = False
+ self.conv = nn.Conv2d(filters0, filters1, kernel_size, stride=1, padding=kernel_size // 2, bias=bias)
+ self.bn = nn.BatchNorm2d(filters1) if bn else None
+
+ def forward(self, x):
+ h = self.conv(x)
+ if self.bn is not None:
+ h = self.bn(h)
+ return h
+
+
+class Dense(nn.Module):
+
+ def __init__(self, units0, units1, bnunits=0, bias=True):
+ super().__init__()
+ if bnunits > 0:
+ bias = False
+ self.dense = nn.Linear(units0, units1, bias=bias)
+ self.bnunits = bnunits
+ self.bn = nn.BatchNorm1d(bnunits) if bnunits > 0 else None
+
+ def forward(self, x):
+ h = self.dense(x)
+ if self.bn is not None:
+ size = h.size()
+ h = h.view(-1, self.bnunits)
+ h = self.bn(h)
+ h = h.view(*size)
+ return h
+
+
+class WideResidualBlock(nn.Module):
+
+ def __init__(self, filters, kernel_size, bn):
+ super().__init__()
+ self.conv1 = Conv(filters, filters, kernel_size, bn, not bn)
+ self.conv2 = Conv(filters, filters, kernel_size, bn, not bn)
+
+ def forward(self, x):
+ return F.relu(x + self.conv2(F.relu(self.conv1(x))))
+
+
+class WideResNet(nn.Module):
+
+ def __init__(self, blocks, filters):
+ super().__init__()
+ self.blocks = nn.ModuleList([WideResidualBlock(filters, 3, bn=False) for _ in range(blocks)])
+
+ def forward(self, x):
+ h = x
+ for block in self.blocks:
+ h = block(h)
+ return h
+
+
+class Encoder(nn.Module):
+
+ def __init__(self, input_size, filters):
+ super().__init__()
+
+ self.input_size = input_size
+ self.conv = Conv(input_size[0], filters, 3, bn=False)
+ self.activation = nn.LeakyReLU(0.1)
+
+ def forward(self, x):
+ return self.activation(self.conv(x))
+
+
+class Head(nn.Module):
+
+ def __init__(self, input_size, out_filters, outputs):
+ super().__init__()
+
+ self.board_size = input_size[1] * input_size[2]
+ self.out_filters = out_filters
+
+ self.conv = Conv(input_size[0], out_filters, 1, bn=False)
+ self.activation = nn.LeakyReLU(0.1)
+ self.fc = nn.Linear(self.board_size * out_filters, outputs, bias=False)
+
+ def forward(self, x):
+ h = self.activation(self.conv(x))
+ h = self.fc(h.view(-1, self.board_size * self.out_filters))
+ return h
+
+
+class ConvLSTMCell(nn.Module):
+
+ def __init__(self, input_dim, hidden_dim, kernel_size, bias):
+ super().__init__()
+
+ self.input_dim = input_dim
+ self.hidden_dim = hidden_dim
+
+ self.kernel_size = kernel_size
+ self.padding = kernel_size[0] // 2, kernel_size[1] // 2
+ self.bias = bias
+
+ self.conv = nn.Conv2d(
+ in_channels=self.input_dim + self.hidden_dim,
+ out_channels=4 * self.hidden_dim,
+ kernel_size=self.kernel_size,
+ padding=self.padding,
+ bias=self.bias
+ )
+
+ def init_hidden(self, input_size, batch_size):
+ return tuple(
+ [
+ torch.zeros(*batch_size, self.hidden_dim, *input_size),
+ torch.zeros(*batch_size, self.hidden_dim, *input_size),
+ ]
+ )
+
+ def forward(self, input_tensor, cur_state):
+ h_cur, c_cur = cur_state
+
+ combined = torch.cat([input_tensor, h_cur], dim=-3) # concatenate along channel axis
+ combined_conv = self.conv(combined)
+
+ cc_i, cc_f, cc_o, cc_g = torch.split(combined_conv, self.hidden_dim, dim=-3)
+ i = torch.sigmoid(cc_i)
+ f = torch.sigmoid(cc_f)
+ o = torch.sigmoid(cc_o)
+ g = torch.tanh(cc_g)
+
+ c_next = f * c_cur + i * g
+ h_next = o * torch.tanh(c_next)
+
+ return h_next, c_next
+
+
+class DRC(nn.Module):
+
+ def __init__(self, num_layers, input_dim, hidden_dim, kernel_size=3, bias=True):
+ super().__init__()
+ self.num_layers = num_layers
+
+ blocks = []
+ for _ in range(self.num_layers):
+ blocks.append(
+ ConvLSTMCell(
+ input_dim=input_dim, hidden_dim=hidden_dim, kernel_size=(kernel_size, kernel_size), bias=bias
+ )
+ )
+ self.blocks = nn.ModuleList(blocks)
+
+ def init_hidden(self, input_size, batch_size):
+ if batch_size is None: # for inference
+ with torch.no_grad():
+ return to_numpy(self.init_hidden(input_size, []))
+ else: # for training
+ hs, cs = [], []
+ for block in self.blocks:
+ h, c = block.init_hidden(input_size, batch_size)
+ hs.append(h)
+ cs.append(c)
+
+ return torch.stack(hs), torch.stack(cs)
+
+ def forward(self, x, hidden, num_repeats):
+ if hidden is None:
+ hidden = self.init_hidden(x.shape[-2:], x.shape[:-3])
+
+ hs = [hidden[0][i] for i in range(self.num_layers)]
+ cs = [hidden[1][i] for i in range(self.num_layers)]
+ for _ in range(num_repeats):
+ for i, block in enumerate(self.blocks):
+ hs[i], cs[i] = block(x, (hs[i], cs[i]))
+
+ return hs[-1], (torch.stack(hs), torch.stack(cs))
+
+
+# simple model
+
+
+class BaseModel(nn.Module):
+
+ def __init__(self, env=None, args=None, action_length=None):
+ super().__init__()
+ self.action_length = env.action_length() if action_length is None else action_length
+
+ def init_hidden(self, batch_size=None):
+ return None
+
+ def inference(self, x, hidden, **kwargs):
+ # numpy array -> numpy array
+ self.eval()
+ with torch.no_grad():
+ xt = to_torch(x, unsqueeze=0)
+ ht = to_torch(hidden, unsqueeze=1)
+ outputs = self.forward(xt, ht, **kwargs)
+
+ return tuple(
+ [(to_numpy(o).squeeze(0) if o is not None else None) for o in outputs[:-1]] + \
+ [map_r(outputs[-1], lambda o: to_numpy(o).squeeze(1)) if outputs[-1] is not None else None]
+ )
+
+
+class RandomModel(BaseModel):
+
+ def inference(self, x=None, hidden=None):
+ return np.zeros(self.action_length), np.zeros(1), np.zeros(1), None
+
+
+class DuelingNet(BaseModel):
+
+ def __init__(self, env, args={}):
+ super().__init__(env, args)
+
+ self.input_size = env.observation().shape
+
+ layers, filters = args.get('layers', 3), args.get('filters', 32)
+ internal_size = (filters, *self.input_size[1:])
+
+ self.encoder = Encoder(self.input_size, filters)
+ self.body = WideResNet(layers, filters)
+ self.head_p = Head(internal_size, 2, self.action_length)
+ self.head_v = Head(internal_size, 1, 1)
+
+ def forward(self, x, hidden=None):
+ h = self.encoder(x)
+ h = self.body(h)
+ h_p = self.head_p(h)
+ h_v = self.head_v(h)
+
+ return h_p, torch.tanh(h_v), None, None
diff --git a/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/handyrl_core/util.py b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/handyrl_core/util.py
new file mode 100644
index 0000000000000000000000000000000000000000..bbd0ea38387a7d638d9354824e713d133b6827ef
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/handyrl_core/util.py
@@ -0,0 +1,66 @@
+# Copyright (c) 2020 DeNA Co., Ltd.
+# Licensed under The MIT License [see LICENSE for details]
+
+
+def map_r(x, callback_fn=None):
+ # recursive map function
+ if isinstance(x, (list, tuple, set)):
+ return type(x)(map_r(xx, callback_fn) for xx in x)
+ elif isinstance(x, dict):
+ return type(x)((key, map_r(xx, callback_fn)) for key, xx in x.items())
+ return callback_fn(x) if callback_fn is not None else None
+
+
+def bimap_r(x, y, callback_fn=None):
+ if isinstance(x, (list, tuple)):
+ return type(x)(bimap_r(xx, y[i], callback_fn) for i, xx in enumerate(x))
+ elif isinstance(x, dict):
+ return type(x)((key, bimap_r(xx, y[key], callback_fn)) for key, xx in x.items())
+ return callback_fn(x, y) if callback_fn is not None else None
+
+
+def trimap_r(x, y, z, callback_fn=None):
+ if isinstance(x, (list, tuple)):
+ return type(x)(trimap_r(xx, y[i], z[i], callback_fn) for i, xx in enumerate(x))
+ elif isinstance(x, dict):
+ return type(x)((key, trimap_r(xx, y[key], z[key], callback_fn)) for key, xx in x.items())
+ return callback_fn(x, y, z) if callback_fn is not None else None
+
+
+def type_r(x):
+ type_s = str(type(x))
+ print(type(x))
+ if isinstance(x, (list, tuple, set)):
+ return {type_s: type_r(xx) for xx in x}
+ elif isinstance(x, dict):
+ return {type_s: type_r(xx) for xx in x.values()}
+ return type_s
+
+
+def rotate(x, max_depth=1024):
+ if max_depth == 0:
+ return x
+ if isinstance(x, (list, tuple)):
+ if isinstance(x[0], (list, tuple)):
+ return type(x[0])(
+ rotate(type(x)(xx[i] for xx in x), max_depth - 1) \
+ for i, _ in enumerate(x[0])
+ )
+ elif isinstance(x[0], dict):
+ return type(x[0])(
+ (key, rotate(type(x)(xx[key] for xx in x), max_depth - 1)) \
+ for key in x[0]
+ )
+ elif isinstance(x, dict):
+ x_front = x[list(x.keys())[0]]
+ if isinstance(x_front, (list, tuple)):
+ return type(x_front)(
+ rotate(type(x)((key, xx[i]) for key, xx in x.items()), max_depth - 1) \
+ for i, _ in enumerate(x_front)
+ )
+ elif isinstance(x_front, dict):
+ return type(x_front)(
+ (key2, rotate(type(x)((key1, xx[key2]) for key1, xx in x.items()), max_depth - 1)) \
+ for key2 in x_front
+ )
+ return x
diff --git a/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/readme.md b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/readme.md
new file mode 100644
index 0000000000000000000000000000000000000000..db8dbd6322a404d7c6c0d02b172e36247dcfaaf0
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/readme.md
@@ -0,0 +1,5 @@
+This is the kaggle gfootball competition 5 th place solution.
+
+See https://www.kaggle.com/c/google-football/discussion/203412 from detail.
+
+Thanks [kyazuki](https://www.kaggle.com/kyazuki) and [@yuricat](https://www.kaggle.com/yuricat) who are generous to share their code.
\ No newline at end of file
diff --git a/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/submission.py b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/submission.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a5f3d2e15f6a952cc9f9df3a39623dc6229a77a
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/submission.py
@@ -0,0 +1,70 @@
+import os.path as osp
+import yaml
+
+import numpy as np
+import torch
+
+from .football_ikki import Environment
+from .handyrl_core.model import load_model
+
+model_path = osp.join(osp.dirname(__file__), 'models/1679.pth')
+
+with open(osp.join(osp.dirname(__file__), 'config.yaml')) as f:
+ config = yaml.safe_load(f)
+
+env = Environment(config['env_args'])
+model = load_model(env.net()(env), model_path)
+model.eval()
+
+
+def output_think(env, obs, actions, p, v, r):
+ pmask = np.ones_like(p)
+ pmask[actions] = 0
+ p = p - pmask * 1e32
+
+ def softmax(x):
+ x = np.exp(x - np.max(x, axis=-1))
+ return x / x.sum(axis=-1)
+
+ sticky_actions = obs['players_raw'][0]['sticky_actions']
+ print(sticky_actions)
+
+ print(actions)
+ print((softmax(p) * 1000).astype(int))
+ print(v)
+ print(r)
+
+
+prev_action = 0
+reserved_action = None
+
+
+def agent(obs):
+ global prev_action, reserved_action
+
+ info = [{'observation': obs, 'action': [prev_action]}, None]
+ env.play_info(info)
+ # print('step %d' % len(env.states))
+
+ x = env.observation(0)
+
+ p, v, r, _ = model.inference(x, None)
+ actions = env.legal_actions(0)
+
+ # output_think(env, obs, actions, p, v, r)
+
+ ap_list = sorted([(a, p[a]) for a in actions], key=lambda x: -x[1])
+
+ # you need return a list contains your single action(a int type number from [1, 18])
+ # be ware of your model output might be a float number, so make sure return a int type number.
+ action = ap_list[0][0]
+
+ if reserved_action is not None:
+ prev_action = reserved_action
+ reserved_action = None
+ # print('###RESERVED###')
+ else:
+ # split action
+ prev_action, reserved_action = env.special_to_actions(action)
+
+ return [prev_action]
diff --git a/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/view_test.py b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/view_test.py
new file mode 100644
index 0000000000000000000000000000000000000000..10e076fed6cd9f64ade766fd632b820ff995a230
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/bots/TamakEriFever/view_test.py
@@ -0,0 +1,31 @@
+# Set up the Environment.
+
+import time
+
+from kaggle_environments import make
+
+# opponent = "football/idle.py"
+# opponent = "football/rulebaseC.py"
+opponent = "builtin_ai"
+
+video_title = "chain"
+video_path = "videos/" + video_title + "_" + opponent.split("/")[-1].replace(".py",
+ "") + str(int(time.time())) + ".webm"
+
+env = make(
+ "football",
+ configuration={
+ "save_video": True,
+ "scenario_name": "11_vs_11_kaggle",
+ "running_in_notebook": False
+ },
+ info={"LiveVideoPath": video_path},
+ debug=True
+)
+output = env.run(["submission.py", opponent])[-1]
+
+scores = [output[i]['observation']['players_raw'][0]['score'][0] for i in range(2)]
+print('Left player: score = %s, status = %s, info = %s' % (scores[0], output[0]['status'], output[0]['info']))
+print('Right player: score = %s, status = %s, info = %s' % (scores[1], output[1]['status'], output[1]['info']))
+
+env.render(mode="human", width=800, height=600)
diff --git a/DI-engine/dizoo/gfootball/model/bots/__init__.py b/DI-engine/dizoo/gfootball/model/bots/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ec340923464f19ee795866072b99911fec9c8be6
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/bots/__init__.py
@@ -0,0 +1,2 @@
+from .kaggle_5th_place_model import FootballKaggle5thPlaceModel
+from .rule_based_bot_model import FootballRuleBaseModel
diff --git a/DI-engine/dizoo/gfootball/model/bots/kaggle_5th_place_model.py b/DI-engine/dizoo/gfootball/model/bots/kaggle_5th_place_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..5ed9a5d7cf14996ef7178b226407557e37721d1b
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/bots/kaggle_5th_place_model.py
@@ -0,0 +1,40 @@
+from kaggle_environments.envs.football.helpers import *
+from math import sqrt
+from enum import Enum
+import torch
+import torch.nn as nn
+import numpy as np
+from ding.torch_utils import tensor_to_list, one_hot, to_ndarray, to_tensor, to_dtype
+from ding.utils import MODEL_REGISTRY
+from .TamakEriFever.submission import agent
+
+
+@MODEL_REGISTRY.register('football_kaggle_5th_place')
+class FootballKaggle5thPlaceModel(torch.nn.Module):
+
+ def __init__(self):
+ super(FootballKaggle5thPlaceModel, self).__init__()
+ # be compatiable with bc policy
+ # to avoid: ValueError: optimizer got an empty parameter list
+ self._dummy_param = nn.Parameter(torch.zeros(1, 1))
+
+ def forward(self, data):
+ actions = []
+ data = data['raw_obs']
+ if isinstance(data['score'], list):
+ # to be compatiable with collect phase in subprocess mode
+ data['score'] = torch.stack(data['score'], dim=-1)
+ # dict of raw observations -> list of dict, each element in the list is the raw obs in a timestep
+ data = [{k: v[i] for k, v in data.items()} for i in range(data['left_team'].shape[0])]
+ for d in data:
+ # the rew obs in one timestep
+ if isinstance(d['steps_left'], torch.Tensor):
+ d = {k: v.cpu() for k, v in d.items()}
+ d = to_ndarray(d)
+ for k in ['active', 'designated', 'ball_owned_player', 'ball_owned_team']:
+ d[k] = int(d[k])
+ for k in ['sticky_actions']:
+ d[k] = list(d[k])
+ d = {'controlled_players': 1, 'players_raw': [d]}
+ actions.append(agent(d)[0])
+ return {'action': torch.LongTensor(actions), 'logit': one_hot(torch.LongTensor(actions), 19)}
diff --git a/DI-engine/dizoo/gfootball/model/bots/rule_based_bot_model.py b/DI-engine/dizoo/gfootball/model/bots/rule_based_bot_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..8c08198e4e4daa77e0d2f7ca91f75ecac7a248e9
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/bots/rule_based_bot_model.py
@@ -0,0 +1,785 @@
+"""
+## referenced https://www.kaggle.com/eugenkeil/simple-baseline-bot by @eugenkeil
+
+## referenced https://www.kaggle.com/david1013/tunable-baseline-bot by @david1013
+
+"""
+from kaggle_environments.envs.football.helpers import *
+from math import sqrt
+from enum import Enum
+import random
+import torch
+import torch.nn as nn
+import numpy as np
+from ding.torch_utils import tensor_to_list, one_hot, to_ndarray
+from ding.utils import MODEL_REGISTRY
+from ding.torch_utils import to_tensor, to_dtype
+"""
+Readable Reminder
+*********************
+class Action(Enum):
+ Idle = 0
+ Left = 1
+ TopLeft = 2
+ Top = 3
+ TopRight = 4
+ Right = 5
+ BottomRight = 6
+ Bottom = 7
+ BottomLeft = 8
+ LongPass= 9
+ HighPass = 10
+ ShortPass = 11
+ Shot = 12
+ Sprint = 13
+ ReleaseDirection = 14
+ ReleaseSprint = 15
+ Slide = 16
+ Dribble = 17
+ ReleaseDribble = 18
+
+
+sticky_index_to_action = [
+ Action.Left,
+ Action.TopLeft,
+ Action.Top,
+ Action.TopRight,
+ Action.Right,
+ Action.BottomRight,
+ Action.Bottom,
+ Action.BottomLeft,
+ Action.Sprint,
+ Action.Dribble
+]
+
+
+class PlayerRole(Enum):
+ GoalKeeper = 0
+ CenterBack = 1
+ LeftBack = 2
+ RightBack = 3
+ DefenceMidfield = 4
+ CentralMidfield = 5
+ LeftMidfield = 6
+ RIghtMidfield = 7
+ AttackMidfield = 8
+ CentralFront = 9
+
+
+class GameMode(Enum):
+ Normal = 0
+ KickOff = 1
+ GoalKick = 2
+ FreeKick = 3
+ Corner = 4
+ ThrowIn = 5
+ Penalty = 6
+"""
+
+
+class Stiuation(Enum):
+ Delaying = 0
+ Offencing = 1
+ Deffencing = 2
+
+
+class Line(object):
+
+ def __init__(self, pos1, pos2):
+ self.a = 1
+ x1, y1 = pos1
+ x2, y2 = pos2
+ if (y2 - y1) != 0.0:
+ self.b = (x2 - x1) / (y2 - y1)
+ else:
+ self.b = 1e5
+ self.c = -x1 - (self.b * y2)
+ self.length = dist(pos1, pos2)
+
+ def distToLine(self, pos):
+ return (self.a * pos[0] + self.b * pos[1] + self.c) / sqrt(self.a ** 2 + self.b ** 2)
+
+
+roles = [0, 7, 9, 2, 1, 1, 3, 5, 5, 5, 6]
+passes = [Action.ShortPass, Action.LongPass, Action.HighPass]
+
+offenseScore = {
+ 0: [-8.0, 0.0],
+ 1: [0.6, 0.8],
+ 2: [0.6, 0.85],
+ 3: [0.6, 0.85],
+ 4: [0.7, 0.9],
+ 5: [0.8, 0.9],
+ 6: [1, 1],
+ 7: [1, 1],
+ 8: [1, 1.1],
+ 9: [1.1, 1.2]
+}
+
+passBias = 2.0
+
+defenceThreatDist = 0.3
+threatAvg = 3.0
+
+shotDistAbs = 0.03
+shotDistFactor = 0.6
+
+offenseGoalDistFactor = 3.0
+offenseKeeperDistFactor = 0.5
+offenseTirenessFactor = 0.3
+
+sprintTirenessFactor = 0.5
+
+passForShotFactor = 0.6
+
+FREEKICK_SHOT_AREA = [[0.5, 1], [-0.2, 0.2]]
+
+START_SHOT_AREA1 = [[0.6, 0.75], [-0.2, 0.2]]
+START_SHOT_AREA2 = [[0.75, 0.95], [-0.13, 0.13]]
+
+PASS_FOR_SHOT_AREA1 = [[0.75, 1], [-0.42, -0.18]]
+PASS_FOR_SHOT_AREA2 = [[0.75, 1], [0.18, 0.42]]
+
+KEEPER_ZONE_AREA = [[0.75, 1], [-0.2, 0.2]]
+LONG_SHOT_RANGE_AREA = [[0.5, 1], [-0.25, 0.25]]
+SPRINT_AREA = [[-0.1, 0.6], [-0.42, 0.42]]
+DEFENCE_SPRING_AREA = [[-0.7, 0.4], [-0.4, 0.4]]
+# DRIBBLE_AREA = [[-0.1, 0.2], [-0.3, 0.3]]
+SLIDE_AREA = [[-0.65, 0], [-0.42, 0.42]]
+
+takenSelfFactor = 0.5
+passFactors = {Action.HighPass: [1.0, 1.2, 3.0], Action.ShortPass: [1.1, 1.5, 1.5], Action.LongPass: [1.0, 1.2, 2]}
+
+# top right/ Bottom left corner are:
+# [1, -0.42] and [-1, 0.42], respectively.
+
+
+def dist(pos1, pos2):
+ return sqrt((pos1[1] - pos2[1]) ** 2 + (pos1[0] - pos2[0]) ** 2)
+
+
+def dirSign(x):
+ if abs(x) < 0.01:
+ return 1
+ elif x < 0:
+ return 0
+ return 2
+
+
+def plusPos(pos1, pos2):
+ return [pos1[0] + pos2[0], pos1[1] + pos2[1]]
+
+
+def vec2dir(vec):
+ p = sqrt(vec[0] ** 2 + vec[1] ** 2)
+ coef = 1 / p
+ return [vec[0] * coef, vec[1] * coef]
+
+
+TOTAL_STEP = 3000
+
+# functions help moving
+
+directions = [
+ [Action.TopLeft, Action.Top, Action.TopRight], [Action.Left, Action.Idle, Action.Right],
+ [Action.BottomLeft, Action.Bottom, Action.BottomRight]
+]
+
+
+def insideArea(pos, area):
+ return area[0][0] <= pos[0] <= area[0][1] and area[1][0] <= pos[1] <= area[1][1]
+
+
+def gotoDir(x, y):
+ xdir = dirSign(x)
+ ydir = dirSign(y)
+ return directions[ydir][xdir]
+
+
+class Processer(object):
+
+ def __init__(self):
+ self._obs = {}
+ self._curPos = None
+ self._keeperPos = None
+ self._goalPos = [1, 0]
+ self._shot_dir_ready = False
+ self._pass_dir_ready = False
+ self._ball_is_free = False
+ self._we_have_ball = False
+ self._enemy_have_ball = False
+ self._our_goalkeeper_have_ball = False
+ self._shot_buf_player = None
+ self._shot_buf_step = -1
+ self._pass_buf_player = None
+ self._pass_buf_step = -1
+ self._score_diff = 0
+ self._pass_type = Action.ShortPass
+
+ def preprocess(self):
+ self._game_mode = self._obs['game_mode']
+ self._cur_player = self._obs['active']
+ if self._obs['score'].shape[0] == 2:
+ self._score_diff = self._obs['score'][0] - self._obs['score'][1]
+ else:
+ self._score_diff = self._obs['score']
+
+ self._curPos = self._obs['left_team'][self._obs['active']]
+ self._curDir = self._obs['left_team_direction'][self._obs['active']]
+ self._keeperPos = self._obs['right_team'][0]
+ self._ballPos = self._obs['ball']
+
+ self._ourPos = self._obs['left_team']
+ self._enemyPos = self._obs['right_team']
+
+ self._ball_is_free = self._obs['ball_owned_team'] == -1
+ self._we_have_ball = self._obs['ball_owned_team'] == 0
+ self._enemy_have_ball = self._obs['ball_owned_team'] == 1
+ self._our_goalkeeper_have_ball = self._obs['ball_owned_player'] == 0 and self._we_have_ball
+ self._our_active_have_ball = self._we_have_ball and self._obs['ball_owned_player'] == self._obs['active']
+
+ self._controlled_role = self._obs['left_team_roles'][self._obs['active']]
+
+ self._most_foward_enemy_pos = self.getMostForwardEnemyPos()
+ self._closest_enemey_pos = self.getClosestEnemyPos()
+ self._closest_enemey_to_cur_vec = [
+ self._curPos[0] - self._closest_enemey_pos[0], self._curPos[1] - self._closest_enemey_pos[1]
+ ]
+ self._closest_enemey_to_cur_dir = vec2dir(self._closest_enemey_to_cur_vec)
+ self._cloest_enemey_dist = dist(self._curPos, self._closest_enemey_pos)
+ self._remain_step = self._obs['steps_left']
+
+ self._cur_tireness = self._obs['left_team_tired_factor'][self._obs['active']]
+ self._our_tireness = self._obs['left_team_tired_factor']
+
+ self._dribbling = Action.Dribble in self._obs['sticky_actions']
+ self._sprinting = Action.Sprint in self._obs['sticky_actions']
+
+ self._our_goalkeeper_active = self._cur_player == 0
+
+ # TODO
+ self._ball_dir = self._obs['ball_direction']
+ self._ball_owner_dir = self.getBallOwnerDir()
+ self._ball_owner_pos = self.getBallOwnerPos()
+
+ if self._enemy_have_ball:
+ self._closest_to_enemy_pos, self._closest_to_enemy_player = self.getClosestToEnemy()
+
+ if not self._shot_dir_ready:
+ self._shot_buf_player = -1
+
+ # general helper
+ ################################
+ def getRole(self, i):
+ return roles[i]
+
+ # general helper for init
+ #################################
+ def getBallOwnerPos(self):
+ if self._ball_is_free:
+ return None
+ elif self._we_have_ball:
+ return self._obs['left_team'][self._obs['ball_owned_player']]
+ else:
+ return self._obs['right_team'][self._obs['ball_owned_player']]
+
+ def getBallOwnerDir(self):
+ if self._ball_is_free:
+ return None
+ elif self._we_have_ball:
+ return self._obs['left_team_direction'][self._obs['ball_owned_player']]
+ else:
+ return self._obs['right_team_direction'][self._obs['ball_owned_player']]
+
+ # general movement
+ ##################################
+ def gobetweenKeeperGate(self):
+ xdir = dirSign(self._keeperPos[0] / 2 + self._goalPos[0] / 2 - self._curPos[0] - 0.05)
+ ydir = dirSign(self._keeperPos[1] / 2 + self._goalPos[1] / 2 - self._curPos[1])
+ return directions[ydir][xdir]
+
+ def gotoDst(self, x, y):
+ xdir = dirSign(x - self._curPos[0])
+ ydir = dirSign(y - self._curPos[1])
+ return directions[ydir][xdir]
+
+ def getMostForwardEnemyPos(self):
+ ret = [0, 0]
+ i = 0
+ for pos in self._obs['right_team']:
+ if i == 0:
+ i += 1
+ continue
+ if pos[0] > ret[0]:
+ ret = pos
+ return ret
+
+ def getAvgDefenceDistToPlayer(self, *args):
+ if len(args) == 0:
+ i = self._cur_player
+ else:
+ i = args[0]
+ sumDist = 0
+ for pos in self._enemyPos:
+ if dist(pos, self._ourPos[i]) < defenceThreatDist:
+ sumDist += dist(pos, self._ourPos[i])
+ return sumDist / threatAvg
+
+ def getClosestEnemy(self, *args):
+ if len(args) == 0:
+ i = self._cur_player
+ else:
+ i = args[0]
+ closest_pos = self._keeperPos
+ closest_index = 0
+ index = 0
+ closest_dist = 2
+ for pos in self._obs['right_team']:
+ if dist(pos, self._ourPos[i]) < dist(self._ourPos[i], closest_pos):
+ closest_pos = pos
+ closest_index = index
+ closest_dist = dist(pos, self._ourPos[i])
+ index += 1
+ return [closest_pos, closest_index, closest_dist]
+
+ def getClosestEnemyPos(self, *args):
+ if len(args) == 0:
+ i = self._cur_player
+ else:
+ i = args[0]
+ return self.getClosestEnemy(i)[0]
+
+ def getClosestEnemyDist(self, *args):
+ if len(args) == 0:
+ i = self._cur_player
+ else:
+ i = args[0]
+ return self.getClosestEnemy(i)[2]
+
+ def should_sprint(self):
+ if self._cur_tireness * sprintTirenessFactor > ((TOTAL_STEP - self._remain_step) / TOTAL_STEP) + 0.2:
+ return False
+ if self._enemy_have_ball:
+ return insideArea(self._curPos, DEFENCE_SPRING_AREA)
+ if self._we_have_ball:
+ return insideArea(self._curPos, SPRINT_AREA)
+
+ # help Judge Shooting
+ def shotWill(self):
+ if insideArea(self._curPos, START_SHOT_AREA1) or insideArea(self._curPos, START_SHOT_AREA2):
+ return True
+ elif not insideArea(self._keeperPos, KEEPER_ZONE_AREA) and insideArea(self._curPos, LONG_SHOT_RANGE_AREA):
+ return True
+ if dist(self._curPos, self._keeperPos) < shotDistFactor * dist(self._keeperPos, self._goalPos) + shotDistAbs:
+ return True
+ return False
+
+ # short pass
+ # def shortPassForShot(self):
+ # if insideArea(self._curPos, PASS_FOR_SHOT_AREA1) or insideArea(self._curPos, PASS_FOR_SHOT_AREA2):
+ # if not self.judgeOffside():
+ # return True
+ # return False
+
+ # help defense
+ #########################
+
+ def getClosestToEnemy(self):
+ retpos = self._obs['left_team'][0]
+ index = 0
+ retindex = index
+ for pos in self._obs['left_team']:
+ if dist(pos, self._ball_owner_pos) < dist(retpos, self._ball_owner_pos):
+ retpos = pos
+ retindex = index
+ index += 1
+ return retpos, retindex
+
+ def getMinxLeftTeam(self):
+ i = 0
+ retpos = [1, 0]
+ for pos in self._ourPos:
+ if i == 0:
+ i += 1
+ continue
+ if pos[0] < retpos[0]:
+ retpos = pos
+ return retpos
+
+ # After testing we know that sliding is not good, so no slide
+ def should_slide(self):
+ if not self._enemy_have_ball:
+ return False
+ # TODO
+ # replace 'and True' -> 'has yellow card'
+ if self._curPos[0] < self._ball_owner_pos[0] - 0.01 and self._curPos[0] < self._ballPos[0] - 0.007 and dist(
+ self._curPos, self._ball_owner_pos) < 0.03 and self._curDir[0] < 0 and insideArea(self._curPos,
+ SLIDE_AREA) and True:
+ return True
+ return False
+
+ # TODO
+ # can this be smarter?
+ def should_chase(self):
+ if self._curPos[0] > self._ball_owner_pos[0] + 0.02 and self._curPos[0] != self._closest_to_enemy_pos[0]:
+ return False
+ minLeftTeamPos = self.getMinxLeftTeam()
+ if self._curPos[0] > self._ball_owner_pos[0] + 0.03 and self._ball_owner_pos[0] - minLeftTeamPos[0] > 1.5 * abs(
+ self._ball_owner_pos[1] - minLeftTeamPos[1]):
+ return False
+ return True
+
+ # help not in our zone
+ def shotAway(self):
+ # disable or enable ?
+ return False
+ if self._curPos[0] < -0.7 and self._our_active_have_ball:
+ return True
+ return False
+
+ # def passAway(self):
+ # if self._curPos[0] < -0.4 and self._our_active_have_ball:
+ # return True
+ # return False
+
+ # functions use to judge passing
+ def judgeOffside(self, *args):
+ if len(args) == 0:
+ LeftTeam = 0
+ for pos in self._obs['left_team']:
+ LeftTeam = max(LeftTeam, pos[0])
+ else:
+ LeftTeam = self._ourPos[args[0]][0]
+ maxRightTeam = self.getMostForwardEnemyPos()[0]
+ return LeftTeam > maxRightTeam
+
+ # TODO
+ def passWill(self):
+ curOffenceMark = self.offenseMark(self._cur_player)
+ bestPassMark, bestPassType, bestPassIndex = self.getBestPass()
+ if bestPassMark > curOffenceMark + passBias:
+ # print("cur pos=", self._curPos)
+ # print("cur off score = ", curOffenceMark)
+ # print("best pass mark = ", bestPassMark)
+ # print("remain step = ", self._remain_step)
+ # print("best pass type = ", bestPassType)
+ # print("want to pass to = ", bestPassIndex)
+ return True, bestPassType, bestPassIndex
+ else:
+ return False, Action.ShortPass, -1
+
+ # TODO
+ def getBestPass(self):
+ if not self._our_active_have_ball:
+ return -1, Action.ShortPass, -1
+ bestPassType = Action.ShortPass
+ bestPassIndex = -1
+ bestPassMark = -10
+ for index in range(11):
+ # can't pass to yourself
+ if index == self._cur_player:
+ continue
+ passMark, passType = self.passMarkTo(index)
+ if passMark > bestPassMark:
+ bestPassMark = passMark
+ bestPassType = passType
+ bestPassIndex = index
+ return bestPassMark, bestPassType, bestPassIndex
+
+ # TODO
+ def passMarkTo(self, i):
+ bestPassType = Action.ShortPass
+ bestPassMark = -10
+ for t in passes:
+ if self.getPassSuccessMark(i, t) + self.offenseMark(i) > bestPassMark:
+ bestPassType = t
+ bestPassMark = self.getPassSuccessMark(i, t) + self.offenseMark(i)
+ return bestPassMark, bestPassType
+
+ def getRoleOffenceScore(self, i):
+ r = roles[i]
+ adder, multier = offenseScore[r]
+ return adder, multier
+
+ # TODO
+ # around 1.0 to 10.0
+ def offenseMark(self, i):
+ mark = 0.0
+ mark += self.getClosestEnemyDist(i)
+ mark += self.getAvgDefenceDistToPlayer(i)
+ # the closer to enemy goal the better
+ mark += 3.0 / (dist(self._ourPos[i], self._goalPos) + 0.2)
+ # but should be further to goalie
+ mark -= 0.5 / (dist(self._ourPos[i], self._keeperPos) + 0.2)
+ # offense pluser for role
+ adder, multier = self.getRoleOffenceScore(i)
+ mark *= multier
+ mark += adder
+ # ADD tireness
+ mark += 1.0 - self._our_tireness[i] * offenseTirenessFactor
+ if insideArea(self._ourPos[i], PASS_FOR_SHOT_AREA1) or insideArea(self._ourPos[i], PASS_FOR_SHOT_AREA2):
+ mark = mark * passForShotFactor
+ return mark
+
+ # TODO
+ # range from
+ def getPassSuccessMark(self, i, passType):
+ # you can't pass to yourself right?
+ if i == self._cur_player:
+ return -10
+ # can't pass offside ball
+ if self.judgeOffside(i):
+ return -10
+ mark = 0.0
+ # calculate intercept
+ # if passType == Action.HighPass:
+ # interceptFactor = 1.0
+ # distFactor = 1.2
+ # takenFactor = 3.0
+ # elif passType == Action.ShortPass:
+ # interceptFactor = 1.0
+ # distFactor = 1.5
+ # takenFactor = 1.5
+ # else:
+ # interceptFactor = 1.2
+ # distFactor = 1.2
+ # takenFactor = 1.5
+ interceptFactor = passFactors[passType][0]
+ distFactor = passFactors[passType][1]
+ takenFactor = passFactors[passType][2]
+ l = Line(self._curPos, self._ourPos[i])
+ minDist = 2
+ for pos in self._enemyPos:
+ minDist = min(minDist, l.distToLine(pos))
+ mark += (minDist * interceptFactor)
+ # calculate taken
+ taken = self.getClosestEnemyDist(i) + takenSelfFactor * self.getClosestEnemyDist()
+ mark += (taken * takenFactor)
+ # calculate dist
+ mark += (l.length * distFactor)
+ return mark
+
+ # freeKick
+ def shotFreeKick(self):
+ if insideArea(self._curPos, FREEKICK_SHOT_AREA):
+ return True
+ return False
+
+ # TODO
+ def cutAngleWithClosest(self):
+ x = self._keeperPos[0] / 2 + self._goalPos[0] / 2 - self._curPos[0]
+ y = self._keeperPos[1] / 2 + self._goalPos[1] / 2 - self._curPos[1]
+ x += self._closest_enemey_to_cur_dir[0] * (0.05 / (self._cloest_enemey_dist + 0.03))
+ y += self._closest_enemey_to_cur_dir[1] * (0.05 / (self._cloest_enemey_dist + 0.03))
+ return gotoDir(x, y)
+
+ def process(self, obs):
+ self._obs = obs
+ self.preprocess()
+
+ # TODO
+ # of course you can only shot in penalty
+ if self._game_mode == GameMode.Penalty:
+ return Action.Shot
+
+ if self._game_mode == GameMode.Corner:
+ if self._pass_dir_ready:
+ return self._pass_type
+ bestPassMark, bestPassType, bestPassIndex = self.getBestPass()
+ self._pass_dir_ready = True
+ self._pass_type = bestPassType
+ return self.gotoDst(self._ourPos[bestPassIndex][0], self._ourPos[bestPassIndex][1])
+
+ if self._game_mode == GameMode.FreeKick:
+ if self.shotFreeKick():
+ return Action.Shot
+ else:
+ if self._pass_dir_ready:
+ return self._pass_type
+ bestPassMark, bestPassType, bestPassIndex = self.getBestPass()
+ self._pass_dir_ready = True
+ self._pass_type = bestPassType
+ return self.gotoDst(self._ourPos[bestPassIndex][0], self._ourPos[bestPassIndex][1])
+
+ if self._game_mode == GameMode.KickOff:
+ return Action.ShortPass
+
+ if self._game_mode == GameMode.ThrowIn:
+ if self._pass_dir_ready:
+ return self._pass_type
+ bestPassMark, bestPassType, bestPassIndex = self.getBestPass()
+ self._pass_dir_ready = True
+ self._pass_type = bestPassType
+ return self.gotoDst(self._ourPos[bestPassIndex][0], self._ourPos[bestPassIndex][1])
+
+ if self._our_active_have_ball and not self._our_goalkeeper_have_ball:
+ if self._shot_dir_ready and self._cur_player == self._shot_buf_player and self._remain_step == self._shot_buf_step - 1:
+ self._shot_dir_ready = False
+ self._shot_buf_player = -1
+ self._shot_buf_step = -1
+ return Action.Shot
+ if self.shotWill():
+ self._shot_buf_player = self._cur_player
+ self._shot_buf_step = self._remain_step
+ self._shot_dir_ready = True
+ # TODO
+ # improve shot direction
+ return self.gobetweenKeeperGate()
+ if self._pass_dir_ready and self._cur_player == self._pass_buf_player and self._remain_step == self._pass_buf_step - 1:
+ self._pass_dir_ready = False
+ self._pass_buf_player = -1
+ self._pass_buf_step = -1
+ return self._pass_type
+ # elif self.passAway() and self._curDir[0] > 0.0:
+ # return Action.HighPass
+ # elif self.shortPassForShot():
+ # return Action.ShortPass
+ else:
+ self._shot_dir_ready = False
+ self._pass_dir_ready = False
+ doPass, doPassType, doPassIndex = self.passWill()
+ if doPass:
+ self._pass_dir_ready = True
+ self._pass_type = doPassType
+ self._pass_buf_step = self._remain_step
+ self._pass_buf_player = self._cur_player
+ return self.gotoDst(self._ourPos[doPassIndex][0], self._ourPos[doPassIndex][1])
+ # ADD avoid opponent
+ if self._closest_enemey_to_cur_vec[0] > 0:
+ # closest enemy behind me and left
+ if not self._sprinting and self.should_sprint():
+ return Action.Sprint
+ if self._dribbling and dist(self._curPos, self._closest_enemey_pos) > 0.02:
+ return Action.ReleaseDribble
+ return self.gobetweenKeeperGate()
+ elif dist(self._curPos, self._closest_enemey_pos) < 0.02:
+ # enemy too close, start dribble
+ # if not self._dribbling:
+ # return Action.Dribble
+ # enemy infront of me, try to cut an angle
+ return self.cutAngleWithClosest()
+ else:
+ # no enemy near me
+ if self._dribbling:
+ return Action.ReleaseDribble
+ if not self._sprinting:
+ return Action.Sprint
+ # ADD release sprint
+ # if self._sprinting and not self.should_sprint():
+ # return Action.ReleaseSprintt
+ # elif not insideArea(curPos, SPRINT_AREA) and Action.Sprint in obs['sticky_actions']:
+ # return Action.ReleaseSprint
+ return self.gobetweenKeeperGate()
+ elif self._we_have_ball and not self._our_goalkeeper_have_ball and not self._our_active_have_ball:
+ self._shot_dir_ready = False
+ return self.gotoDst(self._goalPos[0], self._goalPos[1])
+ elif self._our_goalkeeper_have_ball:
+ self._shot_dir_ready = False
+ if self._our_goalkeeper_active:
+ return Action.HighPass
+ if self._sprinting:
+ return Action.ReleaseSprint
+ return self.gobetweenKeeperGate()
+
+ self._shot_dir_ready = False
+ # ball in enemy or ball free
+ if self._dribbling:
+ return Action.ReleaseDribble
+
+ if self._ball_is_free:
+ if not self._sprinting and self.should_sprint():
+ return Action.Sprint
+ return self.gotoDst(self._ballPos[0] + 2 * self._ball_dir[0], self._ballPos[1] + 2 * self._ball_dir[1])
+
+ if self._enemy_have_ball:
+ # TODO
+ # defense now!
+ # if you are can't catch him and you are not the closest one to gate, just quit chasing.
+ """
+ if not self.should_chase():
+ if self._sprinting:
+ return Action.ReleaseSprint
+ return Action.Idle
+ if self.should_slide():
+ return Action.Slide
+ """
+ if not self._sprinting and self.should_sprint() and self.should_chase():
+ return Action.Sprint
+ # intersect the ball, see https://www.kaggle.com/c/google-football/discussion/191804
+ return self.gotoDst(
+ self._ballPos[0] + 1 * self._ball_dir[0] + 1 * self._ball_owner_dir[0],
+ self._ballPos[1] + 1 * self._ball_dir[1] + 1 * self._ball_owner_dir[1]
+ )
+
+ return self.gotoDst(self._goalPos[0], self._goalPos[1])
+
+
+processer = Processer()
+
+
+# @human_readable_agent
+def agent(obs):
+ global processer
+ return processer.process(obs)
+
+
+def raw_obs_to_readable(obs):
+ # print("obs = ", obs)
+ # print("obs sticky=", obs['active_player_sticky_actions'])
+ obs['sticky_actions'] = {sticky_index_to_action[nr] for nr, action in enumerate(obs['sticky_actions']) if action}
+ # Turn 'game_mode' into an enum.
+ obs['game_mode'] = GameMode(obs['game_mode'])
+ # In case of single agent mode, 'designated' is always equal to 'active'.
+ if 'designated' in obs:
+ del obs['designated']
+ # Conver players' roles to enum.
+ obs['left_team_roles'] = [PlayerRole(role) for role in obs['left_team_roles']]
+ obs['right_team_roles'] = [PlayerRole(role) for role in obs['right_team_roles']]
+ return obs
+
+
+def rule_agent(obs):
+ # obs = obs[0]
+ obs = raw_obs_to_readable(obs)
+ return agent(obs).value
+
+
+def idel_agent(obs):
+ return 0
+
+
+def random_agent(obs):
+ return random.randint(0, 18)
+
+
+agents_map = {"random": random_agent, "rule": rule_agent, "idel": idel_agent}
+
+
+@MODEL_REGISTRY.register('football_rule')
+class FootballRuleBaseModel(torch.nn.Module):
+
+ def __init__(self, cfg={}):
+ super(FootballRuleBaseModel, self).__init__()
+ self.agent_type = cfg.get('agent_type', 'rule')
+ self._agent = agents_map[self.agent_type]
+ # be compatiable with bc policy
+ # to avoid: ValueError: optimizer got an empty parameter list
+ self._dummy_param = nn.Parameter(torch.zeros(1, 1))
+
+ def forward(self, data):
+ actions = []
+ data = data['raw_obs']
+ if isinstance(data['score'], list):
+ # to be compatiable with collect phase in subprocess mode
+ data['score'] = torch.stack(data['score'], dim=-1)
+ # dict of raw observations -> list of dict, each element in the list is the raw obs in one timestep
+ data = [{k: v[i] for k, v in data.items()} for i in range(data['left_team'].shape[0])]
+ for d in data:
+ # the rew obs in one timestep
+ if isinstance(d['steps_left'], torch.Tensor):
+ d = {k: v.cpu() for k, v in d.items()}
+ d = to_ndarray(d)
+ for k in ['active', 'designated', 'ball_owned_player', 'ball_owned_team']:
+ d[k] = int(d[k])
+ actions.append(self._agent(d))
+ return {'action': torch.LongTensor(actions), 'logit': one_hot(torch.LongTensor(actions), 19)}
diff --git a/DI-engine/dizoo/gfootball/model/conv1d/conv1d.py b/DI-engine/dizoo/gfootball/model/conv1d/conv1d.py
new file mode 100644
index 0000000000000000000000000000000000000000..025fa04c3d6f58c6e16463f74df0305706a0b999
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/conv1d/conv1d.py
@@ -0,0 +1,138 @@
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from ding.utils import MODEL_REGISTRY, deep_merge_dicts
+from ding.config import read_config
+from dizoo.gfootball.model.conv1d.conv1d_default_config import conv1d_default_config
+
+
+@MODEL_REGISTRY.register('conv1d')
+class GfootballConv1DModel(nn.Module):
+
+ def __init__(
+ self,
+ cfg: dict = {},
+ ) -> None:
+ super(GfootballConv1DModel, self).__init__()
+ self.cfg = deep_merge_dicts(conv1d_default_config, cfg)
+
+ self.fc_player = nn.Linear(
+ self.cfg.feature_embedding.player.input_dim, self.cfg.feature_embedding.player.output_dim
+ )
+ self.fc_ball = nn.Linear(self.cfg.feature_embedding.ball.input_dim, self.cfg.feature_embedding.ball.output_dim)
+ self.fc_left = nn.Linear(
+ self.cfg.feature_embedding.left_team.input_dim, self.cfg.feature_embedding.left_team.output_dim
+ )
+ self.fc_right = nn.Linear(
+ self.cfg.feature_embedding.right_team.input_dim, self.cfg.feature_embedding.right_team.output_dim
+ )
+ self.fc_left_closest = nn.Linear(
+ self.cfg.feature_embedding.left_closest.input_dim, self.cfg.feature_embedding.left_closest.output_dim
+ )
+ self.fc_right_closest = nn.Linear(
+ self.cfg.feature_embedding.right_closest.input_dim, self.cfg.feature_embedding.right_closest.output_dim
+ )
+
+ self.conv1d_left = nn.Conv1d(
+ self.cfg.feature_embedding.left_team.output_dim,
+ self.cfg.feature_embedding.left_team.conv1d_output_channel,
+ 1,
+ stride=1
+ )
+ self.conv1d_right = nn.Conv1d(
+ self.cfg.feature_embedding.right_team.output_dim,
+ self.cfg.feature_embedding.right_team.conv1d_output_channel,
+ 1,
+ stride=1
+ )
+ self.fc_left2 = nn.Linear(
+ self.cfg.feature_embedding.left_team.conv1d_output_channel * 10,
+ self.cfg.feature_embedding.left_team.fc_output_dim
+ )
+ self.fc_right2 = nn.Linear(
+ self.cfg.feature_embedding.right_team.conv1d_output_channel * 11,
+ self.cfg.feature_embedding.right_team.fc_output_dim
+ )
+ self.fc_cat = nn.Linear(self.cfg.fc_cat.input_dim, self.cfg.lstm_size)
+
+ self.norm_player = nn.LayerNorm(64)
+ self.norm_ball = nn.LayerNorm(64)
+ self.norm_left = nn.LayerNorm(48)
+ self.norm_left2 = nn.LayerNorm(96)
+ self.norm_left_closest = nn.LayerNorm(48)
+ self.norm_right = nn.LayerNorm(48)
+ self.norm_right2 = nn.LayerNorm(96)
+ self.norm_right_closest = nn.LayerNorm(48)
+ self.norm_cat = nn.LayerNorm(self.cfg.lstm_size)
+
+ self.lstm = nn.LSTM(self.cfg.lstm_size, self.cfg.lstm_size)
+
+ self.fc_pi_a1 = nn.Linear(self.cfg.lstm_size, self.cfg.policy_head.hidden_dim)
+ self.fc_pi_a2 = nn.Linear(self.cfg.policy_head.hidden_dim, self.cfg.policy_head.act_shape)
+ self.norm_pi_a1 = nn.LayerNorm(164)
+
+ self.fc_pi_m1 = nn.Linear(self.cfg.lstm_size, 164)
+ self.fc_pi_m2 = nn.Linear(164, 8)
+ self.norm_pi_m1 = nn.LayerNorm(164)
+
+ self.fc_v1 = nn.Linear(self.cfg.lstm_size, self.cfg.value_head.hidden_dim)
+ self.norm_v1 = nn.LayerNorm(164)
+ self.fc_v2 = nn.Linear(self.cfg.value_head.hidden_dim, self.cfg.value_head.output_dim, bias=False)
+
+ def forward(self, state_dict):
+ player_state = state_dict["player"].unsqueeze(0)
+ ball_state = state_dict["ball"].unsqueeze(0)
+ left_team_state = state_dict["left_team"].unsqueeze(0)
+ left_closest_state = state_dict["left_closest"].unsqueeze(0)
+ right_team_state = state_dict["right_team"].unsqueeze(0)
+ right_closest_state = state_dict["right_closest"].unsqueeze(0)
+ avail = state_dict["avail"].unsqueeze(0)
+
+ player_embed = self.norm_player(self.fc_player(player_state))
+ ball_embed = self.norm_ball(self.fc_ball(ball_state))
+ left_team_embed = self.norm_left(self.fc_left(left_team_state)) # horizon, batch, n, dim
+ left_closest_embed = self.norm_left_closest(self.fc_left_closest(left_closest_state))
+ right_team_embed = self.norm_right(self.fc_right(right_team_state))
+ right_closest_embed = self.norm_right_closest(self.fc_right_closest(right_closest_state))
+ [horizon, batch_size, n_player, dim] = left_team_embed.size()
+ left_team_embed = left_team_embed.view(horizon * batch_size, n_player,
+ dim).permute(0, 2, 1) # horizon * batch, dim1, n
+ left_team_embed = F.relu(self.conv1d_left(left_team_embed)).permute(0, 2, 1) # horizon * batch, n, dim2
+ left_team_embed = left_team_embed.reshape(horizon * batch_size,
+ -1).view(horizon, batch_size, -1) # horizon, batch, n * dim2
+ left_team_embed = F.relu(self.norm_left2(self.fc_left2(left_team_embed)))
+
+ right_team_embed = right_team_embed.view(horizon * batch_size, n_player + 1,
+ dim).permute(0, 2, 1) # horizon * batch, dim1, n
+ right_team_embed = F.relu(self.conv1d_right(right_team_embed)).permute(0, 2, 1) # horizon * batch, n * dim2
+ ## Usually we need to call reshape() or contiguous() after permute, transpose, etc to make sure
+ # tensor on memory is contiguous
+ right_team_embed = right_team_embed.reshape(horizon * batch_size, -1).view(horizon, batch_size, -1)
+ ## view() can only be used on contiguous tensor, reshape() don't have this limit.
+ right_team_embed = F.relu(self.norm_right2(self.fc_right2(right_team_embed)))
+
+ cat = torch.cat(
+ [player_embed, ball_embed, left_team_embed, right_team_embed, left_closest_embed, right_closest_embed], 2
+ )
+ cat = F.relu(self.norm_cat(self.fc_cat(cat)))
+ hidden = state_dict.pop('prev_state', None)
+ if hidden is None:
+ h_in = (
+ torch.zeros([1, batch_size, self.cfg.lstm_size],
+ dtype=torch.float), torch.zeros([1, batch_size, self.cfg.lstm_size], dtype=torch.float)
+ )
+ else:
+ h_in = hidden
+ out, h_out = self.lstm(cat, h_in)
+
+ a_out = F.relu(self.norm_pi_a1(self.fc_pi_a1(out)))
+ a_out = self.fc_pi_a2(a_out)
+ logit = a_out + (avail - 1) * 1e7
+ prob = F.softmax(logit, dim=2)
+
+ v = F.relu(self.norm_v1(self.fc_v1(out)))
+ v = self.fc_v2(v)
+
+ return {'logit': prob.squeeze(0), 'value': v.squeeze(0), 'next_state': h_out}
diff --git a/DI-engine/dizoo/gfootball/model/conv1d/conv1d_default_config.py b/DI-engine/dizoo/gfootball/model/conv1d/conv1d_default_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b2b2ffe265a26962a4eed4f07acd6bc1d04f884
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/conv1d/conv1d_default_config.py
@@ -0,0 +1,44 @@
+from easydict import EasyDict
+
+conv1d_config = dict(
+ feature_embedding=dict(
+ player=dict(
+ input_dim=36,
+ output_dim=64,
+ ),
+ ball=dict(
+ input_dim=18,
+ output_dim=64,
+ ),
+ left_team=dict(
+ input_dim=7,
+ output_dim=48,
+ conv1d_output_channel=36,
+ fc_output_dim=96,
+ ),
+ right_team=dict(
+ input_dim=7,
+ output_dim=48,
+ conv1d_output_channel=36,
+ fc_output_dim=96,
+ ),
+ left_closest=dict(
+ input_dim=7,
+ output_dim=48,
+ ),
+ right_closest=dict(
+ input_dim=7,
+ output_dim=48,
+ )
+ ),
+ fc_cat=dict(input_dim=416, ),
+ lstm_size=256,
+ policy_head=dict(
+ input_dim=256,
+ hidden_dim=164,
+ act_shape=19,
+ ),
+ value_head=dict(input_dim=256, hidden_dim=164, output_dim=1),
+)
+
+conv1d_default_config = EasyDict(conv1d_config)
diff --git a/DI-engine/dizoo/gfootball/model/q_network/football_q_network.py b/DI-engine/dizoo/gfootball/model/q_network/football_q_network.py
new file mode 100644
index 0000000000000000000000000000000000000000..166f16618bba182f33608522dbdb8a71bd0c2710
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/q_network/football_q_network.py
@@ -0,0 +1,289 @@
+from functools import partial
+from ding.utils import deep_merge_dicts, MODEL_REGISTRY
+from ding.utils.data import default_collate
+from ding.torch_utils import fc_block, Transformer, ResFCBlock, \
+ conv2d_block, ResBlock, build_activation, ScatterConnection
+import torch
+import torch.nn as nn
+from ding.utils import MODEL_REGISTRY, SequenceType, squeeze
+from ding.model.common import FCEncoder, ConvEncoder, DiscreteHead, DuelingHead, MultiHead
+from .football_q_network_default_config import default_model_config
+
+
+@MODEL_REGISTRY.register('football_naive_q')
+class FootballNaiveQ(nn.Module):
+ """
+ Overview:
+ Q model for gfootball.
+ utilize the special football obs encoder ``self.football_obs_encoder``: containing
+ ``ScalarEncoder``, ``PlayerEncoder`` or ``SpatialEncoder``.
+ """
+
+ def __init__(
+ self,
+ cfg: dict = {},
+ ) -> None:
+ super(FootballNaiveQ, self).__init__()
+ self.cfg = deep_merge_dicts(default_model_config, cfg)
+ scalar_encoder_arch = self.cfg.encoder.match_scalar
+ player_encoder_arch = self.cfg.encoder.player
+ self.scalar_encoder = ScalarEncoder(cfg=scalar_encoder_arch)
+ self.player_type = player_encoder_arch.encoder_type
+ assert self.player_type in ['transformer', 'spatial']
+ if self.player_type == 'transformer':
+ self.player_encoder = PlayerEncoder(cfg=player_encoder_arch.transformer)
+ elif self.player_type == 'spatial':
+ self.player_encoder = SpatialEncoder(cfg=player_encoder_arch.spatial)
+ scalar_dim = self.scalar_encoder.output_dim
+ player_dim = self.player_encoder.output_dim
+ head_input_dim = scalar_dim + player_dim
+ self.pred_head = FootballHead(input_dim=head_input_dim, cfg=self.cfg.policy)
+
+ def forward(self, x: dict) -> dict:
+ """
+ Overview:
+ Use obs to run MLP or transformer with ``FootballNaiveQ`` and return the prediction dictionary.
+ Arguments:
+ - x (:obj:`Dict`): Dict containing keyword ``processed_obs`` (:obj:`Dict`) and ``raw_obs`` (:obj:`Dict`).
+ Returns:
+ - outputs (:obj:`Dict`): Dict containing keyword ``logit`` (:obj:`torch.Tensor`) and ``action`` (:obj:`torch.Tensor`).
+ Shapes:
+ - x: :math:`(B, N)`, where ``B = batch_size`` and ``N = hidden_size``.
+ - logit: :math:`(B, A)`, where ``A = action_dim``.
+ - action: :math:`(B, )`.
+ """
+ if isinstance(x, dict) and len(x) == 2:
+ x = x['processed_obs']
+ scalar_encodings = self.scalar_encoder(x)
+ if self.player_type == 'transformer':
+ player_encodings = self.player_encoder(x['players'], x['active_player'])
+ elif self.player_type == 'spatial':
+ player_encodings = self.player_encoder(x['players'])
+ encoding_list = list(scalar_encodings.values()) + [player_encodings]
+ x = torch.cat(encoding_list, dim=1)
+
+ x = self.pred_head(x)
+ return {'logit': x, 'action': torch.argmax(x, dim=-1)}
+
+
+class ScalarEncoder(nn.Module):
+
+ def __init__(self, cfg: dict) -> None:
+ super(ScalarEncoder, self).__init__()
+ self.cfg = cfg
+ self.act = nn.ReLU()
+ self.output_dim = 0
+ for k, arch in cfg.items():
+ self.output_dim += arch['output_dim']
+ encoder = fc_block(arch['input_dim'], arch['output_dim'], activation=self.act)
+ setattr(self, k, encoder)
+
+ def forward(self, x: dict) -> dict:
+ """
+ Shape:
+ - input: dict{scalar_name: scalar_tensor(:math: `(B, scalar_dim)`)}
+ - output: dict{scalar_name: scalar_encoded_tensor(:math: `(B, scalar_encoded_dim)`)}
+ """
+ fixed_scalar_sequence = [
+ 'ball_position', 'ball_direction', 'ball_rotation', 'ball_owned_team', 'ball_owned_player', 'active_player',
+ 'designated_player', 'active_player_sticky_actions', 'score', 'steps_left', 'game_mode'
+ ]
+ encodings = {}
+ for k in fixed_scalar_sequence:
+ data = x[k]
+ encodings[k] = getattr(self, k)(data)
+ if len(encodings[k].shape) == 1:
+ encodings[k].unsqueeze_(0)
+ elif len(encodings[k].shape) == 3:
+ encodings[k].squeeze_(0)
+ return encodings
+
+
+def cat_player_attr(player_data: dict) -> torch.Tensor:
+ """
+ Arguments:
+ player_data: {this_attr_name: [B, this_attr_dim]}
+ Returns:
+ attr: [B, total_attr_dim]
+ """
+ fixed_player_attr_sequence = [
+ 'team', 'index', 'position', 'direction', 'tired_factor', 'yellow_card', 'active', 'role'
+ ]
+ attr = []
+ for k in fixed_player_attr_sequence:
+ if len(player_data[k].shape) == 1 and k != 'tired_factor':
+ player_data[k].unsqueeze_(0) # TODO(pu): expand batch_dim
+ elif len(player_data[k].shape) == 1 and k == 'tired_factor':
+ player_data[k].unsqueeze_(-1) # TODO(pu): expand data_dim
+
+ if len(player_data[k].shape) == 3:
+ # TODO(pu): to be compatible with serial_entry_bc
+ # ``res = policy._forward_eval(bat['obs'])``
+ player_data[k].squeeze_(0)
+ attr.append(player_data[k])
+ attr = torch.cat(attr, dim=-1)
+ return attr
+
+
+class PlayerEncoder(nn.Module):
+
+ def __init__(
+ self,
+ cfg: dict,
+ ) -> None:
+ super(PlayerEncoder, self).__init__()
+ self.act = nn.ReLU()
+ self.player_num = cfg.player_num
+ assert self.player_num in [1, 22], self.player_num
+ self.output_dim = sum([dim for k, dim in cfg.player_attr_dim.items()]) * self.player_num
+ player_transformer = Transformer(
+ input_dim=cfg.input_dim,
+ head_dim=cfg.head_dim,
+ hidden_dim=cfg.hidden_dim,
+ output_dim=cfg.output_dim,
+ head_num=cfg.head_num,
+ mlp_num=cfg.mlp_num,
+ layer_num=cfg.layer_num,
+ dropout_ratio=cfg.dropout_ratio,
+ activation=self.act,
+ )
+ setattr(self, 'players', player_transformer)
+
+ def forward(self, x: list, active_player: torch.Tensor) -> torch.Tensor:
+ """
+ Shape:
+ - input: list[len=22(=player_num/M)] -> element: dict{attr_name: attr_tensor(:math: `(B, attr_dim)`)}
+ - active_player: :math: `(B, 11)`)
+ - output: :math: `(B, player_num*total_attr_dim)`, player_num is in [1, 22]
+ """
+ player_input = self.get_player_input(x, active=active_player) # (player_num*B, total_attr_dim)
+ # player_output = getattr(self, 'players')(player_input, tensor_output=True) # (player_num*B, total_attr_dim, 1)
+ player_output = getattr(self, 'players')(player_input) # (player_num*B, total_attr_dim, 1)
+ player_output = player_output.squeeze(dim=2) # (player_num*B, total_attr_dim)
+ player_output = player_output.reshape((22, -1, player_output.shape[1])) # (player_num, B, total_attr_dim)
+ player_output = player_output.permute(1, 0, 2) # (B, player_num, total_attr_dim)
+ player_output = player_output.reshape((player_output.shape[0], -1)) # (B, player_num*total_attr_dim)
+ return player_output
+
+ def get_player_input(self, data: list, active: torch.Tensor) -> torch.Tensor:
+ if self.player_num == 1:
+ bs = data[0]['index'].shape[0]
+ batch_player = [None for _ in range(bs)]
+ for player in data:
+ for idx in range(bs):
+ if batch_player[idx] is not None:
+ continue
+ if torch.nonzero(player['index'][idx]).item() == torch.nonzero(active[idx]).item() \
+ and torch.nonzero(player['team'][idx]).item() == 0:
+ batch_player[idx] = {k: v[idx] for k, v in player.items()}
+ if None not in batch_player:
+ break
+ # old_batch_player: list[len=bs] -> element: dict{attr_name: attr_tensor(:math: `(attr_dim)`)}
+ batch_player = default_collate(batch_player)
+ # new_batch_player: dict{attr_name: attr_tensor(:math: `(bs, attr_dim)`)}
+ return cat_player_attr(batch_player).unsqueeze(dim=2)
+ elif self.player_num == 22:
+ players = []
+ for player in data:
+ players.append(cat_player_attr(player))
+ players = torch.cat(players, dim=0)
+ players = players.unsqueeze(dim=2)
+ return players
+
+
+class SpatialEncoder(nn.Module):
+
+ def __init__(
+ self,
+ cfg: dict,
+ ) -> None:
+ super(SpatialEncoder, self).__init__()
+ self.act = build_activation(cfg.activation)
+ self.norm = cfg.norm_type
+ self.scatter = ScatterConnection(cfg.scatter_type)
+ input_dim = sum([dim for k, dim in cfg.player_attr_dim.items()]) # player_attr total dim
+ self.project = conv2d_block(input_dim, cfg.project_dim, 1, 1, 0, activation=self.act, norm_type=self.norm)
+ down_layers = []
+ dims = [cfg.project_dim] + cfg.down_channels
+ self.down_channels = cfg.down_channels
+ for i in range(len(self.down_channels)):
+ down_layers.append(nn.AvgPool2d(2, 2))
+ down_layers.append(conv2d_block(dims[i], dims[i + 1], 3, 1, 1, activation=self.act, norm_type=self.norm))
+ self.downsample = nn.Sequential(*down_layers)
+ self.res = nn.ModuleList()
+ dim = dims[-1]
+ self.resblock_num = cfg.resblock_num
+ for i in range(cfg.resblock_num):
+ self.res.append(ResBlock(dim, activation=self.act, norm_type=self.norm))
+
+ self.gap = nn.AdaptiveAvgPool2d((1, 1))
+ self.fc = fc_block(dim, cfg.fc_dim, activation=self.act)
+ self.output_dim = cfg.fc_dim
+
+ def forward(self, x: list) -> torch.Tensor:
+ """
+ Shape:
+ - input: list[len=22(=player_num/M)] -> element: dict{attr_name: attr_tensor(:math: `(B, attr_dim)`)}
+ - output: :math: `(B, fc_dim)`
+ """
+ players = []
+ players_loc = []
+ granularity = 0.01
+ H, W = 84, 200
+ for player in x:
+ players.append(cat_player_attr(player))
+ device = player['position'].device
+ player_loc = ((player['position'] + torch.FloatTensor([1., 0.42]).to(device)) / granularity).long()
+ player_loc_yx = player_loc[:, [1, 0]]
+ players_loc.append(player_loc_yx)
+ players = torch.stack(players, dim=1) # [B, M, N]
+ players_loc = torch.stack(players_loc, dim=1) # [B, M, 2]
+ players_loc[..., 0] = players_loc[..., 0].clamp(0, H - 1)
+ players_loc[..., 1] = players_loc[..., 1].clamp(0, W - 1)
+ x = self.scatter(players, (H, W), players_loc)
+ x = self.project(x)
+ x = self.downsample(x)
+ for block in self.res:
+ x = block(x)
+ x = self.gap(x)
+ x = x.view(x.shape[:2])
+ x = self.fc(x)
+ return x
+
+
+class FootballHead(nn.Module):
+
+ def __init__(
+ self,
+ input_dim: int,
+ cfg: dict,
+ ) -> None:
+ super(FootballHead, self).__init__()
+ self.act = nn.ReLU()
+ self.input_dim = input_dim
+ self.hidden_dim = cfg.res_block.hidden_dim
+ self.res_num = cfg.res_block.block_num
+ self.dueling = cfg.dqn.dueling
+ self.a_layer_num = cfg.dqn.a_layer_num
+ self.v_layer_num = cfg.dqn.v_layer_num
+ self.action_dim = cfg.action_dim
+ self.pre_fc = fc_block(in_channels=input_dim, out_channels=self.hidden_dim, activation=self.act)
+ res_blocks_list = []
+ for i in range(self.res_num):
+ res_blocks_list.append(ResFCBlock(in_channels=self.hidden_dim, activation=self.act, norm_type=None))
+ self.res_blocks = nn.Sequential(*res_blocks_list)
+ head_fn = partial(
+ DuelingHead, a_layer_num=self.a_layer_num, v_layer_num=self.v_layer_num
+ ) if self.dueling else nn.Linear
+ self.pred = head_fn(self.hidden_dim, self.action_dim)
+
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
+ """
+ Shape:
+ - input: :math: `(B, input_dim)`), input_dim is the sum of all encoders' output_dim
+ - output: :math: `(B, action_dim)`)
+ """
+ x = self.pre_fc(x)
+ x = self.res_blocks(x)
+ x = self.pred(x)
+ return x['logit']
diff --git a/DI-engine/dizoo/gfootball/model/q_network/football_q_network_default_config.py b/DI-engine/dizoo/gfootball/model/q_network/football_q_network_default_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ba0a617f49832218f3d8975b22dc24e00747af7
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/q_network/football_q_network_default_config.py
@@ -0,0 +1,58 @@
+from easydict import EasyDict
+
+model_config = dict(
+ # ===== Encoder =====
+ encoder=dict(
+ match_scalar=dict(
+ ball_position=dict(input_dim=3, output_dim=32),
+ ball_direction=dict(input_dim=3, output_dim=32),
+ ball_rotation=dict(input_dim=3, output_dim=32),
+ ball_owned_team=dict(input_dim=3, output_dim=32),
+ ball_owned_player=dict(input_dim=12, output_dim=32),
+ active_player=dict(input_dim=11, output_dim=32),
+ designated_player=dict(input_dim=11, output_dim=32),
+ active_player_sticky_actions=dict(input_dim=10, output_dim=64),
+ score=dict(input_dim=22, output_dim=64),
+ steps_left=dict(input_dim=30, output_dim=128),
+ game_mode=dict(input_dim=7, output_dim=128),
+ ),
+ player=dict(
+ # choices: ['transformer', 'spatial']
+ encoder_type='transformer',
+ transformer=dict(
+ player_num=22,
+ player_attr_dim=dict(
+ team=2, index=11, position=2, direction=2, tired_factor=1, yellow_card=2, active=2, role=10
+ ),
+ input_dim=1,
+ head_dim=64,
+ hidden_dim=128,
+ output_dim=1,
+ head_num=2,
+ mlp_num=2,
+ layer_num=3,
+ dropout_ratio=1
+ ),
+ spatial=dict(
+ resblock_num=4,
+ fc_dim=256,
+ project_dim=32,
+ down_channels=[64, 128],
+ activation='relu',
+ norm_type='BN',
+ scatter_type='add',
+ player_attr_dim=dict(
+ team=2, index=11, position=2, direction=2, tired_factor=1, yellow_card=2, active=2, role=10
+ ),
+ ),
+ )
+ ),
+ # ===== Policy =====
+ policy=dict(
+ res_block=dict(hidden_dim=1024, block_num=3),
+ dqn=dict(dueling=True, a_layer_num=2, v_layer_num=2),
+ action_dim=19,
+ )
+)
+
+default_model_config = EasyDict(model_config)
diff --git a/DI-engine/dizoo/gfootball/model/q_network/tests/test_football_model.py b/DI-engine/dizoo/gfootball/model/q_network/tests/test_football_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..87b91a9ccc4015f8bbeffedcd227e3ff65fd37b2
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/model/q_network/tests/test_football_model.py
@@ -0,0 +1,45 @@
+import pytest
+import copy
+import torch
+import os
+import yaml
+from easydict import EasyDict
+from dizoo.gfootball.model.q_network.football_q_network import FootballNaiveQ
+from ding.torch_utils import to_tensor, to_dtype
+from dizoo.gfootball.envs.fake_dataset import FakeGfootballDataset
+import pprint
+from dizoo.gfootball.model.q_network.football_q_network_default_config import default_model_config
+
+
+@pytest.mark.envtest
+class TestModel:
+
+ def test_encoder(self, config=default_model_config):
+ B = 4
+ scalar_encoder_arch = config.encoder.match_scalar
+ player_attr_dim = config.encoder.player.transformer.player_attr_dim
+ action_dim = config.policy.action_dim
+ cfg = copy.deepcopy(config)
+
+ for t in ['transformer', 'spatial']:
+ cfg.encoder.player.encoder_type = t
+
+ inputs = {}
+ for k, v in scalar_encoder_arch.items():
+ inputs[k] = torch.randn(B, v['input_dim'])
+ inputs['players'] = []
+ for _ in range(22):
+ inputs['players'].append({k: torch.randn(B, v) for k, v in player_attr_dim.items()})
+ fake_dataset = FakeGfootballDataset()
+ inputs = fake_dataset.get_batched_obs(bs=B)
+ pp = pprint.PrettyPrinter(indent=2)
+ print('observation: ')
+ pp.pprint(inputs)
+
+ model = FootballNaiveQ(cfg)
+ assert isinstance(model, torch.nn.Module)
+ inputs = to_dtype(inputs, torch.float32)
+ inputs = to_tensor(inputs)
+ outputs = model(inputs)
+ assert outputs['logit'].shape == (B, 19)
+ assert outputs['action'].shape == (B, )
diff --git a/DI-engine/dizoo/gfootball/policy/__init__.py b/DI-engine/dizoo/gfootball/policy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..08a90dd21ed359fd60a378fe8c546887fd82348d
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/policy/__init__.py
@@ -0,0 +1 @@
+from .ppo_lstm import PPOPolicy, PPOCommandModePolicy
diff --git a/DI-engine/dizoo/gfootball/policy/ppo_lstm.py b/DI-engine/dizoo/gfootball/policy/ppo_lstm.py
new file mode 100644
index 0000000000000000000000000000000000000000..4380f1261c7579b8c19a7ace63e06784ed21d251
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/policy/ppo_lstm.py
@@ -0,0 +1,347 @@
+from typing import List, Dict, Any, Tuple, Union, Optional
+from collections import namedtuple, deque
+import torch
+import copy
+
+from ding.torch_utils import Adam, to_device
+from ding.rl_utils import ppo_data, ppo_error, ppo_policy_error, ppo_policy_data, get_gae_with_default_last_value, \
+ v_nstep_td_data, v_nstep_td_error, get_nstep_return_data, get_train_sample
+
+from ding.model import model_wrap
+from ding.utils import POLICY_REGISTRY, deep_merge_dicts
+from ding.utils.data import default_collate, default_decollate
+from ding.policy.base_policy import Policy
+from ding.policy.common_utils import default_preprocess_learn
+from ding.policy.command_mode_policy_instance import DummyCommandModePolicy
+
+
+@POLICY_REGISTRY.register('ppo_lstm')
+class PPOPolicy(Policy):
+ r"""
+ Overview:
+ Policy class of PPO algorithm.
+ """
+ config = dict(
+ # (str) RL policy register name (refer to function "POLICY_REGISTRY").
+ type='ppo_lstm',
+ # (bool) Whether to use cuda for network.
+ cuda=False,
+ # (bool) Whether the RL algorithm is on-policy or off-policy. (Note: in practice PPO can be off-policy used)
+ on_policy=True,
+ # (bool) Whether to use priority(priority sample, IS weight, update priority)
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ # (bool) Whether to use nstep_return for value loss
+ nstep_return=False,
+ nstep=3,
+ learn=dict(
+ # How many updates(iterations) to train after collector's one collection.
+ # Bigger "update_per_collect" means bigger off-policy.
+ # collect data -> update policy-> collect data -> ...
+ update_per_collect=5,
+ batch_size=64,
+ learning_rate=0.001,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ ignore_done=False,
+ ),
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] shoule be set
+ # n_sample=64,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # (float) GAE lambda factor for the balance of bias and variance(1-step td and mc)
+ gae_lambda=0.95,
+ ),
+ eval=dict(),
+ # Although ppo is an on-policy algorithm, ding reuses the buffer mechanism, and clear buffer after update.
+ # Note replay_buffer_size must be greater than n_sample.
+ other=dict(replay_buffer=dict(replay_buffer_size=1000, ), ),
+ )
+
+ def _init_learn(self) -> None:
+ r"""
+ Overview:
+ Learn mode init method. Called by ``self.__init__``.
+ Init the optimizer, algorithm config and the main model.
+ """
+ self._priority = self._cfg.priority
+ self._priority_IS_weight = self._cfg.priority_IS_weight
+ assert not self._priority and not self._priority_IS_weight, "Priority is not implemented in PPO"
+ # Orthogonal init
+ for m in self._model.modules():
+ if isinstance(m, torch.nn.Conv2d):
+ torch.nn.init.orthogonal_(m.weight)
+ if isinstance(m, torch.nn.Linear):
+ torch.nn.init.orthogonal_(m.weight)
+ # Optimizer
+ self._optimizer = Adam(self._model.parameters(), lr=self._cfg.learn.learning_rate)
+ self._learn_model = model_wrap(self._model, wrapper_name='base')
+ # self._learn_model = model_wrap(self._learn_model, wrapper_name='hidden_state', state_num=self._cfg.learn.batch_size)
+
+ # Algorithm config
+ self._value_weight = self._cfg.learn.value_weight
+ self._entropy_weight = self._cfg.learn.entropy_weight
+ self._clip_ratio = self._cfg.learn.clip_ratio
+ self._adv_norm = self._cfg.learn.adv_norm
+ self._nstep = self._cfg.nstep
+ self._nstep_return = self._cfg.nstep_return
+ # Main model
+ self._learn_model.reset()
+
+ def _forward_learn(self, data: dict) -> Dict[str, Any]:
+ r"""
+ Overview:
+ Forward and backward function of learn mode.
+ Arguments:
+ - data (:obj:`dict`): Dict type data
+ Returns:
+ - info_dict (:obj:`Dict[str, Any]`):
+ Including current lr, total_loss, policy_loss, value_loss, entropy_loss, \
+ adv_abs_max, approx_kl, clipfrac
+ """
+ data = default_preprocess_learn(data, ignore_done=self._cfg.learn.ignore_done, use_nstep=self._nstep_return)
+ if self._cuda:
+ data = to_device(data, self._device)
+ # ====================
+ # PPO forward
+ # ====================
+
+ self._learn_model.train()
+ # normal ppo
+ if not self._nstep_return:
+ output = self._learn_model.forward(data['obs'])
+ adv = data['adv']
+ if self._adv_norm:
+ # Normalize advantage in a total train_batch
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
+ return_ = data['value'] + adv
+ # Calculate ppo error
+ ppodata = ppo_data(
+ output['logit'], data['logit'], data['action'], output['value'], data['value'], adv, return_,
+ data['weight']
+ )
+ ppo_loss, ppo_info = ppo_error(ppodata, self._clip_ratio)
+ wv, we = self._value_weight, self._entropy_weight
+ total_loss = ppo_loss.policy_loss + wv * ppo_loss.value_loss - we * ppo_loss.entropy_loss
+
+ else:
+ output = self._learn_model.forward(data['obs'])
+ adv = data['adv']
+ if self._adv_norm:
+ # Normalize advantage in a total train_batch
+ adv = (adv - adv.mean()) / (adv.std() + 1e-8)
+
+ # Calculate ppo error
+ ppodata = ppo_policy_data(output['logit'], data['logit'], data['action'], adv, data['weight'])
+ ppo_policy_loss, ppo_info = ppo_policy_error(ppodata, self._clip_ratio)
+ wv, we = self._value_weight, self._entropy_weight
+ next_obs = data.get('next_obs')
+ value_gamma = data.get('value_gamma')
+ reward = data.get('reward')
+ # current value
+ value = self._learn_model.forward(data['obs'])
+ # target value
+ next_data = {'obs': next_obs}
+ target_value = self._learn_model.forward(next_data['obs'])
+ # TODO what should we do here to keep shape
+ assert self._nstep > 1
+ td_data = v_nstep_td_data(
+ value['value'], target_value['value'], reward.t(), data['done'], data['weight'], value_gamma
+ )
+ #calculate v_nstep_td critic_loss
+ critic_loss, td_error_per_sample = v_nstep_td_error(td_data, self._gamma, self._nstep)
+ ppo_loss_data = namedtuple('ppo_loss', ['policy_loss', 'value_loss', 'entropy_loss'])
+ ppo_loss = ppo_loss_data(ppo_policy_loss.policy_loss, critic_loss, ppo_policy_loss.entropy_loss)
+ total_loss = ppo_policy_loss.policy_loss + wv * critic_loss - we * ppo_policy_loss.entropy_loss
+
+ # ====================
+ # PPO update
+ # ====================
+ self._optimizer.zero_grad()
+ total_loss.backward()
+ self._optimizer.step()
+ return {
+ 'cur_lr': self._optimizer.defaults['lr'],
+ 'total_loss': total_loss.item(),
+ 'policy_loss': ppo_loss.policy_loss.item(),
+ 'value_loss': ppo_loss.value_loss.item(),
+ 'entropy_loss': ppo_loss.entropy_loss.item(),
+ 'adv_abs_max': adv.abs().max().item(),
+ 'approx_kl': ppo_info.approx_kl,
+ 'clipfrac': ppo_info.clipfrac,
+ }
+
+ def _state_dict_learn(self) -> Dict[str, Any]:
+ return {
+ 'model': self._learn_model.state_dict(),
+ 'optimizer': self._optimizer.state_dict(),
+ }
+
+ def _load_state_dict_learn(self, state_dict: Dict[str, Any]) -> None:
+ self._learn_model.load_state_dict(state_dict['model'])
+ self._optimizer.load_state_dict(state_dict['optimizer'])
+
+ def _init_collect(self) -> None:
+ r"""
+ Overview:
+ Collect mode init method. Called by ``self.__init__``.
+ Init traj and unroll length, collect model.
+ """
+ self._unroll_len = self._cfg.collect.unroll_len
+ self._collect_model = model_wrap(self._model, wrapper_name='multinomial_sample')
+ # self._collect_model = model_wrap(
+ # self._collect_model, wrapper_name='hidden_state', state_num=self._cfg.collect.env_num, save_prev_state=True
+ # )
+ self._collect_model.reset()
+ self._gamma = self._cfg.collect.discount_factor
+ self._gae_lambda = self._cfg.collect.gae_lambda
+ self._nstep = self._cfg.nstep
+ self._nstep_return = self._cfg.nstep_return
+
+ def _forward_collect(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of collect mode.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): Dict type data, including at least inferred action according to input obs.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._collect_model.eval()
+ with torch.no_grad():
+ output = self._collect_model.forward(data)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _process_transition(self, obs: Any, model_output: dict, timestep: namedtuple) -> dict:
+ """
+ Overview:
+ Generate dict type transition data from inputs.
+ Arguments:
+ - obs (:obj:`Any`): Env observation
+ - model_output (:obj:`dict`): Output of collect model, including at least ['action']
+ - timestep (:obj:`namedtuple`): Output after env step, including at least ['obs', 'reward', 'done']\
+ (here 'obs' indicates obs after env step).
+ Returns:
+ - transition (:obj:`dict`): Dict type transition data.
+ """
+ if not self._nstep_return:
+ transition = {
+ 'obs': obs,
+ 'logit': model_output['logit'],
+ 'action': model_output['action'],
+ 'value': model_output['value'],
+ 'prev_state': model_output['prev_state'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ else:
+ transition = {
+ 'obs': obs,
+ 'next_obs': timestep.obs,
+ 'logit': model_output['logit'],
+ 'action': model_output['action'],
+ 'prev_state': model_output['prev_state'],
+ 'value': model_output['value'],
+ 'reward': timestep.reward,
+ 'done': timestep.done,
+ }
+ return transition
+
+ def _get_train_sample(self, data: deque) -> Union[None, List[Any]]:
+ r"""
+ Overview:
+ Get the trajectory and calculate GAE, return one data to cache for next time calculation
+ Arguments:
+ - data (:obj:`deque`): The trajectory's cache
+ Returns:
+ - samples (:obj:`dict`): The training samples generated
+ """
+ data = get_gae_with_default_last_value(
+ data,
+ data[-1]['done'],
+ gamma=self._gamma,
+ gae_lambda=self._gae_lambda,
+ cuda=self._cuda,
+ )
+
+ if not self._nstep_return:
+ return get_train_sample(data, self._unroll_len)
+ else:
+ return get_nstep_return_data(data, self._nstep)
+
+ def _init_eval(self) -> None:
+ r"""
+ Overview:
+ Evaluate mode init method. Called by ``self.__init__``.
+ Init eval model with argmax strategy.
+ """
+ self._eval_model = model_wrap(self._model, wrapper_name='argmax_sample')
+ # self._eval_model = model_wrap(self._model, wrapper_name='hidden_state', state_num=self._cfg.eval.env_num)
+ self._eval_model.reset()
+
+ def _forward_eval(self, data: dict) -> dict:
+ r"""
+ Overview:
+ Forward function of eval mode, similar to ``self._forward_collect``.
+ Arguments:
+ - data (:obj:`Dict[str, Any]`): Dict type data, stacked env data for predicting policy_output(action), \
+ values are torch.Tensor or np.ndarray or dict/list combinations, keys are env_id indicated by integer.
+ Returns:
+ - output (:obj:`Dict[int, Any]`): The dict of predicting action for the interaction with env.
+ ReturnsKeys
+ - necessary: ``action``
+ """
+ data_id = list(data.keys())
+ data = default_collate(list(data.values()))
+ if self._cuda:
+ data = to_device(data, self._device)
+ # data = {'obs': data}
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data[0])
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ output = default_decollate(output)
+ return {i: d for i, d in zip(data_id, output)}
+
+ def _reset_eval(self, data_id: Optional[List[int]] = None) -> None:
+ self._eval_model.reset(data_id=data_id)
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ return 'vac', ['ding.model.template.vac']
+
+ def _monitor_vars_learn(self) -> List[str]:
+ return super()._monitor_vars_learn() + [
+ 'policy_loss', 'value_loss', 'entropy_loss', 'adv_abs_max', 'approx_kl', 'clipfrac'
+ ]
+
+
+@POLICY_REGISTRY.register('ppo_lstm_command')
+class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy):
+ pass
diff --git a/DI-engine/dizoo/gfootball/replay.py b/DI-engine/dizoo/gfootball/replay.py
new file mode 100644
index 0000000000000000000000000000000000000000..32c26977778406ebc13a1263774fdd6f36614d38
--- /dev/null
+++ b/DI-engine/dizoo/gfootball/replay.py
@@ -0,0 +1,40 @@
+# coding=utf-8
+# Copyright 2019 Google LLC
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Script allowing to replay a given trace file.
+ Example usage:
+ python replay.py --trace_file=/tmp/dumps/shutdown_20190521-165136974075.dump
+"""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+from gfootball.env import script_helpers
+
+from absl import app
+from absl import flags
+
+FLAGS = flags.FLAGS
+
+flags.DEFINE_string('trace_file', None, 'Trace file to replay')
+flags.DEFINE_integer('fps', 10, 'How many frames per second to render')
+flags.mark_flag_as_required('trace_file')
+
+
+def main(_):
+ script_helpers.ScriptHelpers().replay(FLAGS.trace_file, FLAGS.fps)
+
+
+if __name__ == '__main__':
+ app.run(main)
diff --git a/DI-engine/dizoo/gym_anytrading/__init__.py b/DI-engine/dizoo/gym_anytrading/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/gym_anytrading/config/__init__.py b/DI-engine/dizoo/gym_anytrading/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb96be11120372aa94cdf627890814371219dfa4
--- /dev/null
+++ b/DI-engine/dizoo/gym_anytrading/config/__init__.py
@@ -0,0 +1 @@
+from .stocks_dqn_config import stocks_dqn_config, stocks_dqn_create_config
diff --git a/DI-engine/dizoo/gym_anytrading/config/stocks_dqn_config.py b/DI-engine/dizoo/gym_anytrading/config/stocks_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c05a1f597480629fa27b0cdced00599489f61041
--- /dev/null
+++ b/DI-engine/dizoo/gym_anytrading/config/stocks_dqn_config.py
@@ -0,0 +1,92 @@
+from easydict import EasyDict
+
+stocks_dqn_config = dict(
+ exp_name='stocks_dqn_seed0',
+ env=dict(
+ # Whether to use shared memory. Only effective if "env_manager_type" is 'subprocess'
+ # Env number respectively for collector and evaluator.
+ collector_env_num=8,
+ evaluator_env_num=8,
+ env_id='stocks-v0',
+ n_evaluator_episode=8,
+ stop_value=2,
+ # one trading year.
+ eps_length=253,
+ # associated with the feature length.
+ window_size=20,
+ # the path to save result image.
+ save_path='./fig/',
+ # the raw data file name
+ stocks_data_filename='STOCKS_GOOGL',
+ # the stocks range percentage used by train/test.
+ # if one of them is None, train & test set will use all data by default.
+ train_range=None,
+ test_range=None,
+ ),
+ policy=dict(
+ # Whether to use cuda for network.
+ cuda=True,
+ model=dict(
+ obs_shape=62,
+ action_shape=5,
+ encoder_hidden_size_list=[128],
+ head_layer_num=1,
+ # Whether to use dueling head.
+ dueling=True,
+ ),
+ # Reward's future discount factor, aka. gamma.
+ discount_factor=0.99,
+ # How many steps in td error.
+ nstep=5,
+ # learn_mode config
+ learn=dict(
+ update_per_collect=10,
+ batch_size=64,
+ learning_rate=0.001,
+ # Frequency of target network update.
+ target_update_freq=100,
+ ignore_done=True,
+ ),
+ # collect_mode config
+ collect=dict(
+ # You can use either "n_sample" or "n_episode" in collector.collect.
+ # Get "n_sample" samples per collect.
+ n_sample=64,
+ # Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ # command_mode config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=0.95,
+ end=0.1,
+ decay=50000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, )
+ ),
+ ),
+)
+stocks_dqn_config = EasyDict(stocks_dqn_config)
+main_config = stocks_dqn_config
+
+stocks_dqn_create_config = dict(
+ env=dict(
+ type='stocks-v0',
+ import_names=['dizoo.gym_anytrading.envs.stocks_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='dqn', ),
+ evaluator=dict(
+ type='trading_interaction',
+ import_names=['dizoo.gym_anytrading.worker'],
+ ),
+)
+stocks_dqn_create_config = EasyDict(stocks_dqn_create_config)
+create_config = stocks_dqn_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/gym_anytrading/envs/README.md b/DI-engine/dizoo/gym_anytrading/envs/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..0be3fe219b7553b1d6d704688935124a4e9cd39c
--- /dev/null
+++ b/DI-engine/dizoo/gym_anytrading/envs/README.md
@@ -0,0 +1,99 @@
+# DI-engine AnyTrading
+
+AnyTrading is a collection of OpenAI Gym environments for reinforcement learning-based trading algorithms.
+
+Based on original gym-anytrading environment (you can see that at https://github.com/AminHP/gym-anytrading), there are lots of modifications done to improve the original environment.
+
+In our environment, TradingEnv is an abstract environment which is defined to support all kinds of trading environments. StocksEnv, inheriting and extending TradingEnv, backtests the trading data of Google stock from 2009 to 2018.
+
+## Environment Properties
+
+The original design of gym-anytrading is quite simple, which aims at making the agent learn in a faster and more efficient way. However, we find that
+many defects of the original environment make it difficult to train agents, and the incomplete original environment is difficult to describe the real trading environment. Therefore, lots of modifications have been done. In the several following subsections, I will explain why these modifications are meaningful.
+
+### State Machine
+We use a state machine to describe how the TradingEnv interact with agent as well as how an agent make profits.
+
+As shown below, the state machine use three kinds of trading positions and five (action "Hold" does not shown) kinds of trading actions to describe how the transaction goes over time.
+
+![state machine](./statemachine.png)
+
+### Trading Positions
+
+Short:
+ If the current env is in Short state, it means that the agent borrowed stocks from the securities companies.
+
+Flat:
+ If the current env is in Flat state, it means that the agent does not hold shares.
+
+Long:
+ If the current env is in Long state, it means that the agent has changed all the funds into stocks.
+
+### Trading Actions
+
+Double_Sell:
+ means agent want sell all the stocks it holds as well as the stocks it borrows from securities companies.
+
+Sell:
+ means sell the stocks agent holds.
+
+Hold:
+ maintain current status.
+
+Buy:
+ means buy the stocks at current close price.
+
+Double_Buy:
+ means return shares to securities companies and exchange all the funds on hand for stocks at current close price.
+
+### How did the profit and loss happen
+
+If profit or loss occurs, it means that one of the following two cycles in state machine has occurred.
+
+- buying long
+ - Flat -> Long -> Flat
+- short selling
+ - Flat -> Short -> Flat
+
+### Current Profit Calculation
+
+According to the above definition, we can easily know that the formula of accumulative profit is:
+
+$\prod_{buying\ long}(r_{curr}/r_{pre}\ *\ cost) * \prod_{short\ selling}((2-r_{curr}/r_{pre})\ *\ cost)$
+
+
+### Reward Function
+
+
+
+Comparing the objective function ($\mathbb{E}_{\tau}\sum\ r$) in reinforcement learning and the formula of profit, we can get that the reward function is:
+
+- buying long:
+ - $log(close_{curr} / close_{pre})+log(cost)$
+- short selling:
+ - $log(2 - close_{curr} / close_{pre})+log(cost)$
+- otherwise:
+ - 0
+
+so that maximize $\mathbb{E}_{\tau} \sum r$
+is equivalent to maximize $\mathbb{E}_{\tau}[\prod_{buying\ long}(r_{curr}/r_{pre}\ *\ cost) + \prod_{short\ selling}((2-r_{curr}/r_{pre})\ *\ cost)]$
+
+The experimental results show that such a definition is better than the original gym-anytrading accumulated reward function :$\sum(r_{curr} - r_{pre})$.
+### Render Function
+
+ As you see, you can use `render` method to plot the position and profit at one episode.
+
+
+ - The position figure:
+ - The x-axis of the position figure is trading days. In this case, it is 252 trading days.
+ - The y-axis of the position figure is the closing price of each day.
+ - Besides, the red inverted triangle, the green positive triangle and the blue circle represent the position of the agent every trading day respectively.
+
+![position](./position.png)
+
+ - The profit figure:
+ - Similarly, The x-axis of the profit figure is trading days. In this case, it is 252 trading days. (a pair of pictures keep the same time interval)
+ - The y-axis of the profit figure is the profit of each day. 1.5 means the rate of return is 150%.
+
+![profit](./profit.png)
+
diff --git a/DI-engine/dizoo/gym_anytrading/envs/__init__.py b/DI-engine/dizoo/gym_anytrading/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..aae5455f6fda5ded9b436447e88f300e0627623e
--- /dev/null
+++ b/DI-engine/dizoo/gym_anytrading/envs/__init__.py
@@ -0,0 +1,2 @@
+from .trading_env import TradingEnv, Actions, Positions
+from .stocks_env import StocksEnv
diff --git a/DI-engine/dizoo/gym_anytrading/envs/data/README.md b/DI-engine/dizoo/gym_anytrading/envs/data/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..cb583350b4f30ab017f995d428340a5e50e3f57d
--- /dev/null
+++ b/DI-engine/dizoo/gym_anytrading/envs/data/README.md
@@ -0,0 +1,2 @@
+You can put stocks data here.
+Your data file needs to be named like "STOCKS_GOOGL.csv", which ends up with ".csv" suffix.
\ No newline at end of file
diff --git a/DI-engine/dizoo/gym_anytrading/envs/stocks_env.py b/DI-engine/dizoo/gym_anytrading/envs/stocks_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d34caa1827ec0f8c6ca34bda5b1d898aa3cf4fb
--- /dev/null
+++ b/DI-engine/dizoo/gym_anytrading/envs/stocks_env.py
@@ -0,0 +1,148 @@
+from pprint import pprint
+from typing import Any
+from copy import deepcopy
+import numpy as np
+
+from dizoo.gym_anytrading.envs.trading_env import TradingEnv, Actions, Positions, load_dataset
+from ding.utils import ENV_REGISTRY
+from ding.torch_utils import to_ndarray
+
+
+@ENV_REGISTRY.register('stocks-v0')
+class StocksEnv(TradingEnv):
+
+ def __init__(self, cfg):
+
+ super().__init__(cfg)
+
+ # ====== load Google stocks data =======
+ raw_data = load_dataset(self._cfg.stocks_data_filename, 'Date')
+ self.raw_prices = raw_data.loc[:, 'Close'].to_numpy()
+ EPS = 1e-10
+ self.df = deepcopy(raw_data)
+ if self.train_range == None or self.test_range == None:
+ self.df = self.df.apply(lambda x: (x - x.mean()) / (x.std() + EPS), axis=0)
+ else:
+ boundary = int(len(self.df) * self.train_range)
+ train_data = raw_data[:boundary].copy()
+ boundary = int(len(raw_data) * (1 + self.test_range))
+ test_data = raw_data[boundary:].copy()
+
+ train_data = train_data.apply(lambda x: (x - x.mean()) / (x.std() + EPS), axis=0)
+ test_data = test_data.apply(lambda x: (x - x.mean()) / (x.std() + EPS), axis=0)
+ self.df.loc[train_data.index, train_data.columns] = train_data
+ self.df.loc[test_data.index, test_data.columns] = test_data
+ # ======================================
+
+ # set cost
+ self.trade_fee_bid_percent = 0.01 # unit
+ self.trade_fee_ask_percent = 0.005 # unit
+
+ # override
+ def _process_data(self, start_idx: int = None) -> Any:
+ '''
+ Overview:
+ used by env.reset(), process the raw data.
+ Arguments:
+ - start_idx (int): the start tick; if None, then randomly select.
+ Returns:
+ - prices: the close.
+ - signal_features: feature map
+ - feature_dim_len: the dimension length of selected feature
+ '''
+
+ # ====== build feature map ========
+ all_feature_name = ['Close', 'Open', 'High', 'Low', 'Adj Close', 'Volume']
+ all_feature = {k: self.df.loc[:, k].to_numpy() for k in all_feature_name}
+ # add feature "Diff"
+ prices = self.df.loc[:, 'Close'].to_numpy()
+ diff = np.insert(np.diff(prices), 0, 0)
+ all_feature_name.append('Diff')
+ all_feature['Diff'] = diff
+ # =================================
+
+ # you can select features you want
+ selected_feature_name = ['Close', 'Diff', 'Volume']
+ selected_feature = np.column_stack([all_feature[k] for k in selected_feature_name])
+ feature_dim_len = len(selected_feature_name)
+
+ # validate index
+ if start_idx is None:
+ if self.train_range == None or self.test_range == None:
+ self.start_idx = np.random.randint(self.window_size - 1, len(self.df) - self._cfg.eps_length)
+ elif self._env_id[-1] == 'e':
+ boundary = int(len(self.df) * (1 + self.test_range))
+ assert len(self.df) - self._cfg.eps_length > boundary + self.window_size,\
+ "parameter test_range is too large!"
+ self.start_idx = np.random.randint(boundary + self.window_size, len(self.df) - self._cfg.eps_length)
+ else:
+ boundary = int(len(self.df) * self.train_range)
+ assert boundary - self._cfg.eps_length > self.window_size,\
+ "parameter test_range is too small!"
+ self.start_idx = np.random.randint(self.window_size, boundary - self._cfg.eps_length)
+ else:
+ self.start_idx = start_idx
+
+ self._start_tick = self.start_idx
+ self._end_tick = self._start_tick + self._cfg.eps_length - 1
+
+ return prices, selected_feature, feature_dim_len
+
+ # override
+ def _calculate_reward(self, action: int) -> np.float32:
+ step_reward = 0.
+ current_price = (self.raw_prices[self._current_tick])
+ last_trade_price = (self.raw_prices[self._last_trade_tick])
+ ratio = current_price / last_trade_price
+ cost = np.log((1 - self.trade_fee_ask_percent) * (1 - self.trade_fee_bid_percent))
+
+ if action == Actions.BUY and self._position == Positions.SHORT:
+ step_reward = np.log(2 - ratio) + cost
+
+ if action == Actions.SELL and self._position == Positions.LONG:
+ step_reward = np.log(ratio) + cost
+
+ if action == Actions.DOUBLE_SELL and self._position == Positions.LONG:
+ step_reward = np.log(ratio) + cost
+
+ if action == Actions.DOUBLE_BUY and self._position == Positions.SHORT:
+ step_reward = np.log(2 - ratio) + cost
+
+ step_reward = float(step_reward)
+
+ return step_reward
+
+ # override
+ def max_possible_profit(self) -> float:
+ current_tick = self._start_tick
+ last_trade_tick = current_tick - 1
+ profit = 1.
+
+ while current_tick <= self._end_tick:
+
+ if self.raw_prices[current_tick] < self.raw_prices[current_tick - 1]:
+ while (current_tick <= self._end_tick
+ and self.raw_prices[current_tick] < self.raw_prices[current_tick - 1]):
+ current_tick += 1
+
+ current_price = self.raw_prices[current_tick - 1]
+ last_trade_price = self.raw_prices[last_trade_tick]
+ tmp_profit = profit * (2 - (current_price / last_trade_price)) * (1 - self.trade_fee_ask_percent
+ ) * (1 - self.trade_fee_bid_percent)
+ profit = max(profit, tmp_profit)
+ else:
+ while (current_tick <= self._end_tick
+ and self.raw_prices[current_tick] >= self.raw_prices[current_tick - 1]):
+ current_tick += 1
+
+ current_price = self.raw_prices[current_tick - 1]
+ last_trade_price = self.raw_prices[last_trade_tick]
+ tmp_profit = profit * (current_price / last_trade_price) * (1 - self.trade_fee_ask_percent
+ ) * (1 - self.trade_fee_bid_percent)
+ profit = max(profit, tmp_profit)
+ last_trade_tick = current_tick - 1
+
+ return profit
+
+ def __repr__(self) -> str:
+ return "DI-engine Stocks Trading Env"
diff --git a/DI-engine/dizoo/gym_anytrading/envs/test_stocks_env.py b/DI-engine/dizoo/gym_anytrading/envs/test_stocks_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..fa76d780fa60ec7c5d8d79a4db34927110b34b0e
--- /dev/null
+++ b/DI-engine/dizoo/gym_anytrading/envs/test_stocks_env.py
@@ -0,0 +1,37 @@
+import pytest
+import numpy as np
+from easydict import EasyDict
+from dizoo.gym_anytrading.envs import StocksEnv
+
+
+@pytest.mark.envtest
+class TestStocksEnv:
+
+ def test_naive(self):
+ env = StocksEnv(EasyDict({"env_id": 'stocks-v0', "eps_length": 300,\
+ "window_size": 20, "train_range": None, "test_range": None, "stocks_data_filename": 'STOCKS_GOOGL'}))
+ env.seed(314, dynamic_seed=False)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (62, )
+ for _ in range(5):
+ env.reset()
+ np.random.seed(314)
+ print('=' * 60)
+ for i in range(10):
+ # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
+ # can generate legal random action.
+ if i < 5:
+ random_action = np.array([env.action_space.sample()])
+ else:
+ random_action = env.random_action()
+ timestep = env.step(random_action)
+ print(timestep)
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (62, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.reward_space.low
+ assert timestep.reward <= env.reward_space.high
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/gym_anytrading/envs/trading_env.py b/DI-engine/dizoo/gym_anytrading/envs/trading_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4ff57a057bf7fb5955986901295b5d80bb7a893
--- /dev/null
+++ b/DI-engine/dizoo/gym_anytrading/envs/trading_env.py
@@ -0,0 +1,302 @@
+from cmath import inf
+from typing import Any, List
+from easydict import EasyDict
+from abc import abstractmethod
+from gym import spaces
+from gym.utils import seeding
+from enum import Enum
+
+import os
+import gym
+import copy
+import pandas as pd
+import numpy as np
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.utils import ENV_REGISTRY
+from ding.torch_utils import to_ndarray
+
+
+def load_dataset(name, index_name):
+ base_dir = os.path.dirname(os.path.abspath(__file__))
+ path = os.path.join(base_dir, 'data', name + '.csv')
+ assert os.path.exists(
+ path
+ ), "You need to put the stock data under the \'DI-engine/dizoo/gym_anytrading/envs/data\' folder.\n \
+ if using StocksEnv, you can download Google stocks data at \
+ https://github.com/AminHP/gym-anytrading/blob/master/gym_anytrading/datasets/data/STOCKS_GOOGL.csv"
+
+ df = pd.read_csv(path, parse_dates=True, index_col=index_name)
+ return df
+
+
+class Actions(int, Enum):
+ DOUBLE_SELL = 0
+ SELL = 1
+ HOLD = 2
+ BUY = 3
+ DOUBLE_BUY = 4
+
+
+class Positions(int, Enum):
+ SHORT = -1.
+ FLAT = 0.
+ LONG = 1.
+
+
+def transform(position: Positions, action: int) -> Any:
+ '''
+ Overview:
+ used by env.tep().
+ This func is used to transform the env's position from
+ the input (position, action) pair according to the status machine.
+ Arguments:
+ - position(Positions) : Long, Short or Flat
+ - action(int) : Doulbe_Sell, Sell, Hold, Buy, Double_Buy
+ Returns:
+ - next_position(Positions) : the position after transformation.
+ '''
+ if action == Actions.SELL:
+
+ if position == Positions.LONG:
+ return Positions.FLAT, False
+
+ if position == Positions.FLAT:
+ return Positions.SHORT, True
+
+ if action == Actions.BUY:
+
+ if position == Positions.SHORT:
+ return Positions.FLAT, False
+
+ if position == Positions.FLAT:
+ return Positions.LONG, True
+
+ if action == Actions.DOUBLE_SELL and (position == Positions.LONG or position == Positions.FLAT):
+ return Positions.SHORT, True
+
+ if action == Actions.DOUBLE_BUY and (position == Positions.SHORT or position == Positions.FLAT):
+ return Positions.LONG, True
+
+ return position, False
+
+
+@ENV_REGISTRY.register('base_trading')
+class TradingEnv(BaseEnv):
+
+ def __init__(self, cfg: EasyDict) -> None:
+
+ self._cfg = cfg
+ self._env_id = cfg.env_id
+ #======== param to plot =========
+ self.cnt = 0
+
+ if 'plot_freq' not in self._cfg:
+ self.plot_freq = 10
+ else:
+ self.plot_freq = self._cfg.plot_freq
+ if 'save_path' not in self._cfg:
+ self.save_path = './'
+ else:
+ self.save_path = self._cfg.save_path
+ #================================
+
+ self.train_range = cfg.train_range
+ self.test_range = cfg.test_range
+ self.window_size = cfg.window_size
+ self.prices = None
+ self.signal_features = None
+ self.feature_dim_len = None
+ self.shape = (cfg.window_size, 3)
+
+ #======== param about episode =========
+ self._start_tick = 0
+ self._end_tick = 0
+ self._done = None
+ self._current_tick = None
+ self._last_trade_tick = None
+ self._position = None
+ self._position_history = None
+ self._total_reward = None
+ #======================================
+
+ self._init_flag = True
+ # init the following variables variable at first reset.
+ self._action_space = None
+ self._observation_space = None
+ self._reward_space = None
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+ self.np_random, seed = seeding.np_random(seed)
+
+ def reset(self, start_idx: int = None) -> Any:
+ self.cnt += 1
+ self.prices, self.signal_features, self.feature_dim_len = self._process_data(start_idx)
+ if self._init_flag:
+ self.shape = (self.window_size, self.feature_dim_len)
+ self._action_space = spaces.Discrete(len(Actions))
+ self._observation_space = spaces.Box(low=-np.inf, high=np.inf, shape=self.shape, dtype=np.float64)
+ self._reward_space = gym.spaces.Box(-inf, inf, shape=(1, ), dtype=np.float32)
+ self._init_flag = False
+ self._done = False
+ self._current_tick = self._start_tick
+ self._last_trade_tick = self._current_tick - 1
+ self._position = Positions.FLAT
+ self._position_history = [self._position]
+ self._profit_history = [1.]
+ self._total_reward = 0.
+
+ return self._get_observation()
+
+ def random_action(self) -> Any:
+ return np.array([self.action_space.sample()])
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ assert isinstance(action, np.ndarray), type(action)
+ if action.shape == (1, ):
+ action = action.item() # 0-dim array
+
+ self._done = False
+ self._current_tick += 1
+
+ if self._current_tick >= self._end_tick:
+ self._done = True
+
+ step_reward = self._calculate_reward(action)
+ self._total_reward += step_reward
+
+ self._position, trade = transform(self._position, action)
+
+ if trade:
+ self._last_trade_tick = self._current_tick
+
+ self._position_history.append(self._position)
+ self._profit_history.append(float(np.exp(self._total_reward)))
+ observation = self._get_observation()
+ info = dict(
+ total_reward=self._total_reward,
+ position=self._position.value,
+ )
+
+ if self._done:
+ if self._env_id[-1] == 'e' and self.cnt % self.plot_freq == 0:
+ self.render()
+ info['max_possible_profit'] = np.log(self.max_possible_profit())
+ info['eval_episode_return'] = self._total_reward
+
+ step_reward = to_ndarray([step_reward]).astype(np.float32)
+ return BaseEnvTimestep(observation, step_reward, self._done, info)
+
+ def _get_observation(self) -> np.ndarray:
+ obs = to_ndarray(self.signal_features[(self._current_tick - self.window_size + 1):self._current_tick + 1]
+ ).reshape(-1).astype(np.float32)
+
+ tick = (self._current_tick - self._last_trade_tick) / self._cfg.eps_length
+ obs = np.hstack([obs, to_ndarray([self._position.value]), to_ndarray([tick])]).astype(np.float32)
+ return obs
+
+ def render(self) -> None:
+ import matplotlib.pyplot as plt
+ plt.clf()
+ plt.xlabel('trading days')
+ plt.ylabel('profit')
+ plt.plot(self._profit_history)
+ plt.savefig(self.save_path + str(self._env_id) + "-profit.png")
+
+ plt.clf()
+ plt.xlabel('trading days')
+ plt.ylabel('close price')
+ window_ticks = np.arange(len(self._position_history))
+ eps_price = self.raw_prices[self._start_tick:self._end_tick + 1]
+ plt.plot(eps_price)
+
+ short_ticks = []
+ long_ticks = []
+ flat_ticks = []
+ for i, tick in enumerate(window_ticks):
+ if self._position_history[i] == Positions.SHORT:
+ short_ticks.append(tick)
+ elif self._position_history[i] == Positions.LONG:
+ long_ticks.append(tick)
+ else:
+ flat_ticks.append(tick)
+
+ plt.plot(long_ticks, eps_price[long_ticks], 'g^', markersize=3, label="Long")
+ plt.plot(flat_ticks, eps_price[flat_ticks], 'bo', markersize=3, label="Flat")
+ plt.plot(short_ticks, eps_price[short_ticks], 'rv', markersize=3, label="Short")
+ plt.legend(loc='upper left', bbox_to_anchor=(0.05, 0.95))
+ plt.savefig(self.save_path + str(self._env_id) + '-price.png')
+
+ def close(self):
+ import matplotlib.pyplot as plt
+ plt.close()
+
+ # override
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ """
+ Overview:
+ Return a list of all of the environment from input config, used in env manager \
+ (a series of vectorized env), and this method is mainly responsible for envs collecting data.
+ In TradingEnv, this method will rename every env_id and generate different config.
+ Arguments:
+ - cfg (:obj:`dict`): Original input env config, which needs to be transformed into the type of creating \
+ env instance actually and generated the corresponding number of configurations.
+ Returns:
+ - env_cfg_list (:obj:`List[dict]`): List of ``cfg`` including all the config collector envs.
+ .. note::
+ Elements(env config) in collector_env_cfg/evaluator_env_cfg can be different, such as server ip and port.
+ """
+ collector_env_num = cfg.pop('collector_env_num')
+ collector_env_cfg = [copy.deepcopy(cfg) for _ in range(collector_env_num)]
+ for i in range(collector_env_num):
+ collector_env_cfg[i]['env_id'] += ('-' + str(i) + 'e')
+ return collector_env_cfg
+
+ # override
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ """
+ Overview:
+ Return a list of all of the environment from input config, used in env manager \
+ (a series of vectorized env), and this method is mainly responsible for envs evaluating performance.
+ In TradingEnv, this method will rename every env_id and generate different config.
+ Arguments:
+ - cfg (:obj:`dict`): Original input env config, which needs to be transformed into the type of creating \
+ env instance actually and generated the corresponding number of configurations.
+ Returns:
+ - env_cfg_list (:obj:`List[dict]`): List of ``cfg`` including all the config evaluator envs.
+ """
+ evaluator_env_num = cfg.pop('evaluator_env_num')
+ evaluator_env_cfg = [copy.deepcopy(cfg) for _ in range(evaluator_env_num)]
+ for i in range(evaluator_env_num):
+ evaluator_env_cfg[i]['env_id'] += ('-' + str(i) + 'e')
+ return evaluator_env_cfg
+
+ @abstractmethod
+ def _process_data(self):
+ raise NotImplementedError
+
+ @abstractmethod
+ def _calculate_reward(self, action):
+ raise NotImplementedError
+
+ @abstractmethod
+ def max_possible_profit(self):
+ raise NotImplementedError
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine Trading Env"
diff --git a/DI-engine/dizoo/gym_anytrading/worker/__init__.py b/DI-engine/dizoo/gym_anytrading/worker/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ef31b522e020c8311c390db51d55182d4a60f77b
--- /dev/null
+++ b/DI-engine/dizoo/gym_anytrading/worker/__init__.py
@@ -0,0 +1,2 @@
+import imp
+from .trading_serial_evaluator import *
diff --git a/DI-engine/dizoo/gym_anytrading/worker/trading_serial_evaluator.py b/DI-engine/dizoo/gym_anytrading/worker/trading_serial_evaluator.py
new file mode 100644
index 0000000000000000000000000000000000000000..4287d484230f94273119680de3b2a683445e357e
--- /dev/null
+++ b/DI-engine/dizoo/gym_anytrading/worker/trading_serial_evaluator.py
@@ -0,0 +1,222 @@
+from typing import Any, Optional, Callable, Tuple
+from collections import deque, namedtuple
+from easydict import EasyDict
+import torch
+import numpy as np
+
+from ding.envs import BaseEnvManager
+from ding.worker import VectorEvalMonitor, InteractionSerialEvaluator
+from ding.torch_utils import to_tensor, to_ndarray, to_item
+from ding.utils import SERIAL_EVALUATOR_REGISTRY, import_module
+
+
+@SERIAL_EVALUATOR_REGISTRY.register('trading_interaction')
+class TradingSerialEvaluator(InteractionSerialEvaluator):
+ """
+ Overview:
+ Trading interaction serial evaluator class, policy interacts with anytrading env.
+ Interfaces:
+ __init__, reset, reset_policy, reset_env, close, should_eval, eval
+ Property:
+ env, policy
+ """
+ config = dict(
+ # Evaluate every "eval_freq" training iterations.
+ eval_freq=1000,
+ render=dict(
+ # tensorboard video render is disabled by default
+ render_freq=-1,
+ mode='train_iter',
+ ),
+ type='trading_interaction',
+ )
+
+ def __init__(
+ self,
+ cfg: dict,
+ env: BaseEnvManager = None,
+ policy: namedtuple = None,
+ tb_logger: 'SummaryWriter' = None, # noqa
+ exp_name: Optional[str] = 'default_experiment',
+ instance_name: Optional[str] = 'evaluator',
+ ) -> None:
+ """
+ Overview:
+ Init method. Just init super class.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Configuration EasyDict.
+ """
+ super().__init__(cfg, env, policy, tb_logger, exp_name, instance_name)
+
+ def eval(
+ self,
+ save_ckpt_fn: Callable = None,
+ train_iter: int = -1,
+ envstep: int = -1,
+ n_episode: Optional[int] = None,
+ force_render: bool = False,
+ ) -> Tuple[bool, dict]:
+ '''
+ Overview:
+ Evaluate policy and store the best policy based on whether it reaches the highest historical reward.
+ Arguments:
+ - save_ckpt_fn (:obj:`Callable`): Saving ckpt function, which will be triggered by getting the best reward.
+ - train_iter (:obj:`int`): Current training iteration.
+ - envstep (:obj:`int`): Current env interaction step.
+ - n_episode (:obj:`int`): Number of evaluation episodes.
+ Returns:
+ - stop_flag (:obj:`bool`): Whether this training program can be ended.
+ - episode_info (:obj:`dict`): Current evaluation return information.
+ '''
+
+ if n_episode is None:
+ n_episode = self._default_n_episode
+ assert n_episode is not None, "please indicate eval n_episode"
+ envstep_count = 0
+ info = {}
+ eval_monitor = TradingEvalMonitor(self._env.env_num, n_episode)
+ self._env.reset()
+ self._policy.reset()
+
+ # force_render overwrite frequency constraint
+ render = force_render or self._should_render(envstep, train_iter)
+
+ with self._timer:
+ while not eval_monitor.is_finished():
+ obs = self._env.ready_obs
+ obs = to_tensor(obs, dtype=torch.float32)
+
+ # update videos
+ if render:
+ eval_monitor.update_video(self._env.ready_imgs)
+
+ policy_output = self._policy.forward(obs)
+ actions = {i: a['action'] for i, a in policy_output.items()}
+ actions = to_ndarray(actions)
+ timesteps = self._env.step(actions)
+ timesteps = to_tensor(timesteps, dtype=torch.float32)
+ for env_id, t in timesteps.items():
+ if t.info.get('abnormal', False):
+ # If there is an abnormal timestep, reset all the related variables(including this env).
+ self._policy.reset([env_id])
+ continue
+ if t.done:
+ # Env reset is done by env_manager automatically.
+ self._policy.reset([env_id])
+ reward = t.info['eval_episode_return']
+ eval_monitor.update_info(env_id, t.info)
+ eval_monitor.update_reward(env_id, reward)
+
+ #========== only used by anytrading =======
+ if 'max_possible_profit' in t.info:
+ max_profit = t.info['max_possible_profit']
+ eval_monitor.update_max_profit(env_id, max_profit)
+ #==========================================
+
+ self._logger.info(
+ "[EVALUATOR]env {} finish episode, final reward: {}, current episode: {}".format(
+ env_id, eval_monitor.get_latest_reward(env_id), eval_monitor.get_current_episode()
+ )
+ )
+ envstep_count += 1
+ duration = self._timer.value
+ episode_return = eval_monitor.get_episode_return()
+ info = {
+ 'train_iter': train_iter,
+ 'ckpt_name': 'iteration_{}.pth.tar'.format(train_iter),
+ 'episode_count': n_episode,
+ 'envstep_count': envstep_count,
+ 'avg_envstep_per_episode': envstep_count / n_episode,
+ 'evaluate_time': duration,
+ 'avg_envstep_per_sec': envstep_count / duration,
+ 'avg_time_per_episode': n_episode / duration,
+ 'reward_mean': np.mean(episode_return),
+ 'reward_std': np.std(episode_return),
+ 'reward_max': np.max(episode_return),
+ 'reward_min': np.min(episode_return),
+ # 'each_reward': episode_return,
+ }
+ episode_info = eval_monitor.get_episode_info()
+ if episode_info is not None:
+ info.update(episode_info)
+ self._logger.info(self._logger.get_tabulate_vars_hor(info))
+ # self._logger.info(self._logger.get_tabulate_vars(info))
+ for k, v in info.items():
+ if k in ['train_iter', 'ckpt_name', 'each_reward']:
+ continue
+ if not np.isscalar(v):
+ continue
+ self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
+ self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep)
+
+ #========== only used by anytrading =======
+ max_possible_profit = eval_monitor.get_max_episode_profit()
+ info_anytrading = {
+ 'max_possible_profit_max': np.max(max_possible_profit),
+ 'max_possible_profit_mean': np.mean(max_possible_profit),
+ 'max_possible_profit_min': np.min(max_possible_profit),
+ }
+ for k, v in info_anytrading.items():
+ if not np.isscalar(v):
+ continue
+ self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
+ self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, envstep)
+ #==========================================
+
+ if render:
+ video_title = '{}_{}/'.format(self._instance_name, self._render.mode)
+ videos = eval_monitor.get_video()
+ render_iter = envstep if self._render.mode == 'envstep' else train_iter
+ from ding.utils import fps
+ self._tb_logger.add_video(video_title, videos, render_iter, fps(self._env))
+
+ episode_return = np.mean(episode_return)
+ if episode_return > self._max_episode_return:
+ if save_ckpt_fn:
+ save_ckpt_fn('ckpt_best.pth.tar')
+ self._max_episode_return = episode_return
+ stop_flag = episode_return >= self._stop_value and train_iter > 0
+ if stop_flag:
+ self._logger.info(
+ "[DI-engine serial pipeline] " +
+ "Current episode_return: {} is greater than stop_value: {}".format(episode_return, self._stop_value) +
+ ", so your RL agent is converged, you can refer to 'log/evaluator/evaluator_logger.txt' for details."
+ )
+ episode_info = to_item(episode_info)
+ return stop_flag, episode_info
+
+
+class TradingEvalMonitor(VectorEvalMonitor):
+ """
+ Overview:
+ Inherit VectorEvalMonitor for trading env.
+ Add func update_max_profit and get_max_episode_profit in order to log the max_profit for every episode.
+ Interfaces:
+ Besides (__init__, is_finished, update_info, update_reward, get_episode_return,\
+ get_latest_reward, get_current_episode, get_episode_info), there are\
+ (update_max_profit, get_max_episode_profit).
+ """
+
+ def __init__(self, env_num: int, n_episode: int) -> None:
+ super().__init__(env_num, n_episode)
+
+ self._each_env_episode = [n_episode // env_num for _ in range(env_num)]
+ self._max_possible_profit = {
+ env_id: deque(maxlen=maxlen)
+ for env_id, maxlen in enumerate(self._each_env_episode)
+ }
+
+ def update_max_profit(self, env_id: int, max_profit: Any) -> None:
+ """
+ Overview:
+ Update the max profit indicated by env_id.
+ Arguments:
+ - env_id: (:obj:`int`): the id of the environment we need to update the max profit
+ - max_profit: (:obj:`Any`): the profit we need to update
+ """
+ if isinstance(max_profit, torch.Tensor):
+ max_profit = max_profit.item()
+ self._max_possible_profit[env_id].append(max_profit)
+
+ def get_max_episode_profit(self) -> list:
+ return sum([list(v) for v in self._max_possible_profit.values()], [])
diff --git a/DI-engine/dizoo/gym_hybrid/__init__.py b/DI-engine/dizoo/gym_hybrid/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/gym_hybrid/config/__init__.py b/DI-engine/dizoo/gym_hybrid/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/gym_hybrid/config/gym_hybrid_ddpg_config.py b/DI-engine/dizoo/gym_hybrid/config/gym_hybrid_ddpg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..854f5f393928ed899cf74b02c8c644870c7f79d5
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/config/gym_hybrid_ddpg_config.py
@@ -0,0 +1,71 @@
+from easydict import EasyDict
+
+gym_hybrid_ddpg_config = dict(
+ exp_name='gym_hybrid_ddpg_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range [-1, 1].
+ act_scale=True,
+ env_id='Moving-v0', # ['Sliding-v0', 'Moving-v0']
+ n_evaluator_episode=5,
+ stop_value=1.8,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ random_collect_size=0, # hybrid action space not support random collect now
+ action_space='hybrid',
+ model=dict(
+ obs_shape=10,
+ action_shape=dict(
+ action_type_shape=3,
+ action_args_shape=2,
+ ),
+ twin_critic=False,
+ action_space='hybrid',
+ ),
+ learn=dict(
+ update_per_collect=10, # 5~10
+ batch_size=32,
+ discount_factor=0.99,
+ learning_rate_actor=0.0003, # 0.001 ~ 0.0003
+ learning_rate_critic=0.001,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=32,
+ noise_sigma=0.1,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.1,
+ decay=100000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+gym_hybrid_ddpg_config = EasyDict(gym_hybrid_ddpg_config)
+main_config = gym_hybrid_ddpg_config
+
+gym_hybrid_ddpg_create_config = dict(
+ env=dict(
+ type='gym_hybrid',
+ import_names=['dizoo.gym_hybrid.envs.gym_hybrid_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ddpg'),
+)
+gym_hybrid_ddpg_create_config = EasyDict(gym_hybrid_ddpg_create_config)
+create_config = gym_hybrid_ddpg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c gym_hybrid_ddpg_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/gym_hybrid/config/gym_hybrid_hppo_config.py b/DI-engine/dizoo/gym_hybrid/config/gym_hybrid_hppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2011972e19c92c9089dd1d40be09903a7ee9b9ab
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/config/gym_hybrid_hppo_config.py
@@ -0,0 +1,64 @@
+from easydict import EasyDict
+
+gym_hybrid_hppo_config = dict(
+ exp_name='gym_hybrid_hppo_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range, usually [-1, 1].
+ act_scale=True,
+ env_id='Moving-v0', # ['Sliding-v0', 'Moving-v0']
+ n_evaluator_episode=5,
+ stop_value=1.8,
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='hybrid',
+ recompute_adv=True,
+ model=dict(
+ obs_shape=10,
+ action_shape=dict(
+ action_type_shape=3,
+ action_args_shape=2,
+ ),
+ action_space='hybrid',
+ encoder_hidden_size_list=[256, 128, 64, 64],
+ sigma_type='fixed',
+ fixed_sigma_value=0.3,
+ bound_type='tanh',
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=320,
+ learning_rate=3e-4,
+ entropy_weight=0.5,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=3200,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=200, ), ),
+ ),
+)
+gym_hybrid_hppo_config = EasyDict(gym_hybrid_hppo_config)
+main_config = gym_hybrid_hppo_config
+
+gym_hybrid_hppo_create_config = dict(
+ env=dict(
+ type='gym_hybrid',
+ import_names=['dizoo.gym_hybrid.envs.gym_hybrid_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+gym_hybrid_hppo_create_config = EasyDict(gym_hybrid_hppo_create_config)
+create_config = gym_hybrid_hppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c gym_hybrid_hppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/gym_hybrid/config/gym_hybrid_mpdqn_config.py b/DI-engine/dizoo/gym_hybrid/config/gym_hybrid_mpdqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd2bf323432c153e490878c8c4212d9a25c1c957
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/config/gym_hybrid_mpdqn_config.py
@@ -0,0 +1,78 @@
+from easydict import EasyDict
+
+gym_hybrid_mpdqn_config = dict(
+ exp_name='gym_hybrid_mpdqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range [-1, 1].
+ act_scale=True,
+ env_id='Moving-v0', # ['Sliding-v0', 'Moving-v0']
+ n_evaluator_episode=5,
+ stop_value=1.8,
+ ),
+ policy=dict(
+ cuda=True,
+ discount_factor=0.99,
+ nstep=1,
+ model=dict(
+ obs_shape=10,
+ action_shape=dict(
+ action_type_shape=3,
+ action_args_shape=2,
+ ),
+ multi_pass=True,
+ action_mask=[[1, 0], [0, 1], [0, 0]],
+ ),
+ learn=dict(
+ update_per_collect=500, # 10~500
+ batch_size=320,
+ learning_rate_dis=3e-4,
+ learning_rate_cont=3e-4,
+ target_theta=0.001,
+ update_circle=10,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] shoule be set
+ n_sample=3200,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ noise_sigma=0.1,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=1,
+ end=0.1,
+ # (int) Decay length(env step)
+ decay=int(1e5),
+ ),
+ replay_buffer=dict(replay_buffer_size=int(1e6), ),
+ ),
+ )
+)
+
+gym_hybrid_mpdqn_config = EasyDict(gym_hybrid_mpdqn_config)
+main_config = gym_hybrid_mpdqn_config
+
+gym_hybrid_mpdqn_create_config = dict(
+ env=dict(
+ type='gym_hybrid',
+ import_names=['dizoo.gym_hybrid.envs.gym_hybrid_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='pdqn'),
+)
+gym_hybrid_mpdqn_create_config = EasyDict(gym_hybrid_mpdqn_create_config)
+create_config = gym_hybrid_mpdqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c gym_hybrid_mpdqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/gym_hybrid/config/gym_hybrid_pdqn_config.py b/DI-engine/dizoo/gym_hybrid/config/gym_hybrid_pdqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..79d28b07bd888bb98584de84958ac43a996ef84f
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/config/gym_hybrid_pdqn_config.py
@@ -0,0 +1,76 @@
+from easydict import EasyDict
+
+gym_hybrid_pdqn_config = dict(
+ exp_name='gym_hybrid_pdqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range [-1, 1].
+ act_scale=True,
+ env_id='Moving-v0', # ['Sliding-v0', 'Moving-v0']
+ n_evaluator_episode=5,
+ stop_value=1.8,
+ ),
+ policy=dict(
+ cuda=True,
+ discount_factor=0.99,
+ nstep=1,
+ model=dict(
+ obs_shape=10,
+ action_shape=dict(
+ action_type_shape=3,
+ action_args_shape=2,
+ ),
+ ),
+ learn=dict(
+ update_per_collect=500, # 10~500
+ batch_size=320,
+ learning_rate_dis=3e-4,
+ learning_rate_cont=3e-4,
+ target_theta=0.001,
+ update_circle=10,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] shoule be set
+ n_sample=3200, # 128,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ noise_sigma=0.1, # 0.05,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=1,
+ end=0.1,
+ # (int) Decay length(env step)
+ decay=int(1e5),
+ ),
+ replay_buffer=dict(replay_buffer_size=int(1e6), ),
+ ),
+ )
+)
+
+gym_hybrid_pdqn_config = EasyDict(gym_hybrid_pdqn_config)
+main_config = gym_hybrid_pdqn_config
+
+gym_hybrid_pdqn_create_config = dict(
+ env=dict(
+ type='gym_hybrid',
+ import_names=['dizoo.gym_hybrid.envs.gym_hybrid_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='pdqn'),
+)
+gym_hybrid_pdqn_create_config = EasyDict(gym_hybrid_pdqn_create_config)
+create_config = gym_hybrid_pdqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c gym_hybrid_pdqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/gym_hybrid/entry/__init__.py b/DI-engine/dizoo/gym_hybrid/entry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/gym_hybrid/entry/gym_hybrid_ddpg_eval.py b/DI-engine/dizoo/gym_hybrid/entry/gym_hybrid_ddpg_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..388928a61bff0d5255c412600bd3a7ef3ab6607b
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/entry/gym_hybrid_ddpg_eval.py
@@ -0,0 +1,51 @@
+import os
+import gym
+import torch
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+from functools import partial
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.envs import get_vec_env_setting
+from ding.policy import DDPGPolicy
+from ding.model import ContinuousQAC
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.gym_hybrid.config.gym_hybrid_ddpg_config import gym_hybrid_ddpg_config, gym_hybrid_ddpg_create_config
+
+
+def main(main_cfg, create_cfg, seed=0):
+ # Specify evaluation arguments
+ main_cfg.policy.load_path = './ckpt_best.pth.tar'
+ main_cfg.env.replay_path = './'
+ main_cfg.env.evaluator_env_num = 1 # only 1 env for save replay
+ cfg = compile_config(main_cfg, seed=seed, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ evaluator_env = BaseEnvManager([partial(env_fn, cfg=c) for c in evaluator_env_cfg], cfg.env.manager)
+
+ evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
+
+ # Set random seed for all package and instance
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = ContinuousQAC(**cfg.policy.model)
+ policy = DDPGPolicy(cfg.policy, model=model)
+ policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
+
+ # evaluate
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator.eval()
+
+
+if __name__ == "__main__":
+ # gym_hybrid environmrnt rendering is using API from "gym.envs.classic_control.rendering"
+ # which is abandoned in gym >= 0.22.0, please check the gym version before rendering.
+ main(gym_hybrid_ddpg_config, gym_hybrid_ddpg_create_config, seed=0)
diff --git a/DI-engine/dizoo/gym_hybrid/entry/gym_hybrid_ddpg_main.py b/DI-engine/dizoo/gym_hybrid/entry/gym_hybrid_ddpg_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaa8753d6bd23e134265413fc94aa8c9d1593a11
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/entry/gym_hybrid_ddpg_main.py
@@ -0,0 +1,95 @@
+import os
+import gym
+import gym_hybrid
+from tensorboardX import SummaryWriter
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager
+from ding.policy import DDPGPolicy
+from ding.model import ContinuousQAC
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.gym_hybrid.envs.gym_hybrid_env import GymHybridEnv
+from dizoo.gym_hybrid.config.gym_hybrid_ddpg_config import gym_hybrid_ddpg_config
+
+
+def main(cfg, seed=0):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ DDPGPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ save_cfg=True
+ )
+
+ # Set up envs for collection and evaluation
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ # You can either use `PendulumEnv` or `DingEnvWrapper` to make a pendulum env and therefore an env manager.
+ # == Use `DingEnvWrapper`
+ collector_env = BaseEnvManager(
+ env_fn=[lambda: GymHybridEnv(cfg=cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManager(
+ env_fn=[lambda: GymHybridEnv(cfg=cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = ContinuousQAC(**cfg.policy.model)
+ policy = DDPGPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Set up other modules, etc. epsilon greedy
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluate at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Collect data from environments
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Train
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+
+ # evaluate
+ evaluator_env = BaseEnvManager(
+ env_fn=[lambda: GymHybridEnv(cfg=cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(gym_hybrid_ddpg_config, seed=0)
diff --git a/DI-engine/dizoo/gym_hybrid/envs/README.md b/DI-engine/dizoo/gym_hybrid/envs/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..a89b9c113e02d86aa75da2bd5ab3366d61cb74ca
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/README.md
@@ -0,0 +1,21 @@
+# Modified gym-hybrid
+
+The gym-hybrid directory is modified from https://github.com/thomashirtz/gym-hybrid.
+We add the HardMove environment additionally. (Please refer to https://arxiv.org/abs/2109.05490 Section 5.1 for details about HardMove env.)
+
+Specifically, the modified gym-hybrid contains the following three types of environments:
+
+- Moving-v0
+- Sliding-v0
+- HardMove-v0
+
+### Install Guide
+
+```bash
+cd DI-engine/dizoo/gym_hybrid/envs/gym-hybrid
+pip install -e .
+```
+
+## Acknowledgement
+
+https://github.com/thomashirtz/gym-hybrid
\ No newline at end of file
diff --git a/DI-engine/dizoo/gym_hybrid/envs/__init__.py b/DI-engine/dizoo/gym_hybrid/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..fea5c1f9586b6d73c6d6dc1cbbf55393ab4a27e0
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/__init__.py
@@ -0,0 +1 @@
+from .gym_hybrid_env import GymHybridEnv
diff --git a/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/README.md b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..153b07db967d4cdf7a6ed28d9473a154e8d0b8b2
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/README.md
@@ -0,0 +1,147 @@
+# gym-hybrid
+
+Repository containing a collection of environment for reinforcement learning task possessing discrete-continuous hybrid action space.
+
+## "Sliding-v0" and "Moving-v0"
+
+
+
+"Moving-v0" and "Sliding-v0" are sandbox environments for parameterized action-space algorithms. The goal of the agent is to stop inside a target area.
+
+The field is a square with a side length of 2. The target area is a circle with radius 0.1. There is three discrete actions: turn, accelerate, and break. In addition to the action, there is 2 possible complementary parameters: acceleration and rotation.
+
+The episode terminates if one of the three condition is filled:
+* the agent stop inside the target area,
+* the agent leaves the field,
+* the step count is higher than the limit (set by default at 200).
+
+The moving environment doesn't take into account the conservation of inertia, while the sliding environment does. `Sliding-v0` is therefore more realistic than `Moving-v0`.
+
+All the parameters, actions, states and rewards are the same between the two environments. Only the underlying physics changes.
+
+### State
+The [state](https://github.com/thomashirtz/gym-hybrid/blob/fee4bf5de2dc1dd0d2a5431498124b2c071a2344/gym_hybrid/environments.py#L126) is constituted of a list of 10 elements. The environment related values are: the current step divided by the maximum step, and the position of the target (x and y). The player related values are the position (x and y), the speed, the direction (cosine and sine), the distance related to the target, and an indicator that becomes 1 if the player is inside the target zone.
+```python
+state = [
+ agent.x,
+ agent.y,
+ agent.speed,
+ np.cos(agent.theta),
+ np.sin(agent.theta),
+ target.x,
+ target.y,
+ distance,
+ 0 if distance > target_radius else 1,
+ current_step / max_step
+]
+```
+
+### Reward
+The [reward](https://github.com/thomashirtz/gym-hybrid/blob/fee4bf5de2dc1dd0d2a5431498124b2c071a2344/gym_hybrid/environments.py#L141) is the distance of the agent from the target of the last step minus the current distance. There is a penalty (set by default at a low value) to incentivize the learning algorithm to score as quickly as possible. A bonus reward of one is added if the player achieve to stop inside the target area. A malus of one is applied if the step count exceed the limit or if the player leaves the field.
+
+### Actions
+
+**The action ids are:**
+1. Accelerate
+2. Turn
+3. Break
+
+**The parameters are:**
+1. Acceleration value
+2. Rotation value
+
+**There is two distinct way to format an action:**
+
+Action with all the parameters (convenient if the model output all the parameters):
+```python
+action = (action_id, [acceleration_value, rotation_value])
+```
+Example of a valid actions:
+```python
+action = (0, [0.1, 0.4])
+action = (1, [0.0, 0.2])
+action = (2, [0.1, 0.3])
+```
+Note: Only the parameter related to the action chosen will be used.
+
+Action with only the parameter related to the action id (convenient for algorithms that output only the parameter
+of the chosen action, since it doesn't require to pad the action):
+```python
+action = (0, [acceleration_value])
+action = (1, [rotation_value])
+action = (2, [])
+```
+Example of valid actions:
+```python
+action = (0, [0.1])
+action = (1, [0.2])
+action = (2, [])
+```
+### Basics
+Make and initialize an environment:
+```python
+import gym
+import gym_parametrized
+
+sliding_env = gym.make('Sliding-v0')
+sliding_env.reset()
+
+moving_env = gym.make('Moving-v0')
+moving_env.reset()
+```
+
+Get the action space and the observation space:
+```python
+ACTION_SPACE = env.action_space[0].n
+PARAMETERS_SPACE = env.action_space[1].shape[0]
+OBSERVATION_SPACE = env.observation_space.shape[0]
+```
+
+Run a random agent:
+```python
+done = False
+while not done:
+ state, reward, done, info = env.step(env.action_space.sample())
+ print(f'State: {state} Reward: {reward} Done: {done}')
+```
+### Parameters
+The parameter that can be modified during the initialization are:
+* `seed` (default = None)
+* `max_turn`, angle in radi that can be achieved in one step (default = np.pi/2)
+* `max_acceleration`, acceleration that can be achieved in one step (if the input parameter is 1) (default = 0.5)
+* `delta_t`, time step of one step (default = 0.005)
+* `max_step`, limit of the number of step before the end of an environment (default = 200)
+* `penalty`, value substracted to the reward each step to incentivise the agent to finish the environment quicker (default = 0.001)
+
+Initialization with custom parameters:
+```python
+env = gym.make(
+ 'Moving-v0',
+ seed=0,
+ max_turn=1,
+ max_acceleration=1.0,
+ delta_t=0.001,
+ max_step=500,
+ penalty=0.01
+)
+```
+
+### Render & Recording
+Two testing files are avalaible to show users how to render and record the environment:
+* [Python file example for recording](tests/moving_record.py)
+* [Python file example for rendering](tests/moving_render.py)
+
+## Disclaimer
+Even though the mechanics of the environment are done, maybe the hyperparameters will need some further adjustments.
+
+## Reference
+This environment is described in several papers such as:
+[Parametrized Deep Q-Networks Learning, Xiong et al., 2018](https://arxiv.org/pdf/1810.06394.pdf)
+[Hybrid Actor-Critic Reinforcement Learning in Parameterized Action Space, Fan et al., 2019](https://arxiv.org/pdf/1903.01344.pdf)
+
+## Installation
+
+Direct Installation from github using pip by running this command:
+```shell
+pip install git+https://github.com/thomashirtz/gym-hybrid#egg=gym-hybrid
+```
diff --git a/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/__init__.py b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..89cb5d7764e4c0da57bf5bc79acb9cb1b8183f13
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/__init__.py
@@ -0,0 +1,17 @@
+from gym.envs.registration import register
+from gym_hybrid.environments import MovingEnv
+from gym_hybrid.environments import SlidingEnv
+from gym_hybrid.environments import HardMoveEnv
+
+register(
+ id='Moving-v0',
+ entry_point='gym_hybrid:MovingEnv',
+)
+register(
+ id='Sliding-v0',
+ entry_point='gym_hybrid:SlidingEnv',
+)
+register(
+ id='HardMove-v0',
+ entry_point='gym_hybrid:HardMoveEnv',
+)
diff --git a/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/agents.py b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/agents.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b8669ee982a0c0585f61e22bfc7c1db84f943de
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/agents.py
@@ -0,0 +1,117 @@
+from itertools import product
+
+import numpy as np
+
+
+class BaseAgent:
+
+ def __init__(self, break_value: float, delta_t: float):
+ self.x = None
+ self.y = None
+ self.phi = None # angle of the velocity vector
+ self.theta = None # direction of the agent
+ self.speed = None
+ self.delta_t = delta_t
+ self.break_value = break_value
+
+ def accelerate(self, value: float) -> None:
+ raise NotImplementedError
+
+ def break_(self) -> None:
+ raise NotImplementedError
+
+ def turn(self, value: float) -> None:
+ raise NotImplementedError
+
+ def reset(self, x: float, y: float, direction: float) -> None:
+ self.x = x
+ self.y = y
+ self.speed = 0
+ self.theta = direction
+
+ def _step(self) -> None:
+ angle = self.theta if self.phi is None else self.phi
+ self.x += self.delta_t * self.speed * np.cos(angle)
+ self.y += self.delta_t * self.speed * np.sin(angle)
+
+
+class MovingAgent(BaseAgent):
+
+ def __init__(self, break_value: float, delta_t: float):
+ super(MovingAgent, self).__init__(break_value, delta_t)
+
+ def accelerate(self, value: float) -> None:
+ self.speed += value
+ self._step()
+
+ def break_(self) -> None:
+ self.speed = 0 if self.speed < self.break_value else self.speed - self.break_value
+ self._step()
+
+ def turn(self, value: float) -> None:
+ self.theta = (self.theta + value) % (2 * np.pi)
+ self._step()
+
+
+class SlidingAgent(BaseAgent):
+
+ def __init__(self, break_value: float, delta_t: float):
+ super(SlidingAgent, self).__init__(break_value, delta_t)
+ self.phi = 0
+
+ def accelerate(self, value: float) -> None:
+ # Adding two polar vectors: https://math.stackexchange.com/a/1365938/849658
+ # phi_1, r_1 = self.theta, value # the direction of the agent and the magnitude induced by the action
+ # phi_2, r_2 = self.phi, self.speed # the direction of the velocity vector and its magnitude
+ speed = np.sqrt(value ** 2 + self.speed ** 2 + 2 * value * self.speed * np.cos(self.phi - self.theta))
+ angle = self.theta + np.arctan2(
+ self.speed * np.sin(self.phi - self.theta), value + self.speed * np.cos(self.phi - self.theta)
+ )
+ self.speed = speed
+ self.phi = angle
+ self._step()
+
+ def break_(self) -> None:
+ self.speed = 0 if self.speed < self.break_value else self.speed - self.break_value
+ self.phi = self.theta if self.speed == 0 else self.phi # not sure it is needed
+ self._step()
+
+ def turn(self, value: float) -> None:
+ self.theta = (self.theta + value) % (2 * np.pi)
+ self._step()
+
+
+class HardMoveAgent(BaseAgent):
+
+ def __init__(self, break_value: float, delta_t: float, num_actuators: int = 4):
+ super(HardMoveAgent, self).__init__(break_value, delta_t)
+ self.phi = 0
+ self.num_actuators = num_actuators
+ # NOTE: meta_to_mask
+ self.K = 2 ** self.num_actuators
+ self.meta_to_mask = list(product(*[list(range(2)) for _ in range(self.num_actuators)]))
+
+ def accelerate(self, value: float) -> None:
+ pass
+
+ def break_(self) -> None:
+ pass
+
+ def turn(self, value: float) -> None:
+ pass
+
+ def move(self, move_direction_meta: int, move_distances: list) -> None:
+ move_directions_mask = self.meta_to_mask[int(move_direction_meta)]
+ self.move_vector = np.array(
+ [
+ move_directions_mask[i] * move_distances[i] *
+ np.array([np.cos(i * 2 * np.pi / self.num_actuators),
+ np.sin(i * 2 * np.pi / self.num_actuators)]) for i in range(len(move_distances))
+ ]
+ ).sum(0)
+ self._step()
+ self.theta = np.arctan(self.y / self.x) # direction of the agent, in radian
+
+ def _step(self) -> None:
+ self.x = self.x + self.move_vector[0]
+ self.y = self.y + self.move_vector[1]
diff --git a/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/environments.py b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/environments.py
new file mode 100644
index 0000000000000000000000000000000000000000..9716bc4484e333c4eb97d45d6d41672ee3893c65
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/gym_hybrid/environments.py
@@ -0,0 +1,405 @@
+from collections import namedtuple
+from typing import Optional
+from typing import Tuple
+
+import gym
+import numpy as np
+import cv2
+import os
+from gym import spaces
+from gym.utils import seeding
+
+# gym.logger.set_level(40) # noqa
+
+from .agents import BaseAgent, MovingAgent, SlidingAgent, HardMoveAgent
+
+# Action Id
+ACCELERATE = 0
+TURN = 1
+BREAK = 2
+
+Target = namedtuple('Target', ['x', 'y', 'radius'])
+
+
+class Action:
+ """"
+ Action class to store and standardize the action for the environment.
+ """
+
+ def __init__(self, id_: int, parameters: list):
+ """"
+ Initialization of an action.
+
+ Args:
+ id_: The id of the selected action.
+ parameters: The parameters of an action.
+ """
+ self.id = id_
+ self.parameters = parameters
+
+ @property
+ def parameter(self) -> float:
+ """"
+ Property method to return the parameter related to the action selected.
+
+ Returns:
+ The parameter related to this action_id
+ """
+ if len(self.parameters) == 2:
+ return self.parameters[self.id]
+ else:
+ return self.parameters[0]
+
+
+class BaseEnv(gym.Env):
+ """"
+ Gym environment parent class.
+ """
+
+ def __init__(
+ self,
+ seed: Optional[int] = None,
+ max_turn: float = np.pi / 2,
+ max_acceleration: float = 0.5,
+ delta_t: float = 0.005,
+ max_step: int = 200,
+ penalty: float = 0.001,
+ break_value: float = 0.1,
+ ):
+ """Initialization of the gym environment.
+
+ Args:
+ seed (int): Seed used to get reproducible results.
+ max_turn (float): Maximum turn during one step (in radian).
+ max_acceleration (float): Maximum acceleration during one step.
+ delta_t (float): Time duration of one step.
+ max_step (int): Maximum number of steps in one episode.
+ penalty (float): Score penalty given at the agent every step.
+ break_value (float): Break value when performing break action.
+ """
+ # Agent Parameters
+ self.max_turn = max_turn
+ self.max_acceleration = max_acceleration
+ self.break_value = break_value
+
+ # Environment Parameters
+ self.delta_t = delta_t
+ self.max_step = max_step
+ self.field_size = 1.0
+ self.target_radius = 0.1
+ self.penalty = penalty
+
+ # Initialization
+ self.seed(seed)
+ self.target = None
+ self.viewer = None
+ self.current_step = None
+ self.agent = BaseAgent(break_value=break_value, delta_t=delta_t)
+
+ parameters_min = np.array([0, -1])
+ parameters_max = np.array([1, +1])
+
+ self.action_space = spaces.Tuple((spaces.Discrete(3), spaces.Box(parameters_min, parameters_max)))
+ self.observation_space = spaces.Box(np.ones(10), -np.ones(10))
+ dirname = os.path.dirname(__file__)
+ self.bg = cv2.imread(os.path.join(dirname, 'bg.jpg'))
+ self.bg = cv2.cvtColor(self.bg, cv2.COLOR_BGR2RGB)
+ self.bg = cv2.resize(self.bg, (800, 800))
+ self.target_img = cv2.imread(os.path.join(dirname, 'target.png'), cv2.IMREAD_UNCHANGED)
+ self.target_img = cv2.resize(self.target_img, (60, 60))
+
+ def seed(self, seed: Optional[int] = None) -> list:
+ self.np_random, seed = seeding.np_random(seed) # noqa
+ return [seed]
+
+ def reset(self) -> list:
+ self.current_step = 0
+
+ limit = self.field_size - self.target_radius
+ low = [-limit, -limit, self.target_radius]
+ high = [limit, limit, self.target_radius]
+ self.target = Target(*self.np_random.uniform(low, high))
+
+ low = [-self.field_size, -self.field_size, 0]
+ high = [self.field_size, self.field_size, 2 * np.pi]
+ self.agent.reset(*self.np_random.uniform(low, high))
+
+ return self.get_state()
+
+ def step(self, raw_action: Tuple[int, list]) -> Tuple[list, float, bool, dict]:
+ action = Action(*raw_action)
+ last_distance = self.distance
+ self.current_step += 1
+
+ if action.id == TURN:
+ rotation = self.max_turn * max(min(action.parameter, 1), -1)
+ self.agent.turn(rotation)
+ elif action.id == ACCELERATE:
+ acceleration = self.max_acceleration * max(min(action.parameter, 1), 0)
+ self.agent.accelerate(acceleration)
+ elif action.id == BREAK:
+ self.agent.break_()
+
+ if self.distance < self.target_radius and self.agent.speed == 0:
+ reward = self.get_reward(last_distance, True)
+ done = True
+ elif abs(self.agent.x) > self.field_size or abs(self.agent.y
+ ) > self.field_size or self.current_step > self.max_step:
+ reward = -1
+ done = True
+ else:
+ reward = self.get_reward(last_distance)
+ done = False
+
+ return self.get_state(), reward, done, {}
+
+ def get_state(self) -> list:
+ state = [
+ self.agent.x, self.agent.y, self.agent.speed,
+ np.cos(self.agent.theta),
+ np.sin(self.agent.theta), self.target.x, self.target.y, self.distance,
+ 0 if self.distance > self.target_radius else 1, self.current_step / self.max_step
+ ]
+ return state
+
+ def get_reward(self, last_distance: float, goal: bool = False) -> float:
+ return last_distance - self.distance - self.penalty + (1 if goal else 0)
+
+ @property
+ def distance(self) -> float:
+ return self.get_distance(self.agent.x, self.agent.y, self.target.x, self.target.y)
+
+ @staticmethod
+ def get_distance(x1: float, y1: float, x2: float, y2: float) -> float:
+ return np.sqrt(((x1 - x2) ** 2) + ((y1 - y2) ** 2)).item()
+
+ def render(self, mode='human'):
+ screen_width = 400
+ screen_height = 400
+ unit_x = screen_width / 2
+ unit_y = screen_height / 2
+ agent_radius = 0.05
+
+ if self.viewer is None:
+ from gym.envs.classic_control import rendering
+ self.viewer = rendering.Viewer(screen_width, screen_height)
+
+ agent = rendering.make_circle(unit_x * agent_radius)
+ self.agent_trans = rendering.Transform(
+ translation=(unit_x * (1 + self.agent.x), unit_y * (1 + self.agent.y))
+ ) # noqa
+ agent.add_attr(self.agent_trans)
+ agent.set_color(0.1, 0.3, 0.9)
+ self.viewer.add_geom(agent)
+
+ t, r, m = 0.1 * unit_x, 0.04 * unit_y, 0.06 * unit_x
+ arrow = rendering.FilledPolygon([(t, 0), (m, r), (m, -r)])
+ self.arrow_trans = rendering.Transform(rotation=self.agent.theta) # noqa
+ arrow.add_attr(self.arrow_trans)
+ arrow.add_attr(self.agent_trans)
+ arrow.set_color(0, 0, 0)
+ self.viewer.add_geom(arrow)
+
+ target = rendering.make_circle(unit_x * self.target_radius, filled=False)
+ target_trans = rendering.Transform(translation=(unit_x * (1 + self.target.x), unit_y * (1 + self.target.y)))
+ target.add_attr(target_trans)
+ target.set_color(0, 0.6, 0)
+ self.viewer.add_geom(target)
+
+ self.arrow_trans.set_rotation(self.agent.theta)
+ self.agent_trans.set_translation(unit_x * (1 + self.agent.x), unit_y * (1 + self.agent.y))
+
+ ret = self.viewer.render(return_rgb_array=mode == 'rgb_array')
+ # add background
+ ret = np.where(ret == 255, self.bg, ret)
+ # add target logo
+ # # x, y = int(unit_x * (1 + self.target.x)), int(unit_y * (1 - self.target.y))
+ # # x, y = x - 20, y + 25 # seed0
+ # target_area = ret[x:x+60, y:y+60]
+ # rgb_img = cv2.cvtColor(self.target_img[..., :3], cv2.COLOR_BGR2RGB)
+ # target_area = np.where(self.target_img[..., -1:] == 0, target_area, rgb_img)
+ # ret[x:x+60, y:y+60] = target_area
+ # add frame
+ frames = np.array([60, 60, 30]).reshape(1, 1, -1)
+ ret[:6] = frames
+ ret[:, :6] = frames
+ ret[-6:] = frames
+ ret[:, -6:] = frames
+ return ret
+
+ def close(self):
+ if self.viewer:
+ self.viewer.close()
+ self.viewer = None
+
+
+class MovingEnv(BaseEnv):
+
+ def __init__(
+ self,
+ seed: int = None,
+ max_turn: float = np.pi / 2,
+ max_acceleration: float = 0.5,
+ delta_t: float = 0.005,
+ max_step: int = 200,
+ penalty: float = 0.001,
+ break_value: float = 0.1,
+ ):
+ super(MovingEnv, self).__init__(
+ seed=seed,
+ max_turn=max_turn,
+ max_acceleration=max_acceleration,
+ delta_t=delta_t,
+ max_step=max_step,
+ penalty=penalty,
+ break_value=break_value,
+ )
+
+ self.agent = MovingAgent(
+ break_value=break_value,
+ delta_t=delta_t,
+ )
+
+
+class SlidingEnv(BaseEnv):
+
+ def __init__(
+ self,
+ seed: int = None,
+ max_turn: float = np.pi / 2,
+ max_acceleration: float = 0.5,
+ delta_t: float = 0.005,
+ max_step: int = 200,
+ penalty: float = 0.001,
+ break_value: float = 0.1
+ ):
+ super(SlidingEnv, self).__init__(
+ seed=seed,
+ max_turn=max_turn,
+ max_acceleration=max_acceleration,
+ delta_t=delta_t,
+ max_step=max_step,
+ penalty=penalty,
+ break_value=break_value
+ )
+
+ self.agent = SlidingAgent(break_value=break_value, delta_t=delta_t)
+
+
+class HardMoveEnv(gym.Env):
+ """"
+ HardMove environment. Please refer to https://arxiv.org/abs/2109.05490 for details.
+ """
+
+ def __init__(
+ self,
+ num_actuators: int = 4,
+ seed: Optional[int] = None,
+ max_turn: float = np.pi / 2,
+ max_acceleration: float = 0.5,
+ delta_t: float = 0.005,
+ max_step: int = 25,
+ penalty: float = 0.001,
+ break_value: float = 0.1,
+ ):
+ """Initialization of the gym environment.
+
+ Args:
+ seed (int): Seed used to get reproducible results.
+ max_turn (float): Maximum turn during one step (in radian).
+ max_acceleration (float): Maximum acceleration during one step.
+ delta_t (float): Time duration of one step.
+ max_step (int): Maximum number of steps in one episode.
+ penalty (float): Score penalty given at the agent every step.
+ break_value (float): Break value when performing break action.
+ """
+ # Agent Parameters
+ self.num_actuators = num_actuators
+ self.max_turn = max_turn
+ self.max_acceleration = max_acceleration
+ self.break_value = break_value
+
+ # Environment Parameters
+ self.delta_t = delta_t
+ self.max_step = max_step
+ self.field_size = 1.0
+ self.target_radius = 0.1
+ self.penalty = penalty
+
+ # Initialization
+ self.seed(seed)
+ self.target = None
+ self.viewer = None
+ self.current_step = None
+ self.agent = HardMoveAgent(break_value=break_value, delta_t=delta_t, num_actuators=self.num_actuators)
+
+ parameters_min = np.array([-1 for i in range(self.num_actuators)])
+ parameters_max = np.array([+1 for i in range(self.num_actuators)])
+
+ self.action_space = spaces.Tuple(
+ (spaces.Discrete(int(2 ** self.num_actuators)), spaces.Box(parameters_min, parameters_max))
+ )
+ self.observation_space = spaces.Box(np.ones(10), -np.ones(10))
+
+ def seed(self, seed: Optional[int] = None) -> list:
+ self.np_random, seed = seeding.np_random(seed) # noqa
+ return [seed]
+
+ def reset(self) -> list:
+ self.current_step = 0
+
+ limit = self.field_size - self.target_radius
+ low = [-limit, -limit, self.target_radius]
+ high = [limit, limit, self.target_radius]
+ self.target = Target(*self.np_random.uniform(low, high))
+
+ low = [-self.field_size, -self.field_size, 0]
+ high = [self.field_size, self.field_size, 2 * np.pi]
+ self.agent.reset(*self.np_random.uniform(low, high))
+
+ return self.get_state()
+
+ def step(self, raw_action: Tuple[int, list]) -> Tuple[list, float, bool, dict]:
+ move_direction_meta = raw_action[0] # shape (1,) in {2**n}
+ move_distances = raw_action[1] # shape (2**n,)
+ last_distance = self.distance
+ self.current_step += 1
+
+ self.agent.move(move_direction_meta, move_distances)
+ if self.distance < self.target_radius:
+ reward = self.get_reward(last_distance, True)
+ done = True
+ elif abs(self.agent.x) > self.field_size or abs(self.agent.y
+ ) > self.field_size or self.current_step > self.max_step:
+ reward = -1
+ done = True
+ else:
+ reward = self.get_reward(last_distance)
+ done = False
+
+ return self.get_state(), reward, done, {}
+
+ def get_state(self) -> list:
+ state = [
+ self.agent.x, self.agent.y, self.agent.speed,
+ np.cos(self.agent.theta),
+ np.sin(self.agent.theta), self.target.x, self.target.y, self.distance,
+ 0 if self.distance > self.target_radius else 1, self.current_step / self.max_step
+ ]
+ return state
+
+ def get_reward(self, last_distance: float, goal: bool = False) -> float:
+ return last_distance - self.distance - self.penalty + (1 if goal else 0)
+
+ @property
+ def distance(self) -> float:
+ return self.get_distance(self.agent.x, self.agent.y, self.target.x, self.target.y)
+
+ @staticmethod
+ def get_distance(x1: float, y1: float, x2: float, y2: float) -> float:
+ return np.sqrt(((x1 - x2) ** 2) + ((y1 - y2) ** 2)).item()
+
+ def close(self):
+ if self.viewer:
+ self.viewer.close()
+ self.viewer = None
diff --git a/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/setup.py b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..248ccb453559e4127c556d45811238ba6dc36570
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/setup.py
@@ -0,0 +1,8 @@
+from setuptools import setup
+
+setup(
+ name='gym_hybrid',
+ version='0.0.2', # original gym_hybrid version='0.0.1'
+ packages=['gym_hybrid'],
+ install_requires=['gym', 'numpy'],
+)
diff --git a/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/hardmove.py b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/hardmove.py
new file mode 100644
index 0000000000000000000000000000000000000000..bde6b6eb8f43509c79dbcf6497801bbe7fe11432
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/hardmove.py
@@ -0,0 +1,17 @@
+import time
+import gym
+import gym_hybrid
+
+if __name__ == '__main__':
+ env = gym.make('HardMove-v0')
+ env.reset()
+
+ ACTION_SPACE = env.action_space[0].n
+ PARAMETERS_SPACE = env.action_space[1].shape[0]
+ OBSERVATION_SPACE = env.observation_space.shape[0]
+
+ done = False
+ while not done:
+ state, reward, done, info = env.step(env.action_space.sample())
+ print(f'State: {state} Reward: {reward} Done: {done}')
+ time.sleep(0.1)
diff --git a/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/moving.py b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/moving.py
new file mode 100644
index 0000000000000000000000000000000000000000..52315decd914f820744174ad156fdd51cfc9d4aa
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/moving.py
@@ -0,0 +1,17 @@
+import time
+import gym
+import gym_hybrid
+
+if __name__ == '__main__':
+ env = gym.make('Moving-v0')
+ env.reset()
+
+ ACTION_SPACE = env.action_space[0].n
+ PARAMETERS_SPACE = env.action_space[1].shape[0]
+ OBSERVATION_SPACE = env.observation_space.shape[0]
+
+ done = False
+ while not done:
+ state, reward, done, info = env.step(env.action_space.sample())
+ print(f'State: {state} Reward: {reward} Done: {done}')
+ time.sleep(0.1)
diff --git a/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/record.py b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/record.py
new file mode 100644
index 0000000000000000000000000000000000000000..d97eaa13b22dacb78ecc6113d3145e92243c5fe2
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/record.py
@@ -0,0 +1,14 @@
+import gym
+import gym_hybrid
+
+if __name__ == '__main__':
+ env = gym.make('Sliding-v0')
+ env = gym.wrappers.Monitor(env, "./video", force=True)
+ env.metadata["render.modes"] = ["human", "rgb_array"]
+ env.reset()
+
+ done = False
+ while not done:
+ _, _, done, _ = env.step(env.action_space.sample())
+
+ env.close()
diff --git a/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/render.py b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/render.py
new file mode 100644
index 0000000000000000000000000000000000000000..a382525fc4f624c0af06b771b1ee0cb11560d09b
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/render.py
@@ -0,0 +1,16 @@
+import time
+import gym
+import gym_hybrid
+
+if __name__ == '__main__':
+ env = gym.make('Sliding-v0')
+ env.reset()
+
+ done = False
+ while not done:
+ _, _, done, _ = env.step(env.action_space.sample())
+ env.render()
+ time.sleep(0.1)
+
+ time.sleep(1)
+ env.close()
diff --git a/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/sliding.py b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/sliding.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a44dc03329dbc793171949f7750abaf52421f51
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/gym-hybrid/tests/sliding.py
@@ -0,0 +1,17 @@
+import time
+import gym
+import gym_hybrid
+
+if __name__ == '__main__':
+ env = gym.make('Sliding-v0')
+ env.reset()
+
+ ACTION_SPACE = env.action_space[0].n
+ PARAMETERS_SPACE = env.action_space[1].shape[0]
+ OBSERVATION_SPACE = env.observation_space.shape[0]
+
+ done = False
+ while not done:
+ state, reward, done, info = env.step(env.action_space.sample())
+ print(f'State: {state} Reward: {reward} Done: {done}')
+ time.sleep(0.1)
diff --git a/DI-engine/dizoo/gym_hybrid/envs/gym_hybrid_env.py b/DI-engine/dizoo/gym_hybrid/envs/gym_hybrid_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..9f02925d1a1ae2a869eb91501d24110d1d7794d9
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/gym_hybrid_env.py
@@ -0,0 +1,160 @@
+import copy
+import os
+from typing import Dict, Optional
+
+import gym
+import gym_hybrid
+import matplotlib.pyplot as plt
+import numpy as np
+from easydict import EasyDict
+from matplotlib import animation
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.envs.common import affine_transform
+from ding.torch_utils import to_ndarray
+from ding.utils import ENV_REGISTRY
+
+
+@ENV_REGISTRY.register('gym_hybrid')
+class GymHybridEnv(BaseEnv):
+ default_env_id = ['Sliding-v0', 'Moving-v0', 'HardMove-v0']
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ config = dict(
+ env_id='Moving-v0',
+ act_scale=True,
+ )
+
+ def __init__(self, cfg: EasyDict) -> None:
+ self._cfg = cfg
+ self._env_id = cfg.env_id
+ assert self._env_id in self.default_env_id
+ self._act_scale = cfg.act_scale
+ self._replay_path = None
+ self._save_replay = False
+ self._save_replay_count = 0
+ self._init_flag = False
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ if self._env_id == 'HardMove-v0':
+ self._env = gym.make(self._env_id, num_actuators=self._cfg.num_actuators)
+ else:
+ self._env = gym.make(self._env_id)
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+ obs = to_ndarray(obs).astype(np.float32)
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Dict) -> BaseEnvTimestep:
+ if self._act_scale:
+ if self._env_id == 'HardMove-v0':
+ action = [
+ action['action_type'], [affine_transform(i, min_val=-1, max_val=1) for i in action['action_args']]
+ ]
+ else:
+ # acceleration_value.
+ action['action_args'][0] = affine_transform(action['action_args'][0], min_val=0, max_val=1)
+ # rotation_value. Following line can be omitted, because in the affine_transform function,
+ # we have already done the clip(-1,1) operation
+ action['action_args'][1] = affine_transform(action['action_args'][1], min_val=-1, max_val=1)
+ action = [action['action_type'], action['action_args']]
+ if self._save_replay:
+ self._frames.append(self._env.render(mode='rgb_array'))
+ obs, rew, done, info = self._env.step(action)
+
+ obs = to_ndarray(obs)
+ if isinstance(obs, list): # corner case
+ for i in range(len(obs)):
+ if len(obs[i].shape) == 0:
+ obs[i] = np.array([obs[i]])
+ obs = np.concatenate(obs)
+ assert isinstance(obs, np.ndarray) and obs.shape == (10, )
+ obs = obs.astype(np.float32)
+
+ rew = to_ndarray([rew]) # wrapped to be transferred to a numpy array with shape (1,)
+ if isinstance(rew, list):
+ rew = rew[0]
+ assert isinstance(rew, np.ndarray) and rew.shape == (1, )
+ self._eval_episode_return += rew.item()
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ if self._save_replay:
+ if self._env_id == 'HardMove-v0':
+ self._env_id = f'hardmove_n{self._cfg.num_actuators}'
+ path = os.path.join(
+ self._replay_path, '{}_episode_{}.gif'.format(self._env_id, self._save_replay_count)
+ )
+ self.display_frames_as_gif(self._frames, path)
+ self._frames = []
+ self._save_replay_count += 1
+ info['action_args_mask'] = np.array([[1, 0], [0, 1], [0, 0]])
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def random_action(self) -> Dict:
+ # action_type: 0, 1, 2
+ # action_args:
+ # - acceleration_value: [0, 1]
+ # - rotation_value: [-1, 1]
+ raw_action = self._action_space.sample()
+ return {'action_type': raw_action[0], 'action_args': raw_action[1]}
+
+ def __repr__(self) -> str:
+ return "DI-engine gym hybrid Env"
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+ self._save_replay = True
+ self._save_replay_count = 0
+ self._frames = []
+
+ @staticmethod
+ def display_frames_as_gif(frames: list, path: str) -> None:
+ patch = plt.imshow(frames[0])
+ plt.axis('off')
+
+ def animate(i):
+ patch.set_data(frames[i])
+
+ anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=5)
+ anim.save(path, writer='imagemagick', fps=20)
diff --git a/DI-engine/dizoo/gym_hybrid/envs/test_gym_hybrid_env.py b/DI-engine/dizoo/gym_hybrid/envs/test_gym_hybrid_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..896987f33f9b721e5e50262357436b34df00e82b
--- /dev/null
+++ b/DI-engine/dizoo/gym_hybrid/envs/test_gym_hybrid_env.py
@@ -0,0 +1,40 @@
+import numpy as np
+import pytest
+from dizoo.gym_hybrid.envs import GymHybridEnv
+from easydict import EasyDict
+
+
+@pytest.mark.envtest
+class TestGymHybridEnv:
+
+ def test_naive(self):
+ env = GymHybridEnv(
+ EasyDict(
+ {
+ 'env_id': 'Moving-v0',
+ 'act_scale': False,
+ 'save_replay_gif': False,
+ 'replay_path_gif': None,
+ 'replay_path': None
+ }
+ )
+ )
+ env.enable_save_replay('./video')
+ env.seed(314, dynamic_seed=False)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (10, )
+ for i in range(200):
+ random_action = env.random_action()
+ print('random_action', random_action)
+ timestep = env.step(random_action)
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (10, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.info['action_args_mask'].shape == (3, 2)
+ if timestep.done:
+ print('reset env')
+ env.reset()
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/gym_pybullet_drones/__init__.py b/DI-engine/dizoo/gym_pybullet_drones/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/gym_pybullet_drones/config/flythrugate_onppo_config.py b/DI-engine/dizoo/gym_pybullet_drones/config/flythrugate_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ee7bc55dc172cbe0d4a88ad4014ea44eee0107f4
--- /dev/null
+++ b/DI-engine/dizoo/gym_pybullet_drones/config/flythrugate_onppo_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+flythrugate_ppo_config = dict(
+ exp_name='flythrugate_ppo_seed0',
+ env=dict(
+ manager=dict(shared_memory=False, reset_inplace=True),
+ env_id='flythrugate-aviary-v0',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=8,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=0,
+ action_type="VEL",
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ # load_path="./flythrugate_ppo_seed0/ckpt/ckpt_best.pth.tar",
+ model=dict(
+ obs_shape=12,
+ action_shape=4,
+ action_space='continuous',
+ ),
+ action_space='continuous',
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.0,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=2048,
+ gae_lambda=0.97,
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ ),
+)
+flythrugate_ppo_config = EasyDict(flythrugate_ppo_config)
+main_config = flythrugate_ppo_config
+
+flythrugate_ppo_create_config = dict(
+ env=dict(
+ type='gym_pybullet_drones',
+ import_names=['dizoo.gym_pybullet_drones.envs.gym_pybullet_drones_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo', ),
+)
+flythrugate_ppo_create_config = EasyDict(flythrugate_ppo_create_config)
+create_config = flythrugate_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c flythrugate_ppo_config.py -s 0 --env-step 1e7`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/gym_pybullet_drones/config/takeoffaviary_onppo_config.py b/DI-engine/dizoo/gym_pybullet_drones/config/takeoffaviary_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe2f7bfa9c4230eff2b6c772911c1df19edfefdc
--- /dev/null
+++ b/DI-engine/dizoo/gym_pybullet_drones/config/takeoffaviary_onppo_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+takeoffaviary_ppo_config = dict(
+ exp_name='takeoffaviary_ppo_seed0',
+ env=dict(
+ manager=dict(shared_memory=False, reset_inplace=True),
+ env_id='takeoff-aviary-v0',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=8,
+ evaluator_env_num=8,
+ use_act_scale=True,
+ n_evaluator_episode=8,
+ stop_value=0,
+ action_type="VEL",
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ # load_path="./takeoffaviary_ppo_seed0/ckpt/ckpt_best.pth.tar",
+ model=dict(
+ obs_shape=12,
+ action_shape=4,
+ action_space='continuous',
+ ),
+ action_space='continuous',
+ learn=dict(
+ epoch_per_collect=10, #reduce
+ batch_size=64,
+ learning_rate=3e-4, #tune; pytorch lr scheduler
+ value_weight=0.5,
+ entropy_weight=0.0, #0.001
+ clip_ratio=0.2, #0.1
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=2048,
+ gae_lambda=0.97,
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ ),
+)
+takeoffaviary_ppo_config = EasyDict(takeoffaviary_ppo_config)
+main_config = takeoffaviary_ppo_config
+
+takeoffaviary_ppo_create_config = dict(
+ env=dict(
+ type='gym_pybullet_drones',
+ import_names=['dizoo.gym_pybullet_drones.envs.gym_pybullet_drones_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo', ),
+)
+takeoffaviary_ppo_create_config = EasyDict(takeoffaviary_ppo_create_config)
+create_config = takeoffaviary_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c takeoffaviary_ppo_config.py -s 0 --env-step 1e7`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/gym_pybullet_drones/entry/flythrugate_onppo_eval.py b/DI-engine/dizoo/gym_pybullet_drones/entry/flythrugate_onppo_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..05337e8c241690d7e662f150ce56c8593d2b841e
--- /dev/null
+++ b/DI-engine/dizoo/gym_pybullet_drones/entry/flythrugate_onppo_eval.py
@@ -0,0 +1,55 @@
+import os
+import gym
+import torch
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, NaiveReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import PPOPolicy
+from ding.model import VAC
+from ding.utils import set_pkg_seed
+
+from dizoo.gym_pybullet_drones.envs.gym_pybullet_drones_env import GymPybulletDronesEnv
+from dizoo.gym_pybullet_drones.config.flythrugate_onppo_config import flythrugate_ppo_config
+
+
+def main(cfg, seed=0, max_iterations=int(1e10)):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ PPOPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ NaiveReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+
+ info = cfg.env.manager
+
+ cfg.env['record'] = True
+ cfg.env['gui'] = True
+ cfg.env['print_debug_info'] = True
+ cfg.env['plot_observation'] = True
+
+ evaluator_env = BaseEnvManager(
+ env_fn=[lambda: GymPybulletDronesEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
+
+ tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
+ evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
+ evaluator.eval()
+
+
+if __name__ == "__main__":
+ main(flythrugate_ppo_config)
diff --git a/DI-engine/dizoo/gym_pybullet_drones/entry/takeoffaviary_onppo_eval.py b/DI-engine/dizoo/gym_pybullet_drones/entry/takeoffaviary_onppo_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..3ff48259cf9206e7f5feb0d55c8e2a44e83f8bc5
--- /dev/null
+++ b/DI-engine/dizoo/gym_pybullet_drones/entry/takeoffaviary_onppo_eval.py
@@ -0,0 +1,53 @@
+import os
+import gym
+import torch
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, NaiveReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import PPOPolicy
+from ding.model import VAC
+from ding.utils import set_pkg_seed
+
+from dizoo.gym_pybullet_drones.envs.gym_pybullet_drones_env import GymPybulletDronesEnv
+from dizoo.gym_pybullet_drones.config.takeoffaviary_onppo_config import takeoffaviary_ppo_config
+
+
+def main(cfg, seed=0, max_iterations=int(1e10)):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ PPOPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ NaiveReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+
+ cfg.env['record'] = True
+ cfg.env['gui'] = True
+ cfg.env['print_debug_info'] = True
+ cfg.env['plot_observation'] = True
+
+ evaluator_env = BaseEnvManager(
+ env_fn=[lambda: GymPybulletDronesEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
+
+ tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
+ evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
+ evaluator.eval()
+
+
+if __name__ == "__main__":
+ main(takeoffaviary_ppo_config)
diff --git a/DI-engine/dizoo/gym_pybullet_drones/envs/__init__.py b/DI-engine/dizoo/gym_pybullet_drones/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a2fba6a467251e8e5498ff8123942abb7ac5927
--- /dev/null
+++ b/DI-engine/dizoo/gym_pybullet_drones/envs/__init__.py
@@ -0,0 +1 @@
+from .gym_pybullet_drones_env import GymPybulletDronesEnv
diff --git a/DI-engine/dizoo/gym_pybullet_drones/envs/gym_pybullet_drones_env.py b/DI-engine/dizoo/gym_pybullet_drones/envs/gym_pybullet_drones_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..8b1ca6fcceab1c6315041f894ea044cfe4c9d8e6
--- /dev/null
+++ b/DI-engine/dizoo/gym_pybullet_drones/envs/gym_pybullet_drones_env.py
@@ -0,0 +1,270 @@
+from typing import Optional, Callable
+import numpy as np
+import copy
+import gym
+from gym.spaces import Box
+import gym_pybullet_drones
+from gym_pybullet_drones.utils.enums import DroneModel, Physics
+from gym_pybullet_drones.envs.single_agent_rl.BaseSingleAgentAviary import ActionType, ObservationType
+from gym_pybullet_drones.utils.Logger import Logger
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.torch_utils import to_ndarray
+from ding.utils import ENV_REGISTRY
+
+from easydict import EasyDict
+
+
+def gym_pybullet_drones_observation_space(dim, minimum=-np.inf, maximum=np.inf, dtype=np.float32) -> Callable:
+ lower_bound = np.repeat(minimum, dim).astype(dtype)
+ upper_bound = np.repeat(maximum, dim).astype(dtype)
+ lower_bound[2] = 0.0
+ return Box(lower_bound, upper_bound, dtype=dtype)
+
+
+def drones_action_dim(type_of_action) -> int:
+ if type_of_action in [ActionType.RPM, ActionType.DYN, ActionType.VEL]:
+ return 4
+ elif type_of_action == ActionType.PID:
+ return 3
+ elif type_of_action == ActionType.TUN:
+ return 6
+ elif type_of_action in [ActionType.ONE_D_DYN, ActionType.ONE_D_PID, ActionType.ONE_D_RPM]:
+ return 1
+ else:
+ raise ValueError('Invalid action type.')
+
+
+def gym_pybullet_drones_action_space(drone_num=1, minimum=-1, maximum=1, dtype=np.float32) -> Callable:
+
+ def _gym_pybullet_drones_action_space(type_of_action) -> Box:
+ dim = drones_action_dim(type_of_action)
+ return Box(
+ np.repeat(minimum, dim * drone_num).astype(dtype),
+ np.repeat(maximum, dim * drone_num).astype(dtype),
+ dtype=dtype
+ )
+
+ return _gym_pybullet_drones_action_space
+
+
+def gym_pybullet_drones_reward_space(minimum=-10000, maximum=0, dtype=np.float32) -> Callable:
+ return Box(np.repeat(minimum, 1).astype(dtype), np.repeat(maximum, 1).astype(dtype), dtype=dtype)
+
+
+gym_pybullet_drones_env_info = {
+ "takeoff-aviary-v0": {
+ "observation_space": gym_pybullet_drones_observation_space(12, minimum=-1, maximum=1),
+ "action_space": gym_pybullet_drones_action_space(drone_num=1, minimum=-1, maximum=1),
+ "reward_space": gym_pybullet_drones_reward_space()
+ },
+ "flythrugate-aviary-v0": {
+ "observation_space": gym_pybullet_drones_observation_space(12, minimum=-1, maximum=1),
+ "action_space": gym_pybullet_drones_action_space(drone_num=1, minimum=-1, maximum=1),
+ "reward_space": gym_pybullet_drones_reward_space()
+ },
+}
+
+action_type = {
+ "PID": ActionType.PID,
+ "DYN": ActionType.DYN,
+ "VEL": ActionType.VEL,
+ "RPM": ActionType.RPM,
+ "TUN": ActionType.TUN,
+ "ONE_D_DYN": ActionType.ONE_D_DYN,
+ "ONE_D_PID": ActionType.ONE_D_PID,
+ "ONE_D_RPM": ActionType.ONE_D_RPM,
+}
+
+
+@ENV_REGISTRY.register('gym_pybullet_drones')
+class GymPybulletDronesEnv(BaseEnv):
+ """
+ Gym_Pybullet_Drones Environment for training and simulating UAV drones in pybullet physical engine.
+ The tasks are registered in the standard of gym library.
+ url: 'https://github.com/utiasDSL/gym-pybullet-drones'
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ config = {
+ 'num_drones': 1,
+ 'print_debug_info': False,
+ 'output_folder': "./results",
+ 'plot_observation': False,
+ 'freq': 240,
+ 'aggregate_phy_steps': 1,
+ 'gui': False,
+ 'record': False,
+ "action_type": "RPM",
+ }
+
+ def __init__(self, cfg: dict = {}) -> None:
+ self.raw_cfg = copy.deepcopy(cfg)
+ for k, v in self.default_config().items():
+ if k not in cfg:
+ cfg[k] = v
+
+ if cfg["num_drones"] == 1:
+ self.env_kwargs = {
+ 'drone_model': DroneModel.CF2X,
+ 'initial_xyzs': None,
+ 'initial_rpys': None,
+ 'physics': Physics.PYB,
+ 'freq': 240,
+ 'aggregate_phy_steps': 1,
+ 'gui': False,
+ 'record': False,
+ 'obs': ObservationType.KIN,
+ 'act': ActionType.RPM
+ }
+ else:
+ # TODO(zjow): develop envs that support multi drones.
+ self.env_kwargs = {
+ 'drone_model': DroneModel.CF2X,
+ 'num_drones': 2,
+ 'neighbourhood_radius': np.inf,
+ 'initial_xyzs': None,
+ 'initial_rpys': None,
+ 'physics': Physics.PYB,
+ 'freq': 240,
+ 'aggregate_phy_steps': 1,
+ 'gui': False,
+ 'record': False,
+ 'obs': ObservationType.KIN,
+ 'act': ActionType.RPM
+ }
+
+ self._cfg = cfg
+
+ for k, _ in self.env_kwargs.items():
+ if k in cfg:
+ self.env_kwargs[k] = cfg[k]
+
+ self.env_kwargs["act"] = action_type[cfg["action_type"]]
+ self.action_type = self.env_kwargs["act"]
+
+ self._env_id = cfg.env_id
+ self._init_flag = False
+ self._replay_path = None
+
+ self._observation_space = gym_pybullet_drones_env_info[cfg.env_id]["observation_space"]
+ self._action_space = gym_pybullet_drones_env_info[cfg.env_id]["action_space"](self.action_type)
+ self._action_dim = drones_action_dim(self.action_type) * self._cfg["num_drones"]
+ self._reward_space = gym_pybullet_drones_env_info[cfg.env_id]["reward_space"]
+
+ self.env_step_count = 0
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+
+ self._env = gym.make(self._env_id, **self.env_kwargs)
+
+ if self._cfg["plot_observation"]:
+ self.observation_logger = Logger(
+ logging_freq_hz=int(self._env.SIM_FREQ / self._env.AGGR_PHY_STEPS),
+ num_drones=1,
+ output_folder=self._cfg["output_folder"]
+ )
+
+ self._init_flag = True
+
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+ obs = to_ndarray(obs).astype(np.float32)
+ self.env_step_count = 0
+ if self._cfg["plot_observation"]:
+ self.observation_logger.log(
+ drone=0,
+ timestamp=self.env_step_count / self._env.SIM_FREQ,
+ state=np.hstack([obs[0:3], np.zeros(4), obs[3:15],
+ np.resize(np.zeros(self._action_dim), (4))]),
+ control=np.zeros(12)
+ )
+ if self._cfg["print_debug_info"]:
+ if self.env_step_count % self._env.SIM_FREQ == 0:
+ self._env.render()
+ self.env_step_count += 1
+
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ # action = action.astype('float32')
+ obs, rew, done, info = self._env.step(action)
+ if self._cfg["plot_observation"]:
+ self.observation_logger.log(
+ drone=0,
+ timestamp=self.env_step_count / self._env.SIM_FREQ,
+ state=np.hstack([obs[0:3], np.zeros(4), obs[3:15],
+ np.resize(action, (4))]),
+ control=np.zeros(12)
+ )
+
+ if self._cfg["print_debug_info"]:
+ if self.env_step_count % self._env.SIM_FREQ == 0:
+ self._env.render()
+ self.env_step_count += 1
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ if self._cfg["print_debug_info"]:
+ self.plot_observation_curve()
+
+ obs = to_ndarray(obs).astype(np.float32)
+ rew = to_ndarray([rew]).astype(np.float32) # wrapped to be transfered to a array with shape (1,)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ def random_action(self) -> np.ndarray:
+ return self.action_space.sample().astype(np.float32)
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ if not self._init_flag:
+ return self._observation_space
+ else:
+ return self._env.observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ if not self._init_flag:
+ return self._action_space
+ else:
+ return self._env.action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine gym_pybullet_drones Env: " + self._cfg["env_id"]
+
+ def plot_observation_curve(self) -> None:
+ if self._cfg["plot_observation"]:
+ self.observation_logger.plot()
+
+ def clone(self, caller: str) -> 'GymPybulletDronesEnv':
+ return GymPybulletDronesEnv(self.raw_cfg)
diff --git a/DI-engine/dizoo/gym_pybullet_drones/envs/test_ding_env.py b/DI-engine/dizoo/gym_pybullet_drones/envs/test_ding_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..eae113ef7958658b4ca2810ec485c94066f9b090
--- /dev/null
+++ b/DI-engine/dizoo/gym_pybullet_drones/envs/test_ding_env.py
@@ -0,0 +1,32 @@
+import pytest
+from easydict import EasyDict
+import gym_pybullet_drones
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from dizoo.gym_pybullet_drones.envs.gym_pybullet_drones_env import GymPybulletDronesEnv
+
+
+@pytest.mark.envtest
+class TestGymPybulletDronesEnv:
+
+ def test_naive(self):
+ cfg = {"env_id": "takeoff-aviary-v0"}
+ cfg = EasyDict(cfg)
+ env = GymPybulletDronesEnv(cfg)
+
+ env.reset()
+ done = False
+ while not done:
+ action = env.action_space.sample()
+ assert action.shape[0] == 4
+
+ for i in range(action.shape[0]):
+ assert action[i] >= env.action_space.low[i] and action[i] <= env.action_space.high[i]
+
+ obs, reward, done, info = env.step(action)
+
+ assert obs.shape[0] == 12
+ for i in range(obs.shape[0]):
+ assert obs[i] >= env.observation_space.low[i] and obs[i] <= env.observation_space.high[i]
+
+ assert reward >= env.reward_space.low and reward <= env.reward_space.high
diff --git a/DI-engine/dizoo/gym_pybullet_drones/envs/test_ori_env.py b/DI-engine/dizoo/gym_pybullet_drones/envs/test_ori_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..bff4538ee9e38395a2fd6bd1ebbf0e160cfa630a
--- /dev/null
+++ b/DI-engine/dizoo/gym_pybullet_drones/envs/test_ori_env.py
@@ -0,0 +1,27 @@
+import pytest
+import gym
+import numpy as np
+
+import gym_pybullet_drones
+
+
+@pytest.mark.envtest
+class TestGymPybulletDronesOriEnv:
+
+ def test_naive(self):
+ env = gym.make("takeoff-aviary-v0")
+ env.reset()
+ done = False
+ while not done:
+ action = env.action_space.sample()
+ assert action.shape[0] == 4
+
+ for i in range(action.shape[0]):
+ assert action[i] >= env.action_space.low[i] and action[i] <= env.action_space.high[i]
+
+ obs, reward, done, info = env.step(action)
+ assert obs.shape[0] == 12
+ for i in range(obs.shape[0]):
+ assert obs[i] >= env.observation_space.low[i] and obs[i] <= env.observation_space.high[i]
+
+ assert reward >= env.reward_space.low and reward <= env.reward_space.high
diff --git a/DI-engine/dizoo/gym_soccer/__init__.py b/DI-engine/dizoo/gym_soccer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/gym_soccer/config/gym_soccer_pdqn_config.py b/DI-engine/dizoo/gym_soccer/config/gym_soccer_pdqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc2e409184b6a86f9f5d65bb06bd4f06509cf9be
--- /dev/null
+++ b/DI-engine/dizoo/gym_soccer/config/gym_soccer_pdqn_config.py
@@ -0,0 +1,79 @@
+from easydict import EasyDict
+
+gym_soccer_pdqn_config = dict(
+ exp_name='gym_soccer_pdqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # (bool) Scale output action into legal range [-1, 1].
+ act_scale=True,
+ env_id='Soccer-v0', # ['Soccer-v0', 'SoccerEmptyGoal-v0', 'SoccerAgainstKeeper-v0']
+ n_evaluator_episode=5,
+ stop_value=0.99,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=False,
+ discount_factor=0.99,
+ nstep=1,
+ model=dict(
+ obs_shape=10,
+ action_shape=dict(
+ action_type_shape=3,
+ action_args_shape=5,
+ ),
+ ),
+ learn=dict(
+ update_per_collect=500, # 10 ~ 500
+ batch_size=320,
+ learning_rate_dis=3e-4,
+ learning_rate_cont=3e-4,
+ target_theta=0.001,
+ update_circle=10,
+ ),
+ # collect_mode config
+ collect=dict(
+ # (int) Only one of [n_sample, n_episode] shoule be set
+ n_sample=3200,
+ # (int) Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ noise_sigma=0.1,
+ collector=dict(collect_print_freq=1000, ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ # other config
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # (str) Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=1,
+ end=0.1,
+ # (int) Decay length(env step)
+ decay=int(1e5),
+ ),
+ replay_buffer=dict(replay_buffer_size=int(1e6), ),
+ ),
+ )
+)
+
+gym_soccer_pdqn_config = EasyDict(gym_soccer_pdqn_config)
+main_config = gym_soccer_pdqn_config
+
+gym_soccer_pdqn_create_config = dict(
+ env=dict(
+ type='gym_soccer',
+ import_names=['dizoo.gym_soccer.envs.gym_soccer_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='pdqn'),
+)
+gym_soccer_pdqn_create_config = EasyDict(gym_soccer_pdqn_create_config)
+create_config = gym_soccer_pdqn_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c gym_soccer_pdqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/gym_soccer/envs/README.md b/DI-engine/dizoo/gym_soccer/envs/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..8046296b3856087ab9f5075b9ae28a0d484b4050
--- /dev/null
+++ b/DI-engine/dizoo/gym_soccer/envs/README.md
@@ -0,0 +1,11 @@
+# How to replay a log
+
+1. Set the log path to store episode logs by the following command:
+
+ `env.enable_save_replay('./game_log')`
+
+2. After running the game, you can see some log files in the game_log directory.
+
+3. Execute the following command to replay the log file (*.rcg)
+
+ ` env.replay_log("game_log/20211019011053-base_left_0-vs-base_right_0.rcg")`
\ No newline at end of file
diff --git a/DI-engine/dizoo/gym_soccer/envs/__init__.py b/DI-engine/dizoo/gym_soccer/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/gym_soccer/envs/gym_soccer_env.py b/DI-engine/dizoo/gym_soccer/envs/gym_soccer_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..b0e759b495e864da9ed3c4c5e2c9a0e070525ffa
--- /dev/null
+++ b/DI-engine/dizoo/gym_soccer/envs/gym_soccer_env.py
@@ -0,0 +1,165 @@
+import sys
+from typing import Any, List, Optional, Union
+
+import gym
+import gym_soccer
+import numpy as np
+from ding.envs import BaseEnv, BaseEnvInfo, BaseEnvTimestep
+from ding.envs.common.common_function import affine_transform
+from ding.envs.common.env_element import EnvElementInfo
+from ding.torch_utils import to_list, to_ndarray, to_tensor
+from ding.utils import ENV_REGISTRY
+from gym.utils import seeding
+import copy
+
+
+@ENV_REGISTRY.register('gym_soccer')
+class GymSoccerEnv(BaseEnv):
+ default_env_id = ['Soccer-v0', 'SoccerEmptyGoal-v0', 'SoccerAgainstKeeper-v0']
+
+ def __init__(self, cfg: dict = {}) -> None:
+ self._cfg = cfg
+ self._act_scale = cfg.act_scale
+ self._env_id = cfg.env_id
+ assert self._env_id in self.default_env_id
+ self._init_flag = False
+ self._replay_path = './game_log'
+
+ def reset(self) -> np.array:
+ if not self._init_flag:
+ self._env = gym.make(self._env_id, replay_path=self._replay_path, port=self._cfg.port) # TODO
+ self._init_flag = True
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+ obs = to_ndarray(obs).astype(np.float32)
+ return obs
+
+ def step(self, action: List) -> BaseEnvTimestep:
+ if self._act_scale:
+ # The continuous action is a Tensor of size = (1,)
+ # We indexed at [0] to fetch it as a scalar value
+ action[1][0] = affine_transform(action[1][0], min_val=0, max_val=100)
+ action[2][0] = affine_transform(action[2][0], min_val=-180, max_val=180)
+ action[3][0] = affine_transform(action[3][0], min_val=-180, max_val=180)
+ action[4][0] = affine_transform(action[4][0], min_val=0, max_val=100)
+ action[5][0] = affine_transform(action[5][0], min_val=-180, max_val=180)
+
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ obs = to_ndarray(obs).astype(np.float32)
+ # reward wrapped to be transfered to a numpy array with shape (1,)
+ rew = to_ndarray([rew])
+ # '1' indicates the discrete action is associated with the continuous parameters
+ info['action_args_mask'] = np.array([[1, 1, 0, 0, 0], [0, 0, 1, 0, 0], [0, 0, 0, 1, 1]])
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def close(self) -> None:
+ self._init_flag = False
+
+ def get_random_action(self):
+ # discrete action type: 0, 1, 2
+ # continuous action_args:
+ # - power: [0, 100]
+ # - direction: [-180, 180]
+ # the action space is (6,), the first indicates discrete action and the remaining indicates continuous action
+ # discrete action 0 assotiated with the first and second continuous parameters
+ # discrete action 1 assotiated with the third continuous parameter
+ # discrete action 2 assotiated with the forth and fifth continuous parameters
+ return self._env.action_space.sample()
+
+ def info(self) -> BaseEnvInfo:
+ T = EnvElementInfo
+ return BaseEnvInfo(
+ agent_num=1,
+ obs_space=T(
+ (59, ),
+ {
+ # [min, max]
+ 'min': -1,
+ 'max': 1,
+ 'dtype': np.float32,
+ },
+ ),
+ act_space=T(
+ # the discrete action shape is (3,)
+ # however, the continuous action shape is (5,), which is not revealed in the info
+ (
+ 3,
+ ),
+ {
+ # [min, max)
+ 'min': 0,
+ 'max': 3,
+ 'dtype': int,
+ },
+ ),
+ rew_space=T(
+ (1, ),
+ {
+ # [min, max)
+ 'min': 0,
+ 'max': 2.0,
+ 'dtype': int,
+ },
+ ),
+ use_wrappers=None,
+ )
+
+ def render(self, close=False):
+ self._env.render(close)
+
+ def __repr__(self) -> str:
+ return "DI-engine gym soccer Env"
+
+ def replay_log(self, log_path):
+ self._env.replay_log(log_path)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './game_log'
+ self._replay_path = replay_path
+
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ """
+ Overview:
+ Return a list of all of the environment from input config.
+ Arguments:
+ - cfg (:obj:`Dict`) Env config, same config where ``self.__init__()`` takes arguments from
+ Returns:
+ - List of ``cfg`` including all of the collector env's config
+ """
+ cfg_list = []
+ collector_env_num = cfg.pop('collector_env_num')
+ port_pool = list(range(6000, 9999))
+ port_candidates = np.random.choice(port_pool, size=collector_env_num, replace=False)
+ for i in range(collector_env_num):
+ cfg_copy = copy.deepcopy(cfg)
+ cfg_copy.port = port_candidates[i]
+ cfg_list.append(cfg_copy)
+ return cfg_list
+
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ """
+ Overview:
+ Return a list of all of the environment from input config.
+ Arguments:
+ - cfg (:obj:`Dict`) Env config, same config where ``self.__init__()`` takes arguments from
+ Returns:
+ - List of ``cfg`` including all of the evaluator env's config
+ """
+ cfg_list = []
+ evaluator_env_num = cfg.pop('evaluator_env_num')
+ port_pool = list(range(6000, 9999))
+ port_candidates = np.random.choice(port_pool, size=evaluator_env_num, replace=False)
+ for i in range(evaluator_env_num):
+ cfg_copy = copy.deepcopy(cfg)
+ cfg_copy.port = port_candidates[i]
+ cfg_list.append(cfg_copy)
+ return cfg_list
diff --git a/DI-engine/dizoo/gym_soccer/envs/test_gym_soccer_env.py b/DI-engine/dizoo/gym_soccer/envs/test_gym_soccer_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..50bedd89acb0c74d7d55353fb77ed56b23bd6e93
--- /dev/null
+++ b/DI-engine/dizoo/gym_soccer/envs/test_gym_soccer_env.py
@@ -0,0 +1,34 @@
+import numpy as np
+import pytest
+from dizoo.gym_soccer.envs.gym_soccer_env import GymSoccerEnv
+from easydict import EasyDict
+
+
+@pytest.mark.envtest
+class TestGymSoccerEnv:
+
+ def test_naive(self):
+ env = GymSoccerEnv(EasyDict({'env_id': 'Soccer-v0', 'act_scale': True}))
+ # env.enable_save_replay('./video')
+ env.seed(25, dynamic_seed=False)
+ assert env._seed == 25
+ obs = env.reset()
+ assert obs.shape == (59, )
+ for i in range(1000):
+ random_action = env.get_random_action()
+ # print('random_action', random_action)
+ timestep = env.step(random_action)
+ # env.render()
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (59, )
+ # print(timestep.obs)
+ assert timestep.reward.shape == (1, )
+ assert timestep.info['action_args_mask'].shape == (3, 5)
+ if timestep.done:
+ print('reset env')
+ env.reset()
+ assert env._eval_episode_return == 0
+ print(env.info())
+ # env.replay_log("./video/20211019011053-base_left_0-vs-base_right_0.rcg")
+ env.close()
diff --git a/DI-engine/dizoo/image_classification/__init__.py b/DI-engine/dizoo/image_classification/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/image_classification/data/__init__.py b/DI-engine/dizoo/image_classification/data/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..88b324a4e9617afe0d3467f6c1fab3ecda75cf2a
--- /dev/null
+++ b/DI-engine/dizoo/image_classification/data/__init__.py
@@ -0,0 +1,2 @@
+from .dataset import ImageNetDataset
+from .sampler import DistributedSampler
diff --git a/DI-engine/dizoo/image_classification/data/dataset.py b/DI-engine/dizoo/image_classification/data/dataset.py
new file mode 100644
index 0000000000000000000000000000000000000000..9bfa458c4678a057344321aada449416f3e85eb0
--- /dev/null
+++ b/DI-engine/dizoo/image_classification/data/dataset.py
@@ -0,0 +1,139 @@
+from typing import Callable, Union
+import os
+import re
+import math
+from PIL import Image
+import numpy as np
+import torch
+import torch.utils.data as data
+from torchvision import transforms
+
+
+class ToNumpy:
+
+ def __call__(self, pil_img):
+ np_img = np.array(pil_img, dtype=np.uint8)
+ if np_img.ndim < 3:
+ np_img = np.expand_dims(np_img, axis=-1)
+ np_img = np.rollaxis(np_img, 2) # HWC to CHW
+ return np_img
+
+
+def _pil_interp(method):
+ if method == 'bicubic':
+ return Image.BICUBIC
+ elif method == 'lanczos':
+ return Image.LANCZOS
+ elif method == 'hamming':
+ return Image.HAMMING
+ else:
+ # default bilinear, do we want to allow nearest?
+ return Image.BILINEAR
+
+
+def natural_key(string_):
+ return [int(s) if s.isdigit() else s for s in re.split(r'(\d+)', string_.lower())]
+
+
+def find_images_and_targets(folder, types=('.png', '.jpg', '.jpeg'), class_to_idx=None, leaf_name_only=True, sort=True):
+ labels = []
+ filenames = []
+ for root, subdirs, files in os.walk(folder, topdown=False, followlinks=True):
+ rel_path = os.path.relpath(root, folder) if (root != folder) else ''
+ label = os.path.basename(rel_path) if leaf_name_only else rel_path.replace(os.path.sep, '_')
+ for f in files:
+ base, ext = os.path.splitext(f)
+ if ext.lower() in types:
+ filenames.append(os.path.join(root, f))
+ labels.append(label)
+ if class_to_idx is None:
+ # building class index
+ unique_labels = set(labels)
+ sorted_labels = list(sorted(unique_labels, key=natural_key))
+ class_to_idx = {c: idx for idx, c in enumerate(sorted_labels)}
+ images_and_targets = [(f, class_to_idx[l]) for f, l in zip(filenames, labels) if l in class_to_idx]
+ if sort:
+ images_and_targets = sorted(images_and_targets, key=lambda k: natural_key(k[0]))
+ return images_and_targets, class_to_idx
+
+
+IMAGENET_DEFAULT_MEAN = (0.485, 0.456, 0.406)
+IMAGENET_DEFAULT_STD = (0.229, 0.224, 0.225)
+DEFAULT_CROP_PCT = 0.875
+
+
+def transforms_noaug_train(
+ img_size=224,
+ interpolation='bilinear',
+ use_prefetcher=False,
+ mean=IMAGENET_DEFAULT_MEAN,
+ std=IMAGENET_DEFAULT_STD,
+):
+ if interpolation == 'random':
+ # random interpolation not supported with no-aug
+ interpolation = 'bilinear'
+ tfl = [transforms.Resize(img_size, _pil_interp(interpolation)), transforms.CenterCrop(img_size)]
+ if use_prefetcher:
+ # prefetcher and collate will handle tensor conversion and norm
+ tfl += [ToNumpy()]
+ else:
+ tfl += [transforms.ToTensor(), transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
+ return transforms.Compose(tfl)
+
+
+def transforms_imagenet_eval(
+ img_size=224,
+ crop_pct=None,
+ interpolation='bilinear',
+ use_prefetcher=False,
+ mean=IMAGENET_DEFAULT_MEAN,
+ std=IMAGENET_DEFAULT_STD
+):
+ crop_pct = crop_pct or DEFAULT_CROP_PCT
+
+ if isinstance(img_size, (tuple, list)):
+ assert len(img_size) == 2
+ if img_size[-1] == img_size[-2]:
+ # fall-back to older behaviour so Resize scales to shortest edge if target is square
+ scale_size = int(math.floor(img_size[0] / crop_pct))
+ else:
+ scale_size = tuple([int(x / crop_pct) for x in img_size])
+ else:
+ scale_size = int(math.floor(img_size / crop_pct))
+
+ tfl = [
+ transforms.Resize(scale_size, _pil_interp(interpolation)),
+ transforms.CenterCrop(img_size),
+ ]
+ if use_prefetcher:
+ # prefetcher and collate will handle tensor conversion and norm
+ tfl += [ToNumpy()]
+ else:
+ tfl += [transforms.ToTensor(), transforms.Normalize(mean=torch.tensor(mean), std=torch.tensor(std))]
+
+ return transforms.Compose(tfl)
+
+
+class ImageNetDataset(data.Dataset):
+
+ def __init__(self, root: str, is_training: bool, transform: Callable = None) -> None:
+ self.root = root
+ if transform is None:
+ if is_training:
+ transform = transforms_noaug_train()
+ else:
+ transform = transforms_imagenet_eval()
+ self.transform = transform
+ self.data, _ = find_images_and_targets(root)
+
+ def __len__(self) -> int:
+ return len(self.data)
+
+ def __getitem__(self, index: int) -> Union[torch.Tensor, torch.Tensor]:
+ img, target = self.data[index]
+ img = Image.open(img).convert('RGB')
+ if self.transform is not None:
+ img = self.transform(img)
+ if target is None:
+ target = torch.tensor(-1, dtype=torch.long)
+ return img, target
diff --git a/DI-engine/dizoo/image_classification/data/sampler.py b/DI-engine/dizoo/image_classification/data/sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..6e60004ccc5be12a27c988241042021b71329fe2
--- /dev/null
+++ b/DI-engine/dizoo/image_classification/data/sampler.py
@@ -0,0 +1,65 @@
+import math
+import torch
+from torch.utils.data import Sampler
+from ding.utils import get_rank, get_world_size
+
+
+class DistributedSampler(Sampler):
+ """Sampler that restricts data loading to a subset of the dataset.
+
+ It is especially useful in conjunction with
+ :class:`torch.nn.parallel.DistributedDataParallel`. In such case, each
+ process can pass a DistributedSampler instance as a DataLoader sampler,
+ and load a subset of the original dataset that is exclusive to it.
+
+ .. note::
+ Dataset is assumed to be of constant size.
+
+ Arguments:
+ dataset: Dataset used for sampling.
+ world_size (optional): Number of processes participating in
+ distributed training.
+ rank (optional): Rank of the current process within world_size.
+ """
+
+ def __init__(self, dataset, world_size=None, rank=None, round_up=True):
+ if world_size is None:
+ world_size = get_world_size()
+ if rank is None:
+ rank = get_rank()
+ self.dataset = dataset
+ self.world_size = world_size
+ self.rank = rank
+ self.round_up = round_up
+ self.epoch = 0
+
+ self.num_samples = int(math.ceil(len(self.dataset) * 1.0 / self.world_size))
+ if self.round_up:
+ self.total_size = self.num_samples * self.world_size
+ else:
+ self.total_size = len(self.dataset)
+
+ def __iter__(self):
+ # deterministically shuffle based on epoch
+ g = torch.Generator()
+ g.manual_seed(self.epoch)
+ indices = list(torch.randperm(len(self.dataset), generator=g))
+
+ # add extra samples to make it evenly divisible
+ if self.round_up:
+ indices += indices[:(self.total_size - len(indices))]
+ assert len(indices) == self.total_size
+
+ # subsample
+ offset = self.num_samples * self.rank
+ indices = indices[offset:offset + self.num_samples]
+ if self.round_up or (not self.round_up and self.rank < self.world_size - 1):
+ assert len(indices) == self.num_samples
+
+ return iter(indices)
+
+ def __len__(self):
+ return self.num_samples
+
+ def set_epoch(self, epoch):
+ self.epoch = epoch
diff --git a/DI-engine/dizoo/image_classification/entry/imagenet_res18_config.py b/DI-engine/dizoo/image_classification/entry/imagenet_res18_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd4f473dd6838a17143dab1e4fe1ecfa94775e39
--- /dev/null
+++ b/DI-engine/dizoo/image_classification/entry/imagenet_res18_config.py
@@ -0,0 +1,34 @@
+from easydict import EasyDict
+
+imagenet_res18_config = dict(
+ exp_name='imagenet_res18',
+ policy=dict(
+ cuda=True,
+ multi_gpu=True,
+ learn=dict(
+ bp_update_sync=True,
+ train_epoch=200,
+ batch_size=32,
+ learning_rate=0.01,
+ decay_epoch=30,
+ decay_rate=0.1,
+ warmup_lr=1e-4,
+ warmup_epoch=3,
+ weight_decay=1e-4,
+ learner=dict(
+ log_show_freq=10,
+ hook=dict(
+ log_show_after_iter=int(1e9), # use user-defined hook, disable it
+ save_ckpt_after_iter=1000,
+ )
+ )
+ ),
+ collect=dict(
+ learn_data_path='/mnt/lustre/share/images/train',
+ eval_data_path='/mnt/lustre/share/images/val',
+ ),
+ eval=dict(batch_size=32, evaluator=dict(eval_freq=1, stop_value=dict(loss=0.5, acc1=75.0, acc5=95.0))),
+ ),
+ env=dict(),
+)
+imagenet_res18_config = EasyDict(imagenet_res18_config)
diff --git a/DI-engine/dizoo/image_classification/entry/imagenet_res18_main.py b/DI-engine/dizoo/image_classification/entry/imagenet_res18_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..7e49736ab2930c5ecc9b4412b817f81c5ffa3095
--- /dev/null
+++ b/DI-engine/dizoo/image_classification/entry/imagenet_res18_main.py
@@ -0,0 +1,172 @@
+from typing import Union, Optional, Tuple, List
+import time
+import os
+import torch
+from tensorboardX import SummaryWriter
+from torch.utils.data import DataLoader
+
+from ding.worker import BaseLearner, LearnerHook, MetricSerialEvaluator, IMetric
+from ding.config import read_config, compile_config
+from ding.torch_utils import resnet18
+from ding.utils import set_pkg_seed, get_rank, dist_init
+from dizoo.image_classification.policy import ImageClassificationPolicy
+from dizoo.image_classification.data import ImageNetDataset, DistributedSampler
+from dizoo.image_classification.entry.imagenet_res18_config import imagenet_res18_config
+
+
+class ImageClsLogShowHook(LearnerHook):
+
+ def __init__(self, *args, freq: int = 1, **kwargs) -> None:
+ super().__init__(*args, **kwargs)
+ self._freq = freq
+
+ def __call__(self, engine: 'BaseLearner') -> None: # noqa
+ # Only show log for rank 0 learner
+ if engine.rank != 0:
+ for k in engine.log_buffer:
+ engine.log_buffer[k].clear()
+ return
+ # For 'scalar' type variables: log_buffer -> tick_monitor -> monitor_time.step
+ for k, v in engine.log_buffer['scalar'].items():
+ setattr(engine.monitor, k, v)
+ engine.monitor.time.step()
+
+ iters = engine.last_iter.val
+ if iters % self._freq == 0:
+ # For 'scalar' type variables: tick_monitor -> var_dict -> text_logger & tb_logger
+ var_dict = {}
+ log_vars = engine.policy.monitor_vars()
+ attr = 'avg'
+ for k in log_vars:
+ k_attr = k + '_' + attr
+ var_dict[k_attr] = getattr(engine.monitor, attr)[k]()
+ # user-defined variable
+ var_dict['data_time_val'] = engine.data_time
+ epoch_info = engine.epoch_info
+ var_dict['epoch_val'] = epoch_info[0]
+ engine.logger.info(
+ 'Epoch: {} [{:>4d}/{}]\t'
+ 'Loss: {:>6.4f}\t'
+ 'Data Time: {:.3f}\t'
+ 'Forward Time: {:.3f}\t'
+ 'Backward Time: {:.3f}\t'
+ 'GradSync Time: {:.3f}\t'
+ 'LR: {:.3e}'.format(
+ var_dict['epoch_val'], epoch_info[1], epoch_info[2], var_dict['total_loss_avg'],
+ var_dict['data_time_val'], var_dict['forward_time_avg'], var_dict['backward_time_avg'],
+ var_dict['sync_time_avg'], var_dict['cur_lr_avg']
+ )
+ )
+ for k, v in var_dict.items():
+ engine.tb_logger.add_scalar('{}/'.format(engine.instance_name) + k, v, iters)
+ # For 'histogram' type variables: log_buffer -> tb_var_dict -> tb_logger
+ tb_var_dict = {}
+ for k in engine.log_buffer['histogram']:
+ new_k = '{}/'.format(engine.instance_name) + k
+ tb_var_dict[new_k] = engine.log_buffer['histogram'][k]
+ for k, v in tb_var_dict.items():
+ engine.tb_logger.add_histogram(k, v, iters)
+ for k in engine.log_buffer:
+ engine.log_buffer[k].clear()
+
+
+class ImageClassificationMetric(IMetric):
+
+ def __init__(self) -> None:
+ self.loss = torch.nn.CrossEntropyLoss()
+
+ @staticmethod
+ def accuracy(inputs: torch.Tensor, label: torch.Tensor, topk: Tuple = (1, 5)) -> dict:
+ """Computes the accuracy over the k top predictions for the specified values of k"""
+ maxk = max(topk)
+ batch_size = label.size(0)
+ _, pred = inputs.topk(maxk, 1, True, True)
+ pred = pred.t()
+ correct = pred.eq(label.reshape(1, -1).expand_as(pred))
+ return {'acc{}'.format(k): correct[:k].reshape(-1).float().sum(0) * 100. / batch_size for k in topk}
+
+ def eval(self, inputs: torch.Tensor, label: torch.Tensor) -> dict:
+ """
+ Returns:
+ - eval_result (:obj:`dict`): {'loss': xxx, 'acc1': xxx, 'acc5': xxx}
+ """
+ loss = self.loss(inputs, label)
+ output = self.accuracy(inputs, label)
+ output['loss'] = loss
+ for k in output:
+ output[k] = output[k].item()
+ return output
+
+ def reduce_mean(self, inputs: List[dict]) -> dict:
+ L = len(inputs)
+ output = {}
+ for k in inputs[0].keys():
+ output[k] = sum([t[k] for t in inputs]) / L
+ return output
+
+ def gt(self, metric1: dict, metric2: dict) -> bool:
+ if metric2 is None:
+ return True
+ for k in metric1:
+ if metric1[k] < metric2[k]:
+ return False
+ return True
+
+
+def main(cfg: dict, seed: int) -> None:
+ cfg = compile_config(cfg, seed=seed, policy=ImageClassificationPolicy, evaluator=MetricSerialEvaluator)
+ if cfg.policy.multi_gpu:
+ rank, world_size = dist_init()
+ else:
+ rank, world_size = 0, 1
+
+ # Random seed
+ set_pkg_seed(cfg.seed + rank, use_cuda=cfg.policy.cuda)
+
+ model = resnet18()
+ policy = ImageClassificationPolicy(cfg.policy, model=model, enable_field=['learn', 'eval'])
+ learn_dataset = ImageNetDataset(cfg.policy.collect.learn_data_path, is_training=True)
+ eval_dataset = ImageNetDataset(cfg.policy.collect.eval_data_path, is_training=False)
+ if cfg.policy.multi_gpu:
+ learn_sampler = DistributedSampler(learn_dataset)
+ eval_sampler = DistributedSampler(eval_dataset)
+ else:
+ learn_sampler, eval_sampler = None, None
+ learn_dataloader = DataLoader(learn_dataset, cfg.policy.learn.batch_size, sampler=learn_sampler, num_workers=3)
+ eval_dataloader = DataLoader(eval_dataset, cfg.policy.eval.batch_size, sampler=eval_sampler, num_workers=2)
+
+ # Main components
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ log_show_hook = ImageClsLogShowHook(
+ name='image_cls_log_show_hook', priority=0, position='after_iter', freq=cfg.policy.learn.learner.log_show_freq
+ )
+ learner.register_hook(log_show_hook)
+ eval_metric = ImageClassificationMetric()
+ evaluator = MetricSerialEvaluator(
+ cfg.policy.eval.evaluator, [eval_dataloader, eval_metric], policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ # ==========
+ # Main loop
+ # ==========
+ learner.call_hook('before_run')
+ end = time.time()
+
+ for epoch in range(cfg.policy.learn.train_epoch):
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, epoch, 0)
+ if stop:
+ break
+ for i, train_data in enumerate(learn_dataloader):
+ learner.data_time = time.time() - end
+ learner.epoch_info = (epoch, i, len(learn_dataloader))
+ learner.train(train_data)
+ end = time.time()
+ learner.policy.get_attribute('lr_scheduler').step()
+
+ learner.call_hook('after_run')
+
+
+if __name__ == "__main__":
+ main(imagenet_res18_config, 0)
diff --git a/DI-engine/dizoo/image_classification/policy/__init__.py b/DI-engine/dizoo/image_classification/policy/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..61a97f7b6ee3a3c9108845badfe5eac56ca5906a
--- /dev/null
+++ b/DI-engine/dizoo/image_classification/policy/__init__.py
@@ -0,0 +1 @@
+from .policy import ImageClassificationPolicy
diff --git a/DI-engine/dizoo/image_classification/policy/policy.py b/DI-engine/dizoo/image_classification/policy/policy.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6eb9f60baf53950fe8b4748b45a540e8ba9c016
--- /dev/null
+++ b/DI-engine/dizoo/image_classification/policy/policy.py
@@ -0,0 +1,100 @@
+import math
+import torch
+import torch.nn as nn
+from torch.optim import SGD
+from torch.optim.lr_scheduler import LambdaLR
+
+from ding.policy import Policy
+from ding.model import model_wrap
+from ding.torch_utils import to_device
+from ding.utils import EasyTimer
+
+
+class ImageClassificationPolicy(Policy):
+ config = dict(
+ type='image_classification',
+ on_policy=False,
+ )
+
+ def _init_learn(self):
+ self._optimizer = SGD(
+ self._model.parameters(),
+ lr=self._cfg.learn.learning_rate,
+ weight_decay=self._cfg.learn.weight_decay,
+ momentum=0.9
+ )
+ self._timer = EasyTimer(cuda=True)
+
+ def lr_scheduler_fn(epoch):
+ if epoch <= self._cfg.learn.warmup_epoch:
+ return self._cfg.learn.warmup_lr / self._cfg.learn.learning_rate
+ else:
+ ratio = epoch // self._cfg.learn.decay_epoch
+ return math.pow(self._cfg.learn.decay_rate, ratio)
+
+ self._lr_scheduler = LambdaLR(self._optimizer, lr_scheduler_fn)
+ self._lr_scheduler.step()
+ self._learn_model = model_wrap(self._model, 'base')
+ self._learn_model.reset()
+
+ self._ce_loss = nn.CrossEntropyLoss()
+
+ def _forward_learn(self, data):
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._learn_model.train()
+
+ with self._timer:
+ img, target = data
+ logit = self._learn_model.forward(img)
+ loss = self._ce_loss(logit, target)
+ forward_time = self._timer.value
+
+ with self._timer:
+ self._optimizer.zero_grad()
+ loss.backward()
+ backward_time = self._timer.value
+
+ with self._timer:
+ if self._cfg.multi_gpu:
+ self.sync_gradients(self._learn_model)
+ sync_time = self._timer.value
+ self._optimizer.step()
+
+ cur_lr = [param_group['lr'] for param_group in self._optimizer.param_groups]
+ cur_lr = sum(cur_lr) / len(cur_lr)
+ return {
+ 'cur_lr': cur_lr,
+ 'total_loss': loss.item(),
+ 'forward_time': forward_time,
+ 'backward_time': backward_time,
+ 'sync_time': sync_time,
+ }
+
+ def _monitor_vars_learn(self):
+ return ['cur_lr', 'total_loss', 'forward_time', 'backward_time', 'sync_time']
+
+ def _init_eval(self):
+ self._eval_model = model_wrap(self._model, 'base')
+
+ def _forward_eval(self, data):
+ if self._cuda:
+ data = to_device(data, self._device)
+ self._eval_model.eval()
+ with torch.no_grad():
+ output = self._eval_model.forward(data)
+ if self._cuda:
+ output = to_device(output, 'cpu')
+ return output
+
+ def _init_collect(self):
+ pass
+
+ def _forward_collect(self, data):
+ pass
+
+ def _process_transition(self):
+ pass
+
+ def _get_train_sample(self):
+ pass
diff --git a/DI-engine/dizoo/league_demo/__init__.py b/DI-engine/dizoo/league_demo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/league_demo/demo_league.py b/DI-engine/dizoo/league_demo/demo_league.py
new file mode 100644
index 0000000000000000000000000000000000000000..07465efa3473eae53bd6c57b468e2846ca25a009
--- /dev/null
+++ b/DI-engine/dizoo/league_demo/demo_league.py
@@ -0,0 +1,43 @@
+import os
+import shutil
+from easydict import EasyDict
+from ding.league import BaseLeague, ActivePlayer
+
+
+class DemoLeague(BaseLeague):
+
+ def __init__(self, cfg):
+ super(DemoLeague, self).__init__(cfg)
+ self.reset_checkpoint_path = os.path.join(self.path_policy, 'reset_ckpt.pth')
+
+ # override
+ def _get_job_info(self, player: ActivePlayer, eval_flag: bool = False) -> dict:
+ assert isinstance(player, ActivePlayer), player.__class__
+ player_job_info = EasyDict(player.get_job(eval_flag))
+ return {
+ 'agent_num': 2,
+ 'launch_player': player.player_id,
+ 'player_id': [player.player_id, player_job_info.opponent.player_id],
+ 'checkpoint_path': [player.checkpoint_path, player_job_info.opponent.checkpoint_path],
+ 'player_active_flag': [isinstance(p, ActivePlayer) for p in [player, player_job_info.opponent]],
+ }
+
+ # override
+ def _mutate_player(self, player: ActivePlayer):
+ for p in self.active_players:
+ result = p.mutate({'reset_checkpoint_path': self.reset_checkpoint_path})
+ if result is not None:
+ p.rating = self.metric_env.create_rating()
+ self.load_checkpoint(p.player_id, result) # load_checkpoint is set by the caller of league
+ self.save_checkpoint(result, p.checkpoint_path)
+
+ # override
+ def _update_player(self, player: ActivePlayer, player_info: dict) -> None:
+ assert isinstance(player, ActivePlayer)
+ if 'learner_step' in player_info:
+ player.total_agent_step = player_info['learner_step']
+
+ # override
+ @staticmethod
+ def save_checkpoint(src_checkpoint_path: str, dst_checkpoint_path: str) -> None:
+ shutil.copy(src_checkpoint_path, dst_checkpoint_path)
diff --git a/DI-engine/dizoo/league_demo/game_env.py b/DI-engine/dizoo/league_demo/game_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c97d12616602af497b0d3d297cac74c2d18bc11
--- /dev/null
+++ b/DI-engine/dizoo/league_demo/game_env.py
@@ -0,0 +1,91 @@
+from typing import List
+import numpy as np
+import gym
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+
+
+class GameEnv(BaseEnv):
+
+ def __init__(self, game_type: str = 'prisoner_dilemma') -> None:
+ self.game_type = game_type
+ assert self.game_type in ['zero_sum', 'prisoner_dilemma']
+ if self.game_type == 'prisoner_dilemma':
+ self.optimal_policy = [0, 1]
+ elif self.game_type == 'zero_sum':
+ self.optimal_policy = [0.375, 0.625]
+ self._observation_space = None
+ self._action_space = None
+ self._reward_space = None
+
+ def seed(self, seed: int, dynamic_seed: bool = False) -> None:
+ # ignore seed
+ pass
+
+ def reset(self) -> np.ndarray:
+ return np.array([[0, 1], [1, 0]]).astype(np.float32) # trivial observation
+
+ def step(self, actions: List[int]) -> BaseEnvTimestep:
+ if self.game_type == 'zero_sum':
+ if actions == [0, 0]:
+ rewards = 3, -3
+ results = "wins", "losses"
+ elif actions == [0, 1]:
+ rewards = -2, 2
+ results = "losses", "wins"
+ elif actions == [1, 0]:
+ rewards = -2, 2
+ results = "losses", "wins"
+ elif actions == [1, 1]:
+ rewards = 1, -1
+ results = "wins", "losses"
+ else:
+ raise RuntimeError("invalid actions: {}".format(actions))
+ elif self.game_type == 'prisoner_dilemma':
+ if actions == [0, 0]:
+ rewards = -1, -1
+ results = "draws", "draws"
+ elif actions == [0, 1]:
+ rewards = -20, 0
+ results = "losses", "wins"
+ elif actions == [1, 0]:
+ rewards = 0, -20
+ results = "wins", "losses"
+ elif actions == [1, 1]:
+ rewards = -10, -10
+ results = 'draws', 'draws'
+ else:
+ raise RuntimeError("invalid actions: {}".format(actions))
+ observations = np.array([[0, 1], [1, 0]]).astype(np.float32)
+ rewards = np.array(rewards).astype(np.float32)
+ rewards = rewards[..., np.newaxis]
+ dones = True, True
+ infos = {
+ 'result': results[0],
+ 'eval_episode_return': rewards[0]
+ }, {
+ 'result': results[1],
+ 'eval_episode_return': rewards[1]
+ }
+ return BaseEnvTimestep(observations, rewards, True, infos)
+
+ def close(self) -> None:
+ pass
+
+ def __repr__(self) -> str:
+ return "DI-engine League Demo GameEnv"
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def random_action(self) -> List[int]:
+ return [np.random.randint(0, 2) for _ in range(2)]
diff --git a/DI-engine/dizoo/league_demo/league_demo_collector.py b/DI-engine/dizoo/league_demo/league_demo_collector.py
new file mode 100644
index 0000000000000000000000000000000000000000..ce7985a6dccc7ee5ec593d7d6cf4660f79957a0a
--- /dev/null
+++ b/DI-engine/dizoo/league_demo/league_demo_collector.py
@@ -0,0 +1,353 @@
+from typing import Optional, Any, List, Tuple
+from collections import namedtuple, deque
+from easydict import EasyDict
+import numpy as np
+import torch
+
+from ding.envs import BaseEnvManager
+from ding.utils import build_logger, EasyTimer, SERIAL_COLLECTOR_REGISTRY, dicts_to_lists
+from ding.torch_utils import to_tensor, to_ndarray
+from ding.worker.collector.base_serial_collector import ISerialCollector, CachePool, TrajBuffer, INF, \
+ to_tensor_transitions
+
+
+@SERIAL_COLLECTOR_REGISTRY.register('league_demo')
+class LeagueDemoCollector(ISerialCollector):
+ """
+ Overview:
+ League demo collector, derived from BattleEpisodeSerialCollector, add action probs viz.
+ Interfaces:
+ __init__, reset, reset_env, reset_policy, collect, close
+ Property:
+ envstep
+ """
+
+ config = dict(deepcopy_obs=False, transform_obs=False, collect_print_freq=100, get_train_sample=False)
+
+ def __init__(
+ self,
+ cfg: EasyDict,
+ env: BaseEnvManager = None,
+ policy: List[namedtuple] = None,
+ tb_logger: 'SummaryWriter' = None, # noqa
+ exp_name: Optional[str] = 'default_experiment',
+ instance_name: Optional[str] = 'collector'
+ ) -> None:
+ """
+ Overview:
+ Initialization method.
+ Arguments:
+ - cfg (:obj:`EasyDict`): Config dict
+ - env (:obj:`BaseEnvManager`): the subclass of vectorized env_manager(BaseEnvManager)
+ - policy (:obj:`List[namedtuple]`): the api namedtuple of collect_mode policy
+ - tb_logger (:obj:`SummaryWriter`): tensorboard handle
+ """
+ self._exp_name = exp_name
+ self._instance_name = instance_name
+ self._collect_print_freq = cfg.collect_print_freq
+ self._deepcopy_obs = cfg.deepcopy_obs
+ self._transform_obs = cfg.transform_obs
+ self._cfg = cfg
+ self._timer = EasyTimer()
+ self._end_flag = False
+
+ if tb_logger is not None:
+ self._logger, _ = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name, need_tb=False
+ )
+ self._tb_logger = tb_logger
+ else:
+ self._logger, self._tb_logger = build_logger(
+ path='./{}/log/{}'.format(self._exp_name, self._instance_name), name=self._instance_name
+ )
+ self._traj_len = float("inf")
+ self.reset(policy, env)
+
+ def reset_env(self, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset the environment.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the collector with the new passed \
+ in environment and launch.
+ Arguments:
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self._env = _env
+ self._env.launch()
+ self._env_num = self._env.env_num
+ else:
+ self._env.reset()
+
+ def reset_policy(self, _policy: Optional[List[namedtuple]] = None) -> None:
+ """
+ Overview:
+ Reset the policy.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the collector with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy
+ """
+ assert hasattr(self, '_env'), "please set env first"
+ if _policy is not None:
+ assert len(_policy) == 2, "1v1 episode collector needs 2 policy, but found {}".format(len(_policy))
+ self._policy = _policy
+ self._default_n_episode = _policy[0].get_attribute('cfg').collect.get('n_episode', None)
+ self._unroll_len = _policy[0].get_attribute('unroll_len')
+ self._on_policy = _policy[0].get_attribute('cfg').on_policy
+ self._traj_len = INF
+ self._logger.debug(
+ 'Set default n_episode mode(n_episode({}), env_num({}), traj_len({}))'.format(
+ self._default_n_episode, self._env_num, self._traj_len
+ )
+ )
+ for p in self._policy:
+ p.reset()
+
+ def reset(self, _policy: Optional[List[namedtuple]] = None, _env: Optional[BaseEnvManager] = None) -> None:
+ """
+ Overview:
+ Reset the environment and policy.
+ If _env is None, reset the old environment.
+ If _env is not None, replace the old environment in the collector with the new passed \
+ in environment and launch.
+ If _policy is None, reset the old policy.
+ If _policy is not None, replace the old policy in the collector with the new passed in policy.
+ Arguments:
+ - policy (:obj:`Optional[List[namedtuple]]`): the api namedtuple of collect_mode policy
+ - env (:obj:`Optional[BaseEnvManager]`): instance of the subclass of vectorized \
+ env_manager(BaseEnvManager)
+ """
+ if _env is not None:
+ self.reset_env(_env)
+ if _policy is not None:
+ self.reset_policy(_policy)
+
+ self._obs_pool = CachePool('obs', self._env_num, deepcopy=self._deepcopy_obs)
+ self._policy_output_pool = CachePool('policy_output', self._env_num)
+ # _traj_buffer is {env_id: {policy_id: TrajBuffer}}, is used to store traj_len pieces of transitions
+ self._traj_buffer = {
+ env_id: {policy_id: TrajBuffer(maxlen=self._traj_len)
+ for policy_id in range(2)}
+ for env_id in range(self._env_num)
+ }
+ self._env_info = {env_id: {'time': 0., 'step': 0} for env_id in range(self._env_num)}
+
+ self._episode_info = []
+ self._total_envstep_count = 0
+ self._total_episode_count = 0
+ self._total_duration = 0
+ self._last_train_iter = 0
+ self._end_flag = False
+
+ def _reset_stat(self, env_id: int) -> None:
+ """
+ Overview:
+ Reset the collector's state. Including reset the traj_buffer, obs_pool, policy_output_pool\
+ and env_info. Reset these states according to env_id. You can refer to base_serial_collector\
+ to get more messages.
+ Arguments:
+ - env_id (:obj:`int`): the id where we need to reset the collector's state
+ """
+ for i in range(2):
+ self._traj_buffer[env_id][i].clear()
+ self._obs_pool.reset(env_id)
+ self._policy_output_pool.reset(env_id)
+ self._env_info[env_id] = {'time': 0., 'step': 0}
+
+ @property
+ def envstep(self) -> int:
+ """
+ Overview:
+ Print the total envstep count.
+ Return:
+ - envstep (:obj:`int`): the total envstep count
+ """
+ return self._total_envstep_count
+
+ def close(self) -> None:
+ """
+ Overview:
+ Close the collector. If end_flag is False, close the environment, flush the tb_logger\
+ and close the tb_logger.
+ """
+ if self._end_flag:
+ return
+ self._end_flag = True
+ self._env.close()
+ self._tb_logger.flush()
+ self._tb_logger.close()
+
+ def __del__(self) -> None:
+ """
+ Overview:
+ Execute the close command and close the collector. __del__ is automatically called to \
+ destroy the collector instance when the collector finishes its work
+ """
+ self.close()
+
+ def collect(self,
+ n_episode: Optional[int] = None,
+ train_iter: int = 0,
+ policy_kwargs: Optional[dict] = None) -> Tuple[List[Any], List[Any]]:
+ """
+ Overview:
+ Collect `n_episode` data with policy_kwargs, which is already trained `train_iter` iterations
+ Arguments:
+ - n_episode (:obj:`int`): the number of collecting data episode
+ - train_iter (:obj:`int`): the number of training iteration
+ - policy_kwargs (:obj:`dict`): the keyword args for policy forward
+ Returns:
+ - return_data (:obj:`Tuple[List, List]`): A tuple with training sample(data) and episode info, \
+ the former is a list containing collected episodes if not get_train_sample, \
+ otherwise, return train_samples split by unroll_len.
+ """
+ if n_episode is None:
+ if self._default_n_episode is None:
+ raise RuntimeError("Please specify collect n_episode")
+ else:
+ n_episode = self._default_n_episode
+ assert n_episode >= self._env_num, "Please make sure n_episode >= env_num"
+ if policy_kwargs is None:
+ policy_kwargs = {}
+ collected_episode = 0
+ return_data = [[] for _ in range(2)]
+ return_info = [[] for _ in range(2)]
+ ready_env_id = set()
+ remain_episode = n_episode
+
+ while True:
+ with self._timer:
+ # Get current env obs.
+ obs = self._env.ready_obs
+ new_available_env_id = set(obs.keys()).difference(ready_env_id)
+ ready_env_id = ready_env_id.union(set(list(new_available_env_id)[:remain_episode]))
+ remain_episode -= min(len(new_available_env_id), remain_episode)
+ obs = {env_id: obs[env_id] for env_id in ready_env_id}
+ # Policy forward.
+ self._obs_pool.update(obs)
+ if self._transform_obs:
+ obs = to_tensor(obs, dtype=torch.float32)
+ obs = dicts_to_lists(obs)
+ policy_output = [p.forward(obs[i], **policy_kwargs) for i, p in enumerate(self._policy)]
+ self._policy_output_pool.update(policy_output)
+ # Interact with env.
+ actions = {}
+ for env_id in ready_env_id:
+ actions[env_id] = []
+ for output in policy_output:
+ actions[env_id].append(output[env_id]['action'])
+ actions = to_ndarray(actions)
+ # temporally for viz
+ probs0 = torch.softmax(torch.stack([o['logit'] for o in policy_output[0].values()], 0), 1).mean(0)
+ probs1 = torch.softmax(torch.stack([o['logit'] for o in policy_output[1].values()], 0), 1).mean(0)
+ timesteps = self._env.step(actions)
+
+ # TODO(nyz) this duration may be inaccurate in async env
+ interaction_duration = self._timer.value / len(timesteps)
+
+ # TODO(nyz) vectorize this for loop
+ for env_id, timestep in timesteps.items():
+ self._env_info[env_id]['step'] += 1
+ self._total_envstep_count += 1
+ with self._timer:
+ for policy_id, policy in enumerate(self._policy):
+ policy_timestep_data = [d[policy_id] if not isinstance(d, bool) else d for d in timestep]
+ policy_timestep = type(timestep)(*policy_timestep_data)
+ transition = self._policy[policy_id].process_transition(
+ self._obs_pool[env_id][policy_id], self._policy_output_pool[env_id][policy_id],
+ policy_timestep
+ )
+ transition['collect_iter'] = train_iter
+ self._traj_buffer[env_id][policy_id].append(transition)
+ # prepare data
+ if timestep.done:
+ transitions = to_tensor_transitions(self._traj_buffer[env_id][policy_id])
+ if self._cfg.get_train_sample:
+ train_sample = self._policy[policy_id].get_train_sample(transitions)
+ return_data[policy_id].extend(train_sample)
+ else:
+ return_data[policy_id].append(transitions)
+ self._traj_buffer[env_id][policy_id].clear()
+
+ self._env_info[env_id]['time'] += self._timer.value + interaction_duration
+
+ # If env is done, record episode info and reset
+ if timestep.done:
+ self._total_episode_count += 1
+ info = {
+ 'reward0': timestep.info[0]['eval_episode_return'],
+ 'reward1': timestep.info[1]['eval_episode_return'],
+ 'time': self._env_info[env_id]['time'],
+ 'step': self._env_info[env_id]['step'],
+ 'probs0': probs0,
+ 'probs1': probs1,
+ }
+ collected_episode += 1
+ self._episode_info.append(info)
+ for i, p in enumerate(self._policy):
+ p.reset([env_id])
+ self._reset_stat(env_id)
+ ready_env_id.remove(env_id)
+ for policy_id in range(2):
+ return_info[policy_id].append(timestep.info[policy_id])
+ if collected_episode >= n_episode:
+ break
+ # log
+ self._output_log(train_iter)
+ return return_data, return_info
+
+ def _output_log(self, train_iter: int) -> None:
+ """
+ Overview:
+ Print the output log information. You can refer to Docs/Best Practice/How to understand\
+ training generated folders/Serial mode/log/collector for more details.
+ Arguments:
+ - train_iter (:obj:`int`): the number of training iteration.
+ """
+ if (train_iter - self._last_train_iter) >= self._collect_print_freq and len(self._episode_info) > 0:
+ self._last_train_iter = train_iter
+ episode_count = len(self._episode_info)
+ envstep_count = sum([d['step'] for d in self._episode_info])
+ duration = sum([d['time'] for d in self._episode_info])
+ episode_return0 = [d['reward0'] for d in self._episode_info]
+ episode_return1 = [d['reward1'] for d in self._episode_info]
+ probs0 = [d['probs0'] for d in self._episode_info]
+ probs1 = [d['probs1'] for d in self._episode_info]
+ self._total_duration += duration
+ info = {
+ 'episode_count': episode_count,
+ 'envstep_count': envstep_count,
+ 'avg_envstep_per_episode': envstep_count / episode_count,
+ 'avg_envstep_per_sec': envstep_count / duration,
+ 'avg_episode_per_sec': episode_count / duration,
+ 'collect_time': duration,
+ 'reward0_mean': np.mean(episode_return0),
+ 'reward0_std': np.std(episode_return0),
+ 'reward0_max': np.max(episode_return0),
+ 'reward0_min': np.min(episode_return0),
+ 'reward1_mean': np.mean(episode_return1),
+ 'reward1_std': np.std(episode_return1),
+ 'reward1_max': np.max(episode_return1),
+ 'reward1_min': np.min(episode_return1),
+ 'total_envstep_count': self._total_envstep_count,
+ 'total_episode_count': self._total_episode_count,
+ 'total_duration': self._total_duration,
+ }
+ info.update(
+ {
+ 'probs0_select_action0': sum([p[0] for p in probs0]) / len(probs0),
+ 'probs0_select_action1': sum([p[1] for p in probs0]) / len(probs0),
+ 'probs1_select_action0': sum([p[0] for p in probs1]) / len(probs1),
+ 'probs1_select_action1': sum([p[1] for p in probs1]) / len(probs1),
+ }
+ )
+ self._episode_info.clear()
+ self._logger.info("collect end:\n{}".format('\n'.join(['{}: {}'.format(k, v) for k, v in info.items()])))
+ for k, v in info.items():
+ self._tb_logger.add_scalar('{}_iter/'.format(self._instance_name) + k, v, train_iter)
+ if k in ['total_envstep_count']:
+ continue
+ self._tb_logger.add_scalar('{}_step/'.format(self._instance_name) + k, v, self._total_envstep_count)
diff --git a/DI-engine/dizoo/league_demo/league_demo_ppo_config.py b/DI-engine/dizoo/league_demo/league_demo_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc9dfb6a2a53b965aeb1c7efff058272f8e68c58
--- /dev/null
+++ b/DI-engine/dizoo/league_demo/league_demo_ppo_config.py
@@ -0,0 +1,83 @@
+from easydict import EasyDict
+from torch.nn.modules.activation import Threshold
+
+league_demo_ppo_config = dict(
+ exp_name="league_demo_ppo",
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=10,
+ n_evaluator_episode=100,
+ env_type='prisoner_dilemma', # ['zero_sum', 'prisoner_dilemma']
+ stop_value=[-10.1, -5.05], # prisoner_dilemma
+ ),
+ policy=dict(
+ cuda=False,
+ action_space='discrete',
+ model=dict(
+ obs_shape=2,
+ action_shape=2,
+ action_space='discrete',
+ encoder_hidden_size_list=[32, 32],
+ critic_head_hidden_size=32,
+ actor_head_hidden_size=32,
+ share_encoder=False,
+ ),
+ learn=dict(
+ update_per_collect=3,
+ batch_size=32,
+ learning_rate=0.00001,
+ entropy_weight=0.0,
+ learner=dict(log_policy=False),
+ ),
+ collect=dict(
+ n_episode=128, unroll_len=1, discount_factor=1.0, gae_lambda=1.0, collector=dict(get_train_sample=True, )
+ ),
+ other=dict(
+ league=dict(
+ player_category=['default'],
+ path_policy="league_demo_ppo/policy",
+ active_players=dict(
+ main_player=1,
+ main_exploiter=1,
+ league_exploiter=1,
+ ),
+ main_player=dict(
+ one_phase_step=200,
+ branch_probs=dict(
+ pfsp=0.5,
+ sp=0.5,
+ ),
+ strong_win_rate=0.7,
+ ),
+ main_exploiter=dict(
+ one_phase_step=200,
+ branch_probs=dict(main_players=1.0, ),
+ strong_win_rate=0.7,
+ min_valid_win_rate=0.3,
+ ),
+ league_exploiter=dict(
+ one_phase_step=200,
+ branch_probs=dict(pfsp=1.0, ),
+ strong_win_rate=0.7,
+ mutate_prob=0.5,
+ ),
+ use_pretrain=False,
+ use_pretrain_init_historical=False,
+ payoff=dict(
+ type='battle',
+ decay=0.99,
+ min_win_rate_games=8,
+ ),
+ metric=dict(
+ mu=0,
+ sigma=25 / 3,
+ beta=25 / 3 / 2,
+ tau=0.0,
+ draw_probability=0.02,
+ ),
+ ),
+ ),
+ ),
+)
+league_demo_ppo_config = EasyDict(league_demo_ppo_config)
+# This config file can be executed by `dizoo/league_demo/league_demo_ppo_main.py`
diff --git a/DI-engine/dizoo/league_demo/league_demo_ppo_main.py b/DI-engine/dizoo/league_demo/league_demo_ppo_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..dffdd1ded7b040a48852ad508383ab54e5b75d1e
--- /dev/null
+++ b/DI-engine/dizoo/league_demo/league_demo_ppo_main.py
@@ -0,0 +1,247 @@
+import os
+import copy
+import gym
+import numpy as np
+import torch
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, BattleInteractionSerialEvaluator, NaiveReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import PPOPolicy
+from ding.model import VAC
+from ding.utils import set_pkg_seed, Scheduler, deep_merge_dicts
+from dizoo.league_demo.game_env import GameEnv
+from dizoo.league_demo.demo_league import DemoLeague
+from dizoo.league_demo.league_demo_collector import LeagueDemoCollector
+from dizoo.league_demo.league_demo_ppo_config import league_demo_ppo_config
+
+
+class EvalPolicy1:
+
+ def __init__(self, optimal_policy: list) -> None:
+ assert len(optimal_policy) == 2
+ self.optimal_policy = optimal_policy
+
+ def forward(self, data: dict) -> dict:
+ return {
+ env_id: {
+ 'action': torch.from_numpy(np.random.choice([0, 1], p=self.optimal_policy, size=(1, )))
+ }
+ for env_id in data.keys()
+ }
+
+ def reset(self, data_id: list = []) -> None:
+ pass
+
+
+class EvalPolicy2:
+
+ def forward(self, data: dict) -> dict:
+ return {
+ env_id: {
+ 'action': torch.from_numpy(np.random.choice([0, 1], p=[0.5, 0.5], size=(1, )))
+ }
+ for env_id in data.keys()
+ }
+
+ def reset(self, data_id: list = []) -> None:
+ pass
+
+
+def main(cfg, seed=0, max_train_iter=int(1e8), max_env_step=int(1e8)):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ PPOPolicy,
+ BaseLearner,
+ LeagueDemoCollector,
+ BattleInteractionSerialEvaluator,
+ NaiveReplayBuffer,
+ save_cfg=True
+ )
+ env_type = cfg.env.env_type
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ evaluator_env1 = BaseEnvManager(
+ env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env2 = BaseEnvManager(
+ env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env3 = BaseEnvManager(
+ env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ evaluator_env1.seed(seed, dynamic_seed=False)
+ evaluator_env2.seed(seed, dynamic_seed=False)
+ evaluator_env3.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ league = DemoLeague(cfg.policy.other.league)
+ eval_policy1 = EvalPolicy1(evaluator_env1._env_ref.optimal_policy)
+ eval_policy2 = EvalPolicy2()
+ policies = {}
+ learners = {}
+ collectors = {}
+
+ for player_id in league.active_players_ids:
+ # default set the same arch model(different init weight)
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ policies[player_id] = policy
+ collector_env = BaseEnvManager(
+ env_fn=[lambda: GameEnv(env_type) for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ collector_env.seed(seed)
+
+ learners[player_id] = BaseLearner(
+ cfg.policy.learn.learner,
+ policy.learn_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ instance_name=player_id + '_learner'
+ )
+ collectors[player_id] = LeagueDemoCollector(
+ cfg.policy.collect.collector,
+ collector_env,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ instance_name=player_id + '_collector',
+ )
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ policies['historical'] = policy
+ # use initial policy as another eval_policy
+ eval_policy3 = PPOPolicy(cfg.policy, model=copy.deepcopy(model)).collect_mode
+
+ main_key = [k for k in learners.keys() if k.startswith('main_player')][0]
+ main_player = league.get_player_by_id(main_key)
+ main_learner = learners[main_key]
+ main_collector = collectors[main_key]
+ # collect_mode ppo use multinomial sample for selecting action
+ evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
+ evaluator1_cfg.stop_value = cfg.env.stop_value[0]
+ evaluator1 = BattleInteractionSerialEvaluator(
+ evaluator1_cfg,
+ evaluator_env1, [policies[main_key].collect_mode, eval_policy1],
+ tb_logger,
+ exp_name=cfg.exp_name,
+ instance_name='fixed_evaluator'
+ )
+ evaluator2_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
+ evaluator2_cfg.stop_value = cfg.env.stop_value[1]
+ evaluator2 = BattleInteractionSerialEvaluator(
+ evaluator2_cfg,
+ evaluator_env2, [policies[main_key].collect_mode, eval_policy2],
+ tb_logger,
+ exp_name=cfg.exp_name,
+ instance_name='uniform_evaluator'
+ )
+ evaluator3_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
+ evaluator3_cfg.stop_value = 99999999 # stop_value of evaluator3 is a placeholder
+ evaluator3 = BattleInteractionSerialEvaluator(
+ evaluator3_cfg,
+ evaluator_env3, [policies[main_key].collect_mode, eval_policy3],
+ tb_logger,
+ exp_name=cfg.exp_name,
+ instance_name='init_evaluator'
+ )
+
+ def load_checkpoint_fn(player_id: str, ckpt_path: str):
+ state_dict = torch.load(ckpt_path)
+ policies[player_id].learn_mode.load_state_dict(state_dict)
+
+ torch.save(policies['historical'].learn_mode.state_dict(), league.reset_checkpoint_path)
+ league.load_checkpoint = load_checkpoint_fn
+ # snapshot the initial player as the first historial player
+ for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts):
+ torch.save(policies[player_id].collect_mode.state_dict(), player_ckpt_path)
+ league.judge_snapshot(player_id, force=True)
+ init_main_player_rating = league.metric_env.create_rating(mu=0)
+
+ count = 0
+ while True:
+ if evaluator1.should_eval(main_learner.train_iter):
+ stop_flag1, episode_info = evaluator1.eval(
+ main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep
+ )
+ win_loss_result = [e['result'] for e in episode_info[0]]
+ # set fixed NE policy trueskill(exposure) equal 10
+ main_player.rating = league.metric_env.rate_1vsC(
+ main_player.rating, league.metric_env.create_rating(mu=10, sigma=1e-8), win_loss_result
+ )
+
+ if evaluator2.should_eval(main_learner.train_iter):
+ stop_flag2, episode_info = evaluator2.eval(
+ main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep
+ )
+ win_loss_result = [e['result'] for e in episode_info[0]]
+ # set random(uniform) policy trueskill(exposure) equal 0
+ main_player.rating = league.metric_env.rate_1vsC(
+ main_player.rating, league.metric_env.create_rating(mu=0, sigma=1e-8), win_loss_result
+ )
+ if evaluator3.should_eval(main_learner.train_iter):
+ _, episode_info = evaluator3.eval(
+ main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep
+ )
+ win_loss_result = [e['result'] for e in episode_info[0]]
+ # use init main player as another evaluator metric
+ main_player.rating, init_main_player_rating = league.metric_env.rate_1vs1(
+ main_player.rating, init_main_player_rating, win_loss_result
+ )
+ tb_logger.add_scalar(
+ 'league/init_main_player_trueskill', init_main_player_rating.exposure, main_collector.envstep
+ )
+ if stop_flag1 and stop_flag2:
+ break
+
+ for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts):
+ tb_logger.add_scalar(
+ 'league/{}_trueskill'.format(player_id),
+ league.get_player_by_id(player_id).rating.exposure, main_collector.envstep
+ )
+ collector, learner = collectors[player_id], learners[player_id]
+ job = league.get_job_info(player_id)
+ opponent_player_id = job['player_id'][1]
+ # print('job player: {}'.format(job['player_id']))
+ if 'historical' in opponent_player_id:
+ opponent_policy = policies['historical'].collect_mode
+ opponent_path = job['checkpoint_path'][1]
+ opponent_policy.load_state_dict(torch.load(opponent_path, map_location='cpu'))
+ else:
+ opponent_policy = policies[opponent_player_id].collect_mode
+ collector.reset_policy([policies[player_id].collect_mode, opponent_policy])
+ train_data, episode_info = collector.collect(train_iter=learner.train_iter)
+ train_data, episode_info = train_data[0], episode_info[0] # only use launch player data for training
+ for d in train_data:
+ d['adv'] = d['reward']
+
+ for i in range(cfg.policy.learn.update_per_collect):
+ learner.train(train_data, collector.envstep)
+ torch.save(learner.policy.state_dict(), player_ckpt_path)
+
+ player_info = learner.learn_info
+ player_info['player_id'] = player_id
+ league.update_active_player(player_info)
+ league.judge_snapshot(player_id)
+ # set eval_flag=True to enable trueskill update
+ job_finish_info = {
+ 'eval_flag': True,
+ 'launch_player': job['launch_player'],
+ 'player_id': job['player_id'],
+ 'result': [e['result'] for e in episode_info],
+ }
+ league.finish_job(job_finish_info)
+
+ if main_collector.envstep >= max_env_step or main_learner.train_iter >= max_train_iter:
+ break
+ if count % 100 == 0:
+ print(repr(league.payoff))
+ count += 1
+
+
+if __name__ == "__main__":
+ main(league_demo_ppo_config)
diff --git a/DI-engine/dizoo/league_demo/selfplay_demo_ppo_config.py b/DI-engine/dizoo/league_demo/selfplay_demo_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..36a53d9f63c564e047503b78ac73011c20137c93
--- /dev/null
+++ b/DI-engine/dizoo/league_demo/selfplay_demo_ppo_config.py
@@ -0,0 +1,37 @@
+from easydict import EasyDict
+
+selfplay_demo_ppo_config = dict(
+ exp_name="selfplay_demo_ppo",
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=10,
+ n_evaluator_episode=100,
+ env_type='prisoner_dilemma', # ['zero_sum', 'prisoner_dilemma']
+ stop_value=[-10.1, -5.05], # prisoner_dilemma
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ cuda=False,
+ action_space='discrete',
+ model=dict(
+ obs_shape=2,
+ action_shape=2,
+ action_space='discrete',
+ encoder_hidden_size_list=[32, 32],
+ critic_head_hidden_size=32,
+ actor_head_hidden_size=32,
+ share_encoder=False,
+ ),
+ learn=dict(
+ update_per_collect=3,
+ batch_size=32,
+ learning_rate=0.00001,
+ entropy_weight=0.0,
+ ),
+ collect=dict(
+ n_episode=128, unroll_len=1, discount_factor=1.0, gae_lambda=1.0, collector=dict(get_train_sample=True, )
+ ),
+ ),
+)
+selfplay_demo_ppo_config = EasyDict(selfplay_demo_ppo_config)
+# This config file can be executed by `dizoo/league_demo/selfplay_demo_ppo_main.py`
diff --git a/DI-engine/dizoo/league_demo/selfplay_demo_ppo_main.py b/DI-engine/dizoo/league_demo/selfplay_demo_ppo_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..e6aed59f397ac1bc3a5f23ce781ce98b8c9dfae6
--- /dev/null
+++ b/DI-engine/dizoo/league_demo/selfplay_demo_ppo_main.py
@@ -0,0 +1,129 @@
+import os
+import gym
+import numpy as np
+import copy
+import torch
+from tensorboardX import SummaryWriter
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, BattleInteractionSerialEvaluator, NaiveReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import PPOPolicy
+from ding.model import VAC
+from ding.utils import set_pkg_seed
+from dizoo.league_demo.game_env import GameEnv
+from dizoo.league_demo.league_demo_collector import LeagueDemoCollector
+from dizoo.league_demo.selfplay_demo_ppo_config import selfplay_demo_ppo_config
+
+
+class EvalPolicy1:
+
+ def forward(self, data: dict) -> dict:
+ return {env_id: {'action': torch.zeros(1)} for env_id in data.keys()}
+
+ def reset(self, data_id: list = []) -> None:
+ pass
+
+
+class EvalPolicy2:
+
+ def forward(self, data: dict) -> dict:
+ return {
+ env_id: {
+ 'action': torch.from_numpy(np.random.choice([0, 1], p=[0.5, 0.5], size=(1, )))
+ }
+ for env_id in data.keys()
+ }
+
+ def reset(self, data_id: list = []) -> None:
+ pass
+
+
+def main(cfg, seed=0, max_train_iter=int(1e8), max_env_step=int(1e8)):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ PPOPolicy,
+ BaseLearner,
+ LeagueDemoCollector,
+ BattleInteractionSerialEvaluator,
+ NaiveReplayBuffer,
+ save_cfg=True
+ )
+ env_type = cfg.env.env_type
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(
+ env_fn=[lambda: GameEnv(env_type) for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env1 = BaseEnvManager(
+ env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env2 = BaseEnvManager(
+ env_fn=[lambda: GameEnv(env_type) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ collector_env.seed(seed)
+ evaluator_env1.seed(seed, dynamic_seed=False)
+ evaluator_env2.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model1 = VAC(**cfg.policy.model)
+ policy1 = PPOPolicy(cfg.policy, model=model1)
+ model2 = VAC(**cfg.policy.model)
+ policy2 = PPOPolicy(cfg.policy, model=model2)
+ eval_policy1 = EvalPolicy1()
+ eval_policy2 = EvalPolicy2()
+
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner1 = BaseLearner(
+ cfg.policy.learn.learner, policy1.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='learner1'
+ )
+ learner2 = BaseLearner(
+ cfg.policy.learn.learner, policy2.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='learner2'
+ )
+ collector = LeagueDemoCollector(
+ cfg.policy.collect.collector,
+ collector_env, [policy1.collect_mode, policy2.collect_mode],
+ tb_logger,
+ exp_name=cfg.exp_name
+ )
+ # collect_mode ppo use multinomial sample for selecting action
+ evaluator1_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
+ evaluator1_cfg.stop_value = cfg.env.stop_value[0]
+ evaluator1 = BattleInteractionSerialEvaluator(
+ evaluator1_cfg,
+ evaluator_env1, [policy1.collect_mode, eval_policy1],
+ tb_logger,
+ exp_name=cfg.exp_name,
+ instance_name='fixed_evaluator'
+ )
+ evaluator2_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
+ evaluator2_cfg.stop_value = cfg.env.stop_value[1]
+ evaluator2 = BattleInteractionSerialEvaluator(
+ evaluator2_cfg,
+ evaluator_env2, [policy1.collect_mode, eval_policy2],
+ tb_logger,
+ exp_name=cfg.exp_name,
+ instance_name='uniform_evaluator'
+ )
+
+ while True:
+ if evaluator1.should_eval(learner1.train_iter):
+ stop_flag1, _ = evaluator1.eval(learner1.save_checkpoint, learner1.train_iter, collector.envstep)
+ if evaluator2.should_eval(learner1.train_iter):
+ stop_flag2, _ = evaluator2.eval(learner1.save_checkpoint, learner1.train_iter, collector.envstep)
+ if stop_flag1 and stop_flag2:
+ break
+ train_data, _ = collector.collect(train_iter=learner1.train_iter)
+ for data in train_data:
+ for d in data:
+ d['adv'] = d['reward']
+ for i in range(cfg.policy.learn.update_per_collect):
+ learner1.train(train_data[0], collector.envstep)
+ learner2.train(train_data[1], collector.envstep)
+ if collector.envstep >= max_env_step or learner1.train_iter >= max_train_iter:
+ break
+
+
+if __name__ == "__main__":
+ main(selfplay_demo_ppo_config)
diff --git a/DI-engine/dizoo/mario/__init__.py b/DI-engine/dizoo/mario/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/mario/mario_dqn_config.py b/DI-engine/dizoo/mario/mario_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..817edffa42e1f146e2222265a075ed245516e2f2
--- /dev/null
+++ b/DI-engine/dizoo/mario/mario_dqn_config.py
@@ -0,0 +1,49 @@
+from easydict import EasyDict
+
+mario_dqn_config = dict(
+ exp_name='mario_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=100000,
+ replay_path='mario_dqn_seed0/video',
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[4, 84, 84],
+ action_shape=2,
+ encoder_hidden_size_list=[128, 128, 256],
+ dueling=True,
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=96, ),
+ eval=dict(evaluator=dict(eval_freq=2000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+mario_dqn_config = EasyDict(mario_dqn_config)
+main_config = mario_dqn_config
+mario_dqn_create_config = dict(
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+mario_dqn_create_config = EasyDict(mario_dqn_create_config)
+create_config = mario_dqn_create_config
+# you can run `python3 -u mario_dqn_main.py`
diff --git a/DI-engine/dizoo/mario/mario_dqn_example.py b/DI-engine/dizoo/mario/mario_dqn_example.py
new file mode 100644
index 0000000000000000000000000000000000000000..945ee33479599d979e400faa016b1b98ab09d07f
--- /dev/null
+++ b/DI-engine/dizoo/mario/mario_dqn_example.py
@@ -0,0 +1,66 @@
+import gym
+from ditk import logging
+from ding.model import DQN
+from ding.policy import DQNPolicy
+from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
+from ding.envs.env_wrappers import MaxAndSkipWrapper, WarpFrameWrapper, ScaledFloatFrameWrapper, FrameStackWrapper, \
+ EvalEpisodeReturnWrapper, TimeLimitWrapper
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import OffPolicyLearner, StepCollector, interaction_evaluator, data_pusher, \
+ eps_greedy_handler, CkptSaver, nstep_reward_enhancer
+from ding.utils import set_pkg_seed
+from mario_dqn_config import main_config, create_config
+import gym_super_mario_bros
+from nes_py.wrappers import JoypadSpace
+
+
+def wrapped_mario_env():
+ return DingEnvWrapper(
+ JoypadSpace(gym_super_mario_bros.make("SuperMarioBros-1-1-v0"), [["right"], ["right", "A"]]),
+ cfg={
+ 'env_wrapper': [
+ lambda env: MaxAndSkipWrapper(env, skip=4),
+ lambda env: WarpFrameWrapper(env, size=84),
+ lambda env: ScaledFloatFrameWrapper(env),
+ lambda env: FrameStackWrapper(env, n_frames=4),
+ lambda env: TimeLimitWrapper(env, max_limit=400),
+ lambda env: EvalEpisodeReturnWrapper(env),
+ ]
+ }
+ )
+
+
+def main():
+ filename = '{}/log.txt'.format(main_config.exp_name)
+ logging.getLogger(with_files=[filename]).setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[wrapped_mario_env for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[wrapped_mario_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = DQN(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(eps_greedy_handler(cfg))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(nstep_reward_enhancer(cfg))
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=1000))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/mario/mario_dqn_main.py b/DI-engine/dizoo/mario/mario_dqn_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..7b7bfe1efb5ccc82b68050b5fe713555d9f106a1
--- /dev/null
+++ b/DI-engine/dizoo/mario/mario_dqn_main.py
@@ -0,0 +1,106 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import SyncSubprocessEnvManager, DingEnvWrapper, BaseEnvManager
+from ding.envs.env_wrappers import MaxAndSkipWrapper, WarpFrameWrapper, ScaledFloatFrameWrapper, FrameStackWrapper, \
+ EvalEpisodeReturnWrapper
+from ding.policy import DQNPolicy
+from ding.model import DQN
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from mario_dqn_config import mario_dqn_config
+import gym_super_mario_bros
+from nes_py.wrappers import JoypadSpace
+
+
+def wrapped_mario_env():
+ return DingEnvWrapper(
+ JoypadSpace(gym_super_mario_bros.make("SuperMarioBros-1-1-v0"), [["right"], ["right", "A"]]),
+ cfg={
+ 'env_wrapper': [
+ lambda env: MaxAndSkipWrapper(env, skip=4),
+ lambda env: WarpFrameWrapper(env, size=84),
+ lambda env: ScaledFloatFrameWrapper(env),
+ lambda env: FrameStackWrapper(env, n_frames=4),
+ lambda env: EvalEpisodeReturnWrapper(env),
+ ]
+ }
+ )
+
+
+def main(cfg, seed=0):
+ cfg = compile_config(
+ cfg,
+ SyncSubprocessEnvManager,
+ DQNPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = SyncSubprocessEnvManager(
+ env_fn=[wrapped_mario_env for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = SyncSubprocessEnvManager(
+ env_fn=[wrapped_mario_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ # Set random seed for all package and instance
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = DQN(**cfg.policy.model)
+ policy = DQNPolicy(cfg.policy, model=model)
+
+ # Set up collection, training and evaluation utilities
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ # Set up other modules, etc. epsilon greedy
+ eps_cfg = cfg.policy.other.eps
+ epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
+
+ # Training & Evaluation loop
+ while True:
+ # Evaluating at the beginning and with specific frequency
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Update other modules
+ eps = epsilon_greedy(collector.envstep)
+ # Sampling data from environments
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs={'eps': eps})
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Training
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+ # evaluate
+ evaluator_env = BaseEnvManager(env_fn=[wrapped_mario_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
+ evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(mario_dqn_config)
diff --git a/DI-engine/dizoo/maze/__init__.py b/DI-engine/dizoo/maze/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..3bfa255bd54a460afac518f9473c7701e6432017
--- /dev/null
+++ b/DI-engine/dizoo/maze/__init__.py
@@ -0,0 +1,3 @@
+from gym.envs.registration import register
+
+register(id='Maze', entry_point='dizoo.maze.envs:Maze')
diff --git a/DI-engine/dizoo/maze/config/maze_bc_config.py b/DI-engine/dizoo/maze/config/maze_bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..18c9d6ade8ce1f3aeb1fb9294ca6febe0d1b92e8
--- /dev/null
+++ b/DI-engine/dizoo/maze/config/maze_bc_config.py
@@ -0,0 +1,56 @@
+from easydict import EasyDict
+
+maze_size = 16
+num_actions = 4
+maze_pc_config = dict(
+ exp_name="maze_bc_seed0",
+ env=dict(
+ collector_env_num=1,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ env_id='Maze',
+ size=maze_size,
+ wall_type='tunnel',
+ stop_value=1
+ ),
+ policy=dict(
+ cuda=True,
+ maze_size=maze_size,
+ num_actions=num_actions,
+ max_bfs_steps=100,
+ model=dict(
+ obs_shape=[3, maze_size, maze_size],
+ action_shape=num_actions,
+ encoder_hidden_size_list=[
+ 128,
+ 256,
+ 512,
+ 1024,
+ ],
+ strides=[1, 1, 1, 1]
+ ),
+ learn=dict(
+ # update_per_collect=4,
+ batch_size=256,
+ learning_rate=0.005,
+ train_epoch=5000,
+ optimizer='SGD',
+ ),
+ eval=dict(evaluator=dict(n_episode=5)),
+ collect=dict(),
+ ),
+)
+maze_pc_config = EasyDict(maze_pc_config)
+main_config = maze_pc_config
+maze_pc_create_config = dict(
+ env=dict(
+ type='maze',
+ import_names=['dizoo.maze.envs.maze_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='bc'),
+)
+maze_pc_create_config = EasyDict(maze_pc_create_config)
+create_config = maze_pc_create_config
+
+# You can run `dizoo/maze/entry/maze_bc_main.py` to run this config.
diff --git a/DI-engine/dizoo/maze/config/maze_pc_config.py b/DI-engine/dizoo/maze/config/maze_pc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a5f40b278c9c79a2cb40d2522d2bae98fc44186
--- /dev/null
+++ b/DI-engine/dizoo/maze/config/maze_pc_config.py
@@ -0,0 +1,57 @@
+from easydict import EasyDict
+
+maze_size = 16
+num_actions = 4
+maze_pc_config = dict(
+ exp_name="maze_pc_seed0",
+ train_seeds=5,
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ env_id='Maze',
+ size=maze_size,
+ wall_type='tunnel',
+ stop_value=1,
+ ),
+ policy=dict(
+ cuda=True,
+ maze_size=maze_size,
+ num_actions=num_actions,
+ max_bfs_steps=100,
+ model=dict(
+ obs_shape=[8, maze_size, maze_size],
+ action_shape=num_actions,
+ encoder_hidden_size_list=[
+ 128,
+ 256,
+ 512,
+ 1024,
+ ],
+ ),
+ learn=dict(
+ batch_size=32,
+ learning_rate=0.0005,
+ train_epoch=100,
+ optimizer='Adam',
+ ),
+ eval=dict(evaluator=dict(n_episode=5)),
+ collect=dict(),
+ ),
+)
+maze_pc_config = EasyDict(maze_pc_config)
+main_config = maze_pc_config
+maze_pc_create_config = dict(
+ env=dict(
+ type='maze',
+ import_names=['dizoo.maze.envs.maze_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='pc_bfs'),
+)
+maze_pc_create_config = EasyDict(maze_pc_create_config)
+create_config = maze_pc_create_config
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline_pc
+ serial_pipeline_pc([maze_pc_config, maze_pc_create_config], seed=0)
diff --git a/DI-engine/dizoo/maze/entry/maze_bc_main.py b/DI-engine/dizoo/maze/entry/maze_bc_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..3a42d4e92128997eeede34ae093209138b4957d3
--- /dev/null
+++ b/DI-engine/dizoo/maze/entry/maze_bc_main.py
@@ -0,0 +1,200 @@
+from typing import Union, Optional, Tuple
+import os
+from functools import partial
+from copy import deepcopy
+
+import easydict
+import torch
+import numpy as np
+from tensorboardX import SummaryWriter
+from torch.utils.data import DataLoader, Dataset
+
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from dizoo.maze.envs.maze_env import Maze
+
+
+# BFS algorithm
+def get_vi_sequence(env, observation):
+ """Returns [L, W, W] optimal actions."""
+ xy = np.where(observation[Ellipsis, -1] == 1)
+ start_x, start_y = xy[0][0], xy[1][0]
+ target_location = env.target_location
+ nav_map = env.nav_map
+ current_points = [target_location]
+ chosen_actions = {target_location: 0}
+ visited_points = {target_location: True}
+ vi_sequence = []
+
+ vi_map = np.full((env.size, env.size), fill_value=env.n_action, dtype=np.int32)
+
+ found_start = False
+ while current_points and not found_start:
+ next_points = []
+ for point_x, point_y in current_points:
+ for (action, (next_point_x, next_point_y)) in [(0, (point_x - 1, point_y)), (1, (point_x, point_y - 1)),
+ (2, (point_x + 1, point_y)), (3, (point_x, point_y + 1))]:
+
+ if (next_point_x, next_point_y) in visited_points:
+ continue
+
+ if not (0 <= next_point_x < len(nav_map) and 0 <= next_point_y < len(nav_map[next_point_x])):
+ continue
+
+ if nav_map[next_point_x][next_point_y] == 'x':
+ continue
+
+ next_points.append((next_point_x, next_point_y))
+ visited_points[(next_point_x, next_point_y)] = True
+ chosen_actions[(next_point_x, next_point_y)] = action
+ vi_map[next_point_x, next_point_y] = action
+
+ if next_point_x == start_x and next_point_y == start_y:
+ found_start = True
+ vi_sequence.append(vi_map.copy())
+ current_points = next_points
+ track_back = []
+ if found_start:
+ cur_x, cur_y = start_x, start_y
+ while cur_x != target_location[0] or cur_y != target_location[1]:
+ act = vi_sequence[-1][cur_x, cur_y]
+ track_back.append((torch.FloatTensor(env.process_states([cur_x, cur_y], env.get_maze_map())), act))
+ if act == 0:
+ cur_x += 1
+ elif act == 1:
+ cur_y += 1
+ elif act == 2:
+ cur_x -= 1
+ elif act == 3:
+ cur_y -= 1
+
+ return np.array(vi_sequence), track_back
+
+
+class BCDataset(Dataset):
+
+ def __init__(self, all_data):
+ self._data = all_data
+
+ def __getitem__(self, item):
+ return {'obs': self._data[item][0], 'action': self._data[item][1]}
+
+ def __len__(self):
+ return len(self._data)
+
+
+def load_bc_dataset(train_seeds=1, test_seeds=1, batch_size=32):
+
+ def load_env(seed):
+ ccc = easydict.EasyDict({'size': 16})
+ e = Maze(ccc)
+ e.seed(seed)
+ e.reset()
+ return e
+
+ envs = [load_env(i) for i in range(train_seeds + test_seeds)]
+ data_train = []
+ data_test = []
+
+ for idx, env in enumerate(envs):
+ if idx < train_seeds:
+ data = data_train
+ else:
+ data = data_test
+
+ start_obs = env.process_states(env._get_obs(), env.get_maze_map())
+ _, track_back = get_vi_sequence(env, start_obs)
+
+ data += track_back
+
+ train_data = BCDataset(data_train)
+ test_data = BCDataset(data_test)
+
+ train_dataset = DataLoader(train_data, batch_size=batch_size, shuffle=True)
+ test_dataset = DataLoader(test_data, batch_size=batch_size, shuffle=True)
+ return train_dataset, test_dataset
+
+
+def serial_pipeline_bc(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ max_iter=int(1e6),
+) -> Union['Policy', bool]: # noqa
+ r"""
+ Overview:
+ Serial pipeline entry of imitation learning.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - data_path (:obj:`str`): Path of training data.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ - convergence (:obj:`bool`): whether il training is converged
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = deepcopy(input_cfg)
+ cfg = compile_config(cfg, seed=seed, auto=True, create_cfg=create_cfg)
+
+ # Env, Policy
+ env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ # Random seed
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'eval'])
+
+ # Main components
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ dataloader, test_dataloader = load_bc_dataset()
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+
+ # ==========
+ # Main loop
+ # ==========
+ learner.call_hook('before_run')
+ stop = False
+ iter_cnt = 0
+ for epoch in range(cfg.policy.learn.train_epoch):
+ # Evaluate policy performance
+ loss_list = []
+ for _, bat in enumerate(test_dataloader):
+ bat['action'] = bat['action'].long()
+ res = policy._forward_eval(bat['obs'])
+ res = torch.argmax(res['logit'], dim=1)
+ loss_list.append(torch.sum(res == bat['action'].squeeze(-1)).item() / bat['action'].shape[0])
+ label = 'validation_acc'
+ tb_logger.add_scalar(label, sum(loss_list) / len(loss_list), iter_cnt)
+ for i, train_data in enumerate(dataloader):
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter)
+ if stop:
+ break
+ train_data['action'] = train_data['action'].long()
+ learner.train(train_data)
+ iter_cnt += 1
+ if iter_cnt >= max_iter:
+ stop = True
+ break
+ if stop:
+ break
+
+ learner.call_hook('after_run')
+ print('final reward is: {}'.format(reward))
+ return policy, stop
+
+
+if __name__ == '__main__':
+ from dizoo.maze.config.maze_bc_config import main_config, create_config
+ serial_pipeline_bc([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/maze/envs/__init__.py b/DI-engine/dizoo/maze/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..ab42c5b39d6ce7e31ce7c1e0c392275cbc715ac6
--- /dev/null
+++ b/DI-engine/dizoo/maze/envs/__init__.py
@@ -0,0 +1 @@
+from .maze_env import Maze
diff --git a/DI-engine/dizoo/maze/envs/maze_env.py b/DI-engine/dizoo/maze/envs/maze_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..f441b7d698c7aa36d931294cdda15b7b45483ee5
--- /dev/null
+++ b/DI-engine/dizoo/maze/envs/maze_env.py
@@ -0,0 +1,380 @@
+from typing import List
+
+import copy
+import numpy as np
+import gym
+from gym import spaces
+from gym.utils import seeding
+
+from ding.envs import BaseEnvTimestep
+from ding.utils import ENV_REGISTRY
+
+
+@ENV_REGISTRY.register('maze')
+class Maze(gym.Env):
+ """
+ Environment with random maze layouts. The ASCII representation of the mazes include the following objects:
+ - ``: empty
+ - `x`: wall
+ - `S`: the start location (optional)
+ - `T`: the target location.
+ """
+ KEY_EMPTY = 0
+ KEY_WALL = 1
+ KEY_TARGET = 2
+ KEY_START = 3
+ ASCII_MAP = {
+ KEY_EMPTY: ' ',
+ KEY_WALL: 'x',
+ KEY_TARGET: 'T',
+ KEY_START: 'S',
+ }
+
+ def __init__(
+ self,
+ cfg,
+ ):
+ self._size = cfg.size
+ self._init_flag = False
+ self._random_start = True
+ self._seed = None
+ self._step = 0
+
+ def reset(self):
+ self.active_init()
+ obs = self._get_obs()
+ self._step = 0
+ return self.process_states(obs, self.get_maze_map())
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def active_init(
+ self,
+ tabular_obs=False,
+ reward_fn=lambda x, y, tx, ty: 1 if (x == tx and y == ty) else 0,
+ done_fn=lambda x, y, tx, ty: x == tx and y == ty
+ ):
+ self._maze = self.generate_maze(self.size, self._seed, 'tunnel')
+ self._num_maze_keys = len(Maze.ASCII_MAP.keys())
+ nav_map = self.maze_to_ascii(self._maze)
+ self._map = nav_map
+ self._tabular_obs = tabular_obs
+ self._reward_fn = reward_fn
+ self._done_fn = done_fn
+ if self._reward_fn is None:
+ self._reward_fn = lambda x, y, tx, ty: float(x == tx and y == ty)
+ if self._done_fn is None:
+ self._done_fn = lambda x, y, tx, ty: False
+
+ self._max_x = len(self._map)
+ if not self._max_x:
+ raise ValueError('Invalid map.')
+ self._max_y = len(self._map[0])
+ if not all(len(m) == self._max_y for m in self._map):
+ raise ValueError('Invalid map.')
+ self._start_x, self._start_y = self._find_initial_point()
+ self._target_x, self._target_y = self._find_target_point()
+ self._x, self._y = self._start_x, self._start_y
+
+ self._n_state = self._max_x * self._max_y
+ self._n_action = 4
+
+ if self._tabular_obs:
+ self.observation_space = spaces.Discrete(self._n_state)
+ else:
+ self.observation_space = spaces.Box(low=0.0, high=np.inf, shape=(16, 16, 3))
+
+ self.action_space = spaces.Discrete(self._n_action)
+ self.reward_space = spaces.Box(low=0, high=1, shape=(1, ), dtype=np.float32)
+
+ def random_start(self):
+ init_x, init_y = self._x, self._y
+ while True: # Find empty grid cell.
+ self._x = self.np_random.integers(self._max_x)
+ self._y = self.np_random.integers(self._max_y)
+ if self._map[self._x][self._y] != 'x':
+ break
+ ret = copy.deepcopy(self.process_states(self._get_obs(), self.get_maze_map()))
+ self._x, self._y = init_x, init_y
+ return ret
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ @property
+ def num_maze_keys(self):
+ return self._num_maze_keys
+
+ @property
+ def size(self):
+ return self._size
+
+ def process_states(self, observations, maze_maps):
+ """Returns [B, W, W, 3] binary values. Channels are (wall; goal; obs)"""
+ loc = np.eye(self._size * self._size, dtype=np.int64)[observations[0] * self._size + observations[1]]
+ loc = np.reshape(loc, [self._size, self._size])
+ maze_maps = maze_maps.astype(np.int64)
+
+ states = np.concatenate([maze_maps, loc[Ellipsis, None]], axis=-1, dtype=np.int64)
+ return states
+
+ def get_maze_map(self, stacked=True):
+ if not stacked:
+ return self._maze.copy()
+ wall = self._maze.copy()
+ target_x, target_y = self.target_location
+ assert wall[target_x][target_y] == Maze.KEY_TARGET
+ wall[target_x][target_y] = 0
+ target = np.zeros((self._size, self._size))
+ target[target_x][target_y] = 1
+ assert wall[self._start_x][self._start_y] == Maze.KEY_START
+ wall[self._start_x][self._start_y] = 0
+ return np.stack([wall, target], axis=-1)
+
+ def generate_maze(self, size, seed, wall_type):
+ rng, _ = seeding.np_random(seed)
+ maze = np.full((size, size), fill_value=Maze.KEY_EMPTY, dtype=int)
+
+ if wall_type == 'none':
+ maze[[0, -1], :] = Maze.KEY_WALL
+ maze[:, [0, -1]] = Maze.KEY_WALL
+ elif wall_type == 'tunnel':
+ self.sample_wall(maze, rng)
+ elif wall_type.startswith('blocks:'):
+ maze[[0, -1], :] = Maze.KEY_WALL
+ maze[:, [0, -1]] = Maze.KEY_WALL
+ self.sample_blocks(maze, rng, int(wall_type.split(':')[-1]))
+ else:
+ raise ValueError('Unknown wall type: %s' % wall_type)
+
+ loc_target = self.sample_location(maze, rng)
+ maze[loc_target] = Maze.KEY_TARGET
+
+ loc_start = self.sample_location(maze, rng)
+ maze[loc_start] = Maze.KEY_START
+ self._start_x, self._start_y = loc_start
+
+ return maze
+
+ def sample_blocks(self, maze, rng, num_blocks):
+ """Sample single-block 'wall' or 'obstacles'."""
+ for _ in range(num_blocks):
+ loc = self.sample_location(maze, rng)
+ maze[loc] = Maze.KEY_WALL
+
+ def sample_wall(
+ self, maze, rng, shortcut_prob=0.1, inner_wall_thickness=1, outer_wall_thickness=1, corridor_thickness=2
+ ):
+ room = maze
+
+ # step 1: fill everything as wall
+ room[:] = Maze.KEY_WALL
+
+ # step 2: prepare
+ # we move two pixels at a time, because the walls are also occupying pixels
+ delta = inner_wall_thickness + corridor_thickness
+ dx = [delta, -delta, 0, 0]
+ dy = [0, 0, delta, -delta]
+
+ def get_loc_type(y, x):
+ # remember there is a outside wall of 1 pixel surrounding the room
+ if (y < outer_wall_thickness or y + corridor_thickness - 1 >= room.shape[0] - outer_wall_thickness):
+ return 'invalid'
+ if (x < outer_wall_thickness or x + corridor_thickness - 1 >= room.shape[1] - outer_wall_thickness):
+ return 'invalid'
+ # already visited
+ if room[y, x] == Maze.KEY_EMPTY:
+ return 'occupied'
+ return 'valid'
+
+ def connect_pixel(y, x, ny, nx):
+ pixel = Maze.KEY_EMPTY
+ if ny == y:
+ room[y:y + corridor_thickness, min(x, nx):max(x, nx) + corridor_thickness] = pixel
+ else:
+ room[min(y, ny):max(y, ny) + corridor_thickness, x:x + corridor_thickness] = pixel
+
+ def carve_passage_from(y, x):
+ room[y, x] = Maze.KEY_EMPTY
+ for direction in rng.permutation(len(dx)):
+ ny = y + dy[direction]
+ nx = x + dx[direction]
+
+ loc_type = get_loc_type(ny, nx)
+ if loc_type == 'invalid':
+ continue
+ elif loc_type == 'valid':
+ connect_pixel(y, x, ny, nx)
+ # recursion
+ carve_passage_from(ny, nx)
+ else:
+ # occupied
+ # we create shortcut with some probability, this is because
+ # we do not want to restrict to only one feasible path.
+ if rng.random() < shortcut_prob:
+ connect_pixel(y, x, ny, nx)
+
+ carve_passage_from(outer_wall_thickness, outer_wall_thickness)
+
+ def sample_location(self, maze, rng):
+ for _ in range(1000):
+ x, y = rng.integers(low=1, high=self._size, size=2)
+ if maze[x, y] == Maze.KEY_EMPTY:
+ return x, y
+ raise ValueError('Cannot sample empty location, make maze bigger?')
+
+ @staticmethod
+ def key_to_ascii(key):
+ if key in Maze.ASCII_MAP:
+ return Maze.ASCII_MAP[key]
+ assert (Maze.KEY_OBJ <= key < Maze.KEY_OBJ + Maze.MAX_OBJ_TYPES)
+ return chr(ord('1') + key - Maze.KEY_OBJ)
+
+ def maze_to_ascii(self, maze):
+ return [[Maze.key_to_ascii(x) for x in row] for row in maze]
+
+ def tabular_obs_action(self, status_obs, action, include_maze_layout=False):
+ tabular_obs = self.get_tabular_obs(status_obs)
+ multiplier = self._n_action
+ if include_maze_layout:
+ multiplier += self._num_maze_keys
+ return multiplier * tabular_obs + action
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_env_num = cfg.pop('collector_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = True
+ return [cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_env_num = cfg.pop('evaluator_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = False
+ return [cfg for _ in range(evaluator_env_num)]
+
+ @property
+ def nav_map(self):
+ return self._map
+
+ @property
+ def n_state(self):
+ return self._n_state
+
+ @property
+ def n_action(self):
+ return self._n_action
+
+ @property
+ def target_location(self):
+ return self._target_x, self._target_y
+
+ @property
+ def tabular_obs(self):
+ return self._tabular_obs
+
+ def _find_initial_point(self):
+ for x in range(self._max_x):
+ for y in range(self._max_y):
+ if self._map[x][y] == 'S':
+ break
+ if self._map[x][y] == 'S':
+ break
+ else:
+ return None, None
+
+ return x, y
+
+ def _find_target_point(self):
+ for x in range(self._max_x):
+ for y in range(self._max_y):
+ if self._map[x][y] == 'T':
+ break
+ if self._map[x][y] == 'T':
+ break
+ else:
+ raise ValueError('Target point not found in map.')
+
+ return x, y
+
+ def _get_obs(self):
+ if self._tabular_obs:
+ return self._x * self._max_y + self._y
+ else:
+ return np.array([self._x, self._y])
+
+ def get_tabular_obs(self, status_obs):
+ return self._max_y * status_obs[..., 0] + status_obs[..., 1]
+
+ def get_xy(self, state):
+ x = state / self._max_y
+ y = state % self._max_y
+ return x, y
+
+ def step(self, action):
+ last_x, last_y = self._x, self._y
+ if action == 0:
+ if self._x < self._max_x - 1:
+ self._x += 1
+ elif action == 1:
+ if self._y < self._max_y - 1:
+ self._y += 1
+ elif action == 2:
+ if self._x > 0:
+ self._x -= 1
+ elif action == 3:
+ if self._y > 0:
+ self._y -= 1
+
+ if self._map[self._x][self._y] == 'x':
+ self._x, self._y = last_x, last_y
+ self._step += 1
+ reward = self._reward_fn(self._x, self._y, self._target_x, self._target_y)
+ done = self._done_fn(self._x, self._y, self._target_x, self._target_y)
+ info = {}
+ if self._step > 100:
+ done = True
+ if done:
+ info['final_eval_reward'] = reward
+ info['eval_episode_return'] = reward
+ return BaseEnvTimestep(self.process_states(self._get_obs(), self.get_maze_map()), reward, done, info)
+
+
+def get_value_map(env):
+ """Returns [W, W, A] one-hot VI actions."""
+ target_location = env.target_location
+ nav_map = env.nav_map
+ current_points = [target_location]
+ chosen_actions = {target_location: 0}
+ visited_points = {target_location: True}
+
+ while current_points:
+ next_points = []
+ for point_x, point_y in current_points:
+ for (action, (next_point_x, next_point_y)) in [(0, (point_x - 1, point_y)), (1, (point_x, point_y - 1)),
+ (2, (point_x + 1, point_y)), (3, (point_x, point_y + 1))]:
+
+ if (next_point_x, next_point_y) in visited_points:
+ continue
+
+ if not (0 <= next_point_x < len(nav_map) and 0 <= next_point_y < len(nav_map[next_point_x])):
+ continue
+
+ if nav_map[next_point_x][next_point_y] == 'x':
+ continue
+
+ next_points.append((next_point_x, next_point_y))
+ visited_points[(next_point_x, next_point_y)] = True
+ chosen_actions[(next_point_x, next_point_y)] = action
+ current_points = next_points
+
+ value_map = np.zeros([env.size, env.size, env.n_action])
+ for (x, y), action in chosen_actions.items():
+ value_map[x][y][action] = 1
+ return value_map
diff --git a/DI-engine/dizoo/maze/envs/test_maze_env.py b/DI-engine/dizoo/maze/envs/test_maze_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..b8350d46d360e8aaf87535eca2ef4e07bd6f874e
--- /dev/null
+++ b/DI-engine/dizoo/maze/envs/test_maze_env.py
@@ -0,0 +1,28 @@
+import pytest
+import os
+import numpy as np
+from dizoo.maze.envs.maze_env import Maze
+from easydict import EasyDict
+import copy
+
+
+@pytest.mark.envtest
+class TestMazeEnv:
+
+ def test_maze(self):
+ env = Maze(EasyDict({'size': 16}))
+ env.seed(314)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (16, 16, 3)
+ min_val, max_val = 0, 3
+ for i in range(100):
+ random_action = np.random.randint(min_val, max_val, size=(1, ))
+ timestep = env.step(random_action)
+ print(timestep)
+ print(timestep.obs.max())
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ if timestep.done:
+ env.reset()
+ env.close()
diff --git a/DI-engine/dizoo/metadrive/__init__.py b/DI-engine/dizoo/metadrive/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/metadrive/config/__init__.py b/DI-engine/dizoo/metadrive/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/metadrive/config/metadrive_onppo_config.py b/DI-engine/dizoo/metadrive/config/metadrive_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..05585b8a58f7638da96e933a9d1947e2ecf44504
--- /dev/null
+++ b/DI-engine/dizoo/metadrive/config/metadrive_onppo_config.py
@@ -0,0 +1,111 @@
+from easydict import EasyDict
+from functools import partial
+from tensorboardX import SummaryWriter
+import metadrive
+import gym
+from ding.envs import BaseEnvManager, SyncSubprocessEnvManager
+from ding.config import compile_config
+from ding.model.template import ContinuousQAC, VAC
+from ding.policy import PPOPolicy
+from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner
+from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
+from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
+
+metadrive_basic_config = dict(
+ exp_name='metadrive_onppo_seed0',
+ env=dict(
+ metadrive=dict(
+ use_render=False,
+ traffic_density=0.10, # Density of vehicles occupying the roads, range in [0,1]
+ map='XSOS', # Int or string: an easy way to fill map_config
+ horizon=4000, # Max step number
+ driving_reward=1.0, # Reward to encourage agent to move forward.
+ speed_reward=0.1, # Reward to encourage agent to drive at a high speed
+ use_lateral_reward=False, # reward for lane keeping
+ out_of_road_penalty=40.0, # Penalty to discourage driving out of road
+ crash_vehicle_penalty=40.0, # Penalty to discourage collision
+ decision_repeat=20, # Reciprocal of decision frequency
+ out_of_route_done=True, # Game over if driving out of road
+ ),
+ manager=dict(
+ shared_memory=False,
+ max_retry=2,
+ context='spawn',
+ ),
+ n_evaluator_episode=16,
+ stop_value=255,
+ collector_env_num=8,
+ evaluator_env_num=8,
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='continuous',
+ model=dict(
+ obs_shape=[5, 84, 84],
+ action_shape=2,
+ action_space='continuous',
+ bound_type='tanh',
+ encoder_hidden_size_list=[128, 128, 64],
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ entropy_weight=0.001,
+ value_weight=0.5,
+ clip_ratio=0.02,
+ adv_norm=False,
+ value_norm=True,
+ grad_clip_value=10,
+ ),
+ collect=dict(n_sample=3000, ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ ),
+)
+main_config = EasyDict(metadrive_basic_config)
+
+
+def wrapped_env(env_cfg, wrapper_cfg=None):
+ return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg)
+
+
+def main(cfg):
+ cfg = compile_config(
+ cfg, SyncSubprocessEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = SyncSubprocessEnvManager(
+ env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(collector_env_num)],
+ cfg=cfg.env.manager,
+ )
+ evaluator_env = SyncSubprocessEnvManager(
+ env_fn=[partial(wrapped_env, cfg.env.metadrive) for _ in range(evaluator_env_num)],
+ cfg=cfg.env.manager,
+ )
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ learner.call_hook('before_run')
+ while True:
+ if evaluator.should_eval(learner.train_iter):
+ stop, rate = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Sampling data from environments
+ new_data = collector.collect(cfg.policy.collect.n_sample, train_iter=learner.train_iter)
+ learner.train(new_data, collector.envstep)
+ learner.call_hook('after_run')
+ collector.close()
+ evaluator.close()
+ learner.close()
+
+
+if __name__ == '__main__':
+ main(main_config)
diff --git a/DI-engine/dizoo/metadrive/config/metadrive_onppo_eval_config.py b/DI-engine/dizoo/metadrive/config/metadrive_onppo_eval_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9dab89ed2765a440e48b3618b5150fccb8f4445
--- /dev/null
+++ b/DI-engine/dizoo/metadrive/config/metadrive_onppo_eval_config.py
@@ -0,0 +1,96 @@
+from easydict import EasyDict
+from functools import partial
+from tensorboardX import SummaryWriter
+import torch
+from ding.envs import BaseEnvManager, SyncSubprocessEnvManager
+from ding.config import compile_config
+from ding.model.template import VAC
+from ding.policy import PPOPolicy
+from ding.worker import SampleSerialCollector, InteractionSerialEvaluator, BaseLearner
+from dizoo.metadrive.env.drive_env import MetaDrivePPOOriginEnv
+from dizoo.metadrive.env.drive_wrapper import DriveEnvWrapper
+
+# Load the trained model from this direction, if None, it will initialize from scratch
+model_dir = None
+metadrive_basic_config = dict(
+ exp_name='metadrive_onppo_eval_seed0',
+ env=dict(
+ metadrive=dict(
+ use_render=True,
+ traffic_density=0.10, # Density of vehicles occupying the roads, range in [0,1]
+ map='XSOS', # Int or string: an easy way to fill map_config
+ horizon=4000, # Max step number
+ driving_reward=1.0, # Reward to encourage agent to move forward.
+ speed_reward=0.10, # Reward to encourage agent to drive at a high speed
+ use_lateral_reward=False, # reward for lane keeping
+ out_of_road_penalty=40.0, # Penalty to discourage driving out of road
+ crash_vehicle_penalty=40.0, # Penalty to discourage collision
+ decision_repeat=20, # Reciprocal of decision frequency
+ out_of_route_done=True, # Game over if driving out of road
+ show_bird_view=False, # Only used to evaluate, whether to draw five channels of bird-view image
+ ),
+ manager=dict(
+ shared_memory=False,
+ max_retry=2,
+ context='spawn',
+ ),
+ n_evaluator_episode=16,
+ stop_value=255,
+ collector_env_num=1,
+ evaluator_env_num=1,
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='continuous',
+ model=dict(
+ obs_shape=[5, 84, 84],
+ action_shape=2,
+ action_space='continuous',
+ bound_type='tanh',
+ encoder_hidden_size_list=[128, 128, 64],
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ entropy_weight=0.001,
+ value_weight=0.5,
+ clip_ratio=0.02,
+ adv_norm=False,
+ value_norm=True,
+ grad_clip_value=10,
+ ),
+ collect=dict(n_sample=1000, ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ ),
+)
+main_config = EasyDict(metadrive_basic_config)
+
+
+def wrapped_env(env_cfg, wrapper_cfg=None):
+ return DriveEnvWrapper(MetaDrivePPOOriginEnv(env_cfg), wrapper_cfg)
+
+
+def main(cfg):
+ cfg = compile_config(cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator)
+ evaluator_env_num = cfg.env.evaluator_env_num
+ show_bird_view = cfg.env.metadrive.show_bird_view
+ wrapper_cfg = {'show_bird_view': show_bird_view}
+ evaluator_env = BaseEnvManager(
+ env_fn=[partial(wrapped_env, cfg.env.metadrive, wrapper_cfg) for _ in range(evaluator_env_num)],
+ cfg=cfg.env.manager,
+ )
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ if model_dir is not None:
+ policy._load_state_dict_collect(torch.load(model_dir, map_location='cpu'))
+ tb_logger = SummaryWriter('./log/{}/'.format(cfg.exp_name))
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ stop, rate = evaluator.eval()
+ evaluator.close()
+
+
+if __name__ == '__main__':
+ main(main_config)
diff --git a/DI-engine/dizoo/metadrive/env/__init__.py b/DI-engine/dizoo/metadrive/env/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/metadrive/env/drive_env.py b/DI-engine/dizoo/metadrive/env/drive_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..87087b8f979e05bc4f08504fbce6d277802ac24f
--- /dev/null
+++ b/DI-engine/dizoo/metadrive/env/drive_env.py
@@ -0,0 +1,364 @@
+import copy
+import gym
+import numpy as np
+from ditk import logging
+from typing import Union, Dict, AnyStr, Tuple, Optional
+from gym.envs.registration import register
+from metadrive.manager.traffic_manager import TrafficMode
+from metadrive.obs.top_down_obs_multi_channel import TopDownMultiChannel
+from metadrive.constants import RENDER_MODE_NONE, DEFAULT_AGENT, REPLAY_DONE, TerminationState
+from metadrive.envs.base_env import BaseEnv
+from metadrive.component.map.base_map import BaseMap
+from metadrive.component.map.pg_map import parse_map_config, MapGenerateMethod
+from metadrive.component.pgblock.first_block import FirstPGBlock
+from metadrive.component.vehicle.base_vehicle import BaseVehicle
+from metadrive.utils import Config, merge_dicts, get_np_random, clip
+from metadrive.envs.base_env import BASE_DEFAULT_CONFIG
+from metadrive.component.road_network import Road
+from metadrive.component.algorithm.blocks_prob_dist import PGBlockDistConfig
+
+METADRIVE_DEFAULT_CONFIG = dict(
+ # ===== Generalization =====
+ start_seed=0,
+ environment_num=10,
+ decision_repeat=20,
+ block_dist_config=PGBlockDistConfig,
+
+ # ===== Map Config =====
+ map=3, # int or string: an easy way to fill map_config
+ random_lane_width=False,
+ random_lane_num=False,
+ map_config={
+ BaseMap.GENERATE_TYPE: MapGenerateMethod.BIG_BLOCK_NUM,
+ BaseMap.GENERATE_CONFIG: None, # it can be a file path / block num / block ID sequence
+ BaseMap.LANE_WIDTH: 3.5,
+ BaseMap.LANE_NUM: 3,
+ "exit_length": 50,
+ },
+
+ # ===== Traffic =====
+ traffic_density=0.1,
+ need_inverse_traffic=False,
+ traffic_mode=TrafficMode.Trigger, # "Respawn", "Trigger"
+ random_traffic=False, # Traffic is randomized at default.
+ traffic_vehicle_config=dict(
+ show_navi_mark=False,
+ show_dest_mark=False,
+ enable_reverse=False,
+ show_lidar=False,
+ show_lane_line_detector=False,
+ show_side_detector=False,
+ ),
+
+ # ===== Object =====
+ accident_prob=0., # accident may happen on each block with this probability, except multi-exits block
+
+ # ===== Others =====
+ use_AI_protector=False,
+ save_level=0.5,
+ is_multi_agent=False,
+ vehicle_config=dict(spawn_lane_index=(FirstPGBlock.NODE_1, FirstPGBlock.NODE_2, 0)),
+
+ # ===== Agent =====
+ random_spawn_lane_index=True,
+ target_vehicle_configs={
+ DEFAULT_AGENT: dict(
+ use_special_color=True,
+ spawn_lane_index=(FirstPGBlock.NODE_1, FirstPGBlock.NODE_2, 0),
+ )
+ },
+
+ # ===== Reward Scheme =====
+ # See: https://github.com/decisionforce/metadrive/issues/283
+ success_reward=10.0,
+ out_of_road_penalty=5.0,
+ crash_vehicle_penalty=5.0,
+ crash_object_penalty=5.0,
+ driving_reward=1.0,
+ speed_reward=0.1,
+ use_lateral_reward=False,
+
+ # ===== Cost Scheme =====
+ crash_vehicle_cost=1.0,
+ crash_object_cost=1.0,
+ out_of_road_cost=1.0,
+
+ # ===== Termination Scheme =====
+ out_of_route_done=False,
+ on_screen=False,
+ show_bird_view=False,
+)
+
+
+class MetaDrivePPOOriginEnv(BaseEnv):
+
+ @classmethod
+ def default_config(cls) -> "Config":
+ config = super(MetaDrivePPOOriginEnv, cls).default_config()
+ config.update(METADRIVE_DEFAULT_CONFIG)
+ config.register_type("map", str, int)
+ config["map_config"].register_type("config", None)
+ return config
+
+ def __init__(self, config: dict = None):
+ self.raw_cfg = config
+ self.default_config_copy = Config(self.default_config(), unchangeable=True)
+ self.init_flag = False
+
+ @property
+ def observation_space(self):
+ return gym.spaces.Box(0, 1, shape=(84, 84, 5), dtype=np.float32)
+
+ @property
+ def action_space(self):
+ return gym.spaces.Box(-1, 1, shape=(2, ), dtype=np.float32)
+
+ @property
+ def reward_space(self):
+ return gym.spaces.Box(-100, 100, shape=(1, ), dtype=np.float32)
+
+ def seed(self, seed, dynamic_seed=False):
+ # TODO implement dynamic_seed mechanism
+ super().seed(seed)
+
+ def reset(self):
+ if not self.init_flag:
+ super(MetaDrivePPOOriginEnv, self).__init__(self.raw_cfg)
+ self.start_seed = self.config["start_seed"]
+ self.env_num = self.config["environment_num"]
+ self.init_flag = True
+ obs = super().reset()
+ return obs
+
+ def _merge_extra_config(self, config: Union[dict, "Config"]) -> "Config":
+ config = self.default_config().update(config, allow_add_new_key=False)
+ if config["vehicle_config"]["lidar"]["distance"] > 50:
+ config["max_distance"] = config["vehicle_config"]["lidar"]["distance"]
+ return config
+
+ def _post_process_config(self, config):
+ config = super(MetaDrivePPOOriginEnv, self)._post_process_config(config)
+ if not config["rgb_clip"]:
+ logging.warning(
+ "You have set rgb_clip = False, which means the observation will be uint8 values in [0, 255]. "
+ "Please make sure you have parsed them later before feeding them to network!"
+ )
+ config["map_config"] = parse_map_config(
+ easy_map_config=config["map"], new_map_config=config["map_config"], default_config=self.default_config_copy
+ )
+ config["vehicle_config"]["rgb_clip"] = config["rgb_clip"]
+ config["vehicle_config"]["random_agent_model"] = config["random_agent_model"]
+ if config.get("gaussian_noise", 0) > 0:
+ assert config["vehicle_config"]["lidar"]["gaussian_noise"] == 0, "You already provide config!"
+ assert config["vehicle_config"]["side_detector"]["gaussian_noise"] == 0, "You already provide config!"
+ assert config["vehicle_config"]["lane_line_detector"]["gaussian_noise"] == 0, "You already provide config!"
+ config["vehicle_config"]["lidar"]["gaussian_noise"] = config["gaussian_noise"]
+ config["vehicle_config"]["side_detector"]["gaussian_noise"] = config["gaussian_noise"]
+ config["vehicle_config"]["lane_line_detector"]["gaussian_noise"] = config["gaussian_noise"]
+ if config.get("dropout_prob", 0) > 0:
+ assert config["vehicle_config"]["lidar"]["dropout_prob"] == 0, "You already provide config!"
+ assert config["vehicle_config"]["side_detector"]["dropout_prob"] == 0, "You already provide config!"
+ assert config["vehicle_config"]["lane_line_detector"]["dropout_prob"] == 0, "You already provide config!"
+ config["vehicle_config"]["lidar"]["dropout_prob"] = config["dropout_prob"]
+ config["vehicle_config"]["side_detector"]["dropout_prob"] = config["dropout_prob"]
+ config["vehicle_config"]["lane_line_detector"]["dropout_prob"] = config["dropout_prob"]
+ target_v_config = copy.deepcopy(config["vehicle_config"])
+ if not config["is_multi_agent"]:
+ target_v_config.update(config["target_vehicle_configs"][DEFAULT_AGENT])
+ config["target_vehicle_configs"][DEFAULT_AGENT] = target_v_config
+ return config
+
+ def step(self, actions: Union[np.ndarray, Dict[AnyStr, np.ndarray]]):
+ actions = self._preprocess_actions(actions)
+ engine_info = self._step_simulator(actions)
+ o, r, d, i = self._get_step_return(actions, engine_info=engine_info)
+ return o, r, d, i
+
+ def cost_function(self, vehicle_id: str):
+ vehicle = self.vehicles[vehicle_id]
+ step_info = dict()
+ step_info["cost"] = 0
+ if self._is_out_of_road(vehicle):
+ step_info["cost"] = self.config["out_of_road_cost"]
+ elif vehicle.crash_vehicle:
+ step_info["cost"] = self.config["crash_vehicle_cost"]
+ elif vehicle.crash_object:
+ step_info["cost"] = self.config["crash_object_cost"]
+ return step_info['cost'], step_info
+
+ def _is_out_of_road(self, vehicle):
+ ret = vehicle.on_yellow_continuous_line or vehicle.on_white_continuous_line or \
+ (not vehicle.on_lane) or vehicle.crash_sidewalk
+ if self.config["out_of_route_done"]:
+ ret = ret or vehicle.out_of_route
+ return ret
+
+ def done_function(self, vehicle_id: str):
+ vehicle = self.vehicles[vehicle_id]
+ done = False
+ done_info = {
+ TerminationState.CRASH_VEHICLE: False,
+ TerminationState.CRASH_OBJECT: False,
+ TerminationState.CRASH_BUILDING: False,
+ TerminationState.OUT_OF_ROAD: False,
+ TerminationState.SUCCESS: False,
+ TerminationState.MAX_STEP: False,
+ TerminationState.ENV_SEED: self.current_seed,
+ }
+ if self._is_arrive_destination(vehicle):
+ done = True
+ logging.info("Episode ended! Reason: arrive_dest.")
+ done_info[TerminationState.SUCCESS] = True
+ if self._is_out_of_road(vehicle):
+ done = True
+ logging.info("Episode ended! Reason: out_of_road.")
+ done_info[TerminationState.OUT_OF_ROAD] = True
+ if vehicle.crash_vehicle:
+ done = True
+ logging.info("Episode ended! Reason: crash vehicle ")
+ done_info[TerminationState.CRASH_VEHICLE] = True
+ if vehicle.crash_object:
+ done = True
+ done_info[TerminationState.CRASH_OBJECT] = True
+ logging.info("Episode ended! Reason: crash object ")
+ if vehicle.crash_building:
+ done = True
+ done_info[TerminationState.CRASH_BUILDING] = True
+ logging.info("Episode ended! Reason: crash building ")
+ if self.config["max_step_per_agent"] is not None and \
+ self.episode_lengths[vehicle_id] >= self.config["max_step_per_agent"]:
+ done = True
+ done_info[TerminationState.MAX_STEP] = True
+ logging.info("Episode ended! Reason: max step ")
+
+ if self.config["horizon"] is not None and \
+ self.episode_lengths[vehicle_id] >= self.config["horizon"] and not self.is_multi_agent:
+ # single agent horizon has the same meaning as max_step_per_agent
+ done = True
+ done_info[TerminationState.MAX_STEP] = True
+ logging.info("Episode ended! Reason: max step ")
+
+ done_info[TerminationState.CRASH] = (
+ done_info[TerminationState.CRASH_VEHICLE] or done_info[TerminationState.CRASH_OBJECT]
+ or done_info[TerminationState.CRASH_BUILDING]
+ )
+ return done, done_info
+
+ def reward_function(self, vehicle_id: str):
+ """
+ Override this func to get a new reward function
+ :param vehicle_id: id of BaseVehicle
+ :return: reward
+ """
+ vehicle = self.vehicles[vehicle_id]
+ step_info = dict()
+
+ # Reward for moving forward in current lane
+ if vehicle.lane in vehicle.navigation.current_ref_lanes:
+ current_lane = vehicle.lane
+ positive_road = 1
+ else:
+ current_lane = vehicle.navigation.current_ref_lanes[0]
+ current_road = vehicle.navigation.current_road
+ positive_road = 1 if not current_road.is_negative_road() else -1
+ long_last, _ = current_lane.local_coordinates(vehicle.last_position)
+ long_now, lateral_now = current_lane.local_coordinates(vehicle.position)
+
+ # reward for lane keeping, without it vehicle can learn to overtake but fail to keep in lane
+ if self.config["use_lateral_reward"]:
+ lateral_factor = clip(1 - 2 * abs(lateral_now) / vehicle.navigation.get_current_lane_width(), 0.0, 1.0)
+ else:
+ lateral_factor = 1.0
+
+ reward = 0.0
+ reward += self.config["driving_reward"] * (long_now - long_last) * lateral_factor * positive_road
+ reward += self.config["speed_reward"] * (vehicle.speed / vehicle.max_speed) * positive_road
+
+ step_info["step_reward"] = reward
+
+ if self._is_arrive_destination(vehicle):
+ reward = +self.config["success_reward"]
+ elif self._is_out_of_road(vehicle):
+ reward = -self.config["out_of_road_penalty"]
+ elif vehicle.crash_vehicle:
+ reward = -self.config["crash_vehicle_penalty"]
+ elif vehicle.crash_object:
+ reward = -self.config["crash_object_penalty"]
+ return reward, step_info
+
+ def _get_reset_return(self):
+ ret = {}
+ self.engine.after_step()
+ for v_id, v in self.vehicles.items():
+ self.observations[v_id].reset(self, v)
+ ret[v_id] = self.observations[v_id].observe(v)
+ return ret if self.is_multi_agent else self._wrap_as_single_agent(ret)
+
+ def switch_to_third_person_view(self) -> (str, BaseVehicle):
+ if self.main_camera is None:
+ return
+ self.main_camera.reset()
+ if self.config["prefer_track_agent"] is not None and self.config["prefer_track_agent"] in self.vehicles.keys():
+ new_v = self.vehicles[self.config["prefer_track_agent"]]
+ current_track_vehicle = new_v
+ else:
+ if self.main_camera.is_bird_view_camera():
+ current_track_vehicle = self.current_track_vehicle
+ else:
+ vehicles = list(self.engine.agents.values())
+ if len(vehicles) <= 1:
+ return
+ if self.current_track_vehicle in vehicles:
+ vehicles.remove(self.current_track_vehicle)
+ new_v = get_np_random().choice(vehicles)
+ current_track_vehicle = new_v
+ self.main_camera.track(current_track_vehicle)
+ return
+
+ def switch_to_top_down_view(self):
+ self.main_camera.stop_track()
+
+ def setup_engine(self):
+ super(MetaDrivePPOOriginEnv, self).setup_engine()
+ self.engine.accept("b", self.switch_to_top_down_view)
+ self.engine.accept("q", self.switch_to_third_person_view)
+ from metadrive.manager.traffic_manager import TrafficManager
+ from metadrive.manager.map_manager import MapManager
+ self.engine.register_manager("map_manager", MapManager())
+ self.engine.register_manager("traffic_manager", TrafficManager())
+
+ def _is_arrive_destination(self, vehicle):
+ long, lat = vehicle.navigation.final_lane.local_coordinates(vehicle.position)
+ flag = (vehicle.navigation.final_lane.length - 5 < long < vehicle.navigation.final_lane.length + 5) and (
+ vehicle.navigation.get_current_lane_width() / 2 >= lat >=
+ (0.5 - vehicle.navigation.get_current_lane_num()) * vehicle.navigation.get_current_lane_width()
+ )
+ return flag
+
+ def _reset_global_seed(self, force_seed=None):
+ """
+ Current seed is set to force seed if force_seed is not None.
+ Otherwise, current seed is randomly generated.
+ """
+ current_seed = force_seed if force_seed is not None else \
+ get_np_random(self._DEBUG_RANDOM_SEED).randint(self.start_seed, self.start_seed + self.env_num)
+ self.seed(current_seed)
+
+ def _get_observations(self):
+ return {DEFAULT_AGENT: self.get_single_observation(self.config["vehicle_config"])}
+
+ def get_single_observation(self, _=None):
+ return TopDownMultiChannel(
+ self.config["vehicle_config"],
+ self.config["on_screen"],
+ self.config["rgb_clip"],
+ frame_stack=3,
+ post_stack=10,
+ frame_skip=1,
+ resolution=(84, 84),
+ max_distance=36,
+ )
+
+ def clone(self, caller: str):
+ cfg = copy.deepcopy(self.raw_cfg)
+ return MetaDrivePPOOriginEnv(cfg)
diff --git a/DI-engine/dizoo/metadrive/env/drive_utils.py b/DI-engine/dizoo/metadrive/env/drive_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..2009e5a52d8022b18eed171809e38ed138c209bb
--- /dev/null
+++ b/DI-engine/dizoo/metadrive/env/drive_utils.py
@@ -0,0 +1,121 @@
+from typing import Optional, List
+from gym import utils
+from abc import ABC, abstractmethod
+from typing import Any, Dict, Optional
+from easydict import EasyDict
+from itertools import product
+import gym
+import copy
+import numpy as np
+import matplotlib.pyplot as plt
+from ding.utils.default_helper import deep_merge_dicts
+
+
+class AAA():
+
+ def __init__(self) -> None:
+ self.x = 0
+
+
+def deep_update(
+ original: dict,
+ new_dict: dict,
+ new_keys_allowed: bool = False,
+ whitelist: Optional[List[str]] = None,
+ override_all_if_type_changes: Optional[List[str]] = None
+):
+ """
+ Overview:
+ Updates original dict with values from new_dict recursively.
+
+ .. note::
+
+ If new key is introduced in new_dict, then if new_keys_allowed is not
+ True, an error will be thrown. Further, for sub-dicts, if the key is
+ in the whitelist, then new subkeys can be introduced.
+
+ Arguments:
+ - original (:obj:`dict`): Dictionary with default values.
+ - new_dict (:obj:`dict`): Dictionary with values to be updated
+ - new_keys_allowed (:obj:`bool`): Whether new keys are allowed.
+ - whitelist (Optional[List[str]]): List of keys that correspond to dict
+ values where new subkeys can be introduced. This is only at the top
+ level.
+ - override_all_if_type_changes(Optional[List[str]]): List of top level
+ keys with value=dict, for which we always simply override the
+ entire value (:obj:`dict`), if the "type" key in that value dict changes.
+ """
+ whitelist = whitelist or []
+ override_all_if_type_changes = override_all_if_type_changes or []
+ for k, value in new_dict.items():
+ if k not in original and not new_keys_allowed:
+ raise RuntimeError("Unknown config parameter `{}`. Base config have: {}.".format(k, original.keys()))
+ # Both original value and new one are dicts.
+ if isinstance(original.get(k), dict) and isinstance(value, dict):
+ # Check old type vs old one. If different, override entire value.
+ if k in override_all_if_type_changes and \
+ "type" in value and "type" in original[k] and \
+ value["type"] != original[k]["type"]:
+ original[k] = value
+ # Whitelisted key -> ok to add new subkeys.
+ elif k in whitelist:
+ deep_update(original[k], value, True)
+ # Non-whitelisted key.
+ else:
+ deep_update(original[k], value, new_keys_allowed)
+ # Original value not a dict OR new value not a dict:
+ # Override entire value.
+ else:
+ original[k] = value
+ return original
+
+
+class BaseDriveEnv(gym.Env, utils.EzPickle):
+ config = dict()
+
+ @abstractmethod
+ def __init__(self, cfg: Dict, **kwargs) -> None:
+ if 'cfg_type' not in cfg:
+ self._cfg = self.__class__.default_config()
+ self._cfg = deep_merge_dicts(self._cfg, cfg)
+ else:
+ self._cfg = cfg
+ utils.EzPickle.__init__(self)
+
+ @abstractmethod
+ def step(self, action: Any) -> Any:
+ """
+ Run one step of the environment and return the observation dict.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def reset(self, *args, **kwargs) -> Any:
+ """
+ Reset current environment.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def close(self) -> None:
+ """
+ Release all resources in environment and close.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def seed(self, seed: int) -> None:
+ """
+ Set random seed.
+ """
+ raise NotImplementedError
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(cls.config)
+ cfg.cfg_type = cls.__name__ + 'Config'
+ return copy.deepcopy(cfg)
+
+ @abstractmethod
+ def __repr__(self) -> str:
+ raise NotImplementedError
diff --git a/DI-engine/dizoo/metadrive/env/drive_wrapper.py b/DI-engine/dizoo/metadrive/env/drive_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b1a1373fdcd4bb4c7a832465dd43ece677a797b
--- /dev/null
+++ b/DI-engine/dizoo/metadrive/env/drive_wrapper.py
@@ -0,0 +1,149 @@
+from typing import Any, Dict, Optional
+from easydict import EasyDict
+import matplotlib.pyplot as plt
+import gym
+import copy
+import numpy as np
+from ding.envs.env.base_env import BaseEnvTimestep
+from ding.torch_utils.data_helper import to_ndarray
+from ding.utils.default_helper import deep_merge_dicts
+from dizoo.metadrive.env.drive_utils import BaseDriveEnv
+
+
+def draw_multi_channels_top_down_observation(obs, show_time=0.5):
+ num_channels = obs.shape[-1]
+ assert num_channels == 5
+ channel_names = [
+ "Road and navigation", "Ego now and previous pos", "Neighbor at step t", "Neighbor at step t-1",
+ "Neighbor at step t-2"
+ ]
+ fig, axs = plt.subplots(1, num_channels, figsize=(15, 4), dpi=80)
+ count = 0
+
+ def close_event():
+ plt.close()
+
+ timer = fig.canvas.new_timer(interval=show_time * 1000)
+ timer.add_callback(close_event)
+ for i, name in enumerate(channel_names):
+ count += 1
+ ax = axs[i]
+ ax.imshow(obs[..., i], cmap="bone")
+ ax.set_xticks([])
+ ax.set_yticks([])
+ ax.set_title(name)
+ fig.suptitle("Multi-channels Top-down Observation")
+ timer.start()
+ plt.show()
+ plt.close()
+
+
+class DriveEnvWrapper(gym.Wrapper):
+ """
+ Overview:
+ Environment wrapper to make ``gym.Env`` align with DI-engine definitions, so as to use utilities in DI-engine.
+ It changes ``step``, ``reset`` and ``info`` method of ``gym.Env``, while others are straightly delivered.
+
+ Arguments:
+ - env (BaseDriveEnv): The environment to be wrapped.
+ - cfg (Dict): Config dict.
+ """
+ config = dict()
+
+ def __init__(self, env: BaseDriveEnv, cfg: Dict = None, **kwargs) -> None:
+ if cfg is None:
+ self._cfg = self.__class__.default_config()
+ elif 'cfg_type' not in cfg:
+ self._cfg = self.__class__.default_config()
+ self._cfg = deep_merge_dicts(self._cfg, cfg)
+ else:
+ self._cfg = cfg
+ self.env = env
+ if not hasattr(self.env, 'reward_space'):
+ self.reward_space = gym.spaces.Box(low=-float('inf'), high=float('inf'), shape=(1, ))
+ if 'show_bird_view' in self._cfg and self._cfg['show_bird_view'] is True:
+ self.show_bird_view = True
+ else:
+ self.show_bird_view = False
+ self.action_space = self.env.action_space
+ self.env = env
+
+ def reset(self, *args, **kwargs) -> Any:
+ """
+ Overview:
+ Wrapper of ``reset`` method in env. The observations are converted to ``np.ndarray`` and final reward
+ are recorded.
+ Returns:
+ - Any: Observations from environment
+ """
+ obs = self.env.reset(*args, **kwargs)
+ obs = to_ndarray(obs, dtype=np.float32)
+ if isinstance(obs, np.ndarray) and len(obs.shape) == 3:
+ obs = obs.transpose((2, 0, 1))
+ elif isinstance(obs, dict):
+ vehicle_state = obs['vehicle_state']
+ birdview = obs['birdview'].transpose((2, 0, 1))
+ obs = {'vehicle_state': vehicle_state, 'birdview': birdview}
+ self._eval_episode_return = 0.0
+ self._arrive_dest = False
+ return obs
+
+ def step(self, action: Any = None) -> BaseEnvTimestep:
+ """
+ Overview:
+ Wrapper of ``step`` method in env. This aims to convert the returns of ``gym.Env`` step method into
+ that of ``ding.envs.BaseEnv``, from ``(obs, reward, done, info)`` tuple to a ``BaseEnvTimestep``
+ namedtuple defined in DI-engine. It will also convert actions, observations and reward into
+ ``np.ndarray``, and check legality if action contains control signal.
+ Arguments:
+ - action (Any, optional): Actions sent to env. Defaults to None.
+ Returns:
+ - BaseEnvTimestep: DI-engine format of env step returns.
+ """
+ action = to_ndarray(action)
+ obs, rew, done, info = self.env.step(action)
+ if self.show_bird_view:
+ draw_multi_channels_top_down_observation(obs, show_time=0.5)
+ self._eval_episode_return += rew
+ obs = to_ndarray(obs, dtype=np.float32)
+ if isinstance(obs, np.ndarray) and len(obs.shape) == 3:
+ obs = obs.transpose((2, 0, 1))
+ elif isinstance(obs, dict):
+ vehicle_state = obs['vehicle_state']
+ birdview = obs['birdview'].transpose((2, 0, 1))
+ obs = {'vehicle_state': vehicle_state, 'birdview': birdview}
+ rew = to_ndarray([rew], dtype=np.float32)
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ @property
+ def observation_space(self):
+ return gym.spaces.Box(0, 1, shape=(5, 84, 84), dtype=np.float32)
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+ self.env = gym.wrappers.Monitor(self.env, self._replay_path, video_callable=lambda episode_id: True, force=True)
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(cls.config)
+ cfg.cfg_type = cls.__name__ + 'Config'
+ return copy.deepcopy(cfg)
+
+ def __repr__(self) -> str:
+ return repr(self.env)
+
+ def render(self):
+ self.env.render()
+
+ def clone(self, caller: str):
+ cfg = copy.deepcopy(self._cfg)
+ return DriveEnvWrapper(self.env.clone(caller), cfg)
diff --git a/DI-engine/dizoo/minigrid/__init__.py b/DI-engine/dizoo/minigrid/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..db6673867bedab3f4b20be4e0418affea76a893a
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/__init__.py
@@ -0,0 +1,15 @@
+from gymnasium.envs.registration import register
+
+register(id='MiniGrid-AKTDT-7x7-1-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_7x7_1')
+
+register(id='MiniGrid-AKTDT-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure')
+
+register(id='MiniGrid-AKTDT-13x13-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_13x13')
+
+register(id='MiniGrid-AKTDT-13x13-1-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_13x13_1')
+
+register(id='MiniGrid-AKTDT-19x19-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_19x19')
+
+register(id='MiniGrid-AKTDT-19x19-3-v0', entry_point='dizoo.minigrid.envs:AppleKeyToDoorTreasure_19x19_3')
+
+register(id='MiniGrid-NoisyTV-v0', entry_point='dizoo.minigrid.envs:NoisyTVEnv')
\ No newline at end of file
diff --git a/DI-engine/dizoo/minigrid/config/minigrid_icm_offppo_config.py b/DI-engine/dizoo/minigrid/config/minigrid_icm_offppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..92465fd1fb676f252c1fd10cb8f13d06674320bf
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/config/minigrid_icm_offppo_config.py
@@ -0,0 +1,81 @@
+from easydict import EasyDict
+
+minigrid_icm_offppo_config = dict(
+ exp_name='minigrid_fourroom_icm_offppo_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ # minigrid env id: 'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0','MiniGrid-DoorKey-16x16-v0','MiniGrid-AKTDT-7x7-1-v0'
+ env_id='MiniGrid-FourRooms-v0',
+ max_step=100,
+ stop_value=0.96,
+ ),
+ reward_model=dict(
+ intrinsic_reward_type='add',
+ # intrinsic_reward_weight means the relative weight of RND intrinsic_reward.
+ # Specifically for sparse reward env MiniGrid, in this env,
+ # if reach goal, the agent get reward ~1, otherwise 0,
+ # We could set the intrinsic_reward_weight approximately equal to the inverse of max_episode_steps.
+ # Please refer to rnd_reward_model for details.
+ intrinsic_reward_weight=0.001,
+ learning_rate=3e-4,
+ obs_shape=2835,
+ batch_size=320,
+ update_per_collect=50,
+ clear_buffer_per_iters=int(1e3),
+ obs_norm=True,
+ obs_norm_clamp_max=5,
+ obs_norm_clamp_min=-5,
+ extrinsic_reward_norm=True,
+ extrinsic_reward_norm_max=1,
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=2835,
+ action_shape=7,
+ action_space='discrete',
+ encoder_hidden_size_list=[256, 128, 64, 64],
+ critic_head_hidden_size=64,
+ actor_head_hidden_size=64,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=320,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=3200,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=200, )),
+ ),
+)
+minigrid_icm_offppo_config = EasyDict(minigrid_icm_offppo_config)
+main_config = minigrid_icm_offppo_config
+minigrid_icm_offppo_create_config = dict(
+ env=dict(
+ type='minigrid',
+ import_names=['dizoo.minigrid.envs.minigrid_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo_offpolicy'),
+ reward_model=dict(type='icm'),
+)
+minigrid_icm_offppo_create_config = EasyDict(minigrid_icm_offppo_create_config)
+create_config = minigrid_icm_offppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c minigrid_icm_config.py -s 0`
+ from ding.entry import serial_pipeline_reward_model_offpolicy
+ serial_pipeline_reward_model_offpolicy([main_config, create_config], seed=0, max_env_step=int(10e6))
diff --git a/DI-engine/dizoo/minigrid/config/minigrid_icm_onppo_config.py b/DI-engine/dizoo/minigrid/config/minigrid_icm_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..dc21fe5fc438bd5660eb03287c19cd6c0071c965
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/config/minigrid_icm_onppo_config.py
@@ -0,0 +1,82 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 5
+minigrid_icm_onppo_config = dict(
+ exp_name='minigrid_AKTDT-7x7_icm_onppo_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ # minigrid env id: 'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0','MiniGrid-DoorKey-16x16-v0','MiniGrid-AKTDT-7x7-1-v0'
+ env_id='MiniGrid-NoisyTV-v0',
+ max_step=100,
+ stop_value=12, # run fixed env_steps for MiniGrid-AKTDT-7x7-1-v0
+ # stop_value=0.96,
+ ),
+ reward_model=dict(
+ intrinsic_reward_type='add',
+ # intrinsic_reward_weight means the relative weight of ICM intrinsic_reward.
+ # Specifically for sparse reward env MiniGrid, in this env,
+ # if reach goal, the agent get reward ~1, otherwise 0,
+ # We could set the intrinsic_reward_weight approximately equal to the inverse of max_episode_steps.
+ # Please refer to rnd_reward_model for details.
+ intrinsic_reward_weight=0.003, # 1/300
+ learning_rate=3e-4,
+ obs_shape=2835, # 2715 in MiniGrid-AKTDT-7x7-1-v0 env
+ batch_size=320,
+ update_per_collect=50,
+ clear_buffer_per_iters=int(1e3),
+ extrinsic_reward_norm=True,
+ extrinsic_reward_norm_max=1,
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=2835, # 2715 in MiniGrid-AKTDT-7x7-1-v0 env
+ action_shape=7,
+ action_space='discrete',
+ encoder_hidden_size_list=[256, 128, 64, 64],
+ critic_head_hidden_size=64,
+ actor_head_hidden_size=64,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ update_per_collect=1,
+ batch_size=320,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=3200,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ ),
+)
+minigrid_icm_onppo_config = EasyDict(minigrid_icm_onppo_config)
+main_config = minigrid_icm_onppo_config
+minigrid_icm_onppo_create_config = dict(
+ env=dict(
+ type='minigrid',
+ import_names=['dizoo.minigrid.envs.minigrid_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+ reward_model=dict(type='icm'),
+)
+minigrid_icm_onppo_create_config = EasyDict(minigrid_icm_onppo_create_config)
+create_config = minigrid_icm_onppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c minigrid_icm_onppo_config.py -s 0`
+ from ding.entry import serial_pipeline_reward_model_onpolicy
+ serial_pipeline_reward_model_onpolicy([main_config, create_config], seed=0, max_env_step=int(10e6))
\ No newline at end of file
diff --git a/DI-engine/dizoo/minigrid/config/minigrid_ngu_config.py b/DI-engine/dizoo/minigrid/config/minigrid_ngu_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1aa47e1ebe9afc7bac889c30b8c16f1352cf0e6
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/config/minigrid_ngu_config.py
@@ -0,0 +1,129 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+nstep = 5
+minigrid_ppo_ngu_config = dict(
+ exp_name='minigrid_doorkey_ngu_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ # typical MiniGrid env id:
+ # {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'},
+ # please refer to https://github.com/Farama-Foundation/MiniGrid for details.
+ env_id='MiniGrid-DoorKey-16x16-v0',
+ obs_plus_prev_action_reward=True, # use specific env wrapper for ngu policy
+ max_step=300,
+ stop_value=0.96,
+ ),
+ rnd_reward_model=dict(
+ intrinsic_reward_type='add',
+ learning_rate=5e-4,
+ obs_shape=2835,
+ action_shape=7,
+ batch_size=320, # transitions
+ update_per_collect=10, # 32*100/320=10
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='rnd-ngu',
+ ),
+ episodic_reward_model=dict(
+ # means if using rescale trick to the last non-zero reward
+ # when combing extrinsic and intrinsic reward.
+ # the rescale trick only used in:
+ # 1. sparse reward env minigrid, in which the last non-zero reward is a strong positive signal
+ # 2. the last reward of each episode directly reflects the agent's completion of the task, e.g. lunarlander
+ # Note that the ngu intrinsic reward is a positive value (max value is 5), in these envs,
+ # the last non-zero reward should not be overwhelmed by intrinsic rewards, so we need rescale the
+ # original last nonzero extrinsic reward.
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_rescale=True,
+ # means the rescale value for the last non-zero reward, only used when last_nonzero_reward_rescale is True
+ # please refer to ngu_reward_model for details.
+ last_nonzero_reward_weight=100,
+ intrinsic_reward_type='add',
+ learning_rate=5e-4,
+ obs_shape=2739,
+ action_shape=7,
+ batch_size=320, # transitions
+ update_per_collect=10, # 32*100/64=50
+ only_use_last_five_frames_for_icm_rnd=False,
+ clear_buffer_per_iters=10,
+ nstep=nstep,
+ hidden_size_list=[128, 128, 64],
+ type='episodic',
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ priority_IS_weight=True,
+ discount_factor=0.997,
+ nstep=nstep,
+ burnin_step=2,
+ # (int) is the total length of [sequence sample] minus
+ # the length of burnin part in [sequence sample],
+ # i.e., = = +
+ learn_unroll_len=298, # set this key according to the episode length
+ model=dict(
+ obs_shape=2739,
+ action_shape=7,
+ encoder_hidden_size_list=[128, 128, 512],
+ collector_env_num=collector_env_num,
+ ),
+ learn=dict(
+ update_per_collect=16,
+ batch_size=64,
+ learning_rate=1e-4,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=30000,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+minigrid_ppo_ngu_config = EasyDict(minigrid_ppo_ngu_config)
+main_config = minigrid_ppo_ngu_config
+minigrid_ppo_ngu_create_config = dict(
+ env=dict(
+ type='minigrid',
+ import_names=['dizoo.minigrid.envs.minigrid_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ngu'),
+ rnd_reward_model=dict(type='rnd-ngu'),
+ episodic_reward_model=dict(type='episodic'),
+)
+minigrid_ppo_ngu_create_config = EasyDict(minigrid_ppo_ngu_create_config)
+create_config = minigrid_ppo_ngu_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_ngu -c minigrid_ngu_config.py -s 0`
+ from ding.entry import serial_pipeline_ngu
+ serial_pipeline_ngu([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/minigrid/config/minigrid_offppo_config.py b/DI-engine/dizoo/minigrid/config/minigrid_offppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9ae43f5a84dad44f6dddceb92583ca070724f5d2
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/config/minigrid_offppo_config.py
@@ -0,0 +1,56 @@
+from easydict import EasyDict
+
+minigrid_ppo_config = dict(
+ exp_name="minigrid_empty8_offppo_seed0",
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ # typical MiniGrid env id:
+ # {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'},
+ # please refer to https://github.com/Farama-Foundation/MiniGrid for details.
+ env_id='MiniGrid-Empty-8x8-v0',
+ n_evaluator_episode=5,
+ max_step=300,
+ stop_value=0.96,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=2835,
+ action_shape=7,
+ encoder_hidden_size_list=[256, 128, 64, 64],
+ ),
+ learn=dict(
+ update_per_collect=4,
+ batch_size=64,
+ learning_rate=0.0003,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=False,
+ ),
+ collect=dict(
+ n_sample=128,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+)
+minigrid_ppo_config = EasyDict(minigrid_ppo_config)
+main_config = minigrid_ppo_config
+minigrid_ppo_create_config = dict(
+ env=dict(
+ type='minigrid',
+ import_names=['dizoo.minigrid.envs.minigrid_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo_offpolicy'),
+)
+minigrid_ppo_create_config = EasyDict(minigrid_ppo_create_config)
+create_config = minigrid_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c minigrid_offppo_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/minigrid/config/minigrid_onppo_config.py b/DI-engine/dizoo/minigrid/config/minigrid_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..de4f81cd8e8f6a3e687feb0d8f8ad75a691338ee
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/config/minigrid_onppo_config.py
@@ -0,0 +1,63 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+minigrid_ppo_config = dict(
+ exp_name="minigrid_empty8_onppo_seed0",
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ # typical MiniGrid env id:
+ # {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'},
+ # please refer to https://github.com/Farama-Foundation/MiniGrid for details.
+ env_id='MiniGrid-Empty-8x8-v0',
+ max_step=300,
+ stop_value=0.96,
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=2835,
+ action_shape=7,
+ action_space='discrete',
+ encoder_hidden_size_list=[256, 128, 64, 64],
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ update_per_collect=1,
+ batch_size=320,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ collector_env_num=collector_env_num,
+ n_sample=int(3200),
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+)
+minigrid_ppo_config = EasyDict(minigrid_ppo_config)
+main_config = minigrid_ppo_config
+minigrid_ppo_create_config = dict(
+ env=dict(
+ type='minigrid',
+ import_names=['dizoo.minigrid.envs.minigrid_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+minigrid_ppo_create_config = EasyDict(minigrid_ppo_create_config)
+create_config = minigrid_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c minigrid_onppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/minigrid/config/minigrid_onppo_stdim_config.py b/DI-engine/dizoo/minigrid/config/minigrid_onppo_stdim_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4e025984cf89b2760f731620b5af233aa5ac1c8
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/config/minigrid_onppo_stdim_config.py
@@ -0,0 +1,72 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 5
+minigrid_ppo_stdim_config = dict(
+ exp_name="minigrid_empty8_onppo_stdim_seed0",
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ # typical MiniGrid env id:
+ # {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'},
+ # please refer to https://github.com/Farama-Foundation/MiniGrid for details.
+ env_id='MiniGrid-Empty-8x8-v0',
+ max_step=300,
+ stop_value=0.96,
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=2835,
+ action_shape=7,
+ action_space='discrete',
+ encoder_hidden_size_list=[256, 128, 64, 64],
+ ),
+ aux_model=dict(
+ encode_shape=64,
+ heads=[1, 1],
+ loss_type='infonce',
+ temperature=1.0,
+ ),
+ # the weight of the auxiliary loss to the TD loss
+ aux_loss_weight=0.003,
+ learn=dict(
+ epoch_per_collect=10,
+ update_per_collect=1,
+ batch_size=320,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ collector_env_num=collector_env_num,
+ n_sample=int(3200),
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+)
+minigrid_ppo_stdim_config = EasyDict(minigrid_ppo_stdim_config)
+main_config = minigrid_ppo_stdim_config
+minigrid_ppo_stdim_create_config = dict(
+ env=dict(
+ type='minigrid',
+ import_names=['dizoo.minigrid.envs.minigrid_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo_stdim'),
+)
+minigrid_ppo_stdim_create_config = EasyDict(minigrid_ppo_stdim_create_config)
+create_config = minigrid_ppo_stdim_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c minigrid_onppo_stdim_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/minigrid/config/minigrid_r2d2_config.py b/DI-engine/dizoo/minigrid/config/minigrid_r2d2_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b650a503dfdfacb3f24c888f733be5e9e67c7130
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/config/minigrid_r2d2_config.py
@@ -0,0 +1,91 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 5
+minigrid_r2d2_config = dict(
+ exp_name='debug_minigrid_doorkey_r2d2_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ # typical MiniGrid env id:
+ # {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'},
+ # please refer to https://github.com/Farama-Foundation/MiniGrid for details.
+ env_id='MiniGrid-DoorKey-16x16-v0',
+ n_evaluator_episode=5,
+ max_step=300,
+ stop_value=0.96,
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ priority=True,
+ priority_IS_weight=True,
+ model=dict(
+ obs_shape=2835,
+ action_shape=7,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ discount_factor=0.997,
+ nstep=5,
+ burnin_step=2,
+ # (int) the whole sequence length to unroll the RNN network minus
+ # the timesteps of burnin part,
+ # i.e., = = +
+ learn_unroll_len=40,
+ learn=dict(
+ # according to the R2D2 paper, actor parameter update interval is 400
+ # environment timesteps, and in per collect phase, we collect sequence
+ # samples, the length of each sequence sample is + ,
+ # e.g. if n_sample=32, is 100, thus 32*100/400=8,
+ # we will set update_per_collect=8 in most environments.
+ update_per_collect=8,
+ batch_size=64,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ ),
+ collect=dict(
+ # NOTE: It is important that set key traj_len_inf=True here,
+ # to make sure self._traj_len=INF in serial_sample_collector.py.
+ # In sequence-based policy, for each collect_env,
+ # we want to collect data of length self._traj_len=INF
+ # unless the episode enters the 'done' state.
+ # In each collect phase, we collect a total of sequence samples.
+ n_sample=32,
+ traj_len_inf=True,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.95,
+ end=0.05,
+ decay=1e5,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=100000,
+ # (Float type) How much prioritization is used: 0 means no prioritization while 1 means full prioritization
+ alpha=0.6,
+ # (Float type) How much correction is used: 0 means no correction while 1 means full correction
+ beta=0.4,
+ )
+ ),
+ ),
+)
+minigrid_r2d2_config = EasyDict(minigrid_r2d2_config)
+main_config = minigrid_r2d2_config
+minigrid_r2d2_create_config = dict(
+ env=dict(
+ type='minigrid',
+ import_names=['dizoo.minigrid.envs.minigrid_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='r2d2'),
+)
+minigrid_r2d2_create_config = EasyDict(minigrid_r2d2_create_config)
+create_config = minigrid_r2d2_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c minigrid_r2d2_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/minigrid/config/minigrid_rnd_onppo_config.py b/DI-engine/dizoo/minigrid/config/minigrid_rnd_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a39d3e4d8f26d988599dc091813e7cf11c5730fe
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/config/minigrid_rnd_onppo_config.py
@@ -0,0 +1,89 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 5
+minigrid_ppo_rnd_config = dict(
+ exp_name='minigrid_doorkey8x8_rnd_onppo_seed0',
+ env=dict(
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ # typical MiniGrid env id:
+ # {'MiniGrid-Empty-8x8-v0', 'MiniGrid-FourRooms-v0', 'MiniGrid-DoorKey-8x8-v0','MiniGrid-DoorKey-16x16-v0'},
+ # please refer to https://github.com/Farama-Foundation/MiniGrid for details.
+ env_id='MiniGrid-DoorKey-8x8-v0',
+ # env_id='MiniGrid-AKTDT-7x7-1-v0',
+ max_step=100,
+ stop_value=20, # run fixed env_steps
+ # stop_value=0.96,
+ ),
+ reward_model=dict(
+ intrinsic_reward_type='add',
+ # intrinsic_reward_weight means the relative weight of RND intrinsic_reward.
+ # Specifically for sparse reward env MiniGrid, in this env,
+ # if reach goal, the agent get reward ~1, otherwise 0,
+ # We could set the intrinsic_reward_weight approximately equal to the inverse of max_episode_steps.
+ # Please refer to rnd_reward_model for details.
+ intrinsic_reward_weight=0.003, # 1/300
+ learning_rate=3e-4,
+ obs_shape=2835,
+ batch_size=320,
+ update_per_collect=50,
+ clear_buffer_per_iters=int(1e3),
+ obs_norm=False,
+ obs_norm_clamp_max=5,
+ obs_norm_clamp_min=-5,
+ extrinsic_reward_norm=True,
+ extrinsic_reward_norm_max=1,
+ ),
+ policy=dict(
+ recompute_adv=True,
+ cuda=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=2835,
+ action_shape=7,
+ action_space='discrete',
+ encoder_hidden_size_list=[256, 128, 64, 64],
+ critic_head_hidden_size=64,
+ actor_head_hidden_size=64,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ update_per_collect=1,
+ batch_size=320,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ collector_env_num=collector_env_num,
+ n_sample=3200,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ ),
+)
+minigrid_ppo_rnd_config = EasyDict(minigrid_ppo_rnd_config)
+main_config = minigrid_ppo_rnd_config
+minigrid_ppo_rnd_create_config = dict(
+ env=dict(
+ type='minigrid',
+ import_names=['dizoo.minigrid.envs.minigrid_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+ reward_model=dict(type='rnd'),
+)
+minigrid_ppo_rnd_create_config = EasyDict(minigrid_ppo_rnd_create_config)
+create_config = minigrid_ppo_rnd_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c minigrid_rnd_onppo_config.py -s 0`
+ from ding.entry import serial_pipeline_reward_model_onpolicy
+ serial_pipeline_reward_model_onpolicy([main_config, create_config], seed=0, max_env_step=int(10e6))
diff --git a/DI-engine/dizoo/minigrid/entry/minigrid_onppo_main.py b/DI-engine/dizoo/minigrid/entry/minigrid_onppo_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..eeb97b2fe76df8c88bd84e5cc9182991ba48f16c
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/entry/minigrid_onppo_main.py
@@ -0,0 +1,87 @@
+import gym
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
+from ding.model import VAC
+from ding.policy import PPOPolicy
+from ding.envs import DingEnvWrapper, EvalEpisodeReturnWrapper, BaseEnvManager
+from ding.config import compile_config
+from ding.utils import set_pkg_seed
+
+from dizoo.minigrid.config.minigrid_onppo_config import minigrid_ppo_config
+from minigrid.wrappers import FlatObsWrapper
+import numpy as np
+from tensorboardX import SummaryWriter
+import os
+import gymnasium
+
+
+class MinigridWrapper(gym.Wrapper):
+
+ def __init__(self, env):
+ super().__init__(env)
+ self._observation_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(8, ), dtype=np.float32)
+ self._action_space = gym.spaces.Discrete(9)
+ self._action_space.seed(0) # default seed
+ self.reward_range = (float('-inf'), float('inf'))
+ self.max_steps = minigrid_ppo_config.env.max_step
+
+ def step(self, action):
+ obs, reward, done, _, info = self.env.step(action)
+ self.cur_step += 1
+ if self.cur_step > self.max_steps:
+ done = True
+ return obs, reward, done, info
+
+ def reset(self):
+ self.cur_step = 0
+ return self.env.reset()[0]
+
+
+def wrapped_minigrid_env():
+ return DingEnvWrapper(
+ gymnasium.make(minigrid_ppo_config.env.env_id),
+ cfg={
+ 'env_wrapper': [
+ lambda env: FlatObsWrapper(env),
+ lambda env: MinigridWrapper(env),
+ lambda env: EvalEpisodeReturnWrapper(env),
+ ]
+ }
+ )
+
+
+def main(cfg, seed=0, max_env_step=int(1e10), max_train_iter=int(1e10)):
+ cfg = compile_config(
+ cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(env_fn=[wrapped_minigrid_env for _ in range(collector_env_num)], cfg=cfg.env.manager)
+ evaluator_env = BaseEnvManager(env_fn=[wrapped_minigrid_env for _ in range(evaluator_env_num)], cfg=cfg.env.manager)
+
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+
+ while True:
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ new_data = collector.collect(train_iter=learner.train_iter)
+ learner.train(new_data, collector.envstep)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+
+if __name__ == '__main__':
+ main(minigrid_ppo_config)
diff --git a/DI-engine/dizoo/minigrid/envs/__init__.py b/DI-engine/dizoo/minigrid/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..02d73004ffbf04fb1e18e430068ca1f7daef8ddc
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/envs/__init__.py
@@ -0,0 +1,3 @@
+from .minigrid_env import MiniGridEnv
+from dizoo.minigrid.envs.app_key_to_door_treasure import AppleKeyToDoorTreasure, AppleKeyToDoorTreasure_13x13, AppleKeyToDoorTreasure_19x19, AppleKeyToDoorTreasure_13x13_1, AppleKeyToDoorTreasure_19x19_3, AppleKeyToDoorTreasure_7x7_1
+from dizoo.minigrid.envs.noisy_tv import NoisyTVEnv
\ No newline at end of file
diff --git a/DI-engine/dizoo/minigrid/envs/app_key_to_door_treasure.py b/DI-engine/dizoo/minigrid/envs/app_key_to_door_treasure.py
new file mode 100644
index 0000000000000000000000000000000000000000..3c17db3ae6f2ff7f4ff03f6ef5d978014d96ba6e
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/envs/app_key_to_door_treasure.py
@@ -0,0 +1,223 @@
+#!/usr/bin/env python
+# -*- coding: utf-8 -*-
+
+from minigrid.minigrid_env import *
+from minigrid.utils.rendering import *
+from minigrid.core.world_object import WorldObj
+
+
+class Ball(WorldObj):
+
+ def __init__(self, color='blue'):
+ super(Ball, self).__init__('ball', color)
+
+ def can_pickup(self):
+ return False
+
+ def render(self, img):
+ fill_coords(img, point_in_circle(0.5, 0.5, 0.31), COLORS[self.color])
+
+
+class AppleKeyToDoorTreasure(MiniGridEnv):
+ """
+ Classic 4 rooms gridworld environment.
+ Can specify agent and goal position, if not it set at random.
+ """
+
+ def __init__(self, agent_pos=None, goal_pos=None, grid_size=19, apple=2):
+ self._agent_default_pos = agent_pos
+ self._goal_default_pos = goal_pos
+ self.apple = apple
+ mission_space = MissionSpace(mission_func=lambda: "Reach the goal")
+ super().__init__(mission_space=mission_space, grid_size=grid_size, max_steps=100)
+
+ def _gen_grid(
+ self, width, height
+ ): # Note that it is inherited from MiniGridEnv that if width and height == None, width = grid_size , height = grid_size
+ # Create the grid
+ self.grid = Grid(width, height)
+
+ # Generate the surrounding walls
+ self.grid.horz_wall(0, 0)
+ self.grid.horz_wall(0, height - 1)
+ self.grid.vert_wall(0, 0)
+ self.grid.vert_wall(width - 1, 0)
+
+ room_w = width // 2
+ room_h = height // 2
+
+ # For each row of rooms
+ for j in range(0, 2):
+
+ # For each column
+ for i in range(0, 2):
+ xL = i * room_w
+ yT = j * room_h
+ xR = xL + room_w
+ yB = yT + room_h
+
+ # Bottom wall and door
+ if i + 1 < 2:
+ if j + 1 < 2:
+ self.grid.vert_wall(xR, yT, room_h)
+ # pos = (xR, self._rand_int(yT + 1, yB))
+ else:
+ self.grid.vert_wall(xR, yT, room_h)
+ pos = (xR, self._rand_int(yT + 1, yB))
+ self.grid.set(*pos, None)
+
+ # Bottom wall and door
+ if j + 1 < 2:
+ if i + 1 < 2:
+ self.grid.horz_wall(xL, yB, room_w)
+ pos = (self._rand_int(xL + 1, xR), yB)
+ self.grid.set(*pos, None)
+ else:
+ self.grid.horz_wall(xL, yB, room_w)
+ pos = (self._rand_int(xL + 1, xR), yB)
+ self.put_obj(Door('yellow', is_locked=True), *pos)
+
+ # Place a yellow key on the left side
+ pos1 = (self._rand_int(room_w + 1, 2 * room_w), self._rand_int(room_h + 1, 2 * room_h)) # self._rand_int: [)
+ self.put_obj(Key('yellow'), *pos1)
+ pos2_dummy_list = [] # to avoid overlap of apples
+ for i in range(self.apple):
+ pos2 = (self._rand_int(1, room_w), self._rand_int(1, room_h))
+ while pos2 in pos2_dummy_list:
+ pos2 = (self._rand_int(1, room_w), self._rand_int(1, room_h))
+ self.put_obj(Ball('red'), *pos2)
+ pos2_dummy_list.append(pos2)
+ # Randomize the player start position and orientation
+ if self._agent_default_pos is not None:
+ self.agent_pos = self._agent_default_pos
+ self.grid.set(*self._agent_default_pos, None)
+ self.agent_dir = self._rand_int(0, 4) # assuming random start direction
+ else:
+ self.place_agent()
+
+ if self._goal_default_pos is not None:
+ goal = Goal()
+ self.put_obj(goal, *self._goal_default_pos)
+ goal.init_pos, goal.cur_pos = self._goal_default_pos
+ else:
+ self.place_obj(Goal())
+
+ self.mission = 'Reach the goal'
+
+ def _reward_ball(self):
+ """
+ Compute the reward to be given upon finding the apple
+ """
+
+ return 1
+
+ def _reward_goal(self):
+ """
+ Compute the reward to be given upon success
+ """
+
+ return 10
+
+ def step(self, action):
+ self.step_count += 1
+
+ reward = 0
+ done = False
+
+ # Get the position in front of the agent
+ fwd_pos = self.front_pos
+
+ # Get the contents of the cell in front of the agent
+ fwd_cell = self.grid.get(*fwd_pos)
+
+ # Rotate left
+ if action == self.actions.left:
+ self.agent_dir -= 1
+ if self.agent_dir < 0:
+ self.agent_dir += 4
+
+ # Rotate right
+ elif action == self.actions.right:
+ self.agent_dir = (self.agent_dir + 1) % 4
+
+ # Move forward
+ elif action == self.actions.forward:
+ if fwd_cell == None or fwd_cell.can_overlap(): # Ball and keys' can_overlap are False
+ self.agent_pos = fwd_pos
+ if fwd_cell != None and fwd_cell.type == 'goal':
+ done = True
+ reward = self._reward_goal()
+ if fwd_cell != None and fwd_cell.type == 'ball':
+ reward = self._reward_ball()
+ self.grid.set(*fwd_pos, None)
+ self.agent_pos = fwd_pos
+ if fwd_cell != None and fwd_cell.type == 'lava':
+ done = True
+
+ # Pick up an object
+ elif action == self.actions.pickup:
+ if fwd_cell and fwd_cell.can_pickup():
+ if self.carrying is None:
+ self.carrying = fwd_cell
+ self.carrying.cur_pos = np.array([-1, -1])
+ self.grid.set(*fwd_pos, None)
+
+ # Drop an object
+ elif action == self.actions.drop:
+ if not fwd_cell and self.carrying:
+ self.grid.set(*fwd_pos, self.carrying)
+ self.carrying.cur_pos = fwd_pos
+ self.carrying = None
+
+ # Toggle/activate an object: Here, this will open the door if you have the right key
+ elif action == self.actions.toggle:
+ if fwd_cell:
+ fwd_cell.toggle(self, fwd_pos)
+
+ # Done action (not used by default)
+ elif action == self.actions.done:
+ pass
+
+ else:
+ assert False, "unknown action"
+
+ if self.step_count >= self.max_steps:
+ done = True
+
+ obs = self.gen_obs()
+ # return is (observation, reward, terminated, truncated, info)
+ return obs, reward, done, done, {}
+
+
+class AppleKeyToDoorTreasure_13x13(AppleKeyToDoorTreasure):
+
+ def __init__(self):
+ super().__init__(agent_pos=(2, 8), goal_pos=(7, 1), grid_size=13, apple=2)
+
+
+class AppleKeyToDoorTreasure_19x19(AppleKeyToDoorTreasure):
+
+ def __init__(self):
+ super().__init__(agent_pos=(2, 14), goal_pos=(10, 1), grid_size=19, apple=2)
+
+
+class AppleKeyToDoorTreasure_13x13_1(AppleKeyToDoorTreasure):
+
+ def __init__(self):
+ super().__init__(agent_pos=(2, 8), goal_pos=(7, 1), grid_size=13, apple=1)
+
+
+class AppleKeyToDoorTreasure_7x7_1(AppleKeyToDoorTreasure):
+
+ def __init__(self):
+ super().__init__(agent_pos=(1, 5), goal_pos=(4, 1), grid_size=7, apple=1)
+
+
+class AppleKeyToDoorTreasure_19x19_3(AppleKeyToDoorTreasure):
+
+ def __init__(self):
+ super().__init__(agent_pos=(2, 14), goal_pos=(10, 1), grid_size=19, apple=3)
+
+
+if __name__ == '__main__':
+ AppleKeyToDoorTreasure()._gen_grid(13, 13) # Note that Minigrid has set seeds automatically
diff --git a/DI-engine/dizoo/minigrid/envs/minigrid_env.py b/DI-engine/dizoo/minigrid/envs/minigrid_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..12bd64cae07bdc6f72f8f92cace771764b7dbe0d
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/envs/minigrid_env.py
@@ -0,0 +1,179 @@
+from typing import Any, List, Union, Optional
+from collections import namedtuple
+from easydict import EasyDict
+import copy
+import os
+import time
+import gymnasium as gym
+
+import numpy as np
+from matplotlib import animation
+import matplotlib.pyplot as plt
+from minigrid.wrappers import FlatObsWrapper, RGBImgPartialObsWrapper, ImgObsWrapper
+from .minigrid_wrapper import ViewSizeWrapper
+from ding.envs import ObsPlusPrevActRewWrapper
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.torch_utils import to_ndarray, to_list
+from ding.utils import ENV_REGISTRY
+
+
+@ENV_REGISTRY.register('minigrid')
+class MiniGridEnv(BaseEnv):
+ config = dict(
+ env_id='MiniGrid-KeyCorridorS3R3-v0',
+ flat_obs=True,
+ )
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ self._env_id = cfg.env_id
+ self._flat_obs = cfg.flat_obs
+ self._save_replay = False
+ self._max_step = cfg.max_step
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ if self._save_replay:
+ self._env = gym.make(self._env_id, render_mode="rgb_array") # using the Gymnasium make method
+ else:
+ self._env = gym.make(self._env_id)
+
+ if self._env_id in ['MiniGrid-AKTDT-13x13-v0' or 'MiniGrid-AKTDT-13x13-1-v0']:
+ # customize the agent field of view size, note this must be an odd number
+ # This also related to the observation space, see gym_minigrid.wrappers for more details
+ self._env = ViewSizeWrapper(self._env, agent_view_size=5)
+ if self._env_id == 'MiniGrid-AKTDT-7x7-1-v0':
+ self._env = ViewSizeWrapper(self._env, agent_view_size=3)
+ if self._flat_obs:
+ self._env = FlatObsWrapper(self._env)
+ # self._env = RGBImgPartialObsWrapper(self._env)
+ # self._env = ImgObsWrapper(self._env)
+ if hasattr(self._cfg, 'obs_plus_prev_action_reward') and self._cfg.obs_plus_prev_action_reward:
+ self._env = ObsPlusPrevActRewWrapper(self._env)
+ self._init_flag = True
+ if self._flat_obs:
+ self._observation_space = gym.spaces.Box(0, 1, shape=(2835, ), dtype=np.float32)
+ else:
+ self._observation_space = self._env.observation_space
+ # to be compatiable with subprocess env manager
+ if isinstance(self._observation_space, gym.spaces.Dict):
+ self._observation_space['obs'].dtype = np.dtype('float32')
+ else:
+ self._observation_space.dtype = np.dtype('float32')
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+
+ self._eval_episode_return = 0
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._seed = self._seed + np_seed
+ obs, _ = self._env.reset(seed=self._seed) # using the reset method of Gymnasium env
+ elif hasattr(self, '_seed'):
+ obs, _ = self._env.reset(seed=self._seed)
+ else:
+ obs, _ = self._env.reset()
+ obs = to_ndarray(obs)
+ self._current_step = 0
+ if self._save_replay:
+ self._frames = []
+
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ assert isinstance(action, np.ndarray), type(action)
+ if action.shape == (1, ):
+ action = action.squeeze() # 0-dim array
+ if self._save_replay:
+ self._frames.append(self._env.render())
+ # using the step method of Gymnasium env, return is (observation, reward, terminated, truncated, info)
+ obs, rew, done, _, info = self._env.step(action)
+ rew = float(rew)
+ self._eval_episode_return += rew
+ self._current_step += 1
+ if self._current_step >= self._max_step:
+ done = True
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ info['current_step'] = self._current_step
+ info['max_step'] = self._max_step
+ if self._save_replay:
+ path = os.path.join(
+ self._replay_path, '{}_episode_{}.gif'.format(self._env_id, self._save_replay_count)
+ )
+ self.display_frames_as_gif(self._frames, path)
+ self._save_replay_count += 1
+ obs = to_ndarray(obs)
+ rew = to_ndarray([rew]) # wrapped to be transferred to a array with shape (1,)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_env_num = cfg.pop('collector_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = True
+ return [cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_env_num = cfg.pop('evaluator_env_num')
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = False
+ return [cfg for _ in range(evaluator_env_num)]
+
+ def __repr__(self) -> str:
+ return "DI-engine MiniGrid Env({})".format(self._cfg.env_id)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._save_replay = True
+ self._replay_path = replay_path
+ self._save_replay_count = 0
+
+ @staticmethod
+ def display_frames_as_gif(frames: list, path: str) -> None:
+ patch = plt.imshow(frames[0])
+ plt.axis('off')
+
+ def animate(i):
+ patch.set_data(frames[i])
+
+ anim = animation.FuncAnimation(plt.gcf(), animate, frames=len(frames), interval=5)
+ anim.save(path, writer='imagemagick', fps=20)
diff --git a/DI-engine/dizoo/minigrid/envs/minigrid_wrapper.py b/DI-engine/dizoo/minigrid/envs/minigrid_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..09a14c9c819a56334765f35379d65195cc6d55b2
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/envs/minigrid_wrapper.py
@@ -0,0 +1,34 @@
+import gymnasium as gym
+from gymnasium import spaces
+from gymnasium.core import ObservationWrapper
+
+
+class ViewSizeWrapper(ObservationWrapper):
+ """
+ Wrapper to customize the agent field of view size.
+ This cannot be used with fully observable wrappers.
+ """
+
+ def __init__(self, env, agent_view_size=7):
+ super().__init__(env)
+
+ assert agent_view_size % 2 == 1
+ assert agent_view_size >= 3
+
+ self.agent_view_size = agent_view_size
+
+ # Compute observation space with specified view size
+ new_image_space = gym.spaces.Box(low=0, high=255, shape=(agent_view_size, agent_view_size, 3), dtype="uint8")
+
+ # Override the environment's observation spaceexit
+ self.observation_space = spaces.Dict({**self.observation_space.spaces, "image": new_image_space})
+
+ def observation(self, obs):
+ env = self.unwrapped
+ grid, vis_mask = env.gen_obs_grid(self.agent_view_size)
+
+ # Encode the partially observable view into a numpy array
+ # print('grid:' + grid)
+ # print('vis_mask:' + vis_mask)
+ image = grid.encode(vis_mask)
+ return {**obs, "image": image}
diff --git a/DI-engine/dizoo/minigrid/envs/noisy_tv.py b/DI-engine/dizoo/minigrid/envs/noisy_tv.py
new file mode 100644
index 0000000000000000000000000000000000000000..9c5c78e6139347736892493bbe24673535783f6d
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/envs/noisy_tv.py
@@ -0,0 +1,216 @@
+from minigrid.core.grid import Grid
+from minigrid.core.mission import MissionSpace
+from minigrid.minigrid_env import *
+from minigrid.utils.rendering import *
+from minigrid.core.world_object import WorldObj
+import random
+
+
+class NoisyTVEnv(MiniGridEnv):
+ """
+ ### Description
+
+ Classic four room reinforcement learning environment with random noise. The agent must
+ navigate in a maze composed of four rooms interconnected by 4 gaps in the
+ walls. To obtain a reward, the agent must reach the green goal square. Both
+ the agent and the goal square are randomly placed in any of the four rooms.
+
+ ### Mission Space
+
+ "reach the goal"
+
+ ### Action Space
+
+ | Num | Name | Action |
+ |-----|--------------|--------------|
+ | 0 | left | Turn left |
+ | 1 | right | Turn right |
+ | 2 | forward | Move forward |
+ | 3 | pickup | Unused |
+ | 4 | drop | Unused |
+ | 5 | toggle | Unused |
+ | 6 | done | Unused |
+
+ ### Observation Encoding
+
+ - Each tile is encoded as a 3 dimensional tuple:
+ `(OBJECT_IDX, COLOR_IDX, STATE)`
+ - `OBJECT_TO_IDX` and `COLOR_TO_IDX` mapping can be found in
+ [minigrid/minigrid.py](minigrid/minigrid.py)
+ - `STATE` refers to the door state with 0=open, 1=closed and 2=locked
+
+ ### Rewards
+
+ A reward of '1' is given for success, and '0' for failure.
+ Noisy reward are given upon reaching a noisy tile. Noise obeys Gaussian distribution.
+
+ ### Termination
+
+ The episode ends if any one of the following conditions is met:
+
+ 1. The agent reaches the goal.
+ 2. Timeout (see `max_steps`).
+
+ ### Registered Configurations
+
+ - `MiniGrid-NoisyTV-v0`
+
+ """
+
+ def __init__(self, agent_pos=None, goal_pos=None, noisy_tile_num=4, **kwargs):
+ self._agent_default_pos = agent_pos
+ self._goal_default_pos = goal_pos
+ self.size = 19
+ self._noisy_tile_num = noisy_tile_num
+ self._noisy_tile_pos = []
+ for i in range(self._noisy_tile_num):
+ pos2 = (self._rand_int(1, self.size - 1), self._rand_int(1, self.size - 1))
+ while pos2 in self._noisy_tile_pos:
+ pos2 = (self._rand_int(1, self.size - 1), self._rand_int(1, self.size - 1))
+ self._noisy_tile_pos.append(pos2)
+ mission_space = MissionSpace(mission_func=lambda: "reach the goal")
+
+ super().__init__(mission_space=mission_space, width=self.size, height=self.size, max_steps=100, **kwargs)
+
+ def _reward_noise(self):
+ """
+ Compute the reward to be given upon reach a noisy tile
+ """
+ return self._rand_float(0.05, 0.1)
+
+ def _add_noise(self, obs):
+ """
+ Add noise to obs['image']
+ """
+ image = obs['image'].astype(float)
+ for pos in self._noisy_tile_pos:
+ if self.in_view(pos[0], pos[1]): # if noisy tile is in the view of agent, the view of agent is 7*7.
+ relative_pos = self.relative_coords(pos[0], pos[1])
+ image[relative_pos][1] += 0.5
+ obs['image'] = image
+ return obs
+
+ def _gen_grid(self, width, height):
+ # Create the grid
+ self.grid = Grid(width, height)
+
+ # Generate the surrounding walls
+ self.grid.horz_wall(0, 0)
+ self.grid.horz_wall(0, height - 1)
+ self.grid.vert_wall(0, 0)
+ self.grid.vert_wall(width - 1, 0)
+
+ room_w = width // 2
+ room_h = height // 2
+
+ # For each row of rooms
+ for j in range(0, 2):
+
+ # For each column
+ for i in range(0, 2):
+ xL = i * room_w
+ yT = j * room_h
+ xR = xL + room_w
+ yB = yT + room_h
+
+ # Bottom wall and door
+ if i + 1 < 2:
+ self.grid.vert_wall(xR, yT, room_h)
+ pos = (xR, self._rand_int(yT + 1, yB))
+ self.grid.set(*pos, None)
+
+ # Bottom wall and door
+ if j + 1 < 2:
+ self.grid.horz_wall(xL, yB, room_w)
+ pos = (self._rand_int(xL + 1, xR), yB)
+ self.grid.set(*pos, None)
+
+ # Randomize the player start position and orientation
+ if self._agent_default_pos is not None:
+ self.agent_pos = self._agent_default_pos
+ self.grid.set(*self._agent_default_pos, None)
+ # assuming random start direction
+ self.agent_dir = self._rand_int(0, 4)
+ else:
+ self.place_agent()
+
+ if self._goal_default_pos is not None:
+ goal = Goal()
+ self.put_obj(goal, *self._goal_default_pos)
+ goal.init_pos, goal.cur_pos = self._goal_default_pos
+ else:
+ self.place_obj(Goal())
+
+ def step(self, action):
+ self.step_count += 1
+
+ reward = 0
+ terminated = False
+ truncated = False
+
+ # Get the position in front of the agent
+ fwd_pos = self.front_pos
+
+ # Get the contents of the cell in front of the agent
+ fwd_cell = self.grid.get(*fwd_pos)
+
+ # Rotate left
+ if action == self.actions.left:
+ self.agent_dir -= 1
+ if self.agent_dir < 0:
+ self.agent_dir += 4
+
+ # Rotate right
+ elif action == self.actions.right:
+ self.agent_dir = (self.agent_dir + 1) % 4
+
+ # Move forward
+ elif action == self.actions.forward:
+ if fwd_cell is None or fwd_cell.can_overlap():
+ self.agent_pos = tuple(fwd_pos)
+ if fwd_cell is not None and fwd_cell.type == "goal":
+ terminated = True
+ reward = self._reward()
+ if fwd_cell is not None and fwd_cell.type == "lava":
+ terminated = True
+ # if agent reach noisy tile, return noisy reward.
+ if self.agent_pos in self._noisy_tile_pos:
+ reward = self._reward_noise()
+
+ # Pick up an object
+ elif action == self.actions.pickup:
+ if fwd_cell and fwd_cell.can_pickup():
+ if self.carrying is None:
+ self.carrying = fwd_cell
+ self.carrying.cur_pos = np.array([-1, -1])
+ self.grid.set(fwd_pos[0], fwd_pos[1], None)
+
+ # Drop an object
+ elif action == self.actions.drop:
+ if not fwd_cell and self.carrying:
+ self.grid.set(fwd_pos[0], fwd_pos[1], self.carrying)
+ self.carrying.cur_pos = fwd_pos
+ self.carrying = None
+
+ # Toggle/activate an object
+ elif action == self.actions.toggle:
+ if fwd_cell:
+ fwd_cell.toggle(self, fwd_pos)
+
+ # Done action (not used by default)
+ elif action == self.actions.done:
+ pass
+
+ else:
+ raise ValueError(f"Unknown action: {action}")
+
+ if self.step_count >= self.max_steps:
+ truncated = True
+
+ if self.render_mode == "human":
+ self.render()
+
+ obs = self.gen_obs()
+ obs = self._add_noise(obs)
+
+ return obs, reward, terminated, truncated, {}
diff --git a/DI-engine/dizoo/minigrid/envs/test_minigrid_env.py b/DI-engine/dizoo/minigrid/envs/test_minigrid_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d0abb4e80c1f0b4475240df64ed8e423d128a25
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/envs/test_minigrid_env.py
@@ -0,0 +1,110 @@
+import pytest
+import os
+import numpy as np
+from dizoo.minigrid.envs import MiniGridEnv
+from easydict import EasyDict
+import copy
+
+# The following two cfg can be tested through TestMiniGridAKTDTnv
+config = dict(
+ env_id='MiniGrid-AKTDT-13x13-v0',
+ flat_obs=True,
+)
+cfg = EasyDict(copy.deepcopy(config))
+cfg.cfg_type = 'MiniGridEnvDict'
+
+config2 = dict(
+ env_id='MiniGrid-AKTDT-7x7-1-v0',
+ flat_obs=True,
+)
+cfg2 = EasyDict(copy.deepcopy(config2))
+cfg2.cfg_type = 'MiniGridEnvDict'
+
+
+@pytest.mark.envtest
+class TestMiniGridEnv:
+
+ def test_naive(self):
+ env = MiniGridEnv(MiniGridEnv.default_config())
+ env.seed(314)
+ path = './video'
+ if not os.path.exists(path):
+ os.mkdir(path)
+ env.enable_save_replay(path)
+ assert env._seed == 314
+ obs = env.reset()
+ act_val = env.info().act_space.value
+ min_val, max_val = act_val['min'], act_val['max']
+ for i in range(env._max_step):
+ random_action = np.random.randint(min_val, max_val, size=(1, ))
+ timestep = env.step(random_action)
+ print(timestep)
+ print(timestep.obs.max())
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (2739, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.info().rew_space.value['min']
+ assert timestep.reward <= env.info().rew_space.value['max']
+ if timestep.done:
+ env.reset()
+ print(env.info())
+ env.close()
+
+
+@pytest.mark.envtest
+class TestMiniGridAKTDTnv:
+
+ def test_adtkt_13(self):
+ env = MiniGridEnv(cfg2)
+ env.seed(314)
+ path = './video'
+ if not os.path.exists(path):
+ os.mkdir(path)
+ env.enable_save_replay(path)
+ assert env._seed == 314
+ obs = env.reset()
+ act_val = env.info().act_space.value
+ min_val, max_val = act_val['min'], act_val['max']
+ for i in range(env._max_step):
+ random_action = np.random.randint(min_val, max_val, size=(1, ))
+ timestep = env.step(random_action)
+ print(timestep)
+ print(timestep.obs.max())
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (2667, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.info().rew_space.value['min']
+ assert timestep.reward <= env.info().rew_space.value['max']
+ if timestep.done:
+ env.reset()
+ print(env.info())
+ env.close()
+
+ def test_adtkt_7(self):
+ env = MiniGridEnv(cfg2)
+ env.seed(314)
+ path = './video'
+ if not os.path.exists(path):
+ os.mkdir(path)
+ env.enable_save_replay(path)
+ assert env._seed == 314
+ obs = env.reset()
+ act_val = env.info().act_space.value
+ min_val, max_val = act_val['min'], act_val['max']
+ for i in range(env._max_step):
+ random_action = np.random.randint(min_val, max_val, size=(1, ))
+ timestep = env.step(random_action)
+ print(timestep)
+ print(timestep.obs.max())
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (2619, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.info().rew_space.value['min']
+ assert timestep.reward <= env.info().rew_space.value['max']
+ if timestep.done:
+ env.reset()
+ print(env.info())
+ env.close()
diff --git a/DI-engine/dizoo/minigrid/utils/eval.py b/DI-engine/dizoo/minigrid/utils/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3c6acb9fb73e4737d9a43538f5dbb75b935a1c8
--- /dev/null
+++ b/DI-engine/dizoo/minigrid/utils/eval.py
@@ -0,0 +1,97 @@
+from typing import Union, Optional, List, Any, Callable, Tuple
+import torch
+from ding.config import compile_config, read_config
+from ding.envs import get_vec_env_setting
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from ding.torch_utils import to_tensor, to_ndarray, tensor_to_list
+
+
+def eval(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ state_dict: Optional[dict] = None,
+ replay_path: Optional[str] = './video',
+) -> float:
+ r"""
+ Overview:
+ The evaluation entry for NGU policy.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = input_cfg
+ create_cfg.policy.type += '_command'
+ cfg = compile_config(cfg, auto=True, create_cfg=create_cfg)
+ env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ env = env_fn(evaluator_env_cfg[0])
+ env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['eval']).eval_mode
+ if state_dict is None:
+ state_dict = torch.load(cfg.learner.load_path, map_location='cpu')
+ policy.load_state_dict(state_dict)
+ env.enable_save_replay(replay_path=replay_path)
+ obs = env.reset()
+ obs = {0: obs}
+ episode_return = 0.
+
+ beta_index = {i: 0 for i in range(1)}
+ beta_index = to_tensor(beta_index, dtype=torch.int64)
+ prev_action = {i: torch.tensor(-1) for i in range(1)}
+ prev_reward_e = {i: to_tensor(0, dtype=torch.float32) for i in range(1)}
+
+ while True:
+ # TODO(pu): r_i, reward embedding
+ policy_output = policy.forward(beta_index, obs, prev_action, prev_reward_e)
+
+ actions = {i: a['action'] for i, a in policy_output.items()}
+ actions = to_ndarray(actions)
+
+ action = policy_output[0]['action']
+ action = to_ndarray(action)
+ timestep = env.step(action)
+ # print(action)
+ # print(timestep.reward)
+
+ timesteps = {0: timestep}
+ timesteps = to_tensor(timesteps, dtype=torch.float32)
+
+ prev_reward_e = {env_id: timestep.reward for env_id, timestep in timesteps.items()}
+ prev_reward_e = to_ndarray(prev_reward_e)
+ prev_action = actions
+
+ timestep = timesteps[0]
+ # print(timestep.info)
+ episode_return += timestep.reward
+
+ obs = timestep.obs
+ obs = {0: obs}
+
+ if timestep.done:
+ print(timestep.info)
+ break
+ print('Eval is over! The performance of your RL policy is {}'.format(episode_return))
+
+
+if __name__ == "__main__":
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path = './debug_minigrid_doorkey_ngu_ul298_er01_n32_rbs3e4_fixepseval/ckpt/ckpt_best.pth.tar',
+ # model_path = 'model_path_placeholder',
+ cfg = '../config/minigrid_ngu_config.py'
+
+ state_dict = torch.load(model_path, map_location='cpu')
+ for i in range(0, 10):
+ eval(cfg, seed=i, state_dict=state_dict, replay_path='./video')
diff --git a/DI-engine/dizoo/mujoco/__init__.py b/DI-engine/dizoo/mujoco/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/mujoco/addition/install_mesa.sh b/DI-engine/dizoo/mujoco/addition/install_mesa.sh
new file mode 100644
index 0000000000000000000000000000000000000000..29d2bb78eaa33370d033eac12eb4d647f0b3be59
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/addition/install_mesa.sh
@@ -0,0 +1,4 @@
+mkdir -p ~/rpm
+yumdownloader --destdir ~/rpm --resolve mesa-libOSMesa.x86_64 mesa-libOSMesa-devel.x86_64 patchelf.x86_64
+cd ~/rpm
+for rpm in `ls`; do rpm2cpio $rpm | cpio -id ; done
diff --git a/DI-engine/dizoo/mujoco/config/__init__.py b/DI-engine/dizoo/mujoco/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/mujoco/config/ant_ddpg_config.py b/DI-engine/dizoo/mujoco/config/ant_ddpg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c8698ebcc3a455460d202a0b91fc6cc1a3e703e7
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/ant_ddpg_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+
+ant_ddpg_config = dict(
+ exp_name='ant_ddpg_seed0',
+ env=dict(
+ env_id='Ant-v3',
+ env_wrapper='mujoco_default',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ manager=dict(shared_memory=False, ),
+ # The path to save the game replay
+ # replay_path='./ant_ddpg_seed0/video',
+ ),
+ policy=dict(
+ cuda=True,
+ load_path="./ant_ddpg_seed0/ckpt/ckpt_best.pth.tar",
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=111,
+ action_shape=8,
+ twin_critic=False,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99, # discount_factor: 0.97-0.99
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+ant_ddpg_config = EasyDict(ant_ddpg_config)
+main_config = ant_ddpg_config
+
+ant_ddpg_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='ddpg',
+ import_names=['ding.policy.ddpg'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+ant_ddpg_create_config = EasyDict(ant_ddpg_create_config)
+create_config = ant_ddpg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c ant_ddpg_config.py -s 0 --env-step 1e7`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/ant_gail_sac_config.py b/DI-engine/dizoo/mujoco/config/ant_gail_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7e7cd7d06b93ae1cbd28b53c22053e18d59571c
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/ant_gail_sac_config.py
@@ -0,0 +1,101 @@
+from easydict import EasyDict
+
+obs_shape = 111
+act_shape = 8
+ant_sac_gail_config = dict(
+ exp_name='ant_sac_gail_seed0',
+ env=dict(
+ env_id='Ant-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ reward_model=dict(
+ input_size=obs_shape + act_shape,
+ hidden_size=256,
+ batch_size=64,
+ learning_rate=1e-3,
+ update_per_collect=100,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder+/reward_model/ckpt/ckpt_best.pth.tar',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ data_path='data_path_placeholder',
+ collect_count=300000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=act_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=64,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+ant_sac_gail_config = EasyDict(ant_sac_gail_config)
+main_config = ant_sac_gail_config
+
+ant_sac_gail_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ reward_model=dict(type='gail'),
+)
+ant_sac_gail_create_config = EasyDict(ant_sac_gail_create_config)
+create_config = ant_sac_gail_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_gail -c ant_gail_sac_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. hopper_sac_config.py
+ from ding.entry import serial_pipeline_gail
+ from dizoo.mujoco.config.ant_sac_config import ant_sac_config, ant_sac_create_config
+
+ expert_main_config = ant_sac_config
+ expert_create_config = ant_sac_create_config
+ serial_pipeline_gail(
+ [main_config, create_config], [expert_main_config, expert_create_config],
+ max_env_step=10000000,
+ seed=0,
+ collect_data=True
+ )
diff --git a/DI-engine/dizoo/mujoco/config/ant_onppo_config.py b/DI-engine/dizoo/mujoco/config/ant_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..73d5ba344ba108031af63bcad6896cc6c33ab9f2
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/ant_onppo_config.py
@@ -0,0 +1,67 @@
+from easydict import EasyDict
+
+ant_ppo_config = dict(
+ exp_name="ant_onppo_seed0",
+ env=dict(
+ env_id='Ant-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=10,
+ evaluator_env_num=10,
+ n_evaluator_episode=10,
+ stop_value=6000,
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='continuous',
+ model=dict(
+ action_space='continuous',
+ obs_shape=111,
+ action_shape=8,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ update_per_collect=1,
+ batch_size=320,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ # When we recompute advantage, we need the key done in data to split trajectories, so we must
+ # use 'ignore_done=False' here, but when we add key 'traj_flag' in data as the backup for key done,
+ # we could choose to use 'ignore_done=True'. 'traj_flag' indicates termination of trajectory.
+ ignore_done=False,
+ grad_clip_type='clip_norm',
+ grad_clip_value=0.5,
+ ),
+ collect=dict(
+ n_sample=3200,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ ),
+)
+ant_ppo_config = EasyDict(ant_ppo_config)
+main_config = ant_ppo_config
+
+ant_ppo_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+ant_ppo_create_config = EasyDict(ant_ppo_create_config)
+create_config = ant_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c ant_onppo_config.py -s 0 --env-step 1e7`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/ant_ppo_config.py b/DI-engine/dizoo/mujoco/config/ant_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..25103592bb3cab0cc5c3858bd2bd533c6e6843bf
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/ant_ppo_config.py
@@ -0,0 +1,60 @@
+from easydict import EasyDict
+
+ant_ppo_config = dict(
+ exp_name='ant_ppo_seed0',
+ env=dict(
+ manager=dict(shared_memory=False, reset_inplace=True),
+ env_id='Ant-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=8,
+ evaluator_env_num=10,
+ n_evaluator_episode=10,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ model=dict(
+ obs_shape=111,
+ action_shape=8,
+ action_space='continuous',
+ ),
+ action_space='continuous',
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.0,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=2048,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.97,
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ ),
+)
+ant_ppo_config = EasyDict(ant_ppo_config)
+main_config = ant_ppo_config
+
+ant_ppo_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo', ),
+)
+ant_ppo_create_config = EasyDict(ant_ppo_create_config)
+create_config = ant_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c ant_ppo_config.py -s 0 --env-step 1e7`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/ant_sac_config.py b/DI-engine/dizoo/mujoco/config/ant_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd2881e91f3c8da0e3ec585d79e18056ddff2621
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/ant_sac_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+
+ant_sac_config = dict(
+ exp_name='ant_sac_seed0',
+ env=dict(
+ env_id='Ant-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ manager=dict(shared_memory=False, reset_inplace=True),
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=111,
+ action_shape=8,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+ant_sac_config = EasyDict(ant_sac_config)
+main_config = ant_sac_config
+
+ant_sac_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+ant_sac_create_config = EasyDict(ant_sac_create_config)
+create_config = ant_sac_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c ant_sac_config.py -s 0 --env-step 1e7`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/ant_td3_config.py b/DI-engine/dizoo/mujoco/config/ant_td3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebcb3654ce21e993ec53b7bd34743f422cb7796d
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/ant_td3_config.py
@@ -0,0 +1,72 @@
+from easydict import EasyDict
+
+ant_td3_config = dict(
+ exp_name='ant_td3_seed0',
+ env=dict(
+ env_id='Ant-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ manager=dict(shared_memory=False, reset_inplace=True),
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=111,
+ action_shape=8,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+
+ant_td3_config = EasyDict(ant_td3_config)
+main_config = ant_td3_config
+
+ant_td3_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='td3',
+ import_names=['ding.policy.td3'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+ant_td3_create_config = EasyDict(ant_td3_create_config)
+create_config = ant_td3_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c ant_td3_config.py -s 0 --env-step 1e7`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/ant_trex_onppo_config.py b/DI-engine/dizoo/mujoco/config/ant_trex_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f3d5e96b75e883582867a27c509f56a28b157ea7
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/ant_trex_onppo_config.py
@@ -0,0 +1,80 @@
+from easydict import EasyDict
+
+ant_trex_ppo_config = dict(
+ exp_name='ant_trex_onppo_seed0',
+ env=dict(
+ manager=dict(shared_memory=True, reset_inplace=True),
+ env_id='Ant-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=8,
+ evaluator_env_num=10,
+ n_evaluator_episode=10,
+ stop_value=6000,
+ ),
+ reward_model=dict(
+ type='trex',
+ min_snippet_length=10,
+ max_snippet_length=100,
+ checkpoint_min=100,
+ checkpoint_max=900,
+ checkpoint_step=100,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='abs_data_path + ./ant.params',
+ continuous=True,
+ # Path to the offline dataset
+ # See ding/entry/application_entry_trex_collect_data.py to collect the data
+ offline_data_path='abs_data_path',
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ model=dict(
+ obs_shape=111,
+ action_shape=8,
+ action_space='continuous',
+ ),
+ action_space='continuous',
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.0,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=2048,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.97,
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ ),
+)
+ant_trex_ppo_config = EasyDict(ant_trex_ppo_config)
+main_config = ant_trex_ppo_config
+
+ant_trex_ppo_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo', ),
+)
+ant_trex_ppo_create_config = EasyDict(ant_trex_ppo_create_config)
+create_config = ant_trex_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c ant_trex_onppo_config.py -s 0`
+ from ding.entry import serial_pipeline_trex_onpolicy
+ serial_pipeline_trex_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/ant_trex_sac_config.py b/DI-engine/dizoo/mujoco/config/ant_trex_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c0ef73097ccb49847b585aa1b246dd9f5566595
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/ant_trex_sac_config.py
@@ -0,0 +1,90 @@
+from easydict import EasyDict
+
+ant_trex_sac_config = dict(
+ exp_name='ant_trex_sac_seed0',
+ env=dict(
+ manager=dict(shared_memory=True, reset_inplace=True),
+ env_id='Ant-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ reward_model=dict(
+ type='trex',
+ min_snippet_length=30,
+ max_snippet_length=100,
+ checkpoint_min=1000,
+ checkpoint_max=9000,
+ checkpoint_step=1000,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='abs_data_path + ./ant.params',
+ continuous=True,
+ # Path to the offline dataset
+ # See ding/entry/application_entry_trex_collect_data.py to collect the data
+ offline_data_path='abs_data_path',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=111,
+ action_shape=8,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+ant_trex_sac_config = EasyDict(ant_trex_sac_config)
+main_config = ant_trex_sac_config
+
+ant_trex_sac_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+ant_trex_sac_create_config = EasyDict(ant_trex_sac_create_config)
+create_config = ant_trex_sac_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c ant_trex_sac_config.py -s 0`
+ from ding.entry import serial_pipeline_trex
+ serial_pipeline_trex((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/halfcheetah_bco_config.py b/DI-engine/dizoo/mujoco/config/halfcheetah_bco_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9253af06eb7321ddda5ffa1dbfe758dd58f6fad8
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/halfcheetah_bco_config.py
@@ -0,0 +1,92 @@
+from easydict import EasyDict
+
+halfcheetah_bco_config = dict(
+ exp_name='halfcheetah_bco_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=12000,
+ ),
+ policy=dict(
+ # Whether to use cuda for network.
+ cuda=True,
+ continuous=True,
+ loss_type='l1_loss',
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ action_space='regression',
+ actor_head_hidden_size=128,
+ ),
+ learn=dict(
+ train_epoch=30,
+ batch_size=128,
+ learning_rate=0.01,
+ weight_decay=1e-5,
+ decay_epoch=1000,
+ decay_rate=0.5,
+ warmup_lr=1e-4,
+ warmup_epoch=3,
+ optimizer='SGD',
+ lr_decay=True,
+ momentum=0.9,
+ tanh_mask=True,
+ ),
+ collect=dict(
+ n_episode=100,
+ # control the number (alpha*n_episode) of post-demonstration environment interactions at each iteration.
+ # Notice: alpha * n_episode > collector_env_num
+ model_path='abs model path', # expert model path
+ data_path='abs data path', # expert data path
+ noise=True,
+ noise_sigma=dict(
+ start=0.5,
+ end=0.1,
+ decay=1000000,
+ type='exp',
+ ),
+ noise_range=dict(
+ min=-1,
+ max=1,
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
+ ),
+ bco=dict(
+ learn=dict(idm_batch_size=128, idm_learning_rate=0.001, idm_weight_decay=0, idm_train_epoch=30),
+ model=dict(
+ action_space='regression',
+ idm_encoder_hidden_size_list=[60, 80, 100, 40],
+ ),
+ alpha=0.2,
+ )
+)
+
+halfcheetah_bco_config = EasyDict(halfcheetah_bco_config)
+main_config = halfcheetah_bco_config
+
+halfcheetah_bco_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='bc'),
+ collector=dict(type='episode'),
+)
+halfcheetah_bco_create_config = EasyDict(halfcheetah_bco_create_config)
+create_config = halfcheetah_bco_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_bco
+ from dizoo.mujoco.config.halfcheetah_sac_config import halfcheetah_sac_config, halfcheetah_sac_create_config
+ expert_main_config = halfcheetah_sac_config
+ expert_create_config = halfcheetah_sac_create_config
+ serial_pipeline_bco(
+ [main_config, create_config], [expert_main_config, expert_create_config], seed=0, max_env_step=3000000
+ )
diff --git a/DI-engine/dizoo/mujoco/config/halfcheetah_bdq_config.py b/DI-engine/dizoo/mujoco/config/halfcheetah_bdq_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..25fb65ba35097f9dd367fc1e21e270cc409239e0
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/halfcheetah_bdq_config.py
@@ -0,0 +1,71 @@
+from easydict import EasyDict
+
+halfcheetah_bdq_config = dict(
+ exp_name='halfcheetah_bdq_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=12000,
+ action_bins_per_branch=2,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ discount_factor=0.99,
+ nstep=1,
+ model=dict(
+ obs_shape=17,
+ num_branches=6,
+ action_bins_per_branch=2, # mean the action shape is 6, 2 discrete actions for each action dimension
+ encoder_hidden_size_list=[256, 256, 128],
+ ),
+ learn=dict(
+ batch_size=512,
+ learning_rate=3e-4,
+ ignore_done=True,
+ target_update_freq=500,
+ update_per_collect=20,
+ ),
+ collect=dict(
+ n_sample=256,
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=1,
+ end=0.05,
+ decay=int(1e5),
+ ),
+ replay_buffer=dict(replay_buffer_size=int(1e6), )
+ ),
+ ),
+)
+halfcheetah_bdq_config = EasyDict(halfcheetah_bdq_config)
+main_config = halfcheetah_bdq_config
+
+halfcheetah_bdq_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='bdq', ),
+)
+halfcheetah_bdq_create_config = EasyDict(halfcheetah_bdq_create_config)
+create_config = halfcheetah_bdq_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c halfcheetah_onbdq_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline(
+ (main_config, create_config),
+ seed=0,
+ max_env_step=10000000,
+ )
diff --git a/DI-engine/dizoo/mujoco/config/halfcheetah_d4pg_config.py b/DI-engine/dizoo/mujoco/config/halfcheetah_d4pg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..154bc27a46e876a7110807fc16bea98143768bca
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/halfcheetah_d4pg_config.py
@@ -0,0 +1,69 @@
+from easydict import EasyDict
+
+halfcheetah_d4pg_config = dict(
+ exp_name='halfcheetah_d4pg_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=4,
+ evaluator_env_num=4,
+ n_evaluator_episode=8,
+ stop_value=20000,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ nstep=5,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ actor_head_hidden_size=512,
+ critic_head_hidden_size=512,
+ action_space='regression',
+ critic_head_type='categorical',
+ v_min=0,
+ v_max=5000, # v_max: [3000, 10000]
+ n_atom=51,
+ ),
+ learn=dict(
+ update_per_collect=4, # update_per_collect: [1, 4]
+ batch_size=256,
+ learning_rate_actor=3e-4,
+ learning_rate_critic=3e-4,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=8,
+ unroll_len=1,
+ noise_sigma=0.2, # noise_sigma: [0.1, 0.2]
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+halfcheetah_d4pg_config = EasyDict(halfcheetah_d4pg_config)
+main_config = halfcheetah_d4pg_config
+
+halfcheetah_d4pg_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='d4pg',
+ import_names=['ding.policy.d4pg'],
+ ),
+)
+halfcheetah_d4pg_create_config = EasyDict(halfcheetah_d4pg_create_config)
+create_config = halfcheetah_d4pg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c halfcheetah_d4pg_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/halfcheetah_ddpg_config.py b/DI-engine/dizoo/mujoco/config/halfcheetah_ddpg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..640717b8c678019b9e31bce4dc59d3182503363b
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/halfcheetah_ddpg_config.py
@@ -0,0 +1,65 @@
+from easydict import EasyDict
+
+halfcheetah_ddpg_config = dict(
+ env=dict(
+ exp_name='halfcheetah_ddpg_seed0',
+ env_id='HalfCheetah-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=11000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=False,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+halfcheetah_ddpg_config = EasyDict(halfcheetah_ddpg_config)
+main_config = halfcheetah_ddpg_config
+
+halfcheetah_ddpg_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='ddpg',
+ import_names=['ding.policy.ddpg'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+halfcheetah_ddpg_create_config = EasyDict(halfcheetah_ddpg_create_config)
+create_config = halfcheetah_ddpg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c halfcheetah_ddpg_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
\ No newline at end of file
diff --git a/DI-engine/dizoo/mujoco/config/halfcheetah_gail_sac_config.py b/DI-engine/dizoo/mujoco/config/halfcheetah_gail_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf64cd8c640e071d01d0bf2f227b03a89d254b33
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/halfcheetah_gail_sac_config.py
@@ -0,0 +1,100 @@
+from easydict import EasyDict
+
+obs_shape = 17
+act_shape = 6
+halfcheetah_sac_gail_config = dict(
+ exp_name='halfcheetah_sac_gail_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=12000,
+ ),
+ reward_model=dict(
+ input_size=obs_shape + act_shape,
+ hidden_size=256,
+ batch_size=64,
+ learning_rate=1e-3,
+ update_per_collect=100,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder+/reward_model/ckpt/ckpt_best.pth.tar',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ data_path='data_path_placeholder',
+ collect_count=300000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=act_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=64,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+halfcheetah_sac_gail_config = EasyDict(halfcheetah_sac_gail_config)
+main_config = halfcheetah_sac_gail_config
+
+halfcheetah_sac_gail_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+halfcheetah_sac_gail_create_config = EasyDict(halfcheetah_sac_gail_create_config)
+create_config = halfcheetah_sac_gail_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_gail -c ant_gail_sac_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. hopper_sac_config.py
+ from ding.entry import serial_pipeline_gail
+ from dizoo.mujoco.config.halfcheetah_sac_config import halfcheetah_sac_config, halfcheetah_sac_create_config
+
+ expert_main_config = halfcheetah_sac_config
+ expert_create_config = halfcheetah_sac_create_config
+ serial_pipeline_gail(
+ [main_config, create_config], [expert_main_config, expert_create_config],
+ max_env_step=10000000,
+ seed=0,
+ collect_data=True
+ )
diff --git a/DI-engine/dizoo/mujoco/config/halfcheetah_gcl_sac_config.py b/DI-engine/dizoo/mujoco/config/halfcheetah_gcl_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..367b7bcf03c73f492f2478d9ba6fbf01081877b8
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/halfcheetah_gcl_sac_config.py
@@ -0,0 +1,86 @@
+from easydict import EasyDict
+
+halfcheetah_gcl_sac_config = dict(
+ exp_name='halfcheetah_gcl_sac_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=12000,
+ ),
+ reward_model=dict(
+ learning_rate=0.001,
+ input_size=23,
+ batch_size=32,
+ action_shape=6,
+ continuous=True,
+ update_per_collect=20,
+ ),
+ policy=dict(
+ cuda=False,
+ on_policy=False,
+ random_collect_size=0,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ # If you need the data collected by the collector to contain logit key which reflect the probability of
+ # the action, you can change the key to be True.
+ # In Guided cost Learning, we need to use logit to train the reward model, we change the key to be True.
+ collector_logit=True,
+ n_sample=256,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+halfcheetah_gcl_sac_config = EasyDict(halfcheetah_gcl_sac_config)
+main_config = halfcheetah_gcl_sac_config
+
+halfcheetah_gcl_sac_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ reward_model=dict(type='guided_cost'),
+)
+halfcheetah_gcl_sac_create_config = EasyDict(halfcheetah_gcl_sac_create_config)
+create_config = halfcheetah_gcl_sac_create_config
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline_guided_cost
+ serial_pipeline_guided_cost((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/halfcheetah_onppo_config.py b/DI-engine/dizoo/mujoco/config/halfcheetah_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..87046ff6f5dcc7f1ef027cde1ee4aed1693d34e3
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/halfcheetah_onppo_config.py
@@ -0,0 +1,71 @@
+from easydict import EasyDict
+
+collector_env_num = 1
+evaluator_env_num = 1
+halfcheetah_ppo_config = dict(
+ exp_name='halfcheetah_onppo_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=1,
+ stop_value=12000,
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='continuous',
+ model=dict(
+ action_space='continuous',
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ update_per_collect=1,
+ batch_size=320,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ # for onppo, when we recompute adv, we need the key done in data to split traj, so we must
+ # use ignore_done=False here,
+ # but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
+ # for halfcheetah, the length=1000
+ ignore_done=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=0.5,
+ ),
+ collect=dict(
+ collector_env_num=collector_env_num,
+ n_sample=3200,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+)
+halfcheetah_ppo_config = EasyDict(halfcheetah_ppo_config)
+main_config = halfcheetah_ppo_config
+
+halfcheetah_ppo_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ # env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo', ),
+)
+halfcheetah_ppo_create_config = EasyDict(halfcheetah_ppo_create_config)
+create_config = halfcheetah_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c halfcheetah_onppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
\ No newline at end of file
diff --git a/DI-engine/dizoo/mujoco/config/halfcheetah_sac_config.py b/DI-engine/dizoo/mujoco/config/halfcheetah_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..67ace8134bb9589a4b4b79bba8972eea21497fba
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/halfcheetah_sac_config.py
@@ -0,0 +1,69 @@
+from easydict import EasyDict
+
+halfcheetah_sac_config = dict(
+ exp_name='halfcheetah_sac_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=12000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+halfcheetah_sac_config = EasyDict(halfcheetah_sac_config)
+main_config = halfcheetah_sac_config
+
+halfcheetah_sac_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+halfcheetah_sac_create_config = EasyDict(halfcheetah_sac_create_config)
+create_config = halfcheetah_sac_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c halfcheetah_sac_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/halfcheetah_sqil_sac_config.py b/DI-engine/dizoo/mujoco/config/halfcheetah_sqil_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ea6cb51b53705a2b1809e6572919d6cb30ed0c2f
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/halfcheetah_sqil_sac_config.py
@@ -0,0 +1,76 @@
+from easydict import EasyDict
+
+halfcheetah_sqil_config = dict(
+ exp_name='halfcheetah_sqil_sac_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=12000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ expert_random_collect_size=10000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=2e-4,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=True,
+ ),
+ collect=dict(
+ n_sample=32,
+ # Users should add their own path here (path should lead to a well-trained model)
+ model_path='model_path_placeholder',
+ # Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )), # note: this is the times after which you learns to evaluate
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+halfcheetah_sqil_config = EasyDict(halfcheetah_sqil_config)
+main_config = halfcheetah_sqil_config
+halfcheetah_sqil_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sqil_sac'),
+ replay_buffer=dict(type='naive', ),
+)
+halfcheetah_sqil_create_config = EasyDict(halfcheetah_sqil_create_config)
+create_config = halfcheetah_sqil_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_sqil -c halfcheetah_sqil_sac_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. halfcheetah_sac_config.py
+ from halfcheetah_sac_config import halfcheetah_sac_config, halfcheetah_sac_create_config
+ from ding.entry import serial_pipeline_sqil
+ expert_main_config = halfcheetah_sac_config
+ expert_create_config = halfcheetah_sac_create_config
+ serial_pipeline_sqil(
+ [main_config, create_config], [expert_main_config, expert_create_config], seed=0, max_env_step=5000000
+ )
diff --git a/DI-engine/dizoo/mujoco/config/halfcheetah_td3_config.py b/DI-engine/dizoo/mujoco/config/halfcheetah_td3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..47eb4ce5f1c438e0006594b2c212d371fdcc9fad
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/halfcheetah_td3_config.py
@@ -0,0 +1,71 @@
+from easydict import EasyDict
+
+halfcheetah_td3_config = dict(
+ exp_name='halfcheetah_td3_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=11000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+
+halfcheetah_td3_config = EasyDict(halfcheetah_td3_config)
+main_config = halfcheetah_td3_config
+
+halfcheetah_td3_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='td3',
+ import_names=['ding.policy.td3'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+halfcheetah_td3_create_config = EasyDict(halfcheetah_td3_create_config)
+create_config = halfcheetah_td3_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c halfcheetah_td3_config.py -s 0 --env-step 1e7`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/halfcheetah_trex_onppo_config.py b/DI-engine/dizoo/mujoco/config/halfcheetah_trex_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6d635c212d035baa9fb4024a3f1350937acfc7ee
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/halfcheetah_trex_onppo_config.py
@@ -0,0 +1,101 @@
+from easydict import EasyDict
+
+halfCheetah_trex_ppo_config = dict(
+ exp_name='halfcheetah_trex_onppo_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=8,
+ evaluator_env_num=10,
+ n_evaluator_episode=10,
+ stop_value=3000,
+ ),
+ reward_model=dict(
+ min_snippet_length=30,
+ max_snippet_length=100,
+ checkpoint_min=10000,
+ checkpoint_max=90000,
+ checkpoint_step=10000,
+ num_snippets=60000,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ # However, here in ``expert_model_path``, it is ``exp_name`` of the expert config.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder + /HalfCheetah.params',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ # See ding/entry/application_entry_trex_collect_data.py to collect the data
+ data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ action_space='continuous',
+ ),
+ action_space='continuous',
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.0,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ # for onppo, when we recompute adv, we need the key done in data to split traj, so we must
+ # use ignore_done=False here,
+ # but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
+ # for halfcheetah, the length=1000
+ ignore_done=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=0.5,
+ ),
+ collect=dict(
+ n_sample=2048,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.97,
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ ),
+)
+halfCheetah_trex_ppo_config = EasyDict(halfCheetah_trex_ppo_config)
+main_config = halfCheetah_trex_ppo_config
+
+halfCheetah_trex_ppo_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo', ),
+ reward_model=dict(type='trex'),
+)
+halfCheetah_trex_ppo_create_config = EasyDict(halfCheetah_trex_ppo_create_config)
+create_config = halfCheetah_trex_ppo_create_config
+
+if __name__ == '__main__':
+ # Users should first run ``halfcheetah_onppo_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_trex_onpolicy
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_trex_onpolicy([main_config, create_config])
diff --git a/DI-engine/dizoo/mujoco/config/halfcheetah_trex_sac_config.py b/DI-engine/dizoo/mujoco/config/halfcheetah_trex_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f123682a0cd5a5174e06d851c4555e49e5e33b4
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/halfcheetah_trex_sac_config.py
@@ -0,0 +1,103 @@
+from easydict import EasyDict
+
+halfcheetah_trex_sac_config = dict(
+ exp_name='halfcheetah_trex_sac_seed0',
+ env=dict(
+ env_id='HalfCheetah-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=12000,
+ ),
+ reward_model=dict(
+ learning_rate=1e-5,
+ min_snippet_length=30,
+ max_snippet_length=100,
+ checkpoint_min=1000,
+ checkpoint_max=9000,
+ checkpoint_step=1000,
+ update_per_collect=1,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ # However, here in ``expert_model_path``, it is ``exp_name`` of the expert config.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder + /HalfCheetah.params',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ # See ding/entry/application_entry_trex_collect_data.py to collect the data
+ data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+halfcheetah_trex_sac_config = EasyDict(halfcheetah_trex_sac_config)
+main_config = halfcheetah_trex_sac_config
+
+halfcheetah_trex_sac_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ reward_model=dict(type='trex'),
+)
+halfcheetah_trex_sac_create_config = EasyDict(halfcheetah_trex_sac_create_config)
+create_config = halfcheetah_trex_sac_create_config
+
+if __name__ == '__main__':
+ # Users should first run ``halfcheetah_sac_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_trex
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_trex([main_config, create_config])
diff --git a/DI-engine/dizoo/mujoco/config/hopper_bco_config.py b/DI-engine/dizoo/mujoco/config/hopper_bco_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..668e258e69577768ca829792f781a7d6f23845d4
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_bco_config.py
@@ -0,0 +1,91 @@
+from easydict import EasyDict
+
+hopper_bco_config = dict(
+ exp_name='hopper_bco_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ # Whether to use cuda for network.
+ cuda=True,
+ continuous=True,
+ loss_type='l1_loss',
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ action_space='regression',
+ actor_head_hidden_size=128,
+ ),
+ learn=dict(
+ train_epoch=20,
+ batch_size=128,
+ learning_rate=0.001,
+ weight_decay=1e-4,
+ momentum=0.9,
+ decay_epoch=30,
+ decay_rate=1,
+ warmup_lr=1e-4,
+ warmup_epoch=3,
+ optimizer='SGD',
+ lr_decay=True,
+ ),
+ collect=dict(
+ n_episode=100,
+ # control the number (alpha*n_episode) of post-demonstration environment interactions at each iteration.
+ # Notice: alpha * n_episode > collector_env_num
+ model_path='abs model path', # expert model path
+ data_path='abs data path', # expert data path
+ noise=True,
+ noise_sigma=dict(
+ start=0.5,
+ end=0.1,
+ decay=1000000,
+ type='exp',
+ ),
+ noise_range=dict(
+ min=-1,
+ max=1,
+ ),
+ ),
+ eval=dict(evaluator=dict(eval_freq=40, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
+ ),
+ bco=dict(
+ learn=dict(idm_batch_size=256, idm_learning_rate=0.001, idm_weight_decay=0, idm_train_epoch=20),
+ model=dict(
+ action_space='regression',
+ idm_encoder_hidden_size_list=[60, 80, 100, 40],
+ ),
+ alpha=0.2,
+ )
+)
+
+hopper_bco_config = EasyDict(hopper_bco_config)
+main_config = hopper_bco_config
+
+hopper_bco_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='bc'),
+ collector=dict(type='episode'),
+)
+hopper_bco_create_config = EasyDict(hopper_bco_create_config)
+create_config = hopper_bco_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_bco
+ from dizoo.mujoco.config.hopper_sac_config import hopper_sac_config, hopper_sac_create_config
+ expert_main_config = hopper_sac_config
+ expert_create_config = hopper_sac_create_config
+ serial_pipeline_bco(
+ [main_config, create_config], [expert_main_config, expert_create_config], seed=0, max_env_step=3000000
+ )
diff --git a/DI-engine/dizoo/mujoco/config/hopper_bdq_config.py b/DI-engine/dizoo/mujoco/config/hopper_bdq_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..34dbe21664f01ecb64cd2518c910fbb7f54be2d3
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_bdq_config.py
@@ -0,0 +1,75 @@
+from easydict import EasyDict
+
+hopper_bdq_config = dict(
+ exp_name='hopper_bdq_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=int(1e6),
+ action_bins_per_branch=4,
+ ),
+ policy=dict(
+ cuda=False,
+ priority=False,
+ discount_factor=0.99,
+ nstep=3,
+ model=dict(
+ obs_shape=11,
+ num_branches=3,
+ action_bins_per_branch=4, # mean the action shape is 3, 4 discrete actions for each action dimension
+ encoder_hidden_size_list=[256, 256, 128],
+ ),
+ learn=dict(
+ ignore_done=False,
+ batch_size=512,
+ learning_rate=3e-4,
+ # Frequency of target network update.
+ target_update_freq=500,
+ update_per_collect=20,
+ ),
+ collect=dict(
+ # You can use either "n_sample" or "n_episode" in collector.collect.
+ # Get "n_sample" samples per collect.
+ n_sample=256,
+ # Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(
+ # Epsilon greedy with decay.
+ eps=dict(
+ # Decay type. Support ['exp', 'linear'].
+ type='exp',
+ start=1,
+ end=0.05,
+ decay=int(1e5),
+ ),
+ replay_buffer=dict(replay_buffer_size=int(1e6), )
+ ),
+ ),
+)
+hopper_bdq_config = EasyDict(hopper_bdq_config)
+main_config = hopper_bdq_config
+
+hopper_bdq_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='bdq', ),
+)
+hopper_bdq_create_config = EasyDict(hopper_bdq_create_config)
+create_config = hopper_bdq_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c hopper_bdq_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline(
+ [main_config, create_config],
+ seed=0,
+ max_env_step=10000000,
+ )
diff --git a/DI-engine/dizoo/mujoco/config/hopper_cql_config.py b/DI-engine/dizoo/mujoco/config/hopper_cql_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7713d23381794f7867d7ac72ef947051aadbda43
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_cql_config.py
@@ -0,0 +1,69 @@
+from easydict import EasyDict
+
+hopper_cql_config = dict(
+ exp_name='hopper_cql_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=1e-4,
+ learning_rate_alpha=1e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=False,
+ with_lagrange=False,
+ lagrange_thresh=-1.0,
+ min_q_weight=5.0,
+ ),
+ collect=dict(
+ unroll_len=1,
+ data_type='naive',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ data_path='data_path_placeholder',
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+hopper_cql_config = EasyDict(hopper_cql_config)
+main_config = hopper_cql_config
+
+hopper_cql_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='cql',
+ import_names=['ding.policy.cql'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+hopper_cql_create_config = EasyDict(hopper_cql_create_config)
+create_config = hopper_cql_create_config
diff --git a/DI-engine/dizoo/mujoco/config/hopper_d4pg_config.py b/DI-engine/dizoo/mujoco/config/hopper_d4pg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e533ac684ecc0c085c97675ca9c120f137d87e29
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_d4pg_config.py
@@ -0,0 +1,69 @@
+from easydict import EasyDict
+
+hopper_d4pg_config = dict(
+ exp_name='hopper_d4pg_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=4,
+ evaluator_env_num=4,
+ n_evaluator_episode=8,
+ stop_value=5000,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ nstep=5,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ actor_head_hidden_size=512,
+ critic_head_hidden_size=512,
+ action_space='regression',
+ critic_head_type='categorical',
+ v_min=0,
+ v_max=1000, # 1000 ~ 3000
+ n_atom=51,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=3e-4,
+ learning_rate_critic=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=8,
+ unroll_len=1,
+ noise_sigma=0.2, # 0.1 ~ 0.2
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+hopper_d4pg_config = EasyDict(hopper_d4pg_config)
+main_config = hopper_d4pg_config
+
+hopper_d4pg_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='d4pg',
+ import_names=['ding.policy.d4pg'],
+ ),
+)
+hopper_d4pg_create_config = EasyDict(hopper_d4pg_create_config)
+create_config = hopper_d4pg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c hopper_d4pg_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/hopper_ddpg_config.py b/DI-engine/dizoo/mujoco/config/hopper_ddpg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c0a1a524e8d58f9aef997cd956aa5aecce2c73a0
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_ddpg_config.py
@@ -0,0 +1,65 @@
+from easydict import EasyDict
+
+hopper_ddpg_config = dict(
+ exp_name='hopper_ddpg_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ twin_critic=False,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+hopper_ddpg_config = EasyDict(hopper_ddpg_config)
+main_config = hopper_ddpg_config
+
+hopper_ddpg_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='ddpg',
+ import_names=['ding.policy.ddpg'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+hopper_ddpg_create_config = EasyDict(hopper_ddpg_create_config)
+create_config = hopper_ddpg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c hopper_ddpg_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/hopper_gail_sac_config.py b/DI-engine/dizoo/mujoco/config/hopper_gail_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..26ef8b3816ac4dfadcc72c08850836684bd59e07
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_gail_sac_config.py
@@ -0,0 +1,100 @@
+from easydict import EasyDict
+
+obs_shape = 11
+act_shape = 3
+hopper_gail_sac_config = dict(
+ exp_name='hopper_gail_sac_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ reward_model=dict(
+ input_size=obs_shape + act_shape,
+ hidden_size=256,
+ batch_size=64,
+ learning_rate=1e-3,
+ update_per_collect=100,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder+/reward_model/ckpt/ckpt_best.pth.tar',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ data_path='data_path_placeholder',
+ collect_count=100000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=act_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=64,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+hopper_gail_sac_config = EasyDict(hopper_gail_sac_config)
+main_config = hopper_gail_sac_config
+
+hopper_gail_sac_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ reward_model=dict(type='gail'),
+)
+hopper_gail_sac_create_config = EasyDict(hopper_gail_sac_create_config)
+create_config = hopper_gail_sac_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_gail -c hopper_gail_sac_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. hopper_sac_config.py
+ from ding.entry import serial_pipeline_gail
+ from dizoo.mujoco.config.hopper_sac_config import hopper_sac_config, hopper_sac_create_config
+ expert_main_config = hopper_sac_config
+ expert_create_config = hopper_sac_create_config
+ serial_pipeline_gail(
+ [main_config, create_config], [expert_main_config, expert_create_config],
+ max_env_step=1000000,
+ seed=0,
+ collect_data=True
+ )
diff --git a/DI-engine/dizoo/mujoco/config/hopper_gcl_config.py b/DI-engine/dizoo/mujoco/config/hopper_gcl_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..214f44dbf715b84202b83320baa213006098325b
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_gcl_config.py
@@ -0,0 +1,74 @@
+from easydict import EasyDict
+
+hopper_gcl_config = dict(
+ exp_name='hopper_gcl_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=4,
+ evaluator_env_num=10,
+ n_evaluator_episode=10,
+ stop_value=3000,
+ ),
+ reward_model=dict(
+ learning_rate=0.001,
+ input_size=14,
+ batch_size=32,
+ action_shape=3,
+ continuous=True,
+ update_per_collect=20,
+ ),
+ policy=dict(
+ cuda=False,
+ recompute_adv=True,
+ action_space='continuous',
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ action_space='continuous',
+ ),
+ learn=dict(
+ update_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.0,
+ clip_ratio=0.2,
+ adv_norm=True,
+ ),
+ collect=dict(
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ # If you need the data collected by the collector to contain logit key which reflect the probability of
+ # the action, you can change the key to be True.
+ # In Guided cost Learning, we need to use logit to train the reward model, we change the key to be True.
+ collector_logit=True,
+ n_sample=2048,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.97,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ ),
+)
+hopper_gcl_config = EasyDict(hopper_gcl_config)
+main_config = hopper_gcl_config
+
+hopper_gcl_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo', ),
+ reward_model=dict(type='guided_cost'),
+)
+hopper_gcl_create_config = EasyDict(hopper_gcl_create_config)
+create_config = hopper_gcl_create_config
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline_guided_cost
+ serial_pipeline_guided_cost((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/hopper_onppo_config.py b/DI-engine/dizoo/mujoco/config/hopper_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2cbf05a5532d3f3d33d618d4ec92881289c90a24
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_onppo_config.py
@@ -0,0 +1,67 @@
+from easydict import EasyDict
+
+hopper_onppo_config = dict(
+ exp_name='hopper_onppo_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=8,
+ evaluator_env_num=10,
+ n_evaluator_episode=10,
+ stop_value=4000,
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='continuous',
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ action_space='continuous',
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ update_per_collect=1,
+ batch_size=320,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ # for onppo, when we recompute adv, we need the key done in data to split traj, so we must
+ # use ignore_done=False here,
+ # but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
+ # for halfcheetah, the length=1000
+ ignore_done=False,
+ grad_clip_type='clip_norm',
+ grad_clip_value=0.5,
+ ),
+ collect=dict(
+ n_sample=3200,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+)
+hopper_onppo_config = EasyDict(hopper_onppo_config)
+main_config = hopper_onppo_config
+
+hopper_onppo_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo', ),
+)
+hopper_onppo_create_config = EasyDict(hopper_onppo_create_config)
+create_config = hopper_onppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c hopper_onppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/hopper_sac_config.py b/DI-engine/dizoo/mujoco/config/hopper_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9835aff0f46bffd0878c276ea4fe15dfc5534120
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_sac_config.py
@@ -0,0 +1,69 @@
+from easydict import EasyDict
+
+hopper_sac_config = dict(
+ exp_name='hopper_sac_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+hopper_sac_config = EasyDict(hopper_sac_config)
+main_config = hopper_sac_config
+
+hopper_sac_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+hopper_sac_create_config = EasyDict(hopper_sac_create_config)
+create_config = hopper_sac_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c hopper_sac_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/hopper_sac_data_generation_config.py b/DI-engine/dizoo/mujoco/config/hopper_sac_data_generation_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b32dd50c509ae394fb109bc1c5e44d70edb4eec
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_sac_data_generation_config.py
@@ -0,0 +1,78 @@
+from easydict import EasyDict
+
+hopper_sac_data_generation_config = dict(
+ exp_name='hopper_sac_data_generation_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=10,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ learner=dict(
+ # Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ load_path='model_path_placeholder',
+ hook=dict(
+ load_ckpt_before_run='model_path_placeholder',
+ save_ckpt_after_run=False,
+ )
+ ),
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ save_path='data_path_placeholder',
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+hopper_sac_data_generation_config = EasyDict(hopper_sac_data_generation_config)
+main_config = hopper_sac_data_generation_config
+
+hopper_sac_data_genearation_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+hopper_sac_data_genearation_create_config = EasyDict(hopper_sac_data_genearation_create_config)
+create_config = hopper_sac_data_genearation_create_config
diff --git a/DI-engine/dizoo/mujoco/config/hopper_sqil_sac_config.py b/DI-engine/dizoo/mujoco/config/hopper_sqil_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..172cb44ea0c77b3d8ac31287b6d7837ba0b79123
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_sqil_sac_config.py
@@ -0,0 +1,82 @@
+from easydict import EasyDict
+
+obs_shape = 11
+act_shape = 3
+hopper_sqil_config = dict(
+ exp_name='hopper_sqil_sac_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ expert_random_collect_size=10000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=act_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ update_per_collect=1,
+ batch_size=64,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=True,
+ ),
+ collect=dict(
+ n_sample=16,
+ model_path='model_path_placeholder',
+ # Cut trajectories into pieces with length "unroll_len".
+ unroll_len=1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+hopper_sqil_config = EasyDict(hopper_sqil_config)
+main_config = hopper_sqil_config
+
+hopper_sqil_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sqil_sac', ),
+ replay_buffer=dict(type='naive', ),
+)
+hopper_sqil_create_config = EasyDict(hopper_sqil_create_config)
+create_config = hopper_sqil_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_sqil -c hopper_sqil_sac_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. hopper_sac_config.py
+ from ding.entry import serial_pipeline_sqil
+ from dizoo.mujoco.config.hopper_sac_config import hopper_sac_config, hopper_sac_create_config
+ expert_main_config = hopper_sac_config
+ expert_create_config = hopper_sac_create_config
+ serial_pipeline_sqil(
+ [main_config, create_config],
+ [expert_main_config, expert_create_config],
+ max_env_step=3000000,
+ seed=0,
+ )
diff --git a/DI-engine/dizoo/mujoco/config/hopper_td3_bc_config.py b/DI-engine/dizoo/mujoco/config/hopper_td3_bc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..00a04075a8e1e0939f67c00ebd7d46d50f94ba2f
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_td3_bc_config.py
@@ -0,0 +1,80 @@
+from easydict import EasyDict
+
+hopper_td3_bc_config = dict(
+ exp_name='hopper_td3_bc_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(
+ use_norm=True,
+ offline_stats=dict(use_offline_stats=True, ),
+ ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ train_epoch=30000,
+ batch_size=256,
+ learning_rate_actor=3e-4,
+ learning_rate_critic=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ alpha=2.5,
+ ),
+ collect=dict(
+ unroll_len=1,
+ noise_sigma=0.1,
+ data_type='hdf5',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ data_path='data_path_placeholder',
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=1000, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=2000000, ), ),
+ ),
+)
+
+hopper_td3_bc_config = EasyDict(hopper_td3_bc_config)
+main_config = hopper_td3_bc_config
+
+hopper_td3_bc_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='td3_bc',
+ import_names=['ding.policy.td3_bc'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+hopper_td3_bc_create_config = EasyDict(hopper_td3_bc_create_config)
+create_config = hopper_td3_bc_create_config
+
+# if __name__ == "__main__":
+# # or you can enter `ding -m serial -c hopper_td3_bc_config.py -s 0`
+# from ding.entry import serial_pipeline
+# serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/hopper_td3_config.py b/DI-engine/dizoo/mujoco/config/hopper_td3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f72930ea07c0f204198e6c32ed5c0060ee97c4f
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_td3_config.py
@@ -0,0 +1,71 @@
+from easydict import EasyDict
+
+hopper_td3_config = dict(
+ exp_name='hopper_td3_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+
+hopper_td3_config = EasyDict(hopper_td3_config)
+main_config = hopper_td3_config
+
+hopper_td3_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='td3',
+ import_names=['ding.policy.td3'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+hopper_td3_create_config = EasyDict(hopper_td3_create_config)
+create_config = hopper_td3_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c hopper_td3_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/hopper_td3_data_generation_config.py b/DI-engine/dizoo/mujoco/config/hopper_td3_data_generation_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..97330419d4a92c2cbd73194118af35752e092e07
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_td3_data_generation_config.py
@@ -0,0 +1,81 @@
+from easydict import EasyDict
+
+hopper_td3_data_generation_config = dict(
+ exp_name='hopper_td3_data_generation_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=11000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ learner=dict(
+ # Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ load_path='model_path_placeholder',
+ hook=dict(
+ load_ckpt_before_run='model_path_placeholder',
+ save_ckpt_after_run=False,
+ )
+ ),
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ save_path='data_path_placeholder',
+ data_type='hdf5',
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+
+hopper_td3_data_generation_config = EasyDict(hopper_td3_data_generation_config)
+main_config = hopper_td3_data_generation_config
+
+hopper_td3_data_generation_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='td3',
+ import_names=['ding.policy.td3'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+hopper_td3_data_generation_create_config = EasyDict(hopper_td3_data_generation_create_config)
+create_config = hopper_td3_data_generation_create_config
diff --git a/DI-engine/dizoo/mujoco/config/hopper_trex_onppo_config.py b/DI-engine/dizoo/mujoco/config/hopper_trex_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69451fe3c0bcb25f4a3ab11a4ae6782626e4151
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_trex_onppo_config.py
@@ -0,0 +1,93 @@
+from easydict import EasyDict
+
+hopper_trex_onppo_config = dict(
+ exp_name='hopper_trex_onppo_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=8,
+ evaluator_env_num=10,
+ n_evaluator_episode=10,
+ stop_value=3000,
+ ),
+ reward_model=dict(
+ min_snippet_length=30,
+ max_snippet_length=100,
+ checkpoint_min=10000,
+ checkpoint_max=90000,
+ checkpoint_step=10000,
+ num_snippets=60000,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ # However, here in ``expert_model_path``, it is ``exp_name`` of the expert config.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder + /Hopper.params',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ # See ding/entry/application_entry_trex_collect_data.py to collect the data
+ data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ action_space='continuous',
+ ),
+ action_space='continuous',
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.0,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=2048,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.97,
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ ),
+)
+hopper_trex_onppo_config = EasyDict(hopper_trex_onppo_config)
+main_config = hopper_trex_onppo_config
+
+hopper_trex_onppo_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo', ),
+)
+hopper_trex_onppo_create_config = EasyDict(hopper_trex_onppo_create_config)
+create_config = hopper_trex_onppo_create_config
+
+if __name__ == '__main__':
+ # Users should first run ``hopper_onppo_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_trex_onpolicy
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_trex_onpolicy([main_config, create_config])
diff --git a/DI-engine/dizoo/mujoco/config/hopper_trex_sac_config.py b/DI-engine/dizoo/mujoco/config/hopper_trex_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5c4aa6f2c10aa6ebad0000c3cbcfecc34d33b40c
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/hopper_trex_sac_config.py
@@ -0,0 +1,102 @@
+from easydict import EasyDict
+
+hopper_trex_sac_config = dict(
+ exp_name='hopper_trex_sac_seed0',
+ env=dict(
+ env_id='Hopper-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ reward_model=dict(
+ learning_rate=1e-5,
+ min_snippet_length=30,
+ max_snippet_length=100,
+ checkpoint_min=1000,
+ checkpoint_max=9000,
+ checkpoint_step=1000,
+ update_per_collect=1,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ # However, here in ``expert_model_path``, it is ``exp_name`` of the expert config.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder + /Hopper.params',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ # See ding/entry/application_entry_trex_collect_data.py to collect the data
+ data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=11,
+ action_shape=3,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+hopper_trex_sac_config = EasyDict(hopper_trex_sac_config)
+main_config = hopper_trex_sac_config
+
+hopper_trex_sac_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+hopper_trex_sac_create_config = EasyDict(hopper_trex_sac_create_config)
+create_config = hopper_trex_sac_create_config
+
+if __name__ == '__main__':
+ # Users should first run ``hopper_sac_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_trex
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_trex([main_config, create_config])
diff --git a/DI-engine/dizoo/mujoco/config/mbrl/halfcheetah_mbsac_mbpo_config.py b/DI-engine/dizoo/mujoco/config/mbrl/halfcheetah_mbsac_mbpo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1ee4ac165bb0cf75b32b480a2398b8bceeb2d096
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/mbrl/halfcheetah_mbsac_mbpo_config.py
@@ -0,0 +1,110 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dream
+
+# environment hypo
+env_id = 'HalfCheetah-v3'
+obs_shape = 17
+action_shape = 6
+
+# gpu
+cuda = True
+
+main_config = dict(
+ exp_name='halfcheetach_mbsac_mbpo_seed0',
+ env=dict(
+ env_id=env_id,
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=100000,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # it is better to put random_collect_size in policy.other
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ lambda_=0.8,
+ sample_state=False,
+ update_per_collect=40,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=500, )), # w.r.t envstep
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=250, # w.r.t envstep
+ train_freq=250, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=20000,
+ rollout_end_step=40000,
+ rollout_length_min=1,
+ rollout_length_max=3,
+ ),
+ model=dict(
+ ensemble_size=7,
+ elite_size=5,
+ state_size=obs_shape, # has to be specified
+ action_size=action_shape, # has to be specified
+ reward_size=1,
+ hidden_size=200,
+ use_decay=True,
+ batch_size=256,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbmujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='mbsac',
+ import_names=['ding.policy.mbpolicy.mbsac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ world_model=dict(
+ type='mbpo',
+ import_names=['ding.world_model.mbpo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dream((main_config, create_config), seed=0, max_env_step=100000)
diff --git a/DI-engine/dizoo/mujoco/config/mbrl/halfcheetah_sac_mbpo_config.py b/DI-engine/dizoo/mujoco/config/mbrl/halfcheetah_sac_mbpo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c22eb0aa1f517a990bf77f8bf6e6030f879aec1
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/mbrl/halfcheetah_sac_mbpo_config.py
@@ -0,0 +1,115 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dyna
+
+# environment hypo
+env_id = 'HalfCheetah-v3'
+obs_shape = 17
+action_shape = 6
+
+# gpu
+cuda = True
+
+main_config = dict(
+ exp_name='halfcheetach_sac_mbpo_seed0',
+ env=dict(
+ env_id=env_id,
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=100000,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # it is better to put random_collect_size in policy.other
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=40,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=500, )), # w.r.t envstep
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=250, # w.r.t envstep
+ train_freq=250, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=20000,
+ rollout_end_step=150000,
+ rollout_length_min=1,
+ rollout_length_max=1,
+ ),
+ model=dict(
+ ensemble_size=7,
+ elite_size=5,
+ state_size=obs_shape, # has to be specified
+ action_size=action_shape, # has to be specified
+ reward_size=1,
+ hidden_size=400,
+ use_decay=True,
+ batch_size=256,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ other=dict(
+ rollout_batch_size=100000,
+ rollout_retain=4,
+ real_ratio=0.05,
+ imagination_buffer=dict(replay_buffer_size=6000000, ),
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbmujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ imagination_buffer=dict(type='elastic', ),
+ world_model=dict(
+ type='mbpo',
+ import_names=['ding.world_model.mbpo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dyna((main_config, create_config), seed=0, max_env_step=100000)
diff --git a/DI-engine/dizoo/mujoco/config/mbrl/halfcheetah_stevesac_mbpo_config.py b/DI-engine/dizoo/mujoco/config/mbrl/halfcheetah_stevesac_mbpo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c692252cfc35c119d986c222417ce9c9841b32c6
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/mbrl/halfcheetah_stevesac_mbpo_config.py
@@ -0,0 +1,109 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dream
+
+# environment hypo
+env_id = 'HalfCheetah-v3'
+obs_shape = 17
+action_shape = 6
+
+# gpu
+cuda = True
+
+main_config = dict(
+ exp_name='halfcheetah_stevesac_mbpo_seed0',
+ env=dict(
+ env_id=env_id,
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=100000,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # it is better to put random_collect_size in policy.other
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ ensemble_size=7,
+ update_per_collect=40,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=500, )), # w.r.t envstep
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=250, # w.r.t envstep
+ train_freq=250, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=20000,
+ rollout_end_step=40000,
+ rollout_length_min=1,
+ rollout_length_max=3,
+ ),
+ model=dict(
+ ensemble_size=7,
+ elite_size=5,
+ state_size=obs_shape, # has to be specified
+ action_size=action_shape, # has to be specified
+ reward_size=1,
+ hidden_size=200,
+ use_decay=True,
+ batch_size=256,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbmujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='stevesac',
+ import_names=['ding.policy.mbpolicy.mbsac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ world_model=dict(
+ type='mbpo',
+ import_names=['ding.world_model.mbpo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dream((main_config, create_config), seed=0, max_env_step=100000)
diff --git a/DI-engine/dizoo/mujoco/config/mbrl/hopper_mbsac_mbpo_config.py b/DI-engine/dizoo/mujoco/config/mbrl/hopper_mbsac_mbpo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7a37fe91e09fa945ef475ddea5fd297fd9d35804
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/mbrl/hopper_mbsac_mbpo_config.py
@@ -0,0 +1,110 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dream
+
+# environment hypo
+env_id = 'Hopper-v2'
+obs_shape = 11
+action_shape = 3
+
+# gpu
+cuda = True
+
+main_config = dict(
+ exp_name='hopper_mbsac_mbpo_seed0',
+ env=dict(
+ env_id=env_id,
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=100000,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # it is better to put random_collect_size in policy.other
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ lambda_=0.8,
+ sample_state=False,
+ update_per_collect=20,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=500, )), # w.r.t envstep
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=250, # w.r.t envstep
+ train_freq=250, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=20000,
+ rollout_end_step=40000,
+ rollout_length_min=1,
+ rollout_length_max=3,
+ ),
+ model=dict(
+ ensemble_size=7,
+ elite_size=5,
+ state_size=obs_shape, # has to be specified
+ action_size=action_shape, # has to be specified
+ reward_size=1,
+ hidden_size=200,
+ use_decay=True,
+ batch_size=256,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbmujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='mbsac',
+ import_names=['ding.policy.mbpolicy.mbsac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ world_model=dict(
+ type='mbpo',
+ import_names=['ding.world_model.mbpo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dream((main_config, create_config), seed=0, max_env_step=100000)
diff --git a/DI-engine/dizoo/mujoco/config/mbrl/hopper_sac_mbpo_config.py b/DI-engine/dizoo/mujoco/config/mbrl/hopper_sac_mbpo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f22ffed8426f073a3c61f1066a0b2aed4c39b5a
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/mbrl/hopper_sac_mbpo_config.py
@@ -0,0 +1,115 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dyna
+
+# environment hypo
+env_id = 'Hopper-v2'
+obs_shape = 11
+action_shape = 3
+
+# gpu
+cuda = True
+
+main_config = dict(
+ exp_name='hopper_sac_mbpo_seed0',
+ env=dict(
+ env_id=env_id,
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=100000,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # it is better to put random_collect_size in policy.other
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=500, )), # w.r.t envstep
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=250, # w.r.t envstep
+ train_freq=250, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=20000,
+ rollout_end_step=150000,
+ rollout_length_min=1,
+ rollout_length_max=15,
+ ),
+ model=dict(
+ ensemble_size=7,
+ elite_size=5,
+ state_size=obs_shape, # has to be specified
+ action_size=action_shape, # has to be specified
+ reward_size=1,
+ hidden_size=200,
+ use_decay=True,
+ batch_size=256,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ other=dict(
+ rollout_batch_size=100000,
+ rollout_retain=4,
+ real_ratio=0.05,
+ imagination_buffer=dict(replay_buffer_size=6000000, ),
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbmujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ imagination_buffer=dict(type='elastic', ),
+ world_model=dict(
+ type='mbpo',
+ import_names=['ding.world_model.mbpo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dyna((main_config, create_config), seed=0, max_env_step=100000)
diff --git a/DI-engine/dizoo/mujoco/config/mbrl/hopper_stevesac_mbpo_config.py b/DI-engine/dizoo/mujoco/config/mbrl/hopper_stevesac_mbpo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d22a0e42f19b4920a1205b279dd940b7ff7f1698
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/mbrl/hopper_stevesac_mbpo_config.py
@@ -0,0 +1,109 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dream
+
+# environment hypo
+env_id = 'Hopper-v2'
+obs_shape = 11
+action_shape = 3
+
+# gpu
+cuda = True
+
+main_config = dict(
+ exp_name='hopper_stevesac_mbpo_seed0',
+ env=dict(
+ env_id=env_id,
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=100000,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # it is better to put random_collect_size in policy.other
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ ensemble_size=7,
+ update_per_collect=20,
+ batch_size=256,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=500, )), # w.r.t envstep
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=250, # w.r.t envstep
+ train_freq=250, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=20000,
+ rollout_end_step=40000,
+ rollout_length_min=1,
+ rollout_length_max=3,
+ ),
+ model=dict(
+ ensemble_size=7,
+ elite_size=5,
+ state_size=obs_shape, # has to be specified
+ action_size=action_shape, # has to be specified
+ reward_size=1,
+ hidden_size=200,
+ use_decay=True,
+ batch_size=256,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbmujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='stevesac',
+ import_names=['ding.policy.mbpolicy.mbsac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ world_model=dict(
+ type='mbpo',
+ import_names=['ding.world_model.mbpo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dream((main_config, create_config), seed=0, max_env_step=100000)
diff --git a/DI-engine/dizoo/mujoco/config/mbrl/walker2d_mbsac_mbpo_config.py b/DI-engine/dizoo/mujoco/config/mbrl/walker2d_mbsac_mbpo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e8a348c26813356751dd3690f41969c2bfd9a0b
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/mbrl/walker2d_mbsac_mbpo_config.py
@@ -0,0 +1,110 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dream
+
+# environment hypo
+env_id = 'Walker2d-v2'
+obs_shape = 17
+action_shape = 6
+
+# gpu
+cuda = True
+
+main_config = dict(
+ exp_name='walker2d_mbsac_mbpo_seed0',
+ env=dict(
+ env_id=env_id,
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=4,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=100000,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # it is better to put random_collect_size in policy.other
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ lambda_=0.8,
+ sample_state=False,
+ update_per_collect=20,
+ batch_size=512,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=8,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=500, )), # w.r.t envstep
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=250, # w.r.t envstep
+ train_freq=250, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=30000,
+ rollout_end_step=100000,
+ rollout_length_min=1,
+ rollout_length_max=3,
+ ),
+ model=dict(
+ ensemble_size=7,
+ elite_size=5,
+ state_size=obs_shape, # has to be specified
+ action_size=action_shape, # has to be specified
+ reward_size=1,
+ hidden_size=200,
+ use_decay=True,
+ batch_size=512,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbmujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='mbsac',
+ import_names=['ding.policy.mbpolicy.mbsac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ world_model=dict(
+ type='mbpo',
+ import_names=['ding.world_model.mbpo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dream((main_config, create_config), seed=0, max_env_step=300000)
diff --git a/DI-engine/dizoo/mujoco/config/mbrl/walker2d_sac_mbpo_config.py b/DI-engine/dizoo/mujoco/config/mbrl/walker2d_sac_mbpo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..654451c26632d480604d722edd7dcc9121dc2982
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/mbrl/walker2d_sac_mbpo_config.py
@@ -0,0 +1,115 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dyna
+
+# environment hypo
+env_id = 'Walker2d-v2'
+obs_shape = 17
+action_shape = 6
+
+# gpu
+cuda = True
+
+main_config = dict(
+ exp_name='walker2d_sac_mbpo_seed0',
+ env=dict(
+ env_id=env_id,
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=4,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=100000,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # it is better to put random_collect_size in policy.other
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=512,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=8,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=500, )), # w.r.t envstep
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=250, # w.r.t envstep
+ train_freq=250, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=20000,
+ rollout_end_step=150000,
+ rollout_length_min=1,
+ rollout_length_max=1,
+ ),
+ model=dict(
+ ensemble_size=7,
+ elite_size=5,
+ state_size=obs_shape, # has to be specified
+ action_size=action_shape, # has to be specified
+ reward_size=1,
+ hidden_size=200,
+ use_decay=True,
+ batch_size=512,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ other=dict(
+ rollout_batch_size=100000,
+ rollout_retain=4,
+ real_ratio=0.05,
+ imagination_buffer=dict(replay_buffer_size=6000000, ),
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbmujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ imagination_buffer=dict(type='elastic', ),
+ world_model=dict(
+ type='mbpo',
+ import_names=['ding.world_model.mbpo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dyna((main_config, create_config), seed=0, max_env_step=300000)
diff --git a/DI-engine/dizoo/mujoco/config/mbrl/walker2d_stevesac_mbpo_config.py b/DI-engine/dizoo/mujoco/config/mbrl/walker2d_stevesac_mbpo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0b7502478fa5b4e9b3343bb91e6cc2c8300a1edb
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/mbrl/walker2d_stevesac_mbpo_config.py
@@ -0,0 +1,109 @@
+from easydict import EasyDict
+
+from ding.entry import serial_pipeline_dream
+
+# environment hypo
+env_id = 'Walker2d-v2'
+obs_shape = 17
+action_shape = 6
+
+# gpu
+cuda = True
+
+main_config = dict(
+ exp_name='walker2d_stevesac_mbpo_seed0',
+ env=dict(
+ env_id=env_id,
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=4,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=100000,
+ ),
+ policy=dict(
+ cuda=cuda,
+ # it is better to put random_collect_size in policy.other
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=action_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ ensemble_size=7,
+ update_per_collect=20,
+ batch_size=512,
+ learning_rate_q=3e-4,
+ learning_rate_policy=3e-4,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=8,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(evaluator=dict(eval_freq=500, )), # w.r.t envstep
+ other=dict(
+ # environment buffer
+ replay_buffer=dict(replay_buffer_size=1000000, periodic_thruput_seconds=60),
+ ),
+ ),
+ world_model=dict(
+ eval_freq=250, # w.r.t envstep
+ train_freq=250, # w.r.t envstep
+ cuda=cuda,
+ rollout_length_scheduler=dict(
+ type='linear',
+ rollout_start_step=30000,
+ rollout_end_step=100000,
+ rollout_length_min=1,
+ rollout_length_max=3,
+ ),
+ model=dict(
+ ensemble_size=7,
+ elite_size=5,
+ state_size=obs_shape, # has to be specified
+ action_size=action_shape, # has to be specified
+ reward_size=1,
+ hidden_size=200,
+ use_decay=True,
+ batch_size=512,
+ holdout_ratio=0.1,
+ max_epochs_since_update=5,
+ deterministic_rollout=True,
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+
+create_config = dict(
+ env=dict(
+ type='mbmujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='stevesac',
+ import_names=['ding.policy.mbpolicy.mbsac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ world_model=dict(
+ type='mbpo',
+ import_names=['ding.world_model.mbpo'],
+ ),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ serial_pipeline_dream((main_config, create_config), seed=0, max_env_step=300000)
diff --git a/DI-engine/dizoo/mujoco/config/walker2d_d4pg_config.py b/DI-engine/dizoo/mujoco/config/walker2d_d4pg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..31c6ff7d94157b856031099863e1ae48199e822b
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/walker2d_d4pg_config.py
@@ -0,0 +1,69 @@
+from easydict import EasyDict
+
+walker2d_d4pg_config = dict(
+ exp_name='walker2d_d4pg_seed0',
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=4,
+ evaluator_env_num=4,
+ n_evaluator_episode=8,
+ stop_value=7000,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=True,
+ nstep=5,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ actor_head_hidden_size=512,
+ critic_head_hidden_size=512,
+ action_space='regression',
+ critic_head_type='categorical',
+ v_min=0,
+ v_max=2000, # [1000, 4000]
+ n_atom=51,
+ ),
+ learn=dict(
+ update_per_collect=3, # [1, 4]
+ batch_size=256,
+ learning_rate_actor=3e-4,
+ learning_rate_critic=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=8,
+ unroll_len=1,
+ noise_sigma=0.2, # [0.1, 0.2]
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+walker2d_d4pg_config = EasyDict(walker2d_d4pg_config)
+main_config = walker2d_d4pg_config
+
+walker2d_d4pg_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='d4pg',
+ import_names=['ding.policy.d4pg'],
+ ),
+)
+walker2d_d4pg_create_config = EasyDict(walker2d_d4pg_create_config)
+create_config = walker2d_d4pg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c walker2d_d4pg_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/walker2d_ddpg_config.py b/DI-engine/dizoo/mujoco/config/walker2d_ddpg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..efe9bf9a391db874b78f6c56c5b895ef32395907
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/walker2d_ddpg_config.py
@@ -0,0 +1,65 @@
+from easydict import EasyDict
+
+walker2d_ddpg_config = dict(
+ exp_name='walker2d_ddpg_seed0',
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=False,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+walker2d_ddpg_config = EasyDict(walker2d_ddpg_config)
+main_config = walker2d_ddpg_config
+
+walker2d_ddpg_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='ddpg',
+ import_names=['ding.policy.ddpg'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+walker2d_ddpg_create_config = EasyDict(walker2d_ddpg_create_config)
+create_config = walker2d_ddpg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c walker2d_ddpg_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/walker2d_gail_ddpg_config.py b/DI-engine/dizoo/mujoco/config/walker2d_gail_ddpg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..779f65f63b5f871a7664bcbf4962197d1014a974
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/walker2d_gail_ddpg_config.py
@@ -0,0 +1,99 @@
+from easydict import EasyDict
+
+walker2d_gail_ddpg_config = dict(
+ exp_name='walker2d_gail_ddpg_seed0',
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ reward_model=dict(
+ input_size=23,
+ hidden_size=256,
+ batch_size=64,
+ learning_rate=1e-3,
+ update_per_collect=100,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder+/reward_model/ckpt/ckpt_best.pth.tar',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ data_path='data_path_placeholder',
+ collect_count=100000,
+ ),
+ policy=dict(
+ # state_dict of the policy.
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ load_path='walker2d_ddpg_gail/ckpt/ckpt_best.pth.tar',
+ cuda=True,
+ on_policy=False,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=False,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=1,
+ noise=False,
+ ),
+ collect=dict(
+ n_sample=64,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+walker2d_gail_ddpg_config = EasyDict(walker2d_gail_ddpg_config)
+main_config = walker2d_gail_ddpg_config
+
+walker2d_gail_ddpg_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='ddpg',
+ import_names=['ding.policy.ddpg'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+walker2d_gail_ddpg_create_config = EasyDict(walker2d_gail_ddpg_create_config)
+create_config = walker2d_gail_ddpg_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_gail -c walker2d_gail_ddpg_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. walker2d_ddpg_config.py
+ from ding.entry import serial_pipeline_gail
+ from dizoo.mujoco.config.walker2d_ddpg_config import walker2d_ddpg_config, walker2d_ddpg_create_config
+ expert_main_config = walker2d_ddpg_config
+ expert_create_config = walker2d_ddpg_create_config
+ serial_pipeline_gail(
+ [main_config, create_config], [expert_main_config, expert_create_config],
+ max_env_step=1000000,
+ seed=0,
+ collect_data=True
+ )
diff --git a/DI-engine/dizoo/mujoco/config/walker2d_gail_sac_config.py b/DI-engine/dizoo/mujoco/config/walker2d_gail_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7bd2de9022d549726eb638d0c2ffdc7319c45815
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/walker2d_gail_sac_config.py
@@ -0,0 +1,101 @@
+from easydict import EasyDict
+
+obs_shape = 17
+act_shape = 6
+walker2d_sac_gail_config = dict(
+ exp_name='walker2d_sac_gail_seed0',
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ reward_model=dict(
+ input_size=obs_shape + act_shape,
+ hidden_size=256,
+ batch_size=64,
+ learning_rate=1e-3,
+ update_per_collect=100,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder+/reward_model/ckpt/ckpt_best.pth.tar',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ data_path='data_path_placeholder',
+ collect_count=100000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=act_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=64,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+walker2d_sac_gail_config = EasyDict(walker2d_sac_gail_config)
+main_config = walker2d_sac_gail_config
+
+walker2d_sac_gail_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+ reward_model=dict(type='gail'),
+)
+walker2d_sac_gail_create_config = EasyDict(walker2d_sac_gail_create_config)
+create_config = walker2d_sac_gail_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_gail -c ant_gail_sac_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. hopper_sac_config.py
+ from ding.entry import serial_pipeline_gail
+ from dizoo.mujoco.config.walker2d_sac_config import walker2d_sac_config, walker2d_sac_create_config
+
+ expert_main_config = walker2d_sac_config
+ expert_create_config = walker2d_sac_create_config
+ serial_pipeline_gail(
+ [main_config, create_config], [expert_main_config, expert_create_config],
+ max_env_step=5000000,
+ seed=0,
+ collect_data=True
+ )
diff --git a/DI-engine/dizoo/mujoco/config/walker2d_gcl_config.py b/DI-engine/dizoo/mujoco/config/walker2d_gcl_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1b0b56fa321a661eefe72aed8aaa906005b634b0
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/walker2d_gcl_config.py
@@ -0,0 +1,75 @@
+from easydict import EasyDict
+
+walker2d_gcl_config = dict(
+ exp_name='walker2d_gcl_seed0',
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=8,
+ evaluator_env_num=10,
+ n_evaluator_episode=10,
+ stop_value=3000,
+ ),
+ reward_model=dict(
+ learning_rate=0.001,
+ input_size=23,
+ batch_size=32,
+ action_shape=6,
+ continuous=True,
+ update_per_collect=20,
+ ),
+ policy=dict(
+ cuda=False,
+ recompute_adv=True,
+ action_space='continuous',
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ action_space='continuous',
+ ),
+ learn=dict(
+ update_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.0,
+ clip_ratio=0.2,
+ adv_norm=True,
+ ),
+ collect=dict(
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ model_path='model_path_placeholder',
+ # If you need the data collected by the collector to contain logit key which reflect the probability of
+ # the action, you can change the key to be True.
+ # In Guided cost Learning, we need to use logit to train the reward model, we change the key to be True.
+ collector_logit=True,
+ n_sample=2048,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.97,
+ ),
+ eval=dict(evaluator=dict(eval_freq=100, )),
+ ),
+)
+walker2d_gcl_config = EasyDict(walker2d_gcl_config)
+main_config = walker2d_gcl_config
+
+walker2d_gcl_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo', ),
+ replay_buffer=dict(type='naive', ),
+ reward_model=dict(type='guided_cost'),
+)
+walker2d_gcl_create_config = EasyDict(walker2d_gcl_create_config)
+create_config = walker2d_gcl_create_config
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline_guided_cost
+ serial_pipeline_guided_cost((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/walker2d_onppo_config.py b/DI-engine/dizoo/mujoco/config/walker2d_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..035a9982862c75d206e9d8b6dfdf9f9a32b4836a
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/walker2d_onppo_config.py
@@ -0,0 +1,72 @@
+from easydict import EasyDict
+
+collector_env_num = 1
+evaluator_env_num = 1
+walker2d_onppo_config = dict(
+ exp_name='walker2d_onppo_seed0',
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=10,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ action_space='continuous',
+ model=dict(
+ action_space='continuous',
+ obs_shape=17,
+ action_shape=6,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ update_per_collect=1,
+ batch_size=320,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.001,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ # for onppo, when we recompute adv, we need the key done in data to split traj, so we must
+ # use ignore_done=False here,
+ # but when we add key traj_flag in data as the backup for key done, we could choose to use ignore_done=True
+ # for halfcheetah, the length=1000
+ # ignore_done=True,
+ ignore_done=False,
+ grad_clip_type='clip_norm',
+ grad_clip_value=0.5,
+ ),
+ collect=dict(
+ collector_env_num=collector_env_num,
+ n_sample=3200,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+)
+walker2d_onppo_config = EasyDict(walker2d_onppo_config)
+main_config = walker2d_onppo_config
+
+walker2d_onppo_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='base'),
+ # env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo', ),
+)
+walker2d_onppo_create_config = EasyDict(walker2d_onppo_create_config)
+create_config = walker2d_onppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c walker2d_onppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/walker2d_sac_config.py b/DI-engine/dizoo/mujoco/config/walker2d_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f5a5a3127e07bd3ee22f5784597e852ae42d1bda
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/walker2d_sac_config.py
@@ -0,0 +1,69 @@
+from easydict import EasyDict
+
+walker2d_sac_config = dict(
+ exp_name='walker2d_sac_seed0',
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+walker2d_sac_config = EasyDict(walker2d_sac_config)
+main_config = walker2d_sac_config
+
+walker2d_sac_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+walker2d_sac_create_config = EasyDict(walker2d_sac_create_config)
+create_config = walker2d_sac_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c walker2d_sac_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/walker2d_sqil_sac_config.py b/DI-engine/dizoo/mujoco/config/walker2d_sqil_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..59967f9f3349b697950cac1eed9afbfcf1896545
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/walker2d_sqil_sac_config.py
@@ -0,0 +1,82 @@
+from easydict import EasyDict
+
+obs_shape = 17
+act_shape = 6
+walker2d_sqil_config = dict(
+ exp_name='walker2d_sqil_sac_seed0',
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ expert_random_collect_size=10000,
+ model=dict(
+ obs_shape=obs_shape,
+ action_shape=act_shape,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ nstep=1,
+ discount_factor=0.97,
+ learn=dict(
+ update_per_collect=1,
+ batch_size=64,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=True,
+ ),
+ collect=dict(
+ n_sample=16,
+ unroll_len=1,
+ model_path='model_path_placeholder',
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+walker2d_sqil_config = EasyDict(walker2d_sqil_config)
+main_config = walker2d_sqil_config
+
+walker2d_sqil_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sqil_sac', ),
+ replay_buffer=dict(type='naive', ),
+)
+walker2d_sqil_create_config = EasyDict(walker2d_sqil_create_config)
+create_config = walker2d_sqil_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_sqil -c walker2d_sqil_sac_config.py -s 0`
+ # then input the config you used to generate your expert model in the path mentioned above
+ # e.g. walker2d_sac_config.py
+ from ding.entry import serial_pipeline_sqil
+ from dizoo.mujoco.config.walker2d_sac_config import walker2d_sac_config, walker2d_sac_create_config
+
+ expert_main_config = walker2d_sac_config
+ expert_create_config = walker2d_sac_create_config
+ serial_pipeline_sqil(
+ [main_config, create_config],
+ [expert_main_config, expert_create_config],
+ max_env_step=5000000,
+ seed=0,
+ )
diff --git a/DI-engine/dizoo/mujoco/config/walker2d_td3_config.py b/DI-engine/dizoo/mujoco/config/walker2d_td3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c1bffcedff25de41cb2550b6886d8ca2f91cb84
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/walker2d_td3_config.py
@@ -0,0 +1,71 @@
+from easydict import EasyDict
+
+walker2d_td3_config = dict(
+ exp_name='walker2d_td3_seed0',
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=25000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ action_space='regression',
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=2,
+ noise=True,
+ noise_sigma=0.2,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ noise_sigma=0.1,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ )
+)
+
+walker2d_td3_config = EasyDict(walker2d_td3_config)
+main_config = walker2d_td3_config
+
+walker2d_td3_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='td3',
+ import_names=['ding.policy.td3'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+walker2d_td3_create_config = EasyDict(walker2d_td3_create_config)
+create_config = walker2d_td3_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial -c walker2d_td3_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/mujoco/config/walker2d_trex_onppo_config.py b/DI-engine/dizoo/mujoco/config/walker2d_trex_onppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c53c1efb4b86468a8c8e6be94babf037f073555c
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/walker2d_trex_onppo_config.py
@@ -0,0 +1,93 @@
+from easydict import EasyDict
+
+walker2d_trex_onppo_config = dict(
+ exp_name='walker2d_trex_onppo_seed0',
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=8,
+ evaluator_env_num=10,
+ n_evaluator_episode=10,
+ stop_value=3000,
+ ),
+ reward_model=dict(
+ min_snippet_length=30,
+ max_snippet_length=100,
+ checkpoint_min=10000,
+ checkpoint_max=90000,
+ checkpoint_step=10000,
+ num_snippets=60000,
+ learning_rate=1e-5,
+ update_per_collect=1,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ # However, here in ``expert_model_path``, it is ``exp_name`` of the expert config.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder + /Walker2d.params',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ # See ding/entry/application_entry_trex_collect_data.py to collect the data
+ data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ cuda=True,
+ recompute_adv=True,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ action_space='continuous',
+ ),
+ action_space='continuous',
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.0,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=2048,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.97,
+ ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ ),
+)
+walker2d_trex_onppo_config = EasyDict(walker2d_trex_onppo_config)
+main_config = walker2d_trex_onppo_config
+
+walker2d_trex_onppo_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo', ),
+)
+walker2d_trex_onppo_create_config = EasyDict(walker2d_trex_onppo_create_config)
+create_config = walker2d_trex_onppo_create_config
+
+if __name__ == '__main__':
+ # Users should first run ``walker2d_onppo_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_trex_onpolicy
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_trex_onpolicy([main_config, create_config])
diff --git a/DI-engine/dizoo/mujoco/config/walker2d_trex_sac_config.py b/DI-engine/dizoo/mujoco/config/walker2d_trex_sac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fdd1cab65efd7f13e7561524aa1116a190649c79
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/config/walker2d_trex_sac_config.py
@@ -0,0 +1,102 @@
+from easydict import EasyDict
+
+walker2d_trex_sac_config = dict(
+ exp_name='walker2d_trex_sac_seed0',
+ env=dict(
+ env_id='Walker2d-v3',
+ norm_obs=dict(use_norm=False, ),
+ norm_reward=dict(use_norm=False, ),
+ collector_env_num=1,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ reward_model=dict(
+ learning_rate=1e-5,
+ min_snippet_length=30,
+ max_snippet_length=100,
+ checkpoint_min=1000,
+ checkpoint_max=9000,
+ checkpoint_step=1000,
+ update_per_collect=1,
+ # Users should add their own model path here. Model path should lead to a model.
+ # Absolute path is recommended.
+ # In DI-engine, it is ``exp_name/ckpt/ckpt_best.pth.tar``.
+ # However, here in ``expert_model_path``, it is ``exp_name`` of the expert config.
+ expert_model_path='model_path_placeholder',
+ # Path where to store the reward model
+ reward_model_path='data_path_placeholder + /Walker2d.params',
+ # Users should add their own data path here. Data path should lead to a file to store data or load the stored data.
+ # Absolute path is recommended.
+ # In DI-engine, it is usually located in ``exp_name`` directory
+ # See ding/entry/application_entry_trex_collect_data.py to collect the data
+ data_path='data_path_placeholder',
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=10000,
+ model=dict(
+ obs_shape=17,
+ action_shape=6,
+ twin_critic=True,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=1,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ reparameterization=True,
+ auto_alpha=False,
+ ),
+ collect=dict(
+ n_sample=1,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(),
+ other=dict(replay_buffer=dict(replay_buffer_size=1000000, ), ),
+ ),
+)
+
+walker2d_trex_sac_config = EasyDict(walker2d_trex_sac_config)
+main_config = walker2d_trex_sac_config
+
+walker2d_trex_sac_create_config = dict(
+ env=dict(
+ type='mujoco',
+ import_names=['dizoo.mujoco.envs.mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='sac',
+ import_names=['ding.policy.sac'],
+ ),
+ replay_buffer=dict(type='naive', ),
+)
+walker2d_trex_sac_create_config = EasyDict(walker2d_trex_sac_create_config)
+create_config = walker2d_trex_sac_create_config
+
+if __name__ == '__main__':
+ # Users should first run ``walker2d_sac_config.py`` to save models (or checkpoints).
+ # Note: Users should check that the checkpoints generated should include iteration_'checkpoint_min'.pth.tar, iteration_'checkpoint_max'.pth.tar with the interval checkpoint_step
+ # where checkpoint_max, checkpoint_min, checkpoint_step are specified above.
+ import argparse
+ import torch
+ from ding.entry import trex_collecting_data
+ from ding.entry import serial_pipeline_trex
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--cfg', type=str, default='please enter abs path for this file')
+ parser.add_argument('--seed', type=int, default=0)
+ parser.add_argument('--device', type=str, default='cuda' if torch.cuda.is_available() else 'cpu')
+ args = parser.parse_args()
+ # The function ``trex_collecting_data`` below is to collect episodic data for training the reward model in trex.
+ trex_collecting_data(args)
+ serial_pipeline_trex([main_config, create_config])
diff --git a/DI-engine/dizoo/mujoco/entry/__init__.py b/DI-engine/dizoo/mujoco/entry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/mujoco/entry/mujoco_cql_generation_main.py b/DI-engine/dizoo/mujoco/entry/mujoco_cql_generation_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..96d32e1db12299dccc064d49dec8f315d6dc452c
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/entry/mujoco_cql_generation_main.py
@@ -0,0 +1,32 @@
+from dizoo.mujoco.config.hopper_sac_data_generation_config import main_config, create_config
+from ding.entry import collect_demo_data, eval
+import torch
+import copy
+
+
+def eval_ckpt(args):
+ config = copy.deepcopy([main_config, create_config])
+ eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run)
+
+
+def generate(args):
+ config = copy.deepcopy([main_config, create_config])
+ state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
+ collect_demo_data(
+ config,
+ collect_count=main_config.policy.other.replay_buffer.replay_buffer_size,
+ seed=args.seed,
+ expert_data_path=main_config.policy.collect.save_path,
+ state_dict=state_dict
+ )
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ args = parser.parse_args()
+
+ eval_ckpt(args)
+ generate(args)
diff --git a/DI-engine/dizoo/mujoco/entry/mujoco_cql_main.py b/DI-engine/dizoo/mujoco/entry/mujoco_cql_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..19c1e51acdebc7fe7ae42a507ebaab814b9f64d0
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/entry/mujoco_cql_main.py
@@ -0,0 +1,17 @@
+from dizoo.mujoco.config.hopper_cql_config import main_config, create_config
+from ding.entry import serial_pipeline_offline
+
+
+def train(args):
+ config = [main_config, create_config]
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=10)
+ args = parser.parse_args()
+
+ train(args)
diff --git a/DI-engine/dizoo/mujoco/entry/mujoco_d4pg_main.py b/DI-engine/dizoo/mujoco/entry/mujoco_d4pg_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bb4b72e35825e58bc60aa6fbbbf847e9c796baa
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/entry/mujoco_d4pg_main.py
@@ -0,0 +1,67 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import D4PGPolicy
+from ding.model.template.qac_dist import QACDIST
+from ding.utils import set_pkg_seed
+from dizoo.mujoco.envs.mujoco_env import MujocoEnv
+from dizoo.classic_control.pendulum.config.pendulum_ppo_config import pendulum_ppo_config
+from dizoo.mujoco.config.hopper_d4pg_config import hopper_d4pg_config
+
+
+def main(cfg, seed=0, max_iterations=int(1e10)):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ D4PGPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ MujocoEnv,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(
+ env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManager(
+ env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ collector_env.seed(seed, dynamic_seed=True)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = QACDIST(**cfg.policy.model)
+ policy = D4PGPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger)
+ collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
+ evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ for _ in range(max_iterations):
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data from environments
+ new_data = collector.collect(train_iter=learner.train_iter)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Train
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+ replay_buffer.update(learner.priority_info)
+
+
+if __name__ == "__main__":
+ main(hopper_d4pg_config)
diff --git a/DI-engine/dizoo/mujoco/entry/mujoco_ddpg_eval.py b/DI-engine/dizoo/mujoco/entry/mujoco_ddpg_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..c3092e24e01c1704a52745486998b3b13e7e4bcb
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/entry/mujoco_ddpg_eval.py
@@ -0,0 +1,57 @@
+import os
+import gym
+import torch
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+from functools import partial
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.policy import DDPGPolicy
+from ding.model import ContinuousQAC
+from ding.utils import set_pkg_seed
+from ding.rl_utils import get_epsilon_greedy_fn
+from dizoo.mujoco.envs.mujoco_env import MujocoEnv
+from dizoo.mujoco.config.ant_ddpg_config import ant_ddpg_config
+
+
+def main(main_cfg, seed=0):
+ cfg = compile_config(
+ main_cfg,
+ BaseEnvManager,
+ DDPGPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ MujocoEnv,
+ save_cfg=True
+ )
+
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ evaluator_env.enable_save_replay(cfg.env.replay_path) # switch save replay interface
+
+ # Set random seed for all package and instance
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ # Set up RL Policy
+ model = ContinuousQAC(**cfg.policy.model)
+ policy = DDPGPolicy(cfg.policy, model=model)
+ policy.eval_mode.load_state_dict(torch.load(cfg.policy.load_path, map_location='cpu'))
+
+ # evaluate
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator.eval()
+
+
+if __name__ == "__main__":
+ main(ant_ddpg_config, seed=0)
diff --git a/DI-engine/dizoo/mujoco/entry/mujoco_ddpg_main.py b/DI-engine/dizoo/mujoco/entry/mujoco_ddpg_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8313ec4eda79cf2a0ce2292a0a5b6119024f4ff
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/entry/mujoco_ddpg_main.py
@@ -0,0 +1,65 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, AdvancedReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import DDPGPolicy
+from ding.model import ContinuousQAC
+from ding.utils import set_pkg_seed
+from dizoo.mujoco.envs.mujoco_env import MujocoEnv
+from dizoo.mujoco.config.hopper_ddpg_config import hopper_ddpg_config
+
+
+def main(cfg, seed=0, max_iterations=int(1e10)):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ DDPGPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ AdvancedReplayBuffer,
+ MujocoEnv,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(
+ env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManager(
+ env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ collector_env.seed(seed, dynamic_seed=True)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = ContinuousQAC(**cfg.policy.model)
+ policy = DDPGPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger)
+ collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
+ evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
+ replay_buffer = AdvancedReplayBuffer(cfg.policy.other.replay_buffer, tb_logger, exp_name=cfg.exp_name)
+
+ for _ in range(max_iterations):
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ # Collect data from environments
+ new_data = collector.collect(train_iter=learner.train_iter)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+ # Train
+ for i in range(cfg.policy.learn.update_per_collect):
+ train_data = replay_buffer.sample(learner.policy.get_attribute('batch_size'), learner.train_iter)
+ if train_data is None:
+ break
+ learner.train(train_data, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(hopper_ddpg_config)
diff --git a/DI-engine/dizoo/mujoco/entry/mujoco_ppo_main.py b/DI-engine/dizoo/mujoco/entry/mujoco_ppo_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..830ea8a94b61750767e9e0e7840d4338825a7ac5
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/entry/mujoco_ppo_main.py
@@ -0,0 +1,58 @@
+import os
+import gym
+from tensorboardX import SummaryWriter
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, NaiveReplayBuffer
+from ding.envs import BaseEnvManager, DingEnvWrapper
+from ding.policy import PPOPolicy
+from ding.model import VAC
+from ding.utils import set_pkg_seed
+from dizoo.classic_control.pendulum.envs import PendulumEnv
+from dizoo.mujoco.envs.mujoco_env import MujocoEnv
+from dizoo.classic_control.pendulum.config.pendulum_ppo_config import pendulum_ppo_config
+from dizoo.mujoco.config.hopper_onppo_config import hopper_onppo_config
+
+
+def main(cfg, seed=0, max_iterations=int(1e10)):
+ cfg = compile_config(
+ cfg,
+ BaseEnvManager,
+ PPOPolicy,
+ BaseLearner,
+ SampleSerialCollector,
+ InteractionSerialEvaluator,
+ NaiveReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(
+ env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManager(
+ env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ collector_env.seed(seed, dynamic_seed=True)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter(os.path.join('./log/', 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger)
+ collector = SampleSerialCollector(cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger)
+ evaluator = InteractionSerialEvaluator(cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger)
+
+ for _ in range(max_iterations):
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ new_data = collector.collect(train_iter=learner.train_iter)
+ learner.train(new_data, collector.envstep)
+
+
+if __name__ == "__main__":
+ main(hopper_onppo_config)
diff --git a/DI-engine/dizoo/mujoco/entry/mujoco_td3_bc_main.py b/DI-engine/dizoo/mujoco/entry/mujoco_td3_bc_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..ebedfb82501101d6d4ce9970a39200fb0a9756a0
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/entry/mujoco_td3_bc_main.py
@@ -0,0 +1,60 @@
+import torch
+from copy import deepcopy
+
+from dizoo.mujoco.config.hopper_td3_data_generation_config import main_config, create_config
+from ding.entry import serial_pipeline_offline, collect_demo_data, eval, serial_pipeline
+
+
+def train_td3_bc(args):
+ from dizoo.mujoco.config.hopper_td3_bc_config import main_config, create_config
+ main_config.exp_name = 'td3_bc'
+ main_config.policy.collect.data_path = './td3/expert_demos.hdf5'
+ main_config.policy.collect.data_type = 'hdf5'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline_offline(config, seed=args.seed)
+
+
+def eval_ckpt(args):
+ main_config.exp_name = 'td3'
+ main_config.policy.learn.learner.load_path = './td3/ckpt/ckpt_best.pth.tar'
+ main_config.policy.learn.learner.hook.load_ckpt_before_run = './td3/ckpt/ckpt_best.pth.tar'
+ state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
+ config = deepcopy([main_config, create_config])
+ eval(config, seed=args.seed, load_path=main_config.policy.learn.learner.hook.load_ckpt_before_run)
+ # eval(config, seed=args.seed, state_dict=state_dict)
+
+
+def generate(args):
+ main_config.exp_name = 'td3'
+ main_config.policy.learn.learner.load_path = './td3/ckpt/ckpt_best.pth.tar'
+ main_config.policy.collect.save_path = './td3/expert.pkl'
+ main_config.policy.collect.data_type = 'hdf5'
+ config = deepcopy([main_config, create_config])
+ state_dict = torch.load(main_config.policy.learn.learner.load_path, map_location='cpu')
+ collect_demo_data(
+ config,
+ collect_count=main_config.policy.other.replay_buffer.replay_buffer_size,
+ seed=args.seed,
+ expert_data_path=main_config.policy.collect.save_path,
+ state_dict=state_dict
+ )
+
+
+def train_expert(args):
+ from dizoo.mujoco.config.hopper_td3_config import main_config, create_config
+ main_config.exp_name = 'td3'
+ config = deepcopy([main_config, create_config])
+ serial_pipeline(config, seed=args.seed, max_iterations=int(1e6))
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ args = parser.parse_args()
+
+ train_expert(args)
+ eval_ckpt(args)
+ generate(args)
+ train_td3_bc(args)
diff --git a/DI-engine/dizoo/mujoco/envs/__init__.py b/DI-engine/dizoo/mujoco/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0ed00b309af13330b5c68c1f167bb0055f119f34
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/envs/__init__.py
@@ -0,0 +1,2 @@
+from .mujoco_env import MujocoEnv
+from .mujoco_disc_env import MujocoDiscEnv
diff --git a/DI-engine/dizoo/mujoco/envs/mujoco_disc_env.py b/DI-engine/dizoo/mujoco/envs/mujoco_disc_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..442b5b2535ba628c732e532a935555f6a673283b
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/envs/mujoco_disc_env.py
@@ -0,0 +1,166 @@
+import copy
+import os
+from itertools import product
+from typing import Union, List, Optional
+
+import gym
+import numpy as np
+from easydict import EasyDict
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.envs.common import save_frames_as_gif
+from ding.torch_utils import to_ndarray
+from ding.utils import ENV_REGISTRY
+from .mujoco_wrappers import wrap_mujoco
+
+
+@ENV_REGISTRY.register('mujoco-disc')
+class MujocoDiscEnv(BaseEnv):
+ """
+ Overview:
+ The modified Mujoco environment with manually discretized action space. For each dimension, equally dividing the
+ original continuous action into ``each_dim_disc_size`` bins and using their Cartesian product to obtain
+ handcrafted discrete actions.
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ config = dict(
+ action_clip=False,
+ delay_reward_step=0,
+ replay_path=None,
+ save_replay_gif=False,
+ replay_path_gif=None,
+ )
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._action_clip = cfg.action_clip
+ self._delay_reward_step = cfg.delay_reward_step
+ self._init_flag = False
+ self._replay_path = None
+ self._replay_path_gif = cfg.replay_path_gif
+ self._save_replay_gif = cfg.save_replay_gif
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = self._make_env()
+ self._env.observation_space.dtype = np.float32 # To unify the format of envs in DI-engine
+ self._observation_space = self._env.observation_space
+ self._raw_action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ if self._replay_path is not None:
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+ if self._save_replay_gif:
+ self._frames = []
+ obs = self._env.reset()
+ obs = to_ndarray(obs).astype('float32')
+
+ # disc_to_cont: transform discrete action index to original continuous action
+ self.m = self._raw_action_space.shape[0]
+ self.n = self._cfg.each_dim_disc_size
+ self.K = self.n ** self.m
+ self.disc_to_cont = list(product(*[list(range(self.n)) for _ in range(self.m)]))
+ self._eval_episode_return = 0.
+ # the modified discrete action space
+ self._action_space = gym.spaces.Discrete(self.K)
+
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
+ # disc_to_cont: transform discrete action index to original continuous action
+ action = [-1 + 2 / self.n * k for k in self.disc_to_cont[int(action)]]
+ action = to_ndarray(action)
+
+ if self._save_replay_gif:
+ self._frames.append(self._env.render(mode='rgb_array'))
+ if self._action_clip:
+ action = np.clip(action, -1, 1)
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+
+ if done:
+ if self._save_replay_gif:
+ path = os.path.join(
+ self._replay_path_gif, '{}_episode_{}.gif'.format(self._cfg.env_id, self._save_replay_count)
+ )
+ save_frames_as_gif(self._frames, path)
+ self._save_replay_count += 1
+ info['eval_episode_return'] = self._eval_episode_return
+
+ obs = to_ndarray(obs).astype(np.float32)
+ rew = to_ndarray([rew]).astype(np.float32)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def _make_env(self):
+ return wrap_mujoco(
+ self._cfg.env_id,
+ norm_obs=self._cfg.get('norm_obs', None),
+ norm_reward=self._cfg.get('norm_reward', None),
+ delay_reward_step=self._delay_reward_step
+ )
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+ self._save_replay = True
+ self._save_replay_count = 0
+
+ def random_action(self) -> np.ndarray:
+ return self.action_space.sample()
+
+ def __repr__(self) -> str:
+ return "DI-engine modified Mujoco Env({}) with manually discretized action space".format(self._cfg.env_id)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_cfg = copy.deepcopy(cfg)
+ collector_env_num = collector_cfg.pop('collector_env_num', 1)
+ return [collector_cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_cfg = copy.deepcopy(cfg)
+ evaluator_env_num = evaluator_cfg.pop('evaluator_env_num', 1)
+ evaluator_cfg.norm_reward.use_norm = False
+ return [evaluator_cfg for _ in range(evaluator_env_num)]
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
diff --git a/DI-engine/dizoo/mujoco/envs/mujoco_env.py b/DI-engine/dizoo/mujoco/envs/mujoco_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..c150581a5b56bc81137dc9f38c91fe1b576df1de
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/envs/mujoco_env.py
@@ -0,0 +1,229 @@
+import copy
+import os
+from typing import Union, List, Optional
+
+import gym
+import numpy as np
+import torch
+from easydict import EasyDict
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.envs.common import save_frames_as_gif
+from ding.torch_utils import to_ndarray
+from ding.utils import ENV_REGISTRY
+from .mujoco_wrappers import wrap_mujoco
+
+
+@ENV_REGISTRY.register('mujoco')
+class MujocoEnv(BaseEnv):
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ config = dict(
+ action_clip=False,
+ delay_reward_step=0,
+ replay_path=None,
+ save_replay_gif=False,
+ replay_path_gif=None,
+ action_bins_per_branch=None,
+ )
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._action_clip = cfg.action_clip
+ self._delay_reward_step = cfg.delay_reward_step
+ self._init_flag = False
+ self._replay_path = None
+ self._replay_path_gif = cfg.replay_path_gif
+ self._save_replay_gif = cfg.save_replay_gif
+ self._action_bins_per_branch = cfg.action_bins_per_branch
+
+ def map_action(self, action: Union[np.ndarray, list]) -> Union[np.ndarray, list]:
+ """
+ Overview:
+ Map the discretized action index to the action in the original action space.
+ Arguments:
+ - action (:obj:`np.ndarray or list`): The discretized action index. \
+ The value ranges is {0, 1, ..., self._action_bins_per_branch - 1}.
+ Returns:
+ - outputs (:obj:`list`): The action in the original action space. \
+ The value ranges is [-1, 1].
+ Examples:
+ >>> inputs = [2, 0, 4]
+ >>> self._action_bins_per_branch = 5
+ >>> outputs = map_action(inputs)
+ >>> assert isinstance(outputs, list) and outputs == [0.0, -1.0, 1.0]
+ """
+ return [2 * x / (self._action_bins_per_branch - 1) - 1 for x in action]
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = self._make_env()
+ if self._replay_path is not None:
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+
+ self._env.observation_space.dtype = np.float32 # To unify the format of envs in DI-engine
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ obs = self._env.reset()
+ obs = to_ndarray(obs).astype('float32')
+ self._eval_episode_return = 0.
+
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
+ if self._action_bins_per_branch:
+ action = self.map_action(action)
+ action = to_ndarray(action)
+ if self._save_replay_gif:
+ self._frames.append(self._env.render(mode='rgb_array'))
+ if self._action_clip:
+ action = np.clip(action, -1, 1)
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ if done:
+ if self._save_replay_gif:
+ path = os.path.join(
+ self._replay_path_gif, '{}_episode_{}.gif'.format(self._cfg.env_id, self._save_replay_count)
+ )
+ save_frames_as_gif(self._frames, path)
+ self._save_replay_count += 1
+ info['eval_episode_return'] = self._eval_episode_return
+
+ obs = to_ndarray(obs).astype(np.float32)
+ rew = to_ndarray([rew]).astype(np.float32)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def _make_env(self):
+ return wrap_mujoco(
+ self._cfg.env_id,
+ norm_obs=self._cfg.get('norm_obs', None),
+ norm_reward=self._cfg.get('norm_reward', None),
+ delay_reward_step=self._delay_reward_step
+ )
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ def random_action(self) -> np.ndarray:
+ return self.action_space.sample()
+
+ def __repr__(self) -> str:
+ return "DI-engine Mujoco Env({})".format(self._cfg.env_id)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_cfg = copy.deepcopy(cfg)
+ collector_env_num = collector_cfg.pop('collector_env_num', 1)
+ return [collector_cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_cfg = copy.deepcopy(cfg)
+ evaluator_env_num = evaluator_cfg.pop('evaluator_env_num', 1)
+ evaluator_cfg.norm_reward.use_norm = False
+ return [evaluator_cfg for _ in range(evaluator_env_num)]
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+
+@ENV_REGISTRY.register('mbmujoco')
+class MBMujocoEnv(MujocoEnv):
+
+ def termination_fn(self, next_obs: torch.Tensor) -> torch.Tensor:
+ """
+ Overview:
+ This function determines whether each state is a terminated state.
+ .. note::
+ This is a collection of termination functions for mujocos used in MBPO (arXiv: 1906.08253),\
+ directly copied from MBPO repo https://github.com/jannerm/mbpo/tree/master/mbpo/static.
+ """
+ assert len(next_obs.shape) == 2
+ if self._cfg.env_id == "Hopper-v2":
+ height = next_obs[:, 0]
+ angle = next_obs[:, 1]
+ not_done = torch.isfinite(next_obs).all(-1) \
+ * (torch.abs(next_obs[:, 1:]) < 100).all(-1) \
+ * (height > .7) \
+ * (torch.abs(angle) < .2)
+
+ done = ~not_done
+ return done
+ elif self._cfg.env_id == "Walker2d-v2":
+ height = next_obs[:, 0]
+ angle = next_obs[:, 1]
+ not_done = (height > 0.8) \
+ * (height < 2.0) \
+ * (angle > -1.0) \
+ * (angle < 1.0)
+ done = ~not_done
+ return done
+ elif 'walker_' in self._cfg.env_id:
+ torso_height = next_obs[:, -2]
+ torso_ang = next_obs[:, -1]
+ if 'walker_7' in self._cfg.env_id or 'walker_5' in self._cfg.env_id:
+ offset = 0.
+ else:
+ offset = 0.26
+ not_done = (torso_height > 0.8 - offset) \
+ * (torso_height < 2.0 - offset) \
+ * (torso_ang > -1.0) \
+ * (torso_ang < 1.0)
+ done = ~not_done
+ return done
+ elif self._cfg.env_id == "HalfCheetah-v3":
+ done = torch.zeros_like(next_obs.sum(-1)).bool()
+ return done
+ elif self._cfg.env_id in ['Ant-v2', 'AntTruncatedObs-v2']:
+ x = next_obs[:, 0]
+ not_done = torch.isfinite(next_obs).all(axis=-1) \
+ * (x >= 0.2) \
+ * (x <= 1.0)
+ done = ~not_done
+ return done
+ elif self._cfg.env_id in ['Humanoid-v2', 'HumanoidTruncatedObs-v2']:
+ z = next_obs[:, 0]
+ done = (z < 1.0) + (z > 2.0)
+ return done
+ else:
+ raise KeyError("not implemented env_id: {}".format(self._cfg.env_id))
diff --git a/DI-engine/dizoo/mujoco/envs/mujoco_gym_env.py b/DI-engine/dizoo/mujoco/envs/mujoco_gym_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..fbc31ecbfba767cc262c70b92817a922f4fabc8a
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/envs/mujoco_gym_env.py
@@ -0,0 +1,66 @@
+import numpy as np
+
+import gym
+from gym.envs.mujoco.ant import AntEnv
+from gym.envs.mujoco.humanoid import HumanoidEnv
+
+
+def gym_env_register(id, max_episode_steps=1000):
+
+ def register(gym_env):
+ spec = {
+ 'id': id,
+ 'entry_point': (f'dizoo.mujoco.envs.mujoco_gym_env:{gym_env.__name__}'),
+ 'max_episode_steps': max_episode_steps
+ }
+ gym.register(**spec)
+ return gym_env
+
+ return register
+
+
+@gym_env_register('AntTruncatedObs-v2')
+class AntTruncatedObsEnv(AntEnv):
+ """
+ Overview:
+ Modified ant with observation dim truncated to 27, which is used in MBPO (arXiv: 1906.08253).
+ .. note::
+ External forces (sim.data.cfrc_ext) are removed from the observation.
+ Otherwise identical to Ant-v2 from\
+ .
+ """
+
+ def _get_obs(self):
+ return np.concatenate(
+ [
+ self.sim.data.qpos.flat[2:],
+ self.sim.data.qvel.flat,
+ # np.clip(self.sim.data.cfrc_ext, -1, 1).flat,
+ ]
+ )
+
+
+@gym_env_register('HumanoidTruncatedObs-v2')
+class HumanoidTruncatedObsEnv(HumanoidEnv):
+ """
+ Overview:
+ Modified humanoid with observation dim truncated to 45, which is used in MBPO (arXiv: 1906.08253).
+ .. note::
+ COM inertia (cinert), COM velocity (cvel), actuator forces (qfrc_actuator),\
+ and external forces (cfrc_ext) are removed from the observation.
+ Otherwise identical to Humanoid-v2 from\
+ .
+ """
+
+ def _get_obs(self):
+ data = self.sim.data
+ return np.concatenate(
+ [
+ data.qpos.flat[2:],
+ data.qvel.flat,
+ # data.cinert.flat,
+ # data.cvel.flat,
+ # data.qfrc_actuator.flat,
+ # data.cfrc_ext.flat
+ ]
+ )
diff --git a/DI-engine/dizoo/mujoco/envs/mujoco_wrappers.py b/DI-engine/dizoo/mujoco/envs/mujoco_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..377172f2f8f393ad83ec5b6212d0eb0b22105b88
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/envs/mujoco_wrappers.py
@@ -0,0 +1,36 @@
+from typing import Dict
+import gym
+import numpy as np
+
+from ding.envs import ObsNormWrapper, RewardNormWrapper, DelayRewardWrapper, EvalEpisodeReturnWrapper
+
+
+def wrap_mujoco(
+ env_id,
+ norm_obs: Dict = dict(use_norm=False, ),
+ norm_reward: Dict = dict(use_norm=False, ),
+ delay_reward_step: int = 1
+) -> gym.Env:
+ r"""
+ Overview:
+ Wrap Mujoco Env to preprocess env step's return info, e.g. observation normalization, reward normalization, etc.
+ Arguments:
+ - env_id (:obj:`str`): Mujoco environment id, for example "HalfCheetah-v3"
+ - norm_obs (:obj:`EasyDict`): Whether to normalize observation or not
+ - norm_reward (:obj:`EasyDict`): Whether to normalize reward or not. For evaluator, environment's reward \
+ should not be normalized: Either ``norm_reward`` is None or ``norm_reward.use_norm`` is False can do this.
+ Returns:
+ - wrapped_env (:obj:`gym.Env`): The wrapped mujoco environment
+ """
+ # import customized gym environment
+ from . import mujoco_gym_env
+ env = gym.make(env_id)
+ env = EvalEpisodeReturnWrapper(env)
+ if norm_obs is not None and norm_obs.use_norm:
+ env = ObsNormWrapper(env)
+ if norm_reward is not None and norm_reward.use_norm:
+ env = RewardNormWrapper(env, norm_reward.reward_discount)
+ if delay_reward_step > 1:
+ env = DelayRewardWrapper(env, delay_reward_step)
+
+ return env
diff --git a/DI-engine/dizoo/mujoco/envs/test/test_mujoco_disc_env.py b/DI-engine/dizoo/mujoco/envs/test/test_mujoco_disc_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8a39cc9c24b01a6b4353de243092fb3a25f85d5
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/envs/test/test_mujoco_disc_env.py
@@ -0,0 +1,44 @@
+import pytest
+import numpy as np
+from easydict import EasyDict
+
+from ding.utils import set_pkg_seed
+from dizoo.mujoco.envs import MujocoDiscEnv
+
+
+@pytest.mark.envtest
+def test_mujoco_env_eval_episode_return():
+ set_pkg_seed(1234, use_cuda=False)
+ each_dim_disc_size = 2
+ env = MujocoDiscEnv(
+ EasyDict(
+ {
+ 'env_id': 'Ant-v3',
+ 'action_clip': False,
+ 'each_dim_disc_size': each_dim_disc_size,
+ 'delay_reward_step': 4,
+ 'save_replay_gif': False,
+ 'replay_path_gif': None
+ }
+ )
+ )
+ env.seed(1234)
+ env.reset()
+ action_dim = env._raw_action_space.shape
+ eval_episode_return = np.array([0.], dtype=np.float32)
+ while True:
+ action = np.random.randint(0, each_dim_disc_size ** action_dim[0], 1)
+ timestep = env.step(action)
+ eval_episode_return += timestep.reward
+ # print("{}(dtype: {})".format(timestep.reward, timestep.reward.dtype))
+ if timestep.done:
+ print(
+ "{}({}), {}({})".format(
+ timestep.info['eval_episode_return'], type(timestep.info['eval_episode_return']),
+ eval_episode_return, type(eval_episode_return)
+ )
+ )
+ # timestep.reward and the cumulative reward in wrapper EvalEpisodeReturn are not the same.
+ assert abs(timestep.info['eval_episode_return'].item() - eval_episode_return.item()) / \
+ abs(timestep.info['eval_episode_return'].item()) < 1e-5
+ break
diff --git a/DI-engine/dizoo/mujoco/envs/test/test_mujoco_env.py b/DI-engine/dizoo/mujoco/envs/test/test_mujoco_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..34bd311850e7e1989294da70e453b6edf9b4b786
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/envs/test/test_mujoco_env.py
@@ -0,0 +1,73 @@
+import os
+import pytest
+import numpy as np
+from easydict import EasyDict
+
+from ding.utils import set_pkg_seed
+from dizoo.mujoco.envs import MujocoEnv
+
+
+@pytest.mark.envtest
+@pytest.mark.parametrize('delay_reward_step', [1, 10])
+def test_mujoco_env_delay_reward(delay_reward_step):
+ set_pkg_seed(1234, use_cuda=False)
+ env = MujocoEnv(
+ EasyDict(
+ {
+ 'env_id': 'Ant-v3',
+ 'action_clip': False,
+ 'delay_reward_step': delay_reward_step,
+ 'save_replay_gif': False,
+ 'replay_path_gif': None
+ }
+ )
+ )
+ env.seed(1234)
+ env.reset()
+ action_dim = env.action_space.shape
+ for i in range(25):
+ # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
+ # can generate legal random action.
+ if i < 10:
+ action = np.random.random(size=action_dim)
+ else:
+ action = env.random_action()
+ timestep = env.step(action)
+ print(timestep.reward)
+ assert timestep.reward.shape == (1, ), timestep.reward.shape
+
+
+@pytest.mark.envtest
+def test_mujoco_env_eval_episode_return():
+ set_pkg_seed(1234, use_cuda=False)
+ env = MujocoEnv(
+ EasyDict(
+ {
+ 'env_id': 'Ant-v3',
+ 'action_clip': False,
+ 'delay_reward_step': 4,
+ 'save_replay_gif': False,
+ 'replay_path_gif': None
+ }
+ )
+ )
+ env.seed(1234)
+ env.reset()
+ action_dim = env.action_space.shape
+ eval_episode_return = np.array([0.], dtype=np.float32)
+ while True:
+ action = np.random.random(size=action_dim)
+ timestep = env.step(action)
+ eval_episode_return += timestep.reward
+ # print("{}(dtype: {})".format(timestep.reward, timestep.reward.dtype))
+ if timestep.done:
+ print(
+ "{}({}), {}({})".format(
+ timestep.info['eval_episode_return'], type(timestep.info['eval_episode_return']),
+ eval_episode_return, type(eval_episode_return)
+ )
+ )
+ # timestep.reward and the cumulative reward in wrapper EvalEpisodeReturn are not the same.
+ assert abs(timestep.info['eval_episode_return'].item() - eval_episode_return.item()) / \
+ abs(timestep.info['eval_episode_return'].item()) < 1e-5
+ break
diff --git a/DI-engine/dizoo/mujoco/envs/test/test_mujoco_gym_env.py b/DI-engine/dizoo/mujoco/envs/test/test_mujoco_gym_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..8d47effbb6315cc5ebf20b1a566a32edba096466
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/envs/test/test_mujoco_gym_env.py
@@ -0,0 +1,13 @@
+import pytest
+import gym
+
+
+@pytest.mark.envtest
+def test_shapes():
+ from dizoo.mujoco.envs import mujoco_gym_env
+ ant = gym.make('AntTruncatedObs-v2')
+ assert ant.observation_space.shape == (27, )
+ assert ant.action_space.shape == (8, )
+ humanoid = gym.make('HumanoidTruncatedObs-v2')
+ assert humanoid.observation_space.shape == (45, )
+ assert humanoid.action_space.shape == (17, )
diff --git a/DI-engine/dizoo/mujoco/example/mujoco_bc_main.py b/DI-engine/dizoo/mujoco/example/mujoco_bc_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..c48d4a1e94308e7248b548af56ad272321c15235
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/example/mujoco_bc_main.py
@@ -0,0 +1,77 @@
+from ding.entry import serial_pipeline_bc, serial_pipeline, collect_demo_data
+from dizoo.mujoco.config.halfcheetah_td3_config import main_config, create_config
+from copy import deepcopy
+from typing import Union, Optional, List, Any, Tuple
+import os
+import torch
+import logging
+from functools import partial
+from tensorboardX import SummaryWriter
+import torch.nn as nn
+from ding.envs import get_vec_env_setting, create_env_manager
+from ding.worker import BaseLearner, InteractionSerialEvaluator, BaseSerialCommander, create_buffer, \
+ create_serial_collector
+from ding.config import read_config, compile_config
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from ding.entry.utils import random_collect
+from ding.entry import collect_demo_data, collect_episodic_demo_data, episode_to_transitions
+import pickle
+
+
+def load_policy(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ load_path: str,
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+) -> 'Policy': # noqa
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = input_cfg
+ create_cfg.policy.type = create_cfg.policy.type + '_command'
+ env_fn = None if env_setting is None else env_setting[0]
+ cfg = compile_config(cfg, seed=seed, env=env_fn, auto=True, create_cfg=create_cfg, save_cfg=True)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval', 'command'])
+ sd = torch.load(load_path, map_location='cpu')
+ policy.collect_mode.load_state_dict(sd)
+ return policy
+
+
+def main():
+ half_td3_config, half_td3_create_config = main_config, create_config
+ train_config = [deepcopy(half_td3_config), deepcopy(half_td3_create_config)]
+ exp_path = 'DI-engine/halfcheetah_td3_seed0/ckpt/ckpt_best.pth.tar'
+ expert_policy = load_policy(train_config, load_path=exp_path, seed=0)
+
+ # collect expert demo data
+ collect_count = 100
+ expert_data_path = 'expert_data.pkl'
+ state_dict = expert_policy.collect_mode.state_dict()
+ collect_config = [deepcopy(half_td3_config), deepcopy(half_td3_create_config)]
+
+ collect_episodic_demo_data(
+ deepcopy(collect_config),
+ seed=0,
+ state_dict=state_dict,
+ expert_data_path=expert_data_path,
+ collect_count=collect_count
+ )
+
+ episode_to_transitions(expert_data_path, expert_data_path, nstep=1)
+
+ # il training 2
+ il_config = [deepcopy(half_td3_config), deepcopy(half_td3_create_config)]
+ il_config[0].policy.learn.train_epoch = 1000000
+ il_config[0].policy.type = 'bc'
+ il_config[0].policy.continuous = True
+ il_config[0].exp_name = "continuous_bc_seed0"
+ il_config[0].env.stop_value = 50000
+ il_config[0].multi_agent = False
+ bc_policy, converge_stop_flag = serial_pipeline_bc(il_config, seed=314, data_path=expert_data_path, max_iter=4e6)
+ return bc_policy
+
+
+if __name__ == '__main__':
+ policy = main()
diff --git a/DI-engine/dizoo/mujoco/example/mujoco_sac.py b/DI-engine/dizoo/mujoco/example/mujoco_sac.py
new file mode 100644
index 0000000000000000000000000000000000000000..f2833187f01448b2fc1c3d32b28206f8a8bc8dc6
--- /dev/null
+++ b/DI-engine/dizoo/mujoco/example/mujoco_sac.py
@@ -0,0 +1,45 @@
+from ditk import logging
+from ding.model import ContinuousQAC
+from ding.policy import SACPolicy
+from ding.envs import DingEnvWrapper, SubprocessEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import data_pusher, StepCollector, interaction_evaluator, \
+ CkptSaver, OffPolicyLearner, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.mujoco.envs.mujoco_env import MujocoEnv
+from dizoo.mujoco.config.hopper_sac_config import main_config, create_config
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = SubprocessEnvManagerV2(
+ env_fn=[lambda: MujocoEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ model = ContinuousQAC(**cfg.policy.model)
+ buffer_ = DequeBuffer(size=cfg.policy.other.replay_buffer.replay_buffer_size)
+ policy = SACPolicy(cfg.policy, model=model)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(
+ StepCollector(cfg, policy.collect_mode, collector_env, random_collect_size=cfg.policy.random_collect_size)
+ )
+ task.use(data_pusher(cfg, buffer_))
+ task.use(OffPolicyLearner(cfg, policy.learn_mode, buffer_))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=500))
+ task.use(termination_checker(max_env_step=int(3e6)))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/multiagent_mujoco/README.md b/DI-engine/dizoo/multiagent_mujoco/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..6e82c6ecb1f79706c785a8270a304448df8ab5c2
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/README.md
@@ -0,0 +1,7 @@
+## Multi Agent Mujoco Env
+
+Multi Agent Mujoco is an environment for Continuous Multi-Agent Robotic Control, based on OpenAI's Mujoco Gym environments.
+
+The environment is described in the paper [Deep Multi-Agent Reinforcement Learning for Decentralized Continuous Cooperative Control](https://arxiv.org/abs/2003.06709) by Christian Schroeder de Witt, Bei Peng, Pierre-Alexandre Kamienny, Philip Torr, Wendelin Böhmer and Shimon Whiteson, Torr Vision Group and Whiteson Research Lab, University of Oxford, 2020
+
+You can find more details in [Multi-Agent Mujoco Environment](https://github.com/schroederdewitt/multiagent_mujoco)
diff --git a/DI-engine/dizoo/multiagent_mujoco/__init__.py b/DI-engine/dizoo/multiagent_mujoco/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/multiagent_mujoco/config/ant_maddpg_config.py b/DI-engine/dizoo/multiagent_mujoco/config/ant_maddpg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ed6744e818f7cf937eb103a993f70d7d614a39af
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/config/ant_maddpg_config.py
@@ -0,0 +1,63 @@
+from easydict import EasyDict
+
+ant_ddpg_default_config = dict(
+ exp_name='multi_mujoco_ant_2x4_ddpg',
+ env=dict(
+ scenario='Ant-v2',
+ agent_conf="2x4d",
+ agent_obsk=2,
+ add_agent_id=False,
+ episode_limit=1000,
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=0,
+ multi_agent=True,
+ model=dict(
+ agent_obs_shape=54,
+ global_obs_shape=111,
+ action_shape=4,
+ action_space='regression',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=10,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ target_theta=0.005,
+ discount_factor=0.99,
+ ),
+ collect=dict(
+ n_sample=400,
+ noise_sigma=0.1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
+ ),
+)
+
+ant_ddpg_default_config = EasyDict(ant_ddpg_default_config)
+main_config = ant_ddpg_default_config
+
+ant_ddpg_default_create_config = dict(
+ env=dict(
+ type='mujoco_multi',
+ import_names=['dizoo.multiagent_mujoco.envs.multi_mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ddpg'),
+ replay_buffer=dict(type='naive', ),
+)
+ant_ddpg_default_create_config = EasyDict(ant_ddpg_default_create_config)
+create_config = ant_ddpg_default_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c ant_maddpg_config.py -s 0`
+ from ding.entry.serial_entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/multiagent_mujoco/config/ant_mappo_config.py b/DI-engine/dizoo/multiagent_mujoco/config/ant_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d11c31be8d495e6d8ad8a29be0efbde4ec0d427b
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/config/ant_mappo_config.py
@@ -0,0 +1,80 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='multi_mujoco_ant_2x4_ppo',
+ env=dict(
+ scenario='Ant-v2',
+ agent_conf="2x4d",
+ agent_obsk=2,
+ add_agent_id=False,
+ episode_limit=1000,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ action_space='continuous',
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=2,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ agent_obs_shape=54,
+ #global_obs_shape=216,
+ global_obs_shape=111,
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=4,
+ # (List[int]) The size of hidden layer
+ # hidden_size_list=[64],
+ action_space='continuous'
+ ),
+ # used in state_num of hidden_state
+ learn=dict(
+ epoch_per_collect=3,
+ batch_size=800,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.001,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=True,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=5,
+ ),
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=1000, )),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='mujoco_multi',
+ import_names=['dizoo.multiagent_mujoco.envs.multi_mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/multiagent_mujoco/config/ant_masac_config.py b/DI-engine/dizoo/multiagent_mujoco/config/ant_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9316b095c0867116d4ac18203bde1241cd8b56f4
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/config/ant_masac_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+ant_sac_default_config = dict(
+ exp_name='multi_mujoco_ant_2x4_sac',
+ env=dict(
+ scenario='Ant-v2',
+ agent_conf="2x4d",
+ agent_obsk=2,
+ add_agent_id=False,
+ episode_limit=1000,
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=0,
+ multi_agent=True,
+ model=dict(
+ agent_obs_shape=54,
+ global_obs_shape=111,
+ action_shape=4,
+ action_space='reparameterization',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=10,
+ batch_size=256,
+ learning_rate_q=1e-3,
+ learning_rate_policy=1e-3,
+ learning_rate_alpha=3e-4,
+ target_theta=0.005,
+ discount_factor=0.99,
+ ),
+ collect=dict(n_sample=400, ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
+ ),
+)
+
+ant_sac_default_config = EasyDict(ant_sac_default_config)
+main_config = ant_sac_default_config
+
+ant_sac_default_create_config = dict(
+ env=dict(
+ type='mujoco_multi',
+ import_names=['dizoo.multiagent_mujoco.envs.multi_mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sac'),
+ replay_buffer=dict(type='naive', ),
+)
+ant_sac_default_create_config = EasyDict(ant_sac_default_create_config)
+create_config = ant_sac_default_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c ant_masac_config.py -s 0`
+ from ding.entry.serial_entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/multiagent_mujoco/config/ant_matd3_config.py b/DI-engine/dizoo/multiagent_mujoco/config/ant_matd3_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4575f40de58dbaf1935dd5d9015a4fe28a934faa
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/config/ant_matd3_config.py
@@ -0,0 +1,66 @@
+from easydict import EasyDict
+
+ant_td3_default_config = dict(
+ exp_name='multi_mujoco_ant_2x4_td3',
+ env=dict(
+ scenario='Ant-v2',
+ agent_conf="2x4d",
+ agent_obsk=2,
+ add_agent_id=False,
+ episode_limit=1000,
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=0,
+ multi_agent=True,
+ model=dict(
+ agent_obs_shape=54,
+ global_obs_shape=111,
+ action_shape=4,
+ action_space='regression',
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ twin_critic=True,
+ ),
+ learn=dict(
+ update_per_collect=10,
+ batch_size=256,
+ learning_rate_actor=1e-3,
+ learning_rate_critic=1e-3,
+ target_theta=0.005,
+ discount_factor=0.99,
+ actor_update_freq=2,
+ noise=True,
+ ),
+ collect=dict(
+ n_sample=400,
+ noise_sigma=0.1,
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), ),
+ ),
+)
+
+ant_td3_default_config = EasyDict(ant_td3_default_config)
+main_config = ant_td3_default_config
+
+ant_td3_default_create_config = dict(
+ env=dict(
+ type='mujoco_multi',
+ import_names=['dizoo.multiagent_mujoco.envs.multi_mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='td3'),
+ replay_buffer=dict(type='naive', ),
+)
+ant_td3_default_create_config = EasyDict(ant_td3_default_create_config)
+create_config = ant_td3_default_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c ant_matd3_config.py -s 0`
+ from ding.entry.serial_entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/multiagent_mujoco/config/halfcheetah_happo_config.py b/DI-engine/dizoo/multiagent_mujoco/config/halfcheetah_happo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c849551d94b9a99c697cf1a61e4e9de1732568b7
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/config/halfcheetah_happo_config.py
@@ -0,0 +1,83 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+n_agent = 2
+
+main_config = dict(
+ exp_name='HAPPO_result/debug/multi_mujoco_halfcheetah_2x3_happo',
+ env=dict(
+ scenario='HalfCheetah-v2',
+ agent_conf="2x3",
+ agent_obsk=2,
+ add_agent_id=False,
+ episode_limit=1000,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ agent_num=n_agent,
+ action_space='continuous',
+ model=dict(
+ action_space='continuous',
+ agent_num=n_agent,
+ agent_obs_shape=8,
+ global_obs_shape=17,
+ action_shape=3,
+ use_lstm=False,
+ ),
+ learn=dict(
+ epoch_per_collect=5,
+ # batch_size=3200,
+ batch_size=800,
+ learning_rate=5e-4,
+ critic_learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ # entropy_weight=0.001,
+ entropy_weight=0.001,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=True,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=3,
+ ignore_done=True,
+ # ignore_done=False,
+ ),
+ collect=dict(
+ n_sample=3200,
+ unroll_len=1,
+ env_num=collector_env_num,
+ ),
+ eval=dict(
+ env_num=evaluator_env_num,
+ evaluator=dict(eval_freq=1000, ),
+ ),
+ other=dict(),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='mujoco_multi',
+ import_names=['dizoo.multiagent_mujoco.envs.multi_mujoco_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='happo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/multiagent_mujoco/config/halfcheetah_mappo_config.py b/DI-engine/dizoo/multiagent_mujoco/config/halfcheetah_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6db3feea7ba7346d1fc66c9d93d340b96d88838
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/config/halfcheetah_mappo_config.py
@@ -0,0 +1,80 @@
+from easydict import EasyDict
+
+collector_env_num = 8
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='HAPPO_result/multi_mujoco_halfcheetah_2x3_mappo',
+ env=dict(
+ scenario='HalfCheetah-v2',
+ agent_conf="2x3",
+ agent_obsk=2,
+ add_agent_id=False,
+ episode_limit=1000,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ action_space='continuous',
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=2,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ agent_obs_shape=8,
+ #global_obs_shape=216,
+ global_obs_shape=17,
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=3,
+ # (List[int]) The size of hidden layer
+ # hidden_size_list=[64],
+ action_space='continuous'
+ ),
+ # used in state_num of hidden_state
+ learn=dict(
+ epoch_per_collect=5,
+ batch_size=800,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.001,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=True,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=5,
+ ),
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=1000, )),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='mujoco_multi',
+ import_names=['dizoo.multiagent_mujoco.envs.multi_mujoco_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/multiagent_mujoco/config/walker2d_happo_config.py b/DI-engine/dizoo/multiagent_mujoco/config/walker2d_happo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a947a255892885508d6aa29e34ed382412f4b609
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/config/walker2d_happo_config.py
@@ -0,0 +1,91 @@
+from easydict import EasyDict
+import os
+collector_env_num = 8
+evaluator_env_num = 8
+n_agent = 2
+
+main_config = dict(
+ exp_name='HAPPO_result/debug/multi_mujoco_walker_2x3_happo',
+ env=dict(
+ scenario='Walker2d-v2',
+ agent_conf="2x3",
+ agent_obsk=2,
+ add_agent_id=False,
+ episode_limit=1000,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=8,
+ stop_value=6000,
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ agent_num=n_agent,
+ action_space='continuous',
+ model=dict(
+ action_space='continuous',
+ agent_num=n_agent,
+ agent_obs_shape=8,
+ global_obs_shape=17,
+ action_shape=3,
+ use_lstm=False,
+ ),
+ learn=dict(
+ epoch_per_collect=5,
+ # batch_size=3200,
+ # batch_size=800,
+ batch_size=320,
+ # batch_size=100,
+ learning_rate=5e-4,
+ critic_learning_rate=5e-3,
+ # learning_rate=3e-3,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ # value_weight=0.5,
+ value_weight=1,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ # entropy_weight=0.001,
+ entropy_weight=0.003,
+ # entropy_weight=0.005,
+ # entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=True,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ # grad_clip_value=5,
+ grad_clip_value=10,
+ # ignore_done=True,
+ ignore_done=False,
+ ),
+ collect=dict(
+ n_sample=3200,
+ # n_sample=4000,
+ unroll_len=1,
+ env_num=collector_env_num,
+ ),
+ eval=dict(
+ env_num=evaluator_env_num,
+ evaluator=dict(eval_freq=1000, ),
+ ),
+ other=dict(),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='mujoco_multi',
+ import_names=['dizoo.multiagent_mujoco.envs.multi_mujoco_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='happo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0, max_env_step=int(1e7))
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/__init__.py b/DI-engine/dizoo/multiagent_mujoco/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..0a828ba4e982da51f5fc218a79e597e131a57567
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/__init__.py
@@ -0,0 +1,4 @@
+from .mujoco_multi import MujocoMulti
+from .coupled_half_cheetah import CoupledHalfCheetah
+from .manyagent_swimmer import ManyAgentSwimmerEnv
+from .manyagent_ant import ManyAgentAntEnv
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/assets/.gitignore b/DI-engine/dizoo/multiagent_mujoco/envs/assets/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..eb0d0a0f1a89ef2ca8e1433ffbe77cb361e0cf11
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/assets/.gitignore
@@ -0,0 +1 @@
+*.auto.xml
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/assets/__init__.py b/DI-engine/dizoo/multiagent_mujoco/envs/assets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/assets/coupled_half_cheetah.xml b/DI-engine/dizoo/multiagent_mujoco/envs/assets/coupled_half_cheetah.xml
new file mode 100644
index 0000000000000000000000000000000000000000..b8c2f9f626b5969edc98f5984e13ca5a3bab36f7
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/assets/coupled_half_cheetah.xml
@@ -0,0 +1,140 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ -
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_ant.xml b/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_ant.xml
new file mode 100644
index 0000000000000000000000000000000000000000..103c74452687b247a06e7c5bd43d7d0582dc23d3
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_ant.xml
@@ -0,0 +1,134 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_ant.xml.template b/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_ant.xml.template
new file mode 100644
index 0000000000000000000000000000000000000000..3b6b4eb85a14d9416c398a01fd4ab4bc6d397575
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_ant.xml.template
@@ -0,0 +1,54 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ body }}
+
+
+
+ {{ actuators }}
+
+
\ No newline at end of file
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_ant__stage1.xml b/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_ant__stage1.xml
new file mode 100644
index 0000000000000000000000000000000000000000..c6ef416f3c33575eb088742242d339613a651e23
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_ant__stage1.xml
@@ -0,0 +1,85 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer.xml.template b/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer.xml.template
new file mode 100644
index 0000000000000000000000000000000000000000..9fb49a95230e5dc8983ef5c81788a5463ef9d99e
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer.xml.template
@@ -0,0 +1,34 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ {{ body }}
+
+
+
+
+{{ actuators }}
+
+
\ No newline at end of file
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer__bckp2.xml b/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer__bckp2.xml
new file mode 100644
index 0000000000000000000000000000000000000000..bce5149599c5eec4cae496030c0523a58ba33b53
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer__bckp2.xml
@@ -0,0 +1,48 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer_bckp.xml b/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer_bckp.xml
new file mode 100644
index 0000000000000000000000000000000000000000..3477813790a32e81d4db1bc7b9a997d90f70c58b
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/assets/manyagent_swimmer_bckp.xml
@@ -0,0 +1,43 @@
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
\ No newline at end of file
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/coupled_half_cheetah.py b/DI-engine/dizoo/multiagent_mujoco/envs/coupled_half_cheetah.py
new file mode 100644
index 0000000000000000000000000000000000000000..6fe0a68507fd5272ff1c3d6bc7ea827e9fbac7eb
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/coupled_half_cheetah.py
@@ -0,0 +1,48 @@
+import numpy as np
+from gym import utils
+from gym.envs.mujoco import mujoco_env
+import os
+
+
+class CoupledHalfCheetah(mujoco_env.MujocoEnv, utils.EzPickle):
+
+ def __init__(self, **kwargs):
+ mujoco_env.MujocoEnv.__init__(
+ self, os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'coupled_half_cheetah.xml'), 5
+ )
+ utils.EzPickle.__init__(self)
+
+ def step(self, action):
+ xposbefore1 = self.sim.data.qpos[0]
+ xposbefore2 = self.sim.data.qpos[len(self.sim.data.qpos) // 2]
+ self.do_simulation(action, self.frame_skip)
+ xposafter1 = self.sim.data.qpos[0]
+ xposafter2 = self.sim.data.qpos[len(self.sim.data.qpos) // 2]
+ ob = self._get_obs()
+ reward_ctrl1 = -0.1 * np.square(action[0:len(action) // 2]).sum()
+ reward_ctrl2 = -0.1 * np.square(action[len(action) // 2:]).sum()
+ reward_run1 = (xposafter1 - xposbefore1) / self.dt
+ reward_run2 = (xposafter2 - xposbefore2) / self.dt
+ reward = (reward_ctrl1 + reward_ctrl2) / 2.0 + (reward_run1 + reward_run2) / 2.0
+ done = False
+ return ob, reward, done, dict(
+ reward_run1=reward_run1, reward_ctrl1=reward_ctrl1, reward_run2=reward_run2, reward_ctrl2=reward_ctrl2
+ )
+
+ def _get_obs(self):
+ return np.concatenate([
+ self.sim.data.qpos.flat[1:],
+ self.sim.data.qvel.flat,
+ ])
+
+ def reset_model(self):
+ qpos = self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq)
+ qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
+ self.set_state(qpos, qvel)
+ return self._get_obs()
+
+ def viewer_setup(self):
+ self.viewer.cam.distance = self.model.stat.extent * 0.5
+
+ def get_env_info(self):
+ return {"episode_limit": self.episode_limit}
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/manyagent_ant.py b/DI-engine/dizoo/multiagent_mujoco/envs/manyagent_ant.py
new file mode 100644
index 0000000000000000000000000000000000000000..5bfb793780fa3c8ab53131038f594dfee730aab5
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/manyagent_ant.py
@@ -0,0 +1,120 @@
+import numpy as np
+from gym import utils
+from gym.envs.mujoco import mujoco_env
+from jinja2 import Template
+import os
+
+
+class ManyAgentAntEnv(mujoco_env.MujocoEnv, utils.EzPickle):
+
+ def __init__(self, **kwargs):
+ agent_conf = kwargs.get("agent_conf")
+ n_agents = int(agent_conf.split("x")[0])
+ n_segs_per_agents = int(agent_conf.split("x")[1])
+ n_segs = n_agents * n_segs_per_agents
+
+ # Check whether asset file exists already, otherwise create it
+ asset_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), 'assets',
+ 'manyagent_ant_{}_agents_each_{}_segments.auto.xml'.format(n_agents, n_segs_per_agents)
+ )
+ #if not os.path.exists(asset_path):
+ print("Auto-Generating Manyagent Ant asset with {} segments at {}.".format(n_segs, asset_path))
+ self._generate_asset(n_segs=n_segs, asset_path=asset_path)
+
+ #asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',git p
+ # 'manyagent_swimmer.xml')
+
+ mujoco_env.MujocoEnv.__init__(self, asset_path, 4)
+ utils.EzPickle.__init__(self)
+
+ def _generate_asset(self, n_segs, asset_path):
+ template_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets', 'manyagent_ant.xml.template')
+ with open(template_path, "r") as f:
+ t = Template(f.read())
+ body_str_template = """
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+
+ """
+
+ body_close_str_template = "\n"
+ actuator_str_template = """\t
+
+
+ \n"""
+
+ body_str = ""
+ for i in range(1, n_segs):
+ body_str += body_str_template.format(*([i] * 16))
+ body_str += body_close_str_template * (n_segs - 1)
+
+ actuator_str = ""
+ for i in range(n_segs):
+ actuator_str += actuator_str_template.format(*([i] * 8))
+
+ rt = t.render(body=body_str, actuators=actuator_str)
+ with open(asset_path, "w") as f:
+ f.write(rt)
+ pass
+
+ def step(self, a):
+ xposbefore = self.get_body_com("torso_0")[0]
+ self.do_simulation(a, self.frame_skip)
+ xposafter = self.get_body_com("torso_0")[0]
+ forward_reward = (xposafter - xposbefore) / self.dt
+ ctrl_cost = .5 * np.square(a).sum()
+ contact_cost = 0.5 * 1e-3 * np.sum(np.square(np.clip(self.sim.data.cfrc_ext, -1, 1)))
+ survive_reward = 1.0
+ reward = forward_reward - ctrl_cost - contact_cost + survive_reward
+ state = self.state_vector()
+ notdone = np.isfinite(state).all() \
+ and state[2] >= 0.2 and state[2] <= 1.0
+ done = not notdone
+ ob = self._get_obs()
+ return ob, reward, done, dict(
+ reward_forward=forward_reward,
+ reward_ctrl=-ctrl_cost,
+ reward_contact=-contact_cost,
+ reward_survive=survive_reward
+ )
+
+ def _get_obs(self):
+ return np.concatenate(
+ [
+ self.sim.data.qpos.flat[2:],
+ self.sim.data.qvel.flat,
+ np.clip(self.sim.data.cfrc_ext, -1, 1).flat,
+ ]
+ )
+
+ def reset_model(self):
+ qpos = self.init_qpos + self.np_random.uniform(size=self.model.nq, low=-.1, high=.1)
+ qvel = self.init_qvel + self.np_random.randn(self.model.nv) * .1
+ self.set_state(qpos, qvel)
+ return self._get_obs()
+
+ def viewer_setup(self):
+ self.viewer.cam.distance = self.model.stat.extent * 0.5
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/manyagent_swimmer.py b/DI-engine/dizoo/multiagent_mujoco/envs/manyagent_swimmer.py
new file mode 100644
index 0000000000000000000000000000000000000000..70e8677a01347b0522f5644de50be0e2ca071757
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/manyagent_swimmer.py
@@ -0,0 +1,89 @@
+import numpy as np
+from gym import utils
+from gym.envs.mujoco import mujoco_env
+import os
+from jinja2 import Template
+
+
+class ManyAgentSwimmerEnv(mujoco_env.MujocoEnv, utils.EzPickle):
+
+ def __init__(self, **kwargs):
+ agent_conf = kwargs.get("agent_conf")
+ n_agents = int(agent_conf.split("x")[0])
+ n_segs_per_agents = int(agent_conf.split("x")[1])
+ n_segs = n_agents * n_segs_per_agents
+
+ # Check whether asset file exists already, otherwise create it
+ asset_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), 'assets',
+ 'manyagent_swimmer_{}_agents_each_{}_segments.auto.xml'.format(n_agents, n_segs_per_agents)
+ )
+ # if not os.path.exists(asset_path):
+ print("Auto-Generating Manyagent Swimmer asset with {} segments at {}.".format(n_segs, asset_path))
+ self._generate_asset(n_segs=n_segs, asset_path=asset_path)
+
+ #asset_path = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'assets',git p
+ # 'manyagent_swimmer.xml')
+
+ mujoco_env.MujocoEnv.__init__(self, asset_path, 4)
+ utils.EzPickle.__init__(self)
+
+ def _generate_asset(self, n_segs, asset_path):
+ template_path = os.path.join(
+ os.path.dirname(os.path.abspath(__file__)), 'assets', 'manyagent_swimmer.xml.template'
+ )
+ with open(template_path, "r") as f:
+ t = Template(f.read())
+ body_str_template = """
+
+
+
+ """
+
+ body_end_str_template = """
+
+
+
+
+ """
+
+ body_close_str_template = "\n"
+ actuator_str_template = """\t \n"""
+
+ body_str = ""
+ for i in range(1, n_segs - 1):
+ body_str += body_str_template.format(i, (-1) ** (i + 1), i)
+ body_str += body_end_str_template.format(n_segs - 1)
+ body_str += body_close_str_template * (n_segs - 2)
+
+ actuator_str = ""
+ for i in range(n_segs):
+ actuator_str += actuator_str_template.format(i)
+
+ rt = t.render(body=body_str, actuators=actuator_str)
+ with open(asset_path, "w") as f:
+ f.write(rt)
+ pass
+
+ def step(self, a):
+ ctrl_cost_coeff = 0.0001
+ xposbefore = self.sim.data.qpos[0]
+ self.do_simulation(a, self.frame_skip)
+ xposafter = self.sim.data.qpos[0]
+ reward_fwd = (xposafter - xposbefore) / self.dt
+ reward_ctrl = -ctrl_cost_coeff * np.square(a).sum()
+ reward = reward_fwd + reward_ctrl
+ ob = self._get_obs()
+ return ob, reward, False, dict(reward_fwd=reward_fwd, reward_ctrl=reward_ctrl)
+
+ def _get_obs(self):
+ qpos = self.sim.data.qpos
+ qvel = self.sim.data.qvel
+ return np.concatenate([qpos.flat[2:], qvel.flat])
+
+ def reset_model(self):
+ self.set_state(
+ self.init_qpos + self.np_random.uniform(low=-.1, high=.1, size=self.model.nq),
+ self.init_qvel + self.np_random.uniform(low=-.1, high=.1, size=self.model.nv)
+ )
+ return self._get_obs()
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/mujoco_multi.py b/DI-engine/dizoo/multiagent_mujoco/envs/mujoco_multi.py
new file mode 100755
index 0000000000000000000000000000000000000000..bd777a6da0570654e822ecb86350fc8ca7177d92
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/mujoco_multi.py
@@ -0,0 +1,250 @@
+from functools import partial
+import gym
+from gym.spaces import Box
+from gym.wrappers import TimeLimit
+import numpy as np
+
+from .multiagentenv import MultiAgentEnv
+from .obsk import get_joints_at_kdist, get_parts_and_edges, build_obs
+
+
+# using code from https://github.com/ikostrikov/pytorch-ddpg-naf
+class NormalizedActions(gym.ActionWrapper):
+
+ def _action(self, action):
+ action = (action + 1) / 2
+ action *= (self.action_space.high - self.action_space.low)
+ action += self.action_space.low
+ return action
+
+ def action(self, action_):
+ return self._action(action_)
+
+ def _reverse_action(self, action):
+ action -= self.action_space.low
+ action /= (self.action_space.high - self.action_space.low)
+ action = action * 2 - 1
+ return action
+
+
+class MujocoMulti(MultiAgentEnv):
+
+ def __init__(self, batch_size=None, **kwargs):
+ super().__init__(batch_size, **kwargs)
+ self.add_agent_id = kwargs["env_args"]["add_agent_id"]
+ self.scenario = kwargs["env_args"]["scenario"] # e.g. Ant-v2
+ self.agent_conf = kwargs["env_args"]["agent_conf"] # e.g. '2x3'
+
+ self.agent_partitions, self.mujoco_edges, self.mujoco_globals = get_parts_and_edges(
+ self.scenario, self.agent_conf
+ )
+
+ self.n_agents = len(self.agent_partitions)
+ self.n_actions = max([len(l) for l in self.agent_partitions])
+ self.obs_add_global_pos = kwargs["env_args"].get("obs_add_global_pos", False)
+
+ self.agent_obsk = kwargs["env_args"].get(
+ "agent_obsk", None
+ ) # if None, fully observable else k>=0 implies observe nearest k agents or joints
+ self.agent_obsk_agents = kwargs["env_args"].get(
+ "agent_obsk_agents", False
+ ) # observe full k nearest agents (True) or just single joints (False)
+
+ if self.agent_obsk is not None:
+ self.k_categories_label = kwargs["env_args"].get("k_categories")
+ if self.k_categories_label is None:
+ if self.scenario in ["Ant-v2", "manyagent_ant"]:
+ self.k_categories_label = "qpos,qvel,cfrc_ext|qpos"
+ elif self.scenario in ["Humanoid-v2", "HumanoidStandup-v2"]:
+ self.k_categories_label = "qpos,qvel,cfrc_ext,cvel,cinert,qfrc_actuator|qpos"
+ elif self.scenario in ["Reacher-v2"]:
+ self.k_categories_label = "qpos,qvel,fingertip_dist|qpos"
+ elif self.scenario in ["coupled_half_cheetah"]:
+ self.k_categories_label = "qpos,qvel,ten_J,ten_length,ten_velocity|"
+ else:
+ self.k_categories_label = "qpos,qvel|qpos"
+
+ k_split = self.k_categories_label.split("|")
+ self.k_categories = [k_split[k if k < len(k_split) else -1].split(",") for k in range(self.agent_obsk + 1)]
+
+ self.global_categories_label = kwargs["env_args"].get("global_categories")
+ self.global_categories = self.global_categories_label.split(
+ ","
+ ) if self.global_categories_label is not None else []
+
+ if self.agent_obsk is not None:
+ self.k_dicts = [
+ get_joints_at_kdist(
+ agent_id,
+ self.agent_partitions,
+ self.mujoco_edges,
+ k=self.agent_obsk,
+ kagents=False,
+ ) for agent_id in range(self.n_agents)
+ ]
+
+ # load scenario from script
+ self.episode_limit = self.args.episode_limit
+
+ self.env_version = kwargs["env_args"].get("env_version", 2)
+ if self.env_version == 2:
+ try:
+ self.wrapped_env = NormalizedActions(gym.make(self.scenario))
+ except gym.error.Error: # env not in gym
+ if self.scenario in ["manyagent_ant"]:
+ from .manyagent_ant import ManyAgentAntEnv as this_env
+ elif self.scenario in ["manyagent_swimmer"]:
+ from .manyagent_swimmer import ManyAgentSwimmerEnv as this_env
+ elif self.scenario in ["coupled_half_cheetah"]:
+ from .coupled_half_cheetah import CoupledHalfCheetah as this_env
+ else:
+ raise NotImplementedError('Custom env not implemented!')
+ self.wrapped_env = NormalizedActions(
+ TimeLimit(this_env(**kwargs["env_args"]), max_episode_steps=self.episode_limit)
+ )
+ else:
+ assert False, "not implemented!"
+ self.timelimit_env = self.wrapped_env.env
+ self.timelimit_env._max_episode_steps = self.episode_limit
+ if gym.version.VERSION > '0.22.0': # for compatibility
+ # get original no wrapped env
+ self.env = self.timelimit_env.env.env.env.env
+ else:
+ self.env = self.timelimit_env.env
+ self.timelimit_env.reset()
+ self.obs_size = self.get_obs_size()
+
+ # COMPATIBILITY
+ self.n = self.n_agents
+ self.observation_space = [
+ Box(low=np.array([-10] * self.n_agents), high=np.array([10] * self.n_agents)) for _ in range(self.n_agents)
+ ]
+
+ acdims = [len(ap) for ap in self.agent_partitions]
+ self.action_space = tuple(
+ [
+ Box(
+ self.env.action_space.low[sum(acdims[:a]):sum(acdims[:a + 1])],
+ self.env.action_space.high[sum(acdims[:a]):sum(acdims[:a + 1])]
+ ) for a in range(self.n_agents)
+ ]
+ )
+
+ def step(self, actions):
+
+ # need to remove dummy actions that arise due to unequal action vector sizes across agents
+ flat_actions = np.concatenate([actions[i][:self.action_space[i].low.shape[0]] for i in range(self.n_agents)])
+ obs_n, reward_n, done_n, info_n = self.wrapped_env.step(flat_actions)
+ self.steps += 1
+
+ info = {}
+ info.update(info_n)
+
+ if done_n:
+ if self.steps < self.episode_limit:
+ info["episode_limit"] = False # the next state will be masked out
+ else:
+ info["episode_limit"] = True # the next state will not be masked out
+
+ obs = {'agent_state': self.get_obs(), 'global_state': self.get_state()}
+
+ return obs, reward_n, done_n, info
+
+ def get_obs(self):
+ """ Returns all agent observat3ions in a list """
+ obs_n = []
+ for a in range(self.n_agents):
+ obs_n.append(self.get_obs_agent(a))
+ return np.array(obs_n).astype(np.float32)
+
+ def get_obs_agent(self, agent_id):
+ if self.agent_obsk is None:
+ return self.env._get_obs()
+ else:
+ return build_obs(
+ self.env,
+ self.k_dicts[agent_id],
+ self.k_categories,
+ self.mujoco_globals,
+ self.global_categories,
+ vec_len=getattr(self, "obs_size", None)
+ )
+
+ def get_obs_size(self):
+ """ Returns the shape of the observation """
+ if self.agent_obsk is None:
+ return self.get_obs_agent(0).size
+ else:
+ return max([len(self.get_obs_agent(agent_id)) for agent_id in range(self.n_agents)])
+
+ def get_state(self, team=None):
+ # TODO: May want global states for different teams (so cannot see what the other team is communicating e.g.)
+ state_n = []
+ if self.add_agent_id:
+ state = self.env._get_obs()
+ for a in range(self.n_agents):
+ agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)
+ agent_id_feats[a] = 1.0
+ state_i = np.concatenate([state, agent_id_feats])
+ state_n.append(state_i)
+ else:
+ for a in range(self.n_agents):
+ state_n.append(self.env._get_obs())
+ return np.array(state_n).astype(np.float32)
+
+ def get_state_size(self):
+ """ Returns the shape of the state"""
+ return len(self.get_state())
+
+ def get_avail_actions(self): # all actions are always available
+ return np.ones(shape=(
+ self.n_agents,
+ self.n_actions,
+ ))
+
+ def get_avail_agent_actions(self, agent_id):
+ """ Returns the available actions for agent_id """
+ return np.ones(shape=(self.n_actions, ))
+
+ def get_total_actions(self):
+ """ Returns the total number of actions an agent could ever take """
+ return self.n_actions # CAREFUL! - for continuous dims, this is action space dim rather
+ # return self.env.action_space.shape[0]
+
+ def get_stats(self):
+ return {}
+
+ # TODO: Temp hack
+ def get_agg_stats(self, stats):
+ return {}
+
+ def reset(self, **kwargs):
+ """ Returns initial observations and states"""
+ self.steps = 0
+ self.timelimit_env.reset()
+ obs = {'agent_state': self.get_obs(), 'global_state': self.get_state()}
+ return obs
+
+ def render(self, **kwargs):
+ self.env.render(**kwargs)
+
+ def close(self):
+ pass
+ #raise NotImplementedError
+
+ def seed(self, args):
+ pass
+
+ def get_env_info(self):
+
+ env_info = {
+ "state_shape": self.get_state_size(),
+ "obs_shape": self.get_obs_size(),
+ "n_actions": self.get_total_actions(),
+ "n_agents": self.n_agents,
+ "episode_limit": self.episode_limit,
+ "action_spaces": self.action_space,
+ "actions_dtype": np.float32,
+ "normalise_actions": False
+ }
+ return env_info
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/multi_mujoco_env.py b/DI-engine/dizoo/multiagent_mujoco/envs/multi_mujoco_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ed538909c0f1c38d0bcda0de811d3bf222194c1
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/multi_mujoco_env.py
@@ -0,0 +1,109 @@
+from typing import Any, Union, List
+import copy
+import numpy as np
+from numpy import dtype
+import gym
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.envs.common.common_function import affine_transform
+from ding.torch_utils import to_ndarray, to_list
+from ding.utils import ENV_REGISTRY
+from .mujoco_multi import MujocoMulti
+
+
+@ENV_REGISTRY.register('mujoco_multi')
+class MujocoEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+
+ def reset(self) -> np.ndarray:
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._cfg.seed = self._seed + np_seed
+ elif hasattr(self, '_seed'):
+ self._cfg.seed = self._seed
+ if not self._init_flag:
+ self._env = MujocoMulti(env_args=self._cfg)
+ self._init_flag = True
+ obs = self._env.reset()
+ self._eval_episode_return = 0.
+
+ # TODO:
+ # self.env_info for scenario='Ant-v2', agent_conf="2x4d",
+ # {'state_shape': 2, 'obs_shape': 54,...}
+ # 'state_shape' is wrong, it should be 111
+ self.env_info = self._env.get_env_info()
+ # self._env.observation_space[agent].shape equals above 'state_shape'
+
+ self._num_agents = self.env_info['n_agents']
+ self._agents = [i for i in range(self._num_agents)]
+ self._observation_space = gym.spaces.Dict(
+ {
+ 'agent_state': gym.spaces.Box(
+ low=float("-inf"), high=float("inf"), shape=obs['agent_state'].shape, dtype=np.float32
+ ),
+ 'global_state': gym.spaces.Box(
+ low=float("-inf"), high=float("inf"), shape=obs['global_state'].shape, dtype=np.float32
+ ),
+ }
+ )
+ self._action_space = gym.spaces.Dict({agent: self._env.action_space[agent] for agent in self._agents})
+ single_agent_obs_space = self._env.action_space[self._agents[0]]
+ if isinstance(single_agent_obs_space, gym.spaces.Box):
+ self._action_dim = single_agent_obs_space.shape
+ elif isinstance(single_agent_obs_space, gym.spaces.Discrete):
+ self._action_dim = (single_agent_obs_space.n, )
+ else:
+ raise Exception('Only support `Box` or `Discrte` obs space for single agent.')
+ self._reward_space = gym.spaces.Dict(
+ {
+ agent: gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32)
+ for agent in self._agents
+ }
+ )
+
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
+ action = to_ndarray(action)
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ @property
+ def num_agents(self) -> Any:
+ return self._num_agents
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine Multi-agent Mujoco Env({})".format(self._cfg.env_id)
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/multiagentenv.py b/DI-engine/dizoo/multiagent_mujoco/envs/multiagentenv.py
new file mode 100755
index 0000000000000000000000000000000000000000..07e65fc549a98d6a85d49d3ab77d7614ed9e7fca
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/multiagentenv.py
@@ -0,0 +1,85 @@
+from collections import namedtuple
+import numpy as np
+
+
+def convert(dictionary):
+ return namedtuple('GenericDict', dictionary.keys())(**dictionary)
+
+
+class MultiAgentEnv(object):
+
+ def __init__(self, batch_size=None, **kwargs):
+ # Unpack arguments from sacred
+ args = kwargs["env_args"]
+ if isinstance(args, dict):
+ args = convert(args)
+ self.args = args
+
+ if getattr(args, "seed", None) is not None:
+ self.seed = args.seed
+ self.rs = np.random.RandomState(self.seed) # initialise numpy random state
+
+ def step(self, actions):
+ """ Returns reward, terminated, info """
+ raise NotImplementedError
+
+ def get_obs(self):
+ """ Returns all agent observations in a list """
+ raise NotImplementedError
+
+ def get_obs_agent(self, agent_id):
+ """ Returns observation for agent_id """
+ raise NotImplementedError
+
+ def get_obs_size(self):
+ """ Returns the shape of the observation """
+ raise NotImplementedError
+
+ def get_state(self):
+ raise NotImplementedError
+
+ def get_state_size(self):
+ """ Returns the shape of the state"""
+ raise NotImplementedError
+
+ def get_avail_actions(self):
+ raise NotImplementedError
+
+ def get_avail_agent_actions(self, agent_id):
+ """ Returns the available actions for agent_id """
+ raise NotImplementedError
+
+ def get_total_actions(self):
+ """ Returns the total number of actions an agent could ever take """
+ # TODO: This is only suitable for a discrete 1 dimensional action space for each agent
+ raise NotImplementedError
+
+ def get_stats(self):
+ raise NotImplementedError
+
+ # TODO: Temp hack
+ def get_agg_stats(self, stats):
+ return {}
+
+ def reset(self):
+ """ Returns initial observations and states"""
+ raise NotImplementedError
+
+ def render(self):
+ raise NotImplementedError
+
+ def close(self):
+ raise NotImplementedError
+
+ def seed(self, seed):
+ raise NotImplementedError
+
+ def get_env_info(self):
+ env_info = {
+ "state_shape": self.get_state_size(),
+ "obs_shape": self.get_obs_size(),
+ "n_actions": self.get_total_actions(),
+ "n_agents": self.n_agents,
+ "episode_limit": self.episode_limit
+ }
+ return env_info
diff --git a/DI-engine/dizoo/multiagent_mujoco/envs/obsk.py b/DI-engine/dizoo/multiagent_mujoco/envs/obsk.py
new file mode 100644
index 0000000000000000000000000000000000000000..404f455abe0711a53febe8025c71f46584e5b70f
--- /dev/null
+++ b/DI-engine/dizoo/multiagent_mujoco/envs/obsk.py
@@ -0,0 +1,662 @@
+import itertools
+import numpy as np
+from copy import deepcopy
+
+
+class Node():
+
+ def __init__(self, label, qpos_ids, qvel_ids, act_ids, body_fn=None, bodies=None, extra_obs=None, tendons=None):
+ self.label = label
+ self.qpos_ids = qpos_ids
+ self.qvel_ids = qvel_ids
+ self.act_ids = act_ids
+ self.bodies = bodies
+ self.extra_obs = {} if extra_obs is None else extra_obs
+ self.body_fn = body_fn
+ self.tendons = tendons
+ pass
+
+ def __str__(self):
+ return self.label
+
+ def __repr__(self):
+ return self.label
+
+
+class HyperEdge():
+
+ def __init__(self, *edges):
+ self.edges = set(edges)
+
+ def __contains__(self, item):
+ return item in self.edges
+
+ def __str__(self):
+ return "HyperEdge({})".format(self.edges)
+
+ def __repr__(self):
+ return "HyperEdge({})".format(self.edges)
+
+
+def get_joints_at_kdist(
+ agent_id,
+ agent_partitions,
+ hyperedges,
+ k=0,
+ kagents=False,
+):
+ """ Identify all joints at distance <= k from agent agent_id
+
+ :param agent_id: id of agent to be considered
+ :param agent_partitions: list of joint tuples in order of agentids
+ :param edges: list of tuples (joint1, joint2)
+ :param k: kth degree
+ :param kagents: True (observe all joints of an agent if a single one is) or False (individual joint granularity)
+ :return:
+ dict with k as key, and list of joints at that distance
+ """
+ assert not kagents, "kagents not implemented!"
+
+ agent_joints = agent_partitions[agent_id]
+
+ def _adjacent(lst, kagents=False):
+ # return all sets adjacent to any element in lst
+ ret = set([])
+ for l in lst:
+ ret = ret.union(set(itertools.chain(*[e.edges.difference({l}) for e in hyperedges if l in e])))
+ return ret
+
+ seen = set([])
+ new = set([])
+ k_dict = {}
+ for _k in range(k + 1):
+ if not _k:
+ new = set(agent_joints)
+ else:
+ print(hyperedges)
+ new = _adjacent(new) - seen
+ seen = seen.union(new)
+ k_dict[_k] = sorted(list(new), key=lambda x: x.label)
+ return k_dict
+
+
+def build_obs(env, k_dict, k_categories, global_dict, global_categories, vec_len=None):
+ """Given a k_dict from get_joints_at_kdist, extract observation vector.
+
+ :param k_dict: k_dict
+ :param qpos: qpos numpy array
+ :param qvel: qvel numpy array
+ :param vec_len: if None no padding, else zero-pad to vec_len
+ :return:
+ observation vector
+ """
+
+ # TODO: This needs to be fixed, it was designed for half-cheetah only!
+ #if add_global_pos:
+ # obs_qpos_lst.append(global_qpos)
+ # obs_qvel_lst.append(global_qvel)
+
+ body_set_dict = {}
+ obs_lst = []
+ # Add parts attributes
+ for k in sorted(list(k_dict.keys())):
+ cats = k_categories[k]
+ for _t in k_dict[k]:
+ for c in cats:
+ if c in _t.extra_obs:
+ items = _t.extra_obs[c](env).tolist()
+ obs_lst.extend(items if isinstance(items, list) else [items])
+ else:
+ if c in ["qvel", "qpos"]: # this is a "joint position/velocity" item
+ items = getattr(env.sim.data, c)[getattr(_t, "{}_ids".format(c))]
+ obs_lst.extend(items if isinstance(items, list) else [items])
+ elif c in ["qfrc_actuator"]: # this is a "vel position" item
+ items = getattr(env.sim.data, c)[getattr(_t, "{}_ids".format("qvel"))]
+ obs_lst.extend(items if isinstance(items, list) else [items])
+ elif c in ["cvel", "cinert", "cfrc_ext"]: # this is a "body position" item
+ if _t.bodies is not None:
+ for b in _t.bodies:
+ if c not in body_set_dict:
+ body_set_dict[c] = set()
+ if b not in body_set_dict[c]:
+ items = getattr(env.sim.data, c)[b].tolist()
+ items = getattr(_t, "body_fn", lambda _id, x: x)(b, items)
+ obs_lst.extend(items if isinstance(items, list) else [items])
+ body_set_dict[c].add(b)
+
+ # Add global attributes
+ body_set_dict = {}
+ for c in global_categories:
+ if c in ["qvel", "qpos"]: # this is a "joint position" item
+ for j in global_dict.get("joints", []):
+ items = getattr(env.sim.data, c)[getattr(j, "{}_ids".format(c))]
+ obs_lst.extend(items if isinstance(items, list) else [items])
+ else:
+ for b in global_dict.get("bodies", []):
+ if c not in body_set_dict:
+ body_set_dict[c] = set()
+ if b not in body_set_dict[c]:
+ obs_lst.extend(getattr(env.sim.data, c)[b].tolist())
+ body_set_dict[c].add(b)
+
+ if vec_len is not None:
+ pad = np.array((vec_len - len(obs_lst)) * [0])
+ if len(pad):
+ return np.concatenate([np.array(obs_lst), pad])
+ return np.array(obs_lst)
+
+
+def build_actions(agent_partitions, k_dict):
+ # Composes agent actions output from networks
+ # into coherent joint action vector to be sent to the env.
+ pass
+
+
+def get_parts_and_edges(label, partitioning):
+ if label in ["half_cheetah", "HalfCheetah-v2"]:
+
+ # define Mujoco graph
+ bthigh = Node("bthigh", -6, -6, 0)
+ bshin = Node("bshin", -5, -5, 1)
+ bfoot = Node("bfoot", -4, -4, 2)
+ fthigh = Node("fthigh", -3, -3, 3)
+ fshin = Node("fshin", -2, -2, 4)
+ ffoot = Node("ffoot", -1, -1, 5)
+
+ edges = [
+ HyperEdge(bfoot, bshin),
+ HyperEdge(bshin, bthigh),
+ HyperEdge(bthigh, fthigh),
+ HyperEdge(fthigh, fshin),
+ HyperEdge(fshin, ffoot)
+ ]
+
+ root_x = Node("root_x", 0, 0, -1, extra_obs={"qpos": lambda env: np.array([])})
+ root_z = Node("root_z", 1, 1, -1)
+ root_y = Node("root_y", 2, 2, -1)
+ globals = {"joints": [root_x, root_y, root_z]}
+
+ if partitioning == "2x3":
+ parts = [(bfoot, bshin, bthigh), (ffoot, fshin, fthigh)]
+ elif partitioning == "6x1":
+ parts = [(bfoot, ), (bshin, ), (bthigh, ), (ffoot, ), (fshin, ), (fthigh, )]
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["Ant-v2"]:
+
+ # define Mujoco graph
+ torso = 1
+ front_left_leg = 2
+ aux_1 = 3
+ ankle_1 = 4
+ front_right_leg = 5
+ aux_2 = 6
+ ankle_2 = 7
+ back_leg = 8
+ aux_3 = 9
+ ankle_3 = 10
+ right_back_leg = 11
+ aux_4 = 12
+ ankle_4 = 13
+
+ hip1 = Node(
+ "hip1", -8, -8, 2, bodies=[torso, front_left_leg], body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #
+ ankle1 = Node(
+ "ankle1",
+ -7,
+ -7,
+ 3,
+ bodies=[front_left_leg, aux_1, ankle_1],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #,
+ hip2 = Node(
+ "hip2", -6, -6, 4, bodies=[torso, front_right_leg], body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #,
+ ankle2 = Node(
+ "ankle2",
+ -5,
+ -5,
+ 5,
+ bodies=[front_right_leg, aux_2, ankle_2],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #,
+ hip3 = Node("hip3", -4, -4, 6, bodies=[torso, back_leg], body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()) #,
+ ankle3 = Node(
+ "ankle3", -3, -3, 7, bodies=[back_leg, aux_3, ankle_3], body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #,
+ hip4 = Node(
+ "hip4", -2, -2, 0, bodies=[torso, right_back_leg], body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #,
+ ankle4 = Node(
+ "ankle4",
+ -1,
+ -1,
+ 1,
+ bodies=[right_back_leg, aux_4, ankle_4],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ ) #,
+
+ edges = [
+ HyperEdge(ankle4, hip4),
+ HyperEdge(ankle1, hip1),
+ HyperEdge(ankle2, hip2),
+ HyperEdge(ankle3, hip3),
+ HyperEdge(hip4, hip1, hip2, hip3),
+ ]
+
+ free_joint = Node(
+ "free",
+ 0,
+ 0,
+ -1,
+ extra_obs={
+ "qpos": lambda env: env.sim.data.qpos[:7],
+ "qvel": lambda env: env.sim.data.qvel[:6],
+ "cfrc_ext": lambda env: np.clip(env.sim.data.cfrc_ext[0:1], -1, 1)
+ }
+ )
+ globals = {"joints": [free_joint]}
+
+ if partitioning == "2x4": # neighbouring legs together
+ parts = [(hip1, ankle1, hip2, ankle2), (hip3, ankle3, hip4, ankle4)]
+ elif partitioning == "2x4d": # diagonal legs together
+ parts = [(hip1, ankle1, hip3, ankle3), (hip2, ankle2, hip4, ankle4)]
+ elif partitioning == "4x2":
+ parts = [(hip1, ankle1), (hip2, ankle2), (hip3, ankle3), (hip4, ankle4)]
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["Hopper-v2"]:
+
+ # define Mujoco-Graph
+ thigh_joint = Node(
+ "thigh_joint",
+ -3,
+ -3,
+ 0,
+ extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[-3]]), -10, 10)}
+ )
+ leg_joint = Node(
+ "leg_joint", -2, -2, 1, extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[-2]]), -10, 10)}
+ )
+ foot_joint = Node(
+ "foot_joint",
+ -1,
+ -1,
+ 2,
+ extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[-1]]), -10, 10)}
+ )
+
+ edges = [HyperEdge(foot_joint, leg_joint), HyperEdge(leg_joint, thigh_joint)]
+
+ root_x = Node(
+ "root_x",
+ 0,
+ 0,
+ -1,
+ extra_obs={
+ "qpos": lambda env: np.array([]),
+ "qvel": lambda env: np.clip(np.array([env.sim.data.qvel[1]]), -10, 10)
+ }
+ )
+ root_z = Node(
+ "root_z", 1, 1, -1, extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[1]]), -10, 10)}
+ )
+ root_y = Node(
+ "root_y", 2, 2, -1, extra_obs={"qvel": lambda env: np.clip(np.array([env.sim.data.qvel[2]]), -10, 10)}
+ )
+ globals = {"joints": [root_x, root_y, root_z]}
+
+ if partitioning == "3x1":
+ parts = [(thigh_joint, ), (leg_joint, ), (foot_joint, )]
+
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["Humanoid-v2", "HumanoidStandup-v2"]:
+
+ # define Mujoco-Graph
+ abdomen_y = Node("abdomen_y", -16, -16, 0) # act ordering bug in env -- double check!
+ abdomen_z = Node("abdomen_z", -17, -17, 1)
+ abdomen_x = Node("abdomen_x", -15, -15, 2)
+ right_hip_x = Node("right_hip_x", -14, -14, 3)
+ right_hip_z = Node("right_hip_z", -13, -13, 4)
+ right_hip_y = Node("right_hip_y", -12, -12, 5)
+ right_knee = Node("right_knee", -11, -11, 6)
+ left_hip_x = Node("left_hip_x", -10, -10, 7)
+ left_hip_z = Node("left_hip_z", -9, -9, 8)
+ left_hip_y = Node("left_hip_y", -8, -8, 9)
+ left_knee = Node("left_knee", -7, -7, 10)
+ right_shoulder1 = Node("right_shoulder1", -6, -6, 11)
+ right_shoulder2 = Node("right_shoulder2", -5, -5, 12)
+ right_elbow = Node("right_elbow", -4, -4, 13)
+ left_shoulder1 = Node("left_shoulder1", -3, -3, 14)
+ left_shoulder2 = Node("left_shoulder2", -2, -2, 15)
+ left_elbow = Node("left_elbow", -1, -1, 16)
+
+ edges = [
+ HyperEdge(abdomen_x, abdomen_y, abdomen_z),
+ HyperEdge(right_hip_x, right_hip_y, right_hip_z),
+ HyperEdge(left_hip_x, left_hip_y, left_hip_z),
+ HyperEdge(left_elbow, left_shoulder1, left_shoulder2),
+ HyperEdge(right_elbow, right_shoulder1, right_shoulder2),
+ HyperEdge(left_knee, left_hip_x, left_hip_y, left_hip_z),
+ HyperEdge(right_knee, right_hip_x, right_hip_y, right_hip_z),
+ HyperEdge(left_shoulder1, left_shoulder2, abdomen_x, abdomen_y, abdomen_z),
+ HyperEdge(right_shoulder1, right_shoulder2, abdomen_x, abdomen_y, abdomen_z),
+ HyperEdge(abdomen_x, abdomen_y, abdomen_z, left_hip_x, left_hip_y, left_hip_z),
+ HyperEdge(abdomen_x, abdomen_y, abdomen_z, right_hip_x, right_hip_y, right_hip_z),
+ ]
+
+ globals = {}
+
+ if partitioning == "9|8": # 17 in total, so one action is a dummy (to be handled by pymarl)
+ # isolate upper and lower body
+ parts = [
+ (
+ left_shoulder1, left_shoulder2, abdomen_x, abdomen_y, abdomen_z, right_shoulder1, right_shoulder2,
+ right_elbow, left_elbow
+ ), (left_hip_x, left_hip_y, left_hip_z, right_hip_x, right_hip_y, right_hip_z, right_knee, left_knee)
+ ]
+ # TODO: There could be tons of decompositions here
+
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["Reacher-v2"]:
+
+ # define Mujoco-Graph
+ body0 = 1
+ body1 = 2
+ fingertip = 3
+ joint0 = Node(
+ "joint0",
+ -4,
+ -4,
+ 0,
+ bodies=[body0, body1],
+ extra_obs={"qpos": (lambda env: np.array([np.sin(env.sim.data.qpos[-4]),
+ np.cos(env.sim.data.qpos[-4])]))}
+ )
+ joint1 = Node(
+ "joint1",
+ -3,
+ -3,
+ 1,
+ bodies=[body1, fingertip],
+ extra_obs={
+ "fingertip_dist": (lambda env: env.get_body_com("fingertip") - env.get_body_com("target")),
+ "qpos": (lambda env: np.array([np.sin(env.sim.data.qpos[-3]),
+ np.cos(env.sim.data.qpos[-3])]))
+ }
+ )
+ edges = [HyperEdge(joint0, joint1)]
+
+ worldbody = 0
+ target = 4
+ target_x = Node("target_x", -2, -2, -1, extra_obs={"qvel": (lambda env: np.array([]))})
+ target_y = Node("target_y", -1, -1, -1, extra_obs={"qvel": (lambda env: np.array([]))})
+ globals = {"bodies": [worldbody, target], "joints": [target_x, target_y]}
+
+ if partitioning == "2x1":
+ # isolate upper and lower arms
+ parts = [(joint0, ), (joint1, )]
+ # TODO: There could be tons of decompositions here
+
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["Swimmer-v2"]:
+
+ # define Mujoco-Graph
+ joint0 = Node("rot2", -2, -2, 0) # TODO: double-check ids
+ joint1 = Node("rot3", -1, -1, 1)
+
+ edges = [HyperEdge(joint0, joint1)]
+ globals = {}
+
+ if partitioning == "2x1":
+ # isolate upper and lower body
+ parts = [(joint0, ), (joint1, )]
+ # TODO: There could be tons of decompositions here
+
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["Walker2d-v2"]:
+
+ # define Mujoco-Graph
+ thigh_joint = Node("thigh_joint", -6, -6, 0)
+ leg_joint = Node("leg_joint", -5, -5, 1)
+ foot_joint = Node("foot_joint", -4, -4, 2)
+ thigh_left_joint = Node("thigh_left_joint", -3, -3, 3)
+ leg_left_joint = Node("leg_left_joint", -2, -2, 4)
+ foot_left_joint = Node("foot_left_joint", -1, -1, 5)
+
+ edges = [
+ HyperEdge(foot_joint, leg_joint),
+ HyperEdge(leg_joint, thigh_joint),
+ HyperEdge(foot_left_joint, leg_left_joint),
+ HyperEdge(leg_left_joint, thigh_left_joint),
+ HyperEdge(thigh_joint, thigh_left_joint)
+ ]
+ globals = {}
+
+ if partitioning == "2x3":
+ # isolate upper and lower body
+ parts = [(foot_joint, leg_joint, thigh_joint), (
+ foot_left_joint,
+ leg_left_joint,
+ thigh_left_joint,
+ )]
+ # TODO: There could be tons of decompositions here
+
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["coupled_half_cheetah"]:
+
+ # define Mujoco graph
+ tendon = 0
+
+ bthigh = Node(
+ "bthigh",
+ -6,
+ -6,
+ 0,
+ tendons=[tendon],
+ extra_obs={
+ "ten_J": lambda env: env.sim.data.ten_J[tendon],
+ "ten_length": lambda env: env.sim.data.ten_length,
+ "ten_velocity": lambda env: env.sim.data.ten_velocity
+ }
+ )
+ bshin = Node("bshin", -5, -5, 1)
+ bfoot = Node("bfoot", -4, -4, 2)
+ fthigh = Node("fthigh", -3, -3, 3)
+ fshin = Node("fshin", -2, -2, 4)
+ ffoot = Node("ffoot", -1, -1, 5)
+
+ bthigh2 = Node(
+ "bthigh2",
+ -6,
+ -6,
+ 0,
+ tendons=[tendon],
+ extra_obs={
+ "ten_J": lambda env: env.sim.data.ten_J[tendon],
+ "ten_length": lambda env: env.sim.data.ten_length,
+ "ten_velocity": lambda env: env.sim.data.ten_velocity
+ }
+ )
+ bshin2 = Node("bshin2", -5, -5, 1)
+ bfoot2 = Node("bfoot2", -4, -4, 2)
+ fthigh2 = Node("fthigh2", -3, -3, 3)
+ fshin2 = Node("fshin2", -2, -2, 4)
+ ffoot2 = Node("ffoot2", -1, -1, 5)
+
+ edges = [
+ HyperEdge(bfoot, bshin),
+ HyperEdge(bshin, bthigh),
+ HyperEdge(bthigh, fthigh),
+ HyperEdge(fthigh, fshin),
+ HyperEdge(fshin, ffoot),
+ HyperEdge(bfoot2, bshin2),
+ HyperEdge(bshin2, bthigh2),
+ HyperEdge(bthigh2, fthigh2),
+ HyperEdge(fthigh2, fshin2),
+ HyperEdge(fshin2, ffoot2)
+ ]
+ globals = {}
+
+ root_x = Node("root_x", 0, 0, -1, extra_obs={"qpos": lambda env: np.array([])})
+ root_z = Node("root_z", 1, 1, -1)
+ root_y = Node("root_y", 2, 2, -1)
+ globals = {"joints": [root_x, root_y, root_z]}
+
+ if partitioning == "1p1":
+ parts = [(bfoot, bshin, bthigh, ffoot, fshin, fthigh), (bfoot2, bshin2, bthigh2, ffoot2, fshin2, fthigh2)]
+ else:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ return parts, edges, globals
+
+ elif label in ["manyagent_swimmer"]:
+
+ # Generate asset file
+ try:
+ n_agents = int(partitioning.split("x")[0])
+ n_segs_per_agents = int(partitioning.split("x")[1])
+ n_segs = n_agents * n_segs_per_agents
+ except Exception as e:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ # Note: Default Swimmer corresponds to n_segs = 3
+
+ # define Mujoco-Graph
+ joints = [Node("rot{:d}".format(i), -n_segs + i, -n_segs + i, i) for i in range(0, n_segs)]
+ edges = [HyperEdge(joints[i], joints[i + 1]) for i in range(n_segs - 1)]
+ globals = {}
+
+ parts = [tuple(joints[i * n_segs_per_agents:(i + 1) * n_segs_per_agents]) for i in range(n_agents)]
+ return parts, edges, globals
+
+ elif label in ["manyagent_ant"]: # TODO: FIX!
+
+ # Generate asset file
+ try:
+ n_agents = int(partitioning.split("x")[0])
+ n_segs_per_agents = int(partitioning.split("x")[1])
+ n_segs = n_agents * n_segs_per_agents
+ except Exception as e:
+ raise Exception("UNKNOWN partitioning config: {}".format(partitioning))
+
+ # # define Mujoco graph
+ # torso = 1
+ # front_left_leg = 2
+ # aux_1 = 3
+ # ankle_1 = 4
+ # right_back_leg = 11
+ # aux_4 = 12
+ # ankle_4 = 13
+ #
+ # off = -4*(n_segs-1)
+ # hip1 = Node("hip1", -4-off, -4-off, 2, bodies=[torso, front_left_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist()) #
+ # ankle1 = Node("ankle1", -3-off, -3-off, 3, bodies=[front_left_leg, aux_1, ankle_1], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,
+ # hip4 = Node("hip4", -2-off, -2-off, 0, bodies=[torso, right_back_leg], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,
+ # ankle4 = Node("ankle4", -1-off, -1-off, 1, bodies=[right_back_leg, aux_4, ankle_4], body_fn=lambda _id, x:np.clip(x, -1, 1).tolist())#,
+ #
+ # edges = [HyperEdge(ankle4, hip4),
+ # HyperEdge(ankle1, hip1),
+ # HyperEdge(hip4, hip1),
+ # ]
+
+ edges = []
+ joints = []
+ for si in range(n_segs):
+
+ torso = 1 + si * 7
+ front_right_leg = 2 + si * 7
+ aux1 = 3 + si * 7
+ ankle1 = 4 + si * 7
+ back_leg = 5 + si * 7
+ aux2 = 6 + si * 7
+ ankle2 = 7 + si * 7
+
+ off = -4 * (n_segs - 1 - si)
+ hip1n = Node(
+ "hip1_{:d}".format(si),
+ -4 - off,
+ -4 - off,
+ 2 + 4 * si,
+ bodies=[torso, front_right_leg],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ )
+ ankle1n = Node(
+ "ankle1_{:d}".format(si),
+ -3 - off,
+ -3 - off,
+ 3 + 4 * si,
+ bodies=[front_right_leg, aux1, ankle1],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ )
+ hip2n = Node(
+ "hip2_{:d}".format(si),
+ -2 - off,
+ -2 - off,
+ 0 + 4 * si,
+ bodies=[torso, back_leg],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ )
+ ankle2n = Node(
+ "ankle2_{:d}".format(si),
+ -1 - off,
+ -1 - off,
+ 1 + 4 * si,
+ bodies=[back_leg, aux2, ankle2],
+ body_fn=lambda _id, x: np.clip(x, -1, 1).tolist()
+ )
+
+ edges += [HyperEdge(ankle1n, hip1n), HyperEdge(ankle2n, hip2n), HyperEdge(hip1n, hip2n)]
+ if si:
+ edges += [HyperEdge(hip1m, hip2m, hip1n, hip2n)]
+
+ hip1m = deepcopy(hip1n)
+ hip2m = deepcopy(hip2n)
+ joints.append([hip1n, ankle1n, hip2n, ankle2n])
+
+ free_joint = Node(
+ "free",
+ 0,
+ 0,
+ -1,
+ extra_obs={
+ "qpos": lambda env: env.sim.data.qpos[:7],
+ "qvel": lambda env: env.sim.data.qvel[:6],
+ "cfrc_ext": lambda env: np.clip(env.sim.data.cfrc_ext[0:1], -1, 1)
+ }
+ )
+ globals = {"joints": [free_joint]}
+
+ parts = [
+ [x for sublist in joints[i * n_segs_per_agents:(i + 1) * n_segs_per_agents] for x in sublist]
+ for i in range(n_agents)
+ ]
+
+ return parts, edges, globals
diff --git a/DI-engine/dizoo/overcooked/README.md b/DI-engine/dizoo/overcooked/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c22f51766f11fe95aaf2ee23ff8f6f564901aad0
--- /dev/null
+++ b/DI-engine/dizoo/overcooked/README.md
@@ -0,0 +1,3 @@
+This is the overcooked-ai environment compatiable to DI-engine.
+
+The origin code is referenced on [Overcooked-AI](https://github.com/HumanCompatibleAI/overcooked_ai), which is a benchmark environment for fully cooperative human-AI task performance, based on the wildly popular video game [Overcooked](http://www.ghosttowngames.com/overcooked/).
\ No newline at end of file
diff --git a/DI-engine/dizoo/overcooked/__init__.py b/DI-engine/dizoo/overcooked/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/overcooked/config/__init__.py b/DI-engine/dizoo/overcooked/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..50cdadfeb47cb0b5b12c640cfc3a852fc5331f51
--- /dev/null
+++ b/DI-engine/dizoo/overcooked/config/__init__.py
@@ -0,0 +1 @@
+from .overcooked_demo_ppo_config import overcooked_demo_ppo_config
diff --git a/DI-engine/dizoo/overcooked/config/overcooked_ppo_config.py b/DI-engine/dizoo/overcooked/config/overcooked_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2068a222a9a7f602704b192dbc1f0007f023e745
--- /dev/null
+++ b/DI-engine/dizoo/overcooked/config/overcooked_ppo_config.py
@@ -0,0 +1,82 @@
+from easydict import EasyDict
+import torch.nn as nn
+
+overcooked_ppo_config = dict(
+ exp_name="overcooked_ppo_seed0",
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=10,
+ n_evaluator_episode=10,
+ concat_obs=False, # stack 2 agents' obs in channel dim
+ stop_value=80,
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=(26, 5, 4),
+ action_shape=6,
+ action_space='discrete',
+ ),
+ learn=dict(
+ epoch_per_collect=4,
+ batch_size=128,
+ learning_rate=0.0005,
+ entropy_weight=0.01,
+ value_norm=True,
+ ),
+ collect=dict(
+ n_sample=1024,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+)
+overcooked_ppo_config = EasyDict(overcooked_ppo_config)
+main_config = overcooked_ppo_config
+cartpole_ppo_create_config = dict(
+ env=dict(
+ type='overcooked_game',
+ import_names=['dizoo.overcooked.envs.overcooked_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+cartpole_ppo_create_config = EasyDict(cartpole_ppo_create_config)
+create_config = cartpole_ppo_create_config
+
+
+class OEncoder(nn.Module):
+
+ def __init__(self, obs_shape):
+ super(OEncoder, self).__init__()
+ self.act = nn.ReLU()
+ self.main = nn.Sequential(
+ *[
+ nn.Conv2d(obs_shape[0], 64, 3, 1, 1),
+ self.act,
+ nn.Conv2d(64, 64, 3, 1, 1),
+ self.act,
+ nn.Conv2d(64, 64, 3, 1, 1),
+ self.act,
+ nn.AdaptiveAvgPool2d((1, 1)),
+ nn.Flatten(),
+ ]
+ )
+
+ def forward(self, x):
+ x = x.float()
+ B, A = x.shape[:2]
+ x = x.view(-1, *x.shape[2:])
+ x = self.main(x)
+ return x.view(B, A, 64)
+
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_onpolicy
+ from ding.model.template import VAC
+ m = main_config.policy.model
+ encoder = OEncoder(obs_shape=m.obs_shape)
+ model = VAC(obs_shape=m.obs_shape, action_shape=m.action_shape, action_space=m.action_space, encoder=encoder)
+ serial_pipeline_onpolicy([main_config, create_config], seed=0, model=model)
diff --git a/DI-engine/dizoo/overcooked/envs/__init__.py b/DI-engine/dizoo/overcooked/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..719920ac917d859b2093d2525954f6b1037f85a6
--- /dev/null
+++ b/DI-engine/dizoo/overcooked/envs/__init__.py
@@ -0,0 +1 @@
+from .overcooked_env import OvercookEnv, OvercookGameEnv
diff --git a/DI-engine/dizoo/overcooked/envs/overcooked_env.py b/DI-engine/dizoo/overcooked/envs/overcooked_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..2de769f0b30ddd3d4b738434d565ab534f98b8f4
--- /dev/null
+++ b/DI-engine/dizoo/overcooked/envs/overcooked_env.py
@@ -0,0 +1,326 @@
+from typing import Any, Union, List
+from collections import namedtuple
+from easydict import EasyDict
+import gym
+import copy
+import numpy as np
+
+from overcooked_ai_py.mdp.actions import Action, Direction
+from overcooked_ai_py.mdp.overcooked_mdp import PlayerState, OvercookedGridworld, OvercookedState, ObjectState, \
+ SoupState, Recipe
+from overcooked_ai_py.mdp.overcooked_env import OvercookedEnv, DEFAULT_ENV_PARAMS
+
+from ding.envs import BaseEnv
+from ding.utils import ENV_REGISTRY, deep_merge_dicts
+
+OvercookEnvTimestep = namedtuple('OvercookEnvTimestep', ['obs', 'reward', 'done', 'info'])
+
+# n, s = Direction.NORTH, Direction.SOUTH
+# e, w = Direction.EAST, Direction.WEST
+# stay, interact = Action.STAY, Action.INTERACT
+# Action.ALL_ACTIONS: [n, s, e, w, stay, interact]
+
+
+@ENV_REGISTRY.register('overcooked')
+class OvercookEnv(BaseEnv):
+ config = EasyDict(
+ dict(
+ env_name="cramped_room",
+ horizon=400,
+ concat_obs=False,
+ action_mask=True,
+ shape_reward=True,
+ )
+ )
+
+ def __init__(self, cfg) -> None:
+ self._cfg = deep_merge_dicts(self.config, cfg)
+ self._env_name = self._cfg.env_name
+ self._horizon = self._cfg.horizon
+ self._concat_obs = self._cfg.concat_obs
+ self._action_mask = self._cfg.action_mask
+ self._shape_reward = self._cfg.shape_reward
+ self.mdp = OvercookedGridworld.from_layout_name(self._env_name)
+ self.base_env = OvercookedEnv.from_mdp(self.mdp, horizon=self._horizon, info_level=0)
+
+ # rightnow overcook environment encoding only support 2 agent game
+ self.agent_num = 2
+ self.action_dim = len(Action.ALL_ACTIONS)
+ self.action_space = gym.spaces.Discrete(len(Action.ALL_ACTIONS))
+ # set up obs shape
+ featurize_fn = lambda mdp, state: mdp.lossless_state_encoding(state)
+ self.featurize_fn = featurize_fn
+ dummy_mdp = self.base_env.mdp
+ dummy_state = dummy_mdp.get_standard_start_state()
+ obs_shape = self.featurize_fn(dummy_mdp, dummy_state)[0].shape # (5, 4, 26)
+ obs_shape = (obs_shape[-1], *obs_shape[:-1]) # permute channel first
+ if self._concat_obs:
+ obs_shape = (obs_shape[0] * 2, *obs_shape[1:])
+ else:
+ obs_shape = (2, ) + obs_shape
+ self.observation_space = gym.spaces.Box(low=0, high=1, shape=obs_shape, dtype=np.int64)
+ if self._action_mask:
+ self.observation_space = gym.spaces.Dict(
+ {
+ 'agent_state': self.observation_space,
+ 'action_mask': gym.spaces.Box(
+ low=0, high=1, shape=(self.agent_num, self.action_dim), dtype=np.int64
+ )
+ }
+ )
+ self.reward_space = gym.spaces.Box(low=0, high=100, shape=(1, ), dtype=np.float32)
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def close(self) -> None:
+ # Note: the real env instance only has a empty close method, only pas
+ pass
+
+ def random_action(self):
+ return [self.action_space.sample() for _ in range(self.agent_num)]
+
+ def step(self, action):
+ assert all(self.action_space.contains(a) for a in action), "%r (%s) invalid" % (action, type(action))
+ agent_action, other_agent_action = [Action.INDEX_TO_ACTION[a] for a in action]
+
+ if self.agent_idx == 0:
+ joint_action = (agent_action, other_agent_action)
+ else:
+ joint_action = (other_agent_action, agent_action)
+
+ next_state, reward, done, env_info = self.base_env.step(joint_action)
+ reward = np.array([float(reward)])
+ self._eval_episode_return += reward
+ if self._shape_reward:
+ self._eval_episode_return += sum(env_info['shaped_r_by_agent'])
+ reward += sum(env_info['shaped_r_by_agent'])
+
+ ob_p0, ob_p1 = self.featurize_fn(self.mdp, next_state)
+ ob_p0, ob_p1 = self.obs_preprocess(ob_p0), self.obs_preprocess(ob_p1)
+ if self.agent_idx == 0:
+ both_agents_ob = [ob_p0, ob_p1]
+ else:
+ both_agents_ob = [ob_p1, ob_p0]
+ if self._concat_obs:
+ both_agents_ob = np.concatenate(both_agents_ob)
+ else:
+ both_agents_ob = np.stack(both_agents_ob)
+
+ env_info["policy_agent_idx"] = self.agent_idx
+ env_info["eval_episode_return"] = self._eval_episode_return
+ env_info["other_agent_env_idx"] = 1 - self.agent_idx
+
+ action_mask = self.get_action_mask()
+ if self._action_mask:
+ obs = {
+ "agent_state": both_agents_ob,
+ # "overcooked_state": self.base_env.state,
+ "action_mask": action_mask
+ }
+ else:
+ obs = both_agents_ob
+ return OvercookEnvTimestep(obs, reward, done, env_info)
+
+ def obs_preprocess(self, obs):
+ obs = obs.transpose(2, 0, 1)
+ return obs
+
+ def reset(self):
+ self.base_env.reset()
+ self._eval_episode_return = 0
+ self.mdp = self.base_env.mdp
+ # random init agent index
+ self.agent_idx = np.random.choice([0, 1])
+ ob_p0, ob_p1 = self.featurize_fn(self.mdp, self.base_env.state)
+ ob_p0, ob_p1 = self.obs_preprocess(ob_p0), self.obs_preprocess(ob_p1)
+
+ if self.agent_idx == 0:
+ both_agents_ob = [ob_p0, ob_p1]
+ else:
+ both_agents_ob = [ob_p1, ob_p0]
+ if self._concat_obs:
+ both_agents_ob = np.concatenate(both_agents_ob)
+ else:
+ both_agents_ob = np.stack(both_agents_ob)
+
+ action_mask = self.get_action_mask()
+
+ if self._action_mask:
+ obs = {"agent_state": both_agents_ob, "action_mask": action_mask}
+ else:
+ obs = both_agents_ob
+ return obs
+
+ def get_available_actions(self):
+ return self.mdp.get_actions(self.base_env.state)
+
+ def get_action_mask(self):
+ available_actions = self.get_available_actions()
+
+ action_masks = np.zeros((self.agent_num, self.action_dim)).astype(np.int64)
+
+ for i in range(self.action_dim):
+ if Action.INDEX_TO_ACTION[i] in available_actions[0]:
+ action_masks[0][i] = 1
+ if Action.INDEX_TO_ACTION[i] in available_actions[1]:
+ action_masks[1][i] = 1
+
+ return action_masks
+
+ def __repr__(self):
+ return "DI-engine Overcooked Env"
+
+
+@ENV_REGISTRY.register('overcooked_game')
+class OvercookGameEnv(BaseEnv):
+ config = EasyDict(
+ dict(
+ env_name="cramped_room",
+ horizon=400,
+ concat_obs=False,
+ action_mask=False,
+ shape_reward=True,
+ )
+ )
+
+ def __init__(self, cfg) -> None:
+ self._cfg = deep_merge_dicts(self.config, cfg)
+ self._env_name = self._cfg.env_name
+ self._horizon = self._cfg.horizon
+ self._concat_obs = self._cfg.concat_obs
+ self._action_mask = self._cfg.action_mask
+ self._shape_reward = self._cfg.shape_reward
+ self.mdp = OvercookedGridworld.from_layout_name(self._env_name)
+ self.base_env = OvercookedEnv.from_mdp(self.mdp, horizon=self._horizon, info_level=0)
+
+ # rightnow overcook environment encoding only support 2 agent game
+ self.agent_num = 2
+ self.action_dim = len(Action.ALL_ACTIONS)
+ self.action_space = gym.spaces.Discrete(len(Action.ALL_ACTIONS))
+ # set up obs shape
+ featurize_fn = lambda mdp, state: mdp.lossless_state_encoding(state)
+ self.featurize_fn = featurize_fn
+ dummy_mdp = self.base_env.mdp
+ dummy_state = dummy_mdp.get_standard_start_state()
+ obs_shape = self.featurize_fn(dummy_mdp, dummy_state)[0].shape # (5, 4, 26)
+ obs_shape = (obs_shape[-1], *obs_shape[:-1]) # permute channel first
+ if self._concat_obs:
+ obs_shape = (obs_shape[0] * 2, *obs_shape[1:])
+ else:
+ obs_shape = (2, ) + obs_shape
+ self.observation_space = gym.spaces.Box(low=0, high=1, shape=obs_shape, dtype=np.int64)
+ if self._action_mask:
+ self.observation_space = gym.spaces.Dict(
+ {
+ 'agent_state': self.observation_space,
+ 'action_mask': gym.spaces.Box(
+ low=0, high=1, shape=(self.agent_num, self.action_dim), dtype=np.int64
+ )
+ }
+ )
+
+ self.reward_space = gym.spaces.Box(low=0, high=100, shape=(1, ), dtype=np.float32)
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def close(self) -> None:
+ # Note: the real env instance only has a empty close method, only pass
+ pass
+
+ def random_action(self):
+ return [self.action_space.sample() for _ in range(self.agent_num)]
+
+ def step(self, action):
+ assert all(self.action_space.contains(a) for a in action), "%r (%s) invalid" % (action, type(action))
+ agent_action, other_agent_action = [Action.INDEX_TO_ACTION[a] for a in action]
+
+ if self.agent_idx == 0:
+ joint_action = (agent_action, other_agent_action)
+ else:
+ joint_action = (other_agent_action, agent_action)
+
+ next_state, reward, done, env_info = self.base_env.step(joint_action)
+
+ reward = np.array([float(reward)])
+ self._eval_episode_return += reward
+ if self._shape_reward:
+ self._eval_episode_return += sum(env_info['shaped_r_by_agent'])
+ reward += sum(env_info['shaped_r_by_agent'])
+ ob_p0, ob_p1 = self.featurize_fn(self.mdp, next_state)
+ ob_p0, ob_p1 = self.obs_preprocess(ob_p0), self.obs_preprocess(ob_p1)
+ if self.agent_idx == 0:
+ both_agents_ob = [ob_p0, ob_p1]
+ else:
+ both_agents_ob = [ob_p1, ob_p0]
+ if self._concat_obs:
+ both_agents_ob = np.concatenate(both_agents_ob)
+ else:
+ both_agents_ob = np.stack(both_agents_ob)
+
+ env_info["policy_agent_idx"] = self.agent_idx
+ env_info["eval_episode_return"] = self._eval_episode_return
+ env_info["other_agent_env_idx"] = 1 - self.agent_idx
+
+ action_mask = self.get_action_mask()
+ if self._action_mask:
+ obs = {"agent_state": both_agents_ob, "action_mask": action_mask}
+ else:
+ obs = both_agents_ob
+ return OvercookEnvTimestep(obs, reward, done, env_info)
+
+ def obs_preprocess(self, obs):
+ obs = obs.transpose(2, 0, 1)
+ return obs
+
+ def reset(self):
+ self.base_env.reset()
+ self._eval_episode_return = 0
+ self.mdp = self.base_env.mdp
+ # random init agent index
+ self.agent_idx = np.random.choice([0, 1])
+ #fix init agent index
+ self.agent_idx = 0
+ ob_p0, ob_p1 = self.featurize_fn(self.mdp, self.base_env.state)
+ ob_p0, ob_p1 = self.obs_preprocess(ob_p0), self.obs_preprocess(ob_p1)
+
+ if self.agent_idx == 0:
+ both_agents_ob = [ob_p0, ob_p1]
+ else:
+ both_agents_ob = [ob_p1, ob_p0]
+ if self._concat_obs:
+ both_agents_ob = np.concatenate(both_agents_ob)
+ else:
+ both_agents_ob = np.stack(both_agents_ob)
+
+ action_mask = self.get_action_mask()
+
+ if self._action_mask:
+ obs = {"agent_state": both_agents_ob, "action_mask": action_mask}
+ else:
+ obs = both_agents_ob
+ return obs
+
+ def get_available_actions(self):
+ return self.mdp.get_actions(self.base_env.state)
+
+ def get_action_mask(self):
+ available_actions = self.get_available_actions()
+
+ action_masks = np.zeros((self.agent_num, self.action_dim)).astype(np.int64)
+
+ for i in range(self.action_dim):
+ if Action.INDEX_TO_ACTION[i] in available_actions[0]:
+ action_masks[0][i] = 1
+ if Action.INDEX_TO_ACTION[i] in available_actions[1]:
+ action_masks[1][i] = 1
+
+ return action_masks
+
+ def __repr__(self):
+ return "DI-engine Overcooked GameEnv"
diff --git a/DI-engine/dizoo/overcooked/envs/test_overcooked_env.py b/DI-engine/dizoo/overcooked/envs/test_overcooked_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e238a8ced9143768998b00752f9095cacd919bf
--- /dev/null
+++ b/DI-engine/dizoo/overcooked/envs/test_overcooked_env.py
@@ -0,0 +1,44 @@
+from time import time
+from easydict import EasyDict
+import pytest
+import numpy as np
+from dizoo.overcooked.envs import OvercookEnv, OvercookGameEnv
+
+
+@pytest.mark.envtest
+class TestOvercooked:
+
+ @pytest.mark.parametrize("action_mask", [True, False])
+ def test_overcook(self, action_mask):
+ num_agent = 2
+ sum_rew = 0.0
+ env = OvercookEnv(EasyDict({'concat_obs': True, 'action_mask': action_mask}))
+ obs = env.reset()
+ for _ in range(env._horizon):
+ action = env.random_action()
+ timestep = env.step(action)
+ obs = timestep.obs
+ if action_mask:
+ for k, v in obs.items():
+ if k not in ['agent_state', 'action_mask']:
+ assert False
+ assert v.shape == env.observation_space[k].shape
+ else:
+ assert obs.shape == env.observation_space.shape
+ assert timestep.done
+ sum_rew += timestep.info['eval_episode_return'][0]
+ print("sum reward is:", sum_rew)
+
+ @pytest.mark.parametrize("concat_obs", [True, False])
+ def test_overcook_game(self, concat_obs):
+ env = OvercookGameEnv(EasyDict({'concat_obs': concat_obs}))
+ print('observation space: {}'.format(env.observation_space.shape))
+ obs = env.reset()
+ for _ in range(env._horizon):
+ action = env.random_action()
+ timestep = env.step(action)
+ obs = timestep.obs
+ assert obs.shape == env.observation_space.shape
+ assert timestep.done
+ print("agent 0 sum reward is:", timestep.info[0]['eval_episode_return'])
+ print("agent 1 sum reward is:", timestep.info[1]['eval_episode_return'])
diff --git a/DI-engine/dizoo/petting_zoo/__init__.py b/DI-engine/dizoo/petting_zoo/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/petting_zoo/config/__init__.py b/DI-engine/dizoo/petting_zoo/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9838dcaa0065cde6cbf72ae16e21b8d13e182d0c
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/config/__init__.py
@@ -0,0 +1,9 @@
+from .ptz_simple_spread_atoc_config import ptz_simple_spread_atoc_config, ptz_simple_spread_atoc_create_config
+from .ptz_simple_spread_collaq_config import ptz_simple_spread_collaq_config, ptz_simple_spread_collaq_create_config
+from .ptz_simple_spread_coma_config import ptz_simple_spread_coma_config, ptz_simple_spread_coma_create_config
+from .ptz_simple_spread_mappo_config import ptz_simple_spread_mappo_config, ptz_simple_spread_mappo_create_config
+from .ptz_simple_spread_qmix_config import ptz_simple_spread_qmix_config, ptz_simple_spread_qmix_create_config
+from .ptz_simple_spread_qtran_config import ptz_simple_spread_qtran_config, ptz_simple_spread_qtran_create_config
+from .ptz_simple_spread_vdn_config import ptz_simple_spread_vdn_config, ptz_simple_spread_vdn_create_config
+from .ptz_simple_spread_wqmix_config import ptz_simple_spread_wqmix_config, ptz_simple_spread_wqmix_create_config
+from .ptz_simple_spread_madqn_config import ptz_simple_spread_madqn_config, ptz_simple_spread_madqn_create_config # noqa
diff --git a/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_atoc_config.py b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_atoc_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..f9a7c944828189ce9ec83e5e87362b8a888ba468
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_atoc_config.py
@@ -0,0 +1,75 @@
+from easydict import EasyDict
+
+n_agent = 5
+n_landmark = n_agent # In simple_spread_v2, n_landmark must = n_agent
+collector_env_num = 8
+evaluator_env_num = 8
+communication = True
+ptz_simple_spread_atoc_config = dict(
+ exp_name='ptz_simple_spread_atoc_seed0',
+ env=dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_cycles=100,
+ agent_obs_only=True,
+ continuous_actions=True,
+ act_scale=True,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=0,
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
+ action_shape=5,
+ n_agent=n_agent,
+ communication=communication,
+ thought_size=16,
+ agent_per_group=min(n_agent // 2, 5),
+ ),
+ learn=dict(
+ update_per_collect=5,
+ batch_size=32,
+ learning_rate_actor=0.001,
+ learning_rate_critic=0.001,
+ ignore_done=True,
+ target_theta=0.005,
+ discount_factor=0.9,
+ communication=communication,
+ actor_update_freq=1,
+ noise=True,
+ noise_sigma=0.15,
+ noise_range=dict(
+ min=-0.5,
+ max=0.5,
+ ),
+ ),
+ collect=dict(
+ n_sample=500,
+ noise_sigma=0.4,
+ ),
+ other=dict(replay_buffer=dict(replay_buffer_size=100000, ), )
+ ),
+)
+ptz_simple_spread_atoc_config = EasyDict(ptz_simple_spread_atoc_config)
+main_config = ptz_simple_spread_atoc_config
+ptz_simple_spread_atoc_create_config = dict(
+ env=dict(
+ import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
+ type='petting_zoo',
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='atoc'),
+)
+ptz_simple_spread_atoc_create_config = EasyDict(ptz_simple_spread_atoc_create_config)
+create_config = ptz_simple_spread_atoc_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c ptz_simple_spread_atoc_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_collaq_config.py b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_collaq_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e791283e322e9bed3b166f88062f4a42e6ee3dbc
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_collaq_config.py
@@ -0,0 +1,74 @@
+from easydict import EasyDict
+
+n_agent = 5
+n_landmark = n_agent
+collector_env_num = 8
+evaluator_env_num = 8
+ptz_simple_spread_collaq_config = dict(
+ exp_name='ptz_simple_spread_collaq_seed0',
+ env=dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_cycles=100,
+ agent_obs_only=False,
+ continuous_actions=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=0,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ agent_num=n_agent,
+ obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
+ alone_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2,
+ global_obs_shape=n_agent * 4 + n_landmark * 2 + n_agent * (n_agent - 1) * 2,
+ action_shape=5,
+ hidden_size_list=[128, 128, 64],
+ attention=True,
+ self_feature_range=[2, 4], # placeholder
+ ally_feature_range=[4, n_agent * 2 + 2], # placeholder
+ attention_size=32,
+ ),
+ agent_num=n_agent,
+ learn=dict(
+ update_per_collect=100,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_theta=0.001,
+ discount_factor=0.99,
+ ),
+ collect=dict(
+ n_sample=600,
+ unroll_len=16,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(eps=dict(
+ type='exp',
+ start=1.0,
+ end=0.05,
+ decay=100000,
+ ), ),
+ ),
+)
+ptz_simple_spread_collaq_config = EasyDict(ptz_simple_spread_collaq_config)
+main_config = ptz_simple_spread_collaq_config
+ptz_simple_spread_collaq_create_config = dict(
+ env=dict(
+ import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
+ type='petting_zoo',
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='collaq'),
+)
+ptz_simple_spread_collaq_create_config = EasyDict(ptz_simple_spread_collaq_create_config)
+create_config = ptz_simple_spread_collaq_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c ptz_simple_spread_collaq_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_coma_config.py b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_coma_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..6ab557988e15a07188d27e6891b67163ba6cd930
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_coma_config.py
@@ -0,0 +1,82 @@
+from easydict import EasyDict
+
+n_agent = 5
+n_landmark = n_agent
+collector_env_num = 8
+evaluator_env_num = 8
+ptz_simple_spread_coma_config = dict(
+ exp_name='ptz_simple_spread_coma_seed0',
+ env=dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_cycles=100,
+ agent_obs_only=False,
+ continuous_actions=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=0,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ agent_num=n_agent,
+ obs_shape=dict(
+ agent_state=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
+ global_state=n_agent * 4 + n_landmark * 2 + n_agent * (n_agent - 1) * 2,
+ ),
+ action_shape=[
+ 5,
+ ],
+ actor_hidden_size_list=[128, 128, 64],
+ ),
+ agent_num=n_agent,
+ learn=dict(
+ update_per_collect=1,
+ batch_size=32,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ discount_factor=0.99,
+ td_lambda=0.8,
+ value_weight=1.0,
+ entropy_weight=0.01,
+ ),
+ collect=dict(
+ n_sample=600,
+ unroll_len=16,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.5,
+ end=0.01,
+ decay=100000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=64,
+ max_use=10,
+ ),
+ ),
+ ),
+)
+ptz_simple_spread_coma_config = EasyDict(ptz_simple_spread_coma_config)
+main_config = ptz_simple_spread_coma_config
+ptz_simple_spread_coma_create_config = dict(
+ env=dict(
+ import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
+ type='petting_zoo',
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='coma'),
+)
+ptz_simple_spread_coma_create_config = EasyDict(ptz_simple_spread_coma_create_config)
+create_config = ptz_simple_spread_coma_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c ptz_simple_spread_coma_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_happo_config.py b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_happo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d1ff088326b1367a00184a77fef2a1f7a0c33283
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_happo_config.py
@@ -0,0 +1,88 @@
+from easydict import EasyDict
+
+n_agent = 3
+n_landmark = n_agent
+collector_env_num = 8
+evaluator_env_num = 8
+main_config = dict(
+ exp_name='ptz_simple_spread_happo_seed0',
+ env=dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_cycles=25,
+ agent_obs_only=False,
+ agent_specific_global_state=True,
+ continuous_actions=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=0,
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ agent_num=n_agent,
+ action_space='discrete',
+ model=dict(
+ action_space='discrete',
+ agent_num=n_agent,
+ agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
+ global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) +
+ n_landmark * 2 + n_agent * (n_agent - 1) * 2,
+ action_shape=5,
+ use_lstm=False,
+ ),
+ learn=dict(
+ multi_gpu=False,
+ epoch_per_collect=5,
+ batch_size=3200,
+ learning_rate=5e-4,
+ critic_learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(
+ n_sample=3200,
+ unroll_len=1,
+ env_num=collector_env_num,
+ ),
+ eval=dict(
+ env_num=evaluator_env_num,
+ evaluator=dict(eval_freq=50, ),
+ ),
+ other=dict(),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
+ type='petting_zoo',
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='happo'),
+)
+create_config = EasyDict(create_config)
+ptz_simple_spread_happo_config = main_config
+ptz_simple_spread_happo_create_config = create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_onpolicy -c ptz_simple_spread_happo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_madqn_config.py b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_madqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8ddb636abf6ba7f0e1efadb643dc0945250c7583
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_madqn_config.py
@@ -0,0 +1,83 @@
+from easydict import EasyDict
+
+n_agent = 3
+n_landmark = n_agent
+collector_env_num = 8
+evaluator_env_num = 8
+main_config = dict(
+ exp_name='ptz_simple_spread_madqn_seed0',
+ env=dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_cycles=25,
+ agent_obs_only=False,
+ agent_specific_global_state=True,
+ continuous_actions=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=0,
+ ),
+ policy=dict(
+ cuda=True,
+ nstep=3,
+ model=dict(
+ obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
+ global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) +
+ n_landmark * 2 + n_agent * (n_agent - 1) * 2,
+ agent_num=n_agent,
+ action_shape=5,
+ global_cooperation=True,
+ hidden_size_list=[256, 256],
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+ ),
+ collect=dict(
+ collector=dict(get_train_sample=True, ),
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ command=dict(),
+ eval=dict(
+ env_num=evaluator_env_num,
+ evaluator=dict(eval_freq=1000, ),
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=15000, ),
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
+ type='petting_zoo',
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='madqn'),
+ collector=dict(type='episode'),
+)
+create_config = EasyDict(create_config)
+ptz_simple_spread_madqn_config = main_config
+ptz_simple_spread_madqn_create_config = create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_entry -c ptz_simple_spread_masac_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_mappo_config.py b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5eb1095a5a925508c511c119945fab185d0f06d7
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_mappo_config.py
@@ -0,0 +1,85 @@
+from easydict import EasyDict
+
+n_agent = 3
+n_landmark = n_agent
+collector_env_num = 8
+evaluator_env_num = 8
+main_config = dict(
+ exp_name='ptz_simple_spread_mappo_seed0',
+ env=dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_cycles=25,
+ agent_obs_only=False,
+ agent_specific_global_state=True,
+ continuous_actions=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=0,
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ action_space='discrete',
+ model=dict(
+ action_space='discrete',
+ agent_num=n_agent,
+ agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
+ global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) +
+ n_landmark * 2 + n_agent * (n_agent - 1) * 2,
+ action_shape=5,
+ ),
+ learn=dict(
+ multi_gpu=False,
+ epoch_per_collect=5,
+ batch_size=3200,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(
+ n_sample=3200,
+ unroll_len=1,
+ env_num=collector_env_num,
+ ),
+ eval=dict(
+ env_num=evaluator_env_num,
+ evaluator=dict(eval_freq=50, ),
+ ),
+ other=dict(),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
+ type='petting_zoo',
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+ptz_simple_spread_mappo_config = main_config
+ptz_simple_spread_mappo_create_config = create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_onpolicy -c ptz_simple_spread_mappo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_masac_config.py b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3782138aa76d56d99465fecd115fdef71741edc
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_masac_config.py
@@ -0,0 +1,84 @@
+from easydict import EasyDict
+
+n_agent = 3
+n_landmark = n_agent
+collector_env_num = 8
+evaluator_env_num = 8
+main_config = dict(
+ exp_name='ptz_simple_spread_masac_seed0',
+ env=dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_cycles=25,
+ agent_obs_only=False,
+ agent_specific_global_state=True,
+ continuous_actions=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=0,
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ multi_agent=True,
+ random_collect_size=5000,
+ model=dict(
+ agent_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
+ global_obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) +
+ n_landmark * 2 + n_agent * (n_agent - 1) * 2,
+ action_shape=5,
+ twin_critic=True,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ # learning_rates
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-5,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ target_entropy=-2,
+ ),
+ collect=dict(
+ n_sample=1600,
+ env_num=collector_env_num,
+ ),
+ eval=dict(
+ env_num=evaluator_env_num,
+ evaluator=dict(eval_freq=50, ),
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(replay_buffer_size=int(1e6), )
+ ),
+ ),
+)
+
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
+ type='petting_zoo',
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sac_discrete'),
+)
+create_config = EasyDict(create_config)
+ptz_simple_spread_masac_config = main_config
+ptz_simple_spread_masac_create_config = create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial_entry -c ptz_simple_spread_masac_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_qmix_config.py b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_qmix_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fcb52277cf3fcb3d6e2a478b04be449620c799c2
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_qmix_config.py
@@ -0,0 +1,71 @@
+from easydict import EasyDict
+
+n_agent = 3
+n_landmark = n_agent
+collector_env_num = 8
+evaluator_env_num = 8
+main_config = dict(
+ exp_name='ptz_simple_spread_qmix_seed0',
+ env=dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_cycles=25,
+ agent_obs_only=False,
+ continuous_actions=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=0,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ agent_num=n_agent,
+ obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
+ global_obs_shape=n_agent * 4 + n_landmark * 2 + n_agent * (n_agent - 1) * 2,
+ action_shape=5,
+ hidden_size_list=[128, 128, 64],
+ mixer=True,
+ ),
+ learn=dict(
+ update_per_collect=100,
+ batch_size=32,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ discount_factor=0.99,
+ double_q=True,
+ ),
+ collect=dict(
+ n_sample=600,
+ unroll_len=16,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(eps=dict(
+ type='exp',
+ start=1.0,
+ end=0.05,
+ decay=100000,
+ ), ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
+ type='petting_zoo',
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qmix'),
+)
+create_config = EasyDict(create_config)
+
+ptz_simple_spread_qmix_config = main_config
+ptz_simple_spread_qmix_create_config = create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c ptz_simple_spread_qmix_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_qtran_config.py b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_qtran_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..afe6e0ac2ef2d1f762fec27aced9ca3b93e49971
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_qtran_config.py
@@ -0,0 +1,84 @@
+from easydict import EasyDict
+
+n_agent = 5
+n_landmark = n_agent
+collector_env_num = 8
+evaluator_env_num = 8
+main_config = dict(
+ exp_name='ptz_simple_spread_qtran_seed0',
+ env=dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_cycles=100,
+ agent_obs_only=False,
+ continuous_actions=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=0,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ agent_num=n_agent,
+ obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
+ global_obs_shape=n_agent * 4 + n_landmark * 2 + n_agent * (n_agent - 1) * 2,
+ action_shape=5,
+ hidden_size_list=[128],
+ embedding_size=64,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=100,
+ batch_size=32,
+ learning_rate=0.0005,
+ double_q=True,
+ target_update_theta=0.001,
+ discount_factor=0.99,
+ td_weight=1,
+ opt_weight=0.1,
+ nopt_min_weight=0.0001,
+ ),
+ collect=dict(
+ n_sample=600,
+ unroll_len=16,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.0,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
+ type='petting_zoo',
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qtran'),
+)
+create_config = EasyDict(create_config)
+
+ptz_simple_spread_qtran_config = main_config
+ptz_simple_spread_qtran_create_config = create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c ptz_simple_spread_qtran_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_vdn_config.py b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_vdn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4aeae888802d4e3f514c882904e592219dab1179
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_vdn_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+
+n_agent = 5
+n_landmark = n_agent
+collector_env_num = 8
+evaluator_env_num = 8
+ptz_simple_spread_vdn_config = dict(
+ exp_name='ptz_simple_spread_vdn_seed0',
+ env=dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_cycles=100,
+ agent_obs_only=False,
+ continuous_actions=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=0,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ agent_num=n_agent,
+ obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
+ global_obs_shape=n_agent * 4 + n_landmark * 2 + n_agent * (n_agent - 1) * 2,
+ action_shape=5,
+ hidden_size_list=[128, 128, 64],
+ mixer=False,
+ ),
+ agent_num=n_agent,
+ learn=dict(
+ update_per_collect=100,
+ batch_size=32,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ discount_factor=0.99,
+ ),
+ collect=dict(
+ n_sample=600,
+ unroll_len=16,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(eps=dict(
+ type='exp',
+ start=1.0,
+ end=0.05,
+ decay=100000,
+ ), ),
+ ),
+)
+ptz_simple_spread_vdn_config = EasyDict(ptz_simple_spread_vdn_config)
+main_config = ptz_simple_spread_vdn_config
+ptz_simple_spread_vdn_create_config = dict(
+ env=dict(
+ import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
+ type='petting_zoo',
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qmix'),
+)
+ptz_simple_spread_vdn_create_config = EasyDict(ptz_simple_spread_vdn_create_config)
+create_config = ptz_simple_spread_vdn_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c ptz_simple_spread_vdn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_wqmix_config.py b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_wqmix_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b32160929e781be56a0f39526aba672395f30c3
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/config/ptz_simple_spread_wqmix_config.py
@@ -0,0 +1,75 @@
+from easydict import EasyDict
+
+n_agent = 5
+n_landmark = n_agent
+collector_env_num = 8
+evaluator_env_num = 8
+ptz_simple_spread_wqmix_config = dict(
+ exp_name='ptz_simple_spread_wqmix_seed0',
+ env=dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_cycles=100,
+ agent_obs_only=False,
+ continuous_actions=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ stop_value=0,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ agent_num=n_agent,
+ obs_shape=2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2,
+ global_obs_shape=n_agent * 4 + n_landmark * 2 + n_agent * (n_agent - 1) * 2,
+ action_shape=5,
+ hidden_size_list=[128, 128, 64],
+ ),
+ agent_num=n_agent,
+ learn=dict(
+ update_per_collect=100,
+ batch_size=32,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ discount_factor=0.99,
+ # # for OW Optimistically-Weighted
+ wqmix_ow=True,
+ alpha=0.5,
+ # # for CW Centrally-Weighted
+ # wqmix_ow = False,
+ # alpha = 0.75,
+ ),
+ collect=dict(
+ n_sample=600,
+ unroll_len=16,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, ),
+ other=dict(eps=dict(
+ type='exp',
+ start=1.0,
+ end=0.05,
+ decay=100000,
+ ), ),
+ ),
+)
+ptz_simple_spread_wqmix_config = EasyDict(ptz_simple_spread_wqmix_config)
+main_config = ptz_simple_spread_wqmix_config
+ptz_simple_spread_wqmix_create_config = dict(
+ env=dict(
+ import_names=['dizoo.petting_zoo.envs.petting_zoo_simple_spread_env'],
+ type='petting_zoo',
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='wqmix'),
+)
+ptz_simple_spread_wqmix_create_config = EasyDict(ptz_simple_spread_wqmix_create_config)
+create_config = ptz_simple_spread_wqmix_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c ptz_simple_spread_wqmix_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/petting_zoo/entry/ptz_simple_spread_eval.py b/DI-engine/dizoo/petting_zoo/entry/ptz_simple_spread_eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..e7f298ad39bc7c4856c0df3334b8cd5b655689c5
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/entry/ptz_simple_spread_eval.py
@@ -0,0 +1,12 @@
+from dizoo.petting_zoo.config.ptz_simple_spread_mappo_config import main_config, create_config
+from ding.entry import eval
+
+
+def main():
+ ckpt_path = './ckpt_best.pth.tar'
+ replay_path = './replay_videos'
+ eval((main_config, create_config), seed=0, load_path=ckpt_path, replay_path=replay_path)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/petting_zoo/envs/__init__.py b/DI-engine/dizoo/petting_zoo/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py b/DI-engine/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..bde84685f00f98c68d011613a227166f29ab3609
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/envs/petting_zoo_simple_spread_env.py
@@ -0,0 +1,419 @@
+from typing import Any, List, Union, Optional, Dict
+import gymnasium as gym
+import numpy as np
+import pettingzoo
+from functools import reduce
+
+from ding.envs import BaseEnv, BaseEnvTimestep, FrameStackWrapper
+from ding.torch_utils import to_ndarray, to_list
+from ding.envs.common.common_function import affine_transform
+from ding.utils import ENV_REGISTRY, import_module
+from pettingzoo.utils.conversions import parallel_wrapper_fn
+from pettingzoo.mpe._mpe_utils.simple_env import SimpleEnv, make_env
+from pettingzoo.mpe.simple_spread.simple_spread import Scenario
+
+
+class PTZRecordVideo(gym.wrappers.RecordVideo):
+ def step(self, action):
+ """Steps through the environment using action, recording observations if :attr:`self.recording`."""
+ # gymnasium==0.27.1
+ (
+ observations,
+ rewards,
+ terminateds,
+ truncateds,
+ infos,
+ ) = self.env.step(action)
+
+ # Because pettingzoo returns a dict of terminated and truncated, we need to check if any of the values are True
+ if not (self.terminated is True or self.truncated is True): # the first location for modifications
+ # increment steps and episodes
+ self.step_id += 1
+ if not self.is_vector_env:
+ if terminateds or truncateds:
+ self.episode_id += 1
+ self.terminated = terminateds
+ self.truncated = truncateds
+ elif terminateds[0] or truncateds[0]:
+ self.episode_id += 1
+ self.terminated = terminateds[0]
+ self.truncated = truncateds[0]
+
+ if self.recording:
+ assert self.video_recorder is not None
+ self.video_recorder.capture_frame()
+ self.recorded_frames += 1
+ if self.video_length > 0:
+ if self.recorded_frames > self.video_length:
+ self.close_video_recorder()
+ else:
+ if not self.is_vector_env:
+ if terminateds is True or truncateds is True: # the second location for modifications
+ self.close_video_recorder()
+ elif terminateds[0] or truncateds[0]:
+ self.close_video_recorder()
+
+ elif self._video_enabled():
+ self.start_video_recorder()
+
+ return observations, rewards, terminateds, truncateds, infos
+
+
+@ENV_REGISTRY.register('petting_zoo')
+class PettingZooEnv(BaseEnv):
+ # Now only supports simple_spread_v2.
+ # All agents' observations should have the same shape.
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ self._replay_path = None
+ self._env_family = self._cfg.env_family
+ self._env_id = self._cfg.env_id
+ self._num_agents = self._cfg.n_agent
+ self._num_landmarks = self._cfg.n_landmark
+ self._continuous_actions = self._cfg.get('continuous_actions', False)
+ self._max_cycles = self._cfg.get('max_cycles', 25)
+ self._act_scale = self._cfg.get('act_scale', False)
+ self._agent_specific_global_state = self._cfg.get('agent_specific_global_state', False)
+ if self._act_scale:
+ assert self._continuous_actions, 'Only continuous action space env needs act_scale'
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ # In order to align with the simple spread in Multiagent Particle Env (MPE),
+ # instead of adopting the pettingzoo interface directly,
+ # we have redefined the way rewards are calculated
+
+ # import_module(['pettingzoo.{}.{}'.format(self._env_family, self._env_id)])
+ # self._env = pettingzoo.__dict__[self._env_family].__dict__[self._env_id].parallel_env(
+ # N=self._cfg.n_agent, continuous_actions=self._continuous_actions, max_cycles=self._max_cycles
+ # )
+
+ # init parallel_env wrapper
+ _env = make_env(simple_spread_raw_env)
+ parallel_env = parallel_wrapper_fn(_env)
+ # init env
+ self._env = parallel_env(
+ N=self._cfg.n_agent, continuous_actions=self._continuous_actions, max_cycles=self._max_cycles
+ )
+ self._env.reset()
+ self._agents = self._env.agents
+
+ self._action_space = gym.spaces.Dict({agent: self._env.action_space(agent) for agent in self._agents})
+ single_agent_obs_space = self._env.action_space(self._agents[0])
+ if isinstance(single_agent_obs_space, gym.spaces.Box):
+ self._action_dim = single_agent_obs_space.shape
+ elif isinstance(single_agent_obs_space, gym.spaces.Discrete):
+ self._action_dim = (single_agent_obs_space.n, )
+ else:
+ raise Exception('Only support `Box` or `Discrete` obs space for single agent.')
+
+ # only for env 'simple_spread_v2', n_agent = 5
+ # now only for the case that each agent in the team have the same obs structure and corresponding shape.
+ if not self._cfg.agent_obs_only:
+ self._observation_space = gym.spaces.Dict(
+ {
+ 'agent_state': gym.spaces.Box(
+ low=float("-inf"),
+ high=float("inf"),
+ shape=(self._num_agents,
+ self._env.observation_space('agent_0').shape[0]), # (self._num_agents, 30)
+ dtype=np.float32
+ ),
+ 'global_state': gym.spaces.Box(
+ low=float("-inf"),
+ high=float("inf"),
+ shape=(
+ 4 * self._num_agents + 2 * self._num_landmarks + 2 * self._num_agents *
+ (self._num_agents - 1),
+ ),
+ dtype=np.float32
+ ),
+ 'agent_alone_state': gym.spaces.Box(
+ low=float("-inf"),
+ high=float("inf"),
+ shape=(self._num_agents, 4 + 2 * self._num_landmarks + 2 * (self._num_agents - 1)),
+ dtype=np.float32
+ ),
+ 'agent_alone_padding_state': gym.spaces.Box(
+ low=float("-inf"),
+ high=float("inf"),
+ shape=(self._num_agents,
+ self._env.observation_space('agent_0').shape[0]), # (self._num_agents, 30)
+ dtype=np.float32
+ ),
+ 'action_mask': gym.spaces.Box(
+ low=float("-inf"),
+ high=float("inf"),
+ shape=(self._num_agents, self._action_dim[0]), # (self._num_agents, 5)
+ dtype=np.float32
+ )
+ }
+ )
+ # whether use agent_specific_global_state. It is usually used in AC multiagent algos, e.g., mappo, masac, etc.
+ if self._agent_specific_global_state:
+ agent_specifig_global_state = gym.spaces.Box(
+ low=float("-inf"),
+ high=float("inf"),
+ shape=(
+ self._num_agents, self._env.observation_space('agent_0').shape[0] + 4 * self._num_agents +
+ 2 * self._num_landmarks + 2 * self._num_agents * (self._num_agents - 1)
+ ),
+ dtype=np.float32
+ )
+ self._observation_space['global_state'] = agent_specifig_global_state
+ else:
+ # for case when env.agent_obs_only=True
+ self._observation_space = gym.spaces.Box(
+ low=float("-inf"),
+ high=float("inf"),
+ shape=(self._num_agents, self._env.observation_space('agent_0').shape[0]),
+ dtype=np.float32
+ )
+
+ self._reward_space = gym.spaces.Dict(
+ {
+ agent: gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32)
+ for agent in self._agents
+ }
+ )
+ if self._replay_path is not None:
+ self._env.render_mode = 'rgb_array'
+ self._env = PTZRecordVideo(self._env, self._replay_path, name_prefix=f'rl-video-{id(self)}', disable_logger=True)
+ self._init_flag = True
+ if hasattr(self, '_seed'):
+ obs = self._env.reset(seed=self._seed)
+ else:
+ obs = self._env.reset()
+ # self._eval_episode_return = {agent: 0. for agent in self._agents}
+ self._eval_episode_return = 0.
+ self._step_count = 0
+ obs_n = self._process_obs(obs)
+ return obs_n
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def render(self) -> None:
+ self._env.render()
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ self._step_count += 1
+ assert isinstance(action, np.ndarray), type(action)
+ action = self._process_action(action)
+ if self._act_scale:
+ for agent in self._agents:
+ # print(action[agent])
+ # print(self.action_space[agent])
+ # print(self.action_space[agent].low, self.action_space[agent].high)
+ action[agent] = affine_transform(
+ action[agent], min_val=self.action_space[agent].low, max_val=self.action_space[agent].high
+ )
+
+ obs, rew, done, trunc, info = self._env.step(action)
+ obs_n = self._process_obs(obs)
+ rew_n = np.array([sum([rew[agent] for agent in self._agents])])
+ rew_n = rew_n.astype(np.float32)
+ # collide_sum = 0
+ # for i in range(self._num_agents):
+ # collide_sum += info['n'][i][1]
+ # collide_penalty = self._cfg.get('collide_penal', self._num_agent)
+ # rew_n += collide_sum * (1.0 - collide_penalty)
+ # rew_n = rew_n / (self._cfg.get('max_cycles', 25) * self._num_agent)
+ self._eval_episode_return += rew_n.item()
+
+ # occupied_landmarks = info['n'][0][3]
+ # if self._step_count >= self._max_step or occupied_landmarks >= self._n_agent \
+ # or occupied_landmarks >= self._num_landmarks:
+ # done_n = True
+ # else:
+ # done_n = False
+ done_n = reduce(lambda x, y: x and y, done.values()) or self._step_count >= self._max_cycles
+
+ # for agent in self._agents:
+ # self._eval_episode_return[agent] += rew[agent]
+ if done_n: # or reduce(lambda x, y: x and y, done.values())
+ info['eval_episode_return'] = self._eval_episode_return
+ # for agent in rew:
+ # rew[agent] = to_ndarray([rew[agent]])
+ return BaseEnvTimestep(obs_n, rew_n, done_n, info)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ def _process_obs(self, obs: 'torch.Tensor') -> np.ndarray: # noqa
+ obs = np.array([obs[agent] for agent in self._agents]).astype(np.float32)
+ if self._cfg.get('agent_obs_only', False):
+ return obs
+ ret = {}
+ # Raw agent observation structure is --
+ # [self_vel, self_pos, landmark_rel_positions, other_agent_rel_positions, communication]
+ # where `communication` are signals from other agents (two for each agent in `simple_spread_v2`` env)
+
+ # agent_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2).
+ # Stacked observation. Contains
+ # - agent itself's state(velocity + position)
+ # - position of items that the agent can observe(e.g. other agents, landmarks)
+ # - communication
+ ret['agent_state'] = obs
+ # global_state: Shape (n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2, ).
+ # 1-dim vector. Contains
+ # - all agents' state(velocity + position) +
+ # - all landmarks' position +
+ # - all agents' communication
+ ret['global_state'] = np.concatenate(
+ [
+ obs[0, 2:-(self._num_agents - 1) * 2], # all agents' position + all landmarks' position
+ obs[:, 0:2].flatten(), # all agents' velocity
+ obs[:, -(self._num_agents - 1) * 2:].flatten() # all agents' communication
+ ]
+ )
+ # agent_specific_global_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2).
+ # 2-dim vector. contains
+ # - agent_state info
+ # - global_state info
+ if self._agent_specific_global_state:
+ ret['global_state'] = np.concatenate(
+ [ret['agent_state'],
+ np.expand_dims(ret['global_state'], axis=0).repeat(self._num_agents, axis=0)],
+ axis=1
+ )
+ # agent_alone_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2).
+ # Stacked observation. Exclude other agents' positions from agent_state. Contains
+ # - agent itself's state(velocity + position) +
+ # - landmarks' positions (do not include other agents' positions)
+ # - communication
+ ret['agent_alone_state'] = np.concatenate(
+ [
+ obs[:, 0:(4 + self._num_agents * 2)], # agent itself's state + landmarks' position
+ obs[:, -(self._num_agents - 1) * 2:], # communication
+ ],
+ 1
+ )
+ # agent_alone_padding_state: Shape (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2).
+ # Contains the same information as agent_alone_state;
+ # But 0-padding other agents' positions.
+ ret['agent_alone_padding_state'] = np.concatenate(
+ [
+ obs[:, 0:(4 + self._num_agents * 2)], # agent itself's state + landmarks' position
+ np.zeros((self._num_agents,
+ (self._num_agents - 1) * 2), np.float32), # Other agents' position(0-padding)
+ obs[:, -(self._num_agents - 1) * 2:] # communication
+ ],
+ 1
+ )
+ # action_mask: All actions are of use(either 1 for discrete or 5 for continuous). Thus all 1.
+ ret['action_mask'] = np.ones((self._num_agents, *self._action_dim)).astype(np.float32)
+ return ret
+
+ def _process_action(self, action: 'torch.Tensor') -> Dict[str, np.ndarray]: # noqa
+ dict_action = {}
+ for i, agent in enumerate(self._agents):
+ agent_action = action[i]
+ if agent_action.shape == (1, ):
+ agent_action = agent_action.squeeze() # 0-dim array
+ dict_action[agent] = agent_action
+ return dict_action
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ for k in random_action:
+ if isinstance(random_action[k], np.ndarray):
+ pass
+ elif isinstance(random_action[k], int):
+ random_action[k] = to_ndarray([random_action[k]], dtype=np.int64)
+ return random_action
+
+ def __repr__(self) -> str:
+ return "DI-engine PettingZoo Env"
+
+ @property
+ def agents(self) -> List[str]:
+ return self._agents
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+
+class simple_spread_raw_env(SimpleEnv):
+
+ def __init__(self, N=3, local_ratio=0.5, max_cycles=25, continuous_actions=False):
+ assert 0. <= local_ratio <= 1., "local_ratio is a proportion. Must be between 0 and 1."
+ scenario = Scenario()
+ world = scenario.make_world(N)
+ super().__init__(scenario, world, max_cycles, continuous_actions=continuous_actions, local_ratio=local_ratio)
+ self.render_mode = 'rgb_array'
+ self.metadata['name'] = "simple_spread_v2"
+
+ def _execute_world_step(self):
+ # set action for each agent
+ for i, agent in enumerate(self.world.agents):
+ action = self.current_actions[i]
+ scenario_action = []
+ if agent.movable:
+ mdim = self.world.dim_p * 2 + 1
+ if self.continuous_actions:
+ scenario_action.append(action[0:mdim])
+ action = action[mdim:]
+ else:
+ scenario_action.append(action % mdim)
+ action //= mdim
+ if not agent.silent:
+ scenario_action.append(action)
+ self._set_action(scenario_action, agent, self.action_spaces[agent.name])
+
+ self.world.step()
+
+ global_reward = 0.
+ if self.local_ratio is not None:
+ global_reward = float(self.scenario.global_reward(self.world))
+
+ for agent in self.world.agents:
+ agent_reward = float(self.scenario.reward(agent, self.world))
+ if self.local_ratio is not None:
+ # we changed reward calc way to keep same with mpe
+ # reward = global_reward * (1 - self.local_ratio) + agent_reward * self.local_ratio
+ reward = global_reward + agent_reward
+ else:
+ reward = agent_reward
+
+ self.rewards[agent.name] = reward
+
+ def render(self):
+ if self.render_mode is None:
+ gym.logger.warn(
+ "You are calling render method without specifying any render mode."
+ )
+ return
+ import pygame
+
+ self.enable_render(self.render_mode)
+
+ self.draw()
+ observation = np.array(pygame.surfarray.pixels3d(self.screen))
+ if self.render_mode == "human":
+ pygame.display.flip()
+ return (
+ np.transpose(observation, axes=(1, 0, 2))
+ if self.render_mode == "rgb_array"
+ else None
+ )
diff --git a/DI-engine/dizoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py b/DI-engine/dizoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..22117cf85fd31e709919bbee4488d50f3cf74c6d
--- /dev/null
+++ b/DI-engine/dizoo/petting_zoo/envs/test_petting_zoo_simple_spread_env.py
@@ -0,0 +1,133 @@
+from easydict import EasyDict
+import pytest
+import numpy as np
+import pettingzoo
+from ding.utils import import_module
+
+from dizoo.petting_zoo.envs.petting_zoo_simple_spread_env import PettingZooEnv
+
+
+@pytest.mark.envtest
+class TestPettingZooEnv:
+
+ def test_agent_obs_only(self):
+ n_agent = 5
+ n_landmark = n_agent
+ env = PettingZooEnv(
+ EasyDict(
+ dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_step=100,
+ agent_obs_only=True,
+ continuous_actions=True,
+ )
+ )
+ )
+ env.seed(123)
+ assert env._seed == 123
+ obs = env.reset()
+ assert obs.shape == (n_agent, 2 + 2 + (n_agent - 1) * 2 + n_agent * 2 + (n_agent - 1) * 2)
+ for i in range(10):
+ random_action = env.random_action()
+ random_action = np.array([random_action[agent] for agent in random_action])
+ timestep = env.step(random_action)
+ print(timestep)
+ assert isinstance(timestep.obs, np.ndarray), timestep.obs
+ assert timestep.obs.shape == (n_agent, 2 + 2 + (n_agent - 1) * 2 + n_agent * 2 + (n_agent - 1) * 2)
+ assert isinstance(timestep.done, bool), timestep.done
+ assert isinstance(timestep.reward, np.ndarray), timestep.reward
+ assert timestep.reward.dtype == np.float32
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
+
+ def test_dict_obs(self):
+ n_agent = 5
+ n_landmark = n_agent
+ env = PettingZooEnv(
+ EasyDict(
+ dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_step=100,
+ agent_obs_only=False,
+ continuous_actions=True,
+ )
+ )
+ )
+ env.seed(123)
+ assert env._seed == 123
+ obs = env.reset()
+ for k, v in obs.items():
+ print(k, v.shape)
+ for i in range(10):
+ random_action = env.random_action()
+ random_action = np.array([random_action[agent] for agent in random_action])
+ timestep = env.step(random_action)
+ print(timestep)
+ assert isinstance(timestep.obs, dict), timestep.obs
+ assert isinstance(timestep.obs['agent_state'], np.ndarray), timestep.obs
+ assert timestep.obs['agent_state'].shape == (
+ n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2
+ )
+ assert timestep.obs['global_state'].shape == (
+ n_agent * (2 + 2) + n_landmark * 2 + n_agent * (n_agent - 1) * 2,
+ )
+ assert timestep.obs['agent_alone_state'].shape == (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2)
+ assert timestep.obs['agent_alone_padding_state'].shape == (
+ n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2
+ )
+ assert timestep.obs['action_mask'].dtype == np.float32
+ assert isinstance(timestep.done, bool), timestep.done
+ assert isinstance(timestep.reward, np.ndarray), timestep.reward
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
+
+ def test_agent_specific_global_state(self):
+ n_agent = 5
+ n_landmark = n_agent
+ env = PettingZooEnv(
+ EasyDict(
+ dict(
+ env_family='mpe',
+ env_id='simple_spread_v2',
+ n_agent=n_agent,
+ n_landmark=n_landmark,
+ max_step=100,
+ agent_obs_only=False,
+ agent_specific_global_state=True,
+ continuous_actions=True,
+ )
+ )
+ )
+ env.seed(123)
+ assert env._seed == 123
+ obs = env.reset()
+ for k, v in obs.items():
+ print(k, v.shape)
+ for i in range(10):
+ random_action = env.random_action()
+ random_action = np.array([random_action[agent] for agent in random_action])
+ timestep = env.step(random_action)
+ print(timestep)
+ assert isinstance(timestep.obs, dict), timestep.obs
+ assert isinstance(timestep.obs['agent_state'], np.ndarray), timestep.obs
+ assert timestep.obs['agent_state'].shape == (
+ n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2
+ )
+ assert timestep.obs['global_state'].shape == (
+ n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2 + n_agent * (2 + 2) +
+ n_landmark * 2 + n_agent * (n_agent - 1) * 2
+ )
+ assert timestep.obs['agent_alone_state'].shape == (n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2)
+ assert timestep.obs['agent_alone_padding_state'].shape == (
+ n_agent, 2 + 2 + n_landmark * 2 + (n_agent - 1) * 2 + (n_agent - 1) * 2
+ )
+ assert isinstance(timestep.done, bool), timestep.done
+ assert isinstance(timestep.reward, np.ndarray), timestep.reward
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/pomdp/__init__.py b/DI-engine/dizoo/pomdp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/pomdp/config/pomdp_dqn_config.py b/DI-engine/dizoo/pomdp/config/pomdp_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..eac3bde86db158423b6764c9aedfbff1fdf68574
--- /dev/null
+++ b/DI-engine/dizoo/pomdp/config/pomdp_dqn_config.py
@@ -0,0 +1,64 @@
+from easydict import EasyDict
+
+pong_dqn_config = dict(
+ exp_name='pomdp_dqn_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='Pong-ramNoFrameskip-v4',
+ frame_stack=4,
+ warp_frame=False,
+ use_ram=True,
+ pomdp=dict(noise_scale=0.01, zero_p=0.2, reward_noise=0.01, duplicate_p=0.2),
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ priority=False,
+ model=dict(
+ obs_shape=[
+ 512,
+ ],
+ action_shape=6,
+ encoder_hidden_size_list=[128, 128, 512],
+ ),
+ nstep=3,
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=10,
+ batch_size=32,
+ learning_rate=0.0001,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=4000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+pong_dqn_config = EasyDict(pong_dqn_config)
+main_config = pong_dqn_config
+pong_dqn_create_config = dict(
+ env=dict(
+ type='pomdp',
+ import_names=['dizoo.pomdp.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='dqn'),
+)
+pong_dqn_create_config = EasyDict(pong_dqn_create_config)
+create_config = pong_dqn_create_config
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pomdp_dqn_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/pomdp/config/pomdp_ppo_config.py b/DI-engine/dizoo/pomdp/config/pomdp_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..30d0b6318a138394fb38d707419c0d982d5cee1d
--- /dev/null
+++ b/DI-engine/dizoo/pomdp/config/pomdp_ppo_config.py
@@ -0,0 +1,67 @@
+from easydict import EasyDict
+
+pong_ppo_config = dict(
+ exp_name='pomdp_ppo_seed0',
+ env=dict(
+ collector_env_num=16,
+ evaluator_env_num=8,
+ n_evaluator_episode=8,
+ stop_value=20,
+ env_id='Pong-ramNoFrameskip-v4',
+ frame_stack=4,
+ warp_frame=False,
+ use_ram=True,
+ pomdp=dict(noise_scale=0.01, zero_p=0.2, reward_noise=0.01, duplicate_p=0.2),
+ manager=dict(shared_memory=False, )
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[
+ 512,
+ ],
+ action_shape=6,
+ encoder_hidden_size_list=[512, 512, 256],
+ actor_head_hidden_size=256,
+ actor_head_layer_num=2,
+ critic_head_hidden_size=256,
+ critic_head_layer_num=2,
+ ),
+ learn=dict(
+ update_per_collect=16,
+ batch_size=128,
+ adv_norm=False,
+ learning_rate=0.0001,
+ value_weight=0.5,
+ entropy_weight=0.03,
+ clip_ratio=0.1,
+ ),
+ collect=dict(
+ n_sample=1024,
+ gae_lambda=0.97,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=200, )),
+ other=dict(replay_buffer=dict(
+ replay_buffer_size=100000,
+ max_use=3,
+ min_sample_ratio=1,
+ ), ),
+ ),
+)
+main_config = EasyDict(pong_ppo_config)
+
+pong_ppo_create_config = dict(
+ env=dict(
+ type='pomdp',
+ import_names=['dizoo.pomdp.envs.atari_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo_offpolicy'),
+)
+create_config = EasyDict(pong_ppo_create_config)
+
+if __name__ == '__main__':
+ # or you can enter `ding -m serial -c pomdp_ppo_config.py -s 0`
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/pomdp/envs/__init__.py b/DI-engine/dizoo/pomdp/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f77fa4d2161ee8323720b768d34d571c02c7208a
--- /dev/null
+++ b/DI-engine/dizoo/pomdp/envs/__init__.py
@@ -0,0 +1 @@
+from .atari_env import PomdpAtariEnv
diff --git a/DI-engine/dizoo/pomdp/envs/atari_env.py b/DI-engine/dizoo/pomdp/envs/atari_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..d73ec6008d8272f85c381e124791a2c8881a116c
--- /dev/null
+++ b/DI-engine/dizoo/pomdp/envs/atari_env.py
@@ -0,0 +1,121 @@
+from typing import Any, List, Union, Sequence
+import copy
+import gym
+import numpy as np
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.utils import ENV_REGISTRY
+from ding.torch_utils import to_ndarray, to_list
+from .atari_wrappers import wrap_deepmind
+
+from pprint import pprint
+
+
+def PomdpEnv(cfg, only_info=False):
+ '''
+ For debug purpose, create an env follow openai gym standard so it can be widely test by
+ other library with same environment setting in DI-engine
+ env = PomdpEnv(cfg)
+ obs = env.reset()
+ obs, reward, done, info = env.step(action)
+ '''
+ env = wrap_deepmind(
+ cfg.env_id,
+ frame_stack=cfg.frame_stack,
+ episode_life=cfg.is_train,
+ clip_rewards=cfg.is_train,
+ warp_frame=cfg.warp_frame,
+ use_ram=cfg.use_ram,
+ render=cfg.render,
+ pomdp=cfg.pomdp,
+ only_info=only_info,
+ )
+ return env
+
+
+@ENV_REGISTRY.register('pomdp')
+class PomdpAtariEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+
+ def reset(self) -> Sequence:
+ if not self._init_flag:
+ self._env = self._make_env(only_info=False)
+ self._init_flag = True
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ obs = self._env.reset()
+ obs = to_ndarray(obs)
+ self._eval_episode_return = 0.
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ assert isinstance(action, np.ndarray), type(action)
+ action = action.item()
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ obs = to_ndarray(obs)
+ rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def _make_env(self, only_info=False):
+ return wrap_deepmind(
+ self._cfg.env_id,
+ episode_life=self._cfg.is_train,
+ clip_rewards=self._cfg.is_train,
+ pomdp=self._cfg.pomdp,
+ frame_stack=self._cfg.frame_stack,
+ warp_frame=self._cfg.warp_frame,
+ use_ram=self._cfg.use_ram,
+ only_info=only_info,
+ )
+
+ def __repr__(self) -> str:
+ return "DI-engine POMDP Atari Env({})".format(self._cfg.env_id)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_env_num = cfg.pop('collector_env_num', 1)
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = True
+ return [cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_env_num = cfg.pop('evaluator_env_num', 1)
+ cfg = copy.deepcopy(cfg)
+ cfg.is_train = False
+ return [cfg for _ in range(evaluator_env_num)]
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
diff --git a/DI-engine/dizoo/pomdp/envs/atari_wrappers.py b/DI-engine/dizoo/pomdp/envs/atari_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..a85bce1b45b1e4dd89057fdcdd5806ac848b690c
--- /dev/null
+++ b/DI-engine/dizoo/pomdp/envs/atari_wrappers.py
@@ -0,0 +1,196 @@
+# Borrow a lot from openai baselines:
+# https://github.com/openai/baselines/blob/master/baselines/common/atari_wrappers.py
+
+import cv2
+import gym
+import numpy as np
+from collections import deque
+from copy import deepcopy
+from torch import float32
+import matplotlib.pyplot as plt
+
+from ding.envs import RamWrapper, NoopResetWrapper, MaxAndSkipWrapper, EpisodicLifeWrapper, FireResetWrapper, WarpFrameWrapper, ClipRewardWrapper, FrameStackWrapper
+
+
+class ScaledFloatFrameWrapper(gym.ObservationWrapper):
+ """Normalize observations to -1~1.
+
+ :param gym.Env env: the environment to wrap.
+ """
+
+ def __init__(self, env):
+ super().__init__(env)
+ low = np.min(env.observation_space.low)
+ high = np.max(env.observation_space.high)
+ self.bias = low
+ self.scale = high - low
+ self.observation_space = gym.spaces.Box(low=-1., high=1., shape=env.observation_space.shape, dtype=np.float32)
+
+ def observation(self, observation):
+ # use fixed scale and bias temporarily
+ return (observation - 128) / 128
+ # return (observation - self.bias) / self.scale
+
+
+class FrameStackWrapperRam(gym.Wrapper):
+ """Stack n_frames last frames.
+ :param gym.Env env: the environment to wrap.
+ :param int n_frames: the number of frames to stack.
+ """
+
+ def __init__(
+ self,
+ env,
+ n_frames,
+ pomdp={
+ "noise_scale": 0.01,
+ "zero_p": 0.2,
+ "duplicate_p": 0.2,
+ "reward_noise": 0.01
+ },
+ render=False
+ ):
+ super().__init__(env)
+ self.n_frames = n_frames
+ self.n_dims = env.observation_space.shape[0]
+ self._pomdp = pomdp
+ self._render = render
+ self.frames = deque([], maxlen=n_frames)
+ self._images = deque([], maxlen=n_frames)
+ self.viewer = None
+
+ shape = (n_frames * self.n_dims, )
+ self.observation_space = gym.spaces.Box(
+ low=np.min(env.observation_space.low),
+ high=np.max(env.observation_space.high),
+ shape=shape,
+ dtype=env.observation_space.dtype
+ )
+
+ def reset(self):
+ obs = self.env.reset()
+ for _ in range(self.n_frames):
+ self.frames.append(obs)
+ return self._get_ob()
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+ self.frames.append(obs)
+ reward = reward + self._pomdp["reward_noise"] * np.random.randn()
+
+ if self._render:
+ _img = self.env.unwrapped._get_image()
+ _img = _img.mean(axis=-1, keepdims=True).astype(np.uint8)
+ self._images.append(_img)
+ self.render()
+
+ return self._get_ob(), reward, done, info
+
+ def render(self):
+ from gym.envs.classic_control import rendering
+ state = np.stack(self._images, axis=0)
+ obs = self._pomdp_preprocess(state, img=True).astype(np.uint8)
+ obs = np.tile(obs[-1], (1, 1, 3))
+ if self.viewer is None:
+ self.viewer = rendering.SimpleImageViewer()
+ self.viewer.imshow(obs)
+ return self.viewer.isopen
+
+ def _get_ob(self):
+ # the original wrapper use `LazyFrames` but since we use np buffer,
+ # it has no effect
+ state = np.stack(self.frames, axis=0)
+ obs = self._pomdp_preprocess(state)
+
+ return obs.flatten()
+
+ def _pomdp_preprocess(self, state, img=False):
+ obs = deepcopy(state)
+ # POMDP process
+ if np.random.random() > (1 - self._pomdp["duplicate_p"]):
+ update_end_point = np.random.randint(
+ 1, self.n_frames
+ ) # choose a point from that point we can't get new observation
+ _s = (self.n_frames - update_end_point, 1, 1, 1)
+ obs[update_end_point:, ] = np.tile(obs[update_end_point, ], _s)
+
+ if img:
+ pomdp_noise_mask = self._pomdp["noise_scale"] * np.random.randn(*obs.shape) * 128
+ else:
+ pomdp_noise_mask = self._pomdp["noise_scale"] * np.random.randn(*obs.shape)
+
+ # Flickering Atari game
+ obs = obs * int(np.random.random() > self._pomdp["zero_p"]) + pomdp_noise_mask
+ return obs.astype(np.float32)
+
+
+def wrap_deepmind(
+ env_id,
+ episode_life=True,
+ clip_rewards=True,
+ pomdp={},
+ frame_stack=4,
+ scale=True,
+ warp_frame=True,
+ use_ram=False,
+ render=False,
+ only_info=False
+):
+ """Configure environment for DeepMind-style Atari. The observation is
+ channel-first: (c, h, w) instead of (h, w, c).
+
+ :param str env_id: the atari environment id.
+ :param bool episode_life: wrap the episode life wrapper.
+ :param bool clip_rewards: wrap the reward clipping wrapper.
+ :param int frame_stack: wrap the frame stacking wrapper.
+ :param bool scale: wrap the scaling observation wrapper.
+ :param bool warp_frame: wrap the grayscale + resize observation wrapper.
+ :param float pomdp: parameter to control POMDP prepropress,
+ :return: the wrapped atari environment.
+ """
+ assert 'NoFrameskip' in env_id
+ if not only_info:
+ env = gym.make(env_id)
+ env = RamWrapper(env)
+ env = NoopResetWrapper(env, noop_max=30)
+ env = MaxAndSkipWrapper(env, skip=4)
+ if episode_life:
+ env = EpisodicLifeWrapper(env)
+ if 'FIRE' in env.unwrapped.get_action_meanings():
+ env = FireResetWrapper(env)
+ if warp_frame:
+ env = WarpFrameWrapper(env)
+ if scale:
+ env = ScaledFloatFrameWrapper(env)
+ if clip_rewards:
+ env = ClipRewardWrapper(env)
+
+ if frame_stack:
+ if use_ram:
+ env = FrameStackWrapperRam(env, frame_stack, pomdp, render)
+ else:
+ env = FrameStackWrapper(env, frame_stack)
+
+ return env
+ else:
+ wrapper_info = RamWrapper.__name__ + '\n'
+ wrapper_info += NoopResetWrapper.__name__ + '\n'
+ wrapper_info += MaxAndSkipWrapper.__name__ + '\n'
+ if episode_life:
+ wrapper_info = EpisodicLifeWrapper.__name__ + '\n'
+ if 'Pong' in env_id or 'Qbert' in env_id or 'SpaceInvader' in env_id or 'Montezuma' in env_id:
+ wrapper_info = FireResetWrapper.__name__ + '\n'
+ if warp_frame:
+ wrapper_info = WarpFrameWrapper.__name__ + '\n'
+ if scale:
+ wrapper_info = ScaledFloatFrameWrapper.__name__ + '\n'
+ if clip_rewards:
+ wrapper_info = ClipRewardWrapper.__name__ + '\n'
+
+ if frame_stack:
+ if use_ram:
+ wrapper_info = FrameStackWrapperRam.__name__ + '\n'
+ else:
+ wrapper_info = FrameStackWrapper.__name__ + '\n'
+
+ return wrapper_info
diff --git a/DI-engine/dizoo/pomdp/envs/test_atari_env.py b/DI-engine/dizoo/pomdp/envs/test_atari_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..d18a98fac173cb62f71c802a789e1efa85cd2c76
--- /dev/null
+++ b/DI-engine/dizoo/pomdp/envs/test_atari_env.py
@@ -0,0 +1,35 @@
+import pytest
+import gym
+import numpy as np
+from easydict import EasyDict
+from dizoo.pomdp.envs import PomdpAtariEnv
+
+
+@pytest.mark.envtest
+def test_env():
+ cfg = {
+ 'env_id': 'Pong-ramNoFrameskip-v4',
+ 'frame_stack': 4,
+ 'is_train': True,
+ 'warp_frame': False,
+ 'clip_reward': False,
+ 'use_ram': True,
+ 'render': False,
+ 'pomdp': dict(noise_scale=0.001, zero_p=0.1, reward_noise=0.01, duplicate_p=0.2)
+ }
+
+ cfg = EasyDict(cfg)
+ pong_env = PomdpAtariEnv(cfg)
+ pong_env.seed(0)
+ obs = pong_env.reset()
+ act_dim = pong_env.info().act_space.shape[0]
+ while True:
+ random_action = np.random.choice(range(act_dim), size=(1, ))
+ timestep = pong_env.step(random_action)
+ assert timestep.obs.shape == (512, )
+ assert timestep.reward.shape == (1, )
+ # assert isinstance(timestep, tuple)
+ if timestep.done:
+ assert 'eval_episode_return' in timestep.info, timestep.info
+ break
+ pong_env.close()
diff --git a/DI-engine/dizoo/procgen/README.md b/DI-engine/dizoo/procgen/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..4582510a0540779dae21e6f2a6df0edb00dc20b6
--- /dev/null
+++ b/DI-engine/dizoo/procgen/README.md
@@ -0,0 +1,33 @@
+## Coinrun Environment
+
+Coinrun is a simple platformer. The goal is to collect the coin at the far right of the level, and the player spawns on the far left.
+The player must dodge stationary saw obstacles, enemies that pace back and forth, and chasms that lead to death.
+If coins are collected, 10 points will be awarded. If the player dies, or the game time exceeds the maximum allowable time, the game will end
+Note that while the previously released version of CoinRun painted velocity information directly onto observations, the current version does not. This makes the environment significantly more difficult.
+Procedural generation controls the number of platform sections, their corresponding types, the location of crates, and the location and types of obstacles.
+
+![original](./coinrun.png)
+
+## Train Coinrun with DI-engine
+
+DI-engine can achive 10 return on average within 2M episodes by DQN. The tuned example can be found in `dizoo/procgen/entry/coinrun_dqn_config.py`. The training episode return is as follows.
+
+![tb](./coinrun_dqn.svg)
+
+DI-engine can achive 10 return on average within 2M episodes by PPO. The tuned example can be found in `dizoo/procgen/entry/coinrun_ppo_config.py`. The training episode return is as follows.
+
+![tb](./coinrun_ppo.svg)
+
+## Maze Environment
+
+The player, a mouse, must navigate a maze to find the sole piece of cheese and earn a reward. The player may move up, down, left or right to navigate the maze.
+If cheese are collected, 10 points will be awarded. If the game time exceeds the maximum allowable time, the game will end.
+Procedural generation controls the level layout by generating mazes using Kruskal’s algorithm (Kruskal, 1956), uniformly ranging in size from 3x3 to 25x25.
+
+![original](./maze.png)
+
+## Train Maze with DI-engine
+
+DI-engine can achive 10 return on average within 7M episodes by DQN. The tuned example can be found in `dizoo/procgen/entry/maze_dqn_config.py`. The training episode return is as follows.
+
+![tb](./maze_dqn.svg)
diff --git a/DI-engine/dizoo/procgen/__init__.py b/DI-engine/dizoo/procgen/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/procgen/config/__init__.py b/DI-engine/dizoo/procgen/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7c0177d80cdf06ba3b145eb61d954aaaf54a40ab
--- /dev/null
+++ b/DI-engine/dizoo/procgen/config/__init__.py
@@ -0,0 +1,2 @@
+from .coinrun_dqn_config import main_config, create_config
+from .coinrun_ppo_config import main_config, create_config
diff --git a/DI-engine/dizoo/procgen/config/bigfish_plr_config.py b/DI-engine/dizoo/procgen/config/bigfish_plr_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d39afd709794f6addbb0e65edd30a081eb61fa78
--- /dev/null
+++ b/DI-engine/dizoo/procgen/config/bigfish_plr_config.py
@@ -0,0 +1,62 @@
+from easydict import EasyDict
+
+bigfish_plr_config = dict(
+ exp_name='bigfish_plr_seed1',
+ env=dict(
+ is_train=True,
+ control_level=False,
+ env_id='bigfish',
+ collector_env_num=64,
+ evaluator_env_num=10,
+ n_evaluator_episode=50,
+ stop_value=40,
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[3, 64, 64],
+ action_shape=15,
+ encoder_hidden_size_list=[16, 32, 32],
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ impala_cnn_encoder=True,
+ ),
+ learn=dict(
+ learning_rate=0.0005,
+ actor_epoch_per_collect=1,
+ critic_epoch_per_collect=1,
+ value_norm=True,
+ batch_size=16384,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ aux_freq=1,
+ ),
+ collect=dict(n_sample=16384, ),
+ eval=dict(evaluator=dict(eval_freq=96, )),
+ other=dict(),
+ ),
+ level_replay=dict(
+ strategy='min_margin',
+ score_transform='rank',
+ temperature=0.1,
+ ),
+)
+bigfish_plr_config = EasyDict(bigfish_plr_config)
+main_config = bigfish_plr_config
+
+bigfish_plr_create_config = dict(
+ env=dict(
+ type='procgen',
+ import_names=['dizoo.procgen.envs.procgen_env'],
+ ),
+ env_manager=dict(type='subprocess', ),
+ policy=dict(type='ppg'),
+)
+bigfish_plr_create_config = EasyDict(bigfish_plr_create_config)
+create_config = bigfish_plr_create_config
+
+if __name__ == "__main__":
+
+ from ding.entry.serial_entry_plr import serial_pipeline_plr
+ serial_pipeline_plr([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/procgen/config/bigfish_ppg_config.py b/DI-engine/dizoo/procgen/config/bigfish_ppg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..19ac4be698e8b380c4266581cc279a2cb1a4bfa7
--- /dev/null
+++ b/DI-engine/dizoo/procgen/config/bigfish_ppg_config.py
@@ -0,0 +1,57 @@
+from easydict import EasyDict
+
+bigfish_ppg_config = dict(
+ exp_name='bigfish_ppg_seed0',
+ env=dict(
+ is_train=True,
+ env_id='bigfish',
+ collector_env_num=64,
+ evaluator_env_num=10,
+ n_evaluator_episode=50,
+ stop_value=40,
+ manager=dict(shared_memory=True, ),
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[3, 64, 64],
+ action_shape=15,
+ encoder_hidden_size_list=[16, 32, 32],
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ impala_cnn_encoder=True,
+ ),
+ learn=dict(
+ learning_rate=0.0005,
+ actor_epoch_per_collect=1,
+ critic_epoch_per_collect=1,
+ value_norm=True,
+ batch_size=16384,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ aux_freq=1,
+ ),
+ collect=dict(n_sample=16384, ),
+ eval=dict(evaluator=dict(eval_freq=96, )),
+ other=dict(),
+ ),
+)
+bigfish_ppg_config = EasyDict(bigfish_ppg_config)
+main_config = bigfish_ppg_config
+
+bigfish_ppg_create_config = dict(
+ env=dict(
+ type='procgen',
+ import_names=['dizoo.procgen.envs.procgen_env'],
+ ),
+ env_manager=dict(type='subprocess', ),
+ policy=dict(type='ppg'),
+)
+bigfish_ppg_create_config = EasyDict(bigfish_ppg_create_config)
+create_config = bigfish_ppg_create_config
+
+if __name__ == "__main__":
+
+ from ding.entry import serial_pipeline_onpolicy_ppg
+ serial_pipeline_onpolicy_ppg([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/procgen/config/coinrun_dqn_config.py b/DI-engine/dizoo/procgen/config/coinrun_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4e788672b654af4d9640b28565ee9ff7cccd1ca5
--- /dev/null
+++ b/DI-engine/dizoo/procgen/config/coinrun_dqn_config.py
@@ -0,0 +1,51 @@
+from easydict import EasyDict
+
+coinrun_dqn_config = dict(
+ env=dict(
+ env_id='coinrun',
+ collector_env_num=4,
+ evaluator_env_num=4,
+ n_evaluator_episode=4,
+ stop_value=10,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=[3, 64, 64],
+ action_shape=15,
+ encoder_hidden_size_list=[128, 128, 512],
+ dueling=False,
+ ),
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ target_update_freq=500,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+coinrun_dqn_config = EasyDict(coinrun_dqn_config)
+main_config = coinrun_dqn_config
+
+coinrun_dqn_create_config = dict(
+ env=dict(
+ type='procgen',
+ import_names=['dizoo.procgen.envs.procgen_env'],
+ ),
+ env_manager=dict(type='subprocess', ),
+ policy=dict(type='dqn'),
+)
+coinrun_dqn_create_config = EasyDict(coinrun_dqn_create_config)
+create_config = coinrun_dqn_create_config
diff --git a/DI-engine/dizoo/procgen/config/coinrun_ppg_config.py b/DI-engine/dizoo/procgen/config/coinrun_ppg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..793c2128b0889cc9b4e57817ae02b0ffbda8e211
--- /dev/null
+++ b/DI-engine/dizoo/procgen/config/coinrun_ppg_config.py
@@ -0,0 +1,56 @@
+from easydict import EasyDict
+
+coinrun_ppg_config = dict(
+ exp_name='coinrun_ppg_seed0',
+ env=dict(
+ is_train=True,
+ env_id='coinrun',
+ collector_env_num=64,
+ evaluator_env_num=10,
+ n_evaluator_episode=50,
+ stop_value=10,
+ manager=dict(shared_memory=True, ),
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[3, 64, 64],
+ action_shape=15,
+ encoder_hidden_size_list=[16, 32, 32],
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ impala_cnn_encoder=True,
+ ),
+ learn=dict(
+ learning_rate=0.0005,
+ actor_epoch_per_collect=1,
+ critic_epoch_per_collect=1,
+ value_norm=False,
+ batch_size=2048,
+ value_weight=0.5,
+ entropy_weight=0.00,
+ clip_ratio=0.2,
+ aux_freq=1,
+ ),
+ collect=dict(n_sample=16384, ),
+ eval=dict(evaluator=dict(eval_freq=96, )),
+ other=dict(),
+ ),
+)
+coinrun_ppg_config = EasyDict(coinrun_ppg_config)
+main_config = coinrun_ppg_config
+
+coinrun_ppg_create_config = dict(
+ env=dict(
+ type='procgen',
+ import_names=['dizoo.procgen.envs.procgen_env'],
+ ),
+ env_manager=dict(type='subprocess', ),
+ policy=dict(type='ppg'),
+)
+coinrun_ppg_create_config = EasyDict(coinrun_ppg_create_config)
+create_config = coinrun_ppg_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_onpolicy_ppg
+ serial_pipeline_onpolicy_ppg([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/procgen/config/coinrun_ppo_config.py b/DI-engine/dizoo/procgen/config/coinrun_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4a04396b42d68539cd5aa9aa50c1cd43c0b4f03f
--- /dev/null
+++ b/DI-engine/dizoo/procgen/config/coinrun_ppo_config.py
@@ -0,0 +1,54 @@
+from easydict import EasyDict
+
+coinrun_ppo_config = dict(
+ env=dict(
+ is_train=True,
+ env_id='coinrun',
+ collector_env_num=4,
+ evaluator_env_num=4,
+ n_evaluator_episode=4,
+ stop_value=10,
+ ),
+ policy=dict(
+ cuda=False,
+ action_space='discrete',
+ model=dict(
+ obs_shape=[3, 64, 64],
+ action_space='discrete',
+ action_shape=15,
+ encoder_hidden_size_list=[32, 32, 64],
+ ),
+ learn=dict(
+ learning_rate=0.0001,
+ update_per_collect=5,
+ batch_size=64,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+coinrun_ppo_config = EasyDict(coinrun_ppo_config)
+main_config = coinrun_ppo_config
+
+coinrun_ppo_create_config = dict(
+ env=dict(
+ type='procgen',
+ import_names=['dizoo.procgen.envs.procgen_env'],
+ ),
+ env_manager=dict(type='subprocess', ),
+ policy=dict(type='ppo'),
+)
+coinrun_ppo_create_config = EasyDict(coinrun_ppo_create_config)
+create_config = coinrun_ppo_create_config
diff --git a/DI-engine/dizoo/procgen/config/maze_dqn_config.py b/DI-engine/dizoo/procgen/config/maze_dqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4cf674ca164c7d6c1320f8a80b93eb59c4084e6
--- /dev/null
+++ b/DI-engine/dizoo/procgen/config/maze_dqn_config.py
@@ -0,0 +1,52 @@
+from easydict import EasyDict
+
+maze_dqn_config = dict(
+ env=dict(
+ collector_env_num=4,
+ env_id='maze',
+ evaluator_env_num=4,
+ n_evaluator_episode=4,
+ stop_value=10,
+ ),
+ policy=dict(
+ cuda=False,
+ model=dict(
+ obs_shape=[3, 64, 64],
+ action_shape=15,
+ encoder_hidden_size_list=[128, 128, 512],
+ dueling=False,
+ ),
+ discount_factor=0.99,
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ target_update_freq=500,
+ discount_factor=0.99,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+maze_dqn_config = EasyDict(maze_dqn_config)
+main_config = maze_dqn_config
+
+maze_dqn_create_config = dict(
+ env=dict(
+ type='procgen',
+ import_names=['dizoo.procgen.envs.procgen_env'],
+ ),
+ env_manager=dict(type='subprocess', ),
+ policy=dict(type='dqn'),
+)
+maze_dqn_create_config = EasyDict(maze_dqn_create_config)
+create_config = maze_dqn_create_config
diff --git a/DI-engine/dizoo/procgen/config/maze_ppg_config.py b/DI-engine/dizoo/procgen/config/maze_ppg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..bb51dfd85d90b8ac05e3cbc939fc995806e07b03
--- /dev/null
+++ b/DI-engine/dizoo/procgen/config/maze_ppg_config.py
@@ -0,0 +1,59 @@
+from easydict import EasyDict
+
+maze_ppg_config = dict(
+ exp_name='maze_ppg_seed0',
+ env=dict(
+ is_train=True,
+ env_id='maze',
+ collector_env_num=64,
+ evaluator_env_num=10,
+ n_evaluator_episode=50,
+ stop_value=10,
+ manager=dict(shared_memory=True, ),
+ ),
+ policy=dict(
+ cuda=True,
+ model=dict(
+ obs_shape=[3, 64, 64],
+ action_shape=15,
+ encoder_hidden_size_list=[16, 32, 32],
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ impala_cnn_encoder=True,
+ ),
+ learn=dict(
+ learning_rate=0.0005,
+ actor_epoch_per_collect=1,
+ critic_epoch_per_collect=1,
+ value_norm=False,
+ batch_size=2048,
+ value_weight=1.0,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ aux_freq=1,
+ ),
+ collect=dict(
+ n_sample=16384,
+ discount_factor=0.99,
+ ),
+ eval=dict(evaluator=dict(eval_freq=24, )),
+ other=dict(),
+ ),
+)
+maze_ppg_config = EasyDict(maze_ppg_config)
+main_config = maze_ppg_config
+
+maze_ppg_create_config = dict(
+ env=dict(
+ type='procgen',
+ import_names=['dizoo.procgen.envs.procgen_env'],
+ ),
+ env_manager=dict(type='subprocess', ),
+ policy=dict(type='ppg'),
+)
+maze_ppg_create_config = EasyDict(maze_ppg_create_config)
+create_config = maze_ppg_create_config
+
+if __name__ == "__main__":
+ from ding.entry import serial_pipeline_onpolicy_ppg
+ serial_pipeline_onpolicy_ppg([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/procgen/config/maze_ppo_config.py b/DI-engine/dizoo/procgen/config/maze_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d116305f76a790ec7d1dbad0233cd52e1f5247ed
--- /dev/null
+++ b/DI-engine/dizoo/procgen/config/maze_ppo_config.py
@@ -0,0 +1,55 @@
+from easydict import EasyDict
+
+maze_ppo_config = dict(
+ env=dict(
+ # frame_stack=4,
+ is_train=True,
+ env_id='maze',
+ collector_env_num=4,
+ evaluator_env_num=4,
+ n_evaluator_episode=4,
+ stop_value=10,
+ ),
+ policy=dict(
+ cuda=False,
+ action_space='discrete',
+ model=dict(
+ obs_shape=[3, 64, 64],
+ action_shape=15,
+ action_space='discrete',
+ encoder_hidden_size_list=[32, 32, 64],
+ ),
+ learn=dict(
+ update_per_collect=5,
+ batch_size=64,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ learning_rate=0.0001,
+ ),
+ collect=dict(n_sample=100, ),
+ eval=dict(evaluator=dict(eval_freq=5000, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=1.,
+ end=0.05,
+ decay=250000,
+ ),
+ replay_buffer=dict(replay_buffer_size=100000, ),
+ ),
+ ),
+)
+maze_ppo_config = EasyDict(maze_ppo_config)
+main_config = maze_ppo_config
+
+maze_ppo_create_config = dict(
+ env=dict(
+ type='procgen',
+ import_names=['dizoo.procgen.envs.procgen_env'],
+ ),
+ env_manager=dict(type='subprocess', ),
+ policy=dict(type='ppo'),
+)
+maze_ppo_create_config = EasyDict(maze_ppo_create_config)
+create_config = maze_ppo_create_config
diff --git a/DI-engine/dizoo/procgen/entry/coinrun_onppo_main.py b/DI-engine/dizoo/procgen/entry/coinrun_onppo_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca132b1fa73110263e5d141565c16e5984442768
--- /dev/null
+++ b/DI-engine/dizoo/procgen/entry/coinrun_onppo_main.py
@@ -0,0 +1,113 @@
+import os
+from functools import partial
+
+import gym
+import numpy as np
+from easydict import EasyDict
+from tensorboardX import SummaryWriter
+
+from ding.torch_utils import to_ndarray
+from ding.worker import BaseLearner, SampleSerialCollector, InteractionSerialEvaluator
+from ding.model import VAC
+from ding.policy import PPOPolicy
+from ding.envs import DingEnvWrapper, EvalEpisodeReturnWrapper, BaseEnvManager
+from ding.config import compile_config
+from ding.utils import set_pkg_seed
+from dizoo.procgen.config.coinrun_ppo_config import coinrun_ppo_config
+
+
+class CoinrunWrapper(gym.Wrapper):
+
+ def __init__(self, env, cfg):
+ super().__init__(env)
+ cfg = EasyDict(cfg)
+ self._cfg = cfg
+ self._observation_space = gym.spaces.Box(
+ low=np.zeros(shape=(3, 64, 64)), high=np.ones(shape=(3, 64, 64)) * 255, shape=(3, 64, 64), dtype=np.float32
+ )
+ self._action_space = gym.spaces.Discrete(15)
+ self._reward_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32)
+
+ def _process_obs(self, obs):
+ obs = to_ndarray(obs)
+ obs = np.transpose(obs, (2, 0, 1))
+ obs = obs.astype(np.float32)
+ return obs
+
+ def step(self, action):
+ obs, reward, done, info = self.env.step(action)
+ return self._process_obs(obs), reward, bool(done), info
+
+ def reset(self):
+ obs = self.env.reset()
+ return self._process_obs(obs)
+
+
+def wrapped_procgen_env(cfg):
+ default_cfg = dict(
+ control_level=True,
+ start_level=0,
+ num_levels=0,
+ env_id='coinrun',
+ )
+ default_cfg.update(cfg)
+ default_cfg = EasyDict(default_cfg)
+
+ return DingEnvWrapper(
+ gym.make(
+ 'procgen:procgen-' + default_cfg.env_id + '-v0',
+ start_level=default_cfg.start_level,
+ num_levels=default_cfg.num_levels
+ ) if default_cfg.control_level else
+ gym.make('procgen:procgen-' + default_cfg.env_id + '-v0', start_level=0, num_levels=1),
+ cfg={
+ 'env_wrapper': [
+ lambda env: CoinrunWrapper(env, default_cfg),
+ lambda env: EvalEpisodeReturnWrapper(env),
+ ]
+ }
+ )
+
+
+def main(cfg, seed=0, max_env_step=int(1e10), max_train_iter=int(1e10)):
+ cfg = compile_config(
+ cfg, BaseEnvManager, PPOPolicy, BaseLearner, SampleSerialCollector, InteractionSerialEvaluator, save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env = BaseEnvManager(
+ env_fn=[partial(wrapped_procgen_env, cfg=coinrun_ppo_config.env) for _ in range(collector_env_num)],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManager(
+ env_fn=[partial(wrapped_procgen_env, cfg=coinrun_ppo_config.env) for _ in range(evaluator_env_num)],
+ cfg=cfg.env.manager
+ )
+
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ collector = SampleSerialCollector(
+ cfg.policy.collect.collector, collector_env, policy.collect_mode, tb_logger, exp_name=cfg.exp_name
+ )
+ evaluator = InteractionSerialEvaluator(
+ cfg.policy.eval.evaluator, evaluator_env, policy.eval_mode, tb_logger, exp_name=cfg.exp_name
+ )
+
+ while True:
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+ new_data = collector.collect(train_iter=learner.train_iter)
+ learner.train(new_data, collector.envstep)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+
+if __name__ == '__main__':
+ main(coinrun_ppo_config)
diff --git a/DI-engine/dizoo/procgen/envs/__init__.py b/DI-engine/dizoo/procgen/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..39559458757c532cf8348cb79d75278512ec699f
--- /dev/null
+++ b/DI-engine/dizoo/procgen/envs/__init__.py
@@ -0,0 +1 @@
+from .procgen_env import ProcgenEnv
diff --git a/DI-engine/dizoo/procgen/envs/procgen_env.py b/DI-engine/dizoo/procgen/envs/procgen_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..4b194f1d821317a0d52f40f388b714c2327df5df
--- /dev/null
+++ b/DI-engine/dizoo/procgen/envs/procgen_env.py
@@ -0,0 +1,114 @@
+from typing import Any, List, Union, Optional
+from easydict import EasyDict
+import time
+import gym
+import numpy as np
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.envs.common.env_element import EnvElement, EnvElementInfo
+from ding.torch_utils import to_ndarray, to_list
+from ding.utils import ENV_REGISTRY, deep_merge_dicts
+
+
+@ENV_REGISTRY.register('procgen')
+class ProcgenEnv(BaseEnv):
+
+ #If control_level is True, you can control the specific level of the generated environment by controlling start_level and num_level.
+ config = dict(
+ control_level=True,
+ start_level=0,
+ num_levels=0,
+ env_id='coinrun',
+ )
+
+ def __init__(self, cfg: dict) -> None:
+ cfg = deep_merge_dicts(EasyDict(self.config), cfg)
+ self._cfg = cfg
+ self._seed = 0
+ self._init_flag = False
+ self._observation_space = gym.spaces.Box(
+ low=np.zeros(shape=(3, 64, 64)), high=np.ones(shape=(3, 64, 64)) * 255, shape=(3, 64, 64), dtype=np.float32
+ )
+
+ self._action_space = gym.spaces.Discrete(15)
+
+ self._reward_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32)
+ self._control_level = self._cfg.control_level
+ self._start_level = self._cfg.start_level
+ self._num_levels = self._cfg.num_levels
+ self._env_name = 'procgen:procgen-' + self._cfg.env_id + '-v0'
+ # In procgen envs, we use seed to control level, and fix the numpy seed to 0
+ np.random.seed(0)
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ if self._control_level:
+ self._env = gym.make(self._env_name, start_level=self._start_level, num_levels=self._num_levels)
+ else:
+ self._env = gym.make(self._env_name, start_level=0, num_levels=1)
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.close()
+ if self._control_level:
+ self._env = gym.make(self._env_name, start_level=self._start_level, num_levels=self._num_levels)
+ else:
+ self._env = gym.make(self._env_name, start_level=self._seed + np_seed, num_levels=1)
+ elif hasattr(self, '_seed'):
+ self._env.close()
+ if self._control_level:
+ self._env = gym.make(self._env_name, start_level=self._start_level, num_levels=self._num_levels)
+ else:
+ self._env = gym.make(self._env_name, start_level=self._seed, num_levels=1)
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+ obs = to_ndarray(obs)
+ obs = np.transpose(obs, (2, 0, 1))
+ obs = obs.astype(np.float32)
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+
+ def step(self, action: np.ndarray) -> BaseEnvTimestep:
+ assert isinstance(action, np.ndarray), type(action)
+ if action.shape == (1, ):
+ action = action.squeeze() # 0-dim array
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ obs = to_ndarray(obs)
+ obs = np.transpose(obs, (2, 0, 1))
+ obs = obs.astype(np.float32)
+ rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
+ rew = rew.astype(np.float32)
+ return BaseEnvTimestep(obs, rew, bool(done), info)
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine CoinRun Env"
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+ self._env = gym.wrappers.Monitor(
+ self._env, self._replay_path, video_callable=lambda episode_id: True, force=True
+ )
diff --git a/DI-engine/dizoo/procgen/envs/test_coinrun_env.py b/DI-engine/dizoo/procgen/envs/test_coinrun_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..bc0c6aebc4664e98e0d6667c26dd1c162b250f2c
--- /dev/null
+++ b/DI-engine/dizoo/procgen/envs/test_coinrun_env.py
@@ -0,0 +1,25 @@
+import pytest
+import numpy as np
+from easydict import EasyDict
+from dizoo.procgen.envs import ProcgenEnv
+
+
+@pytest.mark.envtest
+class TestProcgenEnv:
+
+ def test_naive(self):
+ env = ProcgenEnv(EasyDict({}))
+ env.seed(314)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (3, 64, 64)
+ for i in range(10):
+ random_action = np.tanh(np.random.random(1))
+ timestep = env.step(random_action)
+ assert timestep.obs.shape == (3, 64, 64)
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.info().rew_space.value['min']
+ assert timestep.reward <= env.info().rew_space.value['max']
+ # assert isinstance(timestep, tuple)
+ print(env.info())
+ env.close()
diff --git a/DI-engine/dizoo/pybullet/__init__.py b/DI-engine/dizoo/pybullet/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/pybullet/envs/__init__.py b/DI-engine/dizoo/pybullet/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..28580bf2b148a6adbcd8dcda9859704571bb582e
--- /dev/null
+++ b/DI-engine/dizoo/pybullet/envs/__init__.py
@@ -0,0 +1 @@
+from .pybullet_env import PybulletEnv
diff --git a/DI-engine/dizoo/pybullet/envs/pybullet_env.py b/DI-engine/dizoo/pybullet/envs/pybullet_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..25def74a23fe6ef48a513f49668c680ad5a72632
--- /dev/null
+++ b/DI-engine/dizoo/pybullet/envs/pybullet_env.py
@@ -0,0 +1,376 @@
+from typing import Any, Union, List
+import copy
+import numpy as np
+from ditk import logging
+
+from ding.envs import BaseEnv, BaseEnvTimestep, BaseEnvInfo, update_shape
+from ding.envs.common.env_element import EnvElement, EnvElementInfo
+from ding.envs.common.common_function import affine_transform
+from ding.torch_utils import to_ndarray, to_list
+from ding.utils import ENV_REGISTRY
+from .pybullet_wrappers import wrap_pybullet
+
+Pybullet_INFO_DICT = {
+ # pybullet env
+ 'InvertedPendulumMuJoCoEnv-v0': BaseEnvInfo(
+ agent_num=1,
+ obs_space=EnvElementInfo(
+ shape=(4, ),
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ act_space=EnvElementInfo(
+ shape=(1, ),
+ value={
+ 'min': -1.0,
+ 'max': 1.0,
+ 'dtype': np.float32
+ },
+ ),
+ rew_space=EnvElementInfo(
+ shape=1,
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf")
+ },
+ ),
+ use_wrappers=None,
+ ),
+ 'InvertedDoublePendulumMuJoCoEnv-v0': BaseEnvInfo(
+ agent_num=1,
+ obs_space=EnvElementInfo(
+ shape=(11, ),
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ act_space=EnvElementInfo(
+ shape=(1, ),
+ value={
+ 'min': -1.0,
+ 'max': 1.0,
+ 'dtype': np.float32
+ },
+ ),
+ rew_space=EnvElementInfo(
+ shape=1,
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf")
+ },
+ ),
+ use_wrappers=None,
+ ),
+ 'Walker2DMuJoCoEnv-v0': BaseEnvInfo(
+ agent_num=1,
+ obs_space=EnvElementInfo(
+ shape=(17, ),
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ act_space=EnvElementInfo(
+ shape=(6, ),
+ value={
+ 'min': -1.0,
+ 'max': 1.0,
+ 'dtype': np.float32
+ },
+ ),
+ rew_space=EnvElementInfo(
+ shape=1,
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf")
+ },
+ ),
+ use_wrappers=None,
+ ),
+ 'Walker2DPyBulletEnv-v0': BaseEnvInfo(
+ agent_num=1,
+ obs_space=EnvElementInfo(
+ shape=(22, ),
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ act_space=EnvElementInfo(
+ shape=(6, ),
+ value={
+ 'min': -1.0,
+ 'max': 1.0,
+ 'dtype': np.float32
+ },
+ ),
+ rew_space=EnvElementInfo(
+ shape=1,
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf")
+ },
+ ),
+ use_wrappers=None,
+ ),
+ 'HalfCheetahMuJoCoEnv-v0': BaseEnvInfo(
+ agent_num=1,
+ obs_space=EnvElementInfo(
+ shape=(17, ),
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ act_space=EnvElementInfo(
+ shape=(6, ),
+ value={
+ 'min': -1.0,
+ 'max': 1.0,
+ 'dtype': np.float32
+ },
+ ),
+ rew_space=EnvElementInfo(
+ shape=1,
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf")
+ },
+ ),
+ use_wrappers=None,
+ ),
+ 'HalfCheetahPyBulletEnv-v0': BaseEnvInfo(
+ agent_num=1,
+ obs_space=EnvElementInfo(
+ shape=(26, ),
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ act_space=EnvElementInfo(
+ shape=(6, ),
+ value={
+ 'min': -1.0,
+ 'max': 1.0,
+ 'dtype': np.float32
+ },
+ ),
+ rew_space=EnvElementInfo(
+ shape=1,
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf")
+ },
+ ),
+ use_wrappers=None,
+ ),
+ 'AntMuJoCoEnv-v0': BaseEnvInfo(
+ agent_num=1,
+ obs_space=EnvElementInfo(
+ shape=(111, ),
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ act_space=EnvElementInfo(
+ shape=(8, ),
+ value={
+ 'min': -1.0,
+ 'max': 1.0,
+ 'dtype': np.float32
+ },
+ ),
+ rew_space=EnvElementInfo(
+ shape=1,
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf")
+ },
+ ),
+ use_wrappers=None,
+ ),
+ 'AntPyBulletEnv-v0': BaseEnvInfo(
+ agent_num=1,
+ obs_space=EnvElementInfo(
+ shape=(28, ),
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ act_space=EnvElementInfo(
+ shape=(8, ),
+ value={
+ 'min': -1.0,
+ 'max': 1.0,
+ 'dtype': np.float32
+ },
+ ),
+ rew_space=EnvElementInfo(
+ shape=1,
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf")
+ },
+ ),
+ use_wrappers=None,
+ ),
+ 'HopperMuJoCoEnv-v0': BaseEnvInfo(
+ agent_num=1,
+ obs_space=EnvElementInfo(
+ shape=(11, ),
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ act_space=EnvElementInfo(
+ shape=(3, ),
+ value={
+ 'min': -1.0,
+ 'max': 1.0,
+ 'dtype': np.float32
+ },
+ ),
+ rew_space=EnvElementInfo(
+ shape=1,
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf")
+ },
+ ),
+ use_wrappers=None,
+ ),
+ 'HopperPyBulletEnv-v0': BaseEnvInfo(
+ agent_num=1,
+ obs_space=EnvElementInfo(
+ shape=(15, ),
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf"),
+ 'dtype': np.float32
+ },
+ ),
+ act_space=EnvElementInfo(
+ shape=(3, ),
+ value={
+ 'min': -1.0,
+ 'max': 1.0,
+ 'dtype': np.float32
+ },
+ ),
+ rew_space=EnvElementInfo(
+ shape=1,
+ value={
+ 'min': np.float64("-inf"),
+ 'max': np.float64("inf")
+ },
+ ),
+ use_wrappers=None,
+ ),
+}
+
+
+@ENV_REGISTRY.register('pybullet')
+class PybulletEnv(BaseEnv):
+ """
+ Note:
+ Due to the open source of mujoco env, DI-engine will deprecate PyBullet env. If anyone needs it, \
+ please add a new issue and we will continue to maintain it.
+ """
+
+ def __init__(self, cfg: dict) -> None:
+ logging.warning('PybulletEnv is deprecated, if anyone needs it, please add a new issue.')
+ self._cfg = cfg
+ self._use_act_scale = cfg.use_act_scale
+ self._init_flag = False
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = self._make_env(only_info=False)
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ obs = self._env.reset()
+ obs = to_ndarray(obs).astype('float32')
+ self._eval_episode_return = 0.
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Union[np.ndarray, list]) -> BaseEnvTimestep:
+ action = to_ndarray(action)
+ if self._use_act_scale:
+ action_range = self.info().act_space.value
+ action = affine_transform(action, min_val=action_range['min'], max_val=action_range['max'])
+ obs, rew, done, info = self._env.step(action)
+ self._eval_episode_return += rew
+ obs = to_ndarray(obs).astype('float32')
+ rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def info(self) -> BaseEnvInfo:
+ if self._cfg.env_id in Pybullet_INFO_DICT:
+ info = copy.deepcopy(Pybullet_INFO_DICT[self._cfg.env_id])
+ info.use_wrappers = self._make_env(only_info=True)
+ obs_shape, act_shape, rew_shape = update_shape(
+ info.obs_space.shape, info.act_space.shape, info.rew_space.shape, info.use_wrappers.split('\n')
+ )
+ info.obs_space.shape = obs_shape
+ info.act_space.shape = act_shape
+ info.rew_space.shape = rew_shape
+ return info
+ else:
+ keys = Pybullet_INFO_DICT.keys()
+ raise NotImplementedError('{} not found in Pybullet_INFO_DICT [{}]'.format(self._cfg.env_id, keys))
+
+ def _make_env(self, only_info=False):
+ return wrap_pybullet(
+ self._cfg.env_id,
+ norm_obs=self._cfg.get('norm_obs', None),
+ norm_reward=self._cfg.get('norm_reward', None),
+ only_info=only_info
+ )
+
+ def __repr__(self) -> str:
+ return "DI-engine Pybullet Env({})".format(self._cfg.env_id)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_cfg = copy.deepcopy(cfg)
+ collector_env_num = collector_cfg.pop('collector_env_num', 1)
+ return [collector_cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_cfg = copy.deepcopy(cfg)
+ evaluator_env_num = evaluator_cfg.pop('evaluator_env_num', 1)
+ evaluator_cfg.norm_reward.use_norm = False
+ return [evaluator_cfg for _ in range(evaluator_env_num)]
diff --git a/DI-engine/dizoo/pybullet/envs/pybullet_wrappers.py b/DI-engine/dizoo/pybullet/envs/pybullet_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e844a2ff9961043247b91257380e3ce62fb26ed
--- /dev/null
+++ b/DI-engine/dizoo/pybullet/envs/pybullet_wrappers.py
@@ -0,0 +1,38 @@
+import gym
+import numpy as np
+from ditk import logging
+
+from ding.envs import ObsNormWrapper, RewardNormWrapper
+
+try:
+ import pybulletgym # register PyBullet enviroments with open ai gym
+except ImportError:
+ logging.warning("not found pybullet env, please install it, refer to https://github.com/benelot/pybullet-gym")
+
+
+def wrap_pybullet(env_id, norm_obs=True, norm_reward=True, only_info=False) -> gym.Env:
+ r"""
+ Overview:
+ Wrap Pybullet Env to preprocess env step's return info, e.g. observation normalization, reward normalization, etc.
+ Arguments:
+ - env_id (:obj:`str`): Pybullet environment id, for example "HalfCheetah-v3"
+ - norm_obs (:obj:`EasyDict`): Whether to normalize observation or not
+ - norm_reward (:obj:`EasyDict`): Whether to normalize reward or not. For evaluator, environment's reward \
+ should not be normalized: Either ``norm_reward`` is None or ``norm_reward.use_norm`` is False can do this.
+ Returns:
+ - wrapped_env (:obj:`gym.Env`): The wrapped Pybullet environment
+ """
+ if not only_info:
+ env = gym.make(env_id)
+ if norm_obs is not None and norm_obs.use_norm:
+ env = ObsNormWrapper(env)
+ if norm_reward is not None and norm_reward.use_norm:
+ env = RewardNormWrapper(env, norm_reward.reward_discount)
+ return env
+ else:
+ wrapper_info = ''
+ if norm_obs is not None and norm_obs.use_norm:
+ wrapper_info = ObsNormWrapper.__name__ + '\n'
+ if norm_reward is not None and norm_reward.use_norm:
+ wrapper_info = RewardNormWrapper.__name__ + '\n'
+ return wrapper_info
diff --git a/DI-engine/dizoo/rocket/README.md b/DI-engine/dizoo/rocket/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..c9e49e47d2d854ce319b2297c2429e1743cd3bee
--- /dev/null
+++ b/DI-engine/dizoo/rocket/README.md
@@ -0,0 +1,10 @@
+# Install
+
+```shell
+pip install git+https://github.com/nighood/rocket-recycling@master#egg=rocket_recycling
+```
+
+# Chek Install
+```shell
+pytest -sv test_rocket_env.py
+```
diff --git a/DI-engine/dizoo/rocket/__init__.py b/DI-engine/dizoo/rocket/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/rocket/config/__init__.py b/DI-engine/dizoo/rocket/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/rocket/config/rocket_hover_ppo_config.py b/DI-engine/dizoo/rocket/config/rocket_hover_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ad4f902ff13ed54710fef2292b5804621019bf6c
--- /dev/null
+++ b/DI-engine/dizoo/rocket/config/rocket_hover_ppo_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+rocket_ppo_config = dict(
+ exp_name='rocket_hovering_onppo_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=330,
+ task='hover',
+ max_steps=800,
+ replay_path='rocket_onppo_seed0/video',
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=8,
+ action_shape=9,
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ adv_norm=False,
+ value_norm=False,
+ learner=dict(hook=dict(save_ckpt_after_iter=100)),
+ ),
+ collect=dict(
+ n_sample=2048,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ ),
+)
+rocket_ppo_config = EasyDict(rocket_ppo_config)
+main_config = rocket_ppo_config
+rocket_ppo_create_config = dict(
+ env=dict(
+ type='rocket',
+ import_names=['dizoo.rocket.envs.rocket_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+rocket_ppo_create_config = EasyDict(rocket_ppo_create_config)
+create_config = rocket_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c rocket_hover_ppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/rocket/config/rocket_landing_ppo_config.py b/DI-engine/dizoo/rocket/config/rocket_landing_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..825c7f039d9886a5c00169f02fc2fcb290b13c01
--- /dev/null
+++ b/DI-engine/dizoo/rocket/config/rocket_landing_ppo_config.py
@@ -0,0 +1,61 @@
+from easydict import EasyDict
+
+rocket_ppo_config = dict(
+ exp_name='rocket_landing_onppo_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ stop_value=2200,
+ task='landing',
+ max_steps=800,
+ replay_path='rocket_landing_onppo_seed0/video',
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=8,
+ action_shape=9,
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64, 128],
+ critic_head_hidden_size=128,
+ actor_head_hidden_size=128,
+ ),
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=64,
+ learning_rate=3e-4,
+ value_weight=0.5,
+ entropy_weight=0.01,
+ clip_ratio=0.2,
+ adv_norm=True,
+ value_norm=True,
+ learner=dict(hook=dict(save_ckpt_after_iter=100)),
+ ),
+ collect=dict(
+ n_sample=2048,
+ unroll_len=1,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ eval=dict(evaluator=dict(eval_freq=1000, ), ),
+ ),
+)
+rocket_ppo_config = EasyDict(rocket_ppo_config)
+main_config = rocket_ppo_config
+rocket_ppo_create_config = dict(
+ env=dict(
+ type='rocket',
+ import_names=['dizoo.rocket.envs.rocket_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+rocket_ppo_create_config = EasyDict(rocket_ppo_create_config)
+create_config = rocket_ppo_create_config
+
+if __name__ == "__main__":
+ # or you can enter `ding -m serial_onpolicy -c rocket_landing_ppo_config.py -s 0`
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/rocket/entry/__init__.py b/DI-engine/dizoo/rocket/entry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/rocket/entry/rocket_hover_onppo_main_v2.py b/DI-engine/dizoo/rocket/entry/rocket_hover_onppo_main_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..503312f47b04c30576b655d6cd70d673b53cd469
--- /dev/null
+++ b/DI-engine/dizoo/rocket/entry/rocket_hover_onppo_main_v2.py
@@ -0,0 +1,95 @@
+import os
+import gym
+import numpy as np
+from tensorboardX import SummaryWriter
+import torch
+from rocket_recycling.rocket import Rocket
+
+from ditk import logging
+from ding.model import VAC
+from ding.policy import PPOPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2, EvalEpisodeReturnWrapper
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
+ gae_estimator, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.rocket.config.rocket_hover_ppo_config import main_config, create_config
+
+
+class RocketHoverWrapper(gym.Wrapper):
+
+ def __init__(self, env):
+ super().__init__(env)
+ self._observation_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(8, ), dtype=np.float32)
+ self._action_space = gym.spaces.Discrete(9)
+ self._action_space.seed(0) # default seed
+ self.reward_range = (float('-inf'), float('inf'))
+
+
+def wrapped_rocket_env(task, max_steps):
+ return DingEnvWrapper(
+ Rocket(task=task, max_steps=max_steps),
+ cfg={'env_wrapper': [
+ lambda env: RocketHoverWrapper(env),
+ lambda env: EvalEpisodeReturnWrapper(env),
+ ]}
+ )
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.policy.cuda = True
+ print('torch.cuda.is_available(): ', torch.cuda.is_available())
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ num_seed = 3
+ for seed_i in range(num_seed):
+ main_config.exp_name = f'task_rocket_hovering_onppo_seed{seed_i}'
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed' + str(seed_i)))
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[
+ lambda: wrapped_rocket_env(cfg.env.task, cfg.env.max_steps)
+ for _ in range(cfg.env.collector_env_num)
+ ],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[
+ lambda: wrapped_rocket_env(cfg.env.task, cfg.env.max_steps)
+ for _ in range(cfg.env.evaluator_env_num)
+ ],
+ cfg=cfg.env.manager
+ )
+
+ # evaluator_env.enable_save_replay()
+
+ set_pkg_seed(seed_i, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+
+ def _add_scalar(ctx):
+ if ctx.eval_value != -np.inf:
+ tb_logger.add_scalar('evaluator_step/reward', ctx.eval_value, global_step=ctx.env_step)
+ collector_rewards = [ctx.trajectories[i]['reward'] for i in range(len(ctx.trajectories))]
+ collector_mean_reward = sum(collector_rewards) / len(ctx.trajectories)
+ collector_max_reward = max(collector_rewards)
+ collector_min_reward = min(collector_rewards)
+ tb_logger.add_scalar('collecter_step/mean_reward', collector_mean_reward, global_step=ctx.env_step)
+ tb_logger.add_scalar('collecter_step/max_reward', collector_max_reward, global_step=ctx.env_step)
+ tb_logger.add_scalar('collecter_step/min_reward', collector_min_reward, global_step=ctx.env_step)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(_add_scalar)
+ task.use(gae_estimator(cfg, policy.collect_mode))
+ task.use(multistep_trainer(cfg, policy.learn_mode))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(termination_checker(max_env_step=int(10e7)))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/rocket/entry/rocket_hover_ppo_main.py b/DI-engine/dizoo/rocket/entry/rocket_hover_ppo_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..13f571448379b848ee082024dbdf832d755482fa
--- /dev/null
+++ b/DI-engine/dizoo/rocket/entry/rocket_hover_ppo_main.py
@@ -0,0 +1,68 @@
+from turtle import Terminator
+import gym
+from ditk import logging
+from ding.model import VAC
+from ding.policy import PPOPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
+ gae_estimator, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.rocket.envs.rocket_env import RocketEnv
+from dizoo.rocket.config.rocket_hover_ppo_config import main_config, create_config
+import numpy as np
+from tensorboardX import SummaryWriter
+import os
+import torch
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.policy.cuda = True
+ print('torch.cuda.is_available(): ', torch.cuda.is_available())
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ num_seed = 3
+ for seed_i in range(num_seed):
+ main_config.exp_name = f'task_rocket_hovering_onppo_seed{seed_i}'
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed' + str(seed_i)))
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: RocketEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: RocketEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ # evaluator_env.enable_save_replay()
+
+ set_pkg_seed(seed_i, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+
+ def _add_scalar(ctx):
+ if ctx.eval_value != -np.inf:
+ tb_logger.add_scalar('evaluator_step/reward', ctx.eval_value, global_step=ctx.env_step)
+ collector_rewards = [ctx.trajectories[i]['reward'] for i in range(len(ctx.trajectories))]
+ collector_mean_reward = sum(collector_rewards) / len(ctx.trajectories)
+ collector_max_reward = max(collector_rewards)
+ collector_min_reward = min(collector_rewards)
+ tb_logger.add_scalar('collecter_step/mean_reward', collector_mean_reward, global_step=ctx.env_step)
+ tb_logger.add_scalar('collecter_step/max_reward', collector_max_reward, global_step=ctx.env_step)
+ tb_logger.add_scalar('collecter_step/min_reward', collector_min_reward, global_step=ctx.env_step)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(_add_scalar)
+ task.use(gae_estimator(cfg, policy.collect_mode))
+ task.use(multistep_trainer(cfg, policy.learn_mode))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(termination_checker(max_env_step=int(10e7)))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/rocket/entry/rocket_landing_onppo_main_v2.py b/DI-engine/dizoo/rocket/entry/rocket_landing_onppo_main_v2.py
new file mode 100644
index 0000000000000000000000000000000000000000..bd682ecf42a8989a060ae03e4aa8d3430e19e147
--- /dev/null
+++ b/DI-engine/dizoo/rocket/entry/rocket_landing_onppo_main_v2.py
@@ -0,0 +1,95 @@
+import os
+import torch
+import gym
+import numpy as np
+from tensorboardX import SummaryWriter
+from rocket_recycling.rocket import Rocket
+
+from ditk import logging
+from ding.model import VAC
+from ding.policy import PPOPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2, EvalEpisodeReturnWrapper
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
+gae_estimator, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.rocket.config.rocket_landing_ppo_config import main_config, create_config
+
+
+class RocketLandingWrapper(gym.Wrapper):
+
+ def __init__(self, env):
+ super().__init__(env)
+ self._observation_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(8, ), dtype=np.float32)
+ self._action_space = gym.spaces.Discrete(9)
+ self._action_space.seed(0) # default seed
+ self.reward_range = (float('-inf'), float('inf'))
+
+
+def wrapped_rocket_env(task, max_steps):
+ return DingEnvWrapper(
+ Rocket(task=task, max_steps=max_steps),
+ cfg={'env_wrapper': [
+ lambda env: RocketLandingWrapper(env),
+ lambda env: EvalEpisodeReturnWrapper(env),
+ ]}
+ )
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'rocket_landing_ppo_nseed'
+ main_config.policy.cuda = True
+ print('torch.cuda.is_available(): ', torch.cuda.is_available())
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ num_seed = 4
+ for seed_i in range(num_seed):
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed' + str(seed_i)))
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[
+ lambda: wrapped_rocket_env(cfg.env.task, cfg.env.max_steps)
+ for _ in range(cfg.env.collector_env_num)
+ ],
+ cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[
+ lambda: wrapped_rocket_env(cfg.env.task, cfg.env.max_steps)
+ for _ in range(cfg.env.evaluator_env_num)
+ ],
+ cfg=cfg.env.manager
+ )
+
+ # evaluator_env.enable_save_replay()
+
+ set_pkg_seed(seed_i, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+
+ def _add_scalar(ctx):
+ if ctx.eval_value != -np.inf:
+ tb_logger.add_scalar('evaluator_step/reward', ctx.eval_value, global_step=ctx.env_step)
+ collector_rewards = [ctx.trajectories[i]['reward'] for i in range(len(ctx.trajectories))]
+ collector_mean_reward = sum(collector_rewards) / len(ctx.trajectories)
+ collector_max_reward = max(collector_rewards)
+ collector_min_reward = min(collector_rewards)
+ tb_logger.add_scalar('collecter_step/mean_reward', collector_mean_reward, global_step=ctx.env_step)
+ tb_logger.add_scalar('collecter_step/max_reward', collector_max_reward, global_step=ctx.env_step)
+ tb_logger.add_scalar('collecter_step/min_reward', collector_min_reward, global_step=ctx.env_step)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(gae_estimator(cfg, policy.collect_mode))
+ task.use(multistep_trainer(cfg, policy.learn_mode))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ # task.use(_add_scalar)
+ task.use(termination_checker(max_env_step=int(3e6)))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/rocket/entry/rocket_landing_ppo_main.py b/DI-engine/dizoo/rocket/entry/rocket_landing_ppo_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf8ebb51625b0a5137f2ac478749bec15c849921
--- /dev/null
+++ b/DI-engine/dizoo/rocket/entry/rocket_landing_ppo_main.py
@@ -0,0 +1,68 @@
+from turtle import Terminator
+import gym
+from ditk import logging
+from ding.model import VAC
+from ding.policy import PPOPolicy
+from ding.envs import DingEnvWrapper, BaseEnvManagerV2
+from ding.data import DequeBuffer
+from ding.config import compile_config
+from ding.framework import task
+from ding.framework.context import OnlineRLContext
+from ding.framework.middleware import multistep_trainer, StepCollector, interaction_evaluator, CkptSaver, \
+gae_estimator, termination_checker
+from ding.utils import set_pkg_seed
+from dizoo.rocket.envs.rocket_env import RocketEnv
+from dizoo.rocket.config.rocket_landing_ppo_config import main_config, create_config
+import numpy as np
+from tensorboardX import SummaryWriter
+import os
+import torch
+
+
+def main():
+ logging.getLogger().setLevel(logging.INFO)
+ main_config.exp_name = 'rocket_landing_ppo_nseed'
+ main_config.policy.cuda = True
+ print('torch.cuda.is_available(): ', torch.cuda.is_available())
+ cfg = compile_config(main_config, create_cfg=create_config, auto=True)
+ num_seed = 4
+ for seed_i in range(num_seed):
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'seed' + str(seed_i)))
+ with task.start(async_mode=False, ctx=OnlineRLContext()):
+ collector_env = BaseEnvManagerV2(
+ env_fn=[lambda: RocketEnv(cfg.env) for _ in range(cfg.env.collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = BaseEnvManagerV2(
+ env_fn=[lambda: RocketEnv(cfg.env) for _ in range(cfg.env.evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ # evaluator_env.enable_save_replay()
+
+ set_pkg_seed(seed_i, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+
+ def _add_scalar(ctx):
+ if ctx.eval_value != -np.inf:
+ tb_logger.add_scalar('evaluator_step/reward', ctx.eval_value, global_step=ctx.env_step)
+ collector_rewards = [ctx.trajectories[i]['reward'] for i in range(len(ctx.trajectories))]
+ collector_mean_reward = sum(collector_rewards) / len(ctx.trajectories)
+ collector_max_reward = max(collector_rewards)
+ collector_min_reward = min(collector_rewards)
+ tb_logger.add_scalar('collecter_step/mean_reward', collector_mean_reward, global_step=ctx.env_step)
+ tb_logger.add_scalar('collecter_step/max_reward', collector_max_reward, global_step=ctx.env_step)
+ tb_logger.add_scalar('collecter_step/min_reward', collector_min_reward, global_step=ctx.env_step)
+
+ task.use(interaction_evaluator(cfg, policy.eval_mode, evaluator_env))
+ task.use(StepCollector(cfg, policy.collect_mode, collector_env))
+ task.use(gae_estimator(cfg, policy.collect_mode))
+ task.use(multistep_trainer(cfg, policy.learn_mode))
+ task.use(CkptSaver(policy, cfg.exp_name, train_freq=100))
+ task.use(_add_scalar)
+ task.use(termination_checker(max_env_step=int(3e6)))
+ task.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/DI-engine/dizoo/rocket/envs/__init__.py b/DI-engine/dizoo/rocket/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..47b076e8afda29a7e0437292704f2ceb2e6a19dd
--- /dev/null
+++ b/DI-engine/dizoo/rocket/envs/__init__.py
@@ -0,0 +1 @@
+from .rocket_env import RocketEnv
diff --git a/DI-engine/dizoo/rocket/envs/rocket_env.py b/DI-engine/dizoo/rocket/envs/rocket_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..dd77fdafa7b802629425dd1166e9a26d8342f4aa
--- /dev/null
+++ b/DI-engine/dizoo/rocket/envs/rocket_env.py
@@ -0,0 +1,111 @@
+from typing import Any, List, Union, Optional
+import time
+import os
+import imageio
+import gym
+import copy
+import numpy as np
+from easydict import EasyDict
+from rocket_recycling.rocket import Rocket
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.torch_utils import to_ndarray, to_list
+from ding.utils import ENV_REGISTRY
+from ding.envs import ObsPlusPrevActRewWrapper
+
+
+@ENV_REGISTRY.register('rocket', force_overwrite=True)
+class RocketEnv(BaseEnv):
+
+ def __init__(self, cfg: dict = {}) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ self._save_replay = False
+ self._observation_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(8, ), dtype=np.float32)
+ self._action_space = gym.spaces.Discrete(9)
+ self._action_space.seed(0) # default seed
+ self._reward_space = gym.spaces.Box(low=float("-inf"), high=float("inf"), shape=(1, ), dtype=np.float32)
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = Rocket(task=self._cfg.task, max_steps=self._cfg.max_steps)
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ self._action_space.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._action_space.seed(self._seed)
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+ obs = to_ndarray(obs)
+ if self._save_replay:
+ self._frames = []
+ return obs
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def step(self, action: Union[int, np.ndarray]) -> BaseEnvTimestep:
+ if isinstance(action, np.ndarray) and action.shape == (1, ):
+ action = action.squeeze() # 0-dim array
+
+ obs, rew, done, info = self._env.step(action)
+ self._env.render()
+ self._eval_episode_return += rew
+
+ if self._save_replay:
+ self._frames.extend(self._env.render())
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ if self._save_replay:
+ path = os.path.join(self._replay_path, '{}_episode.gif'.format(self._save_replay_count))
+ self.display_frames_as_gif(self._frames, path)
+ self._save_replay_count += 1
+ obs = to_ndarray(obs)
+ # wrapped to be transfered to a array with shape (1,)
+ rew = to_ndarray([rew]).astype(np.float32)
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._save_replay = True
+ if not os.path.exists(replay_path):
+ os.makedirs(replay_path)
+ self._replay_path = replay_path
+ self._save_replay_count = 0
+
+ def random_action(self) -> np.ndarray:
+ random_action = self.action_space.sample()
+ random_action = to_ndarray([random_action], dtype=np.int64)
+ return random_action
+
+ def clone(self, caller: str) -> 'RocketEnv':
+ return RocketEnv(copy.deepcopy(self._cfg))
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ def __repr__(self) -> str:
+ return "DI-engine Rocket Env"
+
+ @staticmethod
+ def display_frames_as_gif(frames: list, path: str) -> None:
+ imageio.mimsave(path, frames, fps=20)
diff --git a/DI-engine/dizoo/rocket/envs/test_rocket_env.py b/DI-engine/dizoo/rocket/envs/test_rocket_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..a8bf030fe774fc52113672ec6c522dacb8166d46
--- /dev/null
+++ b/DI-engine/dizoo/rocket/envs/test_rocket_env.py
@@ -0,0 +1,36 @@
+import pytest
+import numpy as np
+from dizoo.rocket.envs import RocketEnv
+from easydict import EasyDict
+
+
+@pytest.mark.envtest
+class TestRocketEnv:
+
+ def test_hover(self):
+ env = RocketEnv(EasyDict({'task': 'hover', 'max_steps': 800}))
+ env.seed(314, dynamic_seed=False)
+ assert env._seed == 314
+ obs = env.reset()
+ assert obs.shape == (8, )
+ for _ in range(5):
+ env.reset()
+ np.random.seed(314)
+ print('=' * 60)
+ for i in range(10):
+ # Both ``env.random_action()``, and utilizing ``np.random`` as well as action space,
+ # can generate legal random action.
+ if i < 5:
+ random_action = np.array([env.action_space.sample()])
+ else:
+ random_action = env.random_action()
+ timestep = env.step(random_action)
+ print('timestep', timestep, '\n')
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (8, )
+ assert timestep.reward.shape == (1, )
+ assert timestep.reward >= env.reward_space.low
+ assert timestep.reward <= env.reward_space.high
+ print(env.observation_space, env.action_space, env.reward_space)
+ env.close()
diff --git a/DI-engine/dizoo/slime_volley/__init__.py b/DI-engine/dizoo/slime_volley/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/slime_volley/config/slime_volley_league_ppo_config.py b/DI-engine/dizoo/slime_volley/config/slime_volley_league_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5675a768f1637907d4f4285b0c31be49844179c5
--- /dev/null
+++ b/DI-engine/dizoo/slime_volley/config/slime_volley_league_ppo_config.py
@@ -0,0 +1,70 @@
+from easydict import EasyDict
+
+league_demo_ppo_config = dict(
+ exp_name="slime_volley_league_ppo_seed0",
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ # we don't set agent_vs_agent field because it should be set in entry for different usage.
+ stop_value=5, # 5 times per episode
+ env_id="SlimeVolley-v0",
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=12,
+ action_shape=6,
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64],
+ critic_head_hidden_size=64,
+ actor_head_hidden_size=64,
+ share_encoder=False, # It is not wise to share encoder in low-dimension observation.
+ ),
+ learn=dict(
+ epoch_per_collect=5,
+ batch_size=256,
+ learning_rate=3e-4,
+ entropy_weight=0.001, # [0.01, 0.001, 0.0]
+ clip_ratio=0.2,
+ ),
+ collect=dict(
+ n_episode=16,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ other=dict(
+ league=dict(
+ player_category=['default'],
+ # path to save policy of league player, user can specify this field, such as:
+ # path_policy="slime_volley_league_ppo_seed0/policy"
+ path_policy="policy_path_placeholer",
+ active_players=dict(main_player=1, ),
+ main_player=dict(
+ one_phase_step=20000,
+ branch_probs=dict(pfsp=0.2, sp=0.8),
+ strong_win_rate=0.7,
+ ),
+ use_pretrain=False,
+ use_pretrain_init_historical=False,
+ payoff=dict(
+ type='battle',
+ decay=0.99,
+ min_win_rate_games=4,
+ ),
+ metric=dict(
+ mu=0,
+ sigma=25 / 3,
+ beta=25 / 3 / 2,
+ tau=0.0,
+ draw_probability=0.02,
+ ),
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(league_demo_ppo_config)
+# this config can be executed by two entry function for different usage
+# - dizoo/slime_volley/entry/slime_volley_selfplay_ppo_main.py
+# - dizoo/slime_volley/entry/slime_volley_league_ppo_main.py
diff --git a/DI-engine/dizoo/slime_volley/config/slime_volley_ppo_config.py b/DI-engine/dizoo/slime_volley/config/slime_volley_ppo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..90e73b76cb8a2a7073b8da7db392cec640f54a6e
--- /dev/null
+++ b/DI-engine/dizoo/slime_volley/config/slime_volley_ppo_config.py
@@ -0,0 +1,53 @@
+from easydict import EasyDict
+from ding.entry import serial_pipeline_onpolicy
+
+slime_volley_ppo_config = dict(
+ exp_name='slime_volley_ppo_seed0',
+ env=dict(
+ collector_env_num=8,
+ evaluator_env_num=5,
+ n_evaluator_episode=5,
+ agent_vs_agent=False, # vs bot
+ stop_value=5, # 5 times per episode
+ env_id="SlimeVolley-v0",
+ ),
+ policy=dict(
+ cuda=True,
+ action_space='discrete',
+ model=dict(
+ obs_shape=12,
+ action_shape=6,
+ action_space='discrete',
+ encoder_hidden_size_list=[64, 64],
+ critic_head_hidden_size=64,
+ actor_head_hidden_size=64,
+ share_encoder=False, # It is not wise to share encoder in low-dimension observation.
+ ),
+ learn=dict(
+ epoch_per_collect=5,
+ batch_size=64,
+ learning_rate=3e-4,
+ entropy_weight=0.0, # [0.01, 0.0]
+ ),
+ collect=dict(
+ n_sample=4096,
+ discount_factor=0.99,
+ gae_lambda=0.95,
+ ),
+ ),
+)
+slime_volley_ppo_config = EasyDict(slime_volley_ppo_config)
+main_config = slime_volley_ppo_config
+slime_volley_ppo_create_config = dict(
+ env=dict(
+ type='slime_volley',
+ import_names=['dizoo.slime_volley.envs.slime_volley_env'],
+ ),
+ env_manager=dict(type='subprocess'), # if you want to save replay, it must use base
+ policy=dict(type='ppo'),
+)
+slime_volley_ppo_create_config = EasyDict(slime_volley_ppo_create_config)
+create_config = slime_volley_ppo_create_config
+
+if __name__ == "__main__":
+ serial_pipeline_onpolicy([main_config, create_config], seed=0)
diff --git a/DI-engine/dizoo/slime_volley/entry/slime_volley_league_ppo_main.py b/DI-engine/dizoo/slime_volley/entry/slime_volley_league_ppo_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..8af1fc0656e9bb40738a990ad4250baf7819cdcd
--- /dev/null
+++ b/DI-engine/dizoo/slime_volley/entry/slime_volley_league_ppo_main.py
@@ -0,0 +1,191 @@
+import os
+import gym
+import numpy as np
+import copy
+import shutil
+import torch
+from tensorboardX import SummaryWriter
+from functools import partial
+from easydict import EasyDict
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, BattleEpisodeSerialCollector, NaiveReplayBuffer, InteractionSerialEvaluator
+from ding.envs import SyncSubprocessEnvManager
+from ding.policy import PPOPolicy
+from ding.model import VAC
+from ding.utils import set_pkg_seed
+from ding.league import BaseLeague, ActivePlayer
+from dizoo.slime_volley.envs import SlimeVolleyEnv
+from dizoo.slime_volley.config.slime_volley_league_ppo_config import main_config
+
+
+class MyLeague(BaseLeague):
+ # override
+ def _get_job_info(self, player: ActivePlayer, eval_flag: bool = False) -> dict:
+ assert isinstance(player, ActivePlayer), player.__class__
+ player_job_info = EasyDict(player.get_job(eval_flag))
+ return {
+ 'agent_num': 2,
+ 'launch_player': player.player_id,
+ 'player_id': [player.player_id, player_job_info.opponent.player_id],
+ 'checkpoint_path': [player.checkpoint_path, player_job_info.opponent.checkpoint_path],
+ 'player_active_flag': [isinstance(p, ActivePlayer) for p in [player, player_job_info.opponent]],
+ }
+
+ # override
+ def _mutate_player(self, player: ActivePlayer):
+ # no mutate operation
+ pass
+
+ # override
+ def _update_player(self, player: ActivePlayer, player_info: dict) -> None:
+ assert isinstance(player, ActivePlayer)
+ if 'learner_step' in player_info:
+ player.total_agent_step = player_info['learner_step']
+
+ # override
+ @staticmethod
+ def save_checkpoint(src_checkpoint_path: str, dst_checkpoint_path: str) -> None:
+ shutil.copy(src_checkpoint_path, dst_checkpoint_path)
+
+
+def main(cfg, seed=0):
+ cfg = compile_config(
+ cfg,
+ SyncSubprocessEnvManager,
+ PPOPolicy,
+ BaseLearner,
+ BattleEpisodeSerialCollector,
+ InteractionSerialEvaluator,
+ NaiveReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env_cfg = copy.deepcopy(cfg.env)
+ collector_env_cfg.agent_vs_agent = True
+ evaluator_env_cfg = copy.deepcopy(cfg.env)
+ evaluator_env_cfg.agent_vs_agent = False
+ evaluator_env = SyncSubprocessEnvManager(
+ env_fn=[partial(SlimeVolleyEnv, evaluator_env_cfg) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env.seed(seed, dynamic_seed=False)
+
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ league = MyLeague(cfg.policy.other.league)
+ policies, learners, collectors = {}, {}, {}
+
+ for player_id in league.active_players_ids:
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ policies[player_id] = policy
+ collector_env = SyncSubprocessEnvManager(
+ env_fn=[partial(SlimeVolleyEnv, collector_env_cfg) for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ collector_env.seed(seed)
+
+ learners[player_id] = BaseLearner(
+ cfg.policy.learn.learner,
+ policy.learn_mode,
+ tb_logger,
+ exp_name=cfg.exp_name,
+ instance_name=player_id + '_learner'
+ )
+ collectors[player_id] = BattleEpisodeSerialCollector(
+ cfg.policy.collect.collector,
+ collector_env, [policy.collect_mode, policy.collect_mode],
+ tb_logger,
+ exp_name=cfg.exp_name,
+ instance_name=player_id + '_collector'
+ )
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+ policies['historical'] = policy
+ main_key = [k for k in learners.keys() if k.startswith('main_player')][0]
+ main_player = league.get_player_by_id(main_key)
+ main_learner = learners[main_key]
+ main_collector = collectors[main_key]
+
+ # eval vs bot
+ evaluator_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
+ evaluator_cfg.stop_value = cfg.env.stop_value
+ evaluator = InteractionSerialEvaluator(
+ evaluator_cfg,
+ evaluator_env,
+ policy.eval_mode,
+ tb_logger,
+ exp_name=cfg.exp_name,
+ instance_name='builtin_ai_evaluator'
+ )
+
+ def load_checkpoint_fn(player_id: str, ckpt_path: str):
+ state_dict = torch.load(ckpt_path)
+ policies[player_id].learn_mode.load_state_dict(state_dict)
+
+ league.load_checkpoint = load_checkpoint_fn
+ # snapshot the initial player as the first historial player
+ for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts):
+ torch.save(policies[player_id].collect_mode.state_dict(), player_ckpt_path)
+ league.judge_snapshot(player_id, force=True)
+
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ count = 0
+ while True:
+ if evaluator.should_eval(main_learner.train_iter):
+ stop_flag, eval_episode_info = evaluator.eval(
+ main_learner.save_checkpoint, main_learner.train_iter, main_collector.envstep
+ )
+ win_loss_result = [e['result'] for e in eval_episode_info]
+ # set eval bot rating as 100
+ main_player.rating = league.metric_env.rate_1vsC(
+ main_player.rating, league.metric_env.create_rating(mu=100, sigma=1e-8), win_loss_result
+ )
+ if stop_flag:
+ break
+ for player_id, player_ckpt_path in zip(league.active_players_ids, league.active_players_ckpts):
+ tb_logger.add_scalar(
+ 'league/{}_trueskill'.format(player_id),
+ league.get_player_by_id(player_id).rating.exposure, main_collector.envstep
+ )
+ collector, learner = collectors[player_id], learners[player_id]
+
+ job = league.get_job_info(player_id)
+ opponent_player_id = job['player_id'][1]
+ # print('job player: {}'.format(job['player_id']))
+ if 'historical' in opponent_player_id:
+ opponent_policy = policies['historical'].collect_mode
+ opponent_path = job['checkpoint_path'][1]
+ opponent_policy.load_state_dict(torch.load(opponent_path, map_location='cpu'))
+ else:
+ opponent_policy = policies[opponent_player_id].collect_mode
+ collector.reset_policy([policies[player_id].collect_mode, opponent_policy])
+
+ new_data, episode_info = collector.collect(
+ train_iter=learner.train_iter, n_episode=cfg.policy.collect.n_episode
+ )
+ train_data = sum(new_data[0], []) # sum all episodes
+ learner.train(train_data, collector.envstep)
+
+ player_info = learner.learn_info
+ player_info['player_id'] = player_id
+ league.update_active_player(player_info)
+ league.judge_snapshot(player_id)
+ # set eval_flag=True to enable trueskill update
+ job_finish_info = {
+ 'eval_flag': True,
+ 'launch_player': job['launch_player'],
+ 'player_id': job['player_id'],
+ # result is from `info` returned from env.step
+ 'result': [e['result'] for e in episode_info[0]],
+ }
+ league.finish_job(job_finish_info)
+ if count % 50 == 0:
+ payoff_string = repr(league.payoff)
+ rank_string = league.player_rank(string=True)
+ tb_logger.add_text('payoff_step', payoff_string, main_collector.envstep)
+ tb_logger.add_text('rank_step', rank_string, main_collector.envstep)
+ count += 1
+
+
+if __name__ == "__main__":
+ main(main_config)
diff --git a/DI-engine/dizoo/slime_volley/entry/slime_volley_selfplay_ppo_main.py b/DI-engine/dizoo/slime_volley/entry/slime_volley_selfplay_ppo_main.py
new file mode 100644
index 0000000000000000000000000000000000000000..e27c19b5ebbf7beab40f1d27a60cca9a9d253e38
--- /dev/null
+++ b/DI-engine/dizoo/slime_volley/entry/slime_volley_selfplay_ppo_main.py
@@ -0,0 +1,87 @@
+import os
+import gym
+import numpy as np
+import copy
+import torch
+from tensorboardX import SummaryWriter
+from functools import partial
+
+from ding.config import compile_config
+from ding.worker import BaseLearner, BattleSampleSerialCollector, NaiveReplayBuffer, InteractionSerialEvaluator
+from ding.envs import SyncSubprocessEnvManager
+from ding.policy import PPOPolicy
+from ding.model import VAC
+from ding.utils import set_pkg_seed
+from dizoo.slime_volley.envs import SlimeVolleyEnv
+from dizoo.slime_volley.config.slime_volley_ppo_config import main_config
+
+
+def main(cfg, seed=0, max_iterations=int(1e10)):
+ """
+ Overview:
+ Naive self-play, no any historial player.
+ """
+ cfg = compile_config(
+ cfg,
+ SyncSubprocessEnvManager,
+ PPOPolicy,
+ BaseLearner,
+ BattleSampleSerialCollector,
+ InteractionSerialEvaluator,
+ NaiveReplayBuffer,
+ save_cfg=True
+ )
+ collector_env_num, evaluator_env_num = cfg.env.collector_env_num, cfg.env.evaluator_env_num
+ collector_env_cfg = copy.deepcopy(cfg.env)
+ collector_env_cfg.agent_vs_agent = True
+ evaluator_env_cfg = copy.deepcopy(cfg.env)
+ evaluator_env_cfg.agent_vs_agent = False
+ collector_env = SyncSubprocessEnvManager(
+ env_fn=[partial(SlimeVolleyEnv, collector_env_cfg) for _ in range(collector_env_num)], cfg=cfg.env.manager
+ )
+ evaluator_env = SyncSubprocessEnvManager(
+ env_fn=[partial(SlimeVolleyEnv, evaluator_env_cfg) for _ in range(evaluator_env_num)], cfg=cfg.env.manager
+ )
+
+ collector_env.seed(seed)
+ evaluator_env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+
+ model = VAC(**cfg.policy.model)
+ policy = PPOPolicy(cfg.policy, model=model)
+
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(
+ cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name, instance_name='learner1'
+ )
+ collector = BattleSampleSerialCollector(
+ cfg.policy.collect.collector,
+ collector_env, [policy.collect_mode, policy.collect_mode],
+ tb_logger,
+ exp_name=cfg.exp_name
+ )
+ evaluator_cfg = copy.deepcopy(cfg.policy.eval.evaluator)
+ evaluator_cfg.stop_value = cfg.env.stop_value
+ evaluator = InteractionSerialEvaluator(
+ evaluator_cfg,
+ evaluator_env,
+ policy.eval_mode,
+ tb_logger,
+ exp_name=cfg.exp_name,
+ instance_name='builtin_ai_evaluator'
+ )
+
+ learner.call_hook('before_run')
+ for _ in range(max_iterations):
+ if evaluator.should_eval(learner.train_iter):
+ stop_flag, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop_flag:
+ break
+ new_data, _ = collector.collect(train_iter=learner.train_iter)
+ train_data = new_data[0] + new_data[1]
+ learner.train(train_data, collector.envstep)
+ learner.call_hook('after_run')
+
+
+if __name__ == "__main__":
+ main(main_config)
diff --git a/DI-engine/dizoo/slime_volley/envs/__init__.py b/DI-engine/dizoo/slime_volley/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..7fc6e04830a7d4b7ea4dddc1c5e19416d2f82086
--- /dev/null
+++ b/DI-engine/dizoo/slime_volley/envs/__init__.py
@@ -0,0 +1 @@
+from .slime_volley_env import SlimeVolleyEnv
diff --git a/DI-engine/dizoo/slime_volley/envs/slime_volley_env.py b/DI-engine/dizoo/slime_volley/envs/slime_volley_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..866cc8e5d7561d35e551b0dd82e3de75dacdc173
--- /dev/null
+++ b/DI-engine/dizoo/slime_volley/envs/slime_volley_env.py
@@ -0,0 +1,197 @@
+import numpy as np
+import gym
+from typing import Any, Union, List, Optional
+import copy
+import slimevolleygym
+from gym.envs.registration import registry
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.utils import ENV_REGISTRY
+from ding.torch_utils import to_ndarray
+
+
+@ENV_REGISTRY.register('slime_volley')
+class SlimeVolleyEnv(BaseEnv):
+
+ def __init__(self, cfg) -> None:
+ self._cfg = cfg
+ self._init_flag = False
+ self._replay_path = None
+ # agent_vs_bot env is single-agent env. obs, action, done, info are all single.
+ # agent_vs_agent env is double-agent env, obs, action, info are double, done is still single.
+ self._agent_vs_agent = cfg.agent_vs_agent
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def step(self, action: Union[np.ndarray, List[np.ndarray]]) -> BaseEnvTimestep:
+ if self._agent_vs_agent:
+ assert isinstance(action, List) and all([isinstance(e, np.ndarray) for e in action])
+ action1, action2 = action[0], action[1]
+ else:
+ assert isinstance(action, np.ndarray)
+ action1, action2 = action, None
+ assert isinstance(action1, np.ndarray), type(action1)
+ assert action2 is None or isinstance(action1, np.ndarray), type(action2)
+ if action1.shape == (1, ):
+ action1 = action1.squeeze() # 0-dim array
+ if action2 is not None and action2.shape == (1, ):
+ action2 = action2.squeeze() # 0-dim array
+ action1 = SlimeVolleyEnv._process_action(action1)
+ action2 = SlimeVolleyEnv._process_action(action2)
+ # gym version >= 0.22.0 only support action in one variable,
+ # So we have to put two actions into one tuple.
+ obs1, rew, done, info = self._env.step((action1, action2))
+ obs1 = to_ndarray(obs1).astype(np.float32)
+ self._eval_episode_return += rew
+ # info ('ale.lives', 'ale.otherLives', 'otherObs', 'state', 'otherState')
+ if self._agent_vs_agent:
+ info = [
+ {
+ 'ale.lives': info['ale.lives'],
+ 'state': info['state']
+ }, {
+ 'ale.lives': info['ale.otherLives'],
+ 'state': info['otherState'],
+ 'obs': info['otherObs']
+ }
+ ]
+ if done:
+ info[0]['eval_episode_return'] = self._eval_episode_return
+ info[1]['eval_episode_return'] = -self._eval_episode_return
+ info[0]['result'] = self.get_episode_result(self._eval_episode_return)
+ info[1]['result'] = self.get_episode_result(-self._eval_episode_return)
+ else:
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ info['result'] = self.get_episode_result(self._eval_episode_return)
+ reward = to_ndarray([rew]).astype(np.float32)
+ if self._agent_vs_agent:
+ obs2 = info[1]['obs']
+ obs2 = to_ndarray(obs2).astype(np.float32)
+ observations = np.stack([obs1, obs2], axis=0)
+ rewards = to_ndarray([rew, -rew]).astype(np.float32)
+ rewards = rewards[..., np.newaxis]
+ return BaseEnvTimestep(observations, rewards, done, info)
+ else:
+ return BaseEnvTimestep(obs1, reward, done, info)
+
+ def get_episode_result(self, eval_episode_return: float):
+ if eval_episode_return > 0: # due to using 5 games (lives) in this env, the eval_episode_return can't be zero.
+ return "wins"
+ else:
+ return "losses"
+
+ def reset(self):
+ if not self._init_flag:
+ self._env = gym.make(self._cfg.env_id)
+
+ if self._replay_path is not None:
+ if gym.version.VERSION > '0.22.0':
+ # Gym removed classic control rendering to support using pygame instead.
+ # And thus, slime_volleyball currently do not support rendering.
+ self._env.metadata.update({'render_modes': ["human"]})
+ else:
+ self._env.metadata.update({'render.modes': ["human"]})
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+ self._env.start_video_recorder()
+
+ ori_shape = self._env.observation_space.shape
+ self._observation_space = gym.spaces.Box(
+ low=float("-inf"),
+ high=float("inf"),
+ shape=(len(self.agents), ) + ori_shape if len(self.agents) >= 2 else ori_shape,
+ dtype=np.float32
+ )
+ self._action_space = gym.spaces.Discrete(6)
+ self._reward_space = gym.spaces.Box(low=-5, high=5, shape=(1, ), dtype=np.float32)
+ self._init_flag = True
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ self._eval_episode_return = 0
+ obs = self._env.reset()
+ obs = to_ndarray(obs).astype(np.float32)
+ if self._agent_vs_agent:
+ obs = np.stack([obs, obs], axis=0)
+ return obs
+ else:
+ return obs
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
+
+ @property
+ def agents(self) -> List[str]:
+ if self._agent_vs_agent:
+ return ['home', 'away']
+ else:
+ return ['home']
+
+ def random_action(self) -> np.ndarray:
+ high = self.action_space.n
+ if self._agent_vs_agent:
+ return [np.random.randint(0, high, size=(1, )) for _ in range(2)]
+ else:
+ return np.random.randint(0, high, size=(1, ))
+
+ def __repr__(self):
+ return "DI-engine Slime Volley Env"
+
+ def enable_save_replay(self, replay_path: Optional[str] = None) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._replay_path = replay_path
+
+ @staticmethod
+ def _process_action(action: np.ndarray, _type: str = "binary") -> np.ndarray:
+ if action is None:
+ return None
+ action = action.item()
+ # Env receives action in [0, 5] (int type). Can translater into:
+ # 1) "binary" type: np.array([0, 1, 0])
+ # 2) "atari" type: NOOP, LEFT, UPLEFT, UP, UPRIGHT, RIGHT
+ to_atari_action = {
+ 0: 0, # NOOP
+ 1: 4, # LEFT
+ 2: 7, # UPLEFT
+ 3: 2, # UP
+ 4: 6, # UPRIGHT
+ 5: 3, # RIGHT
+ }
+ to_binary_action = {
+ 0: [0, 0, 0], # NOOP
+ 1: [1, 0, 0], # LEFT (forward)
+ 2: [1, 0, 1], # UPLEFT (forward jump)
+ 3: [0, 0, 1], # UP (jump)
+ 4: [0, 1, 1], # UPRIGHT (backward jump)
+ 5: [0, 1, 0], # RIGHT (backward)
+ }
+ if _type == "binary":
+ return to_ndarray(to_binary_action[action])
+ elif _type == "atari":
+ return to_atari_action[action]
+ else:
+ raise NotImplementedError
diff --git a/DI-engine/dizoo/slime_volley/envs/test_slime_volley_env.py b/DI-engine/dizoo/slime_volley/envs/test_slime_volley_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..88a089a7eb33c623bae7c28adbb127a5ceba2fe3
--- /dev/null
+++ b/DI-engine/dizoo/slime_volley/envs/test_slime_volley_env.py
@@ -0,0 +1,33 @@
+import pytest
+import numpy as np
+from easydict import EasyDict
+
+from dizoo.slime_volley.envs.slime_volley_env import SlimeVolleyEnv
+
+
+@pytest.mark.envtest
+class TestSlimeVolley:
+
+ @pytest.mark.parametrize('agent_vs_agent', [True, False])
+ def test_slime_volley(self, agent_vs_agent):
+ total_return = 0
+ env = SlimeVolleyEnv(EasyDict({'env_id': 'SlimeVolley-v0', 'agent_vs_agent': agent_vs_agent}))
+ # env.enable_save_replay('replay_video')
+ obs1 = env.reset()
+ print(env.observation_space)
+ print('observation is like:', obs1)
+ done = False
+ while not done:
+ action = env.random_action()
+ observations, rewards, done, infos = env.step(action)
+ if agent_vs_agent:
+ total_return += rewards[0]
+ else:
+ total_return += rewards
+ obs1, obs2 = observations[0], observations[1]
+ assert obs1.shape == obs2.shape, (obs1.shape, obs2.shape)
+ if agent_vs_agent:
+ agent_lives, opponent_lives = infos[0]['ale.lives'], infos[1]['ale.lives']
+ if agent_vs_agent:
+ assert agent_lives == 0 or opponent_lives == 0, (agent_lives, opponent_lives)
+ print("total return is:", total_return)
diff --git a/DI-engine/dizoo/smac/README.md b/DI-engine/dizoo/smac/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..96c7ff21b984c454c4e1ed273055a99ed86d78ee
--- /dev/null
+++ b/DI-engine/dizoo/smac/README.md
@@ -0,0 +1,98 @@
+## PYSC2 Env
+DI-engine uses standard pysc2 env, you can install it as follow:
+```shell
+pip install pysc2
+```
+
+## SMAC Benchmark
+
+==setting: SC2 version=4.6.2.69232, difficulty=7, 2M env step==
+
+
+| 3s5z | pymarl | |DI-engine | | cfg |
+| :----: | :------: | :--: | :------: | :------: | :----------------------------------------------------------: |
+| | win rate | time | win rate | time | |
+| qmix | 1 | 9.5h | **1** | **3.2h** | dizoo/smac/config/smac_3s5z_qmix_config.py |
+| collaq | 1 | 28h | 0.9 | **8.5h** | dizoo/smac/config/smac_3s5z_collaq_config.py |
+| coma | 0 | 2.7h | **0.9** | **2.9h** | dizoo/smac/config/smac_3s5z_coma_config.py |
+| qtran | 0.1 | 11.5h | **0.9** | **4h** | dizoo/smac/config/smac_3s5z_qtran_config.py |
+| ippo | 0.15 | 10.5h | **0.8** | **6.8h** | |
+| mappo(ours) | - | - | **1** | **2.4h** | dizoo/smac/config/smac_3s5z_mappo_config.py |
+| masac(ours) | - | - | **1** | **4.4h** | dizoo/smac/config/smac_3s5z_masac_config.py |
+
+| 5m_vs_6m | pymarl | |DI-engine | | cfg |
+| :-------: | :------: | :--: | :------: | :------: | :----------------------------------------------------------: |
+| | win rate | time | win rate | time | |
+| qmix | **0.76** | 7.5h | 0.6 | **6.5h** | dizoo/smac/config/smac_5m6m_qmix_config.py |
+| collaq | 0.8 | 24h | 0.7 | **9.5h** | dizoo/smac/config/smac_5m6m_collaq_config.py |
+| coma | 0 | 2.5h | 0 | - | |
+| qtran | 0.7 | 7h | 0.55 | **5.5h** | dizoo/smac/config/smac_5m6m_qtran_config.py |
+| ippo | 0 | 9.2h | **0.75** | **6.9h** | |
+| mappo(ours) | - | - | **0.75** | **3.2h** | dizoo/smac/config/smac_5m6m_mappo_config.py |
+| masac(ours) | - | - | **1** | **5.2h** | dizoo/smac/config/smac_5m6m_masac_config.py |
+
+| MMM | pymarl | |DI-engine | | cfg |
+| :----: | :------: | :--: | :------: | :------: | :----------------------------------------------------------: |
+| | win rate | time | win rate | time | |
+| qmix | 1 | 9.5h | **1** | **3.5h** | dizoo/smac/config/smac_MMM_qmix_config.py |
+| collaq | 1 | 38h | **1** | **6.7h** | dizoo/smac/config/smac_MMM_collaq_config.py |
+| coma | 0.1 | 3h | **0.9** | **2.6h** | dizoo/smac/config/smac_MMM_coma_config.py |
+| qtran | 1 | 8.5h | **1** | **5.5h** | dizoo/smac/config/smac_MMM_qtran_config.py |
+| ippo | 0.33 | 7.2h | **1** | **4.7h** | |
+| mappo(ours) | - | - | **1** | **2.7h** | dizoo/smac/config/smac_MMM_mappo_config.py |
+| masac(ours) | - | - | **1** | **5.2h** | dizoo/smac/config/smac_MMM_masac_config.py |
+
+
+| MMM2 | pymarl | |DI-engine | | cfg |
+| :----: | :------: | :--: | :------: | :------: | :----------------------------------------------------------: |
+| | win rate | time | win rate | time | |
+| qmix | 0.7 | 10h | 0.4 | **5.5h** | dizoo/smac/config/smac_MMM2_qmix_config.py |
+| collaq | 0.9 | 24h | 0.6 | **13h** | dizoo/smac/config/smac_MMM2_collaq_config.py |
+| coma | 0 | 3h | **0.2** | 3.5h | dizoo/smac/config/smac_MMM2_coma_config.py |
+| qtran | 0 | 8.5h | 0 | - | |
+| ippo | 0 | 8.3h | **0.875** | **6h** | |
+| mappo(ours) | - | - | **1** | **3.8h** | dizoo/smac/config/smac_MMM2_mappo_config.py |
+| masac(ours) | - | - | **1** | **7.2h** | dizoo/smac/config/smac_MMM2_masac_config.py |
+
+
+| 3s5z_vs_3s6z | MAPPO(Wu) | |DI-engine | | cfg |
+| :----: | :------: | :--: | :------: | :------: | :----------------------------------------------------------: |
+| | win rate | time | win rate | time | |
+| mappo(ours) | - | - | **0.88** | **3.8h** | dizoo/smac/config/smac_3s5zvs3s6z_mappo_config.py |
+| masac(ours) | - | - | **1** | **7.2h** | dizoo/smac/config/smac_3s5zvs3s6z_masac_config.py |
+
+| 8m_vs_9m | MAPPO(Wu) | |DI-engine | | cfg |
+| :----: | :------: | :--: | :------: | :------: | :----------------------------------------------------------: |
+| | win rate | time | win rate | time | |
+| mappo(ours) | - | - | **1** | **3.6h** | dizoo/smac/config/smac_3s5zvs3s6z_mappo_config.py |
+| masac(ours) | - | - | **1** | **6.7h** | dizoo/smac/config/smac_3s5zvs3s6z_masac_config.py |
+
+| 10m_vs_11m | MAPPO(Wu) | |DI-engine | | cfg |
+| :----: | :------: | :--: | :------: | :------: | :----------------------------------------------------------: |
+| | win rate | time | win rate | time | |
+| mappo(ours) | - | - | **1** | **3.9h** | dizoo/smac/config/smac_10m11m_mappo_config.py |
+| masac(ours) | - | - | **1** | **6.9h** | dizoo/smac/config/smac_10m11m_masac_config.py |
+
+
+| 25m | MAPPO(Wu) | |DI-engine | | cfg |
+| :----: | :------: | :--: | :------: | :------: | :----------------------------------------------------------: |
+| | win rate | time | win rate | time | |
+| mappo(ours) | - | - | **1** | **3.7h** | dizoo/smac/config/smac_25m_mappo_config.py |
+| masac(ours) | - | - | **1** | **6.4h** | dizoo/smac/config/smac_25m_masac_config.py |
+
+
+| 2c_vs_64zg | MAPPO(Wu) | |DI-engine | | cfg |
+| :----: | :------: | :--: | :------: | :------: | :----------------------------------------------------------: |
+| | win rate | time | win rate | time | |
+| mappo(ours) | - | - | **1** | **3.2h** | dizoo/smac/config/smac_2c64zg_mappo_config.py |
+| masac(ours) | - | - | **1** | **6.1h** | dizoo/smac/config/smac_2c64zg_masac_config.py |
+
+
+| corridor | MAPPO(Wu) | |DI-engine | | cfg |
+| :----: | :------: | :--: | :------: | :------: | :----------------------------------------------------------: |
+| | win rate | time | win rate | time | |
+| mappo(ours) | - | - | **1** | **2.9h** | dizoo/smac/config/smac_corridor_mappo_config.py |
+| masac(ours) | - | - | **1** | **5.9h** | dizoo/smac/config/smac_corridor_masac_config.py |
+
+
+comment: The time in the table is the time to run 2M env step.
diff --git a/DI-engine/dizoo/smac/__init__.py b/DI-engine/dizoo/smac/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/smac/config/smac_10m11m_mappo_config.py b/DI-engine/dizoo/smac/config/smac_10m11m_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0dd28240eea09532db2748c0f3738d2c244b98bb
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_10m11m_mappo_config.py
@@ -0,0 +1,95 @@
+from easydict import EasyDict
+
+agent_num = 27
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+main_config = dict(
+ exp_name='smac_10m11m_mappo_seed0',
+ env=dict(
+ map_name='10m_vs_11m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=False,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=True,
+ multi_agent=True,
+ continuous=False,
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=agent_num,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ agent_obs_shape=132,
+ global_obs_shape=347,
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=17,
+ # (List[int]) The size of hidden layer
+ # hidden_size_list=[64],
+ # delete encode in code
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=512,
+ ),
+ # used in state_num of hidden_state
+ learn=dict(
+ epoch_per_collect=5,
+ batch_size=3200,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(
+ evaluator=dict(eval_freq=100, ),
+ env_num=evaluator_env_num,
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_10m11m_masac_config.py b/DI-engine/dizoo/smac/config/smac_10m11m_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..af80d1db0d0736e0caea73181719723772bc57ef
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_10m11m_masac_config.py
@@ -0,0 +1,91 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+SMAC_10m11m_masac_default_config = dict(
+ exp_name='smac_10m11m_masac_seed0',
+ env=dict(
+ map_name='10m_vs_11m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=False,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ random_collect_size=0,
+ model=dict(
+ agent_obs_shape=132,
+ global_obs_shape=347,
+ action_shape=17,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-5,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ log_space=True,
+ ),
+ collect=dict(
+ env_num=collector_env_num,
+ n_sample=1600,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(
+ evaluator=dict(eval_freq=50, ),
+ env_num=evaluator_env_num,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(replay_buffer_size=50000, ),
+ ),
+ ),
+)
+
+SMAC_10m11m_masac_default_config = EasyDict(SMAC_10m11m_masac_default_config)
+main_config = SMAC_10m11m_masac_default_config
+
+SMAC_10m11m_masac_default_create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='sac_discrete', ),
+)
+SMAC_10m11m_masac_default_create_config = EasyDict(SMAC_10m11m_masac_default_create_config)
+create_config = SMAC_10m11m_masac_default_create_config
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_25m_mappo_config.py b/DI-engine/dizoo/smac/config/smac_25m_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fd9e6638a888e56d65a7ac91342d2d501b4f77ba
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_25m_mappo_config.py
@@ -0,0 +1,95 @@
+from easydict import EasyDict
+
+agent_num = 25
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+main_config = dict(
+ exp_name='smac_25m_mappo_seed0',
+ env=dict(
+ map_name='25m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=False,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=True,
+ multi_agent=True,
+ continuous=False,
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=agent_num,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ agent_obs_shape=306,
+ global_obs_shape=1199,
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=31,
+ # (List[int]) The size of hidden layer
+ # hidden_size_list=[64],
+ # delete encode in code
+ actor_head_hidden_size=512,
+ critic_head_hidden_size=1024,
+ ),
+ # used in state_num of hidden_state
+ learn=dict(
+ epoch_per_collect=5,
+ batch_size=3200,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(
+ evaluator=dict(eval_freq=100, ),
+ env_num=evaluator_env_num,
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_25m_masac_config.py b/DI-engine/dizoo/smac/config/smac_25m_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b6e279a0ba1e3205cb79eac5b21c6ebb93c4afe
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_25m_masac_config.py
@@ -0,0 +1,91 @@
+from easydict import EasyDict
+
+agent_num = 25
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+SMAC_25m_masac_default_config = dict(
+ exp_name='smac_25m_masac_seed0',
+ env=dict(
+ map_name='25m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=False,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ random_collect_size=0,
+ model=dict(
+ agent_obs_shape=306,
+ global_obs_shape=1199,
+ action_shape=31,
+ twin_critic=True,
+ actor_head_hidden_size=512,
+ critic_head_hidden_size=1024,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-5,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ log_space=True,
+ ),
+ collect=dict(
+ env_num=collector_env_num,
+ n_sample=1600,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(
+ evaluator=dict(eval_freq=50, ),
+ env_num=evaluator_env_num,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(replay_buffer_size=50000, ),
+ ),
+ ),
+)
+
+SMAC_25m_masac_default_config = EasyDict(SMAC_25m_masac_default_config)
+main_config = SMAC_25m_masac_default_config
+
+SMAC_25m_masac_default_create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='sac_discrete', ),
+)
+SMAC_25m_masac_default_create_config = EasyDict(SMAC_25m_masac_default_create_config)
+create_config = SMAC_25m_masac_default_create_config
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_27m30m_mappo_config.py b/DI-engine/dizoo/smac/config/smac_27m30m_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..14caadd256c6fe26221f0536093b4f45c956d9da
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_27m30m_mappo_config.py
@@ -0,0 +1,95 @@
+from easydict import EasyDict
+
+agent_num = 27
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+main_config = dict(
+ exp_name='smac_27m30m_mappo_seed0',
+ env=dict(
+ map_name='27m_vs_30m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=False,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=True,
+ multi_agent=True,
+ continuous=False,
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=agent_num,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ agent_obs_shape=348,
+ global_obs_shape=1454,
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=36,
+ # (List[int]) The size of hidden layer
+ # hidden_size_list=[64],
+ # delete encode in code
+ actor_head_hidden_size=512,
+ critic_head_hidden_size=1024,
+ ),
+ # used in state_num of hidden_state
+ learn=dict(
+ epoch_per_collect=5,
+ batch_size=3200,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(
+ evaluator=dict(eval_freq=100, ),
+ env_num=evaluator_env_num,
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_2c64zg_mappo_config.py b/DI-engine/dizoo/smac/config/smac_2c64zg_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..ac29489d9c46f017ce962c3e8b533106907392ca
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_2c64zg_mappo_config.py
@@ -0,0 +1,90 @@
+from easydict import EasyDict
+
+agent_num = 2
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+main_config = dict(
+ exp_name='smac_2c64zg_mappo_seed0',
+ env=dict(
+ map_name='2c_vs_64zg',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=True,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ continuous=False,
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=agent_num,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ agent_obs_shape=404,
+ global_obs_shape=671,
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=70,
+ # (List[int]) The size of hidden layer
+ # hidden_size_list=[64],
+ action_space='discrete',
+ ),
+ # used in state_num of hidden_state
+ learn=dict(
+ epoch_per_collect=5,
+ batch_size=3200,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ on_policy=True,
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=50, )),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_2c64zg_masac_config.py b/DI-engine/dizoo/smac/config/smac_2c64zg_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fc729c23a5119ad3bfae16d4af2d1a7c622c070e
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_2c64zg_masac_config.py
@@ -0,0 +1,91 @@
+from easydict import EasyDict
+
+agent_num = 2
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+SMAC_2c64zg_masac_default_config = dict(
+ exp_name='smac_2c64zg_masac_seed0',
+ env=dict(
+ map_name='2c_vs_64zg',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=True,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ random_collect_size=0,
+ model=dict(
+ agent_obs_shape=404,
+ global_obs_shape=671,
+ action_shape=70,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-5,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ log_space=True,
+ ),
+ collect=dict(
+ env_num=collector_env_num,
+ n_sample=1600,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(
+ evaluator=dict(eval_freq=50, ),
+ env_num=evaluator_env_num,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=int(1e5),
+ ),
+ replay_buffer=dict(replay_buffer_size=int(1e6), ),
+ ),
+ ),
+)
+
+SMAC_2c64zg_masac_default_config = EasyDict(SMAC_2c64zg_masac_default_config)
+main_config = SMAC_2c64zg_masac_default_config
+
+SMAC_2c64zg_masac_default_create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='sac_discrete', ),
+)
+SMAC_2c64zg_masac_default_create_config = EasyDict(SMAC_2c64zg_masac_default_create_config)
+create_config = SMAC_2c64zg_masac_default_create_config
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_2c64zg_qmix_config.py b/DI-engine/dizoo/smac/config/smac_2c64zg_qmix_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..228d8b878a8acedbb6a829d3b5934e13d7404a65
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_2c64zg_qmix_config.py
@@ -0,0 +1,81 @@
+from easydict import EasyDict
+
+agent_num = 2
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_2c64zg_qmix_seed0',
+ env=dict(
+ map_name='2c_vs_64zg',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=404,
+ global_obs_shape=342,
+ action_shape=70,
+ hidden_size_list=[64],
+ mixer=True,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ double_q=False,
+ target_update_theta=0.005,
+ discount_factor=0.99,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='qmix'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_2s3z_qmix_config.py b/DI-engine/dizoo/smac/config/smac_2s3z_qmix_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8796048da0ecffe7b83f3286190fa93f1bb47faa
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_2s3z_qmix_config.py
@@ -0,0 +1,81 @@
+from easydict import EasyDict
+
+agent_num = 5
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_2s3z_qmix_seed0',
+ env=dict(
+ map_name='2s3z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=96,
+ global_obs_shape=120,
+ action_shape=11,
+ hidden_size_list=[128],
+ mixer=True,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ double_q=False,
+ target_update_theta=0.01,
+ discount_factor=0.95,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=50, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=5000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qmix'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_2s3z_qtran_config.py b/DI-engine/dizoo/smac/config/smac_2s3z_qtran_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..0c66056a1b450e9920538b6d704ea37a7513c4bb
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_2s3z_qtran_config.py
@@ -0,0 +1,84 @@
+from easydict import EasyDict
+
+agent_num = 5
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_2s3z_qtran_seed0',
+ env=dict(
+ map_name='2s3z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=96,
+ global_obs_shape=120,
+ action_shape=11,
+ hidden_size_list=[64],
+ embedding_size=64,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ double_q=True,
+ target_update_theta=0.01,
+ discount_factor=0.95,
+ td_weight=1,
+ opt_weight=0.1,
+ nopt_min_weight=0.0001,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=50, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=5000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qtran'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_3m_masac_config.py b/DI-engine/dizoo/smac/config/smac_3m_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..54d63d3c5119bd44e09f048abb0def7bdabde49d
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_3m_masac_config.py
@@ -0,0 +1,89 @@
+from easydict import EasyDict
+
+agent_num = 3
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+SMAC_3m_masac_default_config = dict(
+ exp_name='smac_3m_masac_seed0',
+ env=dict(
+ map_name='3m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=16,
+ stop_value=0.99,
+ death_mask=True,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=0,
+ model=dict(
+ agent_obs_shape=42,
+ global_obs_shape=77,
+ action_shape=9,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ log_space=True,
+ ),
+ collect=dict(
+ env_num=collector_env_num,
+ n_sample=3200,
+ unroll_len=1,
+ ),
+ eval=dict(
+ evaluator=dict(eval_freq=50, ),
+ env_num=evaluator_env_num,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=100000,
+ ), # TODO(pu)
+ replay_buffer=dict(replay_buffer_size=1000000, ),
+ ),
+ ),
+)
+
+SMAC_3m_masac_default_config = EasyDict(SMAC_3m_masac_default_config)
+main_config = SMAC_3m_masac_default_config
+
+SMAC_3m_masac_default_create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='sac_discrete', ),
+)
+SMAC_3m_masac_default_create_config = EasyDict(SMAC_3m_masac_default_create_config)
+create_config = SMAC_3m_masac_default_create_config
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_3s5z_collaq_config.py b/DI-engine/dizoo/smac/config/smac_3s5z_collaq_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..7f1bf124b441e6ac6670dcf86c3c175b6db58e0a
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_3s5z_collaq_config.py
@@ -0,0 +1,88 @@
+from easydict import EasyDict
+
+agent_num = 8
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_3s5z_collaq_seed0',
+ env=dict(
+ map_name='3s5z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ obs_alone=True,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=150,
+ alone_obs_shape=94,
+ global_obs_shape=216,
+ action_shape=14,
+ hidden_size_list=[128],
+ attention=False,
+ self_feature_range=[124, 128], # placeholder 4
+ ally_feature_range=[68, 124], # placeholder 8*7
+ attention_size=32,
+ mixer=True,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ double_q=False,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+ collaq_loss_weight=1.0,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='collaq'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_3s5z_collaq_per_config.py b/DI-engine/dizoo/smac/config/smac_3s5z_collaq_per_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5530e53724250ccbf8e6c68ada7149445c7f8442
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_3s5z_collaq_per_config.py
@@ -0,0 +1,92 @@
+from easydict import EasyDict
+
+agent_num = 8
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_3s5z_collaq_per_seed0',
+ env=dict(
+ map_name='3s5z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ obs_alone=True,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ # (bool) Whether use priority(priority sample, IS weight, update priority)
+ priority=True,
+ # (bool) Whether use Importance Sampling Weight to correct biased update. If True, priority must be True.
+ priority_IS_weight=True,
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=150,
+ alone_obs_shape=94,
+ global_obs_shape=216,
+ action_shape=14,
+ hidden_size_list=[128],
+ attention=False,
+ self_feature_range=[124, 128], # placeholder 4
+ ally_feature_range=[68, 124], # placeholder 8*7
+ attention_size=32,
+ mixer=True,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ double_q=False,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+ collaq_loss_weight=1.0,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='collaq'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_3s5z_coma_config.py b/DI-engine/dizoo/smac/config/smac_3s5z_coma_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5100797a11c76bb7524fae778b042deb1dfea43a
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_3s5z_coma_config.py
@@ -0,0 +1,81 @@
+from easydict import EasyDict
+
+agent_num = 8
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_3s5z_coma_seed0',
+ env=dict(
+ map_name='3s5z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=agent_num,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ obs_shape=dict(
+ agent_state=150,
+ global_state=216,
+ ),
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=14,
+ # (List[int]) The size of hidden layer
+ actor_hidden_size_list=[64],
+ ),
+ # used in state_num of hidden_state
+ collect=dict(
+ n_episode=32,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.5,
+ end=0.01,
+ decay=200000,
+ ),
+ replay_buffer=dict(
+ # (int) max size of replay buffer
+ replay_buffer_size=5000,
+ # (int) max use count of data, if count is bigger than this value, the data will be removed from buffer
+ max_use=10,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='coma'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_3s5z_madqn_config.py b/DI-engine/dizoo/smac/config/smac_3s5z_madqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e771baf097091cd44e3bfff6183b75c1bc12719
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_3s5z_madqn_config.py
@@ -0,0 +1,84 @@
+from ding.entry import serial_pipeline
+from easydict import EasyDict
+
+agent_num = 8
+collector_env_num = 4
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_3s5z_madqn_seed0',
+ env=dict(
+ map_name='3s5z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ special_global_state=True,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ nstep=1,
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=150,
+ global_obs_shape=295,
+ global_cooperation=True,
+ action_shape=14,
+ hidden_size_list=[256, 256],
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=64,
+ learning_rate=0.0005,
+ clip_value=5,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+ ),
+ collect=dict(
+ collector=dict(get_train_sample=True, ),
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=1000, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=15000, ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='madqn'),
+ collector=dict(type='episode'),
+)
+create_config = EasyDict(create_config)
+
+
+def train(args):
+ config = [main_config, create_config]
+ serial_pipeline(config, seed=args.seed, max_env_step=1e7)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=1)
+ args = parser.parse_args()
+
+ train(args)
diff --git a/DI-engine/dizoo/smac/config/smac_3s5z_mappo_config.py b/DI-engine/dizoo/smac/config/smac_3s5z_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d57055b93dedca80d551e00eb9c97c00df5617a7
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_3s5z_mappo_config.py
@@ -0,0 +1,91 @@
+from easydict import EasyDict
+
+agent_num = 8
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+main_config = dict(
+ exp_name='smac_3s5z_mappo_seed0',
+ env=dict(
+ map_name='3s5z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=False,
+ special_global_state=special_global_state,
+ # save_replay_episodes = 1,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ action_space='discrete',
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=agent_num,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ agent_obs_shape=150,
+ #global_obs_shape=216,
+ global_obs_shape=295,
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=14,
+ # (List[int]) The size of hidden layer
+ # hidden_size_list=[64],
+ action_space='discrete'
+ ),
+ # used in state_num of hidden_state
+ learn=dict(
+ epoch_per_collect=5,
+ batch_size=3200,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.0,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.5,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=50, )),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_3s5z_masac_config.py b/DI-engine/dizoo/smac/config/smac_3s5z_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b133a4f5a7578266a1aa6a32a40b39782dd601b6
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_3s5z_masac_config.py
@@ -0,0 +1,89 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+smac_3s5z_masac_default_config = dict(
+ exp_name='smac_3s5z_masac_seed0',
+ env=dict(
+ map_name='3s5z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=True,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=0,
+ model=dict(
+ agent_obs_shape=150,
+ global_obs_shape=295,
+ action_shape=14,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-5,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ log_space=True,
+ ),
+ collect=dict(
+ env_num=collector_env_num,
+ n_sample=1600,
+ unroll_len=1,
+ ),
+ eval=dict(
+ evaluator=dict(eval_freq=50, ),
+ env_num=evaluator_env_num,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(replay_buffer_size=1000000, ),
+ ),
+ ),
+)
+
+smac_3s5z_masac_default_config = EasyDict(smac_3s5z_masac_default_config)
+main_config = smac_3s5z_masac_default_config
+
+smac_3s5z_masac_default_create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='sac_discrete', ),
+)
+smac_3s5z_masac_default_create_config = EasyDict(smac_3s5z_masac_default_create_config)
+create_config = smac_3s5z_masac_default_create_config
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_3s5z_qmix_config.py b/DI-engine/dizoo/smac/config/smac_3s5z_qmix_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..cd622ae5f6c9a5295df4c72f0e9e5fbd25675235
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_3s5z_qmix_config.py
@@ -0,0 +1,81 @@
+from easydict import EasyDict
+
+agent_num = 8
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_3s5z_qmix_seed0',
+ env=dict(
+ map_name='3s5z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=150,
+ global_obs_shape=216,
+ action_shape=14,
+ hidden_size_list=[64],
+ mixer=True,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=64,
+ learning_rate=0.0005,
+ clip_value=5,
+ double_q=False,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qmix'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_3s5z_qtran_config.py b/DI-engine/dizoo/smac/config/smac_3s5z_qtran_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..119097735453712e8eead042fa3d1db975f4d4fe
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_3s5z_qtran_config.py
@@ -0,0 +1,83 @@
+from easydict import EasyDict
+
+agent_num = 8
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_3s5z_qtran_seed0',
+ env=dict(
+ map_name='3s5z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=150,
+ global_obs_shape=216,
+ action_shape=14,
+ hidden_size_list=[64],
+ embedding_size=64,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ double_q=True,
+ target_update_theta=0.006,
+ discount_factor=0.95,
+ td_weight=1,
+ opt_weight=0.1,
+ nopt_min_weight=0.0001,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qtran'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_3s5z_wqmix_config.py b/DI-engine/dizoo/smac/config/smac_3s5z_wqmix_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..92552f930b3f90d454db818d239da69c579ef72a
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_3s5z_wqmix_config.py
@@ -0,0 +1,86 @@
+from easydict import EasyDict
+
+agent_num = 8
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_3s5z_wqmix_seed0',
+ env=dict(
+ map_name='3s5z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=150,
+ global_obs_shape=216,
+ action_shape=14,
+ hidden_size_list=[64],
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+
+ ## for OW Optimistically-Weighted
+ wqmix_ow=True,
+ alpha=0.5,
+ ## for CW Centrally-Weighted
+ # wqmix_ow = False,
+ # alpha = 0.75,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='wqmix'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_3s5zvs3s6z_madqn_config.py b/DI-engine/dizoo/smac/config/smac_3s5zvs3s6z_madqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..438025241f4466625e559e4e4213931b4b57c9e6
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_3s5zvs3s6z_madqn_config.py
@@ -0,0 +1,84 @@
+from ding.entry import serial_pipeline
+from easydict import EasyDict
+
+agent_num = 8
+collector_env_num = 4
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_3s5zvs3s6z_madqn_seed0',
+ env=dict(
+ map_name='3s5z_vs_3s6z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ special_global_state=True,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ nstep=3,
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=159,
+ global_obs_shape=314,
+ global_cooperation=True,
+ action_shape=15,
+ hidden_size_list=[256, 256],
+ ),
+ learn=dict(
+ update_per_collect=40,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+ ),
+ collect=dict(
+ collector=dict(get_train_sample=True, ),
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=1000, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(replay_buffer_size=30000, ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='madqn'),
+ collector=dict(type='episode'),
+)
+create_config = EasyDict(create_config)
+
+
+def train(args):
+ config = [main_config, create_config]
+ serial_pipeline(config, seed=args.seed, max_env_step=1e7)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ args = parser.parse_args()
+
+ train(args)
diff --git a/DI-engine/dizoo/smac/config/smac_3s5zvs3s6z_mappo_config.py b/DI-engine/dizoo/smac/config/smac_3s5zvs3s6z_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a80527eeba1bc5ea3aecf5e782b807b4e8aa12eb
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_3s5zvs3s6z_mappo_config.py
@@ -0,0 +1,89 @@
+from easydict import EasyDict
+
+agent_num = 8
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+main_config = dict(
+ exp_name='smac_3s5z_vs_3s6z_mappo_seed0',
+ env=dict(
+ map_name='3s5z_vs_3s6z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=True,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ continuous=False,
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=agent_num,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ agent_obs_shape=159,
+ global_obs_shape=314,
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=15,
+ # (List[int]) The size of hidden layer
+ # hidden_size_list=[64],
+ ),
+ # used in state_num of hidden_state
+ learn=dict(
+ epoch_per_collect=5,
+ batch_size=3200,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ on_policy=True,
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=50, )),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_3s5zvs3s6z_masac_config.py b/DI-engine/dizoo/smac/config/smac_3s5zvs3s6z_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a2dbdf87c56f8d62705952e9d08ef9c214654de4
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_3s5zvs3s6z_masac_config.py
@@ -0,0 +1,92 @@
+from easydict import EasyDict
+
+agent_num = 8
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+smac_3s5zvs3s6z_masac_default_config = dict(
+ exp_name='smac_3s5z_vs_3s6z_masac_seed0',
+ env=dict(
+ map_name='3s5z_vs_3s6z',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=True,
+ special_global_state=special_global_state,
+ # save_replay_episodes = 1,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ random_collect_size=0,
+ model=dict(
+ agent_obs_shape=159,
+ global_obs_shape=314,
+ action_shape=15,
+ twin_critic=True,
+ actor_head_hidden_size=512,
+ critic_head_hidden_size=1024,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-5,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ log_space=True,
+ ),
+ collect=dict(
+ env_num=collector_env_num,
+ n_sample=1600,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(
+ evaluator=dict(eval_freq=2000, ),
+ env_num=evaluator_env_num,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(replay_buffer_size=1000000, ),
+ ),
+ ),
+)
+
+smac_3s5zvs3s6z_masac_default_config = EasyDict(smac_3s5zvs3s6z_masac_default_config)
+main_config = smac_3s5zvs3s6z_masac_default_config
+
+smac_3s5zvs3s6z_masac_default_create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sac_discrete', ),
+)
+smac_3s5zvs3s6z_masac_default_create_config = EasyDict(smac_3s5zvs3s6z_masac_default_create_config)
+create_config = smac_3s5zvs3s6z_masac_default_create_config
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_5m6m_collaq_config.py b/DI-engine/dizoo/smac/config/smac_5m6m_collaq_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f775b9dcac5397086caca2a4421d93ec3a27df4
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_5m6m_collaq_config.py
@@ -0,0 +1,94 @@
+from easydict import EasyDict
+
+agent_num = 5
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_5m6m_collaq_seed0',
+ env=dict(
+ map_name='5m_vs_6m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ obs_alone=True,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=72,
+ alone_obs_shape=52,
+ global_obs_shape=98,
+ action_shape=12,
+ hidden_size_list=[128],
+ attention=True,
+ self_feature_range=[54, 55], # placeholder 4
+ ally_feature_range=[34, 54], # placeholder 8*7
+ attention_size=32,
+ mixer=True,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=4,
+ double_q=False,
+ target_update_theta=0.005,
+ discount_factor=0.95,
+ collaq_loss_weight=1.0,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=30000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=50000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='collaq'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+
+def train(args):
+ config = [main_config, create_config]
+ serial_pipeline(config, seed=args.seed)
+
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_5m6m_madqn_config.py b/DI-engine/dizoo/smac/config/smac_5m6m_madqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d05bb23dcb7b1d64705b709d9302502baf9cc77e
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_5m6m_madqn_config.py
@@ -0,0 +1,98 @@
+from ding.entry import serial_pipeline
+from easydict import EasyDict
+
+agent_num = 5
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_5m6m_madqn_seed0',
+ env=dict(
+ map_name='5m_vs_6m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ shared_memory=False,
+ special_global_state=True,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ ),
+ policy=dict(
+ nstep=3,
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=72,
+ global_obs_shape=152,
+ action_shape=12,
+ hidden_size_list=[256, 256],
+ ),
+ learn=dict(
+ update_per_collect=40,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=10,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+ ),
+ collect=dict(
+ collector=dict(get_train_sample=True, ),
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=1000, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=50000,
+ ),
+ replay_buffer=dict(replay_buffer_size=50000, ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='madqn'),
+ collector=dict(type='episode'),
+)
+create_config = EasyDict(create_config)
+
+
+def train(args):
+ config = [main_config, create_config]
+ serial_pipeline(config, seed=args.seed)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ args = parser.parse_args()
+
+ train(args)
+
+
+def train(args):
+ config = [main_config, create_config]
+ serial_pipeline(config, seed=args.seed, max_env_step=1e7)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ args = parser.parse_args()
+
+ train(args)
diff --git a/DI-engine/dizoo/smac/config/smac_5m6m_mappo_config.py b/DI-engine/dizoo/smac/config/smac_5m6m_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4e2d968e320196693a9c1e1b07f61d37f9e3395
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_5m6m_mappo_config.py
@@ -0,0 +1,90 @@
+from easydict import EasyDict
+
+agent_num = 5
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True,
+
+main_config = dict(
+ exp_name='smac_5m6m_mappo_seed0',
+ env=dict(
+ map_name='5m_vs_6m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=True,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ action_space='discrete',
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=agent_num,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ agent_obs_shape=72,
+ #global_obs_shape=216,
+ global_obs_shape=152,
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=12,
+ # (List[int]) The size of hidden layer
+ # hidden_size_list=[64],
+ action_space='discrete',
+ ),
+ # used in state_num of hidden_state
+ learn=dict(
+ epoch_per_collect=10,
+ batch_size=3200,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.05,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=50, )),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_5m6m_masac_config.py b/DI-engine/dizoo/smac/config/smac_5m6m_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8dad040e33243342a99d4a829ef8f873deb5557d
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_5m6m_masac_config.py
@@ -0,0 +1,89 @@
+from easydict import EasyDict
+
+agent_num = 5
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+SMAC_5m6m_masac_default_config = dict(
+ exp_name='smac_5m6m_masac_seed0',
+ env=dict(
+ map_name='5m_vs_6m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=True,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=0,
+ model=dict(
+ agent_obs_shape=72,
+ global_obs_shape=152,
+ action_shape=12,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-5,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ log_space=True,
+ ),
+ collect=dict(
+ env_num=collector_env_num,
+ n_sample=1600,
+ unroll_len=1,
+ ),
+ eval=dict(
+ evaluator=dict(eval_freq=50, ),
+ env_num=evaluator_env_num,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=50000,
+ ),
+ replay_buffer=dict(replay_buffer_size=1000000, ),
+ ),
+ ),
+)
+
+SMAC_5m6m_masac_default_config = EasyDict(SMAC_5m6m_masac_default_config)
+main_config = SMAC_5m6m_masac_default_config
+
+SMAC_5m6m_masac_default_create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='sac_discrete', ),
+)
+SMAC_5m6m_masac_default_create_config = EasyDict(SMAC_5m6m_masac_default_create_config)
+create_config = SMAC_5m6m_masac_default_create_config
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_5m6m_qmix_config.py b/DI-engine/dizoo/smac/config/smac_5m6m_qmix_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..b52a019ad8346104e4966d5443b61ce145273f1b
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_5m6m_qmix_config.py
@@ -0,0 +1,81 @@
+from easydict import EasyDict
+
+agent_num = 5
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_5m6m_qmix_seed0',
+ env=dict(
+ map_name='5m_vs_6m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=72,
+ global_obs_shape=98,
+ action_shape=12,
+ hidden_size_list=[64],
+ mixer=True,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ double_q=False,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=20,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=50000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=5000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qmix'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_5m6m_qtran_config.py b/DI-engine/dizoo/smac/config/smac_5m6m_qtran_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..960e8affb8484c2b771bad2802730beeff2ec59c
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_5m6m_qtran_config.py
@@ -0,0 +1,83 @@
+from easydict import EasyDict
+
+agent_num = 5
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_5m6m_qtran_seed0',
+ env=dict(
+ map_name='5m_vs_6m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=72,
+ global_obs_shape=98,
+ action_shape=12,
+ hidden_size_list=[128],
+ embedding_size=128,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ double_q=False,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+ td_weight=1,
+ opt_weight=0.1,
+ nopt_min_weight=0.001,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qtran'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_5m6m_wqmix_config.py b/DI-engine/dizoo/smac/config/smac_5m6m_wqmix_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2e7586b5aef0c95cf32a8d676213e1847c635d2b
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_5m6m_wqmix_config.py
@@ -0,0 +1,87 @@
+from easydict import EasyDict
+
+agent_num = 5
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_5m6m_wqmix_seed0',
+ env=dict(
+ map_name='5m_vs_6m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=72,
+ global_obs_shape=98,
+ action_shape=12,
+ hidden_size_list=[64],
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+
+ ## for OW Optimistically-Weighted
+ wqmix_ow=True,
+ alpha=0.5,
+ ## for CW Centrally-Weighted
+ # wqmix_ow = False,
+ # alpha = 0.75,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=1000000,
+ #decay=50000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='wqmix'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_8m9m_madqn_config.py b/DI-engine/dizoo/smac/config/smac_8m9m_madqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..672330df241ff8fa1c54de5ee7d6bef8bdc1a310
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_8m9m_madqn_config.py
@@ -0,0 +1,98 @@
+from ding.entry import serial_pipeline
+from easydict import EasyDict
+
+agent_num = 8
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_8m9m_madqn_seed0',
+ env=dict(
+ map_name='8m_vs_9m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ shared_memory=False,
+ special_global_state=True,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ ),
+ policy=dict(
+ nstep=3,
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=108,
+ global_obs_shape=263,
+ action_shape=15,
+ hidden_size_list=[256, 256],
+ ),
+ learn=dict(
+ update_per_collect=40,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=10,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+ ),
+ collect=dict(
+ collector=dict(get_train_sample=True, ),
+ n_episode=32,
+ unroll_len=20,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=1000, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=50000,
+ ),
+ replay_buffer=dict(replay_buffer_size=20000, ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='madqn'),
+ collector=dict(type='episode'),
+)
+create_config = EasyDict(create_config)
+
+
+def train(args):
+ config = [main_config, create_config]
+ serial_pipeline(config, seed=args.seed)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ args = parser.parse_args()
+
+ train(args)
+
+
+def train(args):
+ config = [main_config, create_config]
+ serial_pipeline(config, seed=args.seed, max_env_step=1e7)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ args = parser.parse_args()
+
+ train(args)
diff --git a/DI-engine/dizoo/smac/config/smac_8m9m_mappo_config.py b/DI-engine/dizoo/smac/config/smac_8m9m_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e46f059f4f2b91ad0dd9c479f014d5d1a4d3f5c9
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_8m9m_mappo_config.py
@@ -0,0 +1,92 @@
+from easydict import EasyDict
+
+agent_num = 8
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+main_config = dict(
+ exp_name='smac_8m9m_mappo_seed0',
+ env=dict(
+ map_name='8m_vs_9m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=False,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=True,
+ multi_agent=True,
+ continuous=False,
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=agent_num,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ agent_obs_shape=108,
+ global_obs_shape=263,
+ action_shape=15,
+ # (List[int]) The size of hidden layer
+ # hidden_size_list=[64],
+ # delete encode in code
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ # used in state_num of hidden_state
+ learn=dict(
+ epoch_per_collect=5,
+ batch_size=3200,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(
+ evaluator=dict(eval_freq=100, ),
+ env_num=evaluator_env_num,
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_8m9m_masac_config.py b/DI-engine/dizoo/smac/config/smac_8m9m_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e394b8fc6030b9d0e3f1c1e5f2cee92073ecc781
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_8m9m_masac_config.py
@@ -0,0 +1,92 @@
+from easydict import EasyDict
+from ding.entry import serial_pipeline
+
+agent_num = 8
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+SMAC_8m9m_masac_default_config = dict(
+ exp_name='smac_8m9m_masac_seed0',
+ env=dict(
+ map_name='8m_vs_9m',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=False,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ random_collect_size=0,
+ model=dict(
+ agent_obs_shape=108,
+ global_obs_shape=263,
+ action_shape=15,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=512,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-5,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ log_space=True,
+ ),
+ collect=dict(
+ env_num=collector_env_num,
+ n_sample=1600,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(
+ evaluator=dict(eval_freq=500, ),
+ env_num=evaluator_env_num,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(replay_buffer_size=50000, ),
+ ),
+ ),
+)
+
+SMAC_8m9m_masac_default_config = EasyDict(SMAC_8m9m_masac_default_config)
+main_config = SMAC_8m9m_masac_default_config
+
+SMAC_8m9m_masac_default_create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sac_discrete', ),
+)
+SMAC_8m9m_masac_default_create_config = EasyDict(SMAC_8m9m_masac_default_create_config)
+create_config = SMAC_8m9m_masac_default_create_config
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM2_collaq_config.py b/DI-engine/dizoo/smac/config/smac_MMM2_collaq_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5f20178dff5092a09f6ab55638bf37824fd60ca9
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM2_collaq_config.py
@@ -0,0 +1,92 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_MMM2_collaq_seed0',
+ env=dict(
+ map_name='MMM2',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ obs_alone=True,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=204,
+ alone_obs_shape=132,
+ global_obs_shape=322,
+ action_shape=18,
+ hidden_size_list=[128],
+ attention=True,
+ # obs_shape = move_feature(4) + enemy_feats(enemy_feat_dim*enemy_num)
+ # + ally_feats(ally_feat_dim*ally_num) + own_feats + agent_id_feats (agent_num)
+ # 4+8*12+8*9+22+10
+ # please see the function of get_obs_agent in smac_env.py
+ self_feature_range=[172, 194],
+ ally_feature_range=[100, 172],
+ attention_size=32,
+ mixer=True,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ double_q=False,
+ target_update_theta=0.008,
+ discount_factor=0.93,
+ collaq_loss_weight=1.0,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='collaq'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM2_coma_config.py b/DI-engine/dizoo/smac/config/smac_MMM2_coma_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..d9305c59acd2514e3b756bc50a01f27c4c622866
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM2_coma_config.py
@@ -0,0 +1,80 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_MMM2_coma_seed0',
+ env=dict(
+ map_name='MMM2',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=dict(
+ agent_state=204,
+ global_state=322,
+ ),
+ action_shape=18,
+ actor_hidden_size_list=[64],
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ target_update_theta=0.001,
+ discount_factor=0.99,
+ td_lambda=0.9,
+ policy_weight=0.001,
+ value_weight=1,
+ entropy_weight=0.01,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.5,
+ end=0.01,
+ decay=200000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=5000,
+ max_use=10,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='coma'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM2_madqn_config.py b/DI-engine/dizoo/smac/config/smac_MMM2_madqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe8e96501c241dafcde4b35b4c1c53c09b5db2bc
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM2_madqn_config.py
@@ -0,0 +1,84 @@
+from ding.entry import serial_pipeline
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 4
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_MMM2_madqn_seed0',
+ env=dict(
+ map_name='MMM2',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ special_global_state=True,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ nstep=1,
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=204,
+ global_obs_shape=431,
+ global_cooperation=True,
+ action_shape=18,
+ hidden_size_list=[256, 256],
+ ),
+ learn=dict(
+ update_per_collect=40,
+ batch_size=64,
+ learning_rate=0.0005,
+ clip_value=5,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+ ),
+ collect=dict(
+ collector=dict(get_train_sample=True, ),
+ n_episode=32,
+ unroll_len=20,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=1000, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(replay_buffer_size=30000, ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='madqn'),
+ collector=dict(type='episode'),
+)
+create_config = EasyDict(create_config)
+
+
+def train(args):
+ config = [main_config, create_config]
+ serial_pipeline(config, seed=args.seed, max_env_step=1e7)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ args = parser.parse_args()
+
+ train(args)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM2_mappo_config.py b/DI-engine/dizoo/smac/config/smac_MMM2_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..711ddb8f42578f4904b46d808beba40e5c9efc1b
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM2_mappo_config.py
@@ -0,0 +1,89 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+main_config = dict(
+ exp_name='smac_MMM2_mappo_seed0',
+ env=dict(
+ map_name='MMM2',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=True,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ action_space='discrete',
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=agent_num,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ agent_obs_shape=204,
+ global_obs_shape=431,
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=18,
+ # (List[int]) The size of hidden layer
+ # hidden_size_list=[64],
+ action_space='discrete',
+ ),
+ # used in state_num of hidden_state
+ learn=dict(
+ epoch_per_collect=5,
+ batch_size=1600,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.5,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=50, )),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM2_masac_config.py b/DI-engine/dizoo/smac/config/smac_MMM2_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e99fc464e31b9db91517c5a9e61c8a8739f0a46
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM2_masac_config.py
@@ -0,0 +1,89 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+SMAC_MMM2_masac_default_config = dict(
+ exp_name='smac_MMM2_masac_seed0',
+ env=dict(
+ map_name='MMM2',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=True,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=0,
+ model=dict(
+ agent_obs_shape=204,
+ global_obs_shape=431,
+ action_shape=18,
+ twin_critic=True,
+ actor_head_hidden_size=512,
+ critic_head_hidden_size=1024,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-5,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ log_space=True,
+ ),
+ collect=dict(
+ env_num=collector_env_num,
+ n_sample=1600,
+ unroll_len=1,
+ ),
+ eval=dict(
+ evaluator=dict(eval_freq=50, ),
+ env_num=evaluator_env_num,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(replay_buffer_size=1000000, ),
+ ),
+ ),
+)
+
+SMAC_MMM2_masac_default_config = EasyDict(SMAC_MMM2_masac_default_config)
+main_config = SMAC_MMM2_masac_default_config
+
+SMAC_MMM2_masac_default_create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='sac_discrete', ),
+)
+SMAC_MMM2_masac_default_create_config = EasyDict(SMAC_MMM2_masac_default_create_config)
+create_config = SMAC_MMM2_masac_default_create_config
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM2_qmix_config.py b/DI-engine/dizoo/smac/config/smac_MMM2_qmix_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b2c21d926c6e1fdef801f8cdf83b2127f39cc9d
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM2_qmix_config.py
@@ -0,0 +1,81 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_MMM2_qmix_seed0',
+ env=dict(
+ map_name='MMM2',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=204,
+ global_obs_shape=322,
+ action_shape=18,
+ hidden_size_list=[64],
+ mixer=True,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ double_q=False,
+ target_update_theta=0.005,
+ discount_factor=0.99,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qmix'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM2_wqmix_config.py b/DI-engine/dizoo/smac/config/smac_MMM2_wqmix_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..50c0acbad385e83d86b0821ff6047bbd9b587d90
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM2_wqmix_config.py
@@ -0,0 +1,86 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_MMM2_wqmix_seed0',
+ env=dict(
+ map_name='MMM2',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=204,
+ global_obs_shape=322,
+ action_shape=18,
+ hidden_size_list=[64],
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ target_update_theta=0.01,
+ discount_factor=0.95,
+
+ ## for OW Optimistically-Weighted
+ wqmix_ow=True,
+ alpha=0.5,
+ ## for CW Centrally-Weighted
+ # wqmix_ow = False,
+ # alpha = 0.75,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='wqmix'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM_collaq_config.py b/DI-engine/dizoo/smac/config/smac_MMM_collaq_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..c2ed63335975530f4322d9315761dc5c38241930
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM_collaq_config.py
@@ -0,0 +1,92 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_MMM_collaq_seed0',
+ env=dict(
+ map_name='MMM',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ shared_memory=False,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ obs_alone=True,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=186,
+ alone_obs_shape=114,
+ global_obs_shape=290,
+ action_shape=16,
+ hidden_size_list=[128],
+ attention=False,
+ # obs_shape = move_feature(4) + enemy_feats(enemy_feat_dim*enemy_num)
+ # + ally_feats(ally_feat_dim*ally_num) + own_feats + agent_id_feats (agent_num)
+ # please see the function of get_obs_agent in smac_env.py
+ self_feature_range=[156, 176],
+ ally_feature_range=[84, 156],
+ attention_size=32,
+ mixer=True,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ double_q=True,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+ collaq_loss_weight=1.0,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='collaq'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM_coma_config.py b/DI-engine/dizoo/smac/config/smac_MMM_coma_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e362fb3d6263030cf10678dc2c1c8ec6522e9b98
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM_coma_config.py
@@ -0,0 +1,82 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_MMM_coma_seed0',
+ env=dict(
+ map_name='MMM',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=agent_num,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ obs_shape=dict(
+ agent_state=186,
+ global_state=290,
+ ),
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=16,
+ # (List[int]) The size of hidden layer
+ actor_hidden_size_list=[64],
+ ),
+ # used in state_num of hidden_state
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='exp',
+ start=0.5,
+ end=0.01,
+ decay=200000,
+ ),
+ replay_buffer=dict(
+ # (int) max size of replay buffer
+ replay_buffer_size=5000,
+ # (int) max use count of data, if count is bigger than this value, the data will be removed from buffer
+ max_use=10,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='coma'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM_madqn_config.py b/DI-engine/dizoo/smac/config/smac_MMM_madqn_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..892f1f5217fbc56c1bd422f5af23ae614ea06a3b
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM_madqn_config.py
@@ -0,0 +1,84 @@
+from ding.entry import serial_pipeline
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 4
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_MMM_madqn_seed0',
+ env=dict(
+ map_name='MMM',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ special_global_state=True,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ nstep=1,
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=186,
+ global_obs_shape=389,
+ global_cooperation=True,
+ action_shape=16,
+ hidden_size_list=[256, 256],
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=64,
+ learning_rate=0.0005,
+ clip_value=5,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+ ),
+ collect=dict(
+ collector=dict(get_train_sample=True, ),
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=1000, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(replay_buffer_size=15000, ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='madqn'),
+ collector=dict(type='episode'),
+)
+create_config = EasyDict(create_config)
+
+
+def train(args):
+ config = [main_config, create_config]
+ serial_pipeline(config, seed=args.seed, max_env_step=1e7)
+
+
+if __name__ == "__main__":
+ import argparse
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument('--seed', '-s', type=int, default=0)
+ args = parser.parse_args()
+
+ train(args)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM_mappo_config.py b/DI-engine/dizoo/smac/config/smac_MMM_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..8559052e8120833651f621fa273304c5c81bcdd0
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM_mappo_config.py
@@ -0,0 +1,90 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True,
+
+main_config = dict(
+ exp_name='smac_MMM_mappo_seed0',
+ env=dict(
+ map_name='MMM',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=False,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ multi_agent=True,
+ action_space='discrete',
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=agent_num,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ agent_obs_shape=186,
+ #global_obs_shape=216,
+ global_obs_shape=389,
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=16,
+ # (List[int]) The size of hidden layer
+ # hidden_size_list=[64],
+ action_space='discrete',
+ ),
+ # used in state_num of hidden_state
+ learn=dict(
+ epoch_per_collect=5,
+ batch_size=320,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=50, )),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM_masac_config.py b/DI-engine/dizoo/smac/config/smac_MMM_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..2a1b0bd8e44532cd0e5d55c61ce65e0c10e13fcb
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM_masac_config.py
@@ -0,0 +1,89 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+MMM_masac_default_config = dict(
+ exp_name='smac_MMM_masac_seed0',
+ env=dict(
+ map_name='MMM',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=True,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ random_collect_size=0,
+ model=dict(
+ agent_obs_shape=186,
+ global_obs_shape=389,
+ action_shape=16,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=256,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-4,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.2,
+ auto_alpha=True,
+ log_space=True,
+ ),
+ collect=dict(
+ env_num=collector_env_num,
+ n_sample=1600,
+ unroll_len=1,
+ ),
+ eval=dict(
+ evaluator=dict(eval_freq=50, ),
+ env_num=evaluator_env_num,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(replay_buffer_size=1000000, ),
+ ),
+ ),
+)
+
+MMM_masac_default_config = EasyDict(MMM_masac_default_config)
+main_config = MMM_masac_default_config
+
+MMM_masac_default_create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='sac_discrete', ),
+)
+MMM_masac_default_create_config = EasyDict(MMM_masac_default_create_config)
+create_config = MMM_masac_default_create_config
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM_qmix_config.py b/DI-engine/dizoo/smac/config/smac_MMM_qmix_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ac50e3c3489f7560b6ad7cf548566b4eb183962
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM_qmix_config.py
@@ -0,0 +1,81 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_MMM_qmix_seed0',
+ env=dict(
+ map_name='MMM',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=186,
+ global_obs_shape=290,
+ action_shape=16,
+ hidden_size_list=[64],
+ mixer=True,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ double_q=False,
+ target_update_theta=0.005,
+ discount_factor=0.99,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qmix'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM_qtran_config.py b/DI-engine/dizoo/smac/config/smac_MMM_qtran_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..5b0a5bcd13631c79214dc9869ec024865e39647e
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM_qtran_config.py
@@ -0,0 +1,83 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_MMM_qtran_seed0',
+ env=dict(
+ map_name='MMM',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=186,
+ global_obs_shape=290,
+ action_shape=16,
+ hidden_size_list=[256],
+ embedding_size=256,
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ double_q=True,
+ target_update_theta=0.006,
+ discount_factor=0.95,
+ td_weight=1,
+ opt_weight=0.01,
+ nopt_min_weight=0.0001,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=10000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='qtran'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_MMM_wqmix_config.py b/DI-engine/dizoo/smac/config/smac_MMM_wqmix_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..a4fa3b4c36ec17eac399bd0377d45126c4b511a5
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_MMM_wqmix_config.py
@@ -0,0 +1,86 @@
+from easydict import EasyDict
+
+agent_num = 10
+collector_env_num = 16
+evaluator_env_num = 8
+
+main_config = dict(
+ exp_name='smac_MMM_wqmix_seed0',
+ env=dict(
+ map_name='MMM',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ stop_value=0.999,
+ n_evaluator_episode=32,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ model=dict(
+ agent_num=agent_num,
+ obs_shape=186,
+ global_obs_shape=290,
+ action_shape=16,
+ hidden_size_list=[64],
+ lstm_type='gru',
+ dueling=False,
+ ),
+ learn=dict(
+ update_per_collect=20,
+ batch_size=32,
+ learning_rate=0.0005,
+ clip_value=5,
+ target_update_theta=0.008,
+ discount_factor=0.95,
+
+ ## for OW Optimistically-Weighted
+ wqmix_ow=True,
+ alpha=0.5,
+ ## for CW Centrally-Weighted
+ # wqmix_ow = False,
+ # alpha = 0.75,
+ ),
+ collect=dict(
+ n_episode=32,
+ unroll_len=10,
+ env_num=collector_env_num,
+ ),
+ eval=dict(env_num=evaluator_env_num, evaluator=dict(eval_freq=100, )),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=1000000,
+ ),
+ replay_buffer=dict(
+ replay_buffer_size=15000,
+ # (int) The maximum reuse times of each data
+ max_reuse=1e+9,
+ max_staleness=1e+9,
+ ),
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='wqmix'),
+ collector=dict(type='episode', get_train_sample=True),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_corridor_mappo_config.py b/DI-engine/dizoo/smac/config/smac_corridor_mappo_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..e160c0c106afad8b13e6f08f969a434c6aafd076
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_corridor_mappo_config.py
@@ -0,0 +1,95 @@
+from easydict import EasyDict
+
+agent_num = 6
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+main_config = dict(
+ exp_name='smac_corridor_mappo_seed0',
+ env=dict(
+ map_name='corridor',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=True,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=True,
+ multi_agent=True,
+ continuous=False,
+ model=dict(
+ # (int) agent_num: The number of the agent.
+ # For SMAC 3s5z, agent_num=8; for 2c_vs_64zg, agent_num=2.
+ agent_num=agent_num,
+ # (int) obs_shape: The shapeension of observation of each agent.
+ # For 3s5z, obs_shape=150; for 2c_vs_64zg, agent_num=404.
+ # (int) global_obs_shape: The shapeension of global observation.
+ # For 3s5z, obs_shape=216; for 2c_vs_64zg, agent_num=342.
+ agent_obs_shape=192,
+ global_obs_shape=431,
+ # (int) action_shape: The number of action which each agent can take.
+ # action_shape= the number of common action (6) + the number of enemies.
+ # For 3s5z, obs_shape=14 (6+8); for 2c_vs_64zg, agent_num=70 (6+64).
+ action_shape=30,
+ # (List[int]) The size of hidden layer
+ # hidden_size_list=[64],
+ # delete encode in code
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=512,
+ ),
+ # used in state_num of hidden_state
+ learn=dict(
+ epoch_per_collect=5,
+ batch_size=3200,
+ learning_rate=5e-4,
+ # ==============================================================
+ # The following configs is algorithm-specific
+ # ==============================================================
+ # (float) The loss weight of value network, policy network weight is set to 1
+ value_weight=0.5,
+ # (float) The loss weight of entropy regularization, policy network weight is set to 1
+ entropy_weight=0.01,
+ # (float) PPO clip ratio, defaults to 0.2
+ clip_ratio=0.2,
+ # (bool) Whether to use advantage norm in a whole training batch
+ adv_norm=False,
+ value_norm=True,
+ ppo_param_init=True,
+ grad_clip_type='clip_norm',
+ grad_clip_value=10,
+ ignore_done=False,
+ ),
+ collect=dict(env_num=collector_env_num, n_sample=3200),
+ eval=dict(
+ evaluator=dict(eval_freq=100, ),
+ env_num=evaluator_env_num,
+ ),
+ ),
+)
+main_config = EasyDict(main_config)
+create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='ppo'),
+)
+create_config = EasyDict(create_config)
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/config/smac_corridor_masac_config.py b/DI-engine/dizoo/smac/config/smac_corridor_masac_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..405c7638cef8186c26808c97c60ea50ce1266e5f
--- /dev/null
+++ b/DI-engine/dizoo/smac/config/smac_corridor_masac_config.py
@@ -0,0 +1,92 @@
+from easydict import EasyDict
+from ding.entry import serial_pipeline
+
+agent_num = 6
+collector_env_num = 8
+evaluator_env_num = 8
+special_global_state = True
+
+smac_corridor_masac_default_config = dict(
+ exp_name='smac_corridor_masac_seed0',
+ env=dict(
+ map_name='corridor',
+ difficulty=7,
+ reward_only_positive=True,
+ mirror_opponent=False,
+ agent_num=agent_num,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=32,
+ stop_value=0.99,
+ death_mask=False,
+ special_global_state=special_global_state,
+ manager=dict(
+ shared_memory=False,
+ reset_timeout=6000,
+ ),
+ ),
+ policy=dict(
+ cuda=True,
+ on_policy=False,
+ random_collect_size=0,
+ model=dict(
+ agent_obs_shape=192,
+ global_obs_shape=431,
+ action_shape=30,
+ twin_critic=True,
+ actor_head_hidden_size=256,
+ critic_head_hidden_size=512,
+ ),
+ learn=dict(
+ update_per_collect=50,
+ batch_size=320,
+ learning_rate_q=5e-4,
+ learning_rate_policy=5e-4,
+ learning_rate_alpha=5e-5,
+ ignore_done=False,
+ target_theta=0.005,
+ discount_factor=0.99,
+ alpha=0.005,
+ auto_alpha=False,
+ log_space=False,
+ ),
+ collect=dict(
+ env_num=collector_env_num,
+ n_sample=1600,
+ unroll_len=1,
+ ),
+ command=dict(),
+ eval=dict(
+ evaluator=dict(eval_freq=1000, ),
+ env_num=evaluator_env_num,
+ ),
+ other=dict(
+ eps=dict(
+ type='linear',
+ start=1,
+ end=0.05,
+ decay=100000,
+ ),
+ replay_buffer=dict(replay_buffer_size=50000, ),
+ ),
+ ),
+)
+
+smac_corridor_masac_default_config = EasyDict(smac_corridor_masac_default_config)
+main_config = smac_corridor_masac_default_config
+
+smac_corridor_masac_default_create_config = dict(
+ env=dict(
+ type='smac',
+ import_names=['dizoo.smac.envs.smac_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(type='sac_discrete', ),
+)
+smac_corridor_masac_default_create_config = EasyDict(smac_corridor_masac_default_create_config)
+create_config = smac_corridor_masac_default_create_config
+
+if __name__ == '__main__':
+
+ from ding.entry import serial_pipeline
+ serial_pipeline((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/smac/envs/__init__.py b/DI-engine/dizoo/smac/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..756ee4b97571526d4cdf5085dc9346f9e4225162
--- /dev/null
+++ b/DI-engine/dizoo/smac/envs/__init__.py
@@ -0,0 +1,8 @@
+import warnings
+
+from .fake_smac_env import FakeSMACEnv
+try:
+ from .smac_env import SMACEnv
+except ImportError:
+ warnings.warn("not found pysc2 env, please install it")
+ SMACEnv = None
diff --git a/DI-engine/dizoo/smac/envs/fake_smac_env.py b/DI-engine/dizoo/smac/envs/fake_smac_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..cc0199d57fd47e6ddca16941972c425717c98fdd
--- /dev/null
+++ b/DI-engine/dizoo/smac/envs/fake_smac_env.py
@@ -0,0 +1,52 @@
+from collections import namedtuple
+import numpy as np
+
+from ding.envs import BaseEnv, BaseEnvTimestep
+from ding.utils import ENV_REGISTRY
+
+FakeSMACEnvTimestep = namedtuple('FakeSMACEnvTimestep', ['obs', 'reward', 'done', 'info'])
+FakeSMACEnvInfo = namedtuple('FakeSMACEnvInfo', ['agent_num', 'obs_space', 'act_space', 'rew_space'])
+
+
+@ENV_REGISTRY.register('fake_smac')
+class FakeSMACEnv(BaseEnv):
+
+ def __init__(self, cfg=None):
+ self.agent_num = 8
+ self.action_dim = 6 + self.agent_num
+ self.obs_dim = 248
+ self.obs_alone_dim = 216
+ self.global_obs_dim = 216
+
+ def reset(self):
+ self.step_count = 0
+ return self._get_obs()
+
+ def _get_obs(self):
+ return {
+ 'agent_state': np.random.random((self.agent_num, self.obs_dim)),
+ 'agent_alone_state': np.random.random((self.agent_num, self.obs_alone_dim)),
+ 'agent_alone_padding_state': np.random.random((self.agent_num, self.obs_dim)),
+ 'global_state': np.random.random((self.global_obs_dim)),
+ 'action_mask': np.random.randint(0, 2, size=(self.agent_num, self.action_dim)),
+ }
+
+ def step(self, action):
+ assert action.shape == (self.agent_num, ), action.shape
+ obs = self._get_obs()
+ reward = np.random.randint(0, 10, size=(1, ))
+ done = self.step_count >= 314
+ info = {}
+ if done:
+ info['eval_episode_return'] = 0.71
+ self.step_count += 1
+ return FakeSMACEnvTimestep(obs, reward, done, info)
+
+ def close(self):
+ pass
+
+ def seed(self, _seed):
+ pass
+
+ def __repr__(self):
+ return 'FakeSMACEnv'
diff --git a/DI-engine/dizoo/smac/envs/maps/README.md b/DI-engine/dizoo/smac/envs/maps/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..3cf1a28dc86f145702f35f94641c272e8d1cebdb
--- /dev/null
+++ b/DI-engine/dizoo/smac/envs/maps/README.md
@@ -0,0 +1,15 @@
+# Notes on Two Player Maps
+
+Before starting, you need to do the following things:
+
+1. copy the maps in `maps/SMAC_Maps_two_player/*.SC2Map` to the directory `StarCraft II/Maps/SMAC_Maps_two_player/`.
+2. copy the maps in `maps/SMAC_Maps/*.SC2Map` to the directory `StarCraft II/Maps/SMAC_Maps/`.
+
+A convenient bash script is:
+
+```bash
+# In linux
+cp -r SMAC_Maps_two_player/ ~/StarCraftII/Maps/
+cp -r SMAC_Maps/ ~/StarCraftII/Maps/
+```
+
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/10m_vs_11m.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/10m_vs_11m.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..1dc2286dfd39380feafa6f8a1819248c9f2c9e3b
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/10m_vs_11m.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/1c3s5z.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/1c3s5z.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..07dfe38062b880307a95cd7722c17fa7ea740a24
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/1c3s5z.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/25m.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/25m.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..fcfdeb09dfc47c0b6376447608b3681a2ef8964a
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/25m.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/27m_vs_30m.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/27m_vs_30m.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..861c7f7069125d9d22e098056f7b430e83917acc
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/27m_vs_30m.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/2c_vs_64zg.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/2c_vs_64zg.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..b740b6c3d5fdc71d94f6e9b992206dbbfc5f495d
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/2c_vs_64zg.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/2m_vs_1z.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/2m_vs_1z.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..f4c05c40b1e7be6ae542c9fcae1bf319d28cf7c3
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/2m_vs_1z.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/2s3z.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/2s3z.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..59846ccf27a67450c03bea41c1fe5efbff4f0ad2
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/2s3z.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/2s_vs_1sc.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/2s_vs_1sc.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..c03328db237d440fb98d5e8a8b147a872bb30571
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/2s_vs_1sc.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3m.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3m.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..b35ec1008349e64afb63e23caeaa8c73f6028a03
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3m.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s5z.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s5z.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..e5a4313a216031f463383957fce249dba8a94fe4
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s5z.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..3927ca4f45afacbba09c1abf46aa5576fb8b5345
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s5z_vs_3s6z.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_3z.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_3z.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..4de7cf80e75f7a2e11878f767a74199cf8630a4f
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_3z.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_4z.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_4z.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..8db2dfc6aa08e310ad99f4f7f0b0b7b84811ffc2
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_4z.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_5z.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_5z.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..70c99d29635670f8ab9c5c6433c12b235278f2b5
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/3s_vs_5z.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/5m_vs_6m.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/5m_vs_6m.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..f2ae42c2da5c1d11683d63f458f1a554e2da66fc
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/5m_vs_6m.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/6h_vs_8z.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/6h_vs_8z.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..df01eb64749ef42cd68e51fb20e53303a476ffe0
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/6h_vs_8z.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/8m.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/8m.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..6593c72ffaeffada5a45e973967574f46ab0ec12
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/8m.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/8m_vs_9m.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/8m_vs_9m.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..5b8815f69c84c7ed1244d11bb1654286ba5d52c4
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/8m_vs_9m.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/MMM.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/MMM.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..ed26fe446731b821aae6853e310ecce782a89864
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/MMM.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/MMM2.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/MMM2.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..ab25a02bb391b1da5045a2ef73873c42305c429a
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/MMM2.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/__init__.py b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/bane_vs_bane.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/bane_vs_bane.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..bb81284cc1e7396278b972d39aeb14c2f86d906d
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/bane_vs_bane.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/corridor.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/corridor.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..90daed607b72b7d0cc14f4d2ccad7302ecc3bd01
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/corridor.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/infestor_viper.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/infestor_viper.SC2Map
new file mode 100644
index 0000000000000000000000000000000000000000..88a8b2cb6278f4d28a7e8753cfef72bd53a9d21c
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/infestor_viper.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/so_many_baneling.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/so_many_baneling.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..6a184e355eb0f724a53877c12ca055dc22182d67
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps/so_many_baneling.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps_two_player/3m.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps_two_player/3m.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..3fb426d93a7d00dae662da2fee566e7883781fd8
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps_two_player/3m.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps_two_player/3s5z.SC2Map b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps_two_player/3s5z.SC2Map
new file mode 100755
index 0000000000000000000000000000000000000000..5a18cd0392ca9ff34bfa1133e6b8d135068b0974
Binary files /dev/null and b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps_two_player/3s5z.SC2Map differ
diff --git a/DI-engine/dizoo/smac/envs/maps/SMAC_Maps_two_player/__init__.py b/DI-engine/dizoo/smac/envs/maps/SMAC_Maps_two_player/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/smac/envs/maps/__init__.py b/DI-engine/dizoo/smac/envs/maps/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/smac/envs/smac_action.py b/DI-engine/dizoo/smac/envs/smac_action.py
new file mode 100644
index 0000000000000000000000000000000000000000..aaceb32e9777eaf73ed542d6ae037d894d0fc412
--- /dev/null
+++ b/DI-engine/dizoo/smac/envs/smac_action.py
@@ -0,0 +1,426 @@
+import enum
+import math
+
+import numpy as np
+from collections import namedtuple
+from s2clientprotocol import common_pb2 as sc_common, sc2api_pb2 as sc_pb, raw_pb2 as r_pb
+
+ORIGINAL_AGENT = "me"
+OPPONENT_AGENT = "opponent"
+
+MOVE_EAST = 4
+MOVE_WEST = 5
+
+actions = {
+ "move": 16, # target: PointOrUnit
+ "attack": 23, # target: PointOrUnit
+ "stop": 4, # target: None
+ "heal": 386, # Unit
+ "parasitic_bomb": 2542, # target: Unit
+ 'fungal_growth': 74, # target: PointOrUnit
+}
+
+
+class Direction(enum.IntEnum):
+ NORTH = 0
+ SOUTH = 1
+ EAST = 2
+ WEST = 3
+
+
+def distance(x1, y1, x2, y2):
+ """Distance between two points."""
+ return math.hypot(x2 - x1, y2 - y1)
+
+
+class SMACAction:
+ info_template = namedtuple('EnvElementInfo', ['shape', 'value', 'to_agent_processor', 'from_agent_processor'])
+
+ def __init__(self, n_agents, n_enemies, two_player=False, mirror_opponent=True):
+ self.obs_pathing_grid = False
+ self.obs_terrain_height = False
+ self.state_last_action = True
+ self.state_timestep_number = False
+ self.n_obs_pathing = 8
+ self.n_obs_height = 9
+ self._move_amount = 2
+ self.n_actions_no_attack = 6
+ self.n_actions_move = 4
+ self.n_actions = self.n_actions_no_attack + n_enemies
+ self.map_x = 0
+ self.map_y = 0
+
+ # Status tracker
+ self.last_action = np.zeros((n_agents, self.n_actions))
+ self.last_action_opponent = np.zeros((n_enemies, self.n_actions))
+ self.n_agents = n_agents
+ self.n_enemies = n_enemies
+
+ self.two_player = two_player
+ self.mirror_opponent = mirror_opponent
+
+ def reset(self):
+ self.last_action.fill(0)
+ self.last_action_opponent.fill(0)
+
+ def update(self, map_info, map_x, map_y):
+ if map_info.pathing_grid.bits_per_pixel == 1:
+ vals = np.array(list(map_info.pathing_grid.data)).reshape(map_x, int(map_y / 8))
+ self.pathing_grid = np.transpose(
+ np.array([[(b >> i) & 1 for b in row for i in range(7, -1, -1)] for row in vals], dtype=np.bool)
+ )
+ else:
+ self.pathing_grid = np.invert(
+ np.flip(
+ np.transpose(np.array(list(map_info.pathing_grid.data), dtype=np.bool).reshape(map_x, map_y)),
+ axis=1
+ )
+ )
+
+ self.terrain_height = np.flip(
+ np.transpose(np.array(list(map_info.terrain_height.data)).reshape(map_x, map_y)), 1
+ ) / 255
+ self.map_x = map_x
+ self.map_y = map_y
+
+ def _parse_single(self, actions, engine, is_opponent=False):
+ actions = np.asarray(actions, dtype=np.int)
+ assert len(actions) == (self.n_enemies if is_opponent else self.n_agents)
+
+ actions_int = [int(a) for a in actions]
+ # Make them one-hot
+ if is_opponent:
+ self.last_action_opponent = np.eye(self.n_actions)[np.array(actions_int)]
+ else:
+ self.last_action = np.eye(self.n_actions)[np.array(actions_int)]
+
+ sc_actions = []
+ for a_id, action in enumerate(actions_int):
+ sc_action = self.get_agent_action(a_id, action, engine, is_opponent)
+ if sc_action:
+ sc_actions.append(sc_action)
+ return sc_actions
+
+ def get_action(self, actions, engine):
+ if self.two_player:
+ # ========= Two player mode ==========
+ assert self.two_player
+ assert isinstance(actions, dict)
+ assert ORIGINAL_AGENT in actions
+ assert OPPONENT_AGENT in actions
+
+ if self.mirror_opponent:
+ actions[OPPONENT_AGENT] = [self._transform_action(a) for a in actions[OPPONENT_AGENT]]
+
+ sc_actions_me = self._parse_single(actions[ORIGINAL_AGENT], engine, is_opponent=False)
+ sc_actions_opponent = self._parse_single(actions[OPPONENT_AGENT], engine, is_opponent=True)
+
+ return {ORIGINAL_AGENT: sc_actions_me, OPPONENT_AGENT: sc_actions_opponent}
+ else:
+ assert not isinstance(actions, dict)
+ sc_actions = self._parse_single(actions, engine, is_opponent=False)
+ return sc_actions
+
+ def get_unit_by_id(self, a_id, engine, is_opponent=False):
+ """Get unit by ID."""
+ if is_opponent:
+ return engine.enemies[a_id]
+ return engine.agents[a_id]
+
+ def get_agent_action(self, a_id, action, engine, is_opponent=False):
+ """Construct the action for agent a_id.
+ The input action here is *absolute* and is not mirrored!
+ We use skip_mirror=True in get_avail_agent_actions to avoid error.
+ """
+ avail_actions = self.get_avail_agent_actions(a_id, engine, is_opponent=is_opponent, skip_mirror=True)
+ try:
+ assert avail_actions[action] == 1, \
+ "Agent {} cannot perform action {} in ava {}".format(a_id, action, avail_actions)
+ except Exception as e:
+ if action == 0:
+ action = 1
+ else:
+ action = 1
+ # TODO
+ # raise e
+ unit = self.get_unit_by_id(a_id, engine, is_opponent=is_opponent)
+
+ # if is_opponent:
+ # action = avail_actions[0] if avail_actions[0] else avail_actions[1]
+
+ # ===== The follows is intact to the original =====
+ tag = unit.tag
+ type_id = unit.unit_type
+ x = unit.pos.x
+ y = unit.pos.y
+
+ # if is_opponent:
+ # print(f"The given unit tag {tag}, x {x}, y {y} and action {action}")
+
+ if action == 0:
+ # no-op (valid only when dead)
+ assert unit.health == 0, "No-op only available for dead agents."
+ return None
+ elif action == 1:
+ # stop
+ cmd = r_pb.ActionRawUnitCommand(ability_id=actions["stop"], unit_tags=[tag], queue_command=False)
+
+ elif action == 2:
+ # move north
+ cmd = r_pb.ActionRawUnitCommand(
+ ability_id=actions["move"],
+ target_world_space_pos=sc_common.Point2D(x=x, y=y + self._move_amount),
+ unit_tags=[tag],
+ queue_command=False
+ )
+
+ elif action == 3:
+ # move south
+ cmd = r_pb.ActionRawUnitCommand(
+ ability_id=actions["move"],
+ target_world_space_pos=sc_common.Point2D(x=x, y=y - self._move_amount),
+ unit_tags=[tag],
+ queue_command=False
+ )
+
+ elif action == 4:
+ # move east
+ cmd = r_pb.ActionRawUnitCommand(
+ ability_id=actions["move"],
+ target_world_space_pos=sc_common.Point2D(x=x + self._move_amount, y=y),
+ unit_tags=[tag],
+ queue_command=False
+ )
+
+ elif action == 5:
+ # move west
+ cmd = r_pb.ActionRawUnitCommand(
+ ability_id=actions["move"],
+ target_world_space_pos=sc_common.Point2D(x=x - self._move_amount, y=y),
+ unit_tags=[tag],
+ queue_command=False
+ )
+ else:
+ # attack/heal units that are in range
+ target_id = action - self.n_actions_no_attack
+ if engine.map_type == "MMM" and unit.unit_type == (engine.medivac_id_opponent
+ if is_opponent else engine.medivac_id):
+ target_unit = (engine.enemies[target_id] if is_opponent else engine.agents[target_id])
+ action_name = "heal"
+ elif engine.map_type == 'infestor_viper':
+ # viper
+ if type_id == 499:
+ target_unit = engine.enemies[target_id]
+ action_name = "parasitic_bomb"
+ # infestor
+ else:
+ target_unit = engine.enemies[target_id]
+ target_loc = (target_unit.pos.x, target_unit.pos.y)
+ action_name = "fungal_growth"
+ target_loc = sc_common.Point2D(x=target_loc[0], y=target_loc[1])
+ cmd = r_pb.ActionRawUnitCommand(
+ ability_id=actions[action_name],
+ target_world_space_pos=target_loc,
+ unit_tags=[tag],
+ queue_command=False
+ )
+ return sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))
+ else:
+ target_unit = (engine.agents[target_id] if is_opponent else engine.enemies[target_id])
+ action_name = "attack"
+
+ action_id = actions[action_name]
+ target_tag = target_unit.tag
+
+ cmd = r_pb.ActionRawUnitCommand(
+ ability_id=action_id, target_unit_tag=target_tag, unit_tags=[tag], queue_command=False
+ )
+
+ sc_action = sc_pb.Action(action_raw=r_pb.ActionRaw(unit_command=cmd))
+ return sc_action
+
+ def get_avail_agent_actions(self, agent_id, engine, is_opponent=False, skip_mirror=False):
+ """Returns the available actions for agent_id."""
+ medivac_id = engine.medivac_id_opponent if is_opponent else engine.medivac_id
+ unit = self.get_unit_by_id(agent_id, engine, is_opponent)
+ if unit.health > 0:
+ # cannot choose no-op when alive
+ avail_actions = [0] * self.n_actions
+
+ # stop should be allowed
+ avail_actions[1] = 1
+
+ # see if we can move
+ if self.can_move(unit, Direction.NORTH):
+ avail_actions[2] = 1
+ if self.can_move(unit, Direction.SOUTH):
+ avail_actions[3] = 1
+ if self.can_move(unit, Direction.EAST):
+ avail_actions[4] = 1
+ if self.can_move(unit, Direction.WEST):
+ avail_actions[5] = 1
+
+ # Can attack only alive units that are alive in the shooting range
+ shoot_range = self.unit_shoot_range(unit)
+
+ target_items = engine.enemies.items() if not is_opponent else engine.agents.items()
+ self_items = engine.agents.items() if not is_opponent else engine.enemies.items()
+ if engine.map_type == "MMM" and unit.unit_type == medivac_id:
+ # Medivacs cannot heal themselves or other flying units
+ target_items = [(t_id, t_unit) for (t_id, t_unit) in self_items if t_unit.unit_type != medivac_id]
+
+ for t_id, t_unit in target_items:
+ if t_unit.health > 0:
+ dist = distance(unit.pos.x, unit.pos.y, t_unit.pos.x, t_unit.pos.y)
+ if dist <= shoot_range:
+ if engine.map_type == "infestor_viper":
+ value = 0
+ # viper
+ if unit.unit_type == 499:
+ if unit.energy >= 125:
+ value = 1
+ # infestor
+ else:
+ if unit.energy >= 50:
+ value = 1
+ avail_actions[t_id + self.n_actions_no_attack] = value
+ else:
+ avail_actions[t_id + self.n_actions_no_attack] = 1
+
+ else:
+ # only no-op allowed
+ avail_actions = [1] + [0] * (self.n_actions - 1)
+
+ if (not skip_mirror) and self.mirror_opponent and is_opponent:
+ avail_actions[MOVE_EAST], avail_actions[MOVE_WEST] = \
+ avail_actions[MOVE_WEST], avail_actions[MOVE_EAST]
+
+ return avail_actions
+
+ def can_move(self, unit, direction):
+ """Whether a unit can move in a given direction."""
+ m = self._move_amount / 2
+
+ if direction == Direction.NORTH:
+ x, y = int(unit.pos.x), int(unit.pos.y + m)
+ elif direction == Direction.SOUTH:
+ x, y = int(unit.pos.x), int(unit.pos.y - m)
+ elif direction == Direction.EAST:
+ x, y = int(unit.pos.x + m), int(unit.pos.y)
+ else:
+ x, y = int(unit.pos.x - m), int(unit.pos.y)
+
+ if self.check_bounds(x, y) and self.pathing_grid[x, y]:
+ return True
+
+ return False
+
+ def check_bounds(self, x, y):
+ """Whether a point is within the map bounds."""
+ return 0 <= x < self.map_x and 0 <= y < self.map_y
+
+ def get_surrounding_pathing(self, unit):
+ """Returns pathing values of the grid surrounding the given unit."""
+ points = self.get_surrounding_points(unit, include_self=False)
+ vals = [self.pathing_grid[x, y] if self.check_bounds(x, y) else 1 for x, y in points]
+ return vals
+
+ def get_surrounding_height(self, unit):
+ """Returns height values of the grid surrounding the given unit."""
+ points = self.get_surrounding_points(unit, include_self=True)
+ vals = [self.terrain_height[x, y] if self.check_bounds(x, y) else 1 for x, y in points]
+ return vals
+
+ def unit_shoot_range(self, unit):
+ """Returns the shooting range for an agent."""
+ type_id = unit.unit_type
+ if type_id == 499:
+ return 8
+ elif type_id == 111:
+ return 10
+ else:
+ return 6
+
+ def get_surrounding_points(self, unit, include_self=False):
+ """Returns the surrounding points of the unit in 8 directions."""
+ x = int(unit.pos.x)
+ y = int(unit.pos.y)
+
+ ma = self._move_amount
+
+ points = [
+ (x, y + 2 * ma),
+ (x, y - 2 * ma),
+ (x + 2 * ma, y),
+ (x - 2 * ma, y),
+ (x + ma, y + ma),
+ (x - ma, y - ma),
+ (x + ma, y - ma),
+ (x - ma, y + ma),
+ ]
+
+ if include_self:
+ points.append((x, y))
+
+ return points
+
+ def get_movement_features(self, agent_id, engine, is_opponent=False):
+ unit = self.get_unit_by_id(agent_id, engine, is_opponent=is_opponent)
+ move_feats_dim = self.get_obs_move_feats_size()
+ move_feats = np.zeros(move_feats_dim, dtype=np.float32)
+
+ if unit.health > 0: # otherwise dead, return all zeros
+ # Movement features
+ avail_actions = self.get_avail_agent_actions(agent_id, engine, is_opponent=is_opponent)
+ for m in range(self.n_actions_move):
+ move_feats[m] = avail_actions[m + 2]
+
+ ind = self.n_actions_move
+
+ if self.obs_pathing_grid:
+ move_feats[ind:ind + self.n_obs_pathing # TODO self.n_obs_pathing ?
+ ] = self.get_surrounding_pathing(unit)
+ ind += self.n_obs_pathing
+
+ if self.obs_terrain_height:
+ move_feats[ind:] = self.get_surrounding_height(unit)
+ return move_feats
+
+ def get_obs_move_feats_size(self):
+ """Returns the size of the vector containing the agents's movement-related features."""
+ move_feats = self.n_actions_move
+ if self.obs_pathing_grid:
+ move_feats += self.n_obs_pathing
+ if self.obs_terrain_height:
+ move_feats += self.n_obs_height
+
+ return move_feats
+
+ def get_last_action(self, is_opponent=False):
+ if is_opponent:
+ ret = self.last_action_opponent
+ if self.mirror_opponent:
+ ret[:, MOVE_EAST], ret[:, MOVE_WEST] = \
+ ret[:, MOVE_WEST].copy(), ret[:, MOVE_EAST].copy()
+ else:
+ ret = self.last_action
+ return ret
+
+ def get_avail_actions(self, engine, is_opponent=False):
+ return [
+ self.get_avail_agent_actions(agent_id, engine, is_opponent=is_opponent)
+ for agent_id in range(self.n_agents if not is_opponent else self.n_enemies)
+ ]
+
+ @staticmethod
+ def _transform_action(a):
+ if a == MOVE_EAST: # intend to move east
+ a = MOVE_WEST
+ elif a == MOVE_WEST: # intend to move west
+ a = MOVE_EAST
+ return a
+
+ def info(self):
+ shape = (self.n_actions, )
+ value = {'min': 0, 'max': 1}
+ return SMACAction.info_template(shape, value, None, None)
diff --git a/DI-engine/dizoo/smac/envs/smac_env.py b/DI-engine/dizoo/smac/envs/smac_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..f08d5e096f9e829e67d0948df3f9cb67bd995752
--- /dev/null
+++ b/DI-engine/dizoo/smac/envs/smac_env.py
@@ -0,0 +1,1748 @@
+import copy
+import enum
+from collections import namedtuple
+from operator import attrgetter
+from functools import reduce
+
+import numpy as np
+import math
+import random
+from ditk import logging
+from easydict import EasyDict
+import pysc2.env.sc2_env as sc2_env
+from pysc2.env.sc2_env import SC2Env, Agent, MAX_STEP_COUNT, get_default, crop_and_deduplicate_names
+from pysc2.lib import protocol
+from s2clientprotocol import common_pb2 as sc_common
+from s2clientprotocol import debug_pb2 as d_pb
+from s2clientprotocol import sc2api_pb2 as sc_pb
+from ding.envs import BaseEnv
+from ding.envs.common.env_element import EnvElement, EnvElementInfo
+from ding.utils import ENV_REGISTRY, deep_merge_dicts
+
+from .smac_map import get_map_params
+from .smac_action import SMACAction, distance
+from .smac_reward import SMACReward
+
+races = {
+ "R": sc_common.Random,
+ "P": sc_common.Protoss,
+ "T": sc_common.Terran,
+ "Z": sc_common.Zerg,
+}
+
+ORIGINAL_AGENT = "me"
+OPPONENT_AGENT = "opponent"
+
+SUPPORT_MAPS = [
+ "SMAC_Maps_two_player/3s5z.SC2Map",
+ "SMAC_Maps_two_player/3m.SC2Map",
+ "GBU_Maps/infestor_viper.sc2map",
+]
+
+FORCE_RESTART_INTERVAL = 50000
+
+
+class Direction(enum.IntEnum):
+ NORTH = 0
+ SOUTH = 1
+ EAST = 2
+ WEST = 3
+
+
+@ENV_REGISTRY.register('smac')
+class SMACEnv(SC2Env, BaseEnv):
+ """
+ This environment provides the interface for both single agent and multiple agents (two players) in
+ SC2 environment.
+ """
+
+ SMACTimestep = namedtuple('SMACTimestep', ['obs', 'reward', 'done', 'info', 'episode_steps'])
+ SMACEnvInfo = namedtuple('SMACEnvInfo', ['agent_num', 'obs_space', 'act_space', 'rew_space', 'episode_limit'])
+ config = dict(
+ two_player=False,
+ mirror_opponent=False,
+ reward_type="original",
+ save_replay_episodes=None,
+ difficulty=7,
+ reward_death_value=10,
+ reward_win=200,
+ obs_alone=False,
+ game_steps_per_episode=None,
+ reward_only_positive=True,
+ death_mask=False,
+ special_global_state=False,
+ # add map's center location ponit or not
+ add_center_xy=True,
+ independent_obs=False,
+ # add agent's id information or not in special global state
+ state_agent_id=True,
+ )
+
+ def __init__(
+ self,
+ cfg,
+ ):
+ cfg = deep_merge_dicts(EasyDict(self.config), cfg)
+ self.cfg = cfg
+ self.save_replay_episodes = cfg.save_replay_episodes
+ assert (self.save_replay_episodes is None) or isinstance(
+ self.save_replay_episodes, int
+ ) # Denote the number of replays to save
+ self.two_player = cfg.two_player
+ self.difficulty = cfg.difficulty
+ self.obs_alone = cfg.obs_alone
+ self.game_steps_per_episode = cfg.game_steps_per_episode
+
+ map_name = cfg.map_name
+ assert map_name is not None
+ map_params = get_map_params(map_name)
+ self.reward_only_positive = cfg.reward_only_positive
+ self.difficulty = cfg.difficulty
+ self.obs_alone = cfg.obs_alone
+ self.players, self.num_players = self._get_players(
+ "agent_vs_agent" if self.two_player else "game_vs_bot",
+ player1_race=map_params["a_race"],
+ player2_race=map_params["b_race"]
+ )
+ self._map_name = map_name
+
+ # SMAC used
+ self.n_agents = map_params["n_agents"]
+ self.n_enemies = map_params["n_enemies"]
+ self.episode_limit = map_params["limit"]
+
+ self._agent_race = map_params["a_race"]
+ self._bot_race = map_params["b_race"]
+ self.shield_bits_ally = 1 if self._agent_race == "P" else 0
+ self.shield_bits_enemy = 1 if self._bot_race == "P" else 0
+ self.unit_type_bits = map_params["unit_type_bits"]
+ self.map_type = map_params["map_type"]
+
+ self.agents = {}
+ self.enemies = {}
+ self._episode_count = 0
+ self._episode_steps = 0
+ self._total_steps = 0
+ self._next_reset_steps = FORCE_RESTART_INTERVAL
+
+ self._obs = None
+ self.battles_won = 0
+ self.battles_game = 0
+ self.timeouts = 0
+ self.force_restarts = 0
+ self.last_stats = None
+
+ self._min_unit_type = 0
+ self.marine_id = self.marauder_id = self.medivac_id = 0
+ self.hydralisk_id = self.zergling_id = self.baneling_id = 0
+ self.stalker_id = self.colossus_id = self.zealot_id = 0
+
+ self.add_center_xy = cfg.add_center_xy
+ self.state_agent_id = cfg.state_agent_id
+ self.death_mask = cfg.death_mask
+ self.special_global_state = cfg.special_global_state
+
+ # reward
+ self.reward_death_value = cfg.reward_death_value
+ self.reward_win = cfg.reward_win
+ self.reward_defeat = 0
+ self.reward_negative_scale = 0.5
+ self.reward_type = cfg.reward_type
+ self.max_reward = (self.n_enemies * self.reward_death_value + self.reward_win)
+ self.obs_pathing_grid = False
+ self.obs_own_health = True
+ self.obs_all_health = True
+ self.obs_instead_of_state = False
+ self.obs_last_action = True
+ self.obs_terrain_height = False
+ self.obs_timestep_number = False
+ self.state_last_action = True
+ self.state_timestep_number = False
+ if self.obs_all_health:
+ self.obs_own_health = True
+ self.n_obs_pathing = 8
+ self.n_obs_height = 9
+ self._move_amount = 2
+ self.continuing_episode = False
+
+ self._seed = None
+ self._launch_env_flag = True
+ self.just_force_restarts = False
+
+ # Set to false if you need structured observation / state
+ self.flatten_observation = True
+ self.mirror_opponent = cfg.mirror_opponent
+ if self.mirror_opponent:
+ self.flatten_observation = False
+
+ # Opponent related variables
+ self.battles_won_opponent = 0
+ self.battles_defeat = 0
+ self._min_unit_type_opponent = 0
+ self.marine_id_opponent = self.marauder_id_opponent = self.medivac_id_opponent = 0
+ self.hydralisk_id_opponent = self.zergling_id_opponent = self.baneling_id_opponent = 0
+ self.stalker_id_opponent = self.colossus_id_opponent = self.zealot_id_opponent = 0
+ self.max_distance_x = 0
+ self.max_distance_y = 0
+ self.map_x = 0
+ self.map_y = 0
+
+ self.previous_ally_units = None
+ self.previous_enemy_units = None
+
+ self.independent_obs = cfg.independent_obs
+
+ self.action_helper = SMACAction(self.n_agents, self.n_enemies, self.two_player, self.mirror_opponent)
+ self.reward_helper = SMACReward(
+ self.n_agents,
+ self.n_enemies,
+ self.two_player,
+ self.reward_type,
+ self.max_reward,
+ reward_only_positive=self.reward_only_positive
+ )
+
+ self._observation_space = self.get_obs_space()
+ self._action_space = self.action_helper.info(),
+ self._reward_space = self.reward_helper.info(),
+
+ def seed(self, seed, dynamic_seed=False):
+ self._seed = seed
+
+ def _create_join(self):
+ if self.two_player:
+ for m in self._maps:
+ m.directory = "SMAC_Maps_two_player"
+ map_path = m.path
+ assert map_path in SUPPORT_MAPS, "We only support the following maps: {}. Please move " \
+ "the maps in evaluate/sources/SMAC_Maps_two_player " \
+ "to the maps folder of SC2."
+ # copy and overwrite original implementation
+ map_inst = random.choice(self._maps)
+ self._map_name = map_inst.name
+
+ self._step_mul = max(1, self._default_step_mul or map_inst.step_mul)
+ self._score_index = get_default(self._default_score_index, map_inst.score_index)
+ self._score_multiplier = get_default(self._default_score_multiplier, map_inst.score_multiplier)
+ self._episode_length = get_default(self._default_episode_length, map_inst.game_steps_per_episode)
+ if self._episode_length <= 0 or self._episode_length > MAX_STEP_COUNT:
+ self._episode_length = MAX_STEP_COUNT
+
+ # Create the game. Set the first instance as the host.
+ create = sc_pb.RequestCreateGame(disable_fog=self._disable_fog, realtime=self._realtime)
+
+ if self._battle_net_map:
+ create.battlenet_map_name = map_inst.battle_net
+ else:
+ create.local_map.map_path = map_inst.path
+ map_data = map_inst.data(self._run_config)
+ if self._num_agents == 1:
+ create.local_map.map_data = map_data
+ else:
+ # Save the maps so they can access it. Don't do it in parallel since SC2
+ # doesn't respect tmpdir on windows, which leads to a race condition:
+ # https://github.com/Blizzard/s2client-proto/issues/102
+ for c in self._controllers:
+ c.save_map(map_inst.path, map_data)
+ if self._random_seed is not None:
+ create.random_seed = self._random_seed
+ for p in self._players:
+ if isinstance(p, Agent):
+ create.player_setup.add(type=sc_pb.Participant)
+ else:
+ create.player_setup.add(
+ type=sc_pb.Computer,
+ race=random.choice(p.race),
+ difficulty=p.difficulty,
+ ai_build=random.choice(p.build)
+ )
+ if self._num_agents > 1:
+ self._controllers[1].create_game(create)
+ else:
+ self._controllers[0].create_game(create)
+
+ # Create the join requests.
+ agent_players = [p for p in self._players if isinstance(p, Agent)]
+ self.sanitized_names = crop_and_deduplicate_names(p.name for p in agent_players)
+ join_reqs = []
+ for p, name, interface in zip(agent_players, self.sanitized_names, self._interface_options):
+ join = sc_pb.RequestJoinGame(options=interface)
+ join.race = random.choice(p.race)
+ join.player_name = name
+ if self._ports:
+ join.shared_port = 0 # unused
+ join.server_ports.game_port = self._ports[0]
+ join.server_ports.base_port = self._ports[1]
+ for i in range(self._num_agents - 1):
+ join.client_ports.add(game_port=self._ports[i * 2 + 2], base_port=self._ports[i * 2 + 3])
+ join_reqs.append(join)
+
+ # Join the game. This must be run in parallel because Join is a blocking
+ # call to the game that waits until all clients have joined.
+ self._parallel.run((c.join_game, join) for c, join in zip(self._controllers, join_reqs))
+
+ self._game_info = self._parallel.run(c.game_info for c in self._controllers)
+ for g, interface in zip(self._game_info, self._interface_options):
+ if g.options.render != interface.render:
+ logging.warning(
+ "Actual interface options don't match requested options:\n"
+ "Requested:\n%s\n\nActual:\n%s", interface, g.options
+ )
+
+ # original pysc2 case
+ # if require_features:
+ # self._features = [
+ # features.features_from_game_info(
+ # game_info=g, agent_interface_format=aif, map_name=self._map_name)
+ # for g, aif in zip(self._game_info, self._interface_formats)]
+ # smac case
+ self._features = None
+
+ def _get_players(self, game_type, player1_race, player2_race):
+ if game_type == 'game_vs_bot':
+ agent_num = 1
+ print('difficulty', self.difficulty)
+ players = [sc2_env.Agent(races[player1_race]), sc2_env.Bot(races[player2_race], self.difficulty)]
+ elif game_type == 'agent_vs_agent':
+ agent_num = 2
+ players = [sc2_env.Agent(races[player1_race]), sc2_env.Agent(races[player2_race])]
+ else:
+ raise KeyError("invalid game_type: {}".format(game_type))
+ return players, agent_num
+
+ def _launch(self):
+
+ print("*****LAUNCH FUNCTION CALLED*****")
+
+ # necessary for compatibility with pysc2
+ from absl import flags
+ flags.FLAGS(['smac'])
+ agent_interface_format = sc2_env.parse_agent_interface_format(use_raw_units=True)
+
+ SC2Env.__init__(
+ self,
+ map_name=self.map_name,
+ battle_net_map=False,
+ players=self.players,
+ agent_interface_format=agent_interface_format,
+ discount=None,
+ discount_zero_after_timeout=False,
+ visualize=False,
+ step_mul=8,
+ realtime=False,
+ save_replay_episodes=self.save_replay_episodes,
+ replay_dir=None if self.save_replay_episodes is None else ".",
+ replay_prefix=None,
+ game_steps_per_episode=self.game_steps_per_episode,
+ score_index=None,
+ score_multiplier=None,
+ random_seed=self._seed,
+ disable_fog=False,
+ ensure_available_actions=True,
+ version=None
+ )
+
+ self._launch_env_flag = True
+
+ game_info = self._game_info[0]
+ map_info = game_info.start_raw
+ map_play_area_min = map_info.playable_area.p0
+ map_play_area_max = map_info.playable_area.p1
+ self.max_distance_x = map_play_area_max.x - map_play_area_min.x
+ self.max_distance_y = map_play_area_max.y - map_play_area_min.y
+ self.map_x = map_info.map_size.x
+ self.map_y = map_info.map_size.y
+
+ self.action_helper.update(map_info, self.map_x, self.map_y)
+
+ def _restart_episode(self):
+ """Restart the environment by killing all units on the map.
+ There is a trigger in the SC2Map file, which restarts the
+ episode when there are no units left.
+ """
+ try:
+ run_commands = [
+ (
+ self._controllers[0].debug,
+ d_pb.DebugCommand(
+ kill_unit=d_pb.DebugKillUnit(
+ tag=[unit.tag for unit in self.agents.values() if unit.health > 0] +
+ [unit.tag for unit in self.enemies.values() if unit.health > 0]
+ )
+ )
+ )
+ ]
+ if self.two_player:
+ run_commands.append(
+ (self._controllers[1].debug, d_pb.DebugCommand(kill_unit=d_pb.DebugKillUnit(tag=[])))
+ )
+ # Kill all units on the map.
+ self._parallel.run(run_commands)
+ # Forward 2 step to make sure all units revive.
+ ret = self._parallel.run((c.step, 2) for c in self._controllers)
+ except (protocol.ProtocolError, protocol.ConnectionError) as e:
+ print("Error happen in _restart. Error: ", e)
+ self.full_restart()
+
+ def full_restart(self):
+ self.close()
+ self._launch()
+ self.force_restarts += 1
+ self.just_force_restarts = True
+
+ def reset(self):
+ self._episode_steps = 0
+ self._final_eval_fake_reward = 0.
+ old_unit_tags = set(u.tag for u in self.agents.values()).union(set(u.tag for u in self.enemies.values()))
+
+ if self.just_force_restarts:
+ old_unit_tags = set()
+ self.just_force_restarts = False
+
+ if self._launch_env_flag:
+ # Launch StarCraft II
+ print("*************LAUNCH TOTAL GAME********************")
+ self._launch()
+ self._launch_env_flag = False
+ elif (self._total_steps > self._next_reset_steps) or (self.save_replay_episodes is not None):
+ # Avoid hitting the real episode limit of SC2 env
+ print("We are full restarting the environment! save_replay_episodes: ", self.save_replay_episodes)
+ self.full_restart()
+ old_unit_tags = set()
+ self._next_reset_steps += FORCE_RESTART_INTERVAL
+ else:
+ self._restart_episode()
+
+ # Information kept for counting the reward
+ self.win_counted = False
+ self.defeat_counted = False
+
+ self.action_helper.reset()
+
+ self.previous_ally_units = None
+ self.previous_enemy_units = None
+
+ # if self.heuristic_ai:
+ # self.heuristic_targets = [None] * self.n_agents
+
+ count = 0
+ while count <= 5:
+ self._update_obs()
+ #print("INTERNAL INIT UNIT BEGIN")
+ init_flag = self.init_units(old_unit_tags)
+ #print("INTERNAL INIT UNIT OVER", init_flag)
+ count += 1
+ if init_flag:
+ break
+ else:
+ old_unit_tags = set()
+ if count >= 5:
+ raise RuntimeError("reset 5 times error")
+
+ self.reward_helper.reset(self.max_reward)
+
+ assert all(u.health > 0 for u in self.agents.values())
+ assert all(u.health > 0 for u in self.enemies.values())
+
+ if not self.two_player:
+ if self.obs_alone:
+ agent_state, agent_alone_state, agent_alone_padding_state = self.get_obs()
+ return {
+ 'agent_state': agent_state,
+ 'agent_alone_state': agent_alone_state,
+ 'agent_alone_padding_state': agent_alone_padding_state,
+ 'global_state': self.get_state(),
+ 'action_mask': self.get_avail_actions()
+ }
+ elif self.independent_obs:
+ return {
+ 'agent_state': self.get_obs(),
+ 'global_state': self.get_obs(),
+ 'action_mask': self.get_avail_actions(),
+ }
+ elif self.special_global_state:
+ return {
+ 'agent_state': self.get_obs(),
+ 'global_state': self.get_global_special_state(),
+ 'action_mask': self.get_avail_actions(),
+ }
+ else:
+ return {
+ 'agent_state': self.get_obs(),
+ 'global_state': self.get_state(),
+ 'action_mask': self.get_avail_actions(),
+ }
+
+ return {
+ 'agent_state': {
+ ORIGINAL_AGENT: self.get_obs(),
+ OPPONENT_AGENT: self.get_obs(True)
+ },
+ 'global_state': {
+ ORIGINAL_AGENT: self.get_state(),
+ OPPONENT_AGENT: self.get_state(True)
+ },
+ 'action_mask': {
+ ORIGINAL_AGENT: self.get_avail_actions(),
+ OPPONENT_AGENT: self.get_avail_actions(True),
+ },
+ }
+
+ def _submit_actions(self, actions):
+ if self.two_player:
+ # actions is a dict with 'me' and 'opponent' keys.
+ actions_me, actions_opponent = actions[ORIGINAL_AGENT], actions[OPPONENT_AGENT]
+ self._parallel.run(
+ [
+ (self._controllers[0].actions, sc_pb.RequestAction(actions=actions_me)),
+ (self._controllers[1].actions, sc_pb.RequestAction(actions=actions_opponent))
+ ]
+ )
+ step_mul = self._step_mul
+ if step_mul <= 0:
+ raise ValueError("step_mul should be positive, got {}".format(step_mul))
+ if not any(c.status_ended for c in self._controllers): # May already have ended.
+ self._parallel.run((c.step, step_mul) for c in self._controllers)
+ self._update_obs(target_game_loop=self._episode_steps + step_mul)
+ else:
+ # actions is a sequence
+ # Send action request
+ req_actions = sc_pb.RequestAction(actions=actions)
+ self._controllers[0].actions(req_actions)
+ self._controllers[0].step(self._step_mul)
+ self._update_obs()
+
+ def _get_empty_action(self, old_action):
+ me_act = []
+ for a_id in range(self.n_agents):
+ no_op = self.action_helper.get_avail_agent_actions(a_id, self, is_opponent=False)[0]
+ me_act.append(0 if no_op else 1)
+
+ if isinstance(old_action, dict):
+ op_act = []
+ for a_id in range(self.n_enemies):
+ no_op = self.action_helper.get_avail_agent_actions(a_id, self, is_opponent=False)[0]
+ op_act.append(0 if no_op else 1)
+ new_action = {ORIGINAL_AGENT: me_act, OPPONENT_AGENT: op_act}
+ else:
+ new_action = me_act
+ return new_action
+
+ def step(self, actions, force_return_two_player=False):
+ processed_actions = self.action_helper.get_action(actions, self)
+ # self._submit_actions(processed_actions)
+ try:
+ # print("Submitting actions: ", actions)
+ self._submit_actions(processed_actions)
+ # raise ValueError() # To test the functionality of restart
+ except (protocol.ProtocolError, protocol.ConnectionError, ValueError) as e:
+ print("Error happen in step! Error: ", e)
+ self.full_restart()
+ info = {'abnormal': True}
+ return self.SMACTimestep(obs=None, reward=None, done=True, info=info, episode_steps=self._episode_steps)
+
+ # Update units
+ game_end_code = self.update_units()
+ rewards, terminates, infos = self._collect_step_data(game_end_code, actions)
+
+ infos["draw"] = int(not (infos["me"]["battle_won"] or infos["opponent"]["battle_won"]))
+
+ if (not self.two_player) and (not force_return_two_player):
+ rewards, terminates, new_infos = rewards[ORIGINAL_AGENT], terminates[ORIGINAL_AGENT], infos[ORIGINAL_AGENT]
+ self._final_eval_fake_reward += rewards
+ new_infos["battle_lost"] = infos[OPPONENT_AGENT]["battle_won"]
+ new_infos["draw"] = infos["draw"]
+ new_infos['eval_episode_return'] = infos['eval_episode_return']
+ if 'episode_info' in infos:
+ new_infos['episode_info'] = infos['episode_info']
+ new_infos['fake_eval_episode_return'] = infos['fake_eval_episode_return']
+ infos = new_infos
+ if self.obs_alone:
+ agent_state, agent_alone_state, agent_alone_padding_state = self.get_obs()
+ obs = {
+ 'agent_state': agent_state,
+ 'agent_alone_state': agent_alone_state,
+ 'agent_alone_padding_state': agent_alone_padding_state,
+ 'global_state': self.get_state(),
+ 'action_mask': self.get_avail_actions()
+ }
+ elif self.independent_obs:
+ obs = {
+ 'agent_state': self.get_obs(),
+ 'global_state': self.get_obs(),
+ 'action_mask': self.get_avail_actions(),
+ }
+ elif self.special_global_state:
+ obs = {
+ 'agent_state': self.get_obs(),
+ 'global_state': self.get_global_special_state(),
+ 'action_mask': self.get_avail_actions(),
+ }
+ else:
+ obs = {
+ 'agent_state': self.get_obs(),
+ 'global_state': self.get_state(),
+ 'action_mask': self.get_avail_actions(),
+ }
+ else:
+ raise NotImplementedError
+
+ return self.SMACTimestep(
+ obs=copy.deepcopy(obs), reward=rewards, done=terminates, info=infos, episode_steps=self._episode_steps
+ )
+
+ def _collect_step_data(self, game_end_code, action):
+ """This function is called only once at each step, no matter whether you take opponent as agent.
+ We already return dicts for each term, as in Multi-agent scenario.
+ """
+ self._total_steps += 1
+ self._episode_steps += 1
+
+ terminated = False
+
+ reward = self.reward_helper.get_reward(self, action, game_end_code, self.win_counted, self.defeat_counted)
+ for k in reward:
+ reward[k] = np.array(reward[k]).astype(np.float32)
+
+ info = {
+ ORIGINAL_AGENT: {
+ "battle_won": False
+ },
+ OPPONENT_AGENT: {
+ "battle_won": False
+ },
+ 'eval_episode_return': 0.,
+ 'fake_eval_episode_return': 0.
+ }
+
+ if game_end_code is not None:
+ # Battle is over
+ terminated = True
+ self.battles_game += 1
+ if game_end_code == 1 and not self.win_counted:
+ # The original agent win the game.
+ self.battles_won += 1
+ self.win_counted = True
+ info[ORIGINAL_AGENT]["battle_won"] = True
+ info[OPPONENT_AGENT]["battle_won"] = False
+ info['eval_episode_return'] = 1.
+ elif game_end_code == -1 and not self.defeat_counted:
+ self.defeat_counted = True
+ info[ORIGINAL_AGENT]["battle_won"] = False
+ info[OPPONENT_AGENT]["battle_won"] = True
+
+ elif self._episode_steps >= self.episode_limit:
+ # Episode limit reached
+ terminated = True
+ if self.continuing_episode:
+ info[ORIGINAL_AGENT]["episode_limit"] = True
+ info[OPPONENT_AGENT]["episode_limit"] = True
+ self.battles_game += 1
+ self.timeouts += 1
+ # info['eval_episode_return'] = -0.5
+
+ # if sum(u.health + u.shield for u in self.agents.values()) >= \
+ # sum(u.health + u.shield for u in self.enemies.values()):
+ # # lj fix
+ # reward[ORIGINAL_AGENT] += 1
+ # reward[OPPONENT_AGENT] += -1
+ # else:
+ # reward[ORIGINAL_AGENT] += -1
+ # reward[OPPONENT_AGENT] += 1
+
+ if terminated:
+ self._episode_count += 1
+ # 1-dim to 0-dim
+ # count units that are still alive
+ dead_allies, dead_enemies = 0, 0
+ for al_id, al_unit in self.agents.items():
+ if al_unit.health == 0:
+ dead_allies += 1
+ for e_id, e_unit in self.enemies.items():
+ if e_unit.health == 0:
+ dead_enemies += 1
+
+ info['episode_info'] = {
+ 'final_eval_fake_reward': self._final_eval_fake_reward[0],
+ 'dead_allies': dead_allies,
+ 'dead_enemies': dead_enemies
+ }
+ self._final_eval_fake_reward = 0.
+
+ # PZH: Zero at first step
+ if self._episode_steps == 1:
+ for k in reward.keys():
+ reward[k] *= 0.0
+ if terminated:
+ print("WARNNING! Should not terminate at the first step!")
+
+ # Test purpose
+ # reward = {k: 0 * v + 100 for k, v in reward.items()}
+ info['fake_eval_episode_return'] = reward[ORIGINAL_AGENT]
+ return reward, {ORIGINAL_AGENT: terminated, OPPONENT_AGENT: terminated, "__all__": terminated}, info
+
+ def close(self):
+ SC2Env.close(self)
+
+ def init_units(self, old_unit_tags):
+ count = 0
+ while count < 10:
+ # Sometimes not all units have yet been created by SC2
+ self.agents = {}
+ self.enemies = {}
+
+ ally_units = [
+ unit for unit in self._obs.observation.raw_data.units
+ if (unit.owner == 1) and (unit.tag not in old_unit_tags)
+ ]
+ ally_units_sorted = sorted(
+ ally_units,
+ key=attrgetter("unit_type", "pos.x", "pos.y"),
+ reverse=False,
+ )
+
+ for i in range(len(ally_units_sorted)):
+ self.agents[i] = ally_units_sorted[i]
+
+ self.max_reward = self.n_enemies * self.reward_death_value + self.reward_win
+ for unit in self._obs.observation.raw_data.units:
+ if (unit.owner == 2) and (unit.tag not in old_unit_tags):
+ self.enemies[len(self.enemies)] = unit
+ # if self._episode_count == 0:
+ self.max_reward += unit.health_max + unit.shield_max
+
+ all_agents_created = (len(self.agents) == self.n_agents)
+ all_enemies_created = (len(self.enemies) == self.n_enemies)
+
+ all_agents_health = all(u.health > 0 for u in self.agents.values())
+ all_enemies_health = all(u.health > 0 for u in self.enemies.values())
+
+ if all_agents_created and all_enemies_created \
+ and all_agents_health and all_enemies_health: # all good
+ if self._episode_count == 0:
+ min_unit_type = min(unit.unit_type for unit in self.agents.values())
+ min_unit_type_opponent = min(unit.unit_type for unit in self.enemies.values())
+ self._init_ally_unit_types(min_unit_type)
+ self._init_enemy_unit_types(min_unit_type_opponent)
+ return True
+ else:
+ print(
+ "***ALL GOOD FAIL***", all_agents_created, all_enemies_created, all_agents_health,
+ all_enemies_health, len(self._obs.observation.raw_data.units)
+ )
+ print(
+ (len(self.agents) == self.n_agents), (len(self.enemies) == self.n_enemies), len(self.agents),
+ self.n_agents, len(self.enemies), self.n_enemies
+ )
+ self._restart_episode()
+ count += 1
+
+ try:
+ self._parallel.run((c.step, 1) for c in self._controllers)
+ self._update_obs()
+
+ except (protocol.ProtocolError, protocol.ConnectionError) as e:
+ print("Error happen in init_units.", e)
+ self.full_restart()
+ return False
+ if count >= 10:
+ self.full_restart()
+ return False
+
+ def _init_enemy_unit_types(self, min_unit_type_opponent):
+ """Initialise ally unit types. Should be called once from the
+ init_units function.
+ """
+ self._min_unit_type_opponent = min_unit_type_opponent
+ if self.map_type == "marines":
+ self.marine_id_opponent = min_unit_type_opponent
+ elif self.map_type == "stalkers_and_zealots":
+ self.stalker_id_opponent = min_unit_type_opponent
+ self.zealot_id_opponent = min_unit_type_opponent + 1
+ elif self.map_type == "colossi_stalkers_zealots":
+ self.colossus_id_opponent = min_unit_type_opponent
+ self.stalker_id_opponent = min_unit_type_opponent + 1
+ self.zealot_id_opponent = min_unit_type_opponent + 2
+ elif self.map_type == "MMM":
+ self.marauder_id_opponent = min_unit_type_opponent
+ self.marine_id_opponent = min_unit_type_opponent + 1
+ self.medivac_id_opponent = min_unit_type_opponent + 2
+ elif self.map_type == "zealots":
+ self.zealot_id_opponent = min_unit_type_opponent
+ elif self.map_type == "hydralisks":
+ self.hydralisk_id_opponent = min_unit_type_opponent
+ elif self.map_type == "stalkers":
+ self.stalker_id_opponent = min_unit_type_opponent
+ elif self.map_type == "colossus":
+ self.colossus_id_opponent = min_unit_type_opponent
+ elif self.map_type == "bane":
+ self.baneling_id_opponent = min_unit_type_opponent
+ self.zergling_id_opponent = min_unit_type_opponent + 1
+
+ # ================
+ def unit_max_shield(self, unit, is_opponent=False):
+ """Returns maximal shield for a given unit."""
+ stalker_id = self.stalker_id_opponent if is_opponent else self.stalker_id
+ zealot_id = self.zealot_id_opponent if is_opponent else self.zealot_id
+ colossus_id = self.colossus_id_opponent if is_opponent else self.colossus_id
+ if unit.unit_type == 74 or unit.unit_type == stalker_id:
+ return 80 # Protoss's Stalker
+ if unit.unit_type == 73 or unit.unit_type == zealot_id:
+ return 50 # Protoss's Zaelot
+ if unit.unit_type == 4 or unit.unit_type == colossus_id:
+ return 150 # Protoss's Colossus
+
+ def get_unit_type_id(self, unit, ally, is_opponent=False):
+ if is_opponent and ally:
+ return unit.unit_type - self._min_unit_type_opponent
+ else:
+ if ally: # use new SC2 unit types
+ if self.map_type == "infestor_viper":
+ if unit.unit_type == 393:
+ type_id = 0
+ else:
+ type_id = 1
+ else:
+ type_id = unit.unit_type - self._min_unit_type
+ else: # use default SC2 unit types
+ if self.map_type == "stalkers_and_zealots":
+ # id(Stalker) = 74, id(Zealot) = 73
+ type_id = unit.unit_type - 73
+ elif self.map_type == "colossi_stalkers_zealots":
+ # id(Stalker) = 74, id(Zealot) = 73, id(Colossus) = 4
+ if unit.unit_type == 4:
+ type_id = 0
+ elif unit.unit_type == 74:
+ type_id = 1
+ else:
+ type_id = 2
+ elif self.map_type == "bane":
+ if unit.unit_type == 9:
+ type_id = 0
+ else:
+ type_id = 1
+ elif self.map_type == "MMM":
+ if unit.unit_type == 51:
+ type_id = 0
+ elif unit.unit_type == 48:
+ type_id = 1
+ else:
+ type_id = 2
+ elif self.map_type == "infestor_viper":
+ if unit.unit_type == 393:
+ type_id = 0
+ else:
+ type_id = 1
+ else:
+ raise ValueError()
+ return type_id
+
+ def _update_obs(self, target_game_loop=0):
+ # Transform in the thread so it runs while waiting for other observations.
+ # def parallel_observe(c, f):
+
+ if self.two_player:
+
+ def parallel_observe(c):
+ obs = c.observe(target_game_loop=target_game_loop)
+ # agent_obs = f.transform_obs(obs)
+ return obs
+
+ # with self._metrics.measure_observation_time():
+ self._obses = self._parallel.run((parallel_observe, c) for c in self._controllers)
+ else:
+ self._obses = [self._controllers[0].observe()]
+
+ self._obs = self._obses[0]
+
+ def _init_ally_unit_types(self, min_unit_type):
+ """Initialise ally unit types. Should be called once from the
+ init_units function.
+ """
+ self._min_unit_type = min_unit_type
+ if self.map_type == "marines":
+ self.marine_id = min_unit_type
+ elif self.map_type == "stalkers_and_zealots":
+ self.stalker_id = min_unit_type
+ self.zealot_id = min_unit_type + 1
+ elif self.map_type == "colossi_stalkers_zealots":
+ self.colossus_id = min_unit_type
+ self.stalker_id = min_unit_type + 1
+ self.zealot_id = min_unit_type + 2
+ elif self.map_type == "MMM":
+ self.marauder_id = min_unit_type
+ self.marine_id = min_unit_type + 1
+ self.medivac_id = min_unit_type + 2
+ elif self.map_type == "zealots":
+ self.zealot_id = min_unit_type
+ elif self.map_type == "hydralisks":
+ self.hydralisk_id = min_unit_type
+ elif self.map_type == "stalkers":
+ self.stalker_id = min_unit_type
+ elif self.map_type == "colossus":
+ self.colossus_id = min_unit_type
+ elif self.map_type == "bane":
+ self.baneling_id = min_unit_type
+ self.zergling_id = min_unit_type + 1
+
+ def get_obs(self, is_opponent=False):
+ """Returns all agent observations in a list.
+ NOTE: Agents should have access only to their local observations
+ during decentralised execution.
+ """
+ agents_obs_list = [self.get_obs_agent(i, is_opponent) for i in range(self.n_agents)]
+
+ if self.mirror_opponent and is_opponent:
+ assert not self.flatten_observation
+ new_obs = list()
+ for agent_obs in agents_obs_list:
+ new_agent_obs = dict()
+ for key, feat in agent_obs.items():
+ feat = feat.copy()
+
+ if key == "move_feats":
+ can_move_right = feat[2]
+ can_move_left = feat[3]
+ feat[3] = can_move_right
+ feat[2] = can_move_left
+
+ elif key == "enemy_feats" or key == "ally_feats":
+ for unit_id in range(feat.shape[0]):
+ # Relative x
+ feat[unit_id, 2] = -feat[unit_id, 2]
+
+ new_agent_obs[key] = feat
+ new_obs.append(new_agent_obs)
+ agents_obs_list = new_obs
+
+ if not self.flatten_observation:
+ agents_obs_list = self._flatten_obs(agents_obs_list)
+ if self.obs_alone:
+ agents_obs_list, agents_obs_alone_list, agents_obs_alone_padding_list = list(zip(*agents_obs_list))
+ return np.array(agents_obs_list).astype(np.float32), np.array(agents_obs_alone_list).astype(
+ np.float32
+ ), np.array(agents_obs_alone_padding_list).astype(np.float32)
+ else:
+ return np.array(agents_obs_list).astype(np.float32)
+
+ def get_obs_agent(self, agent_id, is_opponent=False):
+ unit = self.get_unit_by_id(agent_id, is_opponent=is_opponent)
+
+ # TODO All these function should have an opponent version
+ enemy_feats_dim = self.get_obs_enemy_feats_size()
+ ally_feats_dim = self.get_obs_ally_feats_size()
+ own_feats_dim = self.get_obs_own_feats_size()
+
+ enemy_feats = np.zeros(enemy_feats_dim, dtype=np.float32)
+ ally_feats = np.zeros(ally_feats_dim, dtype=np.float32)
+ own_feats = np.zeros(own_feats_dim, dtype=np.float32)
+
+ move_feats = self.action_helper.get_movement_features(agent_id, self, is_opponent)
+
+ if unit.health > 0: # otherwise dead, return all zeros
+ x = unit.pos.x
+ y = unit.pos.y
+ sight_range = self.unit_sight_range(agent_id)
+ avail_actions = self.action_helper.get_avail_agent_actions(agent_id, self, is_opponent)
+
+ # Enemy features
+ if is_opponent:
+ enemy_items = self.agents.items()
+ else:
+ enemy_items = self.enemies.items()
+ for e_id, e_unit in enemy_items:
+ e_x = e_unit.pos.x
+ e_y = e_unit.pos.y
+ dist = distance(x, y, e_x, e_y)
+
+ if (dist < sight_range and e_unit.health > 0): # visible and alive
+ # Sight range > shoot range
+ enemy_feats[e_id, 0] = avail_actions[self.action_helper.n_actions_no_attack + e_id] # available
+ enemy_feats[e_id, 1] = dist / sight_range # distance
+ enemy_feats[e_id, 2] = (e_x - x) / sight_range # relative X
+ enemy_feats[e_id, 3] = (e_y - y) / sight_range # relative Y
+
+ ind = 4
+ if self.obs_all_health:
+ enemy_feats[e_id, ind] = (e_unit.health / e_unit.health_max) # health
+ ind += 1
+ if self.shield_bits_enemy > 0:
+ max_shield = self.unit_max_shield(e_unit, not is_opponent)
+ enemy_feats[e_id, ind] = (e_unit.shield / max_shield) # shield
+ ind += 1
+
+ if self.unit_type_bits > 0:
+ # If enemy is computer, than use ally=False, but since now we use
+ # agent for enemy, ally=True
+ if self.two_player:
+ type_id = self.get_unit_type_id(e_unit, True, not is_opponent)
+ else:
+ type_id = self.get_unit_type_id(e_unit, False, False)
+ enemy_feats[e_id, ind + type_id] = 1 # unit type
+
+ # Ally features
+ al_ids = [
+ al_id for al_id in range((self.n_agents if not is_opponent else self.n_enemies)) if al_id != agent_id
+ ]
+ for i, al_id in enumerate(al_ids):
+
+ al_unit = self.get_unit_by_id(al_id, is_opponent=is_opponent)
+ al_x = al_unit.pos.x
+ al_y = al_unit.pos.y
+ dist = distance(x, y, al_x, al_y)
+
+ if (dist < sight_range and al_unit.health > 0): # visible and alive
+ ally_feats[i, 0] = 1 # visible
+ ally_feats[i, 1] = dist / sight_range # distance
+ ally_feats[i, 2] = (al_x - x) / sight_range # relative X
+ ally_feats[i, 3] = (al_y - y) / sight_range # relative Y
+
+ ind = 4
+ if self.obs_all_health:
+ ally_feats[i, ind] = (al_unit.health / al_unit.health_max) # health
+ ind += 1
+ if self.shield_bits_ally > 0:
+ max_shield = self.unit_max_shield(al_unit, is_opponent)
+ ally_feats[i, ind] = (al_unit.shield / max_shield) # shield
+ ind += 1
+
+ if self.unit_type_bits > 0:
+ type_id = self.get_unit_type_id(al_unit, True, is_opponent)
+ ally_feats[i, ind + type_id] = 1
+ ind += self.unit_type_bits
+
+ # LJ fix
+ # if self.obs_last_action:
+ # ally_feats[i, ind:] = self.action_helper.get_last_action(is_opponent)[al_id]
+
+ # Own features
+ ind = 0
+ if self.obs_own_health:
+ own_feats[ind] = unit.health / unit.health_max
+ ind += 1
+ if self.shield_bits_ally > 0:
+ max_shield = self.unit_max_shield(unit, is_opponent)
+ own_feats[ind] = unit.shield / max_shield
+ ind += 1
+
+ if self.unit_type_bits > 0:
+ type_id = self.get_unit_type_id(unit, True, is_opponent)
+ own_feats[ind + type_id] = 1
+ ind += self.unit_type_bits
+ if self.obs_last_action:
+ own_feats[ind:] = self.action_helper.get_last_action(is_opponent)[agent_id]
+
+ if is_opponent:
+ agent_id_feats = np.zeros(self.n_enemies)
+ else:
+ agent_id_feats = np.zeros(self.n_agents)
+ agent_id_feats[agent_id] = 1
+ # Only set to false by outside wrapper
+ if self.flatten_observation:
+ agent_obs = np.concatenate(
+ (
+ move_feats.flatten(),
+ enemy_feats.flatten(),
+ ally_feats.flatten(),
+ own_feats.flatten(),
+ agent_id_feats,
+ )
+ )
+ if self.obs_timestep_number:
+ agent_obs = np.append(agent_obs, self._episode_steps / self.episode_limit)
+ if self.obs_alone:
+ agent_obs_alone = np.concatenate(
+ (
+ move_feats.flatten(),
+ enemy_feats.flatten(),
+ own_feats.flatten(),
+ agent_id_feats,
+ )
+ )
+ agent_obs_alone_padding = np.concatenate(
+ (
+ move_feats.flatten(),
+ enemy_feats.flatten(),
+ np.zeros_like(ally_feats.flatten()),
+ own_feats.flatten(),
+ agent_id_feats,
+ )
+ )
+ if self.obs_timestep_number:
+ agent_obs_alone = np.append(agent_obs_alone, self._episode_steps / self.episode_limit)
+ agent_obs_alone_padding = np.append(
+ agent_obs_alone_padding, self._episode_steps / self.episode_limit
+ )
+ return agent_obs, agent_obs_alone, agent_obs_alone_padding
+ else:
+ return agent_obs
+ else:
+ agent_obs = dict(
+ move_feats=move_feats,
+ enemy_feats=enemy_feats,
+ ally_feats=ally_feats,
+ own_feats=own_feats,
+ agent_id_feats=agent_id_feats
+ )
+ if self.obs_timestep_number:
+ agent_obs["obs_timestep_number"] = self._episode_steps / self.episode_limit
+
+ return agent_obs
+
+ def get_unit_by_id(self, a_id, is_opponent=False):
+ """Get unit by ID."""
+ if is_opponent:
+ return self.enemies[a_id]
+ return self.agents[a_id]
+
+ def get_obs_enemy_feats_size(self):
+ """ Returns the dimensions of the matrix containing enemy features.
+ Size is n_enemies x n_features.
+ """
+ nf_en = 4 + self.unit_type_bits
+
+ if self.obs_all_health:
+ nf_en += 1 + self.shield_bits_enemy
+
+ return self.n_enemies, nf_en
+
+ def get_obs_ally_feats_size(self):
+ """Returns the dimensions of the matrix containing ally features.
+ Size is n_allies x n_features.
+ """
+ nf_al = 4 + self.unit_type_bits
+
+ if self.obs_all_health:
+ nf_al += 1 + self.shield_bits_ally
+
+ # LJ fix
+ # if self.obs_last_action:
+ # nf_al += self.n_actions
+
+ return self.n_agents - 1, nf_al
+
+ def get_obs_own_feats_size(self):
+ """Returns the size of the vector containing the agents' own features.
+ """
+ own_feats = self.unit_type_bits
+ if self.obs_own_health:
+ own_feats += 1 + self.shield_bits_ally
+ if self.obs_timestep_number:
+ own_feats += 1
+ if self.obs_last_action:
+ own_feats += self.n_actions
+
+ return own_feats
+
+ def get_obs_move_feats_size(self):
+ """Returns the size of the vector containing the agents's movement-related features."""
+ return self.action_helper.get_obs_move_feats_size()
+
+ def get_state_size(self, is_opponent=False):
+ """Returns the size of the global state."""
+ if self.obs_instead_of_state:
+ return self.get_obs_size(is_opponent) * self.n_agents
+
+ nf_al = 4 + self.shield_bits_ally + self.unit_type_bits
+ nf_en = 3 + self.shield_bits_enemy + self.unit_type_bits
+
+ enemy_state = self.n_enemies * nf_en
+ ally_state = self.n_agents * nf_al
+
+ size = enemy_state + ally_state
+
+ if self.state_last_action:
+ if is_opponent:
+ size += self.n_enemies * self.n_actions_opponent
+ else:
+ size += self.n_agents * self.n_actions
+ if self.state_timestep_number:
+ size += 1
+
+ return size
+
+ def get_obs_size(self, is_opponent=False):
+ # TODO suppose the agents formation are same for both opponent and me. This can be extended in future.
+ """Returns the size of the observation."""
+ own_feats = self.get_obs_own_feats_size()
+ move_feats = self.get_obs_move_feats_size()
+
+ n_enemies, n_enemy_feats = self.get_obs_enemy_feats_size()
+ n_allies, n_ally_feats = self.get_obs_ally_feats_size()
+
+ enemy_feats = n_enemies * n_enemy_feats
+ ally_feats = n_allies * n_ally_feats
+
+ if is_opponent:
+ agent_id_feats = self.n_enemies
+ else:
+ agent_id_feats = self.n_agents
+ return move_feats + enemy_feats + ally_feats + own_feats + agent_id_feats
+
+ def get_obs_alone_size(self, is_opponent=False):
+ # TODO suppose the agents formation are same for both opponent and me. This can be extended in future.
+ """Returns the size of the observation."""
+ own_feats = self.get_obs_own_feats_size()
+ move_feats = self.get_obs_move_feats_size()
+
+ n_enemies, n_enemy_feats = self.get_obs_enemy_feats_size()
+
+ enemy_feats = n_enemies * n_enemy_feats
+
+ if is_opponent:
+ agent_id_feats = self.n_enemies
+ else:
+ agent_id_feats = self.n_agents
+ return move_feats + enemy_feats + own_feats + agent_id_feats
+
+ def get_state(self, is_opponent=False):
+ if self.obs_instead_of_state:
+ obs_concat = np.concatenate(self.get_obs(), axis=0).astype(np.float32)
+ return obs_concat
+
+ nf_al = 4 + self.shield_bits_ally + self.unit_type_bits
+ nf_en = 3 + self.shield_bits_enemy + self.unit_type_bits
+
+ ally_state = np.zeros((self.n_agents, nf_al))
+ enemy_state = np.zeros((self.n_enemies, nf_en))
+
+ center_x = self.map_x / 2
+ center_y = self.map_y / 2
+
+ if is_opponent:
+ iterator = self.enemies.items()
+ else:
+ iterator = self.agents.items()
+
+ for al_id, al_unit in iterator:
+ if al_unit.health > 0:
+ x = al_unit.pos.x
+ y = al_unit.pos.y
+ max_cd = self.unit_max_cooldown(al_unit, is_opponent=is_opponent)
+
+ ally_state[al_id, 0] = (al_unit.health / al_unit.health_max) # health
+ if (self.map_type == "MMM"
+ and al_unit.unit_type == (self.medivac_id_opponent if is_opponent else self.medivac_id)):
+ ally_state[al_id, 1] = al_unit.energy / max_cd # energy
+ else:
+ ally_state[al_id, 1] = (al_unit.weapon_cooldown / max_cd) # cooldown
+ ally_state[al_id, 2] = (x - center_x) / self.max_distance_x # relative X
+ ally_state[al_id, 3] = (y - center_y) / self.max_distance_y # relative Y
+
+ ind = 4
+ if self.shield_bits_ally > 0:
+ max_shield = self.unit_max_shield(al_unit, is_opponent=is_opponent)
+ ally_state[al_id, ind] = (al_unit.shield / max_shield) # shield
+ ind += 1
+
+ if self.unit_type_bits > 0:
+ type_id = self.get_unit_type_id(al_unit, True, is_opponent=is_opponent)
+ ally_state[al_id, ind + type_id] = 1
+
+ if is_opponent:
+ iterator = self.agents.items()
+ else:
+ iterator = self.enemies.items()
+ for e_id, e_unit in iterator:
+ if e_unit.health > 0:
+ x = e_unit.pos.x
+ y = e_unit.pos.y
+
+ enemy_state[e_id, 0] = (e_unit.health / e_unit.health_max) # health
+ enemy_state[e_id, 1] = (x - center_x) / self.max_distance_x # relative X
+ enemy_state[e_id, 2] = (y - center_y) / self.max_distance_y # relative Y
+
+ ind = 3
+ if self.shield_bits_enemy > 0:
+ max_shield = self.unit_max_shield(e_unit, is_opponent=False)
+ enemy_state[e_id, ind] = (e_unit.shield / max_shield) # shield
+ ind += 1
+
+ if self.unit_type_bits > 0:
+ type_id = self.get_unit_type_id(e_unit, True if self.two_player else False, is_opponent=False)
+ enemy_state[e_id, ind + type_id] = 1
+
+ last_action = self.action_helper.get_last_action(is_opponent)
+ if self.flatten_observation:
+ state = np.append(ally_state.flatten(), enemy_state.flatten())
+ if self.state_last_action:
+ state = np.append(state, last_action.flatten())
+ if self.state_timestep_number:
+ state = np.append(state, self._episode_steps / self.episode_limit)
+ state = state.astype(dtype=np.float32)
+ else:
+ state = dict(ally_state=ally_state, enemy_state=enemy_state)
+ if self.state_last_action:
+ state["last_action"] = last_action
+ if self.state_timestep_number:
+ state["state_timestep_number"] = self._episode_steps / self.episode_limit
+
+ if self.mirror_opponent and is_opponent:
+ assert not self.flatten_observation
+
+ new_state = dict()
+ for key, s in state.items():
+ s = s.copy()
+
+ if key == "ally_state":
+ # relative x
+ for unit_id in range(s.shape[0]):
+ s[unit_id, 2] = -s[unit_id, 2]
+
+ elif key == "enemy_state":
+ # relative x
+ for unit_id in range(s.shape[0]):
+ s[unit_id, 1] = -s[unit_id, 1]
+
+ # key == "last_action" is processed in SMACAction
+ new_state[key] = s
+ state = new_state
+
+ if not self.flatten_observation:
+ state = self._flatten_state(state)
+ return np.array(state).astype(np.float32)
+
+ def get_global_special_state(self, is_opponent=False):
+ """Returns all agent observations in a list.
+ NOTE: Agents should have access only to their local observations
+ during decentralised execution.
+ """
+ agents_obs_list = [self.get_state_agent(i, is_opponent) for i in range(self.n_agents)]
+
+ return np.array(agents_obs_list).astype(np.float32)
+
+ def get_global_special_state_size(self, is_opponent=False):
+ enemy_feats_dim = self.get_state_enemy_feats_size()
+ enemy_feats_dim = reduce(lambda x, y: x * y, enemy_feats_dim)
+ ally_feats_dim = self.get_state_ally_feats_size()
+ ally_feats_dim = reduce(lambda x, y: x * y, ally_feats_dim)
+ own_feats_dim = self.get_state_own_feats_size()
+ size = enemy_feats_dim + ally_feats_dim + own_feats_dim + self.n_agents
+ if self.state_timestep_number:
+ size += 1
+ return size
+
+ def get_state_agent(self, agent_id, is_opponent=False):
+ """Returns observation for agent_id. The observation is composed of:
+
+ - agent movement features (where it can move to, height information and pathing grid)
+ - enemy features (available_to_attack, health, relative_x, relative_y, shield, unit_type)
+ - ally features (visible, distance, relative_x, relative_y, shield, unit_type)
+ - agent unit features (health, shield, unit_type)
+
+ All of this information is flattened and concatenated into a list,
+ in the aforementioned order. To know the sizes of each of the
+ features inside the final list of features, take a look at the
+ functions ``get_obs_move_feats_size()``,
+ ``get_obs_enemy_feats_size()``, ``get_obs_ally_feats_size()`` and
+ ``get_obs_own_feats_size()``.
+
+ The size of the observation vector may vary, depending on the
+ environment configuration and type of units present in the map.
+ For instance, non-Protoss units will not have shields, movement
+ features may or may not include terrain height and pathing grid,
+ unit_type is not included if there is only one type of unit in the
+ map etc.).
+
+ NOTE: Agents should have access only to their local observations
+ during decentralised execution.
+ """
+ if self.obs_instead_of_state:
+ obs_concat = np.concatenate(self.get_obs(), axis=0).astype(np.float32)
+ return obs_concat
+
+ unit = self.get_unit_by_id(agent_id)
+
+ enemy_feats_dim = self.get_state_enemy_feats_size()
+ ally_feats_dim = self.get_state_ally_feats_size()
+ own_feats_dim = self.get_state_own_feats_size()
+
+ enemy_feats = np.zeros(enemy_feats_dim, dtype=np.float32)
+ ally_feats = np.zeros(ally_feats_dim, dtype=np.float32)
+ own_feats = np.zeros(own_feats_dim, dtype=np.float32)
+ agent_id_feats = np.zeros(self.n_agents, dtype=np.float32)
+
+ center_x = self.map_x / 2
+ center_y = self.map_y / 2
+
+ if (self.death_mask and unit.health > 0) or (not self.death_mask): # otherwise dead, return all zeros
+ x = unit.pos.x
+ y = unit.pos.y
+ sight_range = self.unit_sight_range(agent_id)
+ last_action = self.action_helper.get_last_action(is_opponent)
+
+ # Movement features
+ avail_actions = self.get_avail_agent_actions(agent_id)
+
+ # Enemy features
+ for e_id, e_unit in self.enemies.items():
+ e_x = e_unit.pos.x
+ e_y = e_unit.pos.y
+ dist = self.distance(x, y, e_x, e_y)
+
+ if e_unit.health > 0: # visible and alive
+ # Sight range > shoot range
+ if unit.health > 0:
+ enemy_feats[e_id, 0] = avail_actions[self.action_helper.n_actions_no_attack + e_id] # available
+ enemy_feats[e_id, 1] = dist / sight_range # distance
+ enemy_feats[e_id, 2] = (e_x - x) / sight_range # relative X
+ enemy_feats[e_id, 3] = (e_y - y) / sight_range # relative Y
+ if dist < sight_range:
+ enemy_feats[e_id, 4] = 1 # visible
+
+ ind = 5
+ if self.obs_all_health:
+ enemy_feats[e_id, ind] = (e_unit.health / e_unit.health_max) # health
+ ind += 1
+ if self.shield_bits_enemy > 0:
+ max_shield = self.unit_max_shield(e_unit)
+ enemy_feats[e_id, ind] = (e_unit.shield / max_shield) # shield
+ ind += 1
+
+ if self.unit_type_bits > 0:
+ type_id = self.get_unit_type_id(e_unit, False)
+ enemy_feats[e_id, ind + type_id] = 1 # unit type
+ ind += self.unit_type_bits
+
+ if self.add_center_xy:
+ enemy_feats[e_id, ind] = (e_x - center_x) / self.max_distance_x # center X
+ enemy_feats[e_id, ind + 1] = (e_y - center_y) / self.max_distance_y # center Y
+
+ # Ally features
+ al_ids = [al_id for al_id in range(self.n_agents) if al_id != agent_id]
+ for i, al_id in enumerate(al_ids):
+
+ al_unit = self.get_unit_by_id(al_id)
+ al_x = al_unit.pos.x
+ al_y = al_unit.pos.y
+ dist = self.distance(x, y, al_x, al_y)
+ max_cd = self.unit_max_cooldown(al_unit)
+
+ if al_unit.health > 0: # visible and alive
+ if unit.health > 0:
+ if dist < sight_range:
+ ally_feats[i, 0] = 1 # visible
+ ally_feats[i, 1] = dist / sight_range # distance
+ ally_feats[i, 2] = (al_x - x) / sight_range # relative X
+ ally_feats[i, 3] = (al_y - y) / sight_range # relative Y
+
+ if (self.map_type == "MMM" and al_unit.unit_type == self.medivac_id):
+ ally_feats[i, 4] = al_unit.energy / max_cd # energy
+ else:
+ ally_feats[i, 4] = (al_unit.weapon_cooldown / max_cd) # cooldown
+
+ ind = 5
+ if self.obs_all_health:
+ ally_feats[i, ind] = (al_unit.health / al_unit.health_max) # health
+ ind += 1
+ if self.shield_bits_ally > 0:
+ max_shield = self.unit_max_shield(al_unit)
+ ally_feats[i, ind] = (al_unit.shield / max_shield) # shield
+ ind += 1
+
+ if self.add_center_xy:
+ ally_feats[i, ind] = (al_x - center_x) / self.max_distance_x # center X
+ ally_feats[i, ind + 1] = (al_y - center_y) / self.max_distance_y # center Y
+ ind += 2
+
+ if self.unit_type_bits > 0:
+ type_id = self.get_unit_type_id(al_unit, True)
+ ally_feats[i, ind + type_id] = 1
+ ind += self.unit_type_bits
+
+ if self.state_last_action:
+ ally_feats[i, ind:] = last_action[al_id]
+
+ # Own features
+ ind = 0
+ own_feats[0] = 1 # visible
+ own_feats[1] = 0 # distance
+ own_feats[2] = 0 # X
+ own_feats[3] = 0 # Y
+ ind = 4
+ if self.obs_own_health:
+ own_feats[ind] = unit.health / unit.health_max
+ ind += 1
+ if self.shield_bits_ally > 0:
+ max_shield = self.unit_max_shield(unit)
+ own_feats[ind] = unit.shield / max_shield
+ ind += 1
+
+ if self.add_center_xy:
+ own_feats[ind] = (x - center_x) / self.max_distance_x # center X
+ own_feats[ind + 1] = (y - center_y) / self.max_distance_y # center Y
+ ind += 2
+
+ if self.unit_type_bits > 0:
+ type_id = self.get_unit_type_id(unit, True)
+ own_feats[ind + type_id] = 1
+ ind += self.unit_type_bits
+
+ if self.state_last_action:
+ own_feats[ind:] = last_action[agent_id]
+
+ state = np.concatenate((ally_feats.flatten(), enemy_feats.flatten(), own_feats.flatten()))
+
+ # Agent id features
+ if self.state_agent_id:
+ agent_id_feats[agent_id] = 1.
+ state = np.append(state, agent_id_feats.flatten())
+
+ if self.state_timestep_number:
+ state = np.append(state, self._episode_steps / self.episode_limit)
+
+ return state
+
+ def get_state_enemy_feats_size(self):
+ """ Returns the dimensions of the matrix containing enemy features.
+ Size is n_enemies x n_features.
+ """
+ nf_en = 5 + self.unit_type_bits
+
+ if self.obs_all_health:
+ nf_en += 1 + self.shield_bits_enemy
+
+ if self.add_center_xy:
+ nf_en += 2
+
+ return self.n_enemies, nf_en
+
+ def get_state_ally_feats_size(self):
+ """Returns the dimensions of the matrix containing ally features.
+ Size is n_allies x n_features.
+ """
+ nf_al = 5 + self.unit_type_bits
+
+ if self.obs_all_health:
+ nf_al += 1 + self.shield_bits_ally
+
+ if self.state_last_action:
+ nf_al += self.n_actions
+
+ if self.add_center_xy:
+ nf_al += 2
+
+ return self.n_agents - 1, nf_al
+
+ def get_state_own_feats_size(self):
+ """Returns the size of the vector containing the agents' own features.
+ """
+ own_feats = 4 + self.unit_type_bits
+ if self.obs_own_health:
+ own_feats += 1 + self.shield_bits_ally
+
+ if self.state_last_action:
+ own_feats += self.n_actions
+
+ if self.add_center_xy:
+ own_feats += 2
+
+ return own_feats
+
+ @staticmethod
+ def distance(x1, y1, x2, y2):
+ """Distance between two points."""
+ return math.hypot(x2 - x1, y2 - y1)
+
+ def unit_max_cooldown(self, unit, is_opponent=False):
+ """Returns the maximal cooldown for a unit."""
+ if is_opponent:
+ switcher = {
+ self.marine_id_opponent: 15,
+ self.marauder_id_opponent: 25,
+ self.medivac_id_opponent: 200, # max energy
+ self.stalker_id_opponent: 35,
+ self.zealot_id_opponent: 22,
+ self.colossus_id_opponent: 24,
+ self.hydralisk_id_opponent: 10,
+ self.zergling_id_opponent: 11,
+ self.baneling_id_opponent: 1
+ }
+ else:
+ switcher = {
+ self.marine_id: 15,
+ self.marauder_id: 25,
+ self.medivac_id: 200, # max energy
+ self.stalker_id: 35,
+ self.zealot_id: 22,
+ self.colossus_id: 24,
+ self.hydralisk_id: 10,
+ self.zergling_id: 11,
+ self.baneling_id: 1
+ }
+ return switcher.get(unit.unit_type, 15)
+
+ def update_units(self):
+ """Update units after an environment step.
+ This function assumes that self._obs is up-to-date.
+ """
+ n_ally_alive = 0
+ n_enemy_alive = 0
+
+ # Store previous state
+ self.previous_ally_units = copy.deepcopy(self.agents)
+ self.previous_enemy_units = copy.deepcopy(self.enemies)
+
+ for al_id, al_unit in self.agents.items():
+ updated = False
+ for unit in self._obs.observation.raw_data.units:
+ if al_unit.tag == unit.tag:
+ self.agents[al_id] = unit
+ updated = True
+ n_ally_alive += 1
+ break
+
+ if not updated: # dead
+ al_unit.health = 0
+
+ for e_id, e_unit in self.enemies.items():
+ updated = False
+ for unit in self._obs.observation.raw_data.units:
+ if e_unit.tag == unit.tag:
+ self.enemies[e_id] = unit
+ updated = True
+ n_enemy_alive += 1
+ break
+
+ if not updated: # dead
+ e_unit.health = 0
+
+ if (n_ally_alive == 0 and n_enemy_alive > 0 or self.only_medivac_left(ally=True)):
+ return -1 # lost
+ if (n_ally_alive > 0 and n_enemy_alive == 0 or self.only_medivac_left(ally=False)):
+ return 1 # won
+ if n_ally_alive == 0 and n_enemy_alive == 0:
+ return 0
+
+ return None
+
+ def only_medivac_left(self, ally):
+ """Check if only Medivac units are left."""
+ if self.map_type != "MMM":
+ return False
+
+ if ally:
+ units_alive = [
+ a for a in self.agents.values()
+ if (a.health > 0 and a.unit_type != self.medivac_id and a.unit_type != self.medivac_id_opponent
+ ) # <<== add medivac_id_opponent
+ ]
+ if len(units_alive) == 0:
+ return True
+ return False
+ else:
+ units_alive = [
+ a for a in self.enemies.values()
+ if (a.health > 0 and a.unit_type != self.medivac_id and a.unit_type != self.medivac_id_opponent)
+ ]
+ if len(units_alive) == 1 and units_alive[0].unit_type == 54:
+ return True
+ return False
+
+ @property
+ def n_actions(self):
+ return self.action_helper.n_actions
+
+ @property
+ def n_actions_opponent(self):
+ return self.n_actions
+
+ # Workaround
+ def get_avail_agent_actions(self, agent_id, is_opponent=False):
+ return self.action_helper.get_avail_agent_actions(agent_id, self, is_opponent)
+
+ def unit_sight_range(self, agent_id=None):
+ """Returns the sight range for an agent."""
+ return 9
+
+ @staticmethod
+ def _flatten_obs(obs):
+
+ def _get_keys(agent_obs):
+ keys = ["move_feats", "enemy_feats", "ally_feats", "own_feats", "agent_id_feats"]
+ if "obs_timestep_number" in agent_obs:
+ keys.append("obs_timestep_number")
+ return keys
+
+ return _flatten(obs, _get_keys)
+
+ @staticmethod
+ def _flatten_state(state):
+
+ def _get_keys(s):
+ keys = ["ally_state", "enemy_state"]
+ if "last_action" in s:
+ keys.append("last_action")
+ if "state_timestep_number" in s:
+ keys.append("state_timestep_number")
+ return keys
+
+ return _flatten([state], _get_keys)[0]
+
+ def get_avail_actions(self, is_opponent=False):
+ ava_action = self.action_helper.get_avail_actions(self, is_opponent)
+ ava_action = np.array(ava_action).astype(np.float32)
+ return ava_action
+
+ def get_obs_space(self, is_opponent=False):
+ T = EnvElementInfo
+ agent_num = self.n_enemies if is_opponent else self.n_agents
+ if self.obs_alone:
+ obs_space = T(
+ {
+ 'agent_state': (agent_num, self.get_obs_size(is_opponent)),
+ 'agent_alone_state': (agent_num, self.get_obs_alone_size(is_opponent)),
+ 'agent_alone_padding_state': (agent_num, self.get_obs_size(is_opponent)),
+ 'global_state': (self.get_state_size(is_opponent), ),
+ 'action_mask': (agent_num, *self.action_helper.info().shape),
+ },
+ None,
+ )
+ else:
+ if self.special_global_state:
+ obs_space = T(
+ {
+ 'agent_state': (agent_num, self.get_obs_size(is_opponent)),
+ 'global_state': (agent_num, self.get_global_special_state_size(is_opponent)),
+ 'action_mask': (agent_num, *self.action_helper.info().shape),
+ },
+ None,
+ )
+ else:
+ obs_space = T(
+ {
+ 'agent_state': (agent_num, self.get_obs_size(is_opponent)),
+ 'global_state': (self.get_state_size(is_opponent), ),
+ 'action_mask': (agent_num, *self.action_helper.info().shape),
+ },
+ None,
+ )
+ return obs_space
+
+ @property
+ def observation_space(self):
+ return self._observation_space
+
+ @property
+ def action_space(self):
+ return self._action_space
+
+ @property
+ def reward_space(self):
+ return self._reward_space
+
+ def __repr__(self):
+ return "DI-engine SMAC Env"
+
+
+def _flatten(obs, get_keys):
+ new_obs = list()
+ for agent_obs in obs:
+ keys = get_keys(agent_obs)
+ new_agent_obs = np.concatenate([agent_obs[feat_key].flatten() for feat_key in keys])
+ new_obs.append(new_agent_obs)
+ return new_obs
+
+
+SMACTimestep = SMACEnv.SMACTimestep
+SMACEnvInfo = SMACEnv.SMACEnvInfo
diff --git a/DI-engine/dizoo/smac/envs/smac_map.py b/DI-engine/dizoo/smac/envs/smac_map.py
new file mode 100644
index 0000000000000000000000000000000000000000..4810aa8e01ea02e94408624874a446d969c358c6
--- /dev/null
+++ b/DI-engine/dizoo/smac/envs/smac_map.py
@@ -0,0 +1,238 @@
+from pysc2.maps import lib
+import os
+
+
+class SMACMap(lib.Map):
+ directory = os.path.join(os.path.dirname(__file__), "maps/SMAC_Maps")
+ download = "https://github.com/oxwhirl/smac#smac-maps"
+ players = 2
+ step_mul = 8
+ game_steps_per_episode = 0
+
+
+# Copied from smac/env/starcraft2/maps/smac_maps.py
+map_param_registry = {
+ "3m": {
+ "n_agents": 3,
+ "n_enemies": 3,
+ "limit": 60,
+ "a_race": "T",
+ "b_race": "T",
+ "unit_type_bits": 0,
+ "map_type": "marines",
+ },
+ "8m": {
+ "n_agents": 8,
+ "n_enemies": 8,
+ "limit": 120,
+ "a_race": "T",
+ "b_race": "T",
+ "unit_type_bits": 0,
+ "map_type": "marines",
+ },
+ "25m": {
+ "n_agents": 25,
+ "n_enemies": 25,
+ "limit": 150,
+ "a_race": "T",
+ "b_race": "T",
+ "unit_type_bits": 0,
+ "map_type": "marines",
+ },
+ "5m_vs_6m": {
+ "n_agents": 5,
+ "n_enemies": 6,
+ "limit": 70,
+ "a_race": "T",
+ "b_race": "T",
+ "unit_type_bits": 0,
+ "map_type": "marines",
+ },
+ "8m_vs_9m": {
+ "n_agents": 8,
+ "n_enemies": 9,
+ "limit": 120,
+ "a_race": "T",
+ "b_race": "T",
+ "unit_type_bits": 0,
+ "map_type": "marines",
+ },
+ "10m_vs_11m": {
+ "n_agents": 10,
+ "n_enemies": 11,
+ "limit": 150,
+ "a_race": "T",
+ "b_race": "T",
+ "unit_type_bits": 0,
+ "map_type": "marines",
+ },
+ "27m_vs_30m": {
+ "n_agents": 27,
+ "n_enemies": 30,
+ "limit": 180,
+ "a_race": "T",
+ "b_race": "T",
+ "unit_type_bits": 0,
+ "map_type": "marines",
+ },
+ "MMM": {
+ "n_agents": 10,
+ "n_enemies": 10,
+ "limit": 150,
+ "a_race": "T",
+ "b_race": "T",
+ "unit_type_bits": 3,
+ "map_type": "MMM",
+ },
+ "MMM2": {
+ "n_agents": 10,
+ "n_enemies": 12,
+ "limit": 180,
+ "a_race": "T",
+ "b_race": "T",
+ "unit_type_bits": 3,
+ "map_type": "MMM",
+ },
+ "2s3z": {
+ "n_agents": 5,
+ "n_enemies": 5,
+ "limit": 120,
+ "a_race": "P",
+ "b_race": "P",
+ "unit_type_bits": 2,
+ "map_type": "stalkers_and_zealots",
+ },
+ "3s5z": {
+ "n_agents": 8,
+ "n_enemies": 8,
+ "limit": 150,
+ "a_race": "P",
+ "b_race": "P",
+ "unit_type_bits": 2,
+ "map_type": "stalkers_and_zealots",
+ },
+ "infestor_viper": {
+ "n_agents": 2,
+ "n_enemies": 9,
+ "limit": 150,
+ "a_race": "Z",
+ "b_race": "Z",
+ "unit_type_bits": 2,
+ "map_type": "infestor_viper"
+ },
+ "3s5z_vs_3s6z": {
+ "n_agents": 8,
+ "n_enemies": 9,
+ "limit": 170,
+ "a_race": "P",
+ "b_race": "P",
+ "unit_type_bits": 2,
+ "map_type": "stalkers_and_zealots",
+ },
+ "3s_vs_3z": {
+ "n_agents": 3,
+ "n_enemies": 3,
+ "limit": 150,
+ "a_race": "P",
+ "b_race": "P",
+ "unit_type_bits": 0,
+ "map_type": "stalkers",
+ },
+ "3s_vs_4z": {
+ "n_agents": 3,
+ "n_enemies": 4,
+ "limit": 200,
+ "a_race": "P",
+ "b_race": "P",
+ "unit_type_bits": 0,
+ "map_type": "stalkers",
+ },
+ "3s_vs_5z": {
+ "n_agents": 3,
+ "n_enemies": 5,
+ "limit": 250,
+ "a_race": "P",
+ "b_race": "P",
+ "unit_type_bits": 0,
+ "map_type": "stalkers",
+ },
+ "1c3s5z": {
+ "n_agents": 9,
+ "n_enemies": 9,
+ "limit": 180,
+ "a_race": "P",
+ "b_race": "P",
+ "unit_type_bits": 3,
+ "map_type": "colossi_stalkers_zealots",
+ },
+ "2m_vs_1z": {
+ "n_agents": 2,
+ "n_enemies": 1,
+ "limit": 150,
+ "a_race": "T",
+ "b_race": "P",
+ "unit_type_bits": 0,
+ "map_type": "marines",
+ },
+ "corridor": {
+ "n_agents": 6,
+ "n_enemies": 24,
+ "limit": 400,
+ "a_race": "P",
+ "b_race": "Z",
+ "unit_type_bits": 0,
+ "map_type": "zealots",
+ },
+ "6h_vs_8z": {
+ "n_agents": 6,
+ "n_enemies": 8,
+ "limit": 150,
+ "a_race": "Z",
+ "b_race": "P",
+ "unit_type_bits": 0,
+ "map_type": "hydralisks",
+ },
+ "2s_vs_1sc": {
+ "n_agents": 2,
+ "n_enemies": 1,
+ "limit": 300,
+ "a_race": "P",
+ "b_race": "Z",
+ "unit_type_bits": 0,
+ "map_type": "stalkers",
+ },
+ "so_many_baneling": {
+ "n_agents": 7,
+ "n_enemies": 32,
+ "limit": 100,
+ "a_race": "P",
+ "b_race": "Z",
+ "unit_type_bits": 0,
+ "map_type": "zealots",
+ },
+ "bane_vs_bane": {
+ "n_agents": 24,
+ "n_enemies": 24,
+ "limit": 200,
+ "a_race": "Z",
+ "b_race": "Z",
+ "unit_type_bits": 2,
+ "map_type": "bane",
+ },
+ "2c_vs_64zg": {
+ "n_agents": 2,
+ "n_enemies": 64,
+ "limit": 400,
+ "a_race": "P",
+ "b_race": "Z",
+ "unit_type_bits": 0,
+ "map_type": "colossus",
+ },
+}
+
+for name in map_param_registry.keys():
+ globals()[name] = type(name, (SMACMap, ), dict(filename=name))
+
+
+def get_map_params(map_name):
+ return map_param_registry[map_name]
diff --git a/DI-engine/dizoo/smac/envs/smac_reward.py b/DI-engine/dizoo/smac/envs/smac_reward.py
new file mode 100644
index 0000000000000000000000000000000000000000..d41921ca91a350d666c531d578ee37cfb0e09c26
--- /dev/null
+++ b/DI-engine/dizoo/smac/envs/smac_reward.py
@@ -0,0 +1,209 @@
+from collections import namedtuple
+import numpy as np
+
+ORIGINAL_AGENT = "me"
+OPPONENT_AGENT = "opponent"
+
+
+class SMACReward:
+ info_template = namedtuple('EnvElementInfo', ['shape', 'value', 'to_agent_processor', 'from_agent_processor'])
+
+ def __init__(
+ self,
+ n_agents,
+ n_enemies,
+ two_player,
+ reward_type,
+ max_reward,
+ reward_scale=True,
+ reduce_agent=True,
+ reward_only_positive=True
+ ):
+ self.reward_only_positive = reward_only_positive
+ self.reward_scale = reward_scale
+ self.max_reward = max_reward
+ self.reward_death_value = 10
+ self.reward_win = 200
+ self.reward_defeat = 0
+ self.reward_negative_scale = 0.5
+ self.reward_scale_rate = 20
+ self.reduce_agent = reduce_agent
+ self.reward_type = reward_type
+ assert self.reward_type in ['sparse', 'original', 'new']
+ self.n_agents = n_agents
+ self.n_enemies = n_enemies
+
+ self.death_tracker_ally = np.zeros(n_agents)
+ self.death_tracker_enemy = np.zeros(n_enemies)
+
+ self.two_player = two_player
+
+ def reset(self, max_reward):
+ self.max_reward = max_reward
+ if self.reward_type == 'original':
+ self.info().value['max'] = self.max_reward / self.reward_scale_rate
+ self.death_tracker_ally.fill(0)
+ self.death_tracker_enemy.fill(0)
+
+ def get_reward(self, engine, action, game_end_code, win_counted, defeat_counted):
+ reward = {
+ ORIGINAL_AGENT: np.asarray(self.reward_battle_split(engine, action, is_opponent=False)),
+ OPPONENT_AGENT: np.asarray(self.reward_battle_split(engine, action, is_opponent=True))
+ }
+ for k in reward:
+ if reward[k].shape == ():
+ reward[k] = np.expand_dims(reward[k], 0)
+
+ if game_end_code is not None:
+ # Battle is over
+ if game_end_code == 1 and not win_counted:
+ if self.reward_type != "sparse":
+ reward[ORIGINAL_AGENT] += self.reward_win
+ reward[OPPONENT_AGENT] += self.reward_defeat
+ else:
+ reward[ORIGINAL_AGENT] += 1
+ reward[OPPONENT_AGENT] += -1
+ elif game_end_code == -1 and not defeat_counted:
+ if self.reward_type != "sparse":
+ reward[ORIGINAL_AGENT] += self.reward_defeat
+ reward[OPPONENT_AGENT] += self.reward_win
+ else:
+ reward[ORIGINAL_AGENT] += -1
+ reward[OPPONENT_AGENT] += 1
+ # Note: if draw happen, the game_end_code may still be None.
+
+ if self.reward_scale:
+ # rescale to 0~1
+ min_val, max_val = self.info().value['min'], self.info().value['max']
+ reward[ORIGINAL_AGENT] = (reward[ORIGINAL_AGENT] - min_val) / (max_val - min_val)
+ reward[OPPONENT_AGENT] = (reward[OPPONENT_AGENT] - min_val) / (max_val - min_val)
+
+ return reward
+
+ def reward_battle_split(self, engine, action, is_opponent=False):
+ """Reward function when self.reward_type != 'sparse'.
+ Returns accumulative hit/shield point damage dealt to the enemy
+ + reward_death_value per enemy unit killed, and, in case
+ self.reward_only_positive == False, - (damage dealt to ally units
+ + reward_death_value per ally unit killed) * self.reward_negative_scale
+ """
+
+ num_agents = engine.n_agents if not is_opponent else engine.n_enemies
+ num_enmies = engine.n_agents if is_opponent else engine.n_enemies
+
+ if self.reward_type == 'sparse':
+ if self.reduce_agent:
+ return 0.
+ else:
+ return np.zeros(num_agents)
+
+ # if self.reward_type != 'original':
+ assert self.reward_type == 'original', 'reward_type={} is not supported!'.format(self.reward_type)
+ delta_deaths = np.zeros([num_agents])
+ reward = np.zeros([num_agents])
+ delta_ally = np.zeros([num_agents])
+ delta_enemy = np.zeros([num_enmies])
+ delta_death_enemy = np.zeros([num_enmies])
+
+ neg_scale = self.reward_negative_scale
+
+ # update deaths
+ if is_opponent:
+ iterator = engine.enemies.items()
+ previous_units = engine.previous_enemy_units
+ death_tracker = self.death_tracker_enemy
+ else:
+ iterator = engine.agents.items()
+ previous_units = engine.previous_ally_units
+ death_tracker = self.death_tracker_ally
+
+ num_players = 2 if self.two_player else 1
+ for al_id, al_unit in iterator:
+ if death_tracker[al_id] < num_players:
+ # did not die so far
+ prev_health = (previous_units[al_id].health + previous_units[al_id].shield)
+ if al_unit.health == 0:
+ # just died
+ death_tracker[al_id] += 1
+ delta_deaths[al_id] -= self.reward_death_value * neg_scale
+ delta_ally[al_id] += prev_health * neg_scale
+ else:
+ # still alive
+ delta_ally[al_id] += neg_scale * (prev_health - al_unit.health - al_unit.shield)
+
+ # Calculate the damage to opponent.
+ if is_opponent:
+ iterator = engine.agents.items()
+ previous_units = engine.previous_ally_units
+ death_tracker = self.death_tracker_ally
+ else:
+ iterator = engine.enemies.items()
+ previous_units = engine.previous_enemy_units
+ death_tracker = self.death_tracker_enemy
+
+ for e_id, e_unit in iterator:
+ if death_tracker[e_id] < num_players:
+ prev_health = (previous_units[e_id].health + previous_units[e_id].shield)
+ if e_unit.health == 0:
+ death_tracker[e_id] += 1
+ delta_death_enemy[e_id] += self.reward_death_value
+ delta_enemy[e_id] += prev_health
+ else:
+ delta_enemy[e_id] += prev_health - e_unit.health - e_unit.shield
+ # if e_unit.health == 0:
+ # death_tracker[e_id] += 1
+ # delta_death_enemy[e_id] += self.reward_death_value
+ # normed_delta_health = prev_health / (e_unit.health_max + e_unit.shield_max)
+ # delta_enemy[e_id] += normed_delta_health * self.reward_death_value
+ # else:
+ # normed_delta_health = (prev_health - e_unit.health -
+ # e_unit.shield) / (e_unit.health_max + e_unit.shield_max)
+ # delta_enemy[e_id] += normed_delta_health * self.reward_death_value
+
+ # if self.reward_type == 'original':
+ # if self.reduce_agent:
+ # total_reward = sum(delta_deaths) + sum(delta_death_enemy) + sum(delta_enemy)
+ # return total_reward
+ # else:
+ # total_reward = sum(delta_deaths) + sum(delta_death_enemy) + sum(delta_enemy) / num_agents
+ # return np.ones(num_agents) * total_reward
+
+ # Attacking reward
+ # if isinstance(action, dict):
+ # my_action = action["me"] if not is_opponent else action["opponent"]
+ # else:
+ # my_action = action
+ # for my_id, my_action in enumerate(my_action):
+ # if my_action > 5:
+ # reward[my_id] += 2
+
+ if self.reward_only_positive:
+ # reward = abs((delta_deaths + delta_death_enemy + delta_enemy).sum())
+ reward = abs(delta_deaths.sum() + delta_death_enemy.sum() + delta_enemy.sum())
+ else:
+ reward = delta_deaths.sum() + delta_death_enemy.sum() + delta_enemy.sum() - delta_ally.sum()
+
+ return reward
+
+ def info(self):
+ if self.reward_type == 'sparse':
+ value = {'min': -1, 'max': 1}
+ elif self.reward_type == 'original':
+ value = {'min': 0, 'max': self.max_reward / self.reward_scale_rate}
+ # value = {'min': 0, 'max': 75.5}
+ # value = {'min': 0, 'max': self.max_reward / 75.5}
+ # # TODO(nyz) health + shield range
+ # if self.reduce_agent:
+ # value = {'min': 0, 'max': (self.reward_win + self.reward_death_value * self.n_enemies +1230)/20}
+ # else:
+ # value = {'min': 0, 'max': self.reward_win + self.reward_death_value * self.n_enemies / self.n_agents}
+ # elif self.reward_type == 'new':
+ # if self.reduce_agent:
+ # value = {'min': 0, 'max': self.reward_win + 2 + self.reward_death_value * self.n_enemies}
+ # else:
+ # value = {
+ # 'min': 0,
+ # 'max': self.reward_win + 2 + self.reward_death_value * self.n_enemies / self.n_agents
+ # }
+ shape = (1, ) if self.reduce_agent else (self.n_agents, )
+ return SMACReward.info_template(shape, value, None, None)
diff --git a/DI-engine/dizoo/smac/envs/test_smac_env.py b/DI-engine/dizoo/smac/envs/test_smac_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..bf3b99ee3753e149cc971a7537e0212c931fed78
--- /dev/null
+++ b/DI-engine/dizoo/smac/envs/test_smac_env.py
@@ -0,0 +1,158 @@
+import pytest
+import numpy as np
+from easydict import EasyDict
+
+from dizoo.smac.envs import SMACEnv
+
+MOVE_EAST = 4
+MOVE_WEST = 5
+
+
+def automation(env, n_agents):
+ actions = {"me": [], "opponent": []}
+ for agent_id in range(n_agents):
+ avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=False)
+ avail_actions_ind = np.nonzero(avail_actions)[0]
+ action = np.random.choice(avail_actions_ind)
+ if avail_actions[0] != 0:
+ action = 0
+ elif len(np.nonzero(avail_actions[6:])[0]) == 0:
+ if avail_actions[MOVE_EAST] != 0:
+ action = MOVE_EAST
+ else:
+ action = np.random.choice(avail_actions_ind)
+ else:
+ action = np.random.choice(avail_actions_ind)
+ # if MOVE_EAST in avail_actions_ind:
+ # action = MOVE_EAST
+ # Let OPPONENT attack ME at the first place
+ # if sum(avail_actions[6:]) > 0:
+ # action = max(avail_actions_ind)
+ # print("ME start attacking OP")
+ # print("Available action for ME: ", avail_actions_ind)
+ actions["me"].append(action)
+ print('ava', avail_actions, action)
+ for agent_id in range(n_agents):
+ avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=True)
+ avail_actions_ind = np.nonzero(avail_actions)[0]
+ action = np.random.choice(avail_actions_ind)
+ if MOVE_EAST in avail_actions_ind:
+ action = MOVE_EAST
+ # Let OPPONENT attack ME at the first place
+ if sum(avail_actions[6:]) > 0:
+ # print("OP start attacking ME")
+ action = max(avail_actions_ind)
+ actions["opponent"].append(action)
+ return actions
+
+
+def random_policy(env, n_agents):
+ actions = {"me": [], "opponent": []}
+ for agent_id in range(n_agents):
+ avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=False)
+ avail_actions_ind = np.nonzero(avail_actions)[0]
+ action = np.random.choice(avail_actions_ind)
+ actions["me"].append(action)
+ for agent_id in range(n_agents):
+ avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=True)
+ avail_actions_ind = np.nonzero(avail_actions)[0]
+ # Move left to kill ME
+ action = np.random.choice(avail_actions_ind)
+ actions["opponent"].append(action)
+ return actions
+
+
+def fix_policy(env, n_agents, me=0, opponent=0):
+ actions = {"me": [], "opponent": []}
+ for agent_id in range(n_agents):
+ avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=False)
+ avail_actions_ind = np.nonzero(avail_actions)[0]
+ action = me
+ if action not in avail_actions_ind:
+ action = avail_actions_ind[0]
+ actions["me"].append(action)
+
+ for agent_id in range(n_agents):
+ avail_actions = env.get_avail_agent_actions(agent_id, is_opponent=True)
+ avail_actions_ind = np.nonzero(avail_actions)[0]
+ action = opponent
+ if action not in avail_actions_ind:
+ action = avail_actions_ind[0]
+ actions["opponent"].append(action)
+ return actions
+
+
+def main(policy, map_name="3m", two_player=False):
+ cfg = EasyDict({'two_player': two_player, 'map_name': map_name, 'save_replay_episodes': None, 'obs_alone': True})
+ env = SMACEnv(cfg)
+ if map_name == "3s5z":
+ n_agents = 8
+ elif map_name == "3m":
+ n_agents = 3
+ elif map_name == "infestor_viper":
+ n_agents = 2
+ else:
+ raise ValueError(f"invalid type: {map_name}")
+ n_episodes = 20
+ me_win = 0
+ draw = 0
+ op_win = 0
+
+ for e in range(n_episodes):
+ print("Now reset the environment for {} episode.".format(e))
+ env.reset()
+ print('reset over')
+ terminated = False
+ episode_return_me = 0
+ episode_return_op = 0
+
+ env_info = env.info()
+ print('begin new episode')
+ while not terminated:
+ actions = policy(env, n_agents)
+ if not two_player:
+ actions = actions["me"]
+ t = env.step(actions)
+ obs, reward, terminated, infos = t.obs, t.reward, t.done, t.info
+ assert set(obs.keys()) == set(
+ ['agent_state', 'global_state', 'action_mask', 'agent_alone_state', 'agent_alone_padding_state']
+ )
+ assert isinstance(obs['agent_state'], np.ndarray)
+ assert obs['agent_state'].shape == env_info.obs_space.shape['agent_state'] # n_agents, agent_state_dim
+ assert isinstance(obs['agent_alone_state'], np.ndarray)
+ assert obs['agent_alone_state'].shape == env_info.obs_space.shape['agent_alone_state']
+ assert isinstance(obs['global_state'], np.ndarray)
+ assert obs['global_state'].shape == env_info.obs_space.shape['global_state'] # global_state_dim
+ assert isinstance(reward, np.ndarray)
+ assert reward.shape == (1, )
+ print('reward', reward)
+ assert isinstance(terminated, bool)
+ episode_return_me += reward["me"] if two_player else reward
+ episode_return_op += reward["opponent"] if two_player else 0
+ terminated = terminated["me"] if two_player else terminated
+
+ if two_player:
+ me_win += int(infos["me"]["battle_won"])
+ op_win += int(infos["opponent"]["battle_won"])
+ draw += int(infos["draw"])
+ else:
+ me_win += int(infos["battle_won"])
+ op_win += int(infos["battle_lost"])
+ draw += int(infos["draw"])
+
+ print(
+ "Total return in episode {} = {} (me), {} (opponent). Me win {}, Draw {}, Opponent win {}, total {}."
+ "".format(e, episode_return_me, episode_return_op, me_win, draw, op_win, e + 1)
+ )
+
+ env.close()
+
+
+@pytest.mark.env_test
+def test_automation():
+ # main(automation, map_name="3m", two_player=False)
+ main(automation, map_name="infestor_viper", two_player=False)
+
+
+if __name__ == "__main__":
+ test_automation()
diff --git a/DI-engine/dizoo/smac/utils/eval.py b/DI-engine/dizoo/smac/utils/eval.py
new file mode 100644
index 0000000000000000000000000000000000000000..1e112e84a7473319422fb504ca648b474eeb51fe
--- /dev/null
+++ b/DI-engine/dizoo/smac/utils/eval.py
@@ -0,0 +1,69 @@
+from typing import Union, Optional, List, Any, Callable, Tuple
+import pickle
+import torch
+from functools import partial
+
+from ding.config import compile_config, read_config
+from ding.envs import get_vec_env_setting
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+
+
+def eval(
+ input_cfg: Union[str, Tuple[dict, dict]],
+ seed: int = 0,
+ env_setting: Optional[List[Any]] = None,
+ model: Optional[torch.nn.Module] = None,
+ state_dict: Optional[dict] = None,
+) -> float:
+ r"""
+ Overview:
+ Pure evaluation entry.
+ Arguments:
+ - input_cfg (:obj:`Union[str, Tuple[dict, dict]]`): Config in dict type. \
+ ``str`` type means config file path. \
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - state_dict (:obj:`Optional[dict]`): The state_dict of policy or model.
+ """
+ if isinstance(input_cfg, str):
+ cfg, create_cfg = read_config(input_cfg)
+ else:
+ cfg, create_cfg = input_cfg
+ create_cfg.policy.type += '_command'
+ cfg = compile_config(cfg, auto=True, create_cfg=create_cfg)
+
+ env_fn, _, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ env = env_fn(evaluator_env_cfg[0])
+ env.seed(seed, dynamic_seed=False)
+ set_pkg_seed(seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['eval']).eval_mode
+ if state_dict is None:
+ state_dict = torch.load(cfg.learner.load_path, map_location='cpu')
+ policy.load_state_dict(state_dict)
+
+ obs = env.reset()
+ episode_return = 0.
+ while True:
+ policy_output = policy.forward({0: obs})
+ action = policy_output[0]['action']
+ print(action)
+ timestep = env.step(action)
+ episode_return += timestep.reward
+ obs = timestep.obs
+ if timestep.done:
+ print(timestep.info)
+ break
+
+ env.save_replay(replay_dir='.', prefix=env._map_name)
+ print('Eval is over! The performance of your RL policy is {}'.format(episode_return))
+
+
+if __name__ == "__main__":
+ path = '../exp/MMM/qmix/1/ckpt_BaseLearner_Wed_Jul_14_22_16_56_2021/iteration_9900.pth.tar'
+ cfg = '../config/smac_MMM_qmix_config.py'
+ state_dict = torch.load(path, map_location='cpu')
+ eval(cfg, seed=0, state_dict=state_dict)
diff --git a/DI-engine/dizoo/sokoban/__init__.py b/DI-engine/dizoo/sokoban/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/sokoban/envs/__init__.py b/DI-engine/dizoo/sokoban/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c90db771d7cec6db936a9a792c7115246c488453
--- /dev/null
+++ b/DI-engine/dizoo/sokoban/envs/__init__.py
@@ -0,0 +1 @@
+from .sokoban_env import SokobanEnv
diff --git a/DI-engine/dizoo/sokoban/envs/sokoban_env.py b/DI-engine/dizoo/sokoban/envs/sokoban_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..295259e702a50a1f67031a79d6dbdfaba259e419
--- /dev/null
+++ b/DI-engine/dizoo/sokoban/envs/sokoban_env.py
@@ -0,0 +1,111 @@
+import gym
+import copy
+import numpy as np
+from typing import List
+from easydict import EasyDict
+from ding.utils import ENV_REGISTRY
+from ding.torch_utils import to_ndarray
+from ding.envs import BaseEnv, BaseEnvTimestep
+from .sokoban_wrappers import wrap_sokoban
+
+
+@ENV_REGISTRY.register('sokoban')
+class SokobanEnv(BaseEnv):
+
+ def __init__(self, cfg: dict) -> None:
+ self._cfg = cfg
+ self._env_id = cfg.env_id
+ self._init_flag = False
+ self._save_replay = False
+
+ def seed(self, seed: int, dynamic_seed: bool = True) -> None:
+ self._seed = seed
+ self._dynamic_seed = dynamic_seed
+ np.random.seed(self._seed)
+
+ def reset(self) -> np.ndarray:
+ if not self._init_flag:
+ self._env = self._make_env(only_info=False)
+ self._init_flag = True
+
+ if self._save_replay:
+ self._env = gym.wrappers.RecordVideo(
+ self._env,
+ video_folder=self._replay_path,
+ episode_trigger=lambda episode_id: True,
+ name_prefix='rl-video-{}'.format(id(self))
+ )
+
+ self._env.observation_space.dtype = np.float32 # To unify the format of envs in DI-engine
+ self._observation_space = self._env.observation_space
+ self._action_space = self._env.action_space
+ self._reward_space = gym.spaces.Box(
+ low=self._env.reward_range[0], high=self._env.reward_range[1], shape=(1, ), dtype=np.float32
+ )
+
+ if hasattr(self, '_seed') and hasattr(self, '_dynamic_seed') and self._dynamic_seed:
+ np_seed = 100 * np.random.randint(1, 1000)
+ self._env.seed(self._seed + np_seed)
+ elif hasattr(self, '_seed'):
+ self._env.seed(self._seed)
+ obs = self._env.reset()
+ obs = to_ndarray(obs).astype('float32')
+ self._eval_episode_return = 0.
+ return obs
+
+ def step(self, action: np.array):
+ action = to_ndarray(action)
+ obs, rew, done, info = self._env.step(int(action))
+ self._eval_episode_return += rew
+ obs = to_ndarray(obs).astype('float32')
+ rew = to_ndarray([rew]) # wrapped to be transfered to a array with shape (1,)
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def _make_env(self, only_info=False):
+ return wrap_sokoban(
+ self._env_id,
+ norm_obs=self._cfg.get('norm_obs', EasyDict(use_norm=False, )),
+ norm_reward=self._cfg.get('norm_reward', EasyDict(use_norm=False, )),
+ only_info=only_info
+ )
+
+ def close(self) -> None:
+ if self._init_flag:
+ self._env.close()
+ self._init_flag = False
+
+ def enable_save_replay(self, replay_path) -> None:
+ if replay_path is None:
+ replay_path = './video'
+ self._save_replay = True
+ self._replay_path = replay_path
+
+ def __repr__(self) -> str:
+ return "DI-engine Sokoban Env({})".format(self._cfg.env_id)
+
+ @staticmethod
+ def create_collector_env_cfg(cfg: dict) -> List[dict]:
+ collector_cfg = copy.deepcopy(cfg)
+ collector_env_num = collector_cfg.pop('collector_env_num', 1)
+ return [collector_cfg for _ in range(collector_env_num)]
+
+ @staticmethod
+ def create_evaluator_env_cfg(cfg: dict) -> List[dict]:
+ evaluator_cfg = copy.deepcopy(cfg)
+ evaluator_env_num = evaluator_cfg.pop('evaluator_env_num', 1)
+ evaluator_cfg.norm_reward = EasyDict(use_norm=False, )
+ return [evaluator_cfg for _ in range(evaluator_env_num)]
+
+ @property
+ def observation_space(self) -> gym.spaces.Space:
+ return self._observation_space
+
+ @property
+ def action_space(self) -> gym.spaces.Space:
+ return self._action_space
+
+ @property
+ def reward_space(self) -> gym.spaces.Space:
+ return self._reward_space
diff --git a/DI-engine/dizoo/sokoban/envs/sokoban_wrappers.py b/DI-engine/dizoo/sokoban/envs/sokoban_wrappers.py
new file mode 100644
index 0000000000000000000000000000000000000000..31bf7e989906c779a22590929b2e0a4ed4003c55
--- /dev/null
+++ b/DI-engine/dizoo/sokoban/envs/sokoban_wrappers.py
@@ -0,0 +1,39 @@
+from typing import Dict
+import gym
+from ditk import logging
+from ding.envs import ObsNormWrapper, RewardNormWrapper
+
+try:
+ import gym_sokoban
+except ImportError:
+ logging.warning("not found sokoban env, please install it, refer to https://github.com/mpSchrader/gym-sokoban")
+
+
+def wrap_sokoban(
+ env_id, norm_obs: bool = False, norm_reward: Dict = dict(use_norm=False, ), only_info=False
+) -> gym.Env:
+ r"""
+ Overview:
+ Wrap Sokoban Env to preprocess env step's return info, e.g. observation normalization, reward normalization, etc.
+ Arguments:
+ - env_id (:obj:`str`): Mujoco environment id, for example "HalfCheetah-v3"
+ - norm_obs (:obj:`EasyDict`): Whether to normalize observation or not
+ - norm_reward (:obj:`EasyDict`): Whether to normalize reward or not. For evaluator, environment's reward \
+ should not be normalized: Either ``norm_reward`` is None or ``norm_reward.use_norm`` is False can do this.
+ Returns:
+ - wrapped_env (:obj:`gym.Env`): The wrapped mujoco environment
+ """
+ if not only_info:
+ env = gym.make(env_id)
+ if norm_obs is not None and norm_obs.use_norm:
+ env = ObsNormWrapper(env)
+ if norm_reward is not None and norm_reward.use_norm:
+ env = RewardNormWrapper(env, norm_reward.reward_discount)
+ return env
+ else:
+ wrapper_info = ''
+ if norm_obs is not None and norm_obs.use_norm:
+ wrapper_info = ObsNormWrapper.__name__ + '\n'
+ if norm_reward is not None and norm_reward.use_norm:
+ wrapper_info += RewardNormWrapper.__name__ + '\n'
+ return wrapper_info
diff --git a/DI-engine/dizoo/sokoban/envs/test_sokoban_env.py b/DI-engine/dizoo/sokoban/envs/test_sokoban_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..3e40e5a003c630b2a0defd697cd6f5c6b63f0f74
--- /dev/null
+++ b/DI-engine/dizoo/sokoban/envs/test_sokoban_env.py
@@ -0,0 +1,25 @@
+from easydict import EasyDict
+import pytest
+import numpy as np
+from dizoo.sokoban.envs.sokoban_env import SokobanEnv
+
+
+@pytest.mark.envtest
+class TestSokoban:
+
+ def test_sokoban(self):
+ env = SokobanEnv(EasyDict({'env_id': 'Sokoban-v0'}))
+ env.reset()
+ for i in range(100):
+ action = np.random.randint(8)
+ timestep = env.step(np.array(action))
+ print(timestep)
+ print(timestep.obs.max())
+ assert isinstance(timestep.obs, np.ndarray)
+ assert isinstance(timestep.done, bool)
+ assert timestep.obs.shape == (160, 160, 3)
+ print(timestep.info)
+ assert timestep.reward.shape == (1, )
+ if timestep.done:
+ env.reset()
+ env.close()
diff --git a/DI-engine/dizoo/tabmwp/README.md b/DI-engine/dizoo/tabmwp/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..410aed8e6f30fe5ee1d10de6ba183a2585d1a19c
--- /dev/null
+++ b/DI-engine/dizoo/tabmwp/README.md
@@ -0,0 +1,16 @@
+## TabMWP Env
+
+## Dataset
+
+The **TabMWP** dataset contains 38,431 tabular math word problems. Each question in **TabMWP** is aligned with a tabular context, which is presented as an image, semi-structured text, and a structured table. There are two types of questions: *free-text* and *multi-choice*, and each problem is annotated with gold solutions to reveal the multi-step reasoning process.
+
+The environment is described in the paper [Dynamic Prompt Learning via Policy Gradient for Semi-structured Mathematical Reasoning](https://arxiv.org/abs/2209.14610) by Pan Lu, Liang Qiu, Kai-Wei Chang, Ying Nian Wu, Song-Chun Zhu, Tanmay Rajpurohit, Peter Clark, Ashwin Kalyan, 2023.
+
+You can find more details in [Prompt PG](https://github.com/lupantech/PromptPG)
+
+## Benchmark
+
+- We collect the responses of GPT-3 using a reduced dataset with 80 training samples and 16 candidates. In this way, there is no need for users to interact with GPT-3 using the API-key of openai.
+- You can directly reproduce the benchmark by running ``python dizoo/tabmwp/configs/tabmwp_pg_config.py``
+
+![origin](./benchmark.png)
diff --git a/DI-engine/dizoo/tabmwp/__init__.py b/DI-engine/dizoo/tabmwp/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/tabmwp/config/tabmwp_pg_config.py b/DI-engine/dizoo/tabmwp/config/tabmwp_pg_config.py
new file mode 100644
index 0000000000000000000000000000000000000000..acda7bcdbd1e152071736f971b13615354c3ad59
--- /dev/null
+++ b/DI-engine/dizoo/tabmwp/config/tabmwp_pg_config.py
@@ -0,0 +1,66 @@
+from easydict import EasyDict
+
+tabmwp_prompt_pg_config = dict(
+ exp_name='tabmwp_prompt_pg_seed0',
+ env=dict(
+ collector_env_num=1,
+ evaluator_env_num=1,
+ n_evaluator_episode=1,
+ stop_value=1,
+ cand_number=16,
+ train_number=80,
+ engine='text-davinci-002',
+ temperature=0.,
+ max_tokens=512,
+ top_p=1.,
+ frequency_penalty=0.,
+ presence_penalty=0.,
+ option_inds=["A", "B", "C", "D", "E", "F"],
+ # The API-key of openai. You can get your key in this website: https://platform.openai.com/
+ api_key='',
+ enable_replay=True,
+ prompt_format='TQ-A',
+ seed=0,
+ ),
+ policy=dict(
+ cuda=True,
+ shot_number=2,
+ model=dict(
+ model_name="bert-base-uncased",
+ add_linear=True,
+ freeze_encoder=True,
+ embedding_size=128,
+ ),
+ learn=dict(
+ batch_size=10,
+ # (bool) Whether to normalize advantage. Default to False.
+ learning_rate=0.001,
+ # (float) loss weight of the value network, the weight of policy network is set to 1
+ entropy_weight=0.001,
+ weight_decay=5e-3,
+ grad_norm=0.5,
+ ),
+ collect=dict(
+ # (int) collect n_sample data, train model 1 times
+ n_sample=20,
+ discount_factor=0.,
+ ),
+ eval=dict(evaluator=dict(eval_freq=500, )),
+ ),
+)
+main_config = EasyDict(tabmwp_prompt_pg_config)
+
+tabmwp_prompt_pg_config = dict(
+ env=dict(
+ type='tabmwp',
+ import_names=['dizoo.tabmwp.envs.tabmwp_env'],
+ ),
+ env_manager=dict(type='base'),
+ policy=dict(type='prompt_pg'),
+ replay_buffer=dict(type='naive'),
+)
+create_config = EasyDict(tabmwp_prompt_pg_config)
+
+if __name__ == '__main__':
+ from ding.entry import serial_pipeline_onpolicy
+ serial_pipeline_onpolicy((main_config, create_config), seed=0)
diff --git a/DI-engine/dizoo/tabmwp/envs/__init__.py b/DI-engine/dizoo/tabmwp/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/DI-engine/dizoo/tabmwp/envs/tabmwp_env.py b/DI-engine/dizoo/tabmwp/envs/tabmwp_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..fe32e02b3546bac66ab8fb7ce12a0a321927b0f9
--- /dev/null
+++ b/DI-engine/dizoo/tabmwp/envs/tabmwp_env.py
@@ -0,0 +1,266 @@
+import os
+from functools import lru_cache
+
+import gym
+import openai
+import numpy as np
+
+from ding.utils import ENV_REGISTRY
+from ding.envs import BaseEnv, BaseEnvTimestep
+from dizoo.tabmwp.envs.utils import create_example_from_pid, build_prompt, get_gpt3_output, calc_rwkv, calc_internlm,\
+ extract_prediction, normalize_answer, load_data
+
+
+@ENV_REGISTRY.register('tabmwp')
+class TabMWP(BaseEnv):
+ model = None
+ tokenizer = None
+
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self.enable_replay = cfg.enable_replay
+ self._init_flag = False
+ self.problems, self.cand_pids, self.train_pids = None, None, None
+ self.problem_id = 0
+ self.cand_examples = []
+ openai.api_key = cfg.api_key
+ self.observation_space = gym.spaces.Dict()
+ self.action_space = gym.spaces.Discrete(self.cfg.cand_number * (self.cfg.cand_number - 1))
+ self.reward_space = gym.spaces.Box(low=-1, high=1, shape=(1, ), dtype=np.float32)
+ self.correct_num = 0
+
+ # Initialize language model if needed.
+ assert self.cfg.engine in ['text-davinci-002', 'glm-10B', 'rwkv-7B', 'internlm-7B']
+
+ try:
+ if self.cfg.engine == 'glm-10B' and TabMWP.model is None:
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
+ TabMWP.tokenizer = AutoTokenizer.from_pretrained("THUDM/glm-10b", trust_remote_code=True)
+ model = AutoModelForSeq2SeqLM.from_pretrained("THUDM/glm-10b", trust_remote_code=True)
+ TabMWP.model = model.half()
+ elif self.cfg.engine == 'rwkv-7B' and TabMWP.model is None:
+ from transformers import AutoTokenizer, RwkvForCausalLM
+ TabMWP.tokenizer = AutoTokenizer.from_pretrained("sgugger/rwkv-7b-pile", trust_remote_code=True)
+ model = RwkvForCausalLM.from_pretrained("sgugger/rwkv-7b-pile")
+ TabMWP.model = model.half()
+ elif self.cfg.engine == 'internlm-7B' and TabMWP.model is None:
+ from transformers import AutoTokenizer, AutoModelForCausalLM
+ TabMWP.tokenizer = AutoTokenizer.from_pretrained("internlm/internlm-7b", trust_remote_code=True)
+ model = AutoModelForCausalLM.from_pretrained("internlm/internlm-7b", trust_remote_code=True)
+ TabMWP.model = model.eval()
+ except ImportError:
+ import sys
+ from ditk import logging
+ logging.warning("not found transformer, please install it using: pip install transformers")
+ sys.exit(1)
+
+ @lru_cache(maxsize=10000)
+ def get_output(self, inp: str) -> str:
+ inputs = TabMWP.tokenizer(inp + " [MASK].", return_tensors="pt")
+ inputs = TabMWP.tokenizer.build_inputs_for_generation(inputs, max_gen_length=512)
+ inputs = {key: value.cuda() for key, value in inputs.items()}
+ outputs = TabMWP.model.generate(
+ **inputs,
+ max_length=512,
+ eos_token_id=TabMWP.tokenizer.eop_token_id,
+ pad_token_id=TabMWP.tokenizer.eos_token_id
+ )
+ outputs = TabMWP.tokenizer.decode(outputs[0].tolist())
+
+ t0 = outputs.find('<|startofpiece|>') + 16
+ t1 = outputs.find('<|endofpiece|>')
+
+ return outputs[t0:t1]
+
+ def seed(self, seed: int, dynamic_seed: bool = False) -> None:
+ self.cfg.seed = seed
+
+ def reset(self) -> dict:
+ self.problems, self.cand_pids, self.train_pids = load_data(self.cfg)
+ if TabMWP.model is not None:
+ TabMWP.model = TabMWP.model.cuda()
+ if self.enable_replay:
+ self.cand_pids = [
+ '32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514', '19270', '23713',
+ '17209', '33379', '34987', '11177'
+ ]
+ if self.cfg.seed == 0: # train
+ self.train_pids = [
+ '14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433', '26979', '18135',
+ '13347', '17679', '38426', '3454', '10432', '31011', '12162', '13063', '7812', '29661', '24482',
+ '4970', '4405', '17405', '27781', '26724', '5993', '16442', '30148', '15895', '6855', '29903',
+ '18107', '29504', '11106', '32964', '29891', '32104', '15712', '24287', '4997', '32581', '21020',
+ '17247', '31455', '13245', '15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198',
+ '26039', '3791', '4909', '37056', '7144', '8185', '2131', '4398', '38199', '29520', '37329',
+ '21388', '28659', '15044', '28510', '12903', '11794', '37095', '32229', '22918', '31680', '15024',
+ '24607', '26930'
+ ]
+ model_io_path = 'dizoo/tabmwp/data/model_in_out_train.txt'
+ if not os.path.exists(model_io_path):
+ os.system(
+ f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_train.txt -O ' +
+ model_io_path + ' --no-check-certificate'
+ )
+ else:
+ self.train_pids = [
+ '21037', '22976', '2224', '14145', '27962', '26553', '22110', '16541', '26044', '19492', '31882',
+ '11991', '27594', '7637', '15394', '7666', '5177', '33761', '13703', '29105'
+ ]
+ model_io_path = 'dizoo/tabmwp/data/model_in_out_eval.txt'
+ os.system(
+ f'wget https://opendilab.net/download/DI-zoo/tabmwp/model_in_out_eval.txt -O ' + model_io_path +
+ ' --no-check-certificate'
+ )
+
+ self.cfg.cand_number = len(self.cand_pids)
+ self.cfg.train_number = len(self.train_pids)
+
+ self.results_memory = []
+ with open(model_io_path, encoding="ISO-8859-1") as f:
+ tmp = f.read().split('\n')
+ for tt in tmp:
+ if len(tt.strip()) == 0:
+ continue
+ self.results_memory.append(eval(tt))
+
+ self.cand_examples = []
+ self.correct_num = 0
+ for pid in self.cand_pids:
+ example = create_example_from_pid(pid, self.problems, self.cfg, test=True)
+ self.cand_examples.append(example)
+
+ self._init_flag = True
+ self.problem_id = 0
+ train_sample = create_example_from_pid(self.train_pids[self.problem_id], self.problems, self.cfg, test=True)
+ obs = {'train_sample': train_sample, 'candidate_samples': self.cand_examples}
+ return obs
+
+ def search_answer(self, pid, pids):
+ for item in self.results_memory:
+ if item['pid'] != pid:
+ continue
+ if item['shot_pids'] == pids:
+ return item['output']
+
+ raise ValueError('item does not exists.')
+
+ def parse_all_answers(self):
+ self.cand_pids = [
+ '32889', '8044', '16892', '5408', '4051', '37355', '17962', '25807', '30602', '5514', '19270', '23713',
+ '17209', '33379', '34987', '11177', '30218', '26066', '24169', '28492'
+ ]
+ self.train_pids = [
+ '14229', '3409', '29980', '799', '5086', '21778', '36441', '34146', '69', '33433', '26979', '18135',
+ '13347', '17679', '38426', '3454', '10432', '31011', '12162', '13063', '7812', '29661', '24482', '4970',
+ '4405', '17405', '27781', '26724', '5993', '16442', '30148', '15895', '6855', '29903', '18107', '29504',
+ '11106', '32964', '29891', '32104', '15712', '24287', '4997', '32581', '21020', '17247', '31455', '13245',
+ '15850', '10011', '10313', '10158', '1817', '33479', '35842', '14198', '26039', '3791', '4909', '37056',
+ '7144', '8185', '2131', '4398', '38199', '29520', '37329', '21388', '28659', '15044', '28510', '12903',
+ '11794', '37095', '32229', '22918', '31680', '15024', '24607', '26930'
+ ]
+ self.problem_id = 0
+ self.cfg.train_number = len(self.train_pids)
+ n = len(self.cand_pids)
+
+ with open('sampled_pid.txt', 'w') as f:
+ f.write(str(self.cand_pids) + '\n')
+ f.write(str(self.train_pids) + '\n')
+
+ with open('model_in_out.txt', 'w') as f:
+ while self.problem_id < self.cfg.train_number:
+ for i in range(n):
+ for j in range(n):
+ if i == j:
+ continue
+ shot_pids = [self.cand_pids[i], self.cand_pids[j]]
+ pid = self.train_pids[self.problem_id]
+
+ # generate the prompt input
+ prompt = build_prompt(self.problems, shot_pids, pid, self.cfg)
+
+ # get the output from LM
+ # assert self._args.engine == 'text-davinci-002'
+ output = get_gpt3_output(prompt, self.cfg)
+
+ output_txt = {'shot_pids': shot_pids, 'pid': pid, 'prompt': prompt, 'output': output}
+ f.write(str(output_txt) + '\n')
+ print(self.problem_id, i, j)
+
+ self.problem_id += 1
+
+ def close(self) -> None:
+ self._init_flag = False
+
+ def step(self, action: np.array) -> BaseEnvTimestep:
+ shot_pids = [self.cand_pids[cid] for cid in action]
+ pid = self.train_pids[self.problem_id]
+
+ # generate the prompt input
+ prompt = build_prompt(self.problems, shot_pids, pid, self.cfg)
+
+ # get the output from LM
+ if self.enable_replay:
+ output = self.search_answer(pid, shot_pids)
+ elif self.cfg.engine == 'text-davinci-002':
+ output = get_gpt3_output(prompt, self.cfg)
+ elif self.cfg.engine == 'rwkv-7B':
+ output = calc_rwkv(self.model, self.tokenizer, prompt)
+ elif self.cfg.engine == 'internlm-7B':
+ output = calc_internlm(self.model, self.tokenizer, prompt, self.cfg)
+ else:
+ output = self.get_output(prompt)
+
+ # extract the prediction from the output
+ prediction = extract_prediction(output, self.problems[pid]['choices'], self.cfg.option_inds)
+
+ # normalize the number in the text
+ prediction_norm = normalize_answer(prediction, self.problems[pid]['unit'])
+
+ if prediction_norm.lower() == normalize_answer(self.problems[pid]['answer'],
+ self.problems[pid]['unit']).lower():
+ reward = 1
+ self.correct_num += 1
+ else:
+ reward = -1
+
+ self.problem_id += 1
+ if self.problem_id == self.cfg.train_number:
+ done = True
+ info = {'eval_episode_return': self.correct_num / self.cfg.train_number}
+ else:
+ done = False
+ info = {}
+
+ train_sample = create_example_from_pid(pid, self.problems, self.cfg, test=True)
+ obs = {'train_sample': train_sample, 'candidate_samples': self.cand_examples}
+
+ return BaseEnvTimestep(obs, reward, done, info)
+
+ def __repr__(self) -> str:
+ return "DI-engine tabmwp Env"
+
+
+if __name__ == '__main__':
+ from easydict import EasyDict
+ env_cfg = EasyDict(
+ dict(
+ cand_number=16,
+ train_number=20,
+ engine='text-davinci-002',
+ temperature=0.,
+ max_tokens=512,
+ top_p=1.,
+ frequency_penalty=0.,
+ presence_penalty=0.,
+ option_inds=["A", "B", "C", "D", "E", "F"],
+ api_key='xxx',
+ prompt_format='TQ-A',
+ enable_replay=True,
+ seed=0,
+ )
+ )
+ env = TabMWP(env_cfg)
+ env.seed(0)
+ env.reset()
+ env.parse_all_answers()
+ env.search_answer('22976', ['32889', '8044'])
diff --git a/DI-engine/dizoo/tabmwp/envs/test_tabmwp_env.py b/DI-engine/dizoo/tabmwp/envs/test_tabmwp_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..ca9020d971ecbaaac761b100e7420a091e78e869
--- /dev/null
+++ b/DI-engine/dizoo/tabmwp/envs/test_tabmwp_env.py
@@ -0,0 +1,25 @@
+from easydict import EasyDict
+import pytest
+from dizoo.tabmwp.envs.tabmwp_env import TabMWP
+
+
+@pytest.mark.envtest
+class TestSokoban:
+
+ def test_tabmwp(self):
+ config = dict(
+ cand_number=20,
+ train_number=100,
+ engine='text-davinci-002',
+ temperature=0.,
+ max_tokens=512,
+ top_p=1.,
+ frequency_penalty=0.,
+ presence_penalty=0.,
+ option_inds=["A", "B", "C", "D", "E", "F"],
+ api_key='',
+ )
+ config = EasyDict(config)
+ env = TabMWP(config)
+ env.seed(0)
+ env.close()
diff --git a/DI-engine/dizoo/tabmwp/envs/utils.py b/DI-engine/dizoo/tabmwp/envs/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..c97c18393528332ac2735ff045786f71573eed7c
--- /dev/null
+++ b/DI-engine/dizoo/tabmwp/envs/utils.py
@@ -0,0 +1,354 @@
+import json
+import os
+import random
+import re
+import time
+from functools import lru_cache
+import torch
+
+import numpy as np
+import openai
+try:
+ import transformers
+except ImportError:
+ import sys
+ from ditk import logging
+ logging.warning("not found transformer, please install it using: pip install transformers")
+ sys.exit(1)
+
+
+def sample_logits(out: torch.Tensor, temperature: float = 1.0, top_p: float = 0.8) -> int:
+ # Sample an action given the logits.
+ probs = torch.softmax(out, dim=-1).cpu().numpy()
+ sorted_probs = np.sort(probs)[::-1]
+ cumulative_probs = np.cumsum(sorted_probs)
+ cutoff = float(sorted_probs[np.argmax(cumulative_probs > top_p)])
+ probs[probs < cutoff] = 0
+ if temperature != 1.0:
+ probs = probs.pow(1.0 / temperature)
+ probs = probs / np.sum(probs)
+ out = np.random.choice(a=len(probs), p=probs)
+ return out
+
+
+def calc_rwkv(
+ model: transformers.RwkvForCausalLM,
+ tokenizer: transformers.AutoTokenizer,
+ prompt: str,
+ max_len: int = 10
+) -> str:
+ # Use RWKV to generate sentence.
+ orig_len = len(prompt)
+ inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
+ outputs = model(**inputs, labels=inputs["input_ids"])
+ out, state = outputs.logits, outputs.state
+ # Recurrent generation.
+ with torch.no_grad():
+ for i in range(max_len):
+ token = sample_logits(out[0, -1])
+ tmp = tokenizer.decode([token])
+ prompt = prompt + tmp
+ inputs = tokenizer(prompt, return_tensors="pt").to('cuda')
+ outputs = model(**inputs, labels=inputs["input_ids"])
+ out, state = outputs.logits, outputs.state
+ return prompt[orig_len:]
+
+
+def calc_internlm(model, tokenizer, prompt: str, args):
+ inputs = tokenizer(prompt, return_tensors="pt")
+ for k, v in inputs.items():
+ inputs[k] = v.cuda()
+ gen_kwargs = {
+ "max_length": args.max_tokens,
+ "top_p": args.top_p,
+ "temperature": args.temperature,
+ "do_sample": True,
+ "repetition_penalty": args.frequency_penalty
+ }
+ output = model.generate(**inputs, **gen_kwargs)
+ output = tokenizer.decode(output)
+ return output
+
+
+def load_data(args: dict) -> tuple:
+ # Load tabmwp dataset.
+ random.seed(args.seed)
+ data_root = 'dizoo/tabmwp/data'
+
+ if not os.path.exists(data_root):
+ os.mkdir(data_root)
+
+ if not os.path.exists(os.path.join(data_root, f'problems_train.json')):
+ os.system(
+ f'wget https://opendilab.net/download/DI-zoo/tabmwp/problems_train.json -O ' +
+ os.path.join(data_root, f'problems_train.json') + ' --no-check-certificate'
+ )
+ problems = json.load(open(os.path.join(data_root, f'problems_train.json')))
+
+ pids = list(problems.keys())
+ samples = random.sample(pids, args.train_number + args.cand_number) # random sample
+ train_pids = samples[:args.train_number]
+ cand_pids = samples[args.train_number:]
+ return problems, cand_pids, train_pids
+
+
+def get_gpt3_output(prompt: str, args: dict) -> str:
+ return call_gpt3(
+ args.engine, prompt, args.temperature, args.max_tokens, args.top_p, args.frequency_penalty,
+ args.presence_penalty
+ )
+
+
+@lru_cache(maxsize=10000)
+def call_gpt3(
+ engine: str, prompt: str, temperature: float, max_tokens: int, top_p: float, frequency_penalty: float,
+ presence_penalty: float
+) -> str:
+ patience = 100
+ while True:
+ try:
+ response = openai.Completion.create(
+ engine=engine,
+ prompt=prompt,
+ temperature=temperature,
+ max_tokens=max_tokens,
+ top_p=top_p,
+ frequency_penalty=frequency_penalty,
+ presence_penalty=presence_penalty,
+ stop=["\n"]
+ )
+ output = response["choices"][0]["text"].strip()
+ break
+ except Exception:
+ patience -= 1
+ if not patience:
+ print("!!! running out of patience waiting for OpenAI")
+ else:
+ time.sleep(0.1)
+ return output
+
+
+def get_table_text(problem: dict) -> str:
+ table = problem['table']
+ title = problem['table_title']
+ if title and len(title) > 0:
+ table = f"[TITLE]: {title}\n{table}"
+ return table
+
+
+def get_question_text(problem: dict, option_inds: list) -> str:
+ question = problem['question']
+
+ unit = problem['unit']
+ if unit and len(unit) > 0:
+ question = f"{question} (Unit: {unit})"
+
+ choices = problem['choices']
+ if choices and len(choices) > 0:
+ choice_list = []
+ for i, c in enumerate(choices):
+ choice_list.append("({}) {}".format(option_inds[i], c))
+ options = " ".join(choice_list)
+ question = f"{question}\nOptions: {options}"
+
+ return question
+
+
+def get_answer(problem: dict) -> str:
+ return problem['answer']
+
+
+def get_solution_text(problem: dict) -> str:
+ # GPT-3 can generate the solution with more tokens
+ solution = problem['solution'].replace("\n", "\\n")
+ return solution
+
+
+def create_one_example(
+ format: str, table: str, question: str, answer: str, solution: str, test_example: bool = True
+) -> str:
+ # Using template to generate one prompt example.
+ input_format, output_format = format.split("-") # e.g., "TQ-A"
+
+ elements = {
+ "Q": f"Question: {question}",
+ "T": f"Table: {table}",
+ "S": f"Solution: {solution}",
+ "A": f"Answer: The answer is {answer}.",
+ "AS": f"Answer: The answer is {answer}. BECAUSE: {solution}",
+ "SA": f"Answer: {solution} The answer is {answer}."
+ }
+
+ # Input
+ input = "\n".join(elements[label] for label in input_format)
+
+ # Output
+ if test_example:
+ output = "Answer:"
+ else:
+ output = elements[output_format]
+
+ # Prompt text
+ text = input + "\n" + output
+ text = text.replace(" ", " ").strip()
+
+ return text
+
+
+def build_prompt(problems: list, shot_pids: list, test_pid: int, args: dict) -> str:
+ # Given ids, generate the complete prompt. That is, the input to LM.
+ examples = []
+ pids = shot_pids + [test_pid]
+
+ # n-shot training examples
+ for pid in pids:
+ problem = problems[pid]
+ table = get_table_text(problem)
+ question = get_question_text(problem, args.option_inds)
+ answer = get_answer(problem)
+ solution = get_solution_text(problems[pid])
+
+ if pid == test_pid:
+ assert pid not in shot_pids
+ example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=True)
+ else:
+ example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=False)
+
+ examples.append(example)
+
+ # create the prompt input
+ prompt_input = '\n\n'.join(examples)
+
+ return prompt_input
+
+
+def extract_prediction(output: str, options: list, option_inds: list) -> str:
+ idx = output.find('\n')
+ if idx > 0:
+ output = output[:idx]
+ idx = output.find('=')
+ if idx > 0:
+ output = output[idx + 1:].strip()
+ # $\\frac{16}{95}$ -> 16/95
+ output = re.sub(r"\$?\\frac\{([\d\.\,\-]+)\}\{([\d\.\,]+)\}\$?", r"\1/\2", output)
+
+ output = re.sub(r"(? 0:
+ pred = res[0].upper() # e.g., "B"
+ if pred in option_inds:
+ ind = option_inds.index(pred) # 1
+ if ind >= len(options):
+ ind = random.choice(range(len(options)))
+ predition = options[ind]
+ return predition
+
+ # find the most similar options
+ scores = [score_string_similarity(x, output) for x in options]
+ max_idx = int(np.argmax(scores)) # json does not recognize NumPy data types
+ predition = options[max_idx]
+ return predition
+
+ else:
+ # free_text QA problems, numeric answer
+ patterns = [
+ # r'^\([A-Za-z]\) ([\s\S]+)$', # "(A) XXXXX"
+ # r'[Th]he answer is \([A-Za-z]\) ([\s\S]+)$', # "The answer is (B) XXXXX."
+ r'[Th]he answer is ([\s\S]+)$', # "The answer is XXXXX.",
+ r'[Th]he table shows that ([\d\$\.\,\/\:]+) ',
+ r' = ([\d\$\.\,\/\:]+)', # "= $1.40"
+ r'(?<= be| is) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "will be $1.40"
+ r'(?<= are| was) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40"
+ r'(?<= were) ([\-\d\$\.\,\/\:]{0,}[\d]+)', # "are $1.40"
+ r' ([\d\$\.\,\/\:]+ [AP]\.M\.)', # 7:25 P.M.
+ r'([\-\d\$\.\,\/\:]{0,}[\d]+)', # 14.5
+ ]
+
+ for p in patterns:
+ pattern = re.compile(p)
+ res = pattern.findall(output)
+ if len(res) > 0:
+ predition = res[-1].strip()
+ if predition.endswith(".") and ".M." not in predition:
+ predition = predition[:-1]
+ return predition
+
+ return output
+
+
+def normalize_answer(text: str, unit: str) -> str:
+ # ["1,000", "123", "3/4", "56.456", "$56.4", "-3", "-10.02", "-3/2"]
+
+ text = re.sub("^[\$]", "", text)
+ text = re.sub("[\,\.\,\/]$", "", text)
+ result = re.match("^[-+]?[\d,./]+$", text)
+
+ if result is not None:
+ # is number?
+ text = text.replace(",", "")
+ result = re.match("[-+]?\d+$", text)
+ try:
+ if result is not None:
+ number = int(text)
+ elif "/" in text:
+ nums = text.split("/")
+ number = round(float(nums[0]) / float(nums[1]), 3)
+ else:
+ number = round(float(text), 3)
+ number = str(number)
+ number = re.sub(r"\.[0]+$", "", number)
+ return number
+ except:
+ return text
+ else:
+ # is text
+ if unit:
+ text = text.replace(unit, "").strip()
+ return text
+
+
+def score_string_similarity(str1: str, str2: str) -> float:
+ if str1 == str2:
+ return 2.0
+ if " " in str1 or " " in str2:
+ str1_split = str1.split(" ")
+ str2_split = str2.split(" ")
+ overlap = list(set(str1_split) & set(str2_split))
+ return len(overlap) / max(len(str1_split), len(str2_split))
+ else:
+ if str1 == str2:
+ return 1.0
+ else:
+ return 0.0
+
+
+def create_example_from_pid(pid: int, problems: list, args: dict, test: bool = False) -> str:
+ problem = problems[pid]
+ table = get_table_text(problem)
+ question = get_question_text(problem, args.option_inds)
+ answer = get_answer(problem)
+ solution = get_solution_text(problems[pid])
+
+ if test:
+ example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=True)
+ else:
+ example = create_one_example(args.prompt_format, table, question, answer, solution, test_example=False)
+
+ return example
diff --git a/DI-engine/docker/Dockerfile.base b/DI-engine/docker/Dockerfile.base
new file mode 100644
index 0000000000000000000000000000000000000000..6e5599040c79de3f5ec39b5f081ae3b168a55096
--- /dev/null
+++ b/DI-engine/docker/Dockerfile.base
@@ -0,0 +1,59 @@
+FROM pytorch/pytorch:1.12.1-cuda11.3-cudnn8-runtime as base
+
+WORKDIR /ding
+
+RUN apt update \
+ && apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils zip unzip cmake -y \
+ && apt clean \
+ && rm -rf /var/cache/apt/* \
+ && sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \
+ && locale-gen
+
+ENV LANG en_US.UTF-8
+ENV LANGUAGE en_US:UTF-8
+ENV LC_ALL en_US.UTF-8
+
+ADD setup.py setup.py
+ADD dizoo dizoo
+ADD ding ding
+ADD README.md README.md
+
+RUN python3 -m pip install --upgrade pip \
+ && python3 -m pip install --ignore-installed 'PyYAML<6.0' \
+ && python3 -m pip install --no-cache-dir .[fast,test]
+
+FROM ubuntu:20.04 as doc
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+WORKDIR /ding
+
+RUN apt-get update && \
+ apt-get install --no-install-recommends -y \
+ python3.8 python3-pip python3.8-dev
+
+RUN apt update \
+ && apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils zip unzip cmake -y \
+ && apt clean \
+ && rm -rf /var/cache/apt/* \
+ && sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \
+ && locale-gen
+
+ENV LANG en_US.UTF-8
+ENV LANGUAGE en_US:UTF-8
+ENV LC_ALL en_US.UTF-8
+
+ADD setup.py setup.py
+ADD dizoo dizoo
+ADD ding ding
+ADD README.md README.md
+
+RUN python3 -m pip install --upgrade pip \
+ && python3 -m pip install --ignore-installed 'PyYAML<6.0' \
+ && python3 -m pip install --no-cache-dir .[fast]
+
+WORKDIR /ding_doc
+
+RUN git clone -b main https://github.com/opendilab/DI-engine-docs.git \
+ && cd DI-engine-docs \
+ && python3 -m pip install -r requirements.txt
diff --git a/DI-engine/docker/Dockerfile.env b/DI-engine/docker/Dockerfile.env
new file mode 100644
index 0000000000000000000000000000000000000000..dbf89c7f3e04d8dab47a4c2203f2723f839cb6f1
--- /dev/null
+++ b/DI-engine/docker/Dockerfile.env
@@ -0,0 +1,147 @@
+FROM opendilab/ding:nightly as atari
+
+WORKDIR /ding
+
+RUN python3 -m pip install --upgrade pip \
+ && python3 -m pip install --no-cache-dir .[common_env] \
+ && pip install autorom \
+ && AutoROM --accept-license
+
+FROM opendilab/ding:nightly as mujoco
+
+WORKDIR /ding
+
+RUN apt update \
+ && apt install -y \
+ build-essential \
+ libgl1-mesa-dev \
+ libgl1-mesa-glx \
+ libglew-dev \
+ libosmesa6-dev \
+ libglfw3 \
+ libglfw3-dev \
+ libsdl2-dev \
+ libsdl2-image-dev \
+ libglm-dev \
+ libfreetype6-dev \
+ patchelf
+
+RUN mkdir -p /root/.mujoco \
+ && wget https://mujoco.org/download/mujoco210-linux-x86_64.tar.gz -O mujoco.tar.gz \
+ && tar -xf mujoco.tar.gz -C /root/.mujoco \
+ && rm mujoco.tar.gz \
+ && echo "export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:/root/.mujoco/mjpro210/bin:/root/.mujoco/mujoco210/bin" >> /root/.bashrc
+
+ENV LD_LIBRARY_PATH /root/.mujoco/mjpro210/bin:/root/.mujoco/mujoco210/bin:${LD_LIBRARY_PATH}
+
+Run python3 -m pip install --upgrade pip \
+ && pip3 install "cython<3" \
+ && pip3 install --no-cache-dir numpy \
+ && pip3 install --no-cache-dir -U "gym[mujoco,mujoco_py]==0.25.1" --user \
+ && pip install gymnasium[mujoco] \
+ && python -c "import mujoco_py"
+
+FROM opendilab/di-star:latest as smac
+
+WORKDIR /ding
+
+ADD setup.py setup.py
+ADD dizoo dizoo
+ADD ding ding
+ADD README.md README.md
+
+RUN python3 -m pip install --upgrade pip \
+ && python3 -m pip install --no-cache-dir .[fast]
+
+ENV SC2PATH=/root/StarCraftII_4.10.0
+
+FROM opendilab/ding:nightly as grf
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+WORKDIR /ding
+
+RUN apt-get update && apt-get install git build-essential libgl1-mesa-dev libsdl2-dev \
+ libsdl2-image-dev libsdl2-ttf-dev libsdl2-gfx-dev libboost-all-dev \
+ libdirectfb-dev libst-dev mesa-utils xvfb x11vnc -y \
+ && apt clean \
+ && rm -rf /var/cache/apt/*
+
+RUN python3 -m pip install --upgrade pip setuptools psutil wheel \
+ && python3 -m pip install --no-cache-dir gfootball
+
+FROM opendilab/ding:nightly as dmc2gym
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+WORKDIR /ding
+
+RUN apt-get update && apt-get install glew-utils freeglut3 freeglut3-dev libosmesa6 wget zip ffmpeg -y
+
+ENV MUJOCO_GL "egl"
+
+RUN wget https://codeload.github.com/denisyarats/dmc2gym/zip/refs/heads/master -O dmc2gym-master.zip \
+ && unzip dmc2gym-master.zip \
+ && python3 -m pip install --no-cache-dir ./dmc2gym-master/ \
+ && rm -rf dmc2gym-master \
+ && rm dmc2gym-master.zip
+
+FROM opendilab/ding:nightly-mujoco as metaworld
+
+WORKDIR /ding
+
+RUN mkdir tempfile \
+ && cd tempfile \
+ && python3 -m pip install --no-cache-dir git+https://github.com/Farama-Foundation/Metaworld.git@b2a4cbb98e20081412cb4cc7ae3d4afc456a732a \
+ && cd .. \
+ && rm -rf tempfile
+
+RUN apt-get install xvfb ffmpeg -y \
+ && rm -rf /opt/conda/bin/ffmpeg \
+ && ln -s /usr/bin/ffmpeg /opt/conda/bin/ffmpeg
+
+FROM opendilab/ding:nightly as cityflow
+
+WORKDIR /ding
+
+RUN apt update \
+ && apt install -y \
+ build-essential
+
+RUN mkdir -p /root/.cityflow \
+ && cd /root/.cityflow \
+ && git clone https://github.com/cityflow-project/CityFlow \
+ && cd CityFlow \
+ && pip install -e .
+
+RUN mkdir -p /root/.smartcross \
+ && cd /root/.smartcross \
+ && git clone https://github.com/opendilab/DI-smartcross \
+ && cd DI-smartcross \
+ && pip install -e .
+
+
+FROM opendilab/ding:nightly as evogym
+
+WORKDIR /ding
+
+RUN apt update \
+ && apt install -y \
+ build-essential libglew-dev libglu1-mesa-dev xorg-dev
+
+RUN mkdir -p /root/.evogym \
+ && cd /root/.evogym \
+ && git clone --recurse-submodules https://github.com/PaParaZz1/evogym.git \
+ && cd evogym \
+ && pip3 install -r requirements.txt
+
+RUN cd /root/.evogym/evogym && python3 setup.py install
+
+FROM opendilab/ding:nightly-mujoco as d4rl
+
+WORKDIR /ding
+
+RUN git clone https://github.com/PaParaZz1/D4RL.git
+
+RUN cd D4RL \
+ && pip install -e .
diff --git a/DI-engine/docker/Dockerfile.hpc b/DI-engine/docker/Dockerfile.hpc
new file mode 100644
index 0000000000000000000000000000000000000000..cf432fc8008219b6b2d17834f8a16ee63811893f
--- /dev/null
+++ b/DI-engine/docker/Dockerfile.hpc
@@ -0,0 +1,77 @@
+FROM opendilab/di-hpc:develop as ding-hpc-develop
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+WORKDIR /ding
+
+RUN apt update \
+ && apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make locales -y \
+ && apt clean \
+ && rm -rf /var/cache/apt/* \
+ && sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \
+ && locale-gen
+
+ENV LANG en_US.UTF-8
+ENV LANGUAGE en_US:UTF-8
+ENV LC_ALL en_US.UTF-8
+ENV ENABLE_DI_HPC true
+
+ADD setup.py setup.py
+ADD dizoo dizoo
+ADD ding ding
+
+RUN python3 -m pip install --upgrade pip \
+ && python3 -m pip install --no-cache-dir 'PyYAML<6.0' \
+ && python3 -m pip install --no-cache-dir .[fast]
+
+FROM opendilab/di-hpc:runtime as ding-hpc-runtime
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+WORKDIR /ding
+
+RUN apt update \
+ && apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make locales -y \
+ && apt clean \
+ && rm -rf /var/cache/apt/* \
+ && sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \
+ && locale-gen
+
+ENV LANG en_US.UTF-8
+ENV LANGUAGE en_US:UTF-8
+ENV LC_ALL en_US.UTF-8
+ENV ENABLE_DI_HPC true
+
+ADD setup.py setup.py
+ADD dizoo dizoo
+ADD ding ding
+
+RUN python3 -m pip install --upgrade pip \
+ && python3 -m pip install --no-cache-dir 'PyYAML<6.0' \
+ && python3 -m pip install --no-cache-dir .[fast]
+
+FROM opendilab/di-hpc:nightly as ding-hpc
+
+ENV DEBIAN_FRONTEND=noninteractive
+
+WORKDIR /ding
+
+RUN apt update \
+ && apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make locales -y \
+ && apt clean \
+ && rm -rf /var/cache/apt/* \
+ && sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \
+ && locale-gen
+
+ENV LANG en_US.UTF-8
+ENV LANGUAGE en_US:UTF-8
+ENV LC_ALL en_US.UTF-8
+ENV ENABLE_DI_HPC true
+
+ADD setup.py setup.py
+ADD dizoo dizoo
+ADD ding ding
+
+RUN python3 -m pip install --upgrade pip \
+ && python3 -m pip install --no-cache-dir 'PyYAML<6.0' \
+ && python3 -m pip install --no-cache-dir .[fast]
diff --git a/DI-engine/docker/Dockerfile.rpc b/DI-engine/docker/Dockerfile.rpc
new file mode 100644
index 0000000000000000000000000000000000000000..b9e9496548ae6ce547f318fecc37c840ce1fe792
--- /dev/null
+++ b/DI-engine/docker/Dockerfile.rpc
@@ -0,0 +1,23 @@
+FROM snsao/pytorch:tensorpipe-fix as base
+
+WORKDIR /ding
+
+RUN apt update \
+ && apt install libgl1-mesa-glx libglib2.0-0 libsm6 libxext6 libxrender-dev swig curl git vim gcc \g++ make wget locales dnsutils -y \
+ && apt clean \
+ && rm -rf /var/cache/apt/* \
+ && sed -i '/en_US.UTF-8/s/^# //g' /etc/locale.gen \
+ && locale-gen
+
+ENV LANG en_US.UTF-8
+ENV LANGUAGE en_US:UTF-8
+ENV LC_ALL en_US.UTF-8
+
+ADD setup.py setup.py
+ADD dizoo dizoo
+ADD ding ding
+ADD README.md README.md
+
+RUN python3 -m pip install --upgrade pip \
+ && python3 -m pip install --ignore-installed 'PyYAML<6.0' \
+ && python3 -m pip install --no-cache-dir .[fast,test]
diff --git a/DI-engine/format.sh b/DI-engine/format.sh
new file mode 100755
index 0000000000000000000000000000000000000000..506ac0243aa6ea14c3c3093077624a27897fe1a9
--- /dev/null
+++ b/DI-engine/format.sh
@@ -0,0 +1,20 @@
+#!/usr/bin/env bash
+# Usage: at the root dir >> bash scripts/format.sh .
+
+# Check yapf version. (20200318 latest is 0.29.0. Format might be changed in future version.)
+ver=$(yapf --version)
+if ! echo $ver | grep -q 0.29.0; then
+ echo "Wrong YAPF version installed: 0.29.0 is required, not $ver. $YAPF_DOWNLOAD_COMMAND_MSG"
+ exit 1
+fi
+
+yapf --in-place --recursive -p --verbose --style .style.yapf $1
+
+if [[ "$2" == '--test' ]]; then # Only for CI usage, user should not use --test flag.
+ if ! git diff --quiet &>/dev/null; then
+ echo '*** You have not reformatted your codes! Please run [bash format.sh] at root directory before commit! Thanks! ***'
+ exit 1
+ else
+ echo "Code style test passed!"
+ fi
+fi
diff --git a/DI-engine/pytest.ini b/DI-engine/pytest.ini
new file mode 100644
index 0000000000000000000000000000000000000000..efdeaba0237ce5d813916813d06fb54a28495ca3
--- /dev/null
+++ b/DI-engine/pytest.ini
@@ -0,0 +1,14 @@
+[pytest]
+execution_timeout = 600
+markers =
+ unittest
+ platformtest
+ envtest
+ cudatest
+ algotest
+ benchmark
+ envpooltest
+ other
+ tmp
+
+norecursedirs = ding/hpc_rl/tests
diff --git a/DI-engine/setup.py b/DI-engine/setup.py
new file mode 100644
index 0000000000000000000000000000000000000000..96ac6f153d9fc2c1ccc112a4aaa61d994473f592
--- /dev/null
+++ b/DI-engine/setup.py
@@ -0,0 +1,192 @@
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS-IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Module setuptools script."""
+
+from __future__ import absolute_import
+from __future__ import division
+from __future__ import print_function
+
+import os
+
+from setuptools import setup, find_packages
+from importlib import import_module
+
+here = os.path.abspath(os.path.dirname(__file__))
+meta_module = import_module('ding')
+meta = meta_module.__dict__
+with open('README.md', mode='r', encoding='utf-8') as f:
+ readme = f.read()
+
+setup(
+ name=meta['__TITLE__'],
+ version=meta['__VERSION__'],
+ description=meta['__DESCRIPTION__'],
+ long_description=readme,
+ long_description_content_type='text/markdown',
+ author=meta['__AUTHOR__'],
+ author_email=meta['__AUTHOR_EMAIL__'],
+ url='https://github.com/opendilab/DI-engine',
+ license='Apache License, Version 2.0',
+ keywords='Decision AI Engine',
+ packages=[
+ # framework
+ *find_packages(include=('ding', "ding.*")),
+ # application
+ *find_packages(include=('dizoo'
+ 'dizoo.*')),
+ ],
+ package_data={
+ package_name: ['*.yaml', '*.xml', '*cfg', '*SC2Map']
+ for package_name in find_packages(include=('ding.*'))
+ },
+ python_requires=">=3.7",
+ install_requires=[
+ 'setuptools<=66.1.1',
+ 'yapf==0.29.0',
+ 'gym==0.25.1', # pypy incompatible; some environments only support gym==0.22.0
+ 'gymnasium',
+ 'torch>=1.1.0',
+ 'numpy>=1.18.0',
+ 'DI-treetensor>=0.4.0',
+ 'DI-toolkit>=0.1.0',
+ 'trueskill',
+ 'tensorboardX>=2.2',
+ 'wandb',
+ 'matplotlib',
+ 'easydict==1.9',
+ 'pyyaml',
+ 'enum_tools',
+ 'cloudpickle',
+ 'hickle',
+ 'tabulate',
+ 'click>=7.0.0',
+ 'requests>=2.25.1', # interaction
+ 'flask~=1.1.2', # interaction
+ 'responses~=0.12.1', # interaction
+ 'URLObject>=2.4.0', # interaction
+ 'MarkupSafe==2.0.1', # interaction, compatibility
+ 'pynng', # parallel
+ 'sniffio', # parallel
+ 'redis', # parallel
+ 'mpire>=2.3.5', # parallel
+ ],
+ extras_require={
+ 'test': [
+ 'coverage>=5,<=7.0.1',
+ 'mock>=4.0.3',
+ 'pytest~=7.0.1', # required by gym>=0.25.0
+ 'pytest-cov~=3.0.0',
+ 'pytest-mock~=3.6.1',
+ 'pytest-xdist>=1.34.0',
+ 'pytest-rerunfailures~=10.2',
+ 'pytest-timeout~=2.0.2',
+ 'readerwriterlock',
+ 'pandas',
+ 'lz4',
+ 'h5py',
+ 'scipy',
+ 'scikit-learn',
+ 'gym[box2d]==0.25.1',
+ 'pettingzoo<=1.22.3',
+ 'opencv-python', # pypy incompatible
+ ],
+ 'style': [
+ 'yapf==0.29.0',
+ 'flake8<=3.9.2',
+ 'importlib-metadata<5.0.0', # compatibility
+ ],
+ 'fast': [
+ 'numpy-stl',
+ 'numba>=0.53.0',
+ ],
+ 'video': [
+ 'moviepy',
+ 'imageio[ffmpeg]',
+ ],
+ 'dist': [
+ 'redis-py-cluster==2.1.0',
+ ],
+ 'common_env': [
+ 'ale-py', # >=0.7.5', # atari
+ 'autorom',
+ 'gym[all]==0.25.1',
+ 'cmake>=3.18.4',
+ 'opencv-python', # pypy incompatible
+ ],
+ 'gfootball_env': [
+ 'gfootball',
+ 'kaggle-environments',
+ ],
+ 'procgen_env': [
+ 'procgen',
+ ],
+ 'bsuite_env': [
+ 'bsuite',
+ ],
+ 'minigrid_env': [
+ 'minigrid>=2.0.0',
+ ],
+ # 'd4rl_env': [
+ # 'd4rl @ git+https://github.com/rail-berkeley/d4rl@master#egg=d4rl',
+ # ],
+ # 'pybulletgym_env': [
+ # 'pybulletgym @ git+https://github.com/benelot/pybullet-gym@master#egg=pybulletgym',
+ # ],
+ # 'gym_hybrid_env': [
+ # 'gym-hybrid @ git+https://github.com/thomashirtz/gym-hybrid@master#egg=gym-hybrid',
+ # ],
+
+ # 'gobigger_env': [
+ # 'gobigger @ git+https://github.com/opendilab/GoBigger@main#egg=gobigger',
+ # ],
+ # 'gym_soccer_env': [
+ # 'gym-soccer @ git+https://github.com/LikeJulia/gym-soccer@dev-install-packages#egg=gym-soccer',
+ # ],
+ 'slimevolleygym_env': [
+ 'slimevolleygym',
+ ],
+ 'smac_env': [
+ 'pysc2',
+ ],
+ 'k8s': [
+ 'kubernetes',
+ ],
+ 'envpool': [
+ 'envpool',
+ ],
+ # 'dmc2gym': [
+ # 'dmc2gym @ git+https://github.com/denisyarats/dmc2gym@master#egg=dmc2gym',
+ # ],
+ # 'rocket_recycling': [
+ # 'rocket_recycling @ git+https://github.com/nighood/rocket-recycling@master#egg=rocket_recycling',
+ # ],
+ 'sokoban': [
+ 'gym-sokoban',
+ ],
+ 'mario': [
+ 'gym-super-mario-bros>=7.3.0',
+ ],
+ },
+ entry_points={'console_scripts': ['ding=ding.entry.cli:cli', 'ditask=ding.entry.cli_ditask:cli_ditask']},
+ classifiers=[
+ 'Development Status :: 5 - Production/Stable',
+ "Intended Audience :: Science/Research",
+ 'License :: OSI Approved :: Apache Software License',
+ 'Operating System :: POSIX :: Linux',
+ 'Operating System :: Microsoft :: Windows',
+ 'Operating System :: MacOS :: MacOS X',
+ 'Programming Language :: Python :: 3.7',
+ 'Programming Language :: Python :: 3.8',
+ 'Programming Language :: Python :: 3.9',
+ 'Topic :: Scientific/Engineering :: Artificial Intelligence',
+ ],
+)
diff --git a/LightZero b/LightZero
deleted file mode 160000
index 3d338ae891b54c955f34b90be1de1a0a14f56477..0000000000000000000000000000000000000000
--- a/LightZero
+++ /dev/null
@@ -1 +0,0 @@
-Subproject commit 3d338ae891b54c955f34b90be1de1a0a14f56477
diff --git a/LightZero/.coveragerc b/LightZero/.coveragerc
new file mode 100644
index 0000000000000000000000000000000000000000..d9a48b4bbc10773784afed7f465458ec5dc9a7c3
--- /dev/null
+++ b/LightZero/.coveragerc
@@ -0,0 +1,2 @@
+[run]
+plugins = Cython.Coverage
diff --git a/LightZero/.gitignore b/LightZero/.gitignore
new file mode 100644
index 0000000000000000000000000000000000000000..c164950f6dc759d1a13ed2d2f3dc5c9bb60fdd9d
--- /dev/null
+++ b/LightZero/.gitignore
@@ -0,0 +1,1447 @@
+# Created by .ignore support plugin (hsz.mobi)
+### ArchLinuxPackages template
+*.tar
+*.tar.*
+*.jar
+*.exe
+*.msi
+*.zip
+*.tgz
+*.log
+*.log.*
+*.sig
+*.mov
+*.pkl
+data_*
+*.so
+*.gv
+*.png
+*.csv
+
+pkg/
+src/
+
+### CVS template
+/CVS/*
+**/CVS/*
+.cvsignore
+*/.cvsignore
+
+### LibreOffice template
+# LibreOffice locks
+.~lock.*#
+
+### CUDA template
+*.i
+*.ii
+*.gpu
+*.ptx
+*.cubin
+*.fatbin
+
+### Eclipse template
+*.bin
+.metadata
+bin/
+bc/
+*.tmp
+*.bak
+*.swp
+*~.nib
+local.properties
+.settings/
+.loadpath
+.recommenders
+
+# External tool builders
+.externalToolBuilders/
+
+# Locally stored "Eclipse launch configurations"
+*.launch
+
+# PyDev specific (Python IDE for Eclipse)
+*.pydevproject
+
+# CDT-specific (C/C++ Development Tooling)
+.cproject
+
+# CDT- autotools
+.autotools
+
+# Java annotation processor (APT)
+.factorypath
+
+# PDT-specific (PHP Development Tools)
+.buildpath
+
+# sbteclipse plugin
+.target
+
+# Tern plugin
+.tern-project
+
+# TeXlipse plugin
+.texlipse
+
+# STS (Spring Tool Suite)
+.springBeans
+
+# Code Recommenders
+.recommenders/
+
+# Annotation Processing
+.apt_generated/
+.apt_generated_test/
+
+# Scala IDE specific (Scala & Java development for Eclipse)
+.cache-main
+.scala_dependencies
+.worksheet
+
+# Uncomment this line if you wish to ignore the project description file.
+# Typically, this file would be tracked if it contains build/dependency configurations:
+#.project
+
+### SVN template
+.svn/
+
+### Images template
+# JPEG
+*.jpg
+*.jpeg
+*.jpe
+*.jif
+*.jfif
+*.jfi
+
+# JPEG 2000
+*.jp2
+*.j2k
+*.jpf
+*.jpx
+*.jpm
+*.mj2
+
+# JPEG XR
+*.jxr
+*.hdp
+*.wdp
+
+# Graphics Interchange Format
+*.gif
+
+# RAW
+*.raw
+
+# Web P
+*.webp
+
+# Portable Network Graphics
+#*.png
+
+# Animated Portable Network Graphics
+*.apng
+
+# Multiple-image Network Graphics
+*.mng
+
+# Tagged Image File Format
+*.tiff
+*.tif
+
+# Scalable Vector Graphics
+*.svg
+*.svgz
+
+# Portable Document Format
+*.pdf
+
+# X BitMap
+*.xbm
+
+# BMP
+*.bmp
+*.dib
+
+# ICO
+*.ico
+
+# 3D Images
+*.3dm
+*.max
+
+### Diff template
+*.patch
+*.diff
+
+### JetBrains template
+# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
+# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
+
+# User-specific stuff
+.idea/**/workspace.xml
+.idea/**/tasks.xml
+.idea/**/usage.statistics.xml
+.idea/**/dictionaries
+.idea/**/shelf
+
+# Generated files
+.idea/**/contentModel.xml
+
+# Sensitive or high-churn files
+.idea/**/dataSources/
+.idea/**/dataSources.ids
+.idea/**/dataSources.local.xml
+.idea/**/sqlDataSources.xml
+.idea/**/dynamic.xml
+.idea/**/uiDesigner.xml
+.idea/**/dbnavigator.xml
+
+# Gradle
+.idea/**/gradle.xml
+.idea/**/libraries
+
+# Gradle and Maven with auto-import
+# When using Gradle or Maven with auto-import, you should exclude module files,
+# since they will be recreated, and may cause churn. Uncomment if using
+# auto-import.
+# .idea/artifacts
+# .idea/compiler.xml
+# .idea/jarRepositories.xml
+# .idea/modules.xml
+# .idea/*.iml
+# .idea/modules
+# *.iml
+# *.ipr
+
+# CMake
+cmake-build-*/
+
+# Mongo Explorer plugin
+.idea/**/mongoSettings.xml
+
+# File-based project format
+*.iws
+
+# IntelliJ
+out/
+
+# mpeltonen/sbt-idea plugin
+.idea_modules/
+
+# JIRA plugin
+atlassian-ide-plugin.xml
+
+# Cursive Clojure plugin
+.idea/replstate.xml
+
+# Crashlytics plugin (for Android Studio and IntelliJ)
+com_crashlytics_export_strings.xml
+crashlytics.properties
+crashlytics-build.properties
+fabric.properties
+
+# Editor-based Rest Client
+.idea/httpRequests
+
+# Android studio 3.1+ serialized cache file
+.idea/caches/build_file_checksums.ser
+
+### CodeIgniter template
+*/config/development
+*/logs/log-*.php
+!*/logs/index.html
+*/cache/*
+!*/cache/index.html
+!*/cache/.htaccess
+
+user_guide_src/build/*
+user_guide_src/cilexer/build/*
+user_guide_src/cilexer/dist/*
+user_guide_src/cilexer/pycilexer.egg-info/*
+
+#codeigniter 3
+application/logs/*
+!application/logs/index.html
+!application/logs/.htaccess
+/vendor/
+
+### Emacs template
+# -*- mode: gitignore; -*-
+*~
+\#*\#
+/.emacs.desktop
+/.emacs.desktop.lock
+*.elc
+auto-save-list
+tramp
+.\#*
+
+# Org-mode
+.org-id-locations
+*_archive
+
+# flymake-mode
+*_flymake.*
+
+# eshell files
+/eshell/history
+/eshell/lastdir
+
+# elpa packages
+/elpa/
+
+# reftex files
+*.rel
+
+# AUCTeX auto folder
+/auto/
+
+# cask packages
+.cask/
+dist/
+
+# Flycheck
+flycheck_*.el
+
+# server auth directory
+/server/
+
+# projectiles files
+.projectile
+
+# directory configuration
+.dir-locals.el
+
+# network security
+/network-security.data
+
+
+### Windows template
+# Windows thumbnail cache files
+Thumbs.db
+Thumbs.db:encryptable
+ehthumbs.db
+ehthumbs_vista.db
+
+# Dump file
+*.stackdump
+
+# Folder config file
+[Dd]esktop.ini
+
+# Recycle Bin used on file shares
+$RECYCLE.BIN/
+
+# Windows Installer files
+*.cab
+*.msix
+*.msm
+*.msp
+
+# Windows shortcuts
+*.lnk
+
+### VisualStudioCode template
+.vscode/*
+!.vscode/settings.json
+!.vscode/tasks.json
+!.vscode/launch.json
+!.vscode/extensions.json
+*.code-workspace
+
+# Local History for Visual Studio Code
+.history/
+
+### CMake template
+CMakeLists.txt.user
+CMakeCache.txt
+CMakeFiles
+CMakeScripts
+Testing
+cmake_install.cmake
+install_manifest.txt
+compile_commands.json
+CTestTestfile.cmake
+_deps
+
+### VisualStudio template
+## Ignore Visual Studio temporary files, build results, and
+## files generated by popular Visual Studio add-ons.
+##
+## Get latest from https://github.com/github/gitignore/blob/master/VisualStudio.gitignore
+
+# User-specific files
+*.rsuser
+*.suo
+*.user
+*.userosscache
+*.sln.docstates
+
+# User-specific files (MonoDevelop/Xamarin Studio)
+*.userprefs
+
+# Mono auto generated files
+mono_crash.*
+
+# Build results
+[Dd]ebug/
+[Dd]ebugPublic/
+[Rr]elease/
+[Rr]eleases/
+x64/
+x86/
+[Ww][Ii][Nn]32/
+[Aa][Rr][Mm]/
+[Aa][Rr][Mm]64/
+bld/
+[Bb]in/
+[Oo]bj/
+[Ll]og/
+[Ll]ogs/
+
+# Visual Studio 2015/2017 cache/options directory
+.vs/
+# Uncomment if you have tasks that create the project's static files in wwwroot
+#wwwroot/
+
+# Visual Studio 2017 auto generated files
+Generated\ Files/
+
+# MSTest test Results
+[Tt]est[Rr]esult*/
+[Bb]uild[Ll]og.*
+
+# NUnit
+*.VisualState.xml
+TestResult.xml
+nunit-*.xml
+
+# Build Results of an ATL Project
+[Dd]ebugPS/
+[Rr]eleasePS/
+dlldata.c
+
+# Benchmark Results
+BenchmarkDotNet.Artifacts/
+
+# .NET Core
+project.lock.json
+project.fragment.lock.json
+artifacts/
+
+# ASP.NET Scaffolding
+ScaffoldingReadMe.txt
+
+# StyleCop
+StyleCopReport.xml
+
+# Files built by Visual Studio
+*_i.c
+*_p.c
+*_h.h
+*.ilk
+*.meta
+*.obj
+*.iobj
+*.pch
+*.pdb
+*.ipdb
+*.pgc
+*.pgd
+*.rsp
+*.sbr
+*.tlb
+*.tli
+*.tlh
+*.tmp_proj
+*_wpftmp.csproj
+*.vspscc
+*.vssscc
+.builds
+*.pidb
+*.svclog
+*.scc
+
+# Chutzpah Test files
+_Chutzpah*
+
+# Visual C++ cache files
+ipch/
+*.aps
+*.ncb
+*.opendb
+*.opensdf
+*.sdf
+*.cachefile
+*.VC.db
+*.VC.VC.opendb
+
+# Visual Studio profiler
+*.psess
+*.vsp
+*.vspx
+*.sap
+
+# Visual Studio Trace Files
+*.e2e
+
+# TFS 2012 Local Workspace
+$tf/
+
+# Guidance Automation Toolkit
+*.gpState
+
+# ReSharper is a .NET coding add-in
+_ReSharper*/
+*.[Rr]e[Ss]harper
+*.DotSettings.user
+
+# TeamCity is a build add-in
+_TeamCity*
+
+# DotCover is a Code Coverage Tool
+*.dotCover
+
+# AxoCover is a Code Coverage Tool
+.axoCover/*
+!.axoCover/settings.json
+
+# Coverlet is a free, cross platform Code Coverage Tool
+coverage*.json
+coverage*.xml
+coverage*.info
+
+# Visual Studio code coverage results
+*.coverage
+*.coveragexml
+
+# NCrunch
+_NCrunch_*
+.*crunch*.local.xml
+nCrunchTemp_*
+
+# MightyMoose
+*.mm.*
+AutoTest.Net/
+
+# Web workbench (sass)
+.sass-cache/
+
+# Installshield output folder
+[Ee]xpress/
+
+# DocProject is a documentation generator add-in
+DocProject/buildhelp/
+DocProject/Help/*.HxT
+DocProject/Help/*.HxC
+DocProject/Help/*.hhc
+DocProject/Help/*.hhk
+DocProject/Help/*.hhp
+DocProject/Help/Html2
+DocProject/Help/html
+
+# Click-Once directory
+publish/
+
+# Publish Web Output
+*.[Pp]ublish.xml
+*.azurePubxml
+# Note: Comment the next line if you want to checkin your web deploy settings,
+# but database connection strings (with potential passwords) will be unencrypted
+*.pubxml
+*.publishproj
+
+# Microsoft Azure Web App publish settings. Comment the next line if you want to
+# checkin your Azure Web App publish settings, but sensitive information contained
+# in these scripts will be unencrypted
+PublishScripts/
+
+# NuGet Packages
+*.nupkg
+# NuGet Symbol Packages
+*.snupkg
+# The packages folder can be ignored because of Package Restore
+**/[Pp]ackages/*
+# except build/, which is used as an MSBuild target.
+!**/[Pp]ackages/build/
+# Uncomment if necessary however generally it will be regenerated when needed
+#!**/[Pp]ackages/repositories.config
+# NuGet v3's project.json files produces more ignorable files
+*.nuget.props
+*.nuget.targets
+
+# Microsoft Azure Build Output
+csx/
+*.build.csdef
+
+# Microsoft Azure Emulator
+ecf/
+rcf/
+
+# Windows Store app package directories and files
+AppPackages/
+BundleArtifacts/
+Package.StoreAssociation.xml
+_pkginfo.txt
+*.appx
+*.appxbundle
+*.appxupload
+
+# Visual Studio cache files
+# files ending in .cache can be ignored
+*.[Cc]ache
+# but keep track of directories ending in .cache
+!?*.[Cc]ache/
+
+# Others
+ClientBin/
+~$*
+*.dbmdl
+*.dbproj.schemaview
+*.jfm
+*.pfx
+*.publishsettings
+orleans.codegen.cs
+
+# Including strong name files can present a security risk
+# (https://github.com/github/gitignore/pull/2483#issue-259490424)
+#*.snk
+
+# Since there are multiple workflows, uncomment next line to ignore bower_components
+# (https://github.com/github/gitignore/pull/1529#issuecomment-104372622)
+#bower_components/
+
+# RIA/Silverlight projects
+Generated_Code/
+
+# Backup & report files from converting an old project file
+# to a newer Visual Studio version. Backup files are not needed,
+# because we have git ;-)
+_UpgradeReport_Files/
+Backup*/
+UpgradeLog*.XML
+UpgradeLog*.htm
+ServiceFabricBackup/
+*.rptproj.bak
+
+# SQL Server files
+*.mdf
+*.ldf
+*.ndf
+
+# Business Intelligence projects
+*.rdl.data
+*.bim.layout
+*.bim_*.settings
+*.rptproj.rsuser
+*- [Bb]ackup.rdl
+*- [Bb]ackup ([0-9]).rdl
+*- [Bb]ackup ([0-9][0-9]).rdl
+
+# Microsoft Fakes
+FakesAssemblies/
+
+# GhostDoc plugin setting file
+*.GhostDoc.xml
+
+# Node.js Tools for Visual Studio
+.ntvs_analysis.dat
+node_modules/
+
+# Visual Studio 6 build log
+*.plg
+
+# Visual Studio 6 workspace options file
+*.opt
+
+# Visual Studio 6 auto-generated workspace file (contains which files were open etc.)
+*.vbw
+
+# Visual Studio LightSwitch build output
+**/*.HTMLClient/GeneratedArtifacts
+**/*.DesktopClient/GeneratedArtifacts
+**/*.DesktopClient/ModelManifest.xml
+**/*.Server/GeneratedArtifacts
+**/*.Server/ModelManifest.xml
+_Pvt_Extensions
+
+# Paket dependency manager
+.paket/paket.exe
+paket-files/
+
+# FAKE - F# Make
+.fake/
+
+# CodeRush personal settings
+.cr/personal
+
+# Python Tools for Visual Studio (PTVS)
+__pycache__/
+*.pyc
+
+# Cake - Uncomment if you are using it
+# tools/**
+# !tools/packages.config
+
+# Tabs Studio
+*.tss
+
+# Telerik's JustMock configuration file
+*.jmconfig
+
+# BizTalk build output
+*.btp.cs
+*.btm.cs
+*.odx.cs
+*.xsd.cs
+
+# OpenCover UI analysis results
+OpenCover/
+
+# Azure Stream Analytics local run output
+ASALocalRun/
+
+# MSBuild Binary and Structured Log
+*.binlog
+
+# NVidia Nsight GPU debugger configuration file
+*.nvuser
+
+# MFractors (Xamarin productivity tool) working folder
+.mfractor/
+
+# Local History for Visual Studio
+.localhistory/
+
+# BeatPulse healthcheck temp database
+healthchecksdb
+
+# Backup folder for Package Reference Convert tool in Visual Studio 2017
+MigrationBackup/
+
+# Ionide (cross platform F# VS Code tools) working folder
+.ionide/
+
+# Fody - auto-generated XML schema
+FodyWeavers.xsd
+
+### Python template
+# Byte-compiled / optimized / DLL files
+*.py[cod]
+*$py.class
+
+# C extensions
+*.so
+
+# Distribution / packaging
+.Python
+build/
+develop-eggs/
+downloads/
+eggs/
+.eggs/
+lib/
+lib64/
+parts/
+sdist/
+var/
+wheels/
+share/python-wheels/
+*.egg-info/
+.installed.cfg
+*.egg
+MANIFEST
+
+# PyInstaller
+# Usually these files are written by a python script from a template
+# before PyInstaller builds the exe, so as to inject date/other infos into it.
+*.manifest
+*.spec
+
+# Installer logs
+pip-log.txt
+pip-delete-this-directory.txt
+
+# Unit test / coverage reports
+htmlcov/
+.tox/
+.nox/
+.coverage
+.coverage.*
+.cache
+nosetests.xml
+coverage.xml
+*.cover
+*.py,cover
+.hypothesis/
+.pytest_cache/
+cover/
+
+# Translations
+*.mo
+*.pot
+
+# Django stuff:
+local_settings.py
+db.sqlite3
+db.sqlite3-journal
+
+# Flask stuff:
+instance/
+.webassets-cache
+
+# Scrapy stuff:
+.scrapy
+
+# Sphinx documentation
+docs/_build/
+
+# PyBuilder
+.pybuilder/
+target/
+
+# Jupyter Notebook
+.ipynb_checkpoints
+
+# IPython
+profile_default/
+ipython_config.py
+
+# pyenv
+# For a library or package, you might want to ignore these files since the code is
+# intended to run in multiple environments; otherwise, check them in:
+# .python-version
+
+# pipenv
+# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
+# However, in case of collaboration, if having platform-specific dependencies or dependencies
+# having no cross-platform support, pipenv may install dependencies that don't work, or not
+# install all needed dependencies.
+#Pipfile.lock
+
+# PEP 582; used by e.g. github.com/David-OConnor/pyflow
+__pypackages__/
+
+# Celery stuff
+celerybeat-schedule
+celerybeat.pid
+
+# SageMath parsed files
+*.sage.py
+
+# Environments
+.env
+.venv
+venv/
+env.bak/
+venv.bak/
+
+# Spyder project settings
+.spyderproject
+.spyproject
+
+# Rope project settings
+.ropeproject
+
+# mkdocs documentation
+/site
+
+# mypy
+.mypy_cache/
+.dmypy.json
+dmypy.json
+
+# Pyre type checker
+.pyre/
+
+# pytype static type analyzer
+.pytype/
+
+# Cython debug symbols
+cython_debug/
+
+### Backup template
+*.gho
+*.ori
+*.orig
+
+### Node template
+# Logs
+logs
+npm-debug.log*
+yarn-debug.log*
+yarn-error.log*
+lerna-debug.log*
+
+# Diagnostic reports (https://nodejs.org/api/report.html)
+report.[0-9]*.[0-9]*.[0-9]*.[0-9]*.json
+
+# Runtime data
+pids
+*.pid
+*.seed
+*.pid.lock
+
+# Directory for instrumented libs generated by jscoverage/JSCover
+lib-cov
+
+# Coverage directory used by tools like istanbul
+coverage
+*.lcov
+
+# nyc test coverage
+.nyc_output
+
+# Grunt intermediate storage (https://gruntjs.com/creating-plugins#storing-task-files)
+.grunt
+
+# Bower dependency directory (https://bower.io/)
+bower_components
+
+# node-waf configuration
+.lock-wscript
+
+# Compiled binary addons (https://nodejs.org/api/addons.html)
+build/Release
+
+# Dependency directories
+jspm_packages/
+
+# Snowpack dependency directory (https://snowpack.dev/)
+web_modules/
+
+# TypeScript cache
+*.tsbuildinfo
+
+# Optional npm cache directory
+.npm
+
+# Optional eslint cache
+.eslintcache
+
+# Microbundle cache
+.rpt2_cache/
+.rts2_cache_cjs/
+.rts2_cache_es/
+.rts2_cache_umd/
+
+# Optional REPL history
+.node_repl_history
+
+# Output of 'npm pack'
+
+# Yarn Integrity file
+.yarn-integrity
+
+# dotenv environment variables file
+.env.test
+
+# parcel-bundler cache (https://parceljs.org/)
+.parcel-cache
+
+# Next.js build output
+.next
+out
+
+# Nuxt.js build / generate output
+.nuxt
+dist
+
+# Gatsby files
+.cache/
+# Comment in the public line in if your project uses Gatsby and not Next.js
+# https://nextjs.org/blog/next-9-1#public-directory-support
+# public
+
+# vuepress build output
+.vuepress/dist
+
+# Serverless directories
+.serverless/
+
+# FuseBox cache
+.fusebox/
+
+# DynamoDB Local files
+.dynamodb/
+
+# TernJS port file
+.tern-port
+
+# Stores VSCode versions used for testing VSCode extensions
+.vscode-test
+
+# yarn v2
+.yarn/cache
+.yarn/unplugged
+.yarn/build-state.yml
+.yarn/install-state.gz
+.pnp.*
+
+### VirtualEnv template
+# Virtualenv
+# http://iamzed.com/2009/05/07/a-primer-on-virtualenv/
+[Bb]in
+[Ii]nclude
+[Ll]ib
+[Ll]ib64
+[Ll]ocal
+pyvenv.cfg
+pip-selfcheck.json
+
+### macOS template
+# General
+.DS_Store
+.AppleDouble
+.LSOverride
+
+# Icon must end with two \r
+Icon
+
+# Thumbnails
+._*
+
+# Files that might appear in the root of a volume
+.DocumentRevisions-V100
+.fseventsd
+.Spotlight-V100
+.TemporaryItems
+.Trashes
+.VolumeIcon.icns
+.com.apple.timemachine.donotpresent
+
+# Directories potentially created on remote AFP share
+.AppleDB
+.AppleDesktop
+Network Trash Folder
+Temporary Items
+.apdisk
+
+### Go template
+# Binaries for programs and plugins
+*.exe~
+*.dll
+*.dylib
+
+# Test binary, built with `go test -c`
+*.test
+
+# Output of the go coverage tool, specifically when used with LiteIDE
+*.out
+
+# Dependency directories (remove the comment below to include it)
+# vendor/
+
+### C template
+# Prerequisites
+*.d
+
+# Object files
+*.o
+*.ko
+*.elf
+
+# Linker output
+*.map
+*.exp
+
+# Precompiled Headers
+*.gch
+
+# Libraries
+*.lib
+*.a
+*.la
+*.lo
+
+# Shared objects (inc. Windows DLLs)
+*.so.*
+
+# Executables
+*.app
+*.i*86
+*.x86_64
+*.hex
+
+# Debug files
+*.dSYM/
+*.su
+*.idb
+
+# Kernel Module Compile Results
+*.mod*
+*.cmd
+.tmp_versions/
+modules.order
+Module.symvers
+Mkfile.old
+dkms.conf
+
+### Example user template template
+### Example user template
+
+# IntelliJ project files
+.idea
+*.iml
+gen
+### TextMate template
+*.tmproj
+*.tmproject
+tmtags
+
+### Anjuta template
+# Local configuration folder and symbol database
+/.anjuta/
+/.anjuta_sym_db.db
+
+### XilinxISE template
+# intermediate build files
+*.bgn
+*.bit
+*.bld
+*.cmd_log
+*.drc
+*.ll
+*.lso
+*.msd
+*.msk
+*.ncd
+*.ngc
+*.ngd
+*.ngr
+*.pad
+*.par
+*.pcf
+*.prj
+*.ptwx
+*.rbb
+*.rbd
+*.stx
+*.syr
+*.twr
+*.twx
+*.unroutes
+*.ut
+*.xpi
+*.xst
+*_bitgen.xwbt
+*_envsettings.html
+*_map.map
+*_map.mrp
+*_map.ngm
+*_map.xrpt
+*_ngdbuild.xrpt
+*_pad.csv
+*_pad.txt
+*_par.xrpt
+*_summary.html
+*_summary.xml
+*_usage.xml
+*_xst.xrpt
+
+# iMPACT generated files
+_impactbatch.log
+impact.xsl
+impact_impact.xwbt
+ise_impact.cmd
+webtalk_impact.xml
+
+# Core Generator generated files
+xaw2verilog.log
+
+# project-wide generated files
+*.gise
+par_usage_statistics.html
+usage_statistics_webtalk.html
+webtalk.log
+webtalk_pn.xml
+
+# generated folders
+iseconfig/
+xlnx_auto_0_xdb/
+xst/
+_ngo/
+_xmsgs/
+
+### TortoiseGit template
+# Project-level settings
+/.tgitconfig
+
+### C++ template
+# Prerequisites
+
+# Compiled Object files
+*.slo
+
+# Precompiled Headers
+
+# Compiled Dynamic libraries
+
+# Fortran module files
+*.mod
+*.smod
+
+# Compiled Static libraries
+*.lai
+
+# Executables
+
+### SublimeText template
+# Cache files for Sublime Text
+*.tmlanguage.cache
+*.tmPreferences.cache
+*.stTheme.cache
+
+# Workspace files are user-specific
+*.sublime-workspace
+
+# Project files should be checked into the repository, unless a significant
+# proportion of contributors will probably not be using Sublime Text
+# *.sublime-project
+
+# SFTP configuration file
+sftp-config.json
+sftp-config-alt*.json
+
+# Package control specific files
+Package Control.last-run
+Package Control.ca-list
+Package Control.ca-bundle
+Package Control.system-ca-bundle
+Package Control.cache/
+Package Control.ca-certs/
+Package Control.merged-ca-bundle
+Package Control.user-ca-bundle
+oscrypto-ca-bundle.crt
+bh_unicode_properties.cache
+
+# Sublime-github package stores a github token in this file
+# https://packagecontrol.io/packages/sublime-github
+GitHub.sublime-settings
+
+### Vim template
+# Swap
+[._]*.s[a-v][a-z]
+!*.svg # comment out if you don't need vector files
+[._]*.sw[a-p]
+[._]s[a-rt-v][a-z]
+[._]ss[a-gi-z]
+[._]sw[a-p]
+
+# Session
+Session.vim
+Sessionx.vim
+
+# Temporary
+.netrwhist
+# Auto-generated tag files
+tags
+# Persistent undo
+[._]*.un~
+
+### Autotools template
+# http://www.gnu.org/software/automake
+
+Makefile.in
+/ar-lib
+/mdate-sh
+/py-compile
+/test-driver
+/ylwrap
+.deps/
+.dirstamp
+
+# http://www.gnu.org/software/autoconf
+
+autom4te.cache
+/autoscan.log
+/autoscan-*.log
+/aclocal.m4
+/compile
+/config.guess
+/config.h.in
+/config.log
+/config.status
+/config.sub
+/configure
+/configure.scan
+/depcomp
+/install-sh
+/missing
+/stamp-h1
+
+# https://www.gnu.org/software/libtool/
+
+/ltmain.sh
+
+# http://www.gnu.org/software/texinfo
+
+/texinfo.tex
+
+# http://www.gnu.org/software/m4/
+
+m4/libtool.m4
+m4/ltoptions.m4
+m4/ltsugar.m4
+m4/ltversion.m4
+m4/lt~obsolete.m4
+
+# Generated Makefile
+# (meta build system like autotools,
+# can automatically generate from config.status script
+# (which is called by configure script))
+
+### Lua template
+# Compiled Lua sources
+luac.out
+
+# luarocks build files
+*.src.rock
+*.tar.gz
+
+# Object files
+*.os
+
+# Precompiled Headers
+
+# Libraries
+*.def
+
+# Shared objects (inc. Windows DLLs)
+
+# Executables
+
+
+### Vagrant template
+# General
+.vagrant/
+
+# Log files (if you are creating logs in debug mode, uncomment this)
+# *.log
+
+### Xcode template
+# Xcode
+#
+# gitignore contributors: remember to update Global/Xcode.gitignore, Objective-C.gitignore & Swift.gitignore
+
+## User settings
+xcuserdata/
+
+## compatibility with Xcode 8 and earlier (ignoring not required starting Xcode 9)
+*.xcscmblueprint
+*.xccheckout
+
+## compatibility with Xcode 3 and earlier (ignoring not required starting Xcode 4)
+DerivedData/
+*.moved-aside
+*.pbxuser
+!default.pbxuser
+*.mode1v3
+!default.mode1v3
+*.mode2v3
+!default.mode2v3
+*.perspectivev3
+!default.perspectivev3
+
+## Gcc Patch
+/*.gcno
+
+### Linux template
+
+# temporary files which can be created if a process still has a handle open of a deleted file
+.fuse_hidden*
+
+# KDE directory preferences
+.directory
+
+# Linux trash folder which might appear on any partition or disk
+.Trash-*
+
+# .nfs files are created when an open file is removed but is still being accessed
+.nfs*
+
+### GitBook template
+# Node rules:
+## Grunt intermediate storage (http://gruntjs.com/creating-plugins#storing-task-files)
+
+## Dependency directory
+## Commenting this out is preferred by some people, see
+## https://docs.npmjs.com/misc/faq#should-i-check-my-node_modules-folder-into-git
+node_modules
+
+# Book build output
+_book
+
+# eBook build output
+*.epub
+*.mobi
+
+### CodeSniffer template
+# gitignore for the PHP Codesniffer framework
+# website: https://github.com/squizlabs/PHP_CodeSniffer
+#
+# Recommended template: PHP.gitignore
+
+/wpcs/*
+
+### PuTTY template
+# Private key
+*.ppk
+*_pb2.py
+*.pth
+*.pth.tar
+*.pt
+*.npy
+__pycache__
+*.egg-info
+experiment_config.yaml
+api-log/
+log/
+htmlcov
+*.lock
+.coverage*
+!.coveragerc
+#/test_*
+.python-version
+/name.txt
+/summary_log
+policy_*
+/data
+.vscode
+formatted_*
+**/exp
+**/benchmark
+**/model_zoo
+*ckpt*
+log*
+*.puml.png
+*.puml.eps
+*.puml.svg
+default*
+events.*
+
+# DI-engine special key
+*default_logger.txt
+*default_tb_logger
+*evaluate.txt
+*total_config.py
+eval_config.py
+collect_demo_data_config.py
+!ding/**/*.py
+events.*
+/test_*
+# LightZero special key
+/zoo/board_games/**/*.c
+/zoo/board_games/**/*.cpp
+/lzero/mcts/**/*.cpp
+/zoo/**/*.c
+/lzero/mcts/**/*.so
+/lzero/mcts/**/*.h
+!/lzero/mcts/**/lib
+!/lzero/mcts/**/lib/*.cpp
+!/lzero/mcts/**/lib/*.hpp
+!/lzero/mcts/**/lib/*.h
+**/tb/*
+**/mcts/ctree/tests_cpp/*
+**/*tmp*
\ No newline at end of file
diff --git a/LightZero/.gitmodules b/LightZero/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..fa22d552b6885d0b474a82123185fa203d220689
--- /dev/null
+++ b/LightZero/.gitmodules
@@ -0,0 +1,3 @@
+[submodule "pybind11"]
+ path = lzero/mcts/ctree/ctree_alphazero/pybind11
+ url = https://github.com/pybind/pybind11.git
\ No newline at end of file
diff --git a/LightZero/.gitpod.Dockerfile b/LightZero/.gitpod.Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..17c0afe45867df8c868c08bae8093b14d640ce1d
--- /dev/null
+++ b/LightZero/.gitpod.Dockerfile
@@ -0,0 +1,30 @@
+# Start from Ubuntu 20.04
+FROM ubuntu:20.04
+
+# Set the working directory in the Docker image
+WORKDIR /opendilab
+
+# Install Python 3.8 and other dependencies
+# We update the apt package list, install Python 3.8, pip, compilers and other necessary tools.
+# After installing, we clean up the apt cache and remove unnecessary lists to save space.
+RUN apt-get update && \
+ apt-get install -y python3.8 python3-pip gcc g++ swig git && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+# Create a symbolic link for Python and pip
+# This makes it easy to call python and pip from any location in the container.
+RUN ln -s /usr/bin/python3.8 /usr/local/bin/python && \
+ ln -s /usr/bin/pip3 /usr/local/bin/pip
+
+# Update pip and setuptools to the latest version
+# This step ensures that we have the latest tools for installing Python packages.
+RUN python -m pip install --upgrade pip setuptools
+
+# Clone the LightZero repository from GitHub
+# This step downloads the latest version of LightZero to our Docker image.
+RUN git clone https://github.com/opendilab/LightZero.git
+
+# Install the LightZero package in editable mode
+# The -e option allows us to edit the source code without needing to reinstall the package.
+RUN pip install -e ./LightZero
diff --git a/LightZero/.gitpod.yml b/LightZero/.gitpod.yml
new file mode 100644
index 0000000000000000000000000000000000000000..c2e9c331a817d3c9b601c6e413d2b54ed335897a
--- /dev/null
+++ b/LightZero/.gitpod.yml
@@ -0,0 +1,14 @@
+# You should adapt it to your project's needs (see https://www.gitpod.io/docs/introduction/learn-gitpod/gitpod-yaml).
+# After you've adjusted this file to your liking, commit it to your remote git repository to share the Gitpod configuration with others.
+
+# If you need to start from a template, Gitpod provides ready-to-use ones: https://www.gitpod.io/docs/introduction/getting-started/quickstart
+
+image:
+ file: .gitpod.Dockerfile
+
+tasks:
+ # The 'init' command is run once at the start of workspace creation.
+ # It is typically used for installing project dependencies, as in this case.
+ - init: |
+ pip install -r requirements.txt
+ # Add any other necessary commands here
diff --git a/LightZero/.style.yapf b/LightZero/.style.yapf
new file mode 100644
index 0000000000000000000000000000000000000000..edd867c28237606d759f83a8242d93ec821557b4
--- /dev/null
+++ b/LightZero/.style.yapf
@@ -0,0 +1,11 @@
+[style]
+# For explanation and more information: https://github.com/google/yapf
+BASED_ON_STYLE=pep8
+DEDENT_CLOSING_BRACKETS=True
+SPLIT_BEFORE_FIRST_ARGUMENT=True
+ALLOW_SPLIT_BEFORE_DICT_VALUE=False
+JOIN_MULTIPLE_LINES=False
+COLUMN_LIMIT=120
+BLANK_LINE_BEFORE_NESTED_CLASS_OR_DEF=True
+BLANK_LINES_AROUND_TOP_LEVEL_DEFINITION=2
+SPACES_AROUND_POWER_OPERATOR=True
diff --git a/LightZero/CHANGELOG.md b/LightZero/CHANGELOG.md
new file mode 100644
index 0000000000000000000000000000000000000000..2ecc245430e36c5e302a4fc11c5291ab5f5060a1
--- /dev/null
+++ b/LightZero/CHANGELOG.md
@@ -0,0 +1,54 @@
+2023.12.07 (v0.0.3)
+- env: MiniGrid env (#110)
+- env: Bsuite env (#110)
+- env: GoBigger env (#39)
+- algo: RND+MuZero (#110)
+- algo: Sampled AlphaZero (#141)
+- algo: Multi-Agent MuZero/EfficientZero (#39)
+- feature: add ctree version of mcts in alphazero (#142)
+- feature: upgrade the dependency on gym with gymnasium (#150)
+- feature: add agent class to support LightZero's HuggingFace Model Zoo (#163)
+- feature: add recent MCTS-related papers in readme (#159)
+- feature: add muzero config for connect4 (#107)
+- feature: added CONTRIBUTING.md (#119)
+- feature: added .gitpod.yml and .gitpod.Dockerfile (#123)
+- feature: added contributors subsection in README (#132)
+- feature: added CODE_OF_CONDUCT.md (#127)
+- polish: refine comments and render_eval configs for various common envs (#154) (#161)
+- polish: polish action_type and env_type, fix test.yml, fix unittest (#160)
+- polish: update env and algo tutorial doc (#106)
+- polish: polish gomoku env (#141)
+- polish: add random_policy support for continuous env (#118)
+- polish: polish simulation method of ptree_az (#120)
+- polish: polish comments of game_segment_to_array
+- fix: fix render method for various common envs (#154) (#161)
+- fix: fix gumbel muzero collector bug, fix gumbel typo (#144)
+- fix: fix assert bug in game_segment.py (#138)
+- fix: fix visit_count_distributions name in muzero_evaluator
+- fix: fix mcts and alphabeta bot unittest (#120)
+- fix: fix typos in ptree_mz.py (#113)
+- fix: fix root_sampled_actions_tmp shape bug in sez ptree
+- fix: fix policy utils unittest
+- fix: fix typo in readme and add a 'back to top' button in readme (#104) (#109) (#111)
+- style: add nips2023 paper link
+
+2023.09.21 (v0.0.2)
+- env: MuJoCo env (#50)
+- env: 2048 env (#64)
+- env: Connect4 env (#63)
+- algo: Gumbel MuZero (#22)
+- algo: Stochastic MuZero (#64)
+- feature: add Dockerfile and its usage instructions (#95)
+- feature: add doc about how to customize envs and algos (#78)
+- feature: add pytorch ddp support (#68)
+- feature: add eps greedy and random collect option in train_muzero_entry (#54)
+- feature: add atari visualization option (#40)
+- feature: add log_buffer_memory_usage utils (#30)
+- polish: polish mcts and ptree_az (#57) (#61)
+- polish: polish readme (#36) (#47) (#51) (#77) (#95) (#96)
+- polish: update paper notes (#89) (#91)
+- polish: polish model and configs (#26) (#27) (#50)
+- fix: fix priority bug in muzero collector (#74)
+- style: update github action (#71) (#72) (#73) (#81) (#83) (#84) (#90)
+
+2023.04.14 (v0.0.1)
\ No newline at end of file
diff --git a/LightZero/CODE_OF_CONDUCT.md b/LightZero/CODE_OF_CONDUCT.md
new file mode 100644
index 0000000000000000000000000000000000000000..c70133d09b68011a90de2a78f5a8d4cb002e028f
--- /dev/null
+++ b/LightZero/CODE_OF_CONDUCT.md
@@ -0,0 +1,128 @@
+# Contributor Covenant Code of Conduct
+
+## Our Pledge
+
+We as members, contributors, and leaders pledge to make participation in our
+community a harassment-free experience for everyone, regardless of age, body
+size, visible or invisible disability, ethnicity, sex characteristics, gender
+identity and expression, level of experience, education, socio-economic status,
+nationality, personal appearance, race, religion, or sexual identity
+and orientation.
+
+We pledge to act and interact in ways that contribute to an open, welcoming,
+diverse, inclusive, and healthy community.
+
+## Our Standards
+
+Examples of behavior that contributes to a positive environment for our
+community include:
+
+* Demonstrating empathy and kindness toward other people
+* Being respectful of differing opinions, viewpoints, and experiences
+* Giving and gracefully accepting constructive feedback
+* Accepting responsibility and apologizing to those affected by our mistakes,
+ and learning from the experience
+* Focusing on what is best not just for us as individuals, but for the
+ overall community
+
+## Examples of unacceptable behavior include:
+
+* The use of sexualized language or imagery, and sexual attention or
+ advances of any kind
+* Trolling, insulting or derogatory comments, and personal or political attacks
+* Public or private harassment
+* Publishing others' private information, such as a physical or email
+ address, without their explicit permission
+* Other conduct which could reasonably be considered inappropriate in a
+ professional setting
+
+## Enforcement Responsibilities
+
+Community leaders are responsible for clarifying and enforcing our standards of
+acceptable behavior and will take appropriate and fair corrective action in
+response to any behavior that they deem inappropriate, threatening, offensive,
+or harmful.
+
+Community leaders have the right and responsibility to remove, edit, or reject
+comments, commits, code, wiki edits, issues, and other contributions that are
+not aligned to this Code of Conduct, and will communicate reasons for moderation
+decisions when appropriate.
+
+## Scope
+
+This Code of Conduct applies within all community spaces, and also applies when
+an individual is officially representing the community in public spaces.
+Examples of representing our community include using an official e-mail address,
+posting via an official social media account, or acting as an appointed
+representative at an online or offline event.
+
+## Enforcement
+
+Instances of abusive, harassing, or otherwise unacceptable behavior may be
+reported to the community leaders responsible for enforcement at
+.
+All complaints will be reviewed and investigated promptly and fairly.
+
+All community leaders are obligated to respect the privacy and security of the
+reporter of any incident.
+
+## Enforcement Guidelines
+
+Community leaders will follow these Community Impact Guidelines in determining
+the consequences for any action they deem in violation of this Code of Conduct:
+
+### 1. Correction
+
+**Community Impact**: Use of inappropriate language or other behavior deemed
+unprofessional or unwelcome in the community.
+
+**Consequence**: A private, written warning from community leaders, providing
+clarity around the nature of the violation and an explanation of why the
+behavior was inappropriate. A public apology may be requested.
+
+### 2. Warning
+
+**Community Impact**: A violation through a single incident or series
+of actions.
+
+**Consequence**: A warning with consequences for continued behavior. No
+interaction with the people involved, including unsolicited interaction with
+those enforcing the Code of Conduct, for a specified period of time. This
+includes avoiding interactions in community spaces as well as external channels
+like social media. Violating these terms may lead to a temporary or
+permanent ban.
+
+### 3. Temporary Ban
+
+**Community Impact**: A serious violation of community standards, including
+sustained inappropriate behavior.
+
+**Consequence**: A temporary ban from any sort of interaction or public
+communication with the community for a specified period of time. No public or
+private interaction with the people involved, including unsolicited interaction
+with those enforcing the Code of Conduct, is allowed during this period.
+Violating these terms may lead to a permanent ban.
+
+### 4. Permanent Ban
+
+**Community Impact**: Demonstrating a pattern of violation of community
+standards, including sustained inappropriate behavior, harassment of an
+individual, or aggression toward or disparagement of classes of individuals.
+
+**Consequence**: A permanent ban from any sort of public interaction within
+the community.
+
+## Attribution
+
+This Code of Conduct is adapted from the [Contributor Covenant][homepage],
+version 2.0, available at
+https://www.contributor-covenant.org/version/2/0/code_of_conduct.html.
+
+Community Impact Guidelines were inspired by [Mozilla's code of conduct
+enforcement ladder](https://github.com/mozilla/diversity).
+
+[homepage]: https://www.contributor-covenant.org
+
+For answers to common questions about this code of conduct, see the FAQ at
+https://www.contributor-covenant.org/faq. Translations are available at
+https://www.contributor-covenant.org/translations.
diff --git a/LightZero/CONTRIBUTING.md b/LightZero/CONTRIBUTING.md
new file mode 100644
index 0000000000000000000000000000000000000000..6e8823425ae296a1cd07ed99e895286535e69353
--- /dev/null
+++ b/LightZero/CONTRIBUTING.md
@@ -0,0 +1,63 @@
+# 🚀 Welcome to LightZero! 🌟
+
+We're thrilled that you want to contribute to LightZero. Your help is invaluable, and we appreciate your efforts to make this project even better. 😄
+
+## 📝 How to Contribute
+
+1. **Fork the Repository** 🍴
+ - Click on the "Fork" button at the top right of the [LightZero repository](https://github.com/opendilab/LightZero).
+
+2. **Clone your Fork** 💻
+ - `git clone https://github.com/your-username/LightZero.git`
+
+3. **Create a New Branch** 🌿
+ - `git checkout -b your-new-feature`
+
+4. **Make Your Awesome Changes** 💥
+ - Add some cool features.
+ - Fix a bug.
+ - Improve the documentation.
+ - Anything that adds value!
+
+5. **Commit Your Changes** 📦
+ - `git commit -m "Your descriptive commit message"`
+
+6. **Push to Your Fork** 🚢
+ - `git push origin your-new-feature`
+
+7. **Create a Pull Request** 🎉
+ - Go to the [LightZero repository](https://github.com/opendilab/LightZero).
+ - Click on "New Pull Request."
+ - Fill in the details and submit your PR.
+ - Please make sure your PR has a clear title and description.
+
+8. **Review & Collaborate** 🤝
+ - Be prepared to answer questions or make changes to your PR as requested by the maintainers.
+
+9. **Celebrate! 🎉** Your contribution has been added to LightZero.
+
+## 📦 Reporting Issues
+
+If you encounter a bug or have an idea for an improvement, please create an issue in the [Issues](https://github.com/opendilab/LightZero/issues) section. Make sure to include details about the problem and how to reproduce it.
+
+## 🛠 Code Style and Guidelines
+
+We follow a few simple guidelines:
+- Keep your code clean and readable.
+- Use meaningful variable and function names.
+- Comment your code when necessary.
+- Ensure your code adheres to existing coding styles and standards.
+
+For detailed information on code style, unit testing, and code review, please refer to our documentation:
+
+- [Code Style](https://di-engine-docs.readthedocs.io/en/latest/21_code_style/index.html)
+- [Unit Test](https://di-engine-docs.readthedocs.io/en/latest/22_test/index.html)
+- [Code Review](https://di-engine-docs.readthedocs.io/en/latest/24_cooperation/git_guide.html)
+
+## 🤖 Code of Conduct
+
+Please be kind and respectful when interacting with other contributors. We have a [Code of Conduct](LICENSE) to ensure a positive and welcoming environment for everyone.
+
+## 🙌 Thank You! 🙏
+
+Your contribution helps make LightZero even better. We appreciate your dedication to the project. Keep coding and stay awesome! 😃
diff --git a/LightZero/Dockerfile b/LightZero/Dockerfile
new file mode 100644
index 0000000000000000000000000000000000000000..d36b5398c83c600703df41c9f8be84e4765bdc6a
--- /dev/null
+++ b/LightZero/Dockerfile
@@ -0,0 +1,52 @@
+# This Dockerfile describes the process of creating a Docker image that includes
+# the necessary environment to run the LightZero library.
+
+# The Docker image is based on Ubuntu 20.04, and it installs Python 3.8 and other
+# necessary dependencies. It then clones the LightZero library from its GitHub
+# repository and installs it in an editable mode.
+
+# Before building the Docker image, create a new empty directory, move this Dockerfile into it,
+# and navigate into this directory. This is to avoid sending unnecessary files to the Docker daemon
+# during the build. Then you can then build the Docker image using the following command in your terminal:
+# docker build -t ubuntu-py38-lz:latest -f ./Dockerfile .
+
+# To run a container from the image in interactive mode with a Bash shell, you can use:
+# docker run -dit --rm ubuntu-py38-lz:latest /bin/bash
+
+# Once you're inside the container, you can run the example Python script with:
+# python ./LightZero/zoo/classic_control/cartpole/config/cartpole_muzero_config.py
+
+# Note: The working directory inside the Docker image is /opendilab, so you don't need
+# to change your current directory before running the Python script.
+
+
+# Start from Ubuntu 20.04
+FROM ubuntu:20.04
+
+# Set the working directory in the Docker image
+WORKDIR /opendilab
+
+# Install Python 3.8 and other dependencies
+# We update the apt package list, install Python 3.8, pip, compilers and other necessary tools.
+# After installing, we clean up the apt cache and remove unnecessary lists to save space.
+RUN apt-get update && \
+ apt-get install -y python3.8 python3-pip gcc g++ swig git && \
+ apt-get clean && \
+ rm -rf /var/lib/apt/lists/*
+
+# Create a symbolic link for Python and pip
+# This makes it easy to call python and pip from any location in the container.
+RUN ln -s /usr/bin/python3.8 /usr/local/bin/python && \
+ ln -s /usr/bin/pip3 /usr/local/bin/pip
+
+# Update pip and setuptools to the latest version
+# This step ensures that we have the latest tools for installing Python packages.
+RUN python -m pip install --upgrade pip setuptools
+
+# Clone the LightZero repository from GitHub
+# This step downloads the latest version of LightZero to our Docker image.
+RUN git clone https://github.com/opendilab/LightZero.git
+
+# Install the LightZero package in editable mode
+# The -e option allows us to edit the source code without needing to reinstall the package.
+RUN pip install -e ./LightZero
\ No newline at end of file
diff --git a/LightZero/LICENSE b/LightZero/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LightZero/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/LightZero/Makefile b/LightZero/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..4f8936e6c0a9a61effb9285b4e9c00fac88e8212
--- /dev/null
+++ b/LightZero/Makefile
@@ -0,0 +1,71 @@
+.PHONY: docs test unittest build clean benchmark zip
+
+NO_DEBUG ?=
+NO_DOCSTRING ?=
+NO_DEBUG_CMD := $(if ${NO_DOCSTRING},-OO,$(if ${NO_DEBUG},-O,))
+PYTHON ?= $(shell which python) ${NO_DEBUG_CMD}
+
+DOC_DIR := ./docs
+DIST_DIR := ./dist
+WHEELHOUSE_DIR := ./wheelhouse
+BENCHMARK_DIR := ./benchmark
+SRC_DIR := ./lzero
+RUNS_DIR := ./runs
+
+RANGE_DIR ?= .
+RANGE_TEST_DIR := ${SRC_DIR}/${RANGE_DIR}
+RANGE_BENCH_DIR := ${BENCHMARK_DIR}/${RANGE_DIR}
+RANGE_SRC_DIR := ${SRC_DIR}/${RANGE_DIR}
+
+CYTHON_FILES := $(shell find ${SRC_DIR} -name '*.pyx')
+CYTHON_RELATED := \
+ $(addsuffix .c, $(basename ${CYTHON_FILES})) \
+ $(addsuffix .cpp, $(basename ${CYTHON_FILES})) \
+ $(addsuffix .h, $(basename ${CYTHON_FILES})) \
+
+COV_TYPES ?= xml term-missing
+COMPILE_PLATFORM ?= manylinux_2_24_x86_64
+
+
+build:
+ $(PYTHON) setup.py build_ext --inplace \
+ $(if ${LINETRACE},--define CYTHON_TRACE,)
+
+zip:
+ $(PYTHON) -m build --sdist --outdir ${DIST_DIR}
+
+package:
+ $(PYTHON) -m build --sdist --wheel --outdir ${DIST_DIR}
+ for whl in `ls ${DIST_DIR}/*.whl`; do \
+ auditwheel repair $$whl -w ${WHEELHOUSE_DIR} --plat ${COMPILE_PLATFORM} && \
+ cp `ls ${WHEELHOUSE_DIR}/*.whl` ${DIST_DIR} && \
+ rm -rf $$whl ${WHEELHOUSE_DIR}/* \
+ ; done
+
+clean:
+ rm -rf $(shell find ${SRC_DIR} -name '*.so') \
+ $(if ${CYTHON_RELATED},$(shell ls ${CYTHON_RELATED} 2> /dev/null),)
+ rm -rf ${DIST_DIR} ${WHEELHOUSE_DIR}
+
+test: unittest benchmark
+
+unittest:
+ $(PYTHON) -m pytest "${RANGE_TEST_DIR}" \
+ -sv -m unittest \
+ $(shell for type in ${COV_TYPES}; do echo "--cov-report=$$type"; done) \
+ --cov="${RANGE_SRC_DIR}" \
+ $(if ${MIN_COVERAGE},--cov-fail-under=${MIN_COVERAGE},) \
+ $(if ${WORKERS},-n ${WORKERS},)
+
+minitest:
+ $(PYTHON) -m pytest "${SRC_DIR}/mcts/tests/test_game_block.py" \
+ -sv -m unittest \
+ $(shell for type in ${COV_TYPES}; do echo "--cov-report=$$type"; done) \
+ --cov="${SRC_DIR}/mcts/tests/test_game_block.py" \
+ $(if ${MIN_COVERAGE},--cov-fail-under=${MIN_COVERAGE},) \
+ $(if ${WORKERS},-n ${WORKERS},)
+
+docs:
+ $(MAKE) -C "${DOC_DIR}" build
+pdocs:
+ $(MAKE) -C "${DOC_DIR}" prod
diff --git a/LightZero/README.md b/LightZero/README.md
new file mode 100644
index 0000000000000000000000000000000000000000..04d34cedd03d6caba346d7865890d0959cef7884
--- /dev/null
+++ b/LightZero/README.md
@@ -0,0 +1,537 @@
+
+
+# LightZero
+
+
+
+
+
+---
+
+[![Twitter](https://img.shields.io/twitter/url?style=social&url=https%3A%2F%2Ftwitter.com%2Fopendilab)](https://twitter.com/opendilab)
+[![PyPI](https://img.shields.io/pypi/v/LightZero)](https://pypi.org/project/LightZero/)
+![PyPI - Python Version](https://img.shields.io/pypi/pyversions/LightZero)
+![Loc](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/HansBug/e002642132ec758e99264118c66778a4/raw/loc.json)
+![Comments](https://img.shields.io/endpoint?url=https://gist.githubusercontent.com/HansBug/e002642132ec758e99264118c66778a4/raw/comments.json)
+
+[![Code Test](https://github.com/opendilab/LightZero/workflows/Code%20Test/badge.svg)](https://github.com/opendilab/LightZero/actions?query=workflow%3A%22Code+Test%22)
+[![Badge Creation](https://github.com/opendilab/LightZero/workflows/Badge%20Creation/badge.svg)](https://github.com/opendilab/LightZero/actions?query=workflow%3A%22Badge+Creation%22)
+[![Package Release](https://github.com/opendilab/LightZero/workflows/Package%20Release/badge.svg)](https://github.com/opendilab/LightZero/actions?query=workflow%3A%22Package+Release%22)
+
+![GitHub Org's stars](https://img.shields.io/github/stars/opendilab)
+[![GitHub stars](https://img.shields.io/github/stars/opendilab/LightZero)](https://github.com/opendilab/LightZero/stargazers)
+[![GitHub forks](https://img.shields.io/github/forks/opendilab/LightZero)](https://github.com/opendilab/LightZero/network)
+![GitHub commit activity](https://img.shields.io/github/commit-activity/m/opendilab/LightZero)
+[![GitHub issues](https://img.shields.io/github/issues/opendilab/LightZero)](https://github.com/opendilab/LightZero/issues)
+[![GitHub pulls](https://img.shields.io/github/issues-pr/opendilab/LightZero)](https://github.com/opendilab/LightZero/pulls)
+[![Contributors](https://img.shields.io/github/contributors/opendilab/LightZero)](https://github.com/opendilab/LightZero/graphs/contributors)
+[![GitHub license](https://img.shields.io/github/license/opendilab/LightZero)](https://github.com/opendilab/LightZero/blob/master/LICENSE)
+
+Updated on 2023.12.07 LightZero-v0.0.3
+
+> LightZero is a lightweight, efficient, and easy-to-understand open-source algorithm toolkit that combines Monte Carlo Tree Search (MCTS) and Deep Reinforcement Learning (RL).
+
+English | [简体中文(Simplified Chinese)](https://github.com/opendilab/LightZero/blob/main/README.zh.md) | [Paper](https://arxiv.org/pdf/2310.08348.pdf)
+
+## Background
+
+The integration of Monte Carlo Tree Search and Deep Reinforcement Learning,
+exemplified by AlphaZero and MuZero,
+has achieved unprecedented performance levels in various games, including Go and Atari.
+This advanced methodology has also made significant strides in scientific domains like protein structure prediction and the search for matrix multiplication algorithms.
+The following is an overview of the historical evolution of the Monte Carlo Tree Search algorithm series:
+![pipeline](assets/mcts_rl_evolution_overview.png)
+
+## Overview
+
+**LightZero** is an open-source algorithm toolkit that combines MCTS and RL for PyTorch. It provides support for a range of MCTS-based RL algorithms and applications with the following advantages:
+- Lightweight.
+- Efficient.
+- Easy-to-understand.
+
+For further details, please refer to [Features](#features), [Framework Structure](#framework-structure) and [Integrated Algorithms](#integrated-algorithms).
+
+**LightZero** aims to **promote the standardization of the MCTS+RL algorithm family to accelerate related research and applications**. A performance comparison of all implemented algorithms under a unified framework is presented in the [Benchmark](#benchmark).
+
+### Outline
+
+- [Overview](#overview)
+ - [Outline](#outline)
+ - [Features](#features)
+ - [Framework Structure](#framework-structure)
+ - [Integrated Algorithms](#integrated-algorithms)
+- [Installation](#installation)
+- [Quick Start](#quick-start)
+- [Benchmark](#benchmark)
+- [Awesome-MCTS Notes](#awesome-mcts-notes)
+ - [Paper Notes](#paper-notes)
+ - [Algo. Overview](#algo-overview)
+- [Awesome-MCTS Papers](#awesome-mcts-papers)
+ - [Key Papers](#key-papers)
+ - [Other Papers](#other-papers)
+- [Feedback and Contribution](#feedback-and-contribution)
+- [Citation](#citation)
+- [Acknowledgments](#acknowledgments)
+- [License](#license)
+
+### Features
+
+**Lightweight**: LightZero integrates multiple MCTS algorithm families and can solve decision-making problems with various attributes in a lightweight framework. The algorithms and environments LightZero implemented can be found [here](#integrated-algorithms).
+
+**Efficient**: LightZero uses mixed heterogeneous computing programming to improve computational efficiency for the most time-consuming part of MCTS algorithms.
+
+**Easy-to-understand**: LightZero provides detailed documentation and algorithm framework diagrams for all integrated algorithms to help users understand the algorithm's core and compare the differences and similarities between algorithms under the same paradigm. LightZero also provides function call graphs and network structure diagrams for algorithm code implementation, making it easier for users to locate critical code. All the documentation can be found [here](#paper-notes).
+
+### Framework Structure
+
+[comment]: <> (
+
+The above picture is the framework pipeline of LightZero. We briefly introduce the three core modules below:
+
+**Model**:
+``Model`` is used to define the network structure, including the ``__init__`` function for initializing the network structure and the ``forward`` function for computing the network's forward propagation.
+
+**Policy**:
+``Policy`` defines the way the network is updated and interacts with the environment, including three processes: the ``learning`` process, the ``collecting`` process, and the ``evaluation`` process.
+
+**MCTS**:
+``MCTS`` defines the structure of the Monte Carlo search tree and the way it interacts with the Policy. The implementation of MCTS includes two languages: Python and C++, implemented in ``ptree`` and ``ctree``, respectively.
+
+For the file structure of LightZero, please refer to [lightzero_file_structure](https://github.com/opendilab/LightZero/blob/main/assets/lightzero_file_structure.svg).
+
+### Integrated Algorithms
+LightZero is a library with a [PyTorch](https://pytorch.org/) implementation of MCTS algorithms (sometimes combined with cython and cpp), including:
+- [AlphaZero](https://www.science.org/doi/10.1126/science.aar6404)
+- [MuZero](https://arxiv.org/abs/1911.08265)
+- [Sampled MuZero](https://arxiv.org/abs/2104.06303)
+- [Stochastic MuZero](https://openreview.net/pdf?id=X6D9bAHhBQ1)
+- [EfficientZero](https://arxiv.org/abs/2111.00210)
+- [Gumbel MuZero](https://openreview.net/pdf?id=bERaNdoegnO&)
+
+The environments and algorithms currently supported by LightZero are shown in the table below:
+
+| Env./Algo. | AlphaZero | MuZero | EfficientZero | Sampled EfficientZero | Gumbel MuZero | Stochastic MuZero |
+|---------------| --------- | ------ |-------------| ------------------ | ---------- |----------------|
+| TicTacToe | ✔ | ✔ | 🔒 | 🔒 | ✔ | 🔒 |
+| Gomoku | ✔ | ✔ | 🔒 | 🔒 | ✔ | 🔒 |
+| Connect4 | ✔ | ✔ | 🔒 | 🔒 | 🔒 | 🔒 |
+| 2048 | ✔ | ✔ | 🔒 | 🔒 | 🔒 | ✔ |
+| Chess | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |
+| Go | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 | 🔒 |
+| CartPole | --- | ✔ | ✔ | ✔ | ✔ | ✔ |
+| Pendulum | --- | ✔ | ✔ | ✔ | ✔ | ✔ |
+| LunarLander | --- | ✔ | ✔ | ✔ | ✔ | ✔ |
+| BipedalWalker | --- | ✔ | ✔ | ✔ | ✔ | 🔒 |
+| Atari | --- | ✔ | ✔ | ✔ | ✔ | ✔ |
+| MuJoCo | --- | ✔ | ✔ | ✔ | 🔒 | 🔒 |
+| MiniGrid | --- | ✔ | ✔ | ✔ | 🔒 | 🔒 |
+| Bsuite | --- | ✔ | ✔ | ✔ | 🔒 | 🔒 |
+
+(1): "✔" means that the corresponding item is finished and well-tested.
+
+(2): "🔒" means that the corresponding item is in the waiting-list (Work In Progress).
+
+(3): "---" means that this algorithm doesn't support this environment.
+
+
+## Installation
+
+You can install the latest LightZero in development from the GitHub source codes with the following command:
+
+```bash
+git clone https://github.com/opendilab/LightZero.git
+cd LightZero
+pip3 install -e .
+```
+
+Kindly note that LightZero currently supports compilation only on `Linux` and `macOS` platforms.
+We are actively working towards extending this support to the `Windows` platform.
+Your patience during this transition is greatly appreciated.
+
+## Installation with Docker
+
+We also provide a Dockerfile that sets up an environment with all dependencies needed to run the LightZero library. This Docker image is based on Ubuntu 20.04 and installs Python 3.8, along with other necessary tools and libraries.
+Here's how to use our Dockerfile to build a Docker image, run a container from this image, and execute LightZero code inside the container.
+1. **Download the Dockerfile**: The Dockerfile is located in the root directory of the LightZero repository. Download this [file](https://github.com/opendilab/LightZero/blob/main/Dockerfile) to your local machine.
+2. **Prepare the build context**: Create a new empty directory on your local machine, move the Dockerfile into this directory, and navigate into this directory. This step helps to avoid sending unnecessary files to the Docker daemon during the build process.
+ ```bash
+ mkdir lightzero-docker
+ mv Dockerfile lightzero-docker/
+ cd lightzero-docker/
+ ```
+3. **Build the Docker image**: Use the following command to build the Docker image. This command should be run from inside the directory that contains the Dockerfile.
+ ```bash
+ docker build -t ubuntu-py38-lz:latest -f ./Dockerfile .
+ ```
+4. **Run a container from the image**: Use the following command to start a container from the image in interactive mode with a Bash shell.
+ ```bash
+ docker run -dit --rm ubuntu-py38-lz:latest /bin/bash
+ ```
+5. **Execute LightZero code inside the container**: Once you're inside the container, you can run the example Python script with the following command:
+ ```bash
+ python ./LightZero/zoo/classic_control/cartpole/config/cartpole_muzero_config.py
+ ```
+
+[comment]: <> (- [AlphaGo Zero](https://www.nature.com/articles/nature24270) )
+
+## Quick Start
+
+Train a MuZero agent to play [CartPole](https://gymnasium.farama.org/environments/classic_control/cart_pole/):
+
+```bash
+cd LightZero
+python3 -u zoo/classic_control/cartpole/config/cartpole_muzero_config.py
+```
+
+Train a MuZero agent to play [Pong](https://gymnasium.farama.org/environments/atari/pong/):
+
+```bash
+cd LightZero
+python3 -u zoo/atari/config/atari_muzero_config.py
+```
+
+Train a MuZero agent to play [TicTacToe](https://en.wikipedia.org/wiki/Tic-tac-toe):
+
+```bash
+cd LightZero
+python3 -u zoo/board_games/tictactoe/config/tictactoe_muzero_bot_mode_config.py
+```
+
+## Benchmark
+
+Click to collapse
+
+- Below are the benchmark results of [AlphaZero](https://github.com/opendilab/LightZero/blob/main/lzero/policy/alphazero.py) and [MuZero](https://github.com/opendilab/LightZero/blob/main/lzero/policy/muzero.py) on three board games: [TicTacToe](https://github.com/opendilab/LightZero/blob/main/zoo/board_games/tictactoe/envs/tictactoe_env.py), [Connect4](https://github.com/opendilab/LightZero/blob/main/zoo/board_games/connect4/envs/connect4_env.py), [Gomoku](https://github.com/opendilab/LightZero/blob/main/zoo/board_games/gomoku/envs/gomoku_env.py).
+
+
+
+
+
+
+- Below are the benchmark results of [MuZero](https://github.com/opendilab/LightZero/blob/main/lzero/policy/muzero.py), [MuZero w/ SSL](https://github.com/opendilab/LightZero/blob/main/lzero/policy/muzero.py) , [EfficientZero](https://github.com/opendilab/LightZero/blob/main/lzero/policy/efficientzero.py) and [Sampled EfficientZero](https://github.com/opendilab/LightZero/blob/main/lzero/policy/sampled_efficientzero.py) on three discrete action space games in [Atari](https://github.com/opendilab/LightZero/blob/main/zoo/atari/envs/atari_lightzero_env.py).
+
+
+
+
+
+
+
+
+- Below are the benchmark results of [Sampled EfficientZero](https://github.com/opendilab/LightZero/blob/main/lzero/policy/sampled_efficientzero.py) with ``Factored/Gaussian`` policy representation on three classic continuous action space games: [Pendulum-v1](https://github.com/opendilab/LightZero/blob/main/zoo/classic_control/pendulum/envs/pendulum_lightzero_env.py), [LunarLanderContinuous-v2](https://github.com/opendilab/LightZero/blob/main/zoo/box2d/lunarlander/envs/lunarlander_env.py), [BipedalWalker-v3](https://github.com/opendilab/LightZero/blob/main/zoo/box2d/bipedalwalker/envs/bipedalwalker_env.py)
+and two MuJoCo continuous action space games: [Hopper-v3](https://github.com/opendilab/LightZero/blob/main/zoo/mujoco/envs/mujoco_lightzero_env.py), [Walker2d-v3](https://github.com/opendilab/LightZero/blob/main/zoo/mujoco/envs/mujoco_lightzero_env.py).
+> "Factored Policy" indicates that the agent learns a policy network that outputs a categorical distribution. After manual discretization, the dimensions of the action space for the five environments are 11, 49 (7^2), 256 (4^4), 64 (4^3), and 4096 (4^6), respectively. On the other hand, "Gaussian Policy" refers to the agent learning a policy network that directly outputs parameters (mu and sigma) for a Gaussian distribution.
+
+
+
+
+
+
+
+
+
+
+
+- Below are the benchmark results of [GumbelMuZero](https://github.com/opendilab/LightZero/blob/main/lzero/policy/gumbel_muzero.py) and [MuZero](https://github.com/opendilab/LightZero/blob/main/lzero/policy/muzero.py) (under different simulation cost) on four environments: [PongNoFrameskip-v4](https://github.com/opendilab/LightZero/blob/main/zoo/atari/envs/atari_lightzero_env.py), [MsPacmanNoFrameskip-v4]((https://github.com/opendilab/LightZero/blob/main/zoo/atari/envs/atari_lightzero_env.py)), [Gomoku](https://github.com/opendilab/LightZero/blob/main/zoo/board_games/gomoku/envs/gomoku_env.py), and [LunarLanderContinuous-v2](https://github.com/opendilab/LightZero/blob/main/zoo/box2d/lunarlander/envs/lunarlander_env.py).
+
+
+
+
+
+
+
+- Below are the benchmark results of [StochasticMuZero](https://github.com/opendilab/LightZero/blob/main/lzero/policy/stochastic_muzero.py) and [MuZero](https://github.com/opendilab/LightZero/blob/main/lzero/policy/muzero.py) on [2048 environment](https://github.com/opendilab/LightZero/blob/main/zoo/game_2048/envs/game_2048_env.py) with varying levels of chance (num_chances=2 and 5).
+
+
+
+
+
+- Below are the benchmark results of various MCTS exploration mechanisms of [MuZero w/ SSL](https://github.com/opendilab/LightZero/blob/main/lzero/policy/muzero.py) in the [MiniGrid environment](https://github.com/opendilab/LightZero/blob/main/zoo/minigrid/envs/minigrid_lightzero_env.py).
+
+
+
+
+
+
+
+
+## Awesome-MCTS Notes
+
+### Paper Notes
+The following are the detailed paper notes (in Chinese) of the above algorithms:
+
+Click to collapse
+
+
+- [AlphaZero](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/AlphaZero.pdf)
+- [MuZero](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/MuZero.pdf)
+- [EfficientZero](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/EfficientZero.pdf)
+- [SampledMuZero](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/SampledMuZero.pdf)
+- [GumbelMuZero](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/GumbelMuZero.pdf)
+- [StochasticMuZero](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/StochasticMuZero.pdf)
+- [NotationTable](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/SymbolTable.pdf)
+
+
+
+### Algo. Overview
+
+The following are the overview MCTS principle diagrams of the above algorithms:
+
+Click to expand
+
+- [MCTS](https://github.com/opendilab/LightZero/blob/main/assets/algo_overview/mcts_overview.pdf)
+- [AlphaZero](https://github.com/opendilab/LightZero/blob/main/assets/algo_overview/alphazero_overview.pdf)
+- [MuZero](https://github.com/opendilab/LightZero/blob/main/assets/algo_overview/muzero_overview.pdf)
+- [EfficientZero](https://github.com/opendilab/LightZero/blob/main/assets/algo_overview/efficientzero_overview.pdf)
+- [SampledMuZero](https://github.com/opendilab/LightZero/blob/main/assets/algo_overview/sampled_muzero_overview.pdf)
+- [GumbelMuZero](https://github.com/opendilab/LightZero/blob/main/assets/algo_overview/gumbel_muzero_overview.pdf)
+
+
+
+## Awesome-MCTS Papers
+
+Here is a collection of research papers about **Monte Carlo Tree Search**.
+[This Section](#awesome-msts-papers) will be continuously updated to track the frontier of MCTS.
+
+### Key Papers
+
+Click to expand
+
+#### LightZero Implemented series
+
+- [2018 _Science_ AlphaZero: A general reinforcement learning algorithm that masters chess, shogi, and Go through self-play](https://www.science.org/doi/10.1126/science.aar6404)
+- [2019 MuZero: Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model](https://arxiv.org/abs/1911.08265)
+- [2021 EfficientZero: Mastering Atari Games with Limited Data](https://arxiv.org/abs/2111.00210)
+- [2021 Sampled MuZero: Learning and Planning in Complex Action Spaces](https://arxiv.org/abs/2104.06303)
+- [2022 Stochastic MuZero: Planning in Stochastic Environments with A Learned Model](https://openreview.net/pdf?id=X6D9bAHhBQ1)
+- [2022 Gumbel MuZero: Policy Improvement by Planning with Gumbel](https://openreview.net/pdf?id=bERaNdoegnO&)
+
+#### AlphaGo series
+- [2015 _Nature_ AlphaGo Mastering the game of Go with deep neural networks and tree search](https://www.nature.com/articles/nature16961)
+- [2017 _Nature_ AlphaGo Zero Mastering the game of Go without human knowledge](https://www.nature.com/articles/nature24270)
+- [2019 ELF OpenGo: An Analysis and Open Reimplementation of AlphaZero](https://arxiv.org/abs/1902.04522)
+ - [Code](https://github.com/pytorch/ELF)
+- [2023 Student of Games: A unified learning algorithm for both perfect and imperfect information games](https://www.science.org/doi/10.1126/sciadv.adg3256)
+
+#### MuZero series
+- [2022 Online and Offline Reinforcement Learning by Planning with a Learned Model](https://arxiv.org/abs/2104.06294)
+- [2021 Vector Quantized Models for Planning](https://arxiv.org/abs/2106.04615)
+- [2021 Muesli: Combining Improvements in Policy Optimization. ](https://arxiv.org/abs/2104.06159)
+#### MCTS Analysis
+- [2020 Monte-Carlo Tree Search as Regularized Policy Optimization](https://arxiv.org/abs/2007.12509)
+- [2021 Self-Consistent Models and Values](https://arxiv.org/abs/2110.12840)
+- [2022 Adversarial Policies Beat Professional-Level Go AIs](https://arxiv.org/abs/2211.00241)
+- [2022 _PNAS_ Acquisition of Chess Knowledge in AlphaZero.](https://arxiv.org/abs/2111.09259)
+
+#### MCTS Application
+- [2023 Symbolic Physics Learner: Discovering governing equations via Monte Carlo tree search](https://openreview.net/pdf?id=ZTK3SefE8_Z)
+- [2022 _Nature_ Discovering faster matrix multiplication algorithms with reinforcement learning](https://www.nature.com/articles/s41586-022-05172-4)
+ - [Code](https://github.com/deepmind/alphatensor)
+- [2022 MuZero with Self-competition for Rate Control in VP9 Video Compression](https://arxiv.org/abs/2202.06626)
+- [2021 DouZero: Mastering DouDizhu with Self-Play Deep Reinforcement Learning](https://arxiv.org/abs/2106.06135)
+- [2019 Combining Planning and Deep Reinforcement Learning in Tactical Decision Making for Autonomous Driving](https://arxiv.org/pdf/1905.02680.pdf)
+
+
+
+### Other Papers
+
+Click to expand
+
+#### ICML
+- [Scalable Safe Policy Improvement via Monte Carlo Tree Search](https://openreview.net/pdf?id=tevbBSzSfK) 2023
+ - Alberto Castellini, Federico Bianchi, Edoardo Zorzi, Thiago D. Simão, Alessandro Farinelli, Matthijs T. J. Spaan
+ - Key: safe policy improvement online using a MCTS based strategy, Safe Policy Improvement with Baseline Bootstrapping
+ - ExpEnv: Gridworld and SysAdmin
+- [Efficient Learning for AlphaZero via Path Consistency](https://proceedings.mlr.press/v162/zhao22h/zhao22h.pdf) 2022
+ - Dengwei Zhao, Shikui Tu, Lei Xu
+ - Key: limited amount of self-plays, path consistency (PC) optimality
+ - ExpEnv: Go, Othello, Gomoku
+- [Visualizing MuZero Models](https://arxiv.org/abs/2102.12924) 2021
+ - Joery A. de Vries, Ken S. Voskuil, Thomas M. Moerland, Aske Plaat
+ - Key: visualizing the value equivalent dynamics model, action trajectories diverge, two regularization techniques
+ - ExpEnv: CartPole and MountainCar.
+- [Convex Regularization in Monte-Carlo Tree Search](https://arxiv.org/pdf/2007.00391.pdf) 2021
+ - Tuan Dam, Carlo D'Eramo, Jan Peters, Joni Pajarinen
+ - Key: entropy-regularization backup operators, regret analysis, Tsallis etropy,
+ - ExpEnv: synthetic tree, Atari
+- [Information Particle Filter Tree: An Online Algorithm for POMDPs with Belief-Based Rewards on Continuous Domains](http://proceedings.mlr.press/v119/fischer20a/fischer20a.pdf) 2020
+ - Johannes Fischer, Ömer Sahin Tas
+ - Key: Continuous POMDP, Particle Filter Tree, information-based reward shaping, Information Gathering.
+ - ExpEnv: POMDPs.jl framework
+ - [Code](https://github.com/johannes-fischer/icml2020_ipft)
+- [Retro*: Learning Retrosynthetic Planning with Neural Guided A* Search](http://proceedings.mlr.press/v119/chen20k/chen20k.pdf) 2020
+ - Binghong Chen, Chengtao Li, Hanjun Dai, Le Song
+ - Key: chemical retrosynthetic planning, neural-based A*-like algorithm, ANDOR tree
+ - ExpEnv: USPTO datasets
+ - [Code](https://github.com/binghong-ml/retro_star)
+#### ICLR
+- [Become a Proficient Player with Limited Data through Watching Pure Videos](https://openreview.net/pdf?id=Sy-o2N0hF4f) 2023
+ - Weirui Ye, Yunsheng Zhang, Pieter Abbeel, Yang Gao
+ - Key: pre-training from action-free videos, forward-inverse cycle consistency (FICC) objective based on vector quantization, pre-training phase, fine-tuning phase.
+ - ExpEnv: Atari
+- [Policy-Based Self-Competition for Planning Problems](https://arxiv.org/abs/2306.04403) 2023
+ - Jonathan Pirnay, Quirin Göttl, Jakob Burger, Dominik Gerhard Grimm
+ - Key: self-competition, find strong trajectories by planning against possible strategies of its past self.
+ - ExpEnv: Traveling Salesman Problem and the Job-Shop Scheduling Problem.
+- [Explaining Temporal Graph Models through an Explorer-Navigator Framework](https://openreview.net/pdf?id=BR_ZhvcYbGJ) 2023
+ - Wenwen Xia, Mincai Lai, Caihua Shan, Yao Zhang, Xinnan Dai, Xiang Li, Dongsheng Li
+ - Key: Temporal GNN Explainer, an explorer to find the event subsets with MCTS, a navigator that learns the correlations between events and helps reduce the search space.
+ - ExpEnv: Wikipedia and Reddit, Synthetic datasets
+- [SpeedyZero: Mastering Atari with Limited Data and Time](https://openreview.net/pdf?id=Mg5CLXZgvLJ) 2023
+ - Yixuan Mei, Jiaxuan Gao, Weirui Ye, Shaohuai Liu, Yang Gao, Yi Wu
+ - Key: distributed RL system, Priority Refresh, Clipped LARS
+ - ExpEnv: Atari
+- [Efficient Offline Policy Optimization with a Learned Model](https://openreview.net/pdf?id=Yt-yM-JbYFO) 2023
+ - Zichen Liu, Siyi Li, Wee Sun Lee, Shuicheng YAN, Zhongwen Xu
+ - Key: Regularized One-Step Model-based algorithm for Offline-RL
+ - ExpEnv: Atari,BSuite
+ - [Code](https://github.com/sail-sg/rosmo/tree/main)
+- [Enabling Arbitrary Translation Objectives with Adaptive Tree Search](https://arxiv.org/pdf/2202.11444.pdf) 2022
+ - Wang Ling, Wojciech Stokowiec, Domenic Donato, Chris Dyer, Lei Yu, Laurent Sartran, Austin Matthews
+ - Key: adaptive tree search, translation models, autoregressive models,
+ - ExpEnv: Chinese–English and Pashto–English tasks from WMT2020, German–English from WMT2014
+- [What's Wrong with Deep Learning in Tree Search for Combinatorial Optimization](https://arxiv.org/abs/2201.10494) 2022
+ - Maximili1an Böther, Otto Kißig, Martin Taraz, Sarel Cohen, Karen Seidel, Tobias Friedrich
+ - Key: combinatorial optimization, open-source benchmark suite for the NP-hard maximum independent set problem, an in-depth analysis of the popular guided tree search algorithm, compare the tree search implementations to other solvers
+ - ExpEnv: NP-hard MAXIMUM INDEPENDENT SET.
+ - [Code](https://github.com/maxiboether/mis-benchmark-framework)
+- [Monte-Carlo Planning and Learning with Language Action Value Estimates](https://openreview.net/pdf?id=7_G8JySGecm) 2021
+ - Youngsoo Jang, Seokin Seo, Jongmin Lee, Kee-Eung Kim
+ - Key: Monte-Carlo tree search with language-driven exploration, locally optimistic language value estimates.
+ - ExpEnv: Interactive Fiction (IF) games
+- [Practical Massively Parallel Monte-Carlo Tree Search Applied to Molecular Design](https://arxiv.org/abs/2006.10504) 2021
+ - Xiufeng Yang, Tanuj Kr Aasawat, Kazuki Yoshizoe
+ - Key: massively parallel Monte-Carlo Tree Search, molecular design, Hash-driven parallel search,
+ - ExpEnv: octanol-water partition coefficient (logP) penalized by the synthetic accessibility (SA) and large Ring Penalty score.
+- [Watch the Unobserved: A Simple Approach to Parallelizing Monte Carlo Tree Search](https://arxiv.org/pdf/1810.11755.pdf) 2020
+ - Anji Liu, Jianshu Chen, Mingze Yu, Yu Zhai, Xuewen Zhou, Ji Liu
+ - Key: parallel Monte-Carlo Tree Search, partition the tree into sub-trees efficiently, compare the observation ratio of each processor.
+ - ExpEnv: speedup and performance comparison on JOY-CITY game, average episode return on atari game
+ - [Code](https://github.com/liuanji/WU-UCT)
+- [Learning to Plan in High Dimensions via Neural Exploration-Exploitation Trees](https://openreview.net/pdf?id=rJgJDAVKvB) 2020
+ - Binghong Chen, Bo Dai, Qinjie Lin, Guo Ye, Han Liu, Le Song
+ - Key: meta path planning algorithm, exploits a novel neural architecture which can learn promising search directions from problem structures.
+ - ExpEnv: a 2d workspace with a 2 DoF (degrees of freedom) point robot, a 3 DoF stick robot and a 5 DoF snake robot
+#### NeurIPS
+- [LightZero: A Unified Benchmark for Monte Carlo Tree Search in General Sequential Decision Scenarios](https://openreview.net/pdf?id=oIUXpBnyjv) 2023
+ - Yazhe Niu, Yuan Pu, Zhenjie Yang, Xueyan Li, Tong Zhou, Jiyuan Ren, Shuai Hu, Hongsheng Li, Yu Liu
+ - Key: the first unified benchmark for deploying MCTS/MuZero in general sequential decision scenarios.
+ - ExpEnv: ClassicControl, Box2D, Atari, MuJoCo, GoBigger, MiniGrid, TicTacToe, ConnectFour, Gomoku, 2048, etc.
+- [Large Language Models as Commonsense Knowledge for Large-Scale Task Planning](https://openreview.net/pdf?id=Wjp1AYB8lH) 2023
+ - Zirui Zhao, Wee Sun Lee, David Hsu
+ - Key: world model (LLM) and the LLM-induced policy can be combined in MCTS, to scale up task planning.
+ - ExpEnv: multiplication, travel planning, object rearrangement
+- [Monte Carlo Tree Search with Boltzmann Exploration](https://openreview.net/pdf?id=NG4DaApavi) 2023
+ - Michael Painter, Mohamed Baioumy, Nick Hawes, Bruno Lacerda
+ - Key: Boltzmann exploration with MCTS, optimal actions for the maximum entropy objective do not necessarily correspond to optimal actions for the original objective, two improved algorithms.
+ - ExpEnv: the Frozen Lake environment, the Sailing Problem, Go
+- [Generalized Weighted Path Consistency for Mastering Atari Games](https://openreview.net/pdf?id=vHRLS8HhK1) 2023
+ - Dengwei Zhao, Shikui Tu, Lei Xu
+ - Key: Generalized Weighted Path Consistency, A weighting mechanism.
+ - ExpEnv: Atari
+- [Accelerating Monte Carlo Tree Search with Probability Tree State Abstraction](https://openreview.net/pdf?id=0zeLTZAqaJ) 2023
+ - Yangqing Fu, Ming Sun, Buqing Nie, Yue Gao
+ - Key: probability tree state abstraction, transitivity and aggregation error bound
+ - ExpEnv: Atari, CartPole, LunarLander, Gomoku
+- [Planning for Sample Efficient Imitation Learning](https://openreview.net/forum?id=BkN5UoAqF7) 2022
+ - Zhao-Heng Yin, Weirui Ye, Qifeng Chen, Yang Gao
+ - Key: Behavioral Cloning,Adversarial Imitation Learning (AIL),MCTS-based RL.
+ - ExpEnv: DeepMind Control Suite
+ - [Code](https://github.com/zhaohengyin/EfficientImitate)
+- [Evaluation Beyond Task Performance: Analyzing Concepts in AlphaZero in Hex](https://openreview.net/pdf?id=dwKwB2Cd-Km) 2022
+ - Charles Lovering, Jessica Zosa Forde, George Konidaris, Ellie Pavlick, Michael L. Littman
+ - Key: AlphaZero’s internal representations, model probing and behavioral tests, how these concepts are captured in the network.
+ - ExpEnv: Hex
+- [Are AlphaZero-like Agents Robust to Adversarial Perturbations?](https://openreview.net/pdf?id=yZ_JlZaOCzv) 2022
+ - Li-Cheng Lan, Huan Zhang, Ti-Rong Wu, Meng-Yu Tsai, I-Chen Wu, 4 Cho-Jui Hsieh
+ - Key: adversarial states, first adversarial attack on Go AIs.
+ - ExpEnv: Go
+- [Monte Carlo Tree Descent for Black-Box Optimization](https://openreview.net/pdf?id=FzdmrTUyZ4g) 2022
+ - Yaoguang Zhai, Sicun Gao
+ - Key: Black-Box Optimization, how to further integrate samplebased descent for faster optimization.
+ - ExpEnv: synthetic functions for nonlinear optimization, reinforcement learning problems in MuJoCo locomotion environments, and optimization problems in Neural Architecture Search (NAS).
+- [Monte Carlo Tree Search based Variable Selection for High Dimensional Bayesian Optimization](https://openreview.net/pdf?id=SUzPos_pUC) 2022
+ - Lei Song∗ , Ke Xue∗ , Xiaobin Huang, Chao Qian
+ - Key: a low-dimensional subspace via MCTS, optimizes in the subspace with any Bayesian optimization algorithm.
+ - ExpEnv: NAS-bench problems and MuJoCo locomotion
+- [Monte Carlo Tree Search With Iteratively Refining State Abstractions](https://proceedings.neurips.cc/paper/2021/file/9b0ead00a217ea2c12e06a72eec4923f-Paper.pdf) 2021
+ - Samuel Sokota, Caleb Ho, Zaheen Ahmad, J. Zico Kolter
+ - Key: stochastic environments, Progressive widening, abstraction refining
+ - ExpEnv: Blackjack, Trap, five by five Go.
+- [Deep Synoptic Monte Carlo Planning in Reconnaissance Blind Chess](https://proceedings.neurips.cc/paper/2021/file/215a71a12769b056c3c32e7299f1c5ed-Paper.pdf) 2021
+ - Gregory Clark
+ - Key: imperfect information, belief state with an unweighted particle filter, a novel stochastic abstraction of information states.
+ - ExpEnv: reconnaissance blind chess
+- [POLY-HOOT: Monte-Carlo Planning in Continuous Space MDPs with Non-Asymptotic Analysis](https://proceedings.neurips.cc/paper/2020/file/30de24287a6d8f07b37c716ad51623a7-Paper.pdf) 2020
+ - Weichao Mao, Kaiqing Zhang, Qiaomin Xie, Tamer Ba¸sar
+ - Key: continuous state-action spaces, Hierarchical Optimistic Optimization.
+ - ExpEnv: CartPole, Inverted Pendulum, Swing-up, and LunarLander.
+- [Learning Search Space Partition for Black-box Optimization using Monte Carlo Tree Search](https://proceedings.neurips.cc/paper/2020/file/e2ce14e81dba66dbff9cbc35ecfdb704-Paper.pdf) 2020
+ - Linnan Wang, Rodrigo Fonseca, Yuandong Tian
+ - Key: learns the partition of the search space using a few samples, a nonlinear decision boundary and learns a local model to pick good candidates.
+ - ExpEnv: MuJoCo locomotion tasks, Small-scale Benchmarks,
+- [Mix and Match: An Optimistic Tree-Search Approach for Learning Models from Mixture Distributions](https://arxiv.org/abs/1907.10154) 2020
+ - Matthew Faw, Rajat Sen, Karthikeyan Shanmugam, Constantine Caramanis, Sanjay Shakkottai
+ - Key: covariate shift problem, Mix&Match combines stochastic gradient descent (SGD) with optimistic tree search and model re-use (evolving partially trained models with samples from different mixture distributions)
+ - [Code](https://github.com/matthewfaw/mixnmatch)
+
+#### Other Conference or Journal
+- [On Monte Carlo Tree Search and Reinforcement Learning](https://www.jair.org/index.php/jair/article/download/11099/26289/20632) Journal of Artificial Intelligence Research 2017.
+- [Sample-Efficient Neural Architecture Search by Learning Actions for Monte Carlo Tree Search](https://arxiv.org/pdf/1906.06832) IEEE Transactions on Pattern Analysis and Machine Intelligence 2022.
+
+
+
+## Feedback and Contribution
+- [File an issue](https://github.com/opendilab/LightZero/issues/new/choose) on Github
+- Contact our email (opendilab@pjlab.org.cn)
+
+- We appreciate all the feedback and contributions to improve LightZero, both algorithms and system designs.
+
+[comment]: <> (- Contributes to our future plan [Roadmap](https://github.com/opendilab/LightZero/projects))
+
+[comment]: <> (And `CONTRIBUTING.md` offers some necessary information.)
+
+
+## Citation
+```latex
+@misc{lightzero,
+ title={LightZero: A Unified Benchmark for Monte Carlo Tree Search in General Sequential Decision Scenarios},
+ author={Yazhe Niu and Yuan Pu and Zhenjie Yang and Xueyan Li and Tong Zhou and Jiyuan Ren and Shuai Hu and Hongsheng Li and Yu Liu},
+ year={2023},
+ eprint={2310.08348},
+ archivePrefix={arXiv},
+ primaryClass={cs.LG}
+}
+```
+
+## Acknowledgments
+
+This project has been developed partially based on the following pioneering works on GitHub repositories.
+We express our profound gratitude for these foundational resources:
+- https://github.com/opendilab/DI-engine
+- https://github.com/deepmind/mctx
+- https://github.com/YeWR/EfficientZero
+- https://github.com/werner-duvaud/muzero-general
+
+We would like to extend our special thanks to the following contributors [@PaParaZz1](https://github.com/PaParaZz1), [@karroyan](https://github.com/karroyan), [@nighood](https://github.com/nighood),
+[@jayyoung0802](https://github.com/jayyoung0802), [@timothijoe](https://github.com/timothijoe), [@TuTuHuss](https://github.com/TuTuHuss), [@HarryXuancy](https://github.com/HarryXuancy), [@puyuan1996](https://github.com/puyuan1996), [@HansBug](https://github.com/HansBug) for their valuable contributions and support to this algorithm library.
+
+Thanks to all who contributed to this project:
+
+
+
+
+
+## License
+All code within this repository is under [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0).
+
+
+
+
+
+## MCTS 相关笔记
+
+### 论文笔记
+
+以下是 LightZero 中集成算法的中文详细文档:
+
+点击折叠
+
+[AlphaZero](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/AlphaZero.pdf)
+
+[MuZero](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/MuZero.pdf)
+
+[EfficientZero](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/EfficientZero.pdf)
+
+[SampledMuZero](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/SampledMuZero.pdf)
+
+[GumbelMuZero](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/GumbelMuZero.pdf)
+
+[StochasticMuZero](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/StochasticMuZero.pdf)
+
+[算法概览图符号表](https://github.com/opendilab/LightZero/blob/main/assets/paper_notes/NotationTable.pdf)
+
+
+
+### 算法框架图
+
+以下是 LightZero 中集成算法的框架概览图:
+
+
+(点击查看更多)
+
+[MCTS](https://github.com/opendilab/LightZero/blob/main/assets/algo_overview/mcts_overview.pdf)
+
+[AlphaZero](https://github.com/opendilab/LightZero/blob/main/assets/algo_overview/alphazero_overview.pdf)
+
+[MuZero](https://github.com/opendilab/LightZero/blob/main/assets/algo_overview/muzero_overview.pdf)
+
+[EfficientZero](https://github.com/opendilab/LightZero/blob/main/assets/algo_overview/efficientzero_overview.pdf)
+
+[SampledMuZero](https://github.com/opendilab/LightZero/blob/main/assets/algo_overview/sampled_muzero_overview.pdf)
+
+[GumbelMuZero](https://github.com/opendilab/LightZero/blob/main/assets/algo_overview/gumbel_muzero_overview.pdf)
+
+
+
+## MCTS 相关论文
+
+以下是关于 **MCTS** 相关的论文集合,[这一部分](#MCTS-相关论文) 将会持续更新,追踪 MCTS 的前沿动态。
+
+### 重要论文
+
+
+(点击查看更多)
+
+#### LightZero Implemented series
+
+- [2018 _Science_ AlphaZero: A general reinforcement learning algorithm that masters chess, shogi, and Go through self-play](https://www.science.org/doi/10.1126/science.aar6404)
+- [2019 MuZero: Mastering Atari, Go, Chess and Shogi by Planning with a Learned Model](https://arxiv.org/abs/1911.08265)
+- [2021 EfficientZero: Mastering Atari Games with Limited Data](https://arxiv.org/abs/2111.00210)
+- [2021 Sampled MuZero: Learning and Planning in Complex Action Spaces](https://arxiv.org/abs/2104.06303)
+- [2022 Stochastic MuZero: Plannig in Stochastic Environments with A Learned Model](https://openreview.net/pdf?id=X6D9bAHhBQ1)
+- [2022 Gumbel MuZero: Policy Improvement by Planning with Gumbel](https://openreview.net/pdf?id=bERaNdoegnO&)
+
+
+#### AlphaGo series
+
+- [2015 _Nature_ AlphaGo Mastering the game of Go with deep neural networks and tree search](https://www.nature.com/articles/nature16961)
+- [2017 _Nature_ AlphaGo Zero Mastering the game of Go without human knowledge](https://www.nature.com/articles/nature24270)
+- [2019 ELF OpenGo: An Analysis and Open Reimplementation of AlphaZero](https://arxiv.org/abs/1902.04522)
+ - [Code](https://github.com/pytorch/ELF)
+- [2023 Student of Games: A unified learning algorithm for both perfect and imperfect information games](https://www.science.org/doi/10.1126/sciadv.adg3256)
+
+#### MuZero series
+- [2022 Online and Offline Reinforcement Learning by Planning with a Learned Model](https://arxiv.org/abs/2104.06294)
+- [2021 Vector Quantized Models for Planning](https://arxiv.org/abs/2106.04615)
+- [2021 Muesli: Combining Improvements in Policy Optimization. ](https://arxiv.org/abs/2104.06159)
+
+#### MCTS Analysis
+- [2020 Monte-Carlo Tree Search as Regularized Policy Optimization](https://arxiv.org/abs/2007.12509)
+- [2021 Self-Consistent Models and Values](https://arxiv.org/abs/2110.12840)
+- [2022 Adversarial Policies Beat Professional-Level Go AIs](https://arxiv.org/abs/2211.00241)
+- [2022 _PNAS_ Acquisition of Chess Knowledge in AlphaZero.](https://arxiv.org/abs/2111.09259)
+
+#### MCTS Application
+- [2023 Symbolic Physics Learner: Discovering governing equations via Monte Carlo tree search](https://openreview.net/pdf?id=ZTK3SefE8_Z)
+- [2022 _Nature_ Discovering faster matrix multiplication algorithms with reinforcement learning](https://www.nature.com/articles/s41586-022-05172-4)
+ - [Code](https://github.com/deepmind/alphatensor)
+- [2022 MuZero with Self-competition for Rate Control in VP9 Video Compression](https://arxiv.org/abs/2202.06626)
+- [2021 DouZero: Mastering DouDizhu with Self-Play Deep Reinforcement Learning](https://arxiv.org/abs/2106.06135)
+- [2019 Combining Planning and Deep Reinforcement Learning in Tactical Decision Making for Autonomous Driving](https://arxiv.org/pdf/1905.02680.pdf)
+
+
+
+### 其他论文
+
+
+(点击查看更多)
+
+#### ICML
+- [Scalable Safe Policy Improvement via Monte Carlo Tree Search](https://openreview.net/pdf?id=tevbBSzSfK) 2023
+ - Alberto Castellini, Federico Bianchi, Edoardo Zorzi, Thiago D. Simão, Alessandro Farinelli, Matthijs T. J. Spaan
+ - Key: safe policy improvement online using a MCTS based strategy, Safe Policy Improvement with Baseline Bootstrapping
+ - ExpEnv: Gridworld and SysAdmin
+- [Efficient Learning for AlphaZero via Path Consistency](https://proceedings.mlr.press/v162/zhao22h/zhao22h.pdf) 2022
+ - Dengwei Zhao, Shikui Tu, Lei Xu
+ - Key: limited amount of self-plays, path consistency (PC) optimality
+ - ExpEnv: Go, Othello, Gomoku
+- [Visualizing MuZero Models](https://arxiv.org/abs/2102.12924) 2021
+ - Joery A. de Vries, Ken S. Voskuil, Thomas M. Moerland, Aske Plaat
+ - Key: visualizing the value equivalent dynamics model, action trajectories diverge, two regularization techniques
+ - ExpEnv: CartPole and MountainCar.
+and internal state transition dynamics,
+- [Convex Regularization in Monte-Carlo Tree Search](https://arxiv.org/pdf/2007.00391.pdf) 2021
+ - Tuan Dam, Carlo D'Eramo, Jan Peters, Joni Pajarinen
+ - Key: entropy-regularization backup operators, regret analysis, Tsallis etropy,
+ - ExpEnv: synthetic tree, Atari
+- [Information Particle Filter Tree: An Online Algorithm for POMDPs with Belief-Based Rewards on Continuous Domains](http://proceedings.mlr.press/v119/fischer20a/fischer20a.pdf) 2020
+ - Johannes Fischer, Ömer Sahin Tas
+ - Key: Continuous POMDP, Particle Filter Tree, information-based reward shaping, Information Gathering.
+ - ExpEnv: POMDPs.jl framework
+ - [Code](https://github.com/johannes-fischer/icml2020_ipft)
+- [Retro*: Learning Retrosynthetic Planning with Neural Guided A* Search](http://proceedings.mlr.press/v119/chen20k/chen20k.pdf) 2020
+ - Binghong Chen, Chengtao Li, Hanjun Dai, Le Song
+ - Key: chemical retrosynthetic planning, neural-based A*-like algorithm, ANDOR tree
+ - ExpEnv: USPTO datasets
+ - [Code](https://github.com/binghong-ml/retro_star)
+#### ICLR
+- [Become a Proficient Player with Limited Data through Watching Pure Videos](https://openreview.net/pdf?id=Sy-o2N0hF4f) 2023
+ - Weirui Ye, Yunsheng Zhang, Pieter Abbeel, Yang Gao
+ - Key: pre-training from action-free videos, forward-inverse cycle consistency (FICC) objective based on vector quantization, pre-training phase, fine-tuning phase.
+ - ExpEnv: Atari
+- [Policy-Based Self-Competition for Planning Problems](https://arxiv.org/abs/2306.04403) 2023
+ - Jonathan Pirnay, Quirin Göttl, Jakob Burger, Dominik Gerhard Grimm
+ - Key: self-competition, find strong trajectories by planning against possible strategies of its past self.
+ - ExpEnv: Traveling Salesman Problem and the Job-Shop Scheduling Problem.
+- [Explaining Temporal Graph Models through an Explorer-Navigator Framework](https://openreview.net/pdf?id=BR_ZhvcYbGJ) 2023
+ - Wenwen Xia, Mincai Lai, Caihua Shan, Yao Zhang, Xinnan Dai, Xiang Li, Dongsheng Li
+ - Key: Temporal GNN Explainer, an explorer to find the event subsets with MCTS, a navigator that learns the correlations between events and helps reduce the search space.
+ - ExpEnv: Wikipedia and Reddit, Synthetic datasets
+- [SpeedyZero: Mastering Atari with Limited Data and Time](https://openreview.net/pdf?id=Mg5CLXZgvLJ) 2023
+ - Yixuan Mei, Jiaxuan Gao, Weirui Ye, Shaohuai Liu, Yang Gao, Yi Wu
+ - Key: distributed RL system, Priority Refresh, Clipped LARS
+ - ExpEnv: Atari
+- [Efficient Offline Policy Optimization with a Learned Model](https://openreview.net/pdf?id=Yt-yM-JbYFO) 2023
+ - Zichen Liu, Siyi Li, Wee Sun Lee, Shuicheng YAN, Zhongwen Xu
+ - Key: Regularized One-Step Model-based algorithm for Offline-RL
+ - ExpEnv: Atari,BSuite
+ - [Code](https://github.com/sail-sg/rosmo/tree/main)
+- [Enabling Arbitrary Translation Objectives with Adaptive Tree Search](https://arxiv.org/pdf/2202.11444.pdf) 2022
+ - Wang Ling, Wojciech Stokowiec, Domenic Donato, Chris Dyer, Lei Yu, Laurent Sartran, Austin Matthews
+ - Key: adaptive tree search, translation models, autoregressive models,
+ - ExpEnv: Chinese–English and Pashto–English tasks from WMT2020, German–English from WMT2014
+- [What's Wrong with Deep Learning in Tree Search for Combinatorial Optimization](https://arxiv.org/abs/2201.10494) 2022
+ - Maximili1an Böther, Otto Kißig, Martin Taraz, Sarel Cohen, Karen Seidel, Tobias Friedrich
+ - Key: Combinatorial optimization, open-source benchmark suite for the NP-hard MAXIMUM INDEPENDENT SET problem, an in-depth analysis of the popular guided tree search algorithm, compare the tree search implementations to other solvers
+ - ExpEnv: NP-hard MAXIMUM INDEPENDENT SET.
+ - [Code](https://github.com/maxiboether/mis-benchmark-framework)
+- [Monte-Carlo Planning and Learning with Language Action Value Estimates](https://openreview.net/pdf?id=7_G8JySGecm) 2021
+ - Youngsoo Jang, Seokin Seo, Jongmin Lee, Kee-Eung Kim
+ - Key: Monte-Carlo tree search with language-driven exploration, locally optimistic language value estimates,
+ - ExpEnv: Interactive Fiction (IF) games
+- [Practical Massively Parallel Monte-Carlo Tree Search Applied to Molecular Design](https://arxiv.org/abs/2006.10504) 2021
+ - Xiufeng Yang, Tanuj Kr Aasawat, Kazuki Yoshizoe
+ - Key: massively parallel Monte-Carlo Tree Search, molecular design, Hash-driven parallel search,
+ - ExpEnv: octanol-water partition coefficient (logP) penalized by the synthetic accessibility (SA) and large Ring Penalty score.
+- [Watch the Unobserved: A Simple Approach to Parallelizing Monte Carlo Tree Search](https://arxiv.org/pdf/1810.11755.pdf) 2020
+ - Anji Liu, Jianshu Chen, Mingze Yu, Yu Zhai, Xuewen Zhou, Ji Liu
+ - Key: parallel Monte-Carlo Tree Search, partition the tree into sub-trees efficiently, compare the observation ratio of each processor
+ - ExpEnv: speedup and performance comparison on JOY-CITY game, average episode return on atari game
+ - [Code](https://github.com/liuanji/WU-UCT)
+- [Learning to Plan in High Dimensions via Neural Exploration-Exploitation Trees](https://openreview.net/pdf?id=rJgJDAVKvB) 2020
+ - Binghong Chen, Bo Dai, Qinjie Lin, Guo Ye, Han Liu, Le Song
+ - Key: meta path planning algorithm, exploits a novel neural architecture which can learn promising search directions from problem structures.
+ - ExpEnv: a 2d workspace with a 2 DoF (degrees of freedom) point robot, a 3 DoF stick robot and a 5 DoF snake robot
+#### NeurIPS
+
+- [LightZero: A Unified Benchmark for Monte Carlo Tree Search in General Sequential Decision Scenarios](https://openreview.net/pdf?id=oIUXpBnyjv) 2023
+ - Yazhe Niu, Yuan Pu, Zhenjie Yang, Xueyan Li, Tong Zhou, Jiyuan Ren, Shuai Hu, Hongsheng Li, Yu Liu
+ - Key: the first unified benchmark for deploying MCTS/MuZero in general sequential decision scenarios.
+ - ExpEnv: ClassicControl, Box2D, Atari, MuJoCo, GoBigger, MiniGrid, TicTacToe, ConnectFour, Gomoku, 2048, etc.
+- [Large Language Models as Commonsense Knowledge for Large-Scale Task Planning](https://openreview.net/pdf?id=Wjp1AYB8lH) 2023
+ - Zirui Zhao, Wee Sun Lee, David Hsu
+ - Key: world model (LLM) and the LLM-induced policy can be combined in MCTS, to scale up task planning.
+ - ExpEnv: multiplication, travel planning, object rearrangement
+- [Monte Carlo Tree Search with Boltzmann Exploration](https://openreview.net/pdf?id=NG4DaApavi) 2023
+ - Michael Painter, Mohamed Baioumy, Nick Hawes, Bruno Lacerda
+ - Key: Boltzmann exploration with MCTS, optimal actions for the maximum entropy objective do not necessarily correspond to optimal actions for the original objective, two improved algorithms.
+ - ExpEnv: the Frozen Lake environment, the Sailing Problem, Go
+- [Generalized Weighted Path Consistency for Mastering Atari Games](https://openreview.net/pdf?id=vHRLS8HhK1) 2023
+ - Dengwei Zhao, Shikui Tu, Lei Xu
+ - Key: Generalized Weighted Path Consistency, A weighting mechanism.
+ - ExpEnv: Atari
+- [Accelerating Monte Carlo Tree Search with Probability Tree State Abstraction](https://openreview.net/pdf?id=0zeLTZAqaJ) 2023
+ - Yangqing Fu, Ming Sun, Buqing Nie, Yue Gao
+ - Key: probability tree state abstraction, transitivity and aggregation error bound
+ - ExpEnv: Atari, CartPole, LunarLander, Gomoku
+- [Planning for Sample Efficient Imitation Learning](https://openreview.net/forum?id=BkN5UoAqF7) 2022
+ - Zhao-Heng Yin, Weirui Ye, Qifeng Chen, Yang Gao
+ - Key: Behavioral Cloning,Adversarial Imitation Learning (AIL),MCTS-based RL,
+ - ExpEnv: DeepMind Control Suite
+ - [Code](https://github.com/zhaohengyin/EfficientImitate)
+- [Evaluation Beyond Task Performance: Analyzing Concepts in AlphaZero in Hex](https://openreview.net/pdf?id=dwKwB2Cd-Km) 2022
+ - Charles Lovering, Jessica Zosa Forde, George Konidaris, Ellie Pavlick, Michael L. Littman
+ - Key: AlphaZero’s internal representations, model probing and behavioral tests, how these concepts are captured in the network.
+ - ExpEnv: Hex
+- [Are AlphaZero-like Agents Robust to Adversarial Perturbations?](https://openreview.net/pdf?id=yZ_JlZaOCzv) 2022
+ - Li-Cheng Lan, Huan Zhang, Ti-Rong Wu, Meng-Yu Tsai, I-Chen Wu, 4 Cho-Jui Hsieh
+ - Key: adversarial states, first adversarial attack on Go AIs
+ - ExpEnv: Go
+- [Monte Carlo Tree Descent for Black-Box Optimization](https://openreview.net/pdf?id=FzdmrTUyZ4g) 2022
+ - Yaoguang Zhai, Sicun Gao
+ - Key: Black-Box Optimization, how to further integrate samplebased descent for faster optimization.
+ - ExpEnv: synthetic functions for nonlinear optimization, reinforcement learning problems in MuJoCo locomotion environments, and optimization problems in Neural Architecture Search (NAS).
+- [Monte Carlo Tree Search based Variable Selection for High Dimensional Bayesian Optimization](https://openreview.net/pdf?id=SUzPos_pUC) 2022
+ - Lei Song∗ , Ke Xue∗ , Xiaobin Huang, Chao Qian
+ - Key: a low-dimensional subspace via MCTS, optimizes in the subspace with any Bayesian optimization algorithm.
+ - ExpEnv: NAS-bench problems and MuJoCo locomotion
+- [Monte Carlo Tree Search With Iteratively Refining State Abstractions](https://proceedings.neurips.cc/paper/2021/file/9b0ead00a217ea2c12e06a72eec4923f-Paper.pdf) 2021
+ - Samuel Sokota, Caleb Ho, Zaheen Ahmad, J. Zico Kolter
+ - Key: stochastic environments, Progressive widening, abstraction refining,
+ - ExpEnv: Blackjack, Trap, five by five Go.
+- [Deep Synoptic Monte Carlo Planning in Reconnaissance Blind Chess](https://proceedings.neurips.cc/paper/2021/file/215a71a12769b056c3c32e7299f1c5ed-Paper.pdf) 2021
+ - Gregory Clark
+ - Key: imperfect information, belief state with an unweighted particle filter, a novel stochastic abstraction of information states.
+ - ExpEnv: reconnaissance blind chess
+- [POLY-HOOT: Monte-Carlo Planning in Continuous Space MDPs with Non-Asymptotic Analysis](https://proceedings.neurips.cc/paper/2020/file/30de24287a6d8f07b37c716ad51623a7-Paper.pdf) 2020
+ - Weichao Mao, Kaiqing Zhang, Qiaomin Xie, Tamer Ba¸sar
+ - Key: continuous state-action spaces, Hierarchical Optimistic Optimization,
+ - ExpEnv: CartPole, Inverted Pendulum, Swing-up, and LunarLander.
+- [Learning Search Space Partition for Black-box Optimization using Monte Carlo Tree Search](https://proceedings.neurips.cc/paper/2020/file/e2ce14e81dba66dbff9cbc35ecfdb704-Paper.pdf) 2020
+ - Linnan Wang, Rodrigo Fonseca, Yuandong Tian
+ - Key: learns the partition of the search space using a few samples, a nonlinear decision boundary and learns a local model to pick good candidates.
+ - ExpEnv: MuJoCo locomotion tasks, Small-scale Benchmarks,
+- [Mix and Match: An Optimistic Tree-Search Approach for Learning Models from Mixture Distributions](https://arxiv.org/abs/1907.10154) 2020
+ - Matthew Faw, Rajat Sen, Karthikeyan Shanmugam, Constantine Caramanis, Sanjay Shakkottai
+ - Key: covariate shift problem, Mix&Match combines stochastic gradient descent (SGD) with optimistic tree search and model re-use (evolving partially trained models with samples from different mixture distributions)
+ - [Code](https://github.com/matthewfaw/mixnmatch)
+
+#### Other Conference or Journal
+- [On Monte Carlo Tree Search and Reinforcement Learning](https://www.jair.org/index.php/jair/article/download/11099/26289/20632) Journal of Artificial Intelligence Research 2017.
+- [Sample-Efficient Neural Architecture Search by Learning Actions for Monte Carlo Tree Search](https://arxiv.org/pdf/1906.06832) IEEE Transactions on Pattern Analysis and Machine Intelligence 2022.
+
+
+## 反馈意见和贡献
+- 有任何疑问或意见都可以在 github 上直接 [提出 issue](https://github.com/opendilab/LightZero/issues/new/choose)
+- 或者联系我们的邮箱 (opendilab@pjlab.org.cn)
+
+- 感谢所有的反馈意见,包括对算法和系统设计。这些反馈意见和建议都会让 LightZero 变得更好。
+
+
+## 引用
+
+```latex
+@misc{lightzero,
+ title={LightZero: A Unified Benchmark for Monte Carlo Tree Search in General Sequential Decision Scenarios},
+ author={Yazhe Niu and Yuan Pu and Zhenjie Yang and Xueyan Li and Tong Zhou and Jiyuan Ren and Shuai Hu and Hongsheng Li and Yu Liu},
+ year={2023},
+ eprint={2310.08348},
+ archivePrefix={arXiv},
+ primaryClass={cs.LG}
+}
+```
+
+## 致谢
+此算法库的实现部分基于以下 GitHub 仓库,非常感谢这些开创性工作:
+- https://github.com/opendilab/DI-engine
+- https://github.com/deepmind/mctx
+- https://github.com/YeWR/EfficientZero
+- https://github.com/werner-duvaud/muzero-general
+
+特别感谢以下贡献者 [@PaParaZz1](https://github.com/PaParaZz1), [@karroyan](https://github.com/karroyan), [@nighood](https://github.com/nighood),
+[@jayyoung0802](https://github.com/jayyoung0802), [@timothijoe](https://github.com/timothijoe), [@TuTuHuss](https://github.com/TuTuHuss), [@HarryXuancy](https://github.com/HarryXuancy), [@puyuan1996](https://github.com/puyuan1996), [@HansBug](https://github.com/HansBug) 对本项目的贡献和支持。
+
+感谢所有为此项目做出贡献的人:
+
+
+
+
+## 许可证
+
+本仓库中的所有代码都符合 [Apache License 2.0](https://www.apache.org/licenses/LICENSE-2.0)。
+
+
+
diff --git a/LightZero/cloc.sh b/LightZero/cloc.sh
new file mode 100755
index 0000000000000000000000000000000000000000..2dc336fc8aa81350fbe9a03c543927734ff00c2b
--- /dev/null
+++ b/LightZero/cloc.sh
@@ -0,0 +1,65 @@
+#!/bin/bash
+
+# This scripts counts the lines of code and comments in all source files
+# and prints the results to the command line. It uses the commandline tool
+# "cloc". You can either pass --loc, --comments or --percentage to show the
+# respective values only.
+# Some parts below need to be adapted to your project!
+
+# Get the location of this script.
+SCRIPT_DIR="$(cd "$(dirname "$0")" && pwd)"
+
+# Run cloc - this counts code lines, blank lines and comment lines
+# for the specified languages. You will need to change this accordingly.
+# For C++, you could use "C++,C/C++ Header" for example.
+# We are only interested in the summary, therefore the tail -1
+SUMMARY="$(cloc "${SCRIPT_DIR}" --include-lang="Python" --md | tail -1)"
+
+# The $SUMMARY is one line of a markdown table and looks like this:
+# SUM:|101|3123|2238|10783
+# We use the following command to split it into an array.
+IFS='|' read -r -a TOKENS <<<"$SUMMARY"
+
+# Store the individual tokens for better readability.
+NUMBER_OF_FILES=${TOKENS[1]}
+COMMENT_LINES=${TOKENS[3]}
+LINES_OF_CODE=${TOKENS[4]}
+
+# To make the estimate of commented lines more accurate, we have to
+# subtract any copyright header which is included in each file.
+# For Fly-Pie, this header has the length of five lines.
+# All dumb comments like those /////////// or those // ------------
+# are also subtracted. As cloc does not count inline comments,
+# the overall estimate should be rather conservative.
+# Change the lines below according to your project.
+# DUMB_COMMENTS="$(grep -r -E '//////|// -----' "${SCRIPT_DIR}" | wc -l)"
+# COMMENT_LINES=$(($COMMENT_LINES - 5 * $NUMBER_OF_FILES - $DUMB_COMMENTS))
+
+# Print all results if no arguments are given.
+if [[ $# -eq 0 ]]; then
+ awk -v a=$LINES_OF_CODE \
+ 'BEGIN {printf "Lines of source code: %6.1fk\n", a/1000}'
+ awk -v a=$COMMENT_LINES \
+ 'BEGIN {printf "Lines of comments: %6.1fk\n", a/1000}'
+ awk -v a=$COMMENT_LINES -v b=$LINES_OF_CODE \
+ 'BEGIN {printf "Comment Percentage: %6.1f%\n", 100*a/b}'
+ exit 0
+fi
+
+# Show lines of code if --loc is given.
+if [[ $* == *--loc* ]]; then
+ awk -v a=$LINES_OF_CODE \
+ 'BEGIN {printf "%.1fk\n", a/1000}'
+fi
+
+# Show lines of comments if --comments is given.
+if [[ $* == *--comments* ]]; then
+ awk -v a=$COMMENT_LINES \
+ 'BEGIN {printf "%.1fk\n", a/1000}'
+fi
+
+# Show precentage of comments if --percentage is given.
+if [[ $* == *--percentage* ]]; then
+ awk -v a=$COMMENT_LINES -v b=$LINES_OF_CODE \
+ 'BEGIN {printf "%.1f\n", 100*a/b}'
+fi
diff --git a/LightZero/docs/Makefile b/LightZero/docs/Makefile
new file mode 100644
index 0000000000000000000000000000000000000000..bc1f9f6fa9a6d5e2b384719f0623c609f2c13363
--- /dev/null
+++ b/LightZero/docs/Makefile
@@ -0,0 +1,62 @@
+# You can set these variables from the command line, and also
+# from the environment for the first two.
+SPHINXOPTS ?=
+SPHINXBUILD ?= $(shell which sphinx-build)
+SPHINXMULTIVERSION ?= $(shell which sphinx-multiversion)
+SOURCEDIR ?= source
+BUILDDIR ?= build
+
+# Minimal makefile for Sphinx documentation
+DIAGRAMS_MK := ${SOURCEDIR}/diagrams.mk
+DIAGRAMS := $(MAKE) -f "${DIAGRAMS_MK}" SOURCE=${SOURCEDIR}
+GRAPHVIZ_MK := ${SOURCEDIR}/graphviz.mk
+GRAPHVIZ := $(MAKE) -f "${GRAPHVIZ_MK}" SOURCE=${SOURCEDIR}
+DEMOS_MK := ${SOURCEDIR}/demos.mk
+DEMOS := $(MAKE) -f "${DEMOS_MK}" SOURCE=${SOURCEDIR}
+NOTEBOOK_MK := ${SOURCEDIR}/notebook.mk
+NOTEBOOK := $(MAKE) -f "${NOTEBOOK_MK}" SOURCE=${SOURCEDIR}
+
+_CURRENT_PATH := ${PATH}
+_PROJ_DIR := $(shell readlink -f ${CURDIR}/..)
+_LIBS_DIR := $(shell readlink -f ${SOURCEDIR}/_libs)
+_SHIMS_DIR := $(shell readlink -f ${SOURCEDIR}/_shims)
+
+.EXPORT_ALL_VARIABLES:
+
+PYTHONPATH = ${_PROJ_DIR}:${_LIBS_DIR}
+PATH = ${_SHIMS_DIR}:${_CURRENT_PATH}
+NO_CONTENTS_BUILD = true
+
+# Catch-all target: route all unknown targets to Sphinx using the new
+# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
+# Put it first so that "make" without argument is like "make help".
+.PHONY: help contents build html prod clean sourcedir builddir Makefile
+
+help:
+ @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+contents:
+ @$(DIAGRAMS) build
+ @$(GRAPHVIZ) build
+ @$(DEMOS) build
+ @$(NOTEBOOK) build
+build: html
+html: contents
+ @$(SPHINXBUILD) -M html "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+ @touch "$(BUILDDIR)/html/.nojekyll"
+prod:
+ @NO_CONTENTS_BUILD='' $(SPHINXMULTIVERSION) "$(SOURCEDIR)" "$(BUILDDIR)/html" $(SPHINXOPTS) $(O)
+ @cp main_page.html "$(BUILDDIR)/html/index.html"
+ @touch "$(BUILDDIR)/html/.nojekyll"
+
+clean:
+ @$(DIAGRAMS) clean
+ @$(GRAPHVIZ) clean
+ @$(DEMOS) clean
+ @$(NOTEBOOK) clean
+ @$(SPHINXBUILD) -M clean "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O)
+
+sourcedir:
+ @echo $(shell readlink -f ${SOURCEDIR})
+builddir:
+ @echo $(shell readlink -f ${BUILDDIR}/html)
\ No newline at end of file
diff --git a/LightZero/docs/main_page.html b/LightZero/docs/main_page.html
new file mode 100644
index 0000000000000000000000000000000000000000..05be0a427c9e9bb56fd725b6c75be82add4d3704
--- /dev/null
+++ b/LightZero/docs/main_page.html
@@ -0,0 +1,9 @@
+
+
+
+ Redirecting to master branch
+
+
+
+
+
\ No newline at end of file
diff --git a/LightZero/docs/source/_libs/.keep b/LightZero/docs/source/_libs/.keep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/LightZero/docs/source/_shims/.keep b/LightZero/docs/source/_shims/.keep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/LightZero/docs/source/_static/.keep b/LightZero/docs/source/_static/.keep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/LightZero/docs/source/_templates/.keep b/LightZero/docs/source/_templates/.keep
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/LightZero/docs/source/_templates/page.html b/LightZero/docs/source/_templates/page.html
new file mode 100644
index 0000000000000000000000000000000000000000..006fb47dd50a22d2e273b5058139a1d4bf2a9ca7
--- /dev/null
+++ b/LightZero/docs/source/_templates/page.html
@@ -0,0 +1,19 @@
+{% extends "!page.html" %}
+{% block body %}
+ {% if current_version and latest_version and current_version != latest_version %}
+
+
+ {% if current_version.is_released %}
+ You're reading an old version of this documentation.
+ If you want up-to-date information, please have a look at
+ {{ latest_version.name }}.
+ {% else %}
+ You're reading the documentation for a development version.
+ For the latest released version, please have a look at
+ {{ latest_version.name }}.
+ {% endif %}
+
+
+ {% endif %}
+ {{ super() }}
+{% endblock %}%
\ No newline at end of file
diff --git a/LightZero/docs/source/_templates/versions.html b/LightZero/docs/source/_templates/versions.html
new file mode 100644
index 0000000000000000000000000000000000000000..37480dd3275ce9f683ddb3b6b8aa245b938c7eb8
--- /dev/null
+++ b/LightZero/docs/source/_templates/versions.html
@@ -0,0 +1,27 @@
+{%- if current_version %}
+
+{%- endif %}
\ No newline at end of file
diff --git a/LightZero/docs/source/api_doc/config/index.rst b/LightZero/docs/source/api_doc/config/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..e8cbeb06253107f4ad7e1fda441565729cde0f18
--- /dev/null
+++ b/LightZero/docs/source/api_doc/config/index.rst
@@ -0,0 +1,7 @@
+lzero.config
+=====================
+
+.. toctree::
+ :maxdepth: 3
+
+ meta
diff --git a/LightZero/docs/source/api_doc/config/meta.rst b/LightZero/docs/source/api_doc/config/meta.rst
new file mode 100644
index 0000000000000000000000000000000000000000..3eb0bb87da23291720a7ecf6ef645f721e004392
--- /dev/null
+++ b/LightZero/docs/source/api_doc/config/meta.rst
@@ -0,0 +1,38 @@
+lzero.config.meta
+==========================
+
+.. automodule:: lzero.config.meta
+
+\_\_TITLE\_\_
+------------------
+
+.. autodata:: lzero.config.meta.__TITLE__
+ :annotation:
+
+
+\_\_VERSION\_\_
+------------------
+
+.. autodata:: lzero.config.meta.__VERSION__
+ :annotation:
+
+
+\_\_DESCRIPTION\_\_
+----------------------
+
+.. autodata:: lzero.config.meta.__DESCRIPTION__
+ :annotation:
+
+
+\_\_AUTHOR\_\_
+------------------
+
+.. autodata:: lzero.config.meta.__AUTHOR__
+ :annotation:
+
+
+\_\_AUTHOR_EMAIL\_\_
+----------------------
+
+.. autodata:: lzero.config.meta.__AUTHOR_EMAIL__
+ :annotation:
diff --git a/LightZero/docs/source/api_doc/entry/eval_alphazero.rst b/LightZero/docs/source/api_doc/entry/eval_alphazero.rst
new file mode 100644
index 0000000000000000000000000000000000000000..ea1b89d84a4a67a996f600beb963912a042685bd
--- /dev/null
+++ b/LightZero/docs/source/api_doc/entry/eval_alphazero.rst
@@ -0,0 +1,15 @@
+lzero.entry.eval_alphazero
+==============================
+
+
+.. automodule:: lzero.entry.eval_alphazero
+.. py:currentmodule:: lzero.entry.eval_alphazero
+
+
+eval_alphazero
+----------------------
+
+.. autofunction:: eval_alphazero
+
+
+
diff --git a/LightZero/docs/source/api_doc/entry/index.rst b/LightZero/docs/source/api_doc/entry/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..c7c02aa8a5362018f67c727de46857d4d71d5000
--- /dev/null
+++ b/LightZero/docs/source/api_doc/entry/index.rst
@@ -0,0 +1,7 @@
+lzero.entry
+==============================
+
+.. toctree::
+ :maxdepth: 3
+
+ eval_alphazero
diff --git a/LightZero/docs/source/conf.py b/LightZero/docs/source/conf.py
new file mode 100644
index 0000000000000000000000000000000000000000..e920feffefc48f1510d741dc6c00a08ad7f607ea
--- /dev/null
+++ b/LightZero/docs/source/conf.py
@@ -0,0 +1,170 @@
+# Configuration file for the Sphinx documentation builder.
+#
+# This file only contains a selection of the most common options. For a full
+# list see the documentation:
+# https://www.sphinx-doc.org/en/master/usage/configuration.html
+
+# -- Path setup --------------------------------------------------------------
+
+# If extensions (or modules to document with autodoc) are in another directory,
+# add these directories to sys.path here. If the directory is relative to the
+# documentation root, use os.path.abspath to make it absolute, like shown here.
+#
+# import os
+# import sys
+# sys.path.insert(0, os.path.abspath('.'))
+
+
+# -- Project information -----------------------------------------------------
+
+import os
+import sys
+from datetime import datetime
+from subprocess import Popen
+
+import where
+from packaging import version as version_
+
+# Get current location
+_DOC_PATH = os.path.dirname(os.path.abspath(__file__))
+_PROJ_PATH = os.path.abspath(os.path.join(_DOC_PATH, '..', '..'))
+_LIBS_PATH = os.path.join(_DOC_PATH, '_libs')
+_SHIMS_PATH = os.path.join(_DOC_PATH, '_shims')
+os.chdir(_PROJ_PATH)
+
+# Set environment, remove the pre-installed package
+sys.path.insert(0, _PROJ_PATH)
+modnames = [mname for mname in sys.modules if mname.startswith('lzero')]
+for modname in modnames:
+ del sys.modules[modname]
+
+# Build dependencies if needed
+if not os.environ.get("NO_CONTENTS_BUILD"):
+ _env = dict(os.environ)
+ _env.update(dict(
+ PYTHONPATH=':'.join([_PROJ_PATH, _LIBS_PATH]),
+ PATH=':'.join([_SHIMS_PATH, os.environ.get('PATH', '')]),
+ ))
+
+ if os.path.exists(os.path.join(_PROJ_PATH, 'requirements-build.txt')):
+ pip_build_cmd = (where.first('pip'), 'install', '-r', os.path.join(_PROJ_PATH, 'requirements-build.txt'))
+ print("Install pip requirements {cmd}...".format(cmd=repr(pip_build_cmd)))
+ pip_build = Popen(pip_build_cmd, stdout=sys.stdout, stderr=sys.stderr, env=_env, cwd=_PROJ_PATH)
+ if pip_build.wait() != 0:
+ raise ChildProcessError("Pip install failed with %d." % (pip_build.returncode,))
+
+ make_build_cmd = (where.first('make'), 'clean', 'build')
+ print("Try building extensions {cmd}...".format(cmd=repr(make_build_cmd)))
+ make_build = Popen(make_build_cmd, stdout=sys.stdout, stderr=sys.stderr, env=_env, cwd=_PROJ_PATH)
+ if make_build.wait() != 0:
+ raise ChildProcessError("Extension build failed with %d." % (make_build.returncode,))
+
+ pip_cmd = (where.first('pip'), 'install', '-r', os.path.join(_PROJ_PATH, 'requirements.txt'))
+ print("Install pip requirements {cmd}...".format(cmd=repr(pip_cmd)))
+ pip = Popen(pip_cmd, stdout=sys.stdout, stderr=sys.stderr, env=_env, cwd=_PROJ_PATH)
+ if pip.wait() != 0:
+ raise ChildProcessError("Pip install failed with %d." % (pip.returncode,))
+
+ pip_docs_cmd = (where.first('pip'), 'install', '-r', os.path.join(_PROJ_PATH, 'requirements-doc.txt'))
+ print("Install pip docs requirements {cmd}...".format(cmd=repr(pip_docs_cmd)))
+ pip_docs = Popen(pip_docs_cmd, stdout=sys.stdout, stderr=sys.stderr, env=_env, cwd=_PROJ_PATH)
+ if pip_docs.wait() != 0:
+ raise ChildProcessError("Pip docs install failed with %d." % (pip.returncode,))
+
+ diagrams_cmd = (where.first('make'), '-f', "diagrams.mk", "build")
+ print("Building diagrams {cmd} at {cp}...".format(cmd=repr(diagrams_cmd), cp=repr(_DOC_PATH)))
+ diagrams = Popen(diagrams_cmd, stdout=sys.stdout, stderr=sys.stderr, env=_env, cwd=_DOC_PATH)
+ if diagrams.wait() != 0:
+ raise ChildProcessError("Diagrams failed with %d." % (diagrams.returncode,))
+
+ graphviz_cmd = (where.first('make'), '-f', "graphviz.mk", "build")
+ print("Building graphs {cmd} at {cp}...".format(cmd=repr(graphviz_cmd), cp=repr(_DOC_PATH)))
+ graphviz = Popen(graphviz_cmd, stdout=sys.stdout, stderr=sys.stderr, env=_env, cwd=_DOC_PATH)
+ if graphviz.wait() != 0:
+ raise ChildProcessError("Graphviz failed with %d." % (graphviz.returncode,))
+
+ demos_cmd = (where.first('make'), '-f', "demos.mk", "build")
+ print("Building demos {cmd} at {cp}...".format(cmd=repr(demos_cmd), cp=repr(_DOC_PATH)))
+ demos = Popen(demos_cmd, stdout=sys.stdout, stderr=sys.stderr, env=_env, cwd=_DOC_PATH)
+ if demos.wait() != 0:
+ raise ChildProcessError("Demos failed with %d." % (demos.returncode,))
+
+ notebook_cmd = (where.first('make'), '-f', "notebook.mk", "build")
+ print("Executing notebooks {cmd} at {cp}...".format(cmd=repr(notebook_cmd), cp=repr(_DOC_PATH)))
+ demos = Popen(notebook_cmd, stdout=sys.stdout, stderr=sys.stderr, env=_env, cwd=_DOC_PATH)
+ if demos.wait() != 0:
+ raise ChildProcessError("Notebook failed with %d." % (demos.returncode,))
+
+ print("Build of contents complete.")
+
+from lzero.config.meta import __TITLE__, __AUTHOR__, __VERSION__
+
+project = __TITLE__
+copyright = '{year}, {author}'.format(year=datetime.now().year, author=__AUTHOR__)
+author = __AUTHOR__
+
+# The short X.Y version
+version = version_.parse(__VERSION__).base_version
+# The full version, including alpha/beta/rc tags
+release = __VERSION__
+
+# -- General configuration ---------------------------------------------------
+
+# Add any Sphinx extension module names here, as strings. They can be
+# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
+# ones.
+extensions = [
+ 'sphinx.ext.autodoc',
+ 'sphinx.ext.doctest',
+ 'sphinx.ext.mathjax',
+ 'sphinx.ext.ifconfig',
+ 'sphinx.ext.viewcode',
+ 'sphinx.ext.githubpages',
+ 'sphinx.ext.todo',
+ 'sphinx.ext.graphviz',
+ 'enum_tools.autoenum',
+ "sphinx_multiversion",
+ 'nbsphinx',
+]
+
+# Add any paths that contain templates here, relative to this directory.
+templates_path = ['_templates']
+
+# The language for content autogenerated by Sphinx. Refer to documentation
+# for a list of supported languages.
+#
+# This is also used if you do content translation via gettext catalogs.
+# Usually you set "language" from the command line for these cases.
+language = None
+
+# List of patterns, relative to source directory, that match files and
+# directories to ignore when looking for source files.
+# This pattern also affects html_static_path and html_extra_path.
+exclude_patterns = []
+
+# -- Options for HTML output -------------------------------------------------
+
+# The theme to use for HTML and HTML Help pages. See the documentation for
+# a list of builtin themes.
+#
+html_theme = 'sphinx_rtd_theme'
+htmlhelp_basename = 'LightZero'
+
+# Add any paths that contain custom static files (such as style sheets) here,
+# relative to this directory. They are copied after the builtin static files,
+# so a file named "default.css" will overwrite the builtin "default.css".
+html_static_path = ['_static']
+
+epub_title = project
+epub_exclude_files = ['search.html']
+
+# Whitelist pattern for tags (set to None to ignore all tags)
+smv_tag_whitelist = r'^v.*$' # Include all tags start with 'v'
+smv_branch_whitelist = r'^.*$' # Include all branches
+smv_remote_whitelist = r'^.*$' # Use branches from all remotes
+smv_released_pattern = r'^tags/.*$' # Tags only
+smv_outputdir_format = '{ref.name}' # Use the branch/tag name
+
+if not os.environ.get("ENV_PROD"):
+ todo_include_todos = True
+ todo_emit_warnings = True
diff --git a/LightZero/docs/source/demos.mk b/LightZero/docs/source/demos.mk
new file mode 100644
index 0000000000000000000000000000000000000000..e4bd6e03b3914fe9e2896cb4fc16f37f7014fb8c
--- /dev/null
+++ b/LightZero/docs/source/demos.mk
@@ -0,0 +1,48 @@
+PYTHON := $(shell which python)
+
+SOURCE ?= .
+PYTHON_DEMOS := $(shell find ${SOURCE} -name *.demo.py)
+PYTHON_DEMOXS := $(shell find ${SOURCE} -name *.demox.py)
+PYTHON_RESULTS := $(addsuffix .py.txt, $(basename ${PYTHON_DEMOS} ${PYTHON_DEMOXS}))
+
+SHELL_DEMOS := $(shell find ${SOURCE} -name *.demo.sh)
+SHELL_DEMOXS := $(shell find ${SOURCE} -name *.demox.sh)
+SHELL_RESULTS := $(addsuffix .sh.txt, $(basename ${SHELL_DEMOS} ${SHELL_DEMOXS}))
+
+%.demo.py.txt: %.demo.py
+ cd "$(shell dirname $(shell readlink -f $<))" && \
+ PYTHONPATH="$(shell dirname $(shell readlink -f $<)):${PYTHONPATH}" \
+ $(PYTHON) "$(shell readlink -f $<)" > "$(shell readlink -f $@)"
+
+%.demox.py.txt: %.demox.py
+ cd "$(shell dirname $(shell readlink -f $<))" && \
+ PYTHONPATH="$(shell dirname $(shell readlink -f $<)):${PYTHONPATH}" \
+ $(PYTHON) "$(shell readlink -f $<)" 1> "$(shell readlink -f $@)" \
+ 2> "$(shell readlink -f $(addsuffix .err, $(basename $@)))"; \
+ echo $$? > "$(shell readlink -f $(addsuffix .exitcode, $(basename $@)))"
+
+%.demo.sh.txt: %.demo.sh
+ cd "$(shell dirname $(shell readlink -f $<))" && \
+ PYTHONPATH="$(shell dirname $(shell readlink -f $<)):${PYTHONPATH}" \
+ $(SHELL) "$(shell readlink -f $<)" > "$(shell readlink -f $@)"
+
+%.demox.sh.txt: %.demox.sh
+ cd "$(shell dirname $(shell readlink -f $<))" && \
+ PYTHONPATH="$(shell dirname $(shell readlink -f $<)):${PYTHONPATH}" \
+ $(SHELL) "$(shell readlink -f $<)" 1> "$(shell readlink -f $@)" \
+ 2> "$(shell readlink -f $(addsuffix .err, $(basename $@)))"; \
+ echo $$? > "$(shell readlink -f $(addsuffix .exitcode, $(basename $@)))"
+
+build: ${PYTHON_RESULTS} ${SHELL_RESULTS}
+
+all: build
+
+clean:
+ rm -rf \
+ $(shell find ${SOURCE} -name *.py.txt) \
+ $(shell find ${SOURCE} -name *.py.err) \
+ $(shell find ${SOURCE} -name *.py.exitcode) \
+ $(shell find ${SOURCE} -name *.sh.txt) \
+ $(shell find ${SOURCE} -name *.sh.err) \
+ $(shell find ${SOURCE} -name *.sh.exitcode) \
+ $(shell find ${SOURCE} -name *.dat.*)
diff --git a/LightZero/docs/source/diagrams.mk b/LightZero/docs/source/diagrams.mk
new file mode 100644
index 0000000000000000000000000000000000000000..2d7b838d0079ca3a82f36e82199dfc5be4fad2da
--- /dev/null
+++ b/LightZero/docs/source/diagrams.mk
@@ -0,0 +1,21 @@
+PLANTUMLCLI ?= $(shell which plantumlcli)
+
+SOURCE ?= .
+PUMLS := $(shell find ${SOURCE} -name *.puml)
+PNGS := $(addsuffix .puml.png, $(basename ${PUMLS}))
+SVGS := $(addsuffix .puml.svg, $(basename ${PUMLS}))
+
+%.puml.png: %.puml
+ $(PLANTUMLCLI) -t png -o "$(shell readlink -f $@)" "$(shell readlink -f $<)"
+
+%.puml.svg: %.puml
+ $(PLANTUMLCLI) -t svg -o "$(shell readlink -f $@)" "$(shell readlink -f $<)"
+
+build: ${SVGS} ${PNGS}
+
+all: build
+
+clean:
+ rm -rf \
+ $(shell find ${SOURCE} -name *.puml.svg) \
+ $(shell find ${SOURCE} -name *.puml.png) \
diff --git a/LightZero/docs/source/graphviz.mk b/LightZero/docs/source/graphviz.mk
new file mode 100644
index 0000000000000000000000000000000000000000..e8235d3dc815ac861d2fb1812fe69e64bbf4a5a0
--- /dev/null
+++ b/LightZero/docs/source/graphviz.mk
@@ -0,0 +1,21 @@
+DOT := $(shell which dot)
+
+SOURCE ?= .
+GVS := $(shell find ${SOURCE} -name *.gv)
+PNGS := $(addsuffix .gv.png, $(basename ${GVS}))
+SVGS := $(addsuffix .gv.svg, $(basename ${GVS}))
+
+%.gv.png: %.gv
+ $(DOT) -Tpng -o"$(shell readlink -f $@)" "$(shell readlink -f $<)"
+
+%.gv.svg: %.gv
+ $(DOT) -Tsvg -o"$(shell readlink -f $@)" "$(shell readlink -f $<)"
+
+build: ${SVGS} ${PNGS}
+
+all: build
+
+clean:
+ rm -rf \
+ $(shell find ${SOURCE} -name *.gv.svg) \
+ $(shell find ${SOURCE} -name *.gv.png) \
diff --git a/LightZero/docs/source/index.rst b/LightZero/docs/source/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..405903bcf7e309cb07f20910a322dc4752afe35d
--- /dev/null
+++ b/LightZero/docs/source/index.rst
@@ -0,0 +1,25 @@
+Welcome to LightZero's Documentation
+=====================================
+
+Overview
+-------------
+
+``LightZero`` is a generalized tree-based data structure.
+Almost all the operation can be supported \
+in form of trees in a convenient way to simplify the \
+structure processing when the calculation is tree-based.
+
+.. toctree::
+ :maxdepth: 2
+ :caption: Tutorials
+
+ tutorials/installation/index
+
+
+.. toctree::
+ :maxdepth: 2
+ :caption: API Documentation
+
+ api_doc/config/index
+ api_doc/entry/index
+
diff --git a/LightZero/docs/source/notebook.mk b/LightZero/docs/source/notebook.mk
new file mode 100644
index 0000000000000000000000000000000000000000..d4cab1222f6ff30b8aae07a2359c00bb8dc84589
--- /dev/null
+++ b/LightZero/docs/source/notebook.mk
@@ -0,0 +1,25 @@
+JUPYTER ?= $(shell which jupyter)
+NBCONVERT ?= ${JUPYTER} nbconvert
+
+SOURCE ?= .
+IPYNBS := $(shell find ${SOURCE} -name *.ipynb -not -name *.result.ipynb)
+RESULTS := $(addsuffix .result.ipynb, $(basename ${IPYNBS}))
+
+%.result.ipynb: %.ipynb
+ cp "$(shell readlink -f $<)" "$(shell readlink -f $@)" && \
+ cd "$(shell dirname $(shell readlink -f $<))" && \
+ PYTHONPATH="$(shell dirname $(shell readlink -f $<)):${PYTHONPATH}" \
+ $(NBCONVERT) --to notebook --inplace --execute "$(shell readlink -f $@)"
+
+build: ${RESULTS}
+
+all: build
+
+clean:
+ rm -rf \
+ $(shell find ${SOURCE} -name *.result.ipynb)
+ for nb in ${IPYNBS}; do \
+ if [ -f $$nb ]; then \
+ $(NBCONVERT) --clear-output --inplace $$nb; \
+ fi; \
+ done;
\ No newline at end of file
diff --git a/LightZero/docs/source/tutorials/algos/customize_algos.md b/LightZero/docs/source/tutorials/algos/customize_algos.md
new file mode 100644
index 0000000000000000000000000000000000000000..044336d48c56000fa3890fb46eb963517f0cba9f
--- /dev/null
+++ b/LightZero/docs/source/tutorials/algos/customize_algos.md
@@ -0,0 +1,163 @@
+# **How to Customize Your Algorithms in LightZero?**
+
+LightZero is an MCTS+RL reinforcement learning framework that provides a set of high-level APIs, enabling users to customize their algorithms within it. Here are some steps and considerations on how to customize an algorithm in LightZero.
+
+## **Basic Steps**
+
+### 1. Understand the Framework Structure
+
+Before you start coding your custom algorithms, you need to have a basic understanding of the LightZero framework's structure. The LightZero pipeline is illustrated in the following diagram.
+
+
+
+The repository's folder consists primarily of two parts: `lzero` and `zoo`. The `lzero` folder contains the core modules required for the LightZero framework's workflow. The `zoo` folder provides a set of predefined environments (`envs`) and their corresponding configuration (`config`) files. The `lzero` folder includes several core modules, including the `policy`, `model`, `worker`, and `entry`. These modules work together to implement complex reinforcement learning algorithms.
+
+- In this architecture, the `policy` module is responsible for implementing the algorithm's decision-making logic, such as action selection during agent-environment interaction and how to update the policy based on collected data. The `model` module is responsible for implementing the neural network structures required by the algorithm.
+
+- The `worker` module consists of two classes: Collector and Evaluator. An instance of the Collector class handles the agent-environment interaction to collect the necessary data for training, while an instance of the Evaluator class evaluates the performance of the current policy.
+
+- The `entry` module is responsible for initializing the environment, model, policy, etc., and its main loop implements core processes such as data collection, model training, and policy evaluation.
+
+- There are close interactions among these modules. Specifically, the `entry` module calls the Collector and Evaluator of the `worker` module to perform data collection and algorithm evaluation. The decision functions of the `policy` module are called by the Collector and Evaluator to determine the agent's actions in a specific environment. The neural network models implemented in the `model` module are embedded in the `policy` object for action generation during interaction and for updates during the training process.
+
+- In the `policy` module, you can find implementations of various algorithms. For example, the MuZero policy is implemented in the `muzero.py` file.
+
+
+### 2. Create a New Policy File
+Create a new Python file under the `lzero/policy` directory. This file will contain your algorithm implementation. For example, if your algorithm is called MyAlgorithm, you can create a file named `my_algorithm.py`.
+
+### 3. Implement Your Policy
+
+Within your policy file, you need to define a class to implement your strategy. This class should inherit from the `Policy` class in DI-engine and implement required methods. Below is a basic framework for a policy class:
+
+
+```python
+@POLICY_REGISTRY.register('my_algorithm')
+class MyAlgorithmPolicy(Policy):
+ """
+ Overview:
+ The policy class for MyAlgorithm.
+ """
+
+ config = dict(
+ # Add your config here
+ )
+
+ def __init__(self, cfg, **kwargs):
+ super().__init__(cfg, **kwargs)
+ # Initialize your policy here
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ # Set the default model name and the import path so that the default model can be loaded during policy initialization
+
+ def _init_learn(self):
+ # Initialize the learn mode here
+
+ def _forward_learn(self, data):
+ # Implement the forward function for learning mode here
+
+ def _init_collect(self):
+ # Initialize the collect mode here
+
+ def _forward_collect(self, data, **kwargs):
+ # Implement the forward function for collect mode here
+
+ def _init_eval(self):
+ # Initialize the eval mode here
+
+ def _forward_eval(self, data, **kwargs):
+ # Implement the forward function for eval mode here
+```
+
+#### Data Collection and Model Evaluation
+
+- In `default_model`, set the class name of the default model used by the current policy and the corresponding reference path.
+- The `_init_collect` and `_init_eval` functions are responsible for instantiating the action selection policy, and the respective policy instances will be called by the _forward_collect and _forward_eval functions.
+- The `_forward_collect` function takes the current state of the environment and selects a step action by calling the instantiated policy in `_init_collect`. The function returns the selected action list and other relevant information. During training, this function is called through the `collector.collect` method of the Collector object created by the Entry file.
+- The logic of the `_forward_eval` function is similar to that of the `_forward_collect`function. The only difference is that the policy used in `_forward_collect` is more focused on exploration to collect diverse training information, while the policy used in `_forward_eval` is more focused on exploitation to obtain the optimal performance of the current policy. During training, this function is called through the `evaluator.eval` method of the Evaluator object created by the Entry file.
+
+#### Policy Learning
+
+- The `_init_learn` function initializes the network model, optimizer, and other objects required during training using the associated parameters of the strategy, such as learning rate, update frequency, optimizer type, passed in from the config file.
+- The `_forward_learn` function is responsible for updating the network. Typically, the `_forward_learn` function receives the data collected by the Collector, calculates the loss function based on this data, and performs gradient updates. The function returns the various losses during the update process and the relevant parameters used for the update, for experimental recording purposes. During training, this function is called through the `learner.train` method of the Learner object created by the Entry file.
+
+### 4. Register Your Policy
+To make LightZero recognize your policy, you need to use the `@POLICY_REGISTRY.register('my_algorithm')` decorator above your policy class to register your policy. This way, LightZero can refer to your policy by the name 'my_algorithm'. Specifically, in the experiment's configuration file, the corresponding algorithm is specified through the `create_config` section:
+
+```Python
+create_config = dict(
+ ...
+ policy=dict(
+ type='my_algorithm',
+ import_names=['lzero.policy.my_algorithm'],
+ ),
+ ...
+)
+```
+
+Here, `type` should be set to the registered policy name, and `import_names` should be set to the location of the policy package.
+
+### 5. Possible Other Modifications
+- **Model**: The LightZero `model.common` package provides some common network structures, such as the `RepresentationNetwork` that maps 2D images to a latent space representation and the `PredictionNetwork` used in MCTS for predicting probabilities and node values. If a custom policy requires a specific network model, you need to implement the corresponding model under the `model` folder. For example, the model for the MuZero algorithm is saved in the `muzero_model.py` file, which implements the `DynamicsNetwork` required by the MuZero algorithm and ultimately creates the `MuZeroModel` by calling the existing network structures in the `model.common` package.
+- **Worker**: LightZero provides corresponding `worker` for AlphaZero and MuZero. Subsequent algorithms like EfficientZero and GumbelMuzero inherit the `worker` from MuZero. If your algorithm has different logic for data collection, you need to implement the corresponding `worker`. For example, if your algorithm requires preprocessing of collected transitions, you can add this segment under the `collect` function of the collector, in which the `get_train_sample` function implements the specific data processing process.
+
+```Python
+if timestep.done:
+ # Prepare trajectory data.
+ transitions = to_tensor_transitions(self._traj_buffer[env_id])
+ # Use ``get_train_sample`` to process the data.
+ train_sample = self._policy.get_train_sample(transitions)
+ return_data.extend(train_sample)
+ self._traj_buffer[env_id].clear()
+```
+
+### 6. Test Your Policy
+After implementing your strategy, it is crucial to ensure its correctness and effectiveness. To do so, you should write some unit tests to verify that your strategy is functioning correctly. For example, you can test if the strategy can execute in a specific environment and if the output of the strategy matches the expected results. You can refer to the [documentation](https://di-engine-docs.readthedocs.io/zh_CN/latest/22_test/index_zh.html) in the DI-engine for guidance on how to write unit tests. You can add your tests in the `lzero/policy/tests`. When writing tests, try to consider all possible scenarios and boundary conditions to ensure your strategy can run properly in various situations.
+
+Here is an example of unit testing in LightZero. In this example, we test the `inverse_scalar_transform` and `InverseScalarTransform`methods. Both methods reverse the transformation of a given value, but they have different implementations. In the unit test, we apply these two methods to the same set of data and compare the output results. If the results are the same, the test passes.
+
+```Python
+import pytest
+import torch
+from lzero.policy.scaling_transform import inverse_scalar_transform, InverseScalarTransform
+
+@pytest.mark.unittest
+def test_scaling_transform():
+ import time
+ logit = torch.randn(16, 601)
+ start = time.time()
+ output_1 = inverse_scalar_transform(logit, 300)
+ print('t1', time.time() - start)
+ handle = InverseScalarTransform(300)
+ start = time.time()
+ output_2 = handle(logit)
+ print('t2', time.time() - start)
+ assert output_1.shape == output_2.shape == (16, 1)
+ assert (output_1 == output_2).all()
+```
+
+In the unit test file, you need to mark the tests with `@pytest.mark.unittest` to include them in the Python testing framework. This allows you to run the unit test file directly by entering `pytest -sv xxx.py` in the command line. `-sv` is a command option that, when used, prints detailed information to the terminal during the test execution for easier inspection.
+
+### 7. Comprehensive Testing and Running
+
+- After ensuring the basic functionality of the policy, you need to use classic environments like cartpole to conduct comprehensive correctness and convergence tests on your policy. This is to verify that your policy can work effectively not only in unit tests but also in real game environments.
+- You can write related configuration files and entry programs by referring to `cartpole_muzero_config.py`. During the testing process, pay attention to record performance data of the policy, such as the score of each round, the convergence speed of the policy, etc., for analysis and improvement.
+
+### 8. Contribution
+
+- After completing all the above steps, if you wish to contribute your policy to the LightZero repository, you can submit a Pull Request on the official repository. Before submission, ensure your code complies with the repository's coding standards, all tests have passed, and there are sufficient documents and comments to explain your code and policy.
+
+- In the description of the PR, explain your policy in detail, including its working principle, your implementation method, and its performance in tests. This will help others understand your contribution and speed up the PR review process.
+
+### 9. Share, Discuss, and Improve
+
+- After implementing and testing the policy, consider sharing your results and experiences with the community. You can post your policy and test results on forums, blogs, or social media and invite others to review and discuss your work. This not only allows you to receive feedback from others but also helps you build a professional network and may trigger new ideas and collaborations.
+- Based on your test results and community feedback, continuously improve and optimize your policy. This may involve adjusting policy parameters, improving code performance, or solving problems and bugs that arise. Remember, policy development is an iterative process, and there's always room for improvement.
+
+## **Considerations**
+
+- Ensure that your code complies with the Python PEP8 coding standards.
+- When implementing methods like `_forward_learn`, `_forward_collect`, and `_forward_eval`, ensure that you correctly handle input and returned data.
+- When writing your policy, ensure that you consider different types of environments. Your policy should be able to handle various environments.
+- When implementing your policy, try to make your code as modular as possible, facilitating others to understand and reuse your code.
+- Write clear documentation and comments describing how your policy works and how your code implements this policy. Strive to maintain the core meaning of the content while enhancing its professionalism and fluency.
\ No newline at end of file
diff --git a/LightZero/docs/source/tutorials/algos/customize_algos_zh.md b/LightZero/docs/source/tutorials/algos/customize_algos_zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..4d115aefa4b59dedf1bad17a847b954e72576dee
--- /dev/null
+++ b/LightZero/docs/source/tutorials/algos/customize_algos_zh.md
@@ -0,0 +1,166 @@
+# LightZero 中如何自定义算法?
+
+LightZero 是一个 MCTS+RL 强化学习框架,它提供了一组高级 API,使得用户可以在其中自定义自己的算法。以下是一些关于如何在 LightZero 中自定义算法的步骤和注意事项。
+
+## 基本步骤
+
+### 1. 理解框架结构
+
+在开始编写自定义算法之前,你需要对 LightZero 的框架结构有一个基本的理解,LightZero 的流程如图所示。
+
+
+
+
+
+仓库的文件夹主要由 `lzero` 和 `zoo` 这两部分组成。`lzero` 中实现了LightZero框架流程所需的核心模块。而 `zoo` 提供了一系列预定义的环境(`envs`)以及对应的配置(`config`)文件。
+`lzero` 文件夹下包括多个核心模块,包括策略(`policy`)、模型(`model`)、工作件(`worker`)以及入口(`entry`)等。这些模块在一起协同工作,实现复杂的强化学习算法。
+- 在此架构中,`policy` 模块负责实现算法的决策逻辑,如在智能体与环境交互时的动作选择,以及如何根据收集到的数据更新策略。 `model` 模块则负责实现算法所需的神经网络结构。
+- `worker` 模块包含 Collector 和 Evaluator 两个类。 Collector 实例负责执行智能体与环境的交互,以收集训练所需的数据,而 Evaluator 实例则负责评估当前策略的性能。
+- `entry` 模块负责初始化环境、模型、策略等,并在其主循环中负责实现数据收集、模型训练以及策略评估等核心过程。
+- 在这些模块之间,存在着紧密的交互关系。具体来说, `entry` 模块会调用 `worker` 模块的Collector和Evaluator来完成数据收集和算法评估。同时, `policy` 模块的决策函数会被Collector和Evaluator调用,以决定智能体在特定环境中的行动。而 `model` 模块实现的神经网络模型,则被嵌入到 `policy` 对象中,用于在交互过程中生成动作,以及在训练过程中进行更新。
+- 在 `policy` 模块中,你可以找到多种算法的实现,例如,MuZero策略就在 `muzero.py` 文件中实现。
+
+### 2. 创建新的策略文件
+
+在 `lzero/policy` 目录下创建一个新的 Python 文件。这个文件将包含你的算法实现。例如,如果你的算法名为 `MyAlgorithm` ,你可以创建一个名为 `my_algorithm.py` 的文件。
+
+### 3. 实现你的策略
+
+在你的策略文件中,你需要定义一个类来实现你的策略。这个类应该继承自 DI-engine中的 `Policy` 类,并实现所需的方法。
+
+以下是一个基本的策略类的框架:
+
+```Python
+@POLICY_REGISTRY.register('my_algorithm')
+class MyAlgorithmPolicy(Policy):
+ """
+ Overview:
+ The policy class for MyAlgorithm.
+ """
+
+ config = dict(
+ # Add your config here
+ )
+
+ def __init__(self, cfg, **kwargs):
+ super().__init__(cfg, **kwargs)
+ # Initialize your policy here
+
+ def default_model(self) -> Tuple[str, List[str]]:
+ # Set the default model name and the import path so that the default model can be loaded during policy initialization
+
+ def _init_learn(self):
+ # Initialize the learn mode here
+
+ def _forward_learn(self, data):
+ # Implement the forward function for learning mode here
+
+ def _init_collect(self):
+ # Initialize the collect mode here
+
+ def _forward_collect(self, data, **kwargs):
+ # Implement the forward function for collect mode here
+
+ def _init_eval(self):
+ # Initialize the eval mode here
+
+ def _forward_eval(self, data, **kwargs):
+ # Implement the forward function for eval mode here
+```
+
+#### 收集数据与评估模型
+
+- 在 `default_model` 中设置当前策略使用的默认模型的类名和相应的引用路径。
+- `_init_collect` 和 `_init_eval` 函数均负责实例化动作选取策略,相应的策略实例会被 `_forward_collect` 和 `_forward_eval` 函数调用。
+- `_forward_collect` 函数会接收当前环境的状态,并通过调用 `_init_collect` 中实例化的策略来选择一步动作。函数会返回所选的动作列表以及其他相关信息。在训练期间,该函数会通过由Entry文件创建的Collector对象的 `collector.collect` 方法进行调用。
+- `_forward_eval` 函数的逻辑与 `_forward_collect` 函数基本一致。唯一的区别在于, `_forward_collect` 中采用的策略更侧重于探索,以收集尽可能多样的训练信息;而在 `_forward_eval` 函数中,所采用的策略更侧重于利用,以获取当前策略的最优性能。在训练期间,该函数会通过由Entry文件创建的Evaluator对象的 `evaluator.eval` 方法进行调用。
+
+#### 策略的学习
+
+- `_init_learn` 函数会利用 config 文件传入的学习率、更新频率、优化器类型等策略的关联参数初始化网络模型、优化器以及训练过程中所需的其他对象。
+- `_forward_learn` 函数则负责实现网络的更新。通常, `_forward_learn` 函数会接收 Collector 所收集的数据,根据这些数据计算损失函数并进行梯度更新。函数会返回更新过程中的各项损失以及更新所采用的相关参数,以便进行实验记录。在训练期间,该函数会通过由 Entry 文件创建的 Learner 对象的 `learner.train` 方法进行调用。
+
+### 4. 注册你的策略
+
+为了让 LightZero 能够识别你的策略,你需要在你的策略类上方使用 `@POLICY_REGISTRY.register('my_algorithm')` 这个装饰器来注册你的策略。这样, LightZero 就可以通过 `'my_algorithm'` 这个名字来引用你的策略了。
+具体而言,在实验的配置文件中,通过 `create_config` 部分来指定相应的算法:
+
+```Python
+create_config = dict(
+ ...
+ policy=dict(
+ type='my_algorithm',
+ import_names=['lzero.policy.my_algorithm'],
+ ),
+ ...
+)
+```
+
+其中 `type` 要设定为所注册的策略名, `import_names` 则设置为策略包的位置。
+
+### 5. **可能的其他更改**
+- **模型(model)**:在 LightZero 的 `model.common` 包中提供了一些通用的网络结构,例如将2D图像映射到隐空间中的表征网络 `RepresentationNetwork` ,在MCTS中用于预测概率和节点价值的预测网络 `PredictionNetwork` 等。如果自定义的策略需要专门的网络模型,则需要自行在 `model` 文件夹下实现相应的模型。例如 Muzero 算法的模型保存在 `muzero_model.py` 文件中,该文件实现了 Muzero 算法所需要的 `DynamicsNetwork` ,并通过调用 `model.common` 包中现成的网络结构最终实现了 `MuZeroModel` 。
+- **工作件(worker)**:在 LightZero 中实现了 AlphaZero 和 MuZero 的相应 `worker` 。后续的 EfficientZero 和 GumbelMuzero 等算法沿用了 MuZero 的 `worker` 。如果你的算法在数据采集的逻辑上有所不同,则需要自行实现相应的 `worker` 。例如,如果你的算法需要对采集到的`transitions` 进行预处理,可以在 collector 文件中的 `collect` 函数下加入下面这一片段。其中 `get_train_sample` 函数实现了具体的数据处理过程。
+
+```Python
+if timestep.done:
+ # Prepare trajectory data.
+ transitions = to_tensor_transitions(self._traj_buffer[env_id])
+ # Use ``get_train_sample`` to process the data.
+ train_sample = self._policy.get_train_sample(transitions)
+ return_data.extend(train_sample)
+ self._traj_buffer[env_id].clear()
+```
+
+### 6. **测试你的策略**
+
+在你实现你的策略之后,确保策略的正确性和有效性是非常重要的。为此,你应该编写一些单元测试来验证你的策略是否正常工作。比如,你可以测试策略是否能在特定的环境中执行,策略的输出是否符合预期等。单元测试的编写及意义可以参考 DI-engine 中的[单元测试指南](https://di-engine-docs.readthedocs.io/zh_CN/latest/22_test/index_zh.html) ,你可以在 `lzero/policy/tests` 目录下添加你的测试。在编写测试时,尽可能考虑到所有可能的场景和边界条件,确保你的策略在各种情况下都能正常运行。
+下面是一个 LightZero 中单元测试的例子。在这个例子中,所测试的对象是 `inverse_scalar_transform` 和 `InverseScalarTransform` 方法。这两个方法都将经过变换的 value 逆变换为原本的值,但是采取了不同的实现。单元测试时,用这两个方法对同一组数据进行处理,并比较输出的结果是否相同。如果相同,则会通过测试。
+
+```Python
+import pytest
+import torch
+from lzero.policy.scaling_transform import inverse_scalar_transform, InverseScalarTransform
+
+@pytest.mark.unittest
+def test_scaling_transform():
+ import time
+ logit = torch.randn(16, 601)
+ start = time.time()
+ output_1 = inverse_scalar_transform(logit, 300)
+ print('t1', time.time() - start)
+ handle = InverseScalarTransform(300)
+ start = time.time()
+ output_2 = handle(logit)
+ print('t2', time.time() - start)
+ assert output_1.shape == output_2.shape == (16, 1)
+ assert (output_1 == output_2).all()
+```
+
+在单元测试文件中,要将测试通过 `@pytest.mark.unittest` 标记到python的测试框架中,这样就可以通过在命令行输入 `pytest -sv xxx.py` 直接运行单元测试文件。其中 `-sv` 是一个命令选项,表示在测试运行过程中将详细的信息打印到终端以便查看。
+
+### 7. **完整测试与运行**
+
+在确保策略的基本功能正常之后,你需要利用如 cartpole 等经典环境,对你的策略进行完整的正确性和收敛性测试。这是为了验证你的策略不仅能在单元测试中工作,而且能在实际游戏环境中有效工作。
+
+你可以仿照 [cartpole_muzero_config.py](https://github.com/opendilab/LightZero/blob/main/zoo/classic_control/cartpole/config/cartpole_muzero_config.py) 编写相关的配置文件和入口程序。在测试过程中,注意记录策略的性能数据,如每轮的得分、策略的收敛速度等,以便于分析和改进。
+
+### 8. **贡献**
+
+在你完成了所有以上步骤后,如果你希望把你的策略贡献到 LightZero 仓库中,你可以在官方仓库上提交 Pull Request 。在提交之前,请确保你的代码符合仓库的编码规范,所有测试都已通过,并且已经有足够的文档和注释来解释你的代码和策略。
+
+在 PR 的描述中,详细说明你的策略,包括它的工作原理,你的实现方法,以及在测试中的表现。这会帮助其他人理解你的贡献,并加速 PR 的审查过程。
+
+### 9. **分享讨论,反馈改进**
+
+完成策略实现和测试后,考虑将你的结果和经验分享给社区。你可以在论坛、博客或者社交媒体上发布你的策略和测试结果,邀请其他人对你的工作进行评价和讨论。这不仅可以得到其他人的反馈,还能帮助你建立专业网络,并可能引发新的想法和合作。
+
+基于你的测试结果和社区的反馈,不断改进和优化你的策略。这可能涉及到调整策略的参数,改进代码的性能,或者解决出现的问题和 bug 。记住,策略的开发是一个迭代的过程,永远有提升的空间。
+
+## 注意事项
+
+- 请确保你的代码符合 python PEP8 编码规范。
+- 当你在实现 `_forward_learn` 、 `_forward_collect` 和 `_forward_eval` 等方法时,请确保正确处理输入和返回的数据。
+- 在编写策略时,请确保考虑到不同的环境类型。你的策略应该能够处理不同的环境。
+- 在实现你的策略时,请尽可能使你的代码模块化,以便于其他人理解和重用你的代码。
+- 请编写清晰的文档和注释,描述你的策略如何工作,以及你的代码是如何实现这个策略的。
\ No newline at end of file
diff --git a/LightZero/docs/source/tutorials/envs/customize_envs.md b/LightZero/docs/source/tutorials/envs/customize_envs.md
new file mode 100644
index 0000000000000000000000000000000000000000..5d914307a72bfbd3dedf50552f155bbde89e9d6a
--- /dev/null
+++ b/LightZero/docs/source/tutorials/envs/customize_envs.md
@@ -0,0 +1,242 @@
+# **How to Customize Your Environments in LightZero?**
+
+When conducting reinforcement learning research or applications with LightZero, you may need to create a custom environment. Creating a custom environment can better adapt to specific problems or tasks, allowing the reinforcement learning algorithms to be effectively trained in those specific environments.
+
+For a typical environment in LightZero, please refer to `atari_lightzero_env.py`. The environment design of LightZero is largely based on the BaseEnv class in DI-engine. When creating a custom environment, we follow similar basic steps as in [DI-engine](https://di-engine-docs.readthedocs.io/en/latest/04_best_practice/ding_env.html).
+
+## Major Differences from BaseEnv
+
+In LightZero, there are many board game environments. Due to the alternating actions of players and the changing set of legal moves, the observation state of the environment in board game environments should include not only the board information but also action masks and current player information. Therefore, in LightZero, the `obs` is no longer an array like in DI-engine but a dictionary. The `observation` key in the dictionary corresponds to `obs` in DI-engine, and in addition, the dictionary contains information such as `action_mask` and `to_play`. For the sake of code compatibility, LightZero also requires the environment to return `obs` that include `action_mask`, `to_play`, and similar information for non-board game environments.
+
+In the specific implementation, these differences are primarily manifested in the following aspects:
+
+- In the `reset()` method, `LightZeroEnv` returns a dictionary `lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}`.
+ - For non-board game environments
+ - Regarding the setting of `to_play`: Since non-board game environments generally only have one player, `to_play` is set to `-1`. (In our algorithm, we judge whether to execute the single player algorithm logic (`to_play=-1`), or the multiple player algorithm logic (`to_play=N`) based on this value.)
+ - Regarding the setting of `action_mask`:
+ - Discrete action space: `action_mask= np.ones(self.env.action_space.n, 'int8')` is a numpy array of ones, indicating that all actions are legal actions.
+ - Continuous action space: `action_mask= None`, the special `None` indicates that the environment is a continuous action space.
+ - For board game environments: To facilitate the subsequent MCTS process, the `lightzero_obs_dict` may also include variables such as the board information `board` and the index of the current player `current_player_index`.
+- In the `step` method, `BaseEnvTimestep(lightzero_obs_dict, rew, done, info)` is returned, where `lightzero_obs_dict` contains the updated observation results.
+
+## Basic Steps
+
+Here are the basic steps to create a custom LightZero environment:
+
+### 1. Create the Environment Class
+First, you need to create a new environment class that inherits from the `BaseEnv` class in DI-engine. For example:
+
+```python
+from ding.envs import BaseEnv
+```
+
+### 2. **__init__ Method**
+In your custom environment class, you need to define an initialization method `__init__`. In this method, you need to set some basic properties of the environment, such as observation space, action space, reward space, etc. For example:
+
+```python
+def __init__(self, cfg=None):
+ self.cfg = cfg
+ self._init_flag = False
+ # set other properties...
+```
+
+### 3. **Reset Method**
+The `reset` method is used to reset the environment to an initial state. This method should return the initial observation of the environment. For example:
+
+```python
+def reset(self):
+ # reset the environment...
+ obs = self._env.reset()
+ # get the action_mask according to the legal action
+ ...
+ lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
+ return lightzero_obs_dict
+```
+
+### 4. **Step Method**
+The `step` method takes an action as input, executes this action, and returns a tuple containing the new observation, reward, whether it's done, and other information. For example:
+
+```python
+def step(self, action):
+ # The core original env step.
+ obs, rew, done, info = self.env.step(action)
+
+ if self.cfg.continuous:
+ action_mask = None
+ else:
+ # get the action_mask according to the legal action
+ action_mask = np.ones(self.env.action_space.n, 'int8')
+
+ lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
+
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+
+ return BaseEnvTimestep(lightzero_obs_dict, rew, done, info)
+```
+
+### 5. **Observation Space and Action Space**
+In a custom environment, you need to provide properties for observation space and action space. These properties are `gym.Space` objects that describe the shape and type of observations and actions. For example:
+
+```python
+@property
+defobservation_space(self):
+ return self.env.observation_space
+
+@property
+def action_space(self):
+ return self.env.action_space
+```
+
+### 6. **Render Method**
+The `render` method displays the gameplay of the game for users to observe. For environments that have implemented the `render` method, users can choose whether to call `render` during the execution of the `step` function to render the game state at each step.
+
+```python
+def render(self, mode: str = 'image_savefile_mode') -> None:
+ """
+ Overview:
+ Renders the game environment.
+ Arguments:
+ - mode (:obj:`str`): The rendering mode. Options are
+ 'state_realtime_mode',
+ 'image_realtime_mode',
+ or 'image_savefile_mode'.
+ """
+ # In 'state_realtime_mode' mode, print the current game board for rendering.
+ if mode == "state_realtime_mode":
+ ...
+ # In other two modes, use a screen for rendering.
+ # Draw the screen.
+ ...
+ if mode == "image_realtime_mode":
+ # Render the picture to user's window.
+ ...
+ elif mode == "image_savefile_mode":
+ # Save the picture to frames.
+ ...
+ self.frames.append(self.screen)
+ return None
+```
+
+In the `render` function, there are three different modes available:
+
+- In the `state_realtime_mode`, `render` directly prints the current state.
+- In the `image_realtime_mode`, `render` uses graphical assets to `render` the environment state, creating a visual interface and displaying it in a real-time window.
+- In the `image_savefile_mode`, `render` saves the rendered images in `self.frames` and converts them into files using `save_render_output` at the end of the game.
+
+During runtime, the mode used by render depends on the value of `self.render_mode`. If `self.render_mode` is set to None, the environment will not call the `render` method.
+
+### 7. **Other Methods**
+Depending on the requirement, you might also need to define other methods, such as `close` (for closing the environment and performing cleanup), etc.
+
+### 8. **Register the Environment**
+Lastly, you need to use the `ENV_REGISTRY.register` decorator to register your new environment so that it can be used in the configuration file. For example:
+
+```python
+from ding.utils import ENV_REGISTRY
+
+@ENV_REGISTRY.register('my_custom_env')
+class MyCustomEnv(BaseEnv):
+ # ...
+```
+
+Once the environment is registered, you can specify the creation of the corresponding environment in the `create_config` section of the configuration file:
+
+```python
+create_config = dict(
+ env=dict(
+ type='my_custom_env',
+ import_names=['zoo.board_games.my_custom_env.envs.my_custom_env'],
+ ),
+ ...
+)
+```
+
+In the configuration, the `type` should be set to the registered environment name, while the `import_names` should be set to the location of the environment package.
+
+Creating a custom environment may require a deep understanding of the specific task and reinforcement learning. When implementing a custom environment, you may need to experiment and adjust to make the environment effectively support reinforcement learning training.
+
+## **Special Methods for Board Game Environments**
+
+Here are the additional steps for creating custom board game environments in LightZero:
+
+1. There are three different modes for board game environments in LightZero: `self_play_mode`, `play_with_bot_mode`, and `eval_mode`. Here is an explanation of these modes:
+ - `self_play_mode`: In this mode, the environment follows the classical setup of board games. Each call to the `step` function places a move in the environment based on the provided action. At the time step when the game is decided, a reward of +1 is returned. In all other time steps where the game is not decided, the reward is 0.
+ - `play_with_bot_mode`: In this mode, each call to the `step` function places a move in the environment based on the provided action, followed by the bot generating an action and placing a move based on that action. In other words, the agent plays as player 1, and the bot plays as player 2 against the agent. At the end of the game, if the agent wins, a reward of +1 is returned. If the bot wins, a reward of -1 is returned. In case of a draw, the reward is 0. In all other time steps where the game is not decided, the reward is 0.
+ - `eval_mode`: This mode is used to evaluate the level of the current agent. There are two evaluation methods: bot evaluation and human evaluation. In bot evaluation, similar to play_with_bot_mode, the bot plays as player 2 against the agent, and the agent's win rate is calculated based on the results. In human evaluation, the user plays as player 2 and interacts with the agent by entering actions in the command line.
+
+ In each mode, at the end of the game, the `eval_episode_return` information from the perspective of player 1 is recorded (if player 1 wins, `eval_episode_return` is 1; if player 1 loses, it is -1; if it's a draw, it is 0), and it is logged in the last time step.
+
+2. In board game environments, as the game progresses, the available actions may decrease. Therefore, it is necessary to implement the `legal_action` method. This method can be used to validate the actions provided by the players and generate child nodes during the MCTS process. Taking the Connect4 environment as an example, this method checks if each column on the game board is full and returns a list. The value in the list is 1 for columns where a move can be made and 0 for other positions.
+
+```python
+def legal_actions(self) -> List[int]:
+ return [i for i in range(7) if self.board[i] == 0]
+```
+
+3. In LightZero's board game environments, additional action generation methods need to be implemented, such as `bot_action` and `random_action`. The `bot_action` method retrieves the corresponding type of bot based on the value of `self.bot_action_type` and generates an action using the pre-implemented algorithm in the bot. On the other hand, `random_action` selects a random action from the current list of legal actions. `bot_action` is used in the `play_with_bot_mode` to implement the interaction with the bot, while `random_action` is called with a certain probability during action selection by the agent and the bot to increase the randomness of the game samples.
+
+```python
+def bot_action(self) -> int:
+ if np.random.rand() < self.prob_random_action_in_bot:
+ return self.random_action()
+ else:
+ if self.bot_action_type == 'rule':
+ return self.rule_bot.get_rule_bot_action(self.board, self._current_player)
+ elif self.bot_action_type == 'mcts':
+ return self.mcts_bot.get_actions(self.board, player_index=self.current_player_index)
+```
+
+## **LightZeroEnvWrapper**
+
+We provide a [LightZeroEnvWrapper](https://github.com/opendilab/LightZero/blob/main/lzero/envs/wrappers/lightzero_env_wrapper.py) in the lzero/envs/wrappers directory. It wraps `classic_control` and `box2d` environments into the format required by LightZero. During initialization, an original environment is passed to the LightZeroEnvWrapper instance, which is initialized using the parent class `gym.Wrapper`. This allows the instance to call methods like `render`, `close`, and `seed` from the original environment. Based on this, the `LightZeroEnvWrapper` class overrides the `step` and `reset` methods to wrap their outputs into a dictionary `lightzero_obs_dict` that conforms to the requirements of LightZero. As a result, the wrapped environment instance meets the requirements of LightZero's custom environments.
+
+```python
+class LightZeroEnvWrapper(gym.Wrapper):
+ # overview comments
+ def __init__(self, env: gym.Env, cfg: EasyDict) -> None:
+ # overview comments
+ super().__init__(env)
+ ...
+```
+Specifically, use the following function to wrap a gym environment into the format required by LightZero using `LightZeroEnvWrapper`. The `get_wrappered_env` function returns an anonymous function that generates a `DingEnvWrapper` instance each time it is called. This instance takes `LightZeroEnvWrapper` as an anonymous function and internally wraps the original environment into the format required by LightZero.
+
+```python
+def get_wrappered_env(wrapper_cfg: EasyDict, env_name: str):
+ # overview comments
+ ...
+ if wrapper_cfg.manually_discretization:
+ return lambda: DingEnvWrapper(
+ gym.make(env_name),
+ cfg={
+ 'env_wrapper': [
+ lambda env: ActionDiscretizationEnvWrapper(env, wrapper_cfg), lambda env:
+ LightZeroEnvWrapper(env, wrapper_cfg)
+ ]
+ }
+ )
+ else:
+ return lambda: DingEnvWrapper(
+ gym.make(env_name), cfg={'env_wrapper': [lambda env: LightZeroEnvWrapper(env, wrapper_cfg)]}
+ )
+```
+
+Then call the `train_muzero_with_gym_env` method in the main entry point of the algorithm, and you can use the wrapped env for training:
+
+```python
+if __name__ == "__main__":
+ """
+ Overview:
+ The ``train_muzero_with_gym_env`` entry means that the environment used in the training process is generated by wrapping the original gym environment with LightZeroEnvWrapper.
+ Users can refer to lzero/envs/wrappers for more details.
+ """
+ from lzero.entry import train_muzero_with_gym_env
+ train_muzero_with_gym_env([main_config, create_config], seed=0, max_env_step=max_env_step)
+```
+
+## **Considerations**
+
+1. **State Representation**: Consider how to represent the environment state as an observation space. For simple environments, you can directly use low-dimensional continuous states; for complex environments, you might need to use images or other high-dimensional discrete states.
+2. **Preprocessing Observation Space**: Depending on the type of the observation space, perform appropriate preprocessing operations on the input data, such as scaling, cropping, graying, normalization, etc. Preprocessing can reduce the dimension of input data and accelerate the learning process.
+3. **Reward Design**: Design a reasonable reward function that aligns with the goal. For example, try to normalize the extrinsic reward given by the environment to \[0, 1\]. By normalizing the extrinsic reward given by the environment, you can better determine the weight of the intrinsic reward and other hyperparameters in the RND algorithm.
\ No newline at end of file
diff --git a/LightZero/docs/source/tutorials/envs/customize_envs_zh.md b/LightZero/docs/source/tutorials/envs/customize_envs_zh.md
new file mode 100644
index 0000000000000000000000000000000000000000..c995f9fe7daeaba1e3787ffe04e7c7240b13d437
--- /dev/null
+++ b/LightZero/docs/source/tutorials/envs/customize_envs_zh.md
@@ -0,0 +1,255 @@
+# LightZero 中如何自定义环境?
+
+- 在使用 LightZero 进行强化学习的研究或应用时,可能需要创建自定义的环境。创建自定义环境可以更好地适应特定的问题或任务,使得强化学习算法能够在特定环境中进行有效的训练。
+- 一个典型的 LightZero 中的环境,请参考 [atari_lightzero_env.py](https://github.com/opendilab/LightZero/blob/main/zoo/atari/envs/atari_lightzero_env.py) 。LightZero的环境设计大致基于DI-engine的`BaseEnv`类。在创建自定义环境时,我们遵循了与DI-engine相似的基本步骤。以下是 DI-engine 中创建自定义环境的文档
+ - https://di-engine-docs.readthedocs.io/zh_CN/latest/04_best_practice/ding_env_zh.html
+
+## 与 BaseEnv 的主要差异
+
+在 LightZero 中,有很多棋类环境。棋类环境由于存在玩家交替执行动作,合法动作在变化的情况,所以环境的观测状态除了棋面信息,还应包含动作掩码,当前玩家等信息。因此,LightZero 中的 `obs` 不再像 DI-engine 中那样是一个数组,而是一个字典。字典中的 `'observation'` 对应于 DI-engine 中的 `obs`,此外字典中还包含了 `'action_mask'`、`'to_play'` 等信息。为了代码的兼容性,对于非棋类环境,LightZero 同样要求环境返回的 `obs` 包含`'action_mask'`、`'to_play'` 等信息。
+
+在具体的方法实现中,这种差异主要体现在下面几点:
+
+- 在 `reset` 方法中,LightZeroEnv 返回的是一个字典 `lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}` 。
+ - 对于非棋类环境
+ - `to_play` 的设置:由于非棋类环境一般只有一个玩家,因此设置 `to_play` =-1 。(我们在算法中根据该值,判断执行单player的算法逻辑 (`to_play` =-1) ,还是多player的算法逻辑 (`to_play`=N) )
+ - 对于 `action_mask` 的设置
+ - 离散动作空间: `action_mask`= np.ones(self.env.action_space.n, 'int8') 是一个全1的numpy数组,表示所有动作都是合法动作。
+ - 连续动作空间: `action_mask` = None ,特殊的 None 表示环境是连续动作空间。
+ - 对于棋类环境:为了方便后续 MCTS 流程, `lightzero_obs_dict ` 中可能还会增加棋面信息 `board` 和当前玩家 `curren_player_index` 等变量。
+- 在 `step` 方法中,返回的是 `BaseEnvTimestep(lightzero_obs_dict, rew, done, info)` ,其中的 `lightzero_obs_dict` 包含了更新后的观察结果。
+
+## 基本步骤
+
+以下是创建自定义 LightZero 环境的基本步骤:
+
+### 1. 创建环境类
+
+首先,需要创建一个新的环境类,该类需要继承自 DI-engine 的 BaseEnv 类。例如:
+
+```Python
+from ding.envs import BaseEnv
+
+class MyCustomEnv(BaseEnv):
+ pass
+```
+
+### 2. __init__方法
+
+在自定义环境类中,需要定义一个初始化方法 `__init__` 。在这个方法中,需要设置一些环境的基本属性,例如观察空间、动作空间、奖励空间等。例如:
+
+```Python
+def __init__(self, cfg=None):
+ self.cfg = cfg
+ self._init_flag = False
+ # set other properties...
+```
+
+### 3. Reset 方法
+
+`reset` 方法用于重置环境到一个初始状态。这个方法应该返回环境的初始观察。例如:
+
+```Python
+def reset(self):
+ # reset the environment...
+ obs = self._env.reset()
+ # get the action_mask according to the legal action
+ ...
+ lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
+ return lightzero_obs_dict
+```
+
+### 4. Step 方法
+
+`step` 方法接受一个动作作为输入,执行这个动作,并返回一个元组,包含新的观察、奖励、是否完成和其他信息。例如:
+
+```Python
+def step(self, action):
+ # The core original env step.
+ obs, rew, done, info = self.env.step(action)
+
+ if self.cfg.continuous:
+ action_mask = None
+ else:
+ # get the action_mask according to the legal action
+ action_mask = np.ones(self.env.action_space.n, 'int8')
+
+ lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
+
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+
+ return BaseEnvTimestep(lightzero_obs_dict, rew, done, info)
+```
+
+### 5. 观察空间和动作空间
+
+在自定义环境中,需要提供观察空间和动作空间的属性。这些属性是 `gym.Space` 对象,描述了观察和动作的形状和类型。例如:
+
+```Python
+@property
+def observation_space(self):
+ return self._observation_space
+
+@property
+def action_space(self):
+ return self._action_space
+
+@property
+def legal_actions(self):
+ # get the actual legal actions
+ return np.arange(self._action_space.n)
+```
+
+### 6. render 方法
+
+`render` 方法会将游戏的对局演示出来,供用户查看。对于实现了 `render` 方法的环境,用户可以选择是否在执行 `step` 函数时调用 `render` 来实现每一步游戏状态的渲染。
+
+```Python
+def render(self, mode: str = 'image_savefile_mode') -> None:
+ """
+ Overview:
+ Renders the game environment.
+ Arguments:
+ - mode (:obj:`str`): The rendering mode. Options are
+ 'state_realtime_mode',
+ 'image_realtime_mode',
+ or 'image_savefile_mode'.
+ """
+ # In 'state_realtime_mode' mode, print the current game board for rendering.
+ if mode == "state_realtime_mode":
+ ...
+ # In other two modes, use a screen for rendering.
+ # Draw the screen.
+ ...
+ if mode == "image_realtime_mode":
+ # Render the picture to user's window.
+ ...
+ elif mode == "image_savefile_mode":
+ # Save the picture to frames.
+ ...
+ self.frames.append(self.screen)
+ return None
+```
+
+在 `render` 中,有三种不同的模式。
+- 在 `state_realtime_mode` 下,`render` 会直接打印当前状态。
+- 在 `image_realtime_mode` 下, `render` 会根据一些图形素材将环境状态渲染出来,形成可视化的界面,并弹出实时的窗口展示。
+- 在 `image_savefile_mode` 下, `render` 会将渲染的图像保存在 `self.frames` 中,并在对局结束时通过 `save_render_output` 将其转化为文件保存下来。
+在运行时, `render` 所采取的模式取决于 `self.render_mode` 的取值。当 `self.render_mode` 取值为 `None` 时,环境不会调用 `render` 方法。
+
+### 7. 其他方法
+
+根据需要,可能还需要定义其他方法,例如 `close` (用于关闭环境并进行清理)等。
+
+### 8. 注册环境
+
+最后,需要使用 `ENV_REGISTRY.register` 装饰器来注册新的环境,使得可以在配置文件中使用它。例如:
+
+```Python
+from ding.utils import ENV_REGISTRY
+
+@ENV_REGISTRY.register('my_custom_env')
+class MyCustomEnv(BaseEnv):
+ # ...
+```
+
+当环境注册好之后,可以在配置文件中的 `create_config` 部分指定生成相应的环境:
+
+```Python
+create_config = dict(
+ env=dict(
+ type='my_custom_env',
+ import_names=['zoo.board_games.my_custom_env.envs.my_custom_env'],
+ ),
+ ...
+)
+```
+
+其中 `type` 要设定为所注册的环境名, `import_names` 则设置为环境包的位置。
+
+创建自定义环境可能需要对具体的任务和强化学习有深入的理解。在实现自定义环境时,可能需要进行一些试验和调整,以使环境能够有效地支持强化学习的训练。
+
+## 棋类环境的特殊方法
+
+以下是创建自定义 LightZero 棋类环境的额外步骤:
+1. LightZero中的棋类环境有三种不同的模式: `self_play_mode` , `play_with_bot_mode` , `eval_mode` 。这三种模式的说明如下:
+ - `self_play_mode`:该模式下,采取棋类环境的经典设置,每调用一次 `step` 函数,会根据传入的动作在环境中落子一次。在分出胜负的时间步,会返回+1的 reward 。在没有分出胜负的所有时间步, reward 均为0。
+ - `play_with_bot_mode`:该模式下,每调用一次 `step` 函数,会根据传入的动作在环境中落子一次,随后调用环境中的 bot 产生一个动作,并根据 bot 的动作再落子一次。也就是说, agent 扮演了1号玩家的角色,而 bot 扮演了2号玩家的角色和 agent 对抗。在对局结束时,如果 agent 胜利,则返回+1的 reward ,如果 bot 胜利,则返回-1的 reward ,平局则 reward 为0。在其余没有分出胜负的时间步, reward 均为0。
+ - `eval_mode`:该模式用于评估当前的 agent 的水平。具体有 bot 和 human 两种评估方法。采取 bot 评估时,和 play_with_bot_mode 中一样,会让 bot 扮演2号玩家和 agent 对抗,并根据结果计算 agent 的胜率。采取 human 模式时,则让用户扮演2号玩家,在命令行输入动作和 agent 对打。
+
+ 每种模式下,在棋局结束后,都会从1号玩家的视角记录本局的 `eval_episode_return` 信息(如果1号玩家赢了,则 `eval_episode_return` 为1,如果输了为-1,平局为0),并记录在最后一个时间步中。
+2. 在棋类环境中,随着对局的推进,可以采取的动作会不断变少,因此还需要实现 `legal_action` 方法。该方法可以用于检验玩家输入的动作是否合法,以及在 MCTS 过程中根据合法动作生成子节点。以 Connect4 环境为例,该方法会检查棋盘中的每一列是否下满,然后返回一个列表。该列表在可以落子的列取值为1,其余位置取值为0。
+
+```Python
+def legal_actions(self) -> List[int]:
+ return [i for i in range(7) if self.board[i] == 0]
+```
+
+3. LightZero的棋类环境中,还需要实现一些动作生成方法,例如 `bot_action` 和 `random_action` 。其中 `bot_action` 会根据 `self.bot_action_type` 的值调取相应种类的 bot ,通过 bot 中预实现的算法生成一个动作。而 `random_action` 则会从当前的合法动作列表中随机选取一个动作返回。 `bot_action` 用于实现环境的 `play_with_bot_mode` ,而 `random_action` 则会在 agent 和 bot 选取动作时依一定概率被调用,来增加对局样本的随机性。
+
+```Python
+def bot_action(self) -> int:
+ if np.random.rand() < self.prob_random_action_in_bot:
+ return self.random_action()
+ else:
+ if self.bot_action_type == 'rule':
+ return self.rule_bot.get_rule_bot_action(self.board, self._current_player)
+ elif self.bot_action_type == 'mcts':
+ return self.mcts_bot.get_actions(self.board, player_index=self.current_player_index)
+```
+
+## LightZeroEnvWrapper
+
+我们在 lzero/envs/wrappers 中提供了一个 [LightZeroEnvWrapper](https://github.com/opendilab/LightZero/blob/main/lzero/envs/wrappers/lightzero_env_wrapper.py)。它能够将经典的 `classic_control`, `box2d` 环境包装成 LightZero 所需要的环境格式。在初始化实例时,会传入一个原始环境,这个原始环境通过父类 `gym.Wrapper` 被初始化,这使得实例可以调用原始环境中的 `render` , `close` , `seed` 等方法。在此基础上, `LightZeroEnvWrapper` 类重写了 `step` 和 `reset` 方法,将其输出封装成符合 LightZero 要求的字典 `lightzero_obs_dict` 。这样一来,封装后的新环境实例就满足了 LightZero 自定义环境的要求。
+
+```Python
+class LightZeroEnvWrapper(gym.Wrapper):
+ # overview comments
+ def __init__(self, env: gym.Env, cfg: EasyDict) -> None:
+ # overview comments
+ super().__init__(env)
+ ...
+```
+
+具体使用时,使用下面的函数,将一个 gym 环境,通过 `LightZeroEnvWrapper` 包装成 LightZero 所需要的环境格式。 `get_wrappered_env` 会返回一个匿名函数,该匿名函数每次调用都会产生一个 `DingEnvWrapper` 实例,该实例会将 `LightZeroEnvWrapper` 作为匿名函数传入,并在实例内部将原始环境封装成 LightZero 所需的格式。
+
+```Python
+def get_wrappered_env(wrapper_cfg: EasyDict, env_name: str):
+ # overview comments
+ ...
+ if wrapper_cfg.manually_discretization:
+ return lambda: DingEnvWrapper(
+ gym.make(env_name),
+ cfg={
+ 'env_wrapper': [
+ lambda env: ActionDiscretizationEnvWrapper(env, wrapper_cfg), lambda env:
+ LightZeroEnvWrapper(env, wrapper_cfg)
+ ]
+ }
+ )
+ else:
+ return lambda: DingEnvWrapper(
+ gym.make(env_name), cfg={'env_wrapper': [lambda env: LightZeroEnvWrapper(env, wrapper_cfg)]}
+ )
+```
+
+然后在算法的主入口处中调用 `train_muzero_with_gym_env` 方法,即可使用上述包装后的 env 用于训练:
+
+```Python
+if __name__ == "__main__":
+ """
+ Overview:
+ The ``train_muzero_with_gym_env`` entry means that the environment used in the training process is generated by wrapping the original gym environment with LightZeroEnvWrapper.
+ Users can refer to lzero/envs/wrappers for more details.
+ """
+ from lzero.entry import train_muzero_with_gym_env
+ train_muzero_with_gym_env([main_config, create_config], seed=0, max_env_step=max_env_step)
+```
+
+## 注意事项
+
+- 状态表示:思考如何将环境状态表示为观察空间。对于简单的环境,可以直接使用低维连续状态;对于复杂的环境,可能需要使用图像或其他高维离散状态表示。
+- 观察空间预处理:根据观察空间的类型,对输入数据进行适当的预处理操作,例如缩放、裁剪、灰度化、归一化等。预处理可以减少输入数据的维度,加速学习过程。
+- 奖励设计:设计合理的符合目标的的奖励函数。例如,环境给出的外在奖励尽量归一化在[0, 1]。通过归一化环境给出的外在奖励,能更好的确定 RND 算法中的内在奖励权重等超参数。
diff --git a/LightZero/docs/source/tutorials/installation/index.rst b/LightZero/docs/source/tutorials/installation/index.rst
new file mode 100644
index 0000000000000000000000000000000000000000..27f35b17d6ff1a9414ec977084fd768cefc5d3f7
--- /dev/null
+++ b/LightZero/docs/source/tutorials/installation/index.rst
@@ -0,0 +1,22 @@
+Installation
+===================
+
+LightZero is currently hosted on PyPI. It required python >= 3.7.
+
+You can simply install LightZero from PyPI with the following command:
+
+.. code:: shell
+
+ pip install LightZero
+
+You can also install with the newest version through GitHub:
+
+.. code:: shell
+
+ pip install -U git+https://github.com/opendilab/LightZero.git@main
+
+
+In newest version of LightZero, cli is supported to do some \
+data processing. Here is the version and help display.
+
+LightZero is still under development, you can also check out the documents in stable version through `https://opendilab.github.io/LightZero/ `_.
diff --git a/LightZero/format.sh b/LightZero/format.sh
new file mode 100755
index 0000000000000000000000000000000000000000..506ac0243aa6ea14c3c3093077624a27897fe1a9
--- /dev/null
+++ b/LightZero/format.sh
@@ -0,0 +1,20 @@
+#!/usr/bin/env bash
+# Usage: at the root dir >> bash scripts/format.sh .
+
+# Check yapf version. (20200318 latest is 0.29.0. Format might be changed in future version.)
+ver=$(yapf --version)
+if ! echo $ver | grep -q 0.29.0; then
+ echo "Wrong YAPF version installed: 0.29.0 is required, not $ver. $YAPF_DOWNLOAD_COMMAND_MSG"
+ exit 1
+fi
+
+yapf --in-place --recursive -p --verbose --style .style.yapf $1
+
+if [[ "$2" == '--test' ]]; then # Only for CI usage, user should not use --test flag.
+ if ! git diff --quiet &>/dev/null; then
+ echo '*** You have not reformatted your codes! Please run [bash format.sh] at root directory before commit! Thanks! ***'
+ exit 1
+ else
+ echo "Code style test passed!"
+ fi
+fi
diff --git a/LightZero/lzero/__init__.py b/LightZero/lzero/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c752f04a7bac07453aee789d7a2b5562cf363ec8
--- /dev/null
+++ b/LightZero/lzero/__init__.py
@@ -0,0 +1 @@
+from .config.meta import __VERSION__ as __version__
diff --git a/LightZero/lzero/agent/__init__.py b/LightZero/lzero/agent/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6cbb38df89f375ec9162f9a79f07b99e6db20e2
--- /dev/null
+++ b/LightZero/lzero/agent/__init__.py
@@ -0,0 +1 @@
+from .muzero import MuZeroAgent
diff --git a/LightZero/lzero/agent/config/__init__.py b/LightZero/lzero/agent/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/LightZero/lzero/agent/config/muzero/__init__.py b/LightZero/lzero/agent/config/muzero/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e8b937645e5891d4a0723ef7454d3a5d3b536bf4
--- /dev/null
+++ b/LightZero/lzero/agent/config/muzero/__init__.py
@@ -0,0 +1,8 @@
+from easydict import EasyDict
+from . import gym_cartpole_v0
+
+supported_env_cfg = {
+ gym_cartpole_v0.cfg.main_config.env.env_id: gym_cartpole_v0.cfg,
+}
+
+supported_env_cfg = EasyDict(supported_env_cfg)
diff --git a/LightZero/lzero/agent/config/muzero/gym_cartpole_v0.py b/LightZero/lzero/agent/config/muzero/gym_cartpole_v0.py
new file mode 100644
index 0000000000000000000000000000000000000000..c46a67a0dabf1ec8fbed2f3f63cbb1485799071b
--- /dev/null
+++ b/LightZero/lzero/agent/config/muzero/gym_cartpole_v0.py
@@ -0,0 +1,76 @@
+from easydict import EasyDict
+
+# ==============================================================
+# begin of the most frequently changed config specified by the user
+# ==============================================================
+collector_env_num = 8
+n_episode = 8
+evaluator_env_num = 3
+num_simulations = 25
+update_per_collect = 100
+batch_size = 256
+max_env_step = int(1e5)
+reanalyze_ratio = 0
+# ==============================================================
+# end of the most frequently changed config specified by the user
+# ==============================================================
+
+cfg = dict(
+ main_config=dict(
+ exp_name='CartPole-v0-MuZero',
+ seed=0,
+ env=dict(
+ env_id='CartPole-v0',
+ continuous=False,
+ manually_discretization=False,
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ n_evaluator_episode=evaluator_env_num,
+ manager=dict(shared_memory=False, ),
+ ),
+ policy=dict(
+ model=dict(
+ observation_shape=4,
+ action_space_size=2,
+ model_type='mlp',
+ lstm_hidden_size=128,
+ latent_state_dim=128,
+ self_supervised_learning_loss=True, # NOTE: default is False.
+ discrete_action_encoding_type='one_hot',
+ norm_type='BN',
+ ),
+ cuda=True,
+ env_type='not_board_games',
+ game_segment_length=50,
+ update_per_collect=update_per_collect,
+ batch_size=batch_size,
+ optim_type='Adam',
+ lr_piecewise_constant_decay=False,
+ learning_rate=0.003,
+ ssl_loss_weight=2, # NOTE: default is 0.
+ num_simulations=num_simulations,
+ reanalyze_ratio=reanalyze_ratio,
+ n_episode=n_episode,
+ eval_freq=int(2e2),
+ replay_buffer_size=int(1e6), # the size/capacity of replay_buffer, in the terms of transitions.
+ collector_env_num=collector_env_num,
+ evaluator_env_num=evaluator_env_num,
+ ),
+ wandb_logger=dict(
+ gradient_logger=False, video_logger=False, plot_logger=False, action_logger=False, return_logger=False
+ ),
+ ),
+ create_config=dict(
+ env=dict(
+ type='cartpole_lightzero',
+ import_names=['zoo.classic_control.cartpole.envs.cartpole_lightzero_env'],
+ ),
+ env_manager=dict(type='subprocess'),
+ policy=dict(
+ type='muzero',
+ import_names=['lzero.policy.muzero'],
+ ),
+ ),
+)
+
+cfg = EasyDict(cfg)
diff --git a/LightZero/lzero/agent/muzero.py b/LightZero/lzero/agent/muzero.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ddb436aecf34962b97e9283feed590ddc4ea98e
--- /dev/null
+++ b/LightZero/lzero/agent/muzero.py
@@ -0,0 +1,422 @@
+import os
+from functools import partial
+from typing import Optional, Union, List
+
+import numpy as np
+import torch
+from ding.bonus.common import TrainingReturn, EvalReturn
+from ding.config import save_config_py, compile_config
+from ding.envs import create_env_manager
+from ding.envs import get_vec_env_setting
+from ding.policy import create_policy
+from ding.rl_utils import get_epsilon_greedy_fn
+from ding.utils import set_pkg_seed, get_rank
+from ding.worker import BaseLearner
+from ditk import logging
+from easydict import EasyDict
+from tensorboardX import SummaryWriter
+
+from lzero.agent.config.muzero import supported_env_cfg
+from lzero.entry.utils import log_buffer_memory_usage, random_collect
+from lzero.mcts import MuZeroGameBuffer
+from lzero.policy import visit_count_temperature
+from lzero.policy.muzero import MuZeroPolicy
+from lzero.policy.random_policy import LightZeroRandomPolicy
+from lzero.worker import MuZeroCollector as Collector
+from lzero.worker import MuZeroEvaluator as Evaluator
+
+
+class MuZeroAgent:
+ """
+ Overview:
+ Agent class for executing MuZero algorithms which include methods for training, deployment, and batch evaluation.
+ Interfaces:
+ __init__, train, deploy, batch_evaluate
+ Properties:
+ best
+
+ .. note::
+ This agent class is tailored for use with the HuggingFace Model Zoo for LightZero
+ (e.g. https://huggingface.co/OpenDILabCommunity/CartPole-v0-MuZero),
+ and provides methods such as "train" and "deploy".
+ """
+
+ supported_env_list = list(supported_env_cfg.keys())
+
+ def __init__(
+ self,
+ env_id: str = None,
+ seed: int = 0,
+ exp_name: str = None,
+ model: Optional[torch.nn.Module] = None,
+ cfg: Optional[Union[EasyDict, dict]] = None,
+ policy_state_dict: str = None,
+ ) -> None:
+ """
+ Overview:
+ Initialize the MuZeroAgent instance with environment parameters, model, and configuration.
+ Arguments:
+ - env_id (:obj:`str`): Identifier for the environment to be used, registered in gym.
+ - seed (:obj:`int`): Random seed for reproducibility. Defaults to 0.
+ - exp_name (:obj:`Optional[str]`): Name for the experiment. Defaults to None.
+ - model (:obj:`Optional[torch.nn.Module]`): PyTorch module to be used as the model. If None, a default model is created. Defaults to None.
+ - cfg (:obj:`Optional[Union[EasyDict, dict]]`): Configuration for the agent. If None, default configuration will be used. Defaults to None.
+ - policy_state_dict (:obj:`Optional[str]`): Path to a pre-trained model state dictionary. If provided, state dict will be loaded. Defaults to None.
+
+ .. note::
+ - If `env_id` is not specified, it must be included in `cfg`.
+ - The `supported_env_list` contains all the environment IDs that are supported by this agent.
+ """
+ assert env_id is not None or cfg["main_config"]["env_id"] is not None, "Please specify env_id or cfg."
+
+ if cfg is not None and not isinstance(cfg, EasyDict):
+ cfg = EasyDict(cfg)
+
+ if env_id is not None:
+ assert env_id in MuZeroAgent.supported_env_list, "Please use supported envs: {}".format(
+ MuZeroAgent.supported_env_list
+ )
+ if cfg is None:
+ cfg = supported_env_cfg[env_id]
+ else:
+ assert cfg.main_config.env.env_id == env_id, "env_id in cfg should be the same as env_id in args."
+ else:
+ assert hasattr(cfg.main_config.env, "env_id"), "Please specify env_id in cfg."
+ assert cfg.main_config.env.env_id in MuZeroAgent.supported_env_list, "Please use supported envs: {}".format(
+ MuZeroAgent.supported_env_list
+ )
+ default_policy_config = EasyDict({"policy": MuZeroPolicy.default_config()})
+ default_policy_config.policy.update(cfg.main_config.policy)
+ cfg.main_config.policy = default_policy_config.policy
+
+ if exp_name is not None:
+ cfg.main_config.exp_name = exp_name
+ self.origin_cfg = cfg
+ self.cfg = compile_config(
+ cfg.main_config, seed=seed, env=None, auto=True, policy=MuZeroPolicy, create_cfg=cfg.create_config
+ )
+ self.exp_name = self.cfg.exp_name
+
+ logging.getLogger().setLevel(logging.INFO)
+ self.seed = seed
+ set_pkg_seed(self.seed, use_cuda=self.cfg.policy.cuda)
+ if not os.path.exists(self.exp_name):
+ os.makedirs(self.exp_name)
+ save_config_py(cfg, os.path.join(self.exp_name, 'policy_config.py'))
+ if model is None:
+ if self.cfg.policy.model.model_type == 'mlp':
+ from lzero.model.muzero_model_mlp import MuZeroModelMLP
+ model = MuZeroModelMLP(**self.cfg.policy.model)
+ elif self.cfg.policy.model.model_type == 'conv':
+ from lzero.model.muzero_model import MuZeroModel
+ model = MuZeroModel(**self.cfg.policy.model)
+ else:
+ raise NotImplementedError
+ if self.cfg.policy.cuda and torch.cuda.is_available():
+ self.cfg.policy.device = 'cuda'
+ else:
+ self.cfg.policy.device = 'cpu'
+ self.policy = create_policy(self.cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
+ if policy_state_dict is not None:
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ self.checkpoint_save_dir = os.path.join(self.exp_name, "ckpt")
+
+ self.env_fn, self.collector_env_cfg, self.evaluator_env_cfg = get_vec_env_setting(self.cfg.env)
+
+ def train(
+ self,
+ step: int = int(1e7),
+ ) -> TrainingReturn:
+ """
+ Overview:
+ Train the agent through interactions with the environment.
+ Arguments:
+ - step (:obj:`int`): Total number of environment steps to train for. Defaults to 10 million (1e7).
+ Returns:
+ - A `TrainingReturn` object containing training information, such as logs and potentially a URL to a training dashboard.
+ .. note::
+ The method involves interacting with the environment, collecting experience, and optimizing the model.
+ """
+
+ collector_env = create_env_manager(
+ self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.collector_env_cfg]
+ )
+ evaluator_env = create_env_manager(
+ self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg]
+ )
+
+ collector_env.seed(self.cfg.seed)
+ evaluator_env.seed(self.cfg.seed, dynamic_seed=False)
+ set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda)
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(self.cfg.exp_name), 'serial')
+ ) if get_rank() == 0 else None
+ learner = BaseLearner(
+ self.cfg.policy.learn.learner, self.policy.learn_mode, tb_logger, exp_name=self.cfg.exp_name
+ )
+
+ # ==============================================================
+ # MCTS+RL algorithms related core code
+ # ==============================================================
+ policy_config = self.cfg.policy
+ batch_size = policy_config.batch_size
+ # specific game buffer for MCTS+RL algorithms
+ replay_buffer = MuZeroGameBuffer(policy_config)
+ collector = Collector(
+ env=collector_env,
+ policy=self.policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=self.cfg.exp_name,
+ policy_config=policy_config
+ )
+ evaluator = Evaluator(
+ eval_freq=self.cfg.policy.eval_freq,
+ n_evaluator_episode=self.cfg.env.n_evaluator_episode,
+ stop_value=self.cfg.env.stop_value,
+ env=evaluator_env,
+ policy=self.policy.eval_mode,
+ tb_logger=tb_logger,
+ exp_name=self.cfg.exp_name,
+ policy_config=policy_config
+ )
+
+ # ==============================================================
+ # Main loop
+ # ==============================================================
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ if self.cfg.policy.update_per_collect is not None:
+ update_per_collect = self.cfg.policy.update_per_collect
+
+ # The purpose of collecting random data before training:
+ # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
+ # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
+ if self.cfg.policy.random_collect_episode_num > 0:
+ random_collect(self.cfg.policy, self.policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
+
+ while True:
+ log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
+ collect_kwargs = {}
+ # set temperature for visit count distributions according to the train_iter,
+ # please refer to Appendix D in MuZero paper for details.
+ collect_kwargs['temperature'] = visit_count_temperature(
+ policy_config.manual_temperature_decay,
+ policy_config.fixed_temperature_value,
+ policy_config.threshold_training_steps_for_final_temperature,
+ trained_steps=learner.train_iter
+ )
+
+ if policy_config.eps.eps_greedy_exploration_in_collect:
+ epsilon_greedy_fn = get_epsilon_greedy_fn(
+ start=policy_config.eps.start,
+ end=policy_config.eps.end,
+ decay=policy_config.eps.decay,
+ type_=policy_config.eps.type
+ )
+ collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
+ else:
+ collect_kwargs['epsilon'] = 0.0
+
+ # Evaluate policy performance.
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+
+ # Collect data by default config n_sample/n_episode.
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ if self.cfg.policy.update_per_collect is None:
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
+ update_per_collect = int(collected_transitions_num * self.cfg.policy.model_update_ratio)
+ # save returned new_data collected by the collector
+ replay_buffer.push_game_segments(new_data)
+ # remove the oldest data if the replay buffer is full.
+ replay_buffer.remove_oldest_data_to_fit()
+
+ # Learn policy from collected data.
+ for i in range(update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ if replay_buffer.get_num_of_transitions() > batch_size:
+ train_data = replay_buffer.sample(batch_size, self.policy)
+ else:
+ logging.warning(
+ f'The data in replay_buffer is not sufficient to sample a mini-batch: '
+ f'batch_size: {batch_size}, '
+ f'{replay_buffer} '
+ f'continue to collect now ....'
+ )
+ break
+
+ # The core train steps for MCTS+RL algorithms.
+ log_vars = learner.train(train_data, collector.envstep)
+
+ if self.cfg.policy.use_priority:
+ replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
+
+ if collector.envstep >= step:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+
+ return TrainingReturn(wandb_url=None)
+
+ def deploy(
+ self,
+ enable_save_replay: bool = False,
+ concatenate_all_replay: bool = False,
+ replay_save_path: str = None,
+ seed: Optional[Union[int, List]] = None,
+ debug: bool = False
+ ) -> EvalReturn:
+ """
+ Overview:
+ Deploy the agent for evaluation in the environment, with optional replay saving. The performance of the
+ agent will be evaluated. Average return and standard deviation of the return will be returned.
+ If `enable_save_replay` is True, replay videos are saved in the specified `replay_save_path`.
+ Arguments:
+ - enable_save_replay (:obj:`bool`): Flag to enable saving of replay footage. Defaults to False.
+ - concatenate_all_replay (:obj:`bool`): Whether to concatenate all replay videos into one file. Defaults to False.
+ - replay_save_path (:obj:`Optional[str]`): Directory path to save replay videos. Defaults to None, which sets a default path.
+ - seed (:obj:`Optional[Union[int, List[int]]]`): Seed or list of seeds for environment reproducibility. Defaults to None.
+ - debug (:obj:`bool`): Whether to enable the debug mode. Default to False.
+ Returns:
+ - An `EvalReturn` object containing evaluation metrics such as mean and standard deviation of returns.
+ """
+
+ deply_configs = [self.evaluator_env_cfg[0]]
+
+ if type(seed) == int:
+ seed_list = [seed]
+ elif seed:
+ seed_list = seed
+ else:
+ seed_list = [0]
+
+ reward_list = []
+
+ if enable_save_replay:
+ replay_save_path = replay_save_path if replay_save_path is not None else os.path.join(
+ self.exp_name, 'videos'
+ )
+ deply_configs[0]['replay_path'] = replay_save_path
+
+ for seed in seed_list:
+
+ evaluator_env = create_env_manager(self.cfg.env.manager, [partial(self.env_fn, cfg=deply_configs[0])])
+
+ evaluator_env.seed(seed if seed is not None else self.cfg.seed, dynamic_seed=False)
+ set_pkg_seed(seed if seed is not None else self.cfg.seed, use_cuda=self.cfg.policy.cuda)
+
+ # ==============================================================
+ # MCTS+RL algorithms related core code
+ # ==============================================================
+ policy_config = self.cfg.policy
+
+ evaluator = Evaluator(
+ eval_freq=self.cfg.policy.eval_freq,
+ n_evaluator_episode=1,
+ stop_value=self.cfg.env.stop_value,
+ env=evaluator_env,
+ policy=self.policy.eval_mode,
+ exp_name=self.cfg.exp_name,
+ policy_config=policy_config
+ )
+
+ # ==============================================================
+ # Main loop
+ # ==============================================================
+
+ stop, reward = evaluator.eval()
+ reward_list.extend(reward['eval_episode_return'])
+
+ if enable_save_replay:
+ files = os.listdir(replay_save_path)
+ files = [file for file in files if file.endswith('0.mp4')]
+ files.sort()
+ if concatenate_all_replay:
+ # create a file named 'files.txt' to store the names of all mp4 files
+ with open(os.path.join(replay_save_path, 'files.txt'), 'w') as f:
+ for file in files:
+ f.write("file '{}'\n".format(file))
+
+ # combine all the mp4 files into one mp4 file, rename it as 'deploy.mp4'
+ os.system(
+ 'ffmpeg -f concat -safe 0 -i {} -c copy {}/deploy.mp4'.format(
+ os.path.join(replay_save_path, 'files.txt'), replay_save_path
+ )
+ )
+
+ return EvalReturn(eval_value=np.mean(reward_list), eval_value_std=np.std(reward_list))
+
+ def batch_evaluate(
+ self,
+ n_evaluator_episode: int = None,
+ ) -> EvalReturn:
+ """
+ Overview:
+ Perform a batch evaluation of the agent over a specified number of episodes: ``n_evaluator_episode``.
+ Arguments:
+ - n_evaluator_episode (:obj:`Optional[int]`): Number of episodes to run the evaluation.
+ If None, uses default value from configuration. Defaults to None.
+ Returns:
+ - An `EvalReturn` object with evaluation results such as mean and standard deviation of returns.
+
+ .. note::
+ This method evaluates the agent's performance across multiple episodes to gauge its effectiveness.
+ """
+ evaluator_env = create_env_manager(
+ self.cfg.env.manager, [partial(self.env_fn, cfg=c) for c in self.evaluator_env_cfg]
+ )
+
+ evaluator_env.seed(self.cfg.seed, dynamic_seed=False)
+ set_pkg_seed(self.cfg.seed, use_cuda=self.cfg.policy.cuda)
+
+ # ==============================================================
+ # MCTS+RL algorithms related core code
+ # ==============================================================
+ policy_config = self.cfg.policy
+
+ evaluator = Evaluator(
+ eval_freq=self.cfg.policy.eval_freq,
+ n_evaluator_episode=self.cfg.env.n_evaluator_episode
+ if n_evaluator_episode is None else n_evaluator_episode,
+ stop_value=self.cfg.env.stop_value,
+ env=evaluator_env,
+ policy=self.policy.eval_mode,
+ exp_name=self.cfg.exp_name,
+ policy_config=policy_config
+ )
+
+ # ==============================================================
+ # Main loop
+ # ==============================================================
+
+ stop, reward = evaluator.eval()
+
+ return EvalReturn(
+ eval_value=np.mean(reward['eval_episode_return']), eval_value_std=np.std(reward['eval_episode_return'])
+ )
+
+ @property
+ def best(self):
+ """
+ Overview:
+ Provides access to the best model according to evaluation metrics.
+ Returns:
+ - The agent with the best model loaded.
+
+ .. note::
+ The best model is saved in the path `./exp_name/ckpt/ckpt_best.pth.tar`.
+ When this property is accessed, the agent instance will load the best model state.
+ """
+
+ best_model_file_path = os.path.join(self.checkpoint_save_dir, "ckpt_best.pth.tar")
+ # Load best model if it exists
+ if os.path.exists(best_model_file_path):
+ policy_state_dict = torch.load(best_model_file_path, map_location=torch.device("cpu"))
+ self.policy.learn_mode.load_state_dict(policy_state_dict)
+ return self
diff --git a/LightZero/lzero/config/__init__.py b/LightZero/lzero/config/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/LightZero/lzero/config/meta.py b/LightZero/lzero/config/meta.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ac1898e3c815f824e10128a3249e192310a3757
--- /dev/null
+++ b/LightZero/lzero/config/meta.py
@@ -0,0 +1,19 @@
+"""
+Overview:
+ Meta information for LightZero package.
+"""
+
+#: Title of this project (should be `LightZero`).
+__TITLE__ = "LightZero"
+
+#: Version of this project.
+__VERSION__ = "0.0.3"
+
+#: Short description of the project, will be included in ``setup.py``.
+__DESCRIPTION__ = 'A lightweight and efficient MCTS/AlphaZero/MuZero algorithm toolkits.'
+
+#: Author of this project.
+__AUTHOR__ = "LightZero's Contributors"
+
+#: Email of the authors'.
+__AUTHOR_EMAIL__ = "opendilab@opendilab.net"
diff --git a/LightZero/lzero/config/utils.py b/LightZero/lzero/config/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..d4c5f8d952505b6894dbc5f776cd9ae4b311ea9e
--- /dev/null
+++ b/LightZero/lzero/config/utils.py
@@ -0,0 +1,18 @@
+import numpy as np
+from ding.utils import get_world_size
+from easydict import EasyDict
+
+
+def lz_to_ddp_config(cfg: EasyDict) -> EasyDict:
+ r"""
+ Overview:
+ Convert the LightZero-style config to ddp config
+ Arguments:
+ - cfg (:obj:`EasyDict`): The config to be converted
+ Returns:
+ - cfg (:obj:`EasyDict`): The converted config
+ """
+ w = get_world_size()
+ cfg.policy.batch_size = int(np.ceil(cfg.policy.batch_size / w))
+ cfg.policy.n_episode = int(np.ceil(cfg.policy.n_episode) / w)
+ return cfg
diff --git a/LightZero/lzero/entry/__init__.py b/LightZero/lzero/entry/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..9b50ed6ec47381d243775359eca9a26f46e0876b
--- /dev/null
+++ b/LightZero/lzero/entry/__init__.py
@@ -0,0 +1,7 @@
+from .train_alphazero import train_alphazero
+from .eval_alphazero import eval_alphazero
+from .train_muzero import train_muzero
+from .train_muzero_with_reward_model import train_muzero_with_reward_model
+from .eval_muzero import eval_muzero
+from .eval_muzero_with_gym_env import eval_muzero_with_gym_env
+from .train_muzero_with_gym_env import train_muzero_with_gym_env
\ No newline at end of file
diff --git a/LightZero/lzero/entry/eval_alphazero.py b/LightZero/lzero/entry/eval_alphazero.py
new file mode 100644
index 0000000000000000000000000000000000000000..486e2e6e52fbba2e5df5ee4c556b785fdcf31756
--- /dev/null
+++ b/LightZero/lzero/entry/eval_alphazero.py
@@ -0,0 +1,96 @@
+import os
+from functools import partial
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+from tensorboardX import SummaryWriter
+
+from ding.config import compile_config
+from ding.envs import create_env_manager
+from ding.envs import get_vec_env_setting
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from lzero.worker import AlphaZeroEvaluator
+
+
+def eval_alphazero(
+ input_cfg: Tuple[dict, dict],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ model_path: Optional[str] = None,
+ num_episodes_each_seed: int = 1,
+ print_seed_details: int = False,
+) -> 'Policy': # noqa
+ """
+ Overview:
+ The eval entry for AlphaZero.
+ Arguments:
+ - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - model_path (:obj:`Optional[str]`): The pretrained model path, which should
+ point to the ckpt file of the pretrained model, and an absolute path is recommended.
+ In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ cfg, create_cfg = input_cfg
+ create_cfg.policy.type = create_cfg.policy.type
+
+ if cfg.policy.cuda and torch.cuda.is_available():
+ cfg.policy.device = 'cuda'
+ else:
+ cfg.policy.device = 'cpu'
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
+
+ # load pretrained model
+ if model_path is not None:
+ policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+
+ evaluator = AlphaZeroEvaluator(
+ eval_freq=cfg.policy.eval_freq,
+ n_evaluator_episode=cfg.env.n_evaluator_episode,
+ stop_value=cfg.env.stop_value,
+ env=evaluator_env,
+ policy=policy.eval_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ )
+
+ while True:
+ # ==============================================================
+ # eval trained model
+ # ==============================================================
+ returns = []
+ for i in range(num_episodes_each_seed):
+ stop_flag, episode_info = evaluator.eval()
+ returns.append(episode_info['eval_episode_return'])
+
+ returns = np.array(returns)
+
+ if print_seed_details:
+ print("=" * 20)
+ print(f'In seed {seed}, returns: {returns}')
+ if cfg.policy.simulation_env_name in ['tictactoe', 'connect4', 'gomoku', 'chess']:
+ print(
+ f'win rate: {len(np.where(returns == 1.)[0]) / num_episodes_each_seed}, draw rate: {len(np.where(returns == 0.)[0]) / num_episodes_each_seed}, lose rate: {len(np.where(returns == -1.)[0]) / num_episodes_each_seed}'
+ )
+ print("=" * 20)
+
+ return returns.mean(), returns
diff --git a/LightZero/lzero/entry/eval_muzero.py b/LightZero/lzero/entry/eval_muzero.py
new file mode 100644
index 0000000000000000000000000000000000000000..e3e96dcddfb86679446d3091cf46119f126c4405
--- /dev/null
+++ b/LightZero/lzero/entry/eval_muzero.py
@@ -0,0 +1,108 @@
+import os
+from functools import partial
+from typing import Optional, Tuple
+
+import numpy as np
+import torch
+from tensorboardX import SummaryWriter
+
+from ding.config import compile_config
+from ding.envs import create_env_manager
+from ding.envs import get_vec_env_setting
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from ding.worker import BaseLearner
+from lzero.worker import MuZeroEvaluator
+
+
+def eval_muzero(
+ input_cfg: Tuple[dict, dict],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ model_path: Optional[str] = None,
+ num_episodes_each_seed: int = 1,
+ print_seed_details: int = False,
+) -> 'Policy': # noqa
+ """
+ Overview:
+ The eval entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero.
+ Arguments:
+ - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - model_path (:obj:`Optional[str]`): The pretrained model path, which should
+ point to the ckpt file of the pretrained model, and an absolute path is recommended.
+ In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ cfg, create_cfg = input_cfg
+ assert create_cfg.policy.type in ['efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'], \
+ "LightZero now only support the following algo.: 'efficientzero', 'muzero', 'stochastic_muzero', 'gumbel_muzero', 'sampled_efficientzero'"
+
+ if cfg.policy.cuda and torch.cuda.is_available():
+ cfg.policy.device = 'cuda'
+ else:
+ cfg.policy.device = 'cpu'
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
+
+ # load pretrained model
+ if model_path is not None:
+ policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+
+ # ==============================================================
+ # MCTS+RL algorithms related core code
+ # ==============================================================
+ policy_config = cfg.policy
+ evaluator = MuZeroEvaluator(
+ eval_freq=cfg.policy.eval_freq,
+ n_evaluator_episode=cfg.env.n_evaluator_episode,
+ stop_value=cfg.env.stop_value,
+ env=evaluator_env,
+ policy=policy.eval_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ policy_config=policy_config
+ )
+
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ while True:
+ # ==============================================================
+ # eval trained model
+ # ==============================================================
+ returns = []
+ for i in range(num_episodes_each_seed):
+ stop_flag, episode_info = evaluator.eval(learner.save_checkpoint, learner.train_iter)
+ returns.append(episode_info['eval_episode_return'])
+
+ returns = np.array(returns)
+
+ if print_seed_details:
+ print("=" * 20)
+ print(f'In seed {seed}, returns: {returns}')
+ if cfg.policy.env_type == 'board_games':
+ print(
+ f'win rate: {len(np.where(returns == 1.)[0]) / num_episodes_each_seed}, draw rate: {len(np.where(returns == 0.)[0]) / num_episodes_each_seed}, lose rate: {len(np.where(returns == -1.)[0]) / num_episodes_each_seed}'
+ )
+ print("=" * 20)
+
+ return returns.mean(), returns
diff --git a/LightZero/lzero/entry/eval_muzero_with_gym_env.py b/LightZero/lzero/entry/eval_muzero_with_gym_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..663b4945aecb53d3ed93b5e79921c506e357d6bd
--- /dev/null
+++ b/LightZero/lzero/entry/eval_muzero_with_gym_env.py
@@ -0,0 +1,118 @@
+import os
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from tensorboardX import SummaryWriter
+
+from ding.config import compile_config
+from ding.envs import DingEnvWrapper, BaseEnvManager
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from ding.worker import BaseLearner
+from lzero.envs.get_wrapped_env import get_wrappered_env
+from lzero.worker import MuZeroEvaluator
+
+
+def eval_muzero_with_gym_env(
+ input_cfg: Tuple[dict, dict],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ model_path: Optional[str] = None,
+ num_episodes_each_seed: int = 1,
+ print_seed_details: int = False,
+) -> 'Policy': # noqa
+ """
+ Overview:
+ The eval entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero.
+ We create a gym environment using env_name parameter, and then convert it to the format
+ required by LightZero using LightZeroEnvWrapper class.
+ Please refer to the get_wrappered_env method for more details.
+ Arguments:
+ - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - model_path (:obj:`Optional[str]`): The pretrained model path, which should
+ point to the ckpt file of the pretrained model, and an absolute path is recommended.
+ In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ cfg, create_cfg = input_cfg
+ assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero'], \
+ "LightZero noow only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero'"
+
+ if cfg.policy.cuda and torch.cuda.is_available():
+ cfg.policy.device = 'cuda'
+ else:
+ cfg.policy.device = 'cpu'
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+ # Create main components: env, policy
+ collector_env_cfg = DingEnvWrapper.create_collector_env_cfg(cfg.env)
+ evaluator_env_cfg = DingEnvWrapper.create_evaluator_env_cfg(cfg.env)
+ collector_env = BaseEnvManager(
+ [get_wrappered_env(c, cfg.env.env_name) for c in collector_env_cfg], cfg=BaseEnvManager.default_config()
+ )
+ evaluator_env = BaseEnvManager(
+ [get_wrappered_env(c, cfg.env.env_name) for c in evaluator_env_cfg], cfg=BaseEnvManager.default_config()
+ )
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
+
+ # load pretrained model
+ if model_path is not None:
+ policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+
+ # ==============================================================
+ # MCTS+RL algorithms related core code
+ # ==============================================================
+ policy_config = cfg.policy
+ # specific game buffer for MCTS+RL algorithms
+ evaluator = MuZeroEvaluator(
+ eval_freq=cfg.policy.eval_freq,
+ n_evaluator_episode=cfg.env.n_evaluator_episode,
+ stop_value=cfg.env.stop_value,
+ env=evaluator_env,
+ policy=policy.eval_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ policy_config=policy_config
+ )
+
+ # ==========
+ # Main loop
+ # ==========
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ while True:
+ # ==============================================================
+ # eval trained model
+ # ==============================================================
+ returns = []
+ for i in range(num_episodes_each_seed):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter)
+ returns.append(reward)
+ returns = np.array(returns)
+
+ if print_seed_details:
+ print("=" * 20)
+ print(f'In seed {seed}, returns: {returns}')
+ if cfg.policy.env_type == 'board_games':
+ print(
+ f'win rate: {len(np.where(returns == 1.)[0]) / num_episodes_each_seed}, draw rate: {len(np.where(returns == 0.)[0]) / num_episodes_each_seed}, lose rate: {len(np.where(returns == -1.)[0]) / num_episodes_each_seed}'
+ )
+ print("=" * 20)
+
+ return returns.mean(), returns
diff --git a/LightZero/lzero/entry/train_alphazero.py b/LightZero/lzero/entry/train_alphazero.py
new file mode 100644
index 0000000000000000000000000000000000000000..3b455adb1d63be7b692bf80b495d6d247ef5f8eb
--- /dev/null
+++ b/LightZero/lzero/entry/train_alphazero.py
@@ -0,0 +1,144 @@
+import logging
+import os
+from functools import partial
+from typing import Optional, Tuple
+
+import torch
+from ding.config import compile_config
+from ding.envs import create_env_manager
+from ding.envs import get_vec_env_setting
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from ding.worker import BaseLearner, create_buffer
+from tensorboardX import SummaryWriter
+
+from lzero.policy import visit_count_temperature
+from lzero.worker import AlphaZeroCollector, AlphaZeroEvaluator
+
+
+def train_alphazero(
+ input_cfg: Tuple[dict, dict],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ model_path: Optional[str] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ The train entry for AlphaZero.
+ Arguments:
+ - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - env_setting (:obj:`Optional[List[Any]]`): A list with 3 elements: \
+ ``BaseEnv`` subclass, collector env config, and evaluator env config.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - model_path (:obj:`Optional[str]`): The pretrained model path, which should
+ point to the ckpt file of the pretrained model, and an absolute path is recommended.
+ In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+ cfg, create_cfg = input_cfg
+ create_cfg.policy.type = create_cfg.policy.type
+
+ if cfg.policy.cuda and torch.cuda.is_available():
+ cfg.policy.device = 'cuda'
+ else:
+ cfg.policy.device = 'cpu'
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
+
+ # load pretrained model
+ if model_path is not None:
+ policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+ replay_buffer = create_buffer(cfg.policy.other.replay_buffer, tb_logger=tb_logger, exp_name=cfg.exp_name)
+
+ policy_config = cfg.policy
+ batch_size = policy_config.batch_size
+ collector = AlphaZeroCollector(
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ )
+ evaluator = AlphaZeroEvaluator(
+ eval_freq=cfg.policy.eval_freq,
+ n_evaluator_episode=cfg.env.n_evaluator_episode,
+ stop_value=cfg.env.stop_value,
+ env=evaluator_env,
+ policy=policy.eval_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ )
+
+ # ==============================================================
+ # Main loop
+ # ==============================================================
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+ if cfg.policy.update_per_collect is not None:
+ update_per_collect = cfg.policy.update_per_collect
+ while True:
+ collect_kwargs = {}
+ # set temperature for visit count distributions according to the train_iter,
+ # please refer to Appendix D in MuZero paper for details.
+ collect_kwargs['temperature'] = visit_count_temperature(
+ policy_config.manual_temperature_decay,
+ policy_config.fixed_temperature_value,
+ policy_config.threshold_training_steps_for_final_temperature,
+ trained_steps=learner.train_iter
+ )
+
+ # Evaluate policy performance
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(
+ learner.save_checkpoint,
+ learner.train_iter,
+ collector.envstep,
+ )
+ if stop:
+ break
+
+ # Collect data by default config n_sample/n_episode
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ new_data = sum(new_data, [])
+ if cfg.policy.update_per_collect is None:
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ collected_transitions_num = len(new_data)
+ update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
+ replay_buffer.push(new_data, cur_collector_envstep=collector.envstep)
+
+ # Learn policy from collected data
+ for i in range(update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ train_data = replay_buffer.sample(batch_size, learner.train_iter)
+ if train_data is None:
+ logging.warning(
+ 'The data in replay_buffer is not sufficient to sample a mini-batch.'
+ 'continue to collect now ....'
+ )
+ break
+
+ learner.train(train_data, collector.envstep)
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/LightZero/lzero/entry/train_muzero.py b/LightZero/lzero/entry/train_muzero.py
new file mode 100644
index 0000000000000000000000000000000000000000..5e397606362f1708529215055cc693003f08cbc1
--- /dev/null
+++ b/LightZero/lzero/entry/train_muzero.py
@@ -0,0 +1,195 @@
+import logging
+import os
+from functools import partial
+from typing import Optional, Tuple
+
+import torch
+from ding.config import compile_config
+from ding.envs import create_env_manager
+from ding.envs import get_vec_env_setting
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed, get_rank
+from ding.rl_utils import get_epsilon_greedy_fn
+from ding.worker import BaseLearner
+from tensorboardX import SummaryWriter
+
+from lzero.entry.utils import log_buffer_memory_usage
+from lzero.policy import visit_count_temperature
+from lzero.policy.random_policy import LightZeroRandomPolicy
+from lzero.worker import MuZeroCollector as Collector
+from lzero.worker import MuZeroEvaluator as Evaluator
+from .utils import random_collect
+
+
+def train_muzero(
+ input_cfg: Tuple[dict, dict],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ model_path: Optional[str] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero, Gumbel Muzero.
+ Arguments:
+ - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - model_path (:obj:`Optional[str]`): The pretrained model path, which should
+ point to the ckpt file of the pretrained model, and an absolute path is recommended.
+ In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+
+ cfg, create_cfg = input_cfg
+ assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero', 'stochastic_muzero'], \
+ "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero', 'gumbel_muzero'"
+
+ if create_cfg.policy.type == 'muzero':
+ from lzero.mcts import MuZeroGameBuffer as GameBuffer
+ elif create_cfg.policy.type == 'efficientzero':
+ from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
+ elif create_cfg.policy.type == 'sampled_efficientzero':
+ from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
+ elif create_cfg.policy.type == 'gumbel_muzero':
+ from lzero.mcts import GumbelMuZeroGameBuffer as GameBuffer
+ elif create_cfg.policy.type == 'stochastic_muzero':
+ from lzero.mcts import StochasticMuZeroGameBuffer as GameBuffer
+
+ if cfg.policy.cuda and torch.cuda.is_available():
+ cfg.policy.device = 'cuda'
+ else:
+ cfg.policy.device = 'cpu'
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
+
+ # load pretrained model
+ if model_path is not None:
+ policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial')) if get_rank() == 0 else None
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+
+ # ==============================================================
+ # MCTS+RL algorithms related core code
+ # ==============================================================
+ policy_config = cfg.policy
+ batch_size = policy_config.batch_size
+ # specific game buffer for MCTS+RL algorithms
+ replay_buffer = GameBuffer(policy_config)
+ collector = Collector(
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ policy_config=policy_config
+ )
+ evaluator = Evaluator(
+ eval_freq=cfg.policy.eval_freq,
+ n_evaluator_episode=cfg.env.n_evaluator_episode,
+ stop_value=cfg.env.stop_value,
+ env=evaluator_env,
+ policy=policy.eval_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ policy_config=policy_config
+ )
+
+ # ==============================================================
+ # Main loop
+ # ==============================================================
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+
+ if cfg.policy.update_per_collect is not None:
+ update_per_collect = cfg.policy.update_per_collect
+
+ # The purpose of collecting random data before training:
+ # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
+ # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
+ if cfg.policy.random_collect_episode_num > 0:
+ random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
+
+ while True:
+ log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
+ collect_kwargs = {}
+ # set temperature for visit count distributions according to the train_iter,
+ # please refer to Appendix D in MuZero paper for details.
+ collect_kwargs['temperature'] = visit_count_temperature(
+ policy_config.manual_temperature_decay,
+ policy_config.fixed_temperature_value,
+ policy_config.threshold_training_steps_for_final_temperature,
+ trained_steps=learner.train_iter
+ )
+
+ if policy_config.eps.eps_greedy_exploration_in_collect:
+ epsilon_greedy_fn = get_epsilon_greedy_fn(
+ start=policy_config.eps.start,
+ end=policy_config.eps.end,
+ decay=policy_config.eps.decay,
+ type_=policy_config.eps.type
+ )
+ collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
+ else:
+ collect_kwargs['epsilon'] = 0.0
+
+ # Evaluate policy performance.
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+
+ # Collect data by default config n_sample/n_episode.
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ if cfg.policy.update_per_collect is None:
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
+ update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
+ # save returned new_data collected by the collector
+ replay_buffer.push_game_segments(new_data)
+ # remove the oldest data if the replay buffer is full.
+ replay_buffer.remove_oldest_data_to_fit()
+
+ # Learn policy from collected data.
+ for i in range(update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ if replay_buffer.get_num_of_transitions() > batch_size:
+ train_data = replay_buffer.sample(batch_size, policy)
+ else:
+ logging.warning(
+ f'The data in replay_buffer is not sufficient to sample a mini-batch: '
+ f'batch_size: {batch_size}, '
+ f'{replay_buffer} '
+ f'continue to collect now ....'
+ )
+ break
+
+ # The core train steps for MCTS+RL algorithms.
+ log_vars = learner.train(train_data, collector.envstep)
+
+ if cfg.policy.use_priority:
+ replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
+
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/LightZero/lzero/entry/train_muzero_with_gym_env.py b/LightZero/lzero/entry/train_muzero_with_gym_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..1bfd855c551c23a1052ca6be4f1b89af0712f9fa
--- /dev/null
+++ b/LightZero/lzero/entry/train_muzero_with_gym_env.py
@@ -0,0 +1,172 @@
+import logging
+import os
+from typing import Optional
+from typing import Tuple
+
+import torch
+from tensorboardX import SummaryWriter
+
+from ding.config import compile_config
+from ding.envs import DingEnvWrapper, BaseEnvManager
+from ding.policy import create_policy
+from ding.utils import set_pkg_seed
+from ding.worker import BaseLearner
+from lzero.envs.get_wrapped_env import get_wrappered_env
+from lzero.policy import visit_count_temperature
+from lzero.worker import MuZeroCollector, MuZeroEvaluator
+
+
+def train_muzero_with_gym_env(
+ input_cfg: Tuple[dict, dict],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ model_path: Optional[str] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ The train entry for MCTS+RL algorithms, including MuZero, EfficientZero, Sampled EfficientZero.
+ We create a gym environment using env_name parameter, and then convert it to the format required by LightZero using LightZeroEnvWrapper class.
+ Please refer to the get_wrappered_env method for more details.
+ Arguments:
+ - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - model_path (:obj:`Optional[str]`): The pretrained model path, which should
+ point to the ckpt file of the pretrained model, and an absolute path is recommended.
+ In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+
+ cfg, create_cfg = input_cfg
+ assert create_cfg.policy.type in ['efficientzero', 'muzero', 'sampled_efficientzero'], \
+ "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero'"
+
+ if create_cfg.policy.type == 'muzero':
+ from lzero.mcts import MuZeroGameBuffer as GameBuffer
+ elif create_cfg.policy.type == 'efficientzero':
+ from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
+ elif create_cfg.policy.type == 'sampled_efficientzero':
+ from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
+
+ if cfg.policy.cuda and torch.cuda.is_available():
+ cfg.policy.device = 'cuda'
+ else:
+ cfg.policy.device = 'cpu'
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+
+ # Create main components: env, policy
+ collector_env_cfg = DingEnvWrapper.create_collector_env_cfg(cfg.env)
+ evaluator_env_cfg = DingEnvWrapper.create_evaluator_env_cfg(cfg.env)
+ collector_env = BaseEnvManager(
+ [get_wrappered_env(c, cfg.env.env_name) for c in collector_env_cfg], cfg=BaseEnvManager.default_config()
+ )
+ evaluator_env = BaseEnvManager(
+ [get_wrappered_env(c, cfg.env.env_name) for c in evaluator_env_cfg], cfg=BaseEnvManager.default_config()
+ )
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
+
+ # load pretrained model
+ if model_path is not None:
+ policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+
+ # ==============================================================
+ # MCTS+RL algorithms related core code
+ # ==============================================================
+ policy_config = cfg.policy
+ batch_size = policy_config.batch_size
+ # specific game buffer for MCTS+RL algorithms
+ replay_buffer = GameBuffer(policy_config)
+ collector = MuZeroCollector(
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ policy_config=policy_config
+ )
+ evaluator = MuZeroEvaluator(
+ eval_freq=cfg.policy.eval_freq,
+ n_evaluator_episode=cfg.env.n_evaluator_episode,
+ stop_value=cfg.env.stop_value,
+ env=evaluator_env,
+ policy=policy.eval_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ policy_config=policy_config
+ )
+
+ # ==============================================================
+ # Main loop
+ # ==============================================================
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+ if cfg.policy.update_per_collect is not None:
+ update_per_collect = cfg.policy.update_per_collect
+ while True:
+ collect_kwargs = {}
+ # set temperature for visit count distributions according to the train_iter,
+ # please refer to Appendix D in MuZero paper for details.
+ collect_kwargs['temperature'] = visit_count_temperature(
+ policy_config.manual_temperature_decay,
+ policy_config.fixed_temperature_value,
+ policy_config.threshold_training_steps_for_final_temperature,
+ trained_steps=learner.train_iter
+ )
+
+ # Evaluate policy performance.
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+
+ # Collect data by default config n_sample/n_episode.
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+ if cfg.policy.update_per_collect is None:
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
+ update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
+ # save returned new_data collected by the collector
+ replay_buffer.push_game_segments(new_data)
+ # remove the oldest data if the replay buffer is full.
+ replay_buffer.remove_oldest_data_to_fit()
+
+ # Learn policy from collected data.
+ for i in range(cfg.policy.update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ if replay_buffer.get_num_of_transitions() > batch_size:
+ train_data = replay_buffer.sample(batch_size, policy)
+ else:
+ logging.warning(
+ f'The data in replay_buffer is not sufficient to sample a mini-batch: '
+ f'batch_size: {batch_size}, '
+ f'{replay_buffer} '
+ f'continue to collect now ....'
+ )
+ break
+
+ # The core train steps for MCTS+RL algorithms.
+ log_vars = learner.train(train_data, collector.envstep)
+
+ if cfg.policy.use_priority:
+ replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
+
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/LightZero/lzero/entry/train_muzero_with_reward_model.py b/LightZero/lzero/entry/train_muzero_with_reward_model.py
new file mode 100644
index 0000000000000000000000000000000000000000..2ae4096017ec91581936931f6976c3af70efdac4
--- /dev/null
+++ b/LightZero/lzero/entry/train_muzero_with_reward_model.py
@@ -0,0 +1,210 @@
+import logging
+import os
+from functools import partial
+from typing import Optional, Tuple
+
+import torch
+from ding.config import compile_config
+from ding.envs import create_env_manager
+from ding.envs import get_vec_env_setting
+from ding.policy import create_policy
+from ding.rl_utils import get_epsilon_greedy_fn
+from ding.utils import set_pkg_seed
+from ding.worker import BaseLearner
+from tensorboardX import SummaryWriter
+
+from lzero.entry.utils import log_buffer_memory_usage, random_collect
+from lzero.policy import visit_count_temperature
+from lzero.policy.random_policy import LightZeroRandomPolicy
+from lzero.reward_model.rnd_reward_model import RNDRewardModel
+from lzero.worker import MuZeroCollector, MuZeroEvaluator
+
+
+def train_muzero_with_reward_model(
+ input_cfg: Tuple[dict, dict],
+ seed: int = 0,
+ model: Optional[torch.nn.Module] = None,
+ model_path: Optional[str] = None,
+ max_train_iter: Optional[int] = int(1e10),
+ max_env_step: Optional[int] = int(1e10),
+) -> 'Policy': # noqa
+ """
+ Overview:
+ The train entry for MCTS+RL algorithms augmented with reward_model.
+ Arguments:
+ - input_cfg (:obj:`Tuple[dict, dict]`): Config in dict type.
+ ``Tuple[dict, dict]`` type means [user_config, create_cfg].
+ - seed (:obj:`int`): Random seed.
+ - model (:obj:`Optional[torch.nn.Module]`): Instance of torch.nn.Module.
+ - model_path (:obj:`Optional[str]`): The pretrained model path, which should
+ point to the ckpt file of the pretrained model, and an absolute path is recommended.
+ In LightZero, the path is usually something like ``exp_name/ckpt/ckpt_best.pth.tar``.
+ - max_train_iter (:obj:`Optional[int]`): Maximum policy update iterations in training.
+ - max_env_step (:obj:`Optional[int]`): Maximum collected environment interaction steps.
+ Returns:
+ - policy (:obj:`Policy`): Converged policy.
+ """
+
+ cfg, create_cfg = input_cfg
+ assert create_cfg.policy.type in ['efficientzero', 'muzero', 'muzero_rnd', 'sampled_efficientzero'], \
+ "train_muzero entry now only support the following algo.: 'efficientzero', 'muzero', 'sampled_efficientzero'"
+
+ if create_cfg.policy.type in ['muzero', 'muzero_rnd']:
+ from lzero.mcts import MuZeroGameBuffer as GameBuffer
+ elif create_cfg.policy.type == 'efficientzero':
+ from lzero.mcts import EfficientZeroGameBuffer as GameBuffer
+ elif create_cfg.policy.type == 'sampled_efficientzero':
+ from lzero.mcts import SampledEfficientZeroGameBuffer as GameBuffer
+
+ if cfg.policy.cuda and torch.cuda.is_available():
+ cfg.policy.device = 'cuda'
+ else:
+ cfg.policy.device = 'cpu'
+
+ cfg = compile_config(cfg, seed=seed, env=None, auto=True, create_cfg=create_cfg, save_cfg=True)
+ # Create main components: env, policy
+ env_fn, collector_env_cfg, evaluator_env_cfg = get_vec_env_setting(cfg.env)
+
+ collector_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in collector_env_cfg])
+ evaluator_env = create_env_manager(cfg.env.manager, [partial(env_fn, cfg=c) for c in evaluator_env_cfg])
+
+ collector_env.seed(cfg.seed)
+ evaluator_env.seed(cfg.seed, dynamic_seed=False)
+ set_pkg_seed(cfg.seed, use_cuda=cfg.policy.cuda)
+
+ policy = create_policy(cfg.policy, model=model, enable_field=['learn', 'collect', 'eval'])
+
+ # load pretrained model
+ if model_path is not None:
+ policy.learn_mode.load_state_dict(torch.load(model_path, map_location=cfg.policy.device))
+
+ # Create worker components: learner, collector, evaluator, replay buffer, commander.
+ tb_logger = SummaryWriter(os.path.join('./{}/log/'.format(cfg.exp_name), 'serial'))
+ learner = BaseLearner(cfg.policy.learn.learner, policy.learn_mode, tb_logger, exp_name=cfg.exp_name)
+
+ # ==============================================================
+ # MCTS+RL algorithms related core code
+ # ==============================================================
+ policy_config = cfg.policy
+ batch_size = policy_config.batch_size
+ # specific game buffer for MCTS+RL algorithms
+ replay_buffer = GameBuffer(policy_config)
+ collector = MuZeroCollector(
+ env=collector_env,
+ policy=policy.collect_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ policy_config=policy_config
+ )
+ evaluator = MuZeroEvaluator(
+ eval_freq=cfg.policy.eval_freq,
+ n_evaluator_episode=cfg.env.n_evaluator_episode,
+ stop_value=cfg.env.stop_value,
+ env=evaluator_env,
+ policy=policy.eval_mode,
+ tb_logger=tb_logger,
+ exp_name=cfg.exp_name,
+ policy_config=policy_config
+ )
+ # create reward_model
+ reward_model = RNDRewardModel(cfg.reward_model, policy.collect_mode.get_attribute('device'), tb_logger,
+ policy._learn_model.representation_network,
+ policy._target_model_for_intrinsic_reward.representation_network,
+ cfg.policy.use_momentum_representation_network
+ )
+
+ # ==============================================================
+ # Main loop
+ # ==============================================================
+ # Learner's before_run hook.
+ learner.call_hook('before_run')
+ if cfg.policy.update_per_collect is not None:
+ update_per_collect = cfg.policy.update_per_collect
+
+ # The purpose of collecting random data before training:
+ # Exploration: Collecting random data helps the agent explore the environment and avoid getting stuck in a suboptimal policy prematurely.
+ # Comparison: By observing the agent's performance during random action-taking, we can establish a baseline to evaluate the effectiveness of reinforcement learning algorithms.
+ if cfg.policy.random_collect_episode_num > 0:
+ random_collect(cfg.policy, policy, LightZeroRandomPolicy, collector, collector_env, replay_buffer)
+
+ while True:
+ log_buffer_memory_usage(learner.train_iter, replay_buffer, tb_logger)
+ collect_kwargs = {}
+ # set temperature for visit count distributions according to the train_iter,
+ # please refer to Appendix D in MuZero paper for details.
+ collect_kwargs['temperature'] = visit_count_temperature(
+ policy_config.manual_temperature_decay,
+ policy_config.fixed_temperature_value,
+ policy_config.threshold_training_steps_for_final_temperature,
+ trained_steps=learner.train_iter,
+ )
+
+ if policy_config.eps.eps_greedy_exploration_in_collect:
+ epsilon_greedy_fn = get_epsilon_greedy_fn(start=policy_config.eps.start, end=policy_config.eps.end,
+ decay=policy_config.eps.decay, type_=policy_config.eps.type)
+ collect_kwargs['epsilon'] = epsilon_greedy_fn(collector.envstep)
+ else:
+ collect_kwargs['epsilon'] = 0.0
+
+ # Evaluate policy performance.
+ if evaluator.should_eval(learner.train_iter):
+ stop, reward = evaluator.eval(learner.save_checkpoint, learner.train_iter, collector.envstep)
+ if stop:
+ break
+
+ # Collect data by default config n_sample/n_episode.
+ new_data = collector.collect(train_iter=learner.train_iter, policy_kwargs=collect_kwargs)
+
+ # ****** reward_model related code ******
+ # collect data for reward_model training
+ reward_model.collect_data(new_data)
+ # update reward_model
+ if reward_model.cfg.input_type == 'latent_state':
+ # train reward_model with latent_state
+ if len(reward_model.train_latent_state) > reward_model.cfg.batch_size:
+ reward_model.train_with_data()
+ elif reward_model.cfg.input_type in ['obs', 'latent_state']:
+ # train reward_model with obs
+ if len(reward_model.train_obs) > reward_model.cfg.batch_size:
+ reward_model.train_with_data()
+ # clear old data in reward_model
+ reward_model.clear_old_data()
+
+ if cfg.policy.update_per_collect is None:
+ # update_per_collect is None, then update_per_collect is set to the number of collected transitions multiplied by the model_update_ratio.
+ collected_transitions_num = sum([len(game_segment) for game_segment in new_data[0]])
+ update_per_collect = int(collected_transitions_num * cfg.policy.model_update_ratio)
+ # save returned new_data collected by the collector
+ replay_buffer.push_game_segments(new_data)
+ # remove the oldest data if the replay buffer is full.
+ replay_buffer.remove_oldest_data_to_fit()
+
+ # Learn policy from collected data.
+ for i in range(update_per_collect):
+ # Learner will train ``update_per_collect`` times in one iteration.
+ if replay_buffer.get_num_of_transitions() > batch_size:
+ train_data = replay_buffer.sample(batch_size, policy)
+ else:
+ logging.warning(
+ f'The data in replay_buffer is not sufficient to sample a mini-batch: '
+ f'batch_size: {batch_size}, '
+ f'{replay_buffer} '
+ f'continue to collect now ....'
+ )
+ break
+
+ # update train_data reward using the augmented reward
+ train_data_augmented = reward_model.estimate(train_data)
+
+ # The core train steps for MCTS+RL algorithms.
+ log_vars = learner.train(train_data_augmented, collector.envstep)
+
+ if cfg.policy.use_priority:
+ replay_buffer.update_priority(train_data, log_vars[0]['value_priority_orig'])
+
+ if collector.envstep >= max_env_step or learner.train_iter >= max_train_iter:
+ break
+
+ # Learner's after_run hook.
+ learner.call_hook('after_run')
+ return policy
diff --git a/LightZero/lzero/entry/utils.py b/LightZero/lzero/entry/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..8e26bc5064f2297e21a07346e3c896cc7f8538ec
--- /dev/null
+++ b/LightZero/lzero/entry/utils.py
@@ -0,0 +1,75 @@
+import os
+
+import psutil
+from pympler.asizeof import asizeof
+from tensorboardX import SummaryWriter
+from typing import Optional, Callable
+
+
+def random_collect(
+ policy_cfg: 'EasyDict', # noqa
+ policy: 'Policy', # noqa
+ RandomPolicy: 'Policy', # noqa
+ collector: 'ISerialCollector', # noqa
+ collector_env: 'BaseEnvManager', # noqa
+ replay_buffer: 'IBuffer', # noqa
+ postprocess_data_fn: Optional[Callable] = None
+) -> None: # noqa
+ assert policy_cfg.random_collect_episode_num > 0
+
+ random_policy = RandomPolicy(cfg=policy_cfg, action_space=collector_env.env_ref.action_space)
+ # set the policy to random policy
+ collector.reset_policy(random_policy.collect_mode)
+
+ # set temperature for visit count distributions according to the train_iter,
+ # please refer to Appendix D in MuZero paper for details.
+ collect_kwargs = {'temperature': 1, 'epsilon': 0.0}
+
+ # Collect data by default config n_sample/n_episode.
+ new_data = collector.collect(n_episode=policy_cfg.random_collect_episode_num, train_iter=0, policy_kwargs=collect_kwargs)
+
+ if postprocess_data_fn is not None:
+ new_data = postprocess_data_fn(new_data)
+
+ # save returned new_data collected by the collector
+ replay_buffer.push_game_segments(new_data)
+ # remove the oldest data if the replay buffer is full.
+ replay_buffer.remove_oldest_data_to_fit()
+
+ # restore the policy
+ collector.reset_policy(policy.collect_mode)
+
+
+def log_buffer_memory_usage(train_iter: int, buffer: "GameBuffer", writer: SummaryWriter) -> None:
+ """
+ Overview:
+ Log the memory usage of the buffer and the current process to TensorBoard.
+ Arguments:
+ - train_iter (:obj:`int`): The current training iteration.
+ - buffer (:obj:`GameBuffer`): The game buffer.
+ - writer (:obj:`SummaryWriter`): The TensorBoard writer.
+ """
+ writer.add_scalar('Buffer/num_of_all_collected_episodes', buffer.num_of_collected_episodes, train_iter)
+ writer.add_scalar('Buffer/num_of_game_segments', len(buffer.game_segment_buffer), train_iter)
+ writer.add_scalar('Buffer/num_of_transitions', len(buffer.game_segment_game_pos_look_up), train_iter)
+
+ game_segment_buffer = buffer.game_segment_buffer
+
+ # Calculate the amount of memory occupied by self.game_segment_buffer (in bytes).
+ buffer_memory_usage = asizeof(game_segment_buffer)
+
+ # Convert buffer_memory_usage to megabytes (MB).
+ buffer_memory_usage_mb = buffer_memory_usage / (1024 * 1024)
+
+ # Record the memory usage of self.game_segment_buffer to TensorBoard.
+ writer.add_scalar('Buffer/memory_usage/game_segment_buffer', buffer_memory_usage_mb, train_iter)
+
+ # Get the amount of memory currently used by the process (in bytes).
+ process = psutil.Process(os.getpid())
+ process_memory_usage = process.memory_info().rss
+
+ # Convert process_memory_usage to megabytes (MB).
+ process_memory_usage_mb = process_memory_usage / (1024 * 1024)
+
+ # Record the memory usage of the process to TensorBoard.
+ writer.add_scalar('Buffer/memory_usage/process', process_memory_usage_mb, train_iter)
diff --git a/LightZero/lzero/envs/__init__.py b/LightZero/lzero/envs/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/LightZero/lzero/envs/get_wrapped_env.py b/LightZero/lzero/envs/get_wrapped_env.py
new file mode 100644
index 0000000000000000000000000000000000000000..41e9262dbeced7831889587b712b393154d50452
--- /dev/null
+++ b/LightZero/lzero/envs/get_wrapped_env.py
@@ -0,0 +1,31 @@
+import gym
+from easydict import EasyDict
+
+from ding.envs import DingEnvWrapper
+from lzero.envs.wrappers import ActionDiscretizationEnvWrapper, LightZeroEnvWrapper
+
+
+def get_wrappered_env(wrapper_cfg: EasyDict, env_name: str):
+ """
+ Overview:
+ Returns a new environment with one or more wrappers applied to it.
+ Arguments:
+ - wrapper_cfg (:obj:`EasyDict`): A dictionary containing configuration settings for the wrappers.
+ - env_name (:obj:`str`): The name of the environment to create.
+ Returns:
+ A callable that creates the wrapped environment.
+ """
+ if wrapper_cfg.manually_discretization:
+ return lambda: DingEnvWrapper(
+ gym.make(env_name),
+ cfg={
+ 'env_wrapper': [
+ lambda env: ActionDiscretizationEnvWrapper(env, wrapper_cfg), lambda env:
+ LightZeroEnvWrapper(env, wrapper_cfg)
+ ]
+ }
+ )
+ else:
+ return lambda: DingEnvWrapper(
+ gym.make(env_name), cfg={'env_wrapper': [lambda env: LightZeroEnvWrapper(env, wrapper_cfg)]}
+ )
diff --git a/LightZero/lzero/envs/tests/__init__.py b/LightZero/lzero/envs/tests/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/LightZero/lzero/envs/tests/test_ding_env_wrapper.py b/LightZero/lzero/envs/tests/test_ding_env_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..5a5d850a6c28b56fff9cf008253a2db929792ae9
--- /dev/null
+++ b/LightZero/lzero/envs/tests/test_ding_env_wrapper.py
@@ -0,0 +1,31 @@
+from easydict import EasyDict
+import pytest
+import gymnasium as gym
+import numpy as np
+
+from ding.envs import DingEnvWrapper
+
+
+@pytest.mark.unittest
+class TestDingEnvWrapper:
+
+ def test(self):
+ env_id = 'Pendulum-v1'
+ env = gym.make(env_id)
+ ding_env = DingEnvWrapper(env=env)
+ print(ding_env.observation_space, ding_env.action_space, ding_env.reward_space)
+ cfg = EasyDict(dict(
+ collector_env_num=16,
+ evaluator_env_num=3,
+ is_train=True,
+ ))
+ l1 = ding_env.create_collector_env_cfg(cfg)
+ assert isinstance(l1, list)
+ l1 = ding_env.create_evaluator_env_cfg(cfg)
+ assert isinstance(l1, list)
+
+ obs = ding_env.reset()
+
+ assert isinstance(obs[0], np.ndarray)
+ action = ding_env.random_action()
+ print('random_action: {}, action_space: {}'.format(action.shape, ding_env.action_space))
diff --git a/LightZero/lzero/envs/tests/test_lightzero_env_wrapper.py b/LightZero/lzero/envs/tests/test_lightzero_env_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..6440ef848d10716d2fa6a7a2c005c1024fe15ae9
--- /dev/null
+++ b/LightZero/lzero/envs/tests/test_lightzero_env_wrapper.py
@@ -0,0 +1,139 @@
+import pytest
+
+from ding.envs import DingEnvWrapper
+from lzero.envs.wrappers import ActionDiscretizationEnvWrapper, LightZeroEnvWrapper
+from easydict import EasyDict
+import gym
+import numpy as np
+
+
+@pytest.mark.unittest
+class TestLightZeroEnvWrapper:
+
+ def test_continuous_pendulum(self):
+ env_cfg = EasyDict(
+ dict(
+ env_name='Pendulum-v1',
+ manually_discretization=False,
+ continuous=True,
+ each_dim_disc_size=None,
+ is_train=True,
+ )
+ )
+
+ lightzero_env = DingEnvWrapper(
+ gym.make(env_cfg.env_name), cfg={'env_wrapper': [
+ lambda env: LightZeroEnvWrapper(env, env_cfg),
+ ]}
+ )
+
+ obs = lightzero_env.reset()
+ print("obs: ", obs)
+
+ print(lightzero_env.observation_space, lightzero_env.action_space, lightzero_env.reward_space)
+
+ assert isinstance(obs, dict)
+ assert isinstance(obs['observation'], np.ndarray) and obs['observation'].shape == (3, )
+ assert obs['action_mask'] is None and obs['to_play'] == -1
+
+ action = lightzero_env.random_action()
+
+ print('random_action: {}, action_space: {}'.format(action.shape, lightzero_env.action_space))
+
+ def test_discretization_pendulum(self):
+ env_cfg = EasyDict(
+ dict(
+ env_name='Pendulum-v1',
+ manually_discretization=True,
+ continuous=False,
+ each_dim_disc_size=11,
+ is_train=True,
+ )
+ )
+
+ lightzero_env = DingEnvWrapper(
+ gym.make(env_cfg.env_name),
+ cfg={
+ 'env_wrapper': [
+ lambda env: ActionDiscretizationEnvWrapper(env, env_cfg),
+ lambda env: LightZeroEnvWrapper(env, env_cfg),
+ ]
+ }
+ )
+
+ obs = lightzero_env.reset()
+ print("obs: ", obs)
+
+ print(lightzero_env.observation_space, lightzero_env.action_space, lightzero_env.reward_space)
+
+ assert isinstance(obs, dict)
+ assert isinstance(obs['observation'], np.ndarray) and obs['observation'].shape == (3, )
+ assert obs['action_mask'].sum() == 11 and obs['to_play'] == -1
+
+ action = lightzero_env.random_action()
+
+ print('random_action: {}, action_space: {}'.format(action.shape, lightzero_env.action_space))
+
+ def test_continuous_bipedalwalker(self):
+ env_cfg = EasyDict(
+ dict(
+ env_name='BipedalWalker-v3',
+ manually_discretization=False,
+ continuous=True,
+ each_dim_disc_size=4,
+ is_train=True,
+ )
+ )
+
+ lightzero_env = DingEnvWrapper(
+ gym.make(env_cfg.env_name), cfg={'env_wrapper': [
+ lambda env: LightZeroEnvWrapper(env, env_cfg),
+ ]}
+ )
+
+ obs = lightzero_env.reset()
+ print("obs: ", obs)
+
+ print(lightzero_env.observation_space, lightzero_env.action_space, lightzero_env.reward_space)
+
+ assert isinstance(obs, dict)
+ assert isinstance(obs['observation'], np.ndarray) and obs['observation'].shape == (24, )
+ assert obs['action_mask'] is None and obs['to_play'] == -1
+
+ action = lightzero_env.random_action()
+
+ print('random_action: {}, action_space: {}'.format(action.shape, lightzero_env.action_space))
+
+ def test_discretization_bipedalwalker(self):
+ env_cfg = EasyDict(
+ dict(
+ env_name='BipedalWalker-v3',
+ manually_discretization=True,
+ continuous=False,
+ each_dim_disc_size=4,
+ is_train=True,
+ )
+ )
+
+ lightzero_env = DingEnvWrapper(
+ gym.make(env_cfg.env_name),
+ cfg={
+ 'env_wrapper': [
+ lambda env: ActionDiscretizationEnvWrapper(env, env_cfg),
+ lambda env: LightZeroEnvWrapper(env, env_cfg),
+ ]
+ }
+ )
+
+ obs = lightzero_env.reset()
+ print("obs: ", obs)
+
+ print(lightzero_env.observation_space, lightzero_env.action_space, lightzero_env.reward_space)
+
+ assert isinstance(obs, dict)
+ assert isinstance(obs['observation'], np.ndarray) and obs['observation'].shape == (24, )
+ assert obs['action_mask'].sum() == 256 and obs['to_play'] == -1
+
+ action = lightzero_env.random_action()
+
+ print('random_action: {}, action_space: {}'.format(action.shape, lightzero_env.action_space))
diff --git a/LightZero/lzero/envs/wrappers/__init__.py b/LightZero/lzero/envs/wrappers/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d925b80bf515764bbc7e38f44aa3fca9fe9ed59c
--- /dev/null
+++ b/LightZero/lzero/envs/wrappers/__init__.py
@@ -0,0 +1,2 @@
+from .action_discretization_env_wrapper import *
+from .lightzero_env_wrapper import *
diff --git a/LightZero/lzero/envs/wrappers/action_discretization_env_wrapper.py b/LightZero/lzero/envs/wrappers/action_discretization_env_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..efd1e0fe5c486439969adc690b75e10004849ade
--- /dev/null
+++ b/LightZero/lzero/envs/wrappers/action_discretization_env_wrapper.py
@@ -0,0 +1,92 @@
+from itertools import product
+
+import gym
+import numpy as np
+from easydict import EasyDict
+
+from ding.envs import BaseEnvTimestep
+from ding.torch_utils import to_ndarray
+from ding.utils import ENV_WRAPPER_REGISTRY
+
+
+@ENV_WRAPPER_REGISTRY.register('action_discretization_env_wrapper')
+class ActionDiscretizationEnvWrapper(gym.Wrapper):
+ """
+ Overview:
+ The modified environment with manually discretized action space. For each dimension, equally dividing the
+ original continuous action into ``each_dim_disc_size`` bins and using their Cartesian product to obtain
+ handcrafted discrete actions.
+ Interface:
+ ``__init__``, ``reset``, ``step``
+ Properties:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ """
+
+ def __init__(self, env: gym.Env, cfg: EasyDict) -> None:
+ """
+ Overview:
+ Initialize ``self.`` See ``help(type(self))`` for accurate signature; \
+ setup the properties according to running mean and std.
+ Arguments:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ """
+ super().__init__(env)
+ assert 'is_train' in cfg, '`is_train` flag must set in the config of env'
+ self.is_train = cfg.is_train
+ self.cfg = cfg
+ self.env_name = cfg.env_name
+ self.continuous = cfg.continuous
+
+ def reset(self, **kwargs):
+ """
+ Overview:
+ Resets the state of the environment and reset properties.
+ Arguments:
+ - kwargs (:obj:`Dict`): Reset with this key argumets
+ Returns:
+ - observation (:obj:`Any`): New observation after reset
+ """
+ obs = self.env.reset(**kwargs)
+ self._raw_action_space = self.env.action_space
+
+ if self.cfg.manually_discretization:
+ # disc_to_cont: transform discrete action index to original continuous action
+ self.m = self._raw_action_space.shape[0]
+ self.n = self.cfg.each_dim_disc_size
+ self.K = self.n ** self.m
+ self.disc_to_cont = list(product(*[list(range(self.n)) for dim in range(self.m)]))
+ # the modified discrete action space
+ self._action_space = gym.spaces.Discrete(self.K)
+
+ return obs
+
+ def step(self, action):
+ """
+ Overview:
+ Step the environment with the given action. Repeat action, sum reward, \
+ and update ``data_count``, and also update the ``self.rms`` property \
+ once after integrating with the input ``action``.
+ Arguments:
+ - action (:obj:`Any`): the given action to step with.
+ Returns:
+ - ``self.observation(observation)`` : normalized observation after the \
+ input action and updated ``self.rms``
+ - reward (:obj:`Any`) : amount of reward returned after previous action
+ - done (:obj:`Bool`) : whether the episode has ended, in which case further \
+ step() calls will return undefined results
+ - info (:obj:`Dict`) : contains auxiliary diagnostic information (helpful \
+ for debugging, and sometimes learning)
+
+ """
+ if self.cfg.manually_discretization:
+ # disc_to_cont: transform discrete action index to original continuous action
+ action = [-1 + 2 / self.n * k for k in self.disc_to_cont[int(action)]]
+ action = to_ndarray(action)
+
+ # The core original env step.
+ obs, rew, done, info = self.env.step(action)
+
+ return BaseEnvTimestep(obs, rew, done, info)
+
+ def __repr__(self) -> str:
+ return "Action Discretization Env."
diff --git a/LightZero/lzero/envs/wrappers/lightzero_env_wrapper.py b/LightZero/lzero/envs/wrappers/lightzero_env_wrapper.py
new file mode 100644
index 0000000000000000000000000000000000000000..82318eb05257848db0080aaf28f7ee1b7377a5fb
--- /dev/null
+++ b/LightZero/lzero/envs/wrappers/lightzero_env_wrapper.py
@@ -0,0 +1,113 @@
+import gym
+import numpy as np
+from easydict import EasyDict
+
+from ding.envs import BaseEnvTimestep
+from ding.utils import ENV_WRAPPER_REGISTRY
+
+
+@ENV_WRAPPER_REGISTRY.register('lightzero_env_wrapper')
+class LightZeroEnvWrapper(gym.Wrapper):
+ """
+ Overview:
+ Package the classic_control, box2d environment into the format required by LightZero.
+ Wrap obs as a dict, containing keys: obs, action_mask and to_play.
+ Interface:
+ ``__init__``, ``reset``, ``step``
+ Properties:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ """
+
+ def __init__(self, env: gym.Env, cfg: EasyDict) -> None:
+ """
+ Overview:
+ Initialize ``self.`` See ``help(type(self))`` for accurate signature; \
+ setup the properties according to running mean and std.
+ Arguments:
+ - env (:obj:`gym.Env`): the environment to wrap.
+ """
+ super().__init__(env)
+ assert 'is_train' in cfg, '`is_train` flag must set in the config of env'
+ self.is_train = cfg.is_train
+ self.cfg = cfg
+ self.env_name = cfg.env_name
+ self.continuous = cfg.continuous
+
+ def reset(self, **kwargs):
+ """
+ Overview:
+ Resets the state of the environment and reset properties.
+ Arguments:
+ - kwargs (:obj:`Dict`): Reset with this key argumets
+ Returns:
+ - observation (:obj:`Any`): New observation after reset
+ """
+ # The core original env reset.
+ obs = self.env.reset(**kwargs)
+ self._eval_episode_return = 0.
+ self._raw_observation_space = self.env.observation_space
+
+ if self.cfg.continuous:
+ action_mask = None
+ else:
+ action_mask = np.ones(self.env.action_space.n, 'int8')
+
+ if self.cfg.continuous:
+ self._observation_space = gym.spaces.Dict(
+ {
+ 'observation': self._raw_observation_space,
+ 'action_mask': gym.spaces.Box(low=np.inf, high=np.inf,
+ shape=(1, )), # TODO: gym.spaces.Constant(None)
+ 'to_play': gym.spaces.Box(low=-1, high=-1, shape=(1, )), # TODO: gym.spaces.Constant(-1)
+ }
+ )
+ else:
+ self._observation_space = gym.spaces.Dict(
+ {
+ 'observation': self._raw_observation_space,
+ 'action_mask': gym.spaces.MultiDiscrete([2 for _ in range(self.env.action_space.n)])
+ if isinstance(self.env.action_space, gym.spaces.Discrete) else
+ gym.spaces.MultiDiscrete([2 for _ in range(self.env.action_space.shape[0])]), # {0,1}
+ 'to_play': gym.spaces.Box(low=-1, high=-1, shape=(1, )), # TODO: gym.spaces.Constant(-1)
+ }
+ )
+
+ lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
+ return lightzero_obs_dict
+
+ def step(self, action):
+ """
+ Overview:
+ Step the environment with the given action. Repeat action, sum reward, \
+ and update ``data_count``, and also update the ``self.rms`` property \
+ once after integrating with the input ``action``.
+ Arguments:
+ - action (:obj:`Any`): the given action to step with.
+ Returns:
+ - ``self.observation(observation)`` : normalized observation after the \
+ input action and updated ``self.rms``
+ - reward (:obj:`Any`) : amount of reward returned after previous action
+ - done (:obj:`Bool`) : whether the episode has ended, in which case further \
+ step() calls will return undefined results
+ - info (:obj:`Dict`) : contains auxiliary diagnostic information (helpful \
+ for debugging, and sometimes learning)
+
+ """
+ # The core original env step.
+ obs, rew, done, info = self.env.step(action)
+
+ if self.cfg.continuous:
+ action_mask = None
+ else:
+ action_mask = np.ones(self.env.action_space.n, 'int8')
+
+ lightzero_obs_dict = {'observation': obs, 'action_mask': action_mask, 'to_play': -1}
+
+ self._eval_episode_return += rew
+ if done:
+ info['eval_episode_return'] = self._eval_episode_return
+
+ return BaseEnvTimestep(lightzero_obs_dict, rew, done, info)
+
+ def __repr__(self) -> str:
+ return "LightZero Env."
\ No newline at end of file
diff --git a/LightZero/lzero/mcts/__init__.py b/LightZero/lzero/mcts/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..a6d259580c33f82df248e808ea8097192355a74a
--- /dev/null
+++ b/LightZero/lzero/mcts/__init__.py
@@ -0,0 +1,4 @@
+from .buffer import *
+from .ctree import *
+from .tree_search import *
+from .utils import *
diff --git a/LightZero/lzero/mcts/buffer/__init__.py b/LightZero/lzero/mcts/buffer/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..31680a75e1591cce05de3c27984050e0136da7d0
--- /dev/null
+++ b/LightZero/lzero/mcts/buffer/__init__.py
@@ -0,0 +1,5 @@
+from .game_buffer_muzero import MuZeroGameBuffer
+from .game_buffer_efficientzero import EfficientZeroGameBuffer
+from .game_buffer_sampled_efficientzero import SampledEfficientZeroGameBuffer
+from .game_buffer_gumbel_muzero import GumbelMuZeroGameBuffer
+from .game_buffer_stochastic_muzero import StochasticMuZeroGameBuffer
diff --git a/LightZero/lzero/mcts/buffer/game_buffer.py b/LightZero/lzero/mcts/buffer/game_buffer.py
new file mode 100644
index 0000000000000000000000000000000000000000..3d760cc21881ff61c98ed21505d96f44a6055e2f
--- /dev/null
+++ b/LightZero/lzero/mcts/buffer/game_buffer.py
@@ -0,0 +1,406 @@
+import copy
+import time
+from abc import ABC, abstractmethod
+from typing import Any, List, Tuple, Optional, Union, TYPE_CHECKING
+
+import numpy as np
+from ding.torch_utils.data_helper import to_list
+from ding.utils import BUFFER_REGISTRY
+from easydict import EasyDict
+
+if TYPE_CHECKING:
+ from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy, GumbelMuZeroPolicy
+
+
+@BUFFER_REGISTRY.register('game_buffer')
+class GameBuffer(ABC, object):
+ """
+ Overview:
+ The base game buffer class for MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy, GumbelMuZeroPolicy.
+ """
+
+ @classmethod
+ def default_config(cls: type) -> EasyDict:
+ cfg = EasyDict(copy.deepcopy(cls.config))
+ cfg.cfg_type = cls.__name__ + 'Dict'
+ return cfg
+
+ # Default configuration for GameBuffer.
+ config = dict(
+ # (int) The size/capacity of the replay buffer in terms of transitions.
+ replay_buffer_size=int(1e6),
+ # (float) The ratio of experiences required for the reanalyzing part in a minibatch.
+ reanalyze_ratio=0.3,
+ # (bool) Whether to consider outdated experiences for reanalyzing. If True, we first sort the data in the minibatch by the time it was produced
+ # and only reanalyze the oldest ``reanalyze_ratio`` fraction.
+ reanalyze_outdated=True,
+ # (bool) Whether to use the root value in the reanalyzing part. Please refer to EfficientZero paper for details.
+ use_root_value=False,
+ # (int) The number of samples required for mini inference.
+ mini_infer_size=256,
+ )
+
+ def __init__(self, cfg: dict):
+ super().__init__()
+ """
+ Overview:
+ Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key
+ in the default configuration, the user-provided value will override the default configuration. Otherwise,
+ the default configuration will be used.
+ """
+ default_config = self.default_config()
+ default_config.update(cfg)
+ self._cfg = default_config
+ self._cfg = cfg
+ assert self._cfg.env_type in ['not_board_games', 'board_games']
+ assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space']
+
+ self.replay_buffer_size = self._cfg.replay_buffer_size
+ self.batch_size = self._cfg.batch_size
+ self._alpha = self._cfg.priority_prob_alpha
+ self._beta = self._cfg.priority_prob_beta
+
+ self.game_segment_buffer = []
+ self.game_pos_priorities = []
+ self.game_segment_game_pos_look_up = []
+
+ self.keep_ratio = 1
+ self.num_of_collected_episodes = 0
+ self.base_idx = 0
+ self.clear_time = 0
+
+ @abstractmethod
+ def sample(
+ self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy", "GumbelMuZeroPolicy"]
+ ) -> List[Any]:
+ """
+ Overview:
+ sample data from ``GameBuffer`` and prepare the current and target batch for training.
+ Arguments:
+ - batch_size (:obj:`int`): batch size.
+ - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy", "GumbelMuZeroPolicy"]`): policy.
+ Returns:
+ - train_data (:obj:`List`): List of train data, including current_batch and target_batch.
+ """
+
+ @abstractmethod
+ def _make_batch(self, orig_data: Any, reanalyze_ratio: float) -> Tuple[Any]:
+ """
+ Overview:
+ prepare the context of a batch
+ reward_value_context: the context of reanalyzed value targets
+ policy_re_context: the context of reanalyzed policy targets
+ policy_non_re_context: the context of non-reanalyzed policy targets
+ current_batch: the inputs of batch
+ Arguments:
+ orig_data: Any batch context from replay buffer
+ reanalyze_ratio: float ratio of reanalyzed policy (value is 100% reanalyzed)
+ Returns:
+ - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch
+ """
+ pass
+
+ def _sample_orig_data(self, batch_size: int) -> Tuple:
+ """
+ Overview:
+ sample orig_data that contains:
+ game_segment_list: a list of game segments
+ pos_in_game_segment_list: transition index in game (relative index)
+ batch_index_list: the index of start transition of sampled minibatch in replay buffer
+ weights_list: the weight concerning the priority
+ make_time: the time the batch is made (for correctly updating replay buffer when data is deleted)
+ Arguments:
+ - batch_size (:obj:`int`): batch size
+ - beta: float the parameter in PER for calculating the priority
+ """
+ assert self._beta > 0
+ num_of_transitions = self.get_num_of_transitions()
+ if self._cfg.use_priority is False:
+ self.game_pos_priorities = np.ones_like(self.game_pos_priorities)
+
+ # +1e-6 for numerical stability
+ probs = self.game_pos_priorities ** self._alpha + 1e-6
+ probs /= probs.sum()
+
+ # sample according to transition index
+ # TODO(pu): replace=True
+ batch_index_list = np.random.choice(num_of_transitions, batch_size, p=probs, replace=False)
+
+ if self._cfg.reanalyze_outdated is True:
+ # NOTE: used in reanalyze part
+ batch_index_list.sort()
+
+ weights_list = (num_of_transitions * probs[batch_index_list]) ** (-self._beta)
+ weights_list /= weights_list.max()
+
+ game_segment_list = []
+ pos_in_game_segment_list = []
+
+ for idx in batch_index_list:
+ game_segment_idx, pos_in_game_segment = self.game_segment_game_pos_look_up[idx]
+ game_segment_idx -= self.base_idx
+ game_segment = self.game_segment_buffer[game_segment_idx]
+
+ game_segment_list.append(game_segment)
+ pos_in_game_segment_list.append(pos_in_game_segment)
+
+ make_time = [time.time() for _ in range(len(batch_index_list))]
+
+ orig_data = (game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time)
+ return orig_data
+
+ def _preprocess_to_play_and_action_mask(
+ self, game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
+ ):
+ """
+ Overview:
+ prepare the to_play and action_mask for the target obs in ``value_obs_list``
+ - to_play: {list: game_segment_batch_size * (num_unroll_steps+1)}
+ - action_mask: {list: game_segment_batch_size * (num_unroll_steps+1)}
+ """
+ to_play = []
+ for bs in range(game_segment_batch_size):
+ to_play_tmp = list(
+ to_play_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] +
+ self._cfg.num_unroll_steps + 1]
+ )
+ if len(to_play_tmp) < self._cfg.num_unroll_steps + 1:
+ # NOTE: the effective to play index is {1,2}, for null padding data, we set to_play=-1
+ to_play_tmp += [-1 for _ in range(self._cfg.num_unroll_steps + 1 - len(to_play_tmp))]
+ to_play.append(to_play_tmp)
+ to_play = sum(to_play, [])
+
+ if self._cfg.model.continuous_action_space is True:
+ # when the action space of the environment is continuous, action_mask[:] is None.
+ return to_play, None
+
+ action_mask = []
+ for bs in range(game_segment_batch_size):
+ action_mask_tmp = list(
+ action_mask_segment[bs][pos_in_game_segment_list[bs]:pos_in_game_segment_list[bs] +
+ self._cfg.num_unroll_steps + 1]
+ )
+ if len(action_mask_tmp) < self._cfg.num_unroll_steps + 1:
+ action_mask_tmp += [
+ list(np.ones(self._cfg.model.action_space_size, dtype=np.int8))
+ for _ in range(self._cfg.num_unroll_steps + 1 - len(action_mask_tmp))
+ ]
+ action_mask.append(action_mask_tmp)
+ action_mask = to_list(action_mask)
+ action_mask = sum(action_mask, [])
+
+ return to_play, action_mask
+
+ @abstractmethod
+ def _prepare_reward_value_context(
+ self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any],
+ total_transitions: int
+ ) -> List[Any]:
+ """
+ Overview:
+ prepare the context of rewards and values for calculating TD value target in reanalyzing part.
+ Arguments:
+ - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer
+ - game_segment_list (:obj:`list`): list of game segments
+ - pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment
+ - total_transitions (:obj:`int`): number of collected transitions
+ Returns:
+ - reward_value_context (:obj:`list`): value_obs_lst, value_mask, state_index_lst, rewards_lst, game_segment_lens,
+ td_steps_lst, action_mask_segment, to_play_segment
+ """
+ pass
+
+ @abstractmethod
+ def _prepare_policy_non_reanalyzed_context(
+ self, batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int]
+ ) -> List[Any]:
+ """
+ Overview:
+ prepare the context of policies for calculating policy target in non-reanalyzing part, just return the policy in self-play
+ Arguments:
+ - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer
+ - game_segment_list (:obj:`list`): list of game segments
+ - pos_in_game_segment_list (:obj:`list`): list transition index in game
+ Returns:
+ - policy_non_re_context (:obj:`list`): state_index_lst, child_visits, game_segment_lens, action_mask_segment, to_play_segment
+ """
+ pass
+
+ @abstractmethod
+ def _prepare_policy_reanalyzed_context(
+ self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str]
+ ) -> List[Any]:
+ """
+ Overview:
+ prepare the context of policies for calculating policy target in reanalyzing part.
+ Arguments:
+ - batch_index_list (:obj:'list'): start transition index in the replay buffer
+ - game_segment_list (:obj:'list'): list of game segments
+ - pos_in_game_segment_list (:obj:'list'): position of transition index in one game history
+ Returns:
+ - policy_re_context (:obj:`list`): policy_obs_lst, policy_mask, state_index_lst, indices,
+ child_visits, game_segment_lens, action_mask_segment, to_play_segment
+ """
+ pass
+
+ @abstractmethod
+ def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> List[np.ndarray]:
+ """
+ Overview:
+ prepare reward and value targets from the context of rewards and values.
+ Arguments:
+ - reward_value_context (:obj:'list'): the reward value context
+ - model (:obj:'torch.tensor'):model of the target model
+ Returns:
+ - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix
+ - batch_target_values (:obj:'np.ndarray): batch of value estimation
+ """
+ pass
+
+ @abstractmethod
+ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray:
+ """
+ Overview:
+ prepare policy targets from the reanalyzed context of policies
+ Arguments:
+ - policy_re_context (:obj:`List`): List of policy context to reanalyzed
+ Returns:
+ - batch_target_policies_re
+ """
+ pass
+
+ @abstractmethod
+ def _compute_target_policy_non_reanalyzed(
+ self, policy_non_re_context: List[Any], policy_shape: Optional[int]
+ ) -> np.ndarray:
+ """
+ Overview:
+ prepare policy targets from the non-reanalyzed context of policies
+ Arguments:
+ - policy_non_re_context (:obj:`List`): List containing:
+ - pos_in_game_segment_list
+ - child_visits
+ - game_segment_lens
+ - action_mask_segment
+ - to_play_segment
+ Returns:
+ - batch_target_policies_non_re
+ """
+ pass
+
+ @abstractmethod
+ def update_priority(
+ self, train_data: Optional[List[Optional[np.ndarray]]], batch_priorities: Optional[Any]
+ ) -> None:
+ """
+ Overview:
+ Update the priority of training data.
+ Arguments:
+ - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority.
+ - batch_priorities (:obj:`batch_priorities`): priorities to update to.
+ """
+ pass
+
+ def push_game_segments(self, data_and_meta: Any) -> None:
+ """
+ Overview:
+ Push game_segments data and it's meta information into buffer.
+ Save a game segment
+ Arguments:
+ - data_and_meta
+ - data (:obj:`Any`): The data (game segments) which will be pushed into buffer.
+ - meta (:obj:`dict`): Meta information, e.g. priority, count, staleness.
+ """
+ data, meta = data_and_meta
+ for (data_game, meta_game) in zip(data, meta):
+ self._push_game_segment(data_game, meta_game)
+
+ def _push_game_segment(self, data: Any, meta: Optional[dict] = None) -> None:
+ """
+ Overview:
+ Push data and it's meta information in buffer.
+ Save a game segment.
+ Arguments:
+ - data (:obj:`Any`): The data (a game segment) which will be pushed into buffer.
+ - meta (:obj:`dict`): Meta information, e.g. priority, count, staleness.
+ - done (:obj:`bool`): whether the game is finished.
+ - unroll_plus_td_steps (:obj:`int`): if the game is not finished, we only save the transitions that can be computed
+ - priorities (:obj:`list`): the priorities corresponding to the transitions in the game history
+ Returns:
+ - buffered_data (:obj:`BufferedData`): The pushed data.
+ """
+ if meta['done']:
+ self.num_of_collected_episodes += 1
+ valid_len = len(data)
+ else:
+ valid_len = len(data) - meta['unroll_plus_td_steps']
+
+ if meta['priorities'] is None:
+ max_prio = self.game_pos_priorities.max() if self.game_segment_buffer else 1
+ # if no 'priorities' provided, set the valid part of the new-added game history the max_prio
+ self.game_pos_priorities = np.concatenate(
+ (
+ self.game_pos_priorities, [max_prio
+ for _ in range(valid_len)] + [0. for _ in range(valid_len, len(data))]
+ )
+ )
+ else:
+ assert len(data) == len(meta['priorities']), " priorities should be of same length as the game steps"
+ priorities = meta['priorities'].copy().reshape(-1)
+ priorities[valid_len:len(data)] = 0.
+ self.game_pos_priorities = np.concatenate((self.game_pos_priorities, priorities))
+
+ self.game_segment_buffer.append(data)
+ self.game_segment_game_pos_look_up += [
+ (self.base_idx + len(self.game_segment_buffer) - 1, step_pos) for step_pos in range(len(data))
+ ]
+
+ def remove_oldest_data_to_fit(self) -> None:
+ """
+ Overview:
+ remove some oldest data if the replay buffer is full.
+ """
+ assert self.replay_buffer_size > self._cfg.batch_size, "replay buffer size should be larger than batch size"
+ nums_of_game_segments = self.get_num_of_game_segments()
+ total_transition = self.get_num_of_transitions()
+ if total_transition > self.replay_buffer_size:
+ index = 0
+ for i in range(nums_of_game_segments):
+ total_transition -= len(self.game_segment_buffer[i])
+ if total_transition <= self.replay_buffer_size * self.keep_ratio:
+ # find the max game_segment index to keep in the buffer
+ index = i
+ break
+ if total_transition >= self._cfg.batch_size:
+ self._remove(index + 1)
+
+ def _remove(self, excess_game_segment_index: List[int]) -> None:
+ """
+ Overview:
+ delete game segments in index [0: excess_game_segment_index]
+ Arguments:
+ - excess_game_segment_index (:obj:`List[str]`): Index of data.
+ """
+ excess_game_positions = sum(
+ [len(game_segment) for game_segment in self.game_segment_buffer[:excess_game_segment_index]]
+ )
+ del self.game_segment_buffer[:excess_game_segment_index]
+ self.game_pos_priorities = self.game_pos_priorities[excess_game_positions:]
+ del self.game_segment_game_pos_look_up[:excess_game_positions]
+ self.base_idx += excess_game_segment_index
+ self.clear_time = time.time()
+
+ def get_num_of_episodes(self) -> int:
+ # number of collected episodes
+ return self.num_of_collected_episodes
+
+ def get_num_of_game_segments(self) -> int:
+ # num of game segments
+ return len(self.game_segment_buffer)
+
+ def get_num_of_transitions(self) -> int:
+ # total number of transitions
+ return len(self.game_segment_game_pos_look_up)
+
+ def __repr__(self):
+ return f'current buffer statistics is: num_of_all_collected_episodes: {self.num_of_collected_episodes}, num of game segments: {len(self.game_segment_buffer)}, number of transitions: {len(self.game_segment_game_pos_look_up)}'
diff --git a/LightZero/lzero/mcts/buffer/game_buffer_efficientzero.py b/LightZero/lzero/mcts/buffer/game_buffer_efficientzero.py
new file mode 100644
index 0000000000000000000000000000000000000000..8bca8f269c3662a721622da62e55dcac714f976c
--- /dev/null
+++ b/LightZero/lzero/mcts/buffer/game_buffer_efficientzero.py
@@ -0,0 +1,441 @@
+from typing import Any, List
+
+import numpy as np
+import torch
+from ding.utils import BUFFER_REGISTRY
+
+from lzero.mcts.tree_search.mcts_ctree import EfficientZeroMCTSCtree as MCTSCtree
+from lzero.mcts.tree_search.mcts_ptree import EfficientZeroMCTSPtree as MCTSPtree
+from lzero.mcts.utils import prepare_observation
+from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform
+from .game_buffer_muzero import MuZeroGameBuffer
+
+
+@BUFFER_REGISTRY.register('game_buffer_efficientzero')
+class EfficientZeroGameBuffer(MuZeroGameBuffer):
+ """
+ Overview:
+ The specific game buffer for EfficientZero policy.
+ """
+
+ def __init__(self, cfg: dict):
+ super().__init__(cfg)
+ """
+ Overview:
+ Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key
+ in the default configuration, the user-provided value will override the default configuration. Otherwise,
+ the default configuration will be used.
+ """
+ default_config = self.default_config()
+ default_config.update(cfg)
+ self._cfg = default_config
+ assert self._cfg.env_type in ['not_board_games', 'board_games']
+ assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space']
+ self.replay_buffer_size = self._cfg.replay_buffer_size
+ self.batch_size = self._cfg.batch_size
+ self._alpha = self._cfg.priority_prob_alpha
+ self._beta = self._cfg.priority_prob_beta
+
+ self.game_segment_buffer = []
+ self.game_pos_priorities = []
+ self.game_segment_game_pos_look_up = []
+
+ self.keep_ratio = 1
+ self.num_of_collected_episodes = 0
+ self.base_idx = 0
+ self.clear_time = 0
+
+ def sample(self, batch_size: int, policy: Any) -> List[Any]:
+ """
+ Overview:
+ sample data from ``GameBuffer`` and prepare the current and target batch for training
+ Arguments:
+ - batch_size (:obj:`int`): batch size
+ - policy (:obj:`torch.tensor`): model of policy
+ Returns:
+ - train_data (:obj:`List`): List of train data
+ """
+ policy._target_model.to(self._cfg.device)
+ policy._target_model.eval()
+
+ # obtain the current_batch and prepare target context
+ reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch(
+ batch_size, self._cfg.reanalyze_ratio
+ )
+
+ # target value_prefixs, target value
+ batch_value_prefixs, batch_target_values = self._compute_target_reward_value(
+ reward_value_context, policy._target_model
+ )
+ # target policy
+ batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model)
+ batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
+ policy_non_re_context, self._cfg.model.action_space_size
+ )
+
+ if 0 < self._cfg.reanalyze_ratio < 1:
+ batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re])
+ elif self._cfg.reanalyze_ratio == 1:
+ batch_target_policies = batch_target_policies_re
+ elif self._cfg.reanalyze_ratio == 0:
+ batch_target_policies = batch_target_policies_non_re
+
+ target_batch = [batch_value_prefixs, batch_target_values, batch_target_policies]
+ # a batch contains the current_batch and the target_batch
+ train_data = [current_batch, target_batch]
+ return train_data
+
+ def _prepare_reward_value_context(
+ self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any],
+ total_transitions: int
+ ) -> List[Any]:
+ """
+ Overview:
+ prepare the context of rewards and values for calculating TD value target in reanalyzing part.
+ Arguments:
+ - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer
+ - game_segment_list (:obj:`list`): list of game segments
+ - pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment
+ - total_transitions (:obj:`int`): number of collected transitions
+ Returns:
+ - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens,
+ td_steps_list, action_mask_segment, to_play_segment
+ """
+ zero_obs = game_segment_list[0].zero_obs()
+ value_obs_list = []
+ # the value is valid or not (out of trajectory)
+ value_mask = []
+ rewards_list = []
+ game_segment_lens = []
+ # for two_player board games
+ action_mask_segment, to_play_segment = [], []
+
+ td_steps_list = []
+ for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list):
+ game_segment_len = len(game_segment)
+ game_segment_lens.append(game_segment_len)
+
+ # ==============================================================
+ # EfficientZero related core code
+ # ==============================================================
+ # TODO(pu):
+ # for atari, off-policy correction: shorter horizon of td steps
+ # delta_td = (total_transitions - idx) // config.auto_td_steps
+ # td_steps = config.td_steps - delta_td
+ # td_steps = np.clip(td_steps, 1, 5).astype(np.int)
+ td_steps = np.clip(self._cfg.td_steps, 1, max(1, game_segment_len - state_index)).astype(np.int32)
+
+ # prepare the corresponding observations for bootstrapped values o_{t+k}
+ # o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps]
+ # t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14]
+ game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps)
+
+ rewards_list.append(game_segment.reward_segment)
+
+ # for two_player board games
+ action_mask_segment.append(game_segment.action_mask_segment)
+ to_play_segment.append(game_segment.to_play_segment)
+
+ for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
+ # get the bootstrapped target obs
+ td_steps_list.append(td_steps)
+ # index of bootstrapped obs o_{t+td_steps}
+ bootstrap_index = current_index + td_steps
+
+ if bootstrap_index < game_segment_len:
+ value_mask.append(1)
+ # beg_index = bootstrap_index - (state_index + td_steps), max of beg_index is num_unroll_steps
+ beg_index = current_index - state_index
+ end_index = beg_index + self._cfg.model.frame_stack_num
+ # the stacked obs in time t
+ obs = game_obs[beg_index:end_index]
+ else:
+ value_mask.append(0)
+ obs = zero_obs
+
+ value_obs_list.append(obs)
+
+ reward_value_context = [
+ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list,
+ action_mask_segment, to_play_segment
+ ]
+ return reward_value_context
+
+ def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> List[np.ndarray]:
+ """
+ Overview:
+ prepare reward and value targets from the context of rewards and values.
+ Arguments:
+ - reward_value_context (:obj:'list'): the reward value context
+ - model (:obj:'torch.tensor'):model of the target model
+ Returns:
+ - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix
+ - batch_target_values (:obj:'np.ndarray): batch of value estimation
+ """
+ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \
+ to_play_segment = reward_value_context # noqa
+ # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1)
+ transition_batch_size = len(value_obs_list)
+ game_segment_batch_size = len(pos_in_game_segment_list)
+
+ to_play, action_mask = self._preprocess_to_play_and_action_mask(
+ game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
+ )
+
+ legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
+
+ # ==============================================================
+ # EfficientZero related core code
+ # ==============================================================
+ batch_target_values, batch_value_prefixs = [], []
+ with torch.no_grad():
+ value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type)
+ # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors
+ slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size))
+ network_output = []
+ for i in range(slices):
+ beg_index = self._cfg.mini_infer_size * i
+ end_index = self._cfg.mini_infer_size * (i + 1)
+
+ m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float()
+
+ # calculate the target value
+ m_output = model.initial_inference(m_obs)
+ if not model.training:
+ # ==============================================================
+ # EfficientZero related core code
+ # ==============================================================
+ # if not in training, obtain the scalars of the value/reward
+ [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
+ [
+ m_output.latent_state,
+ inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
+ m_output.policy_logits
+ ]
+ )
+ m_output.reward_hidden_state = (
+ m_output.reward_hidden_state[0].detach().cpu().numpy(),
+ m_output.reward_hidden_state[1].detach().cpu().numpy()
+ )
+ network_output.append(m_output)
+
+ # concat the output slices after model inference
+ if self._cfg.use_root_value:
+ # use the root values from MCTS, as in EfficiientZero
+ # the root values have limited improvement but require much more GPU actors;
+ _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output(
+ network_output, data_type='efficientzero'
+ )
+ value_prefix_pool = value_prefix_pool.squeeze().tolist()
+ policy_logits_pool = policy_logits_pool.tolist()
+ noises = [
+ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size
+ ).astype(np.float32).tolist() for _ in range(transition_batch_size)
+ ]
+ if self._cfg.mcts_ctree:
+ # cpp mcts_tree
+ roots = MCTSCtree.roots(transition_batch_size, legal_actions)
+ roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
+ # do MCTS for a new policy with the recent target model
+ MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)
+ else:
+ # python mcts_tree
+ roots = MCTSPtree.roots(transition_batch_size, legal_actions)
+ roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
+ # do MCTS for a new policy with the recent target model
+ MCTSPtree(self._cfg).search(
+ roots, model, latent_state_roots, reward_hidden_state_roots, to_play=to_play
+ )
+ roots_values = roots.get_values()
+ value_list = np.array(roots_values)
+ else:
+ # use the predicted values
+ value_list = concat_output_value(network_output)
+
+ # get last state value
+ if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]:
+ # TODO(pu): for board_games, very important, to check
+ value_list = value_list.reshape(-1) * np.array(
+ [
+ self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) %
+ 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i]
+ for i in range(transition_batch_size)
+ ]
+ )
+ else:
+ value_list = value_list.reshape(-1) * (
+ np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list
+ )
+
+ value_list = value_list * np.array(value_mask)
+ value_list = value_list.tolist()
+ horizon_id, value_index = 0, 0
+ for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list,
+ pos_in_game_segment_list,
+ to_play_segment):
+ target_values = []
+ target_value_prefixs = []
+ value_prefix = 0.0
+ base_index = state_index
+ for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
+ bootstrap_index = current_index + td_steps_list[value_index]
+ for i, reward in enumerate(reward_list[current_index:bootstrap_index]):
+ if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]:
+ # TODO(pu): for board_games, very important, to check
+ if to_play_list[base_index] == to_play_list[i]:
+ value_list[value_index] += reward * self._cfg.discount_factor ** i
+ else:
+ value_list[value_index] += -reward * self._cfg.discount_factor ** i
+ else:
+ value_list[value_index] += reward * self._cfg.discount_factor ** i
+
+ # reset every lstm_horizon_len
+ if horizon_id % self._cfg.lstm_horizon_len == 0:
+ value_prefix = 0.0
+ base_index = current_index
+ horizon_id += 1
+
+ if current_index < game_segment_len_non_re:
+ target_values.append(value_list[value_index])
+ # TODO: Since the horizon is small and the discount_factor is close to 1.
+ # Compute the reward sum to approximate the value prefix for simplification
+ value_prefix += reward_list[current_index
+ ] # * self._cfg.discount_factor ** (current_index - base_index)
+ target_value_prefixs.append(value_prefix)
+ else:
+ target_values.append(0)
+ target_value_prefixs.append(value_prefix)
+ value_index += 1
+ batch_value_prefixs.append(target_value_prefixs)
+ batch_target_values.append(target_values)
+ batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=object)
+ batch_target_values = np.asarray(batch_target_values, dtype=object)
+
+ return batch_value_prefixs, batch_target_values
+
+ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray:
+ """
+ Overview:
+ prepare policy targets from the reanalyzed context of policies
+ Arguments:
+ - policy_re_context (:obj:`List`): List of policy context to reanalyzed
+ Returns:
+ - batch_target_policies_re
+ """
+ if policy_re_context is None:
+ return []
+ batch_target_policies_re = []
+
+ policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \
+ to_play_segment = policy_re_context # noqa
+ # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1)
+ transition_batch_size = len(policy_obs_list)
+ game_segment_batch_size = len(pos_in_game_segment_list)
+ to_play, action_mask = self._preprocess_to_play_and_action_mask(
+ game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
+ )
+
+ legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
+ with torch.no_grad():
+ policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type)
+ # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors
+ slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size))
+ network_output = []
+ for i in range(slices):
+ beg_index = self._cfg.mini_infer_size * i
+ end_index = self._cfg.mini_infer_size * (i + 1)
+ m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float()
+
+ m_output = model.initial_inference(m_obs)
+
+ if not model.training:
+ # if not in training, obtain the scalars of the value/reward
+ [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
+ [
+ m_output.latent_state,
+ inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
+ m_output.policy_logits
+ ]
+ )
+ m_output.reward_hidden_state = (
+ m_output.reward_hidden_state[0].detach().cpu().numpy(),
+ m_output.reward_hidden_state[1].detach().cpu().numpy()
+ )
+
+ network_output.append(m_output)
+
+ _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output(
+ network_output, data_type='efficientzero'
+ )
+ value_prefix_pool = value_prefix_pool.squeeze().tolist()
+ policy_logits_pool = policy_logits_pool.tolist()
+ noises = [
+ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size
+ ).astype(np.float32).tolist() for _ in range(transition_batch_size)
+ ]
+ if self._cfg.mcts_ctree:
+ # cpp mcts_tree
+ roots = MCTSCtree.roots(transition_batch_size, legal_actions)
+ roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
+ # do MCTS for a new policy with the recent target model
+ MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)
+ else:
+ # python mcts_tree
+ roots = MCTSPtree.roots(transition_batch_size, legal_actions)
+ roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
+ # do MCTS for a new policy with the recent target model
+ MCTSPtree(self._cfg).search(
+ roots, model, latent_state_roots, reward_hidden_state_roots, to_play=to_play
+ )
+
+ roots_legal_actions_list = legal_actions
+ roots_distributions = roots.get_distributions()
+ policy_index = 0
+ for state_index, game_index in zip(pos_in_game_segment_list, batch_index_list):
+ target_policies = []
+ for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
+ distributions = roots_distributions[policy_index]
+ if policy_mask[policy_index] == 0:
+ # NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0
+ target_policies.append([0 for _ in range(self._cfg.model.action_space_size)])
+ else:
+ if distributions is None:
+ # if at some obs, the legal_action is None, add the fake target_policy
+ target_policies.append(
+ list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size)
+ )
+ else:
+ if self._cfg.mcts_ctree:
+ # cpp mcts_tree
+ if self._cfg.action_type == 'fixed_action_space':
+ sum_visits = sum(distributions)
+ policy = [visit_count / sum_visits for visit_count in distributions]
+ target_policies.append(policy)
+ else:
+ # for two_player board games
+ policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)]
+ # to make sure target_policies have the same dimension
+ sum_visits = sum(distributions)
+ policy = [visit_count / sum_visits for visit_count in distributions]
+ for index, legal_action in enumerate(roots_legal_actions_list[policy_index]):
+ policy_tmp[legal_action] = policy[index]
+ target_policies.append(policy_tmp)
+ else:
+ # python mcts_tree
+ if self._cfg.action_type == 'fixed_action_space':
+ sum_visits = sum(distributions)
+ policy = [visit_count / sum_visits for visit_count in distributions]
+ target_policies.append(policy)
+ else:
+ # for two_player board games
+ policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)]
+ # to make sure target_policies have the same dimension
+ sum_visits = sum(distributions)
+ policy = [visit_count / sum_visits for visit_count in distributions]
+ for index, legal_action in enumerate(roots_legal_actions_list[policy_index]):
+ policy_tmp[legal_action] = policy[index]
+ target_policies.append(policy_tmp)
+ policy_index += 1
+ batch_target_policies_re.append(target_policies)
+ batch_target_policies_re = np.array(batch_target_policies_re)
+ return batch_target_policies_re
diff --git a/LightZero/lzero/mcts/buffer/game_buffer_gumbel_muzero.py b/LightZero/lzero/mcts/buffer/game_buffer_gumbel_muzero.py
new file mode 100644
index 0000000000000000000000000000000000000000..4ede5822bee23ad9f7da61fa1be433c4354db44b
--- /dev/null
+++ b/LightZero/lzero/mcts/buffer/game_buffer_gumbel_muzero.py
@@ -0,0 +1,123 @@
+from typing import Any, Tuple
+
+import numpy as np
+from ding.utils import BUFFER_REGISTRY
+
+from lzero.mcts.buffer import MuZeroGameBuffer
+from lzero.mcts.utils import prepare_observation
+
+
+@BUFFER_REGISTRY.register('game_buffer_gumbel_muzero')
+class GumbelMuZeroGameBuffer(MuZeroGameBuffer):
+ """
+ Overview:
+ The specific game buffer for Gumbel MuZero policy.
+ """
+
+ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
+ """
+ Overview:
+ first sample orig_data through ``_sample_orig_data()``,
+ then prepare the context of a batch:
+ reward_value_context: the context of reanalyzed value targets
+ policy_re_context: the context of reanalyzed policy targets
+ policy_non_re_context: the context of non-reanalyzed policy targets
+ current_batch: the inputs of batch
+ Arguments:
+ - batch_size (:obj:`int`): the batch size of orig_data from replay buffer.
+ - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed)
+ Returns:
+ - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch
+ """
+ # obtain the batch context from replay buffer
+ orig_data = self._sample_orig_data(batch_size)
+ game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data
+ batch_size = len(batch_index_list)
+
+ # ==============================================================
+ # The core difference between GumbelMuZero and MuZero
+ # ==============================================================
+ # The main difference between Gumbel MuZero and MuZero lies in the preprocessing of improved_policy.
+ obs_list, action_list, improved_policy_list, mask_list = [], [], [], []
+ # prepare the inputs of a batch
+ for i in range(batch_size):
+ game = game_segment_list[i]
+ pos_in_game_segment = pos_in_game_segment_list[i]
+
+ actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
+ self._cfg.num_unroll_steps].tolist()
+
+ _improved_policy = game.improved_policy_probs[
+ pos_in_game_segment:pos_in_game_segment + self._cfg.num_unroll_steps]
+ if not isinstance(_improved_policy, list):
+ _improved_policy = _improved_policy.tolist()
+
+ # add mask for invalid actions (out of trajectory)
+ mask_tmp = [1. for i in range(len(actions_tmp))]
+ mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))]
+
+ # pad random action
+ actions_tmp += [
+ np.random.randint(0, game.action_space_size)
+ for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
+ ]
+
+ # pad improved policy with a value such that the sum of the values is equal to 1
+ _improved_policy.extend(np.random.dirichlet(np.ones(game.action_space_size),
+ size=self._cfg.num_unroll_steps + 1 - len(_improved_policy)))
+
+ # obtain the input observations
+ # pad if length of obs in game_segment is less than stack+num_unroll_steps
+ # e.g. stack+num_unroll_steps = 4+5
+ obs_list.append(
+ game_segment_list[i].get_unroll_obs(
+ pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
+ )
+ )
+ action_list.append(actions_tmp)
+ improved_policy_list.append(_improved_policy)
+ mask_list.append(mask_tmp)
+
+ # formalize the input observations
+ obs_list = prepare_observation(obs_list, self._cfg.model.model_type)
+
+ # formalize the inputs of a batch
+ current_batch = [obs_list, action_list, improved_policy_list, mask_list, batch_index_list, weights_list,
+ make_time_list]
+ for i in range(len(current_batch)):
+ current_batch[i] = np.asarray(current_batch[i])
+
+ total_transitions = self.get_num_of_transitions()
+
+ # obtain the context of value targets
+ reward_value_context = self._prepare_reward_value_context(
+ batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions
+ )
+ """
+ only reanalyze recent reanalyze_ratio (e.g. 50%) data
+ if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps
+ 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy
+ """
+ reanalyze_num = int(batch_size * reanalyze_ratio)
+ # reanalyzed policy
+ if reanalyze_num > 0:
+ # obtain the context of reanalyzed policy targets
+ policy_re_context = self._prepare_policy_reanalyzed_context(
+ batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num],
+ pos_in_game_segment_list[:reanalyze_num]
+ )
+ else:
+ policy_re_context = None
+
+ # non reanalyzed policy
+ if reanalyze_num < batch_size:
+ # obtain the context of non-reanalyzed policy targets
+ policy_non_re_context = self._prepare_policy_non_reanalyzed_context(
+ batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:],
+ pos_in_game_segment_list[reanalyze_num:]
+ )
+ else:
+ policy_non_re_context = None
+
+ context = reward_value_context, policy_re_context, policy_non_re_context, current_batch
+ return context
diff --git a/LightZero/lzero/mcts/buffer/game_buffer_muzero.py b/LightZero/lzero/mcts/buffer/game_buffer_muzero.py
new file mode 100644
index 0000000000000000000000000000000000000000..43bf9c73599f66c3c58f6133e1b1db9c8db548e3
--- /dev/null
+++ b/LightZero/lzero/mcts/buffer/game_buffer_muzero.py
@@ -0,0 +1,699 @@
+from typing import Any, List, Tuple, Union, TYPE_CHECKING, Optional
+
+import numpy as np
+import torch
+from ding.utils import BUFFER_REGISTRY
+
+from lzero.mcts.tree_search.mcts_ctree import MuZeroMCTSCtree as MCTSCtree
+from lzero.mcts.tree_search.mcts_ptree import MuZeroMCTSPtree as MCTSPtree
+from lzero.mcts.utils import prepare_observation
+from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform
+from .game_buffer import GameBuffer
+
+if TYPE_CHECKING:
+ from lzero.policy import MuZeroPolicy, EfficientZeroPolicy, SampledEfficientZeroPolicy
+
+
+@BUFFER_REGISTRY.register('game_buffer_muzero')
+class MuZeroGameBuffer(GameBuffer):
+ """
+ Overview:
+ The specific game buffer for MuZero policy.
+ """
+
+ def __init__(self, cfg: dict):
+ super().__init__(cfg)
+ """
+ Overview:
+ Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key
+ in the default configuration, the user-provided value will override the default configuration. Otherwise,
+ the default configuration will be used.
+ """
+ default_config = self.default_config()
+ default_config.update(cfg)
+ self._cfg = default_config
+ assert self._cfg.env_type in ['not_board_games', 'board_games']
+ assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space']
+ self.replay_buffer_size = self._cfg.replay_buffer_size
+ self.batch_size = self._cfg.batch_size
+ self._alpha = self._cfg.priority_prob_alpha
+ self._beta = self._cfg.priority_prob_beta
+
+ self.keep_ratio = 1
+ self.model_update_interval = 10
+ self.num_of_collected_episodes = 0
+ self.base_idx = 0
+ self.clear_time = 0
+
+ self.game_segment_buffer = []
+ self.game_pos_priorities = []
+ self.game_segment_game_pos_look_up = []
+
+ def sample(
+ self, batch_size: int, policy: Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]
+ ) -> List[Any]:
+ """
+ Overview:
+ sample data from ``GameBuffer`` and prepare the current and target batch for training.
+ Arguments:
+ - batch_size (:obj:`int`): batch size.
+ - policy (:obj:`Union["MuZeroPolicy", "EfficientZeroPolicy", "SampledEfficientZeroPolicy"]`): policy.
+ Returns:
+ - train_data (:obj:`List`): List of train data, including current_batch and target_batch.
+ """
+ policy._target_model.to(self._cfg.device)
+ policy._target_model.eval()
+
+ # obtain the current_batch and prepare target context
+ reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch(
+ batch_size, self._cfg.reanalyze_ratio
+ )
+ # target reward, target value
+ batch_rewards, batch_target_values = self._compute_target_reward_value(
+ reward_value_context, policy._target_model
+ )
+ # target policy
+ batch_target_policies_re = self._compute_target_policy_reanalyzed(policy_re_context, policy._target_model)
+ batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
+ policy_non_re_context, self._cfg.model.action_space_size
+ )
+
+ # fusion of batch_target_policies_re and batch_target_policies_non_re to batch_target_policies
+ if 0 < self._cfg.reanalyze_ratio < 1:
+ batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re])
+ elif self._cfg.reanalyze_ratio == 1:
+ batch_target_policies = batch_target_policies_re
+ elif self._cfg.reanalyze_ratio == 0:
+ batch_target_policies = batch_target_policies_non_re
+
+ target_batch = [batch_rewards, batch_target_values, batch_target_policies]
+
+ # a batch contains the current_batch and the target_batch
+ train_data = [current_batch, target_batch]
+ return train_data
+
+ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
+ """
+ Overview:
+ first sample orig_data through ``_sample_orig_data()``,
+ then prepare the context of a batch:
+ reward_value_context: the context of reanalyzed value targets
+ policy_re_context: the context of reanalyzed policy targets
+ policy_non_re_context: the context of non-reanalyzed policy targets
+ current_batch: the inputs of batch
+ Arguments:
+ - batch_size (:obj:`int`): the batch size of orig_data from replay buffer.
+ - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed)
+ Returns:
+ - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch
+ """
+ # obtain the batch context from replay buffer
+ orig_data = self._sample_orig_data(batch_size)
+ game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data
+ batch_size = len(batch_index_list)
+ obs_list, action_list, mask_list = [], [], []
+ # prepare the inputs of a batch
+ for i in range(batch_size):
+ game = game_segment_list[i]
+ pos_in_game_segment = pos_in_game_segment_list[i]
+
+ actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
+ self._cfg.num_unroll_steps].tolist()
+ # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid
+ mask_tmp = [1. for i in range(len(actions_tmp))]
+ mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))]
+
+ # pad random action
+ actions_tmp += [
+ np.random.randint(0, game.action_space_size)
+ for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
+ ]
+
+ # obtain the input observations
+ # pad if length of obs in game_segment is less than stack+num_unroll_steps
+ # e.g. stack+num_unroll_steps = 4+5
+ obs_list.append(
+ game_segment_list[i].get_unroll_obs(
+ pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
+ )
+ )
+ action_list.append(actions_tmp)
+ mask_list.append(mask_tmp)
+
+ # formalize the input observations
+ obs_list = prepare_observation(obs_list, self._cfg.model.model_type)
+
+ # formalize the inputs of a batch
+ current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list]
+ for i in range(len(current_batch)):
+ current_batch[i] = np.asarray(current_batch[i])
+
+ total_transitions = self.get_num_of_transitions()
+
+ # obtain the context of value targets
+ reward_value_context = self._prepare_reward_value_context(
+ batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions
+ )
+ """
+ only reanalyze recent reanalyze_ratio (e.g. 50%) data
+ if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps
+ 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy
+ """
+ reanalyze_num = int(batch_size * reanalyze_ratio)
+ # reanalyzed policy
+ if reanalyze_num > 0:
+ # obtain the context of reanalyzed policy targets
+ policy_re_context = self._prepare_policy_reanalyzed_context(
+ batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num],
+ pos_in_game_segment_list[:reanalyze_num]
+ )
+ else:
+ policy_re_context = None
+
+ # non reanalyzed policy
+ if reanalyze_num < batch_size:
+ # obtain the context of non-reanalyzed policy targets
+ policy_non_re_context = self._prepare_policy_non_reanalyzed_context(
+ batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:],
+ pos_in_game_segment_list[reanalyze_num:]
+ )
+ else:
+ policy_non_re_context = None
+
+ context = reward_value_context, policy_re_context, policy_non_re_context, current_batch
+ return context
+
+ def _prepare_reward_value_context(
+ self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[Any],
+ total_transitions: int
+ ) -> List[Any]:
+ """
+ Overview:
+ prepare the context of rewards and values for calculating TD value target in reanalyzing part.
+ Arguments:
+ - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer
+ - game_segment_list (:obj:`list`): list of game segments
+ - pos_in_game_segment_list (:obj:`list`): list of transition index in game_segment
+ - total_transitions (:obj:`int`): number of collected transitions
+ Returns:
+ - reward_value_context (:obj:`list`): value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens,
+ td_steps_list, action_mask_segment, to_play_segment
+ """
+ zero_obs = game_segment_list[0].zero_obs()
+ value_obs_list = []
+ # the value is valid or not (out of game_segment)
+ value_mask = []
+ rewards_list = []
+ game_segment_lens = []
+ # for board games
+ action_mask_segment, to_play_segment = [], []
+
+ td_steps_list = []
+ for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list):
+ game_segment_len = len(game_segment)
+ game_segment_lens.append(game_segment_len)
+
+ td_steps = np.clip(self._cfg.td_steps, 1, max(1, game_segment_len - state_index)).astype(np.int32)
+
+ # prepare the corresponding observations for bootstrapped values o_{t+k}
+ # o[t+ td_steps, t + td_steps + stack frames + num_unroll_steps]
+ # t=2+3 -> o[2+3, 2+3+4+5] -> o[5, 14]
+ game_obs = game_segment.get_unroll_obs(state_index + td_steps, self._cfg.num_unroll_steps)
+
+ rewards_list.append(game_segment.reward_segment)
+
+ # for board games
+ action_mask_segment.append(game_segment.action_mask_segment)
+ to_play_segment.append(game_segment.to_play_segment)
+
+ for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
+ # get the bootstrapped target obs
+ td_steps_list.append(td_steps)
+ # index of bootstrapped obs o_{t+td_steps}
+ bootstrap_index = current_index + td_steps
+
+ if bootstrap_index < game_segment_len:
+ value_mask.append(1)
+ # beg_index = bootstrap_index - (state_index + td_steps), max of beg_index is num_unroll_steps
+ beg_index = current_index - state_index
+ end_index = beg_index + self._cfg.model.frame_stack_num
+ # the stacked obs in time t
+ obs = game_obs[beg_index:end_index]
+ else:
+ value_mask.append(0)
+ obs = zero_obs
+
+ value_obs_list.append(obs)
+
+ reward_value_context = [
+ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list,
+ action_mask_segment, to_play_segment
+ ]
+ return reward_value_context
+
+ def _prepare_policy_non_reanalyzed_context(
+ self, batch_index_list: List[int], game_segment_list: List[Any], pos_in_game_segment_list: List[int]
+ ) -> List[Any]:
+ """
+ Overview:
+ prepare the context of policies for calculating policy target in non-reanalyzing part, just return the policy in self-play
+ Arguments:
+ - batch_index_list (:obj:`list`): the index of start transition of sampled minibatch in replay buffer
+ - game_segment_list (:obj:`list`): list of game segments
+ - pos_in_game_segment_list (:obj:`list`): list transition index in game
+ Returns:
+ - policy_non_re_context (:obj:`list`): pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment
+ """
+ child_visits = []
+ game_segment_lens = []
+ # for board games
+ action_mask_segment, to_play_segment = [], []
+
+ for game_segment, state_index, idx in zip(game_segment_list, pos_in_game_segment_list, batch_index_list):
+ game_segment_len = len(game_segment)
+ game_segment_lens.append(game_segment_len)
+ # for board games
+ action_mask_segment.append(game_segment.action_mask_segment)
+ to_play_segment.append(game_segment.to_play_segment)
+
+ child_visits.append(game_segment.child_visit_segment)
+
+ policy_non_re_context = [
+ pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment
+ ]
+ return policy_non_re_context
+
+ def _prepare_policy_reanalyzed_context(
+ self, batch_index_list: List[str], game_segment_list: List[Any], pos_in_game_segment_list: List[str]
+ ) -> List[Any]:
+ """
+ Overview:
+ prepare the context of policies for calculating policy target in reanalyzing part.
+ Arguments:
+ - batch_index_list (:obj:'list'): start transition index in the replay buffer
+ - game_segment_list (:obj:'list'): list of game segments
+ - pos_in_game_segment_list (:obj:'list'): position of transition index in one game history
+ Returns:
+ - policy_re_context (:obj:`list`): policy_obs_list, policy_mask, pos_in_game_segment_list, indices,
+ child_visits, game_segment_lens, action_mask_segment, to_play_segment
+ """
+ zero_obs = game_segment_list[0].zero_obs()
+ with torch.no_grad():
+ # for policy
+ policy_obs_list = []
+ policy_mask = []
+ # 0 -> Invalid target policy for padding outside of game segments,
+ # 1 -> Previous target policy for game segments.
+ rewards, child_visits, game_segment_lens = [], [], []
+ # for board games
+ action_mask_segment, to_play_segment = [], []
+ for game_segment, state_index in zip(game_segment_list, pos_in_game_segment_list):
+ game_segment_len = len(game_segment)
+ game_segment_lens.append(game_segment_len)
+ rewards.append(game_segment.reward_segment)
+ # for board games
+ action_mask_segment.append(game_segment.action_mask_segment)
+ to_play_segment.append(game_segment.to_play_segment)
+
+ child_visits.append(game_segment.child_visit_segment)
+ # prepare the corresponding observations
+ game_obs = game_segment.get_unroll_obs(state_index, self._cfg.num_unroll_steps)
+ for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
+
+ if current_index < game_segment_len:
+ policy_mask.append(1)
+ beg_index = current_index - state_index
+ end_index = beg_index + self._cfg.model.frame_stack_num
+ obs = game_obs[beg_index:end_index]
+ else:
+ policy_mask.append(0)
+ obs = zero_obs
+ policy_obs_list.append(obs)
+
+ policy_re_context = [
+ policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens,
+ action_mask_segment, to_play_segment
+ ]
+ return policy_re_context
+
+ def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> Tuple[Any, Any]:
+ """
+ Overview:
+ prepare reward and value targets from the context of rewards and values.
+ Arguments:
+ - reward_value_context (:obj:'list'): the reward value context
+ - model (:obj:'torch.tensor'):model of the target model
+ Returns:
+ - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix
+ - batch_target_values (:obj:'np.ndarray): batch of value estimation
+ """
+ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \
+ to_play_segment = reward_value_context # noqa
+ # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1)
+ transition_batch_size = len(value_obs_list)
+ game_segment_batch_size = len(pos_in_game_segment_list)
+
+ to_play, action_mask = self._preprocess_to_play_and_action_mask(
+ game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
+ )
+ if self._cfg.model.continuous_action_space is True:
+ # when the action space of the environment is continuous, action_mask[:] is None.
+ action_mask = [
+ list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size)
+ ]
+ # NOTE: in continuous action space env: we set all legal_actions as -1
+ legal_actions = [
+ [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
+ ]
+ else:
+ legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
+
+ batch_target_values, batch_rewards = [], []
+ with torch.no_grad():
+ value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type)
+ # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors
+ slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size))
+ network_output = []
+ for i in range(slices):
+ beg_index = self._cfg.mini_infer_size * i
+ end_index = self._cfg.mini_infer_size * (i + 1)
+
+ m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float()
+
+ # calculate the target value
+ m_output = model.initial_inference(m_obs)
+
+ if not model.training:
+ # if not in training, obtain the scalars of the value/reward
+ [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
+ [
+ m_output.latent_state,
+ inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
+ m_output.policy_logits
+ ]
+ )
+
+ network_output.append(m_output)
+
+ # concat the output slices after model inference
+ if self._cfg.use_root_value:
+ # use the root values from MCTS, as in EfficiientZero
+ # the root values have limited improvement but require much more GPU actors;
+ _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(
+ network_output, data_type='muzero'
+ )
+ reward_pool = reward_pool.squeeze().tolist()
+ policy_logits_pool = policy_logits_pool.tolist()
+ noises = [
+ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * int(sum(action_mask[j]))
+ ).astype(np.float32).tolist() for j in range(transition_batch_size)
+ ]
+ if self._cfg.mcts_ctree:
+ # cpp mcts_tree
+ roots = MCTSCtree.roots(transition_batch_size, legal_actions)
+ roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
+ # do MCTS for a new policy with the recent target model
+ MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play)
+ else:
+ # python mcts_tree
+ roots = MCTSPtree.roots(transition_batch_size, legal_actions)
+ roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
+ # do MCTS for a new policy with the recent target model
+ MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play)
+
+ roots_values = roots.get_values()
+ value_list = np.array(roots_values)
+ else:
+ # use the predicted values
+ value_list = concat_output_value(network_output)
+
+ # get last state value
+ if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]:
+ # TODO(pu): for board_games, very important, to check
+ value_list = value_list.reshape(-1) * np.array(
+ [
+ self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) %
+ 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i]
+ for i in range(transition_batch_size)
+ ]
+ )
+ else:
+ value_list = value_list.reshape(-1) * (
+ np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list
+ )
+
+ value_list = value_list * np.array(value_mask)
+ value_list = value_list.tolist()
+ horizon_id, value_index = 0, 0
+
+ for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list,
+ pos_in_game_segment_list,
+ to_play_segment):
+ target_values = []
+ target_rewards = []
+ base_index = state_index
+ for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
+ bootstrap_index = current_index + td_steps_list[value_index]
+ # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]):
+ for i, reward in enumerate(reward_list[current_index:bootstrap_index]):
+ if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]:
+ # TODO(pu): for board_games, very important, to check
+ if to_play_list[base_index] == to_play_list[i]:
+ value_list[value_index] += reward * self._cfg.discount_factor ** i
+ else:
+ value_list[value_index] += -reward * self._cfg.discount_factor ** i
+ else:
+ value_list[value_index] += reward * self._cfg.discount_factor ** i
+ horizon_id += 1
+
+ if current_index < game_segment_len_non_re:
+ target_values.append(value_list[value_index])
+ target_rewards.append(reward_list[current_index])
+ else:
+ target_values.append(0)
+ target_rewards.append(0.0)
+ # TODO: check
+ # target_rewards.append(reward)
+ value_index += 1
+
+ batch_rewards.append(target_rewards)
+ batch_target_values.append(target_values)
+
+ batch_rewards = np.asarray(batch_rewards, dtype=object)
+ batch_target_values = np.asarray(batch_target_values, dtype=object)
+ return batch_rewards, batch_target_values
+
+ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray:
+ """
+ Overview:
+ prepare policy targets from the reanalyzed context of policies
+ Arguments:
+ - policy_re_context (:obj:`List`): List of policy context to reanalyzed
+ Returns:
+ - batch_target_policies_re
+ """
+ if policy_re_context is None:
+ return []
+ batch_target_policies_re = []
+
+ # for board games
+ policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \
+ to_play_segment = policy_re_context
+ # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1)
+ transition_batch_size = len(policy_obs_list)
+ game_segment_batch_size = len(pos_in_game_segment_list)
+
+ to_play, action_mask = self._preprocess_to_play_and_action_mask(
+ game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
+ )
+
+ if self._cfg.model.continuous_action_space is True:
+ # when the action space of the environment is continuous, action_mask[:] is None.
+ action_mask = [
+ list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size)
+ ]
+ # NOTE: in continuous action space env: we set all legal_actions as -1
+ legal_actions = [
+ [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
+ ]
+ else:
+ legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
+
+ with torch.no_grad():
+ policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type)
+ # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors
+ slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size))
+ network_output = []
+ for i in range(slices):
+ beg_index = self._cfg.mini_infer_size * i
+ end_index = self._cfg.mini_infer_size * (i + 1)
+ m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float()
+ m_output = model.initial_inference(m_obs)
+ if not model.training:
+ # if not in training, obtain the scalars of the value/reward
+ [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
+ [
+ m_output.latent_state,
+ inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
+ m_output.policy_logits
+ ]
+ )
+
+ network_output.append(m_output)
+
+ _, reward_pool, policy_logits_pool, latent_state_roots = concat_output(network_output, data_type='muzero')
+ reward_pool = reward_pool.squeeze().tolist()
+ policy_logits_pool = policy_logits_pool.tolist()
+ noises = [
+ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.action_space_size
+ ).astype(np.float32).tolist() for _ in range(transition_batch_size)
+ ]
+ if self._cfg.mcts_ctree:
+ # cpp mcts_tree
+ roots = MCTSCtree.roots(transition_batch_size, legal_actions)
+ roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
+ # do MCTS for a new policy with the recent target model
+ MCTSCtree(self._cfg).search(roots, model, latent_state_roots, to_play)
+ else:
+ # python mcts_tree
+ roots = MCTSPtree.roots(transition_batch_size, legal_actions)
+ roots.prepare(self._cfg.root_noise_weight, noises, reward_pool, policy_logits_pool, to_play)
+ # do MCTS for a new policy with the recent target model
+ MCTSPtree(self._cfg).search(roots, model, latent_state_roots, to_play)
+
+ roots_legal_actions_list = legal_actions
+ roots_distributions = roots.get_distributions()
+ policy_index = 0
+ for state_index, game_index in zip(pos_in_game_segment_list, batch_index_list):
+ target_policies = []
+
+ for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
+ distributions = roots_distributions[policy_index]
+
+ if policy_mask[policy_index] == 0:
+ # NOTE: the invalid padding target policy, O is to make sure the corresponding cross_entropy_loss=0
+ target_policies.append([0 for _ in range(self._cfg.model.action_space_size)])
+ else:
+ if distributions is None:
+ # if at some obs, the legal_action is None, add the fake target_policy
+ target_policies.append(
+ list(np.ones(self._cfg.model.action_space_size) / self._cfg.model.action_space_size)
+ )
+ else:
+ if self._cfg.action_type == 'fixed_action_space':
+ # for atari/classic_control/box2d environments that only have one player.
+ sum_visits = sum(distributions)
+ policy = [visit_count / sum_visits for visit_count in distributions]
+ target_policies.append(policy)
+ else:
+ # for board games that have two players and legal_actions is dy
+ policy_tmp = [0 for _ in range(self._cfg.model.action_space_size)]
+ # to make sure target_policies have the same dimension
+ sum_visits = sum(distributions)
+ policy = [visit_count / sum_visits for visit_count in distributions]
+ for index, legal_action in enumerate(roots_legal_actions_list[policy_index]):
+ policy_tmp[legal_action] = policy[index]
+ target_policies.append(policy_tmp)
+
+ policy_index += 1
+
+ batch_target_policies_re.append(target_policies)
+
+ batch_target_policies_re = np.array(batch_target_policies_re)
+
+ return batch_target_policies_re
+
+ def _compute_target_policy_non_reanalyzed(
+ self, policy_non_re_context: List[Any], policy_shape: Optional[int]
+ ) -> np.ndarray:
+ """
+ Overview:
+ prepare policy targets from the non-reanalyzed context of policies
+ Arguments:
+ - policy_non_re_context (:obj:`List`): List containing:
+ - pos_in_game_segment_list
+ - child_visits
+ - game_segment_lens
+ - action_mask_segment
+ - to_play_segment
+ - policy_shape: self._cfg.model.action_space_size
+ Returns:
+ - batch_target_policies_non_re
+ """
+ batch_target_policies_non_re = []
+ if policy_non_re_context is None:
+ return batch_target_policies_non_re
+
+ pos_in_game_segment_list, child_visits, game_segment_lens, action_mask_segment, to_play_segment = policy_non_re_context
+ game_segment_batch_size = len(pos_in_game_segment_list)
+ transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1)
+
+ to_play, action_mask = self._preprocess_to_play_and_action_mask(
+ game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
+ )
+
+ if self._cfg.model.continuous_action_space is True:
+ # when the action space of the environment is continuous, action_mask[:] is None.
+ action_mask = [
+ list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size)
+ ]
+ # NOTE: in continuous action space env: we set all legal_actions as -1
+ legal_actions = [
+ [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
+ ]
+ else:
+ legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
+
+ with torch.no_grad():
+ policy_index = 0
+ # 0 -> Invalid target policy for padding outside of game segments,
+ # 1 -> Previous target policy for game segments.
+ policy_mask = []
+ for game_segment_len, child_visit, state_index in zip(game_segment_lens, child_visits,
+ pos_in_game_segment_list):
+ target_policies = []
+
+ for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
+ if current_index < game_segment_len:
+ policy_mask.append(1)
+ # NOTE: child_visit is already a distribution
+ distributions = child_visit[current_index]
+ if self._cfg.action_type == 'fixed_action_space':
+ # for atari/classic_control/box2d environments that only have one player.
+ target_policies.append(distributions)
+ else:
+ # for board games that have two players.
+ policy_tmp = [0 for _ in range(policy_shape)]
+ for index, legal_action in enumerate(legal_actions[policy_index]):
+ # only the action in ``legal_action`` the policy logits is nonzero
+ policy_tmp[legal_action] = distributions[index]
+ target_policies.append(policy_tmp)
+ else:
+ # NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0
+ policy_mask.append(0)
+ target_policies.append([0 for _ in range(policy_shape)])
+
+ policy_index += 1
+
+ batch_target_policies_non_re.append(target_policies)
+ batch_target_policies_non_re = np.asarray(batch_target_policies_non_re)
+ return batch_target_policies_non_re
+
+ def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None:
+ """
+ Overview:
+ Update the priority of training data.
+ Arguments:
+ - train_data (:obj:`List[np.ndarray]`): training data to be updated priority.
+ - batch_priorities (:obj:`batch_priorities`): priorities to update to.
+ NOTE:
+ train_data = [current_batch, target_batch]
+ current_batch = [obs_list, action_list, improved_policy_list(only in Gumbel MuZero), mask_list, batch_index_list, weights, make_time_list]
+ """
+ indices = train_data[0][-3]
+ metas = {'make_time': train_data[0][-1], 'batch_priorities': batch_priorities}
+ # only update the priorities for data still in replay buffer
+ for i in range(len(indices)):
+ if metas['make_time'][i] > self.clear_time:
+ idx, prio = indices[i], metas['batch_priorities'][i]
+ self.game_pos_priorities[idx] = prio
diff --git a/LightZero/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py b/LightZero/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py
new file mode 100644
index 0000000000000000000000000000000000000000..970cf23d7abcb3d4ce01f7b4378013451cb86a85
--- /dev/null
+++ b/LightZero/lzero/mcts/buffer/game_buffer_sampled_efficientzero.py
@@ -0,0 +1,584 @@
+from typing import Any, List, Tuple
+
+import numpy as np
+import torch
+from ding.utils import BUFFER_REGISTRY
+
+from lzero.mcts.tree_search.mcts_ctree_sampled import SampledEfficientZeroMCTSCtree as MCTSCtree
+from lzero.mcts.tree_search.mcts_ptree_sampled import SampledEfficientZeroMCTSPtree as MCTSPtree
+from lzero.mcts.utils import prepare_observation, generate_random_actions_discrete
+from lzero.policy import to_detach_cpu_numpy, concat_output, concat_output_value, inverse_scalar_transform
+from .game_buffer_efficientzero import EfficientZeroGameBuffer
+
+
+@BUFFER_REGISTRY.register('game_buffer_sampled_efficientzero')
+class SampledEfficientZeroGameBuffer(EfficientZeroGameBuffer):
+ """
+ Overview:
+ The specific game buffer for Sampled EfficientZero policy.
+ """
+
+ def __init__(self, cfg: dict):
+ super().__init__(cfg)
+ """
+ Overview:
+ Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key
+ in the default configuration, the user-provided value will override the default configuration. Otherwise,
+ the default configuration will be used.
+ """
+ default_config = self.default_config()
+ default_config.update(cfg)
+ self._cfg = default_config
+ assert self._cfg.env_type in ['not_board_games', 'board_games']
+ assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space']
+ self.replay_buffer_size = self._cfg.replay_buffer_size
+ self.batch_size = self._cfg.batch_size
+ self._alpha = self._cfg.priority_prob_alpha
+ self._beta = self._cfg.priority_prob_beta
+
+ self.game_segment_buffer = []
+ self.game_pos_priorities = []
+ self.game_segment_game_pos_look_up = []
+
+ self.keep_ratio = 1
+ self.num_of_collected_episodes = 0
+ self.base_idx = 0
+ self.clear_time = 0
+
+ def sample(self, batch_size: int, policy: Any) -> List[Any]:
+ """
+ Overview:
+ sample data from ``GameBuffer`` and prepare the current and target batch for training
+ Arguments:
+ - batch_size (:obj:`int`): batch size
+ - policy (:obj:`torch.tensor`): model of policy
+ Returns:
+ - train_data (:obj:`List`): List of train data
+ """
+
+ policy._target_model.to(self._cfg.device)
+ policy._target_model.eval()
+
+ reward_value_context, policy_re_context, policy_non_re_context, current_batch = self._make_batch(
+ batch_size, self._cfg.reanalyze_ratio
+ )
+
+ # target reward, target value
+ batch_value_prefixs, batch_target_values = self._compute_target_reward_value(
+ reward_value_context, policy._target_model
+ )
+
+ batch_target_policies_non_re = self._compute_target_policy_non_reanalyzed(
+ policy_non_re_context, self._cfg.model.num_of_sampled_actions
+ )
+
+ if self._cfg.reanalyze_ratio > 0:
+ # target policy
+ batch_target_policies_re, root_sampled_actions = self._compute_target_policy_reanalyzed(
+ policy_re_context, policy._target_model
+ )
+ # ==============================================================
+ # fix reanalyze in sez:
+ # use the latest root_sampled_actions after the reanalyze process,
+ # because the batch_target_policies_re is corresponding to the latest root_sampled_actions
+ # ==============================================================
+
+ assert (self._cfg.reanalyze_ratio > 0 and self._cfg.reanalyze_outdated is True), \
+ "in sampled effiicientzero, if self._cfg.reanalyze_ratio>0, you must set self._cfg.reanalyze_outdated=True"
+ # current_batch = [obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list]
+ if self._cfg.model.continuous_action_space:
+ current_batch[2][:int(batch_size * self._cfg.reanalyze_ratio)] = root_sampled_actions.reshape(
+ int(batch_size * self._cfg.reanalyze_ratio), self._cfg.num_unroll_steps + 1,
+ self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size
+ )
+ else:
+ current_batch[2][:int(batch_size * self._cfg.reanalyze_ratio)] = root_sampled_actions.reshape(
+ int(batch_size * self._cfg.reanalyze_ratio), self._cfg.num_unroll_steps + 1,
+ self._cfg.model.num_of_sampled_actions, 1
+ )
+
+ if 0 < self._cfg.reanalyze_ratio < 1:
+ try:
+ batch_target_policies = np.concatenate([batch_target_policies_re, batch_target_policies_non_re])
+ except Exception as error:
+ print(error)
+ elif self._cfg.reanalyze_ratio == 1:
+ batch_target_policies = batch_target_policies_re
+ elif self._cfg.reanalyze_ratio == 0:
+ batch_target_policies = batch_target_policies_non_re
+
+ target_batch = [batch_value_prefixs, batch_target_values, batch_target_policies]
+ # a batch contains the current_batch and the target_batch
+ train_data = [current_batch, target_batch]
+ return train_data
+
+ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
+ """
+ Overview:
+ first sample orig_data through ``_sample_orig_data()``,
+ then prepare the context of a batch:
+ reward_value_context: the context of reanalyzed value targets
+ policy_re_context: the context of reanalyzed policy targets
+ policy_non_re_context: the context of non-reanalyzed policy targets
+ current_batch: the inputs of batch
+ Arguments:
+ - batch_size (:obj:`int`): the batch size of orig_data from replay buffer.
+ - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed)
+ Returns:
+ - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch
+ """
+ # obtain the batch context from replay buffer
+ orig_data = self._sample_orig_data(batch_size)
+ game_lst, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data
+ batch_size = len(batch_index_list)
+ obs_list, action_list, mask_list = [], [], []
+ root_sampled_actions_list = []
+ # prepare the inputs of a batch
+ for i in range(batch_size):
+ game = game_lst[i]
+ pos_in_game_segment = pos_in_game_segment_list[i]
+ # ==============================================================
+ # sampled related core code
+ # ==============================================================
+ actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
+ self._cfg.num_unroll_steps].tolist()
+
+ # NOTE: self._cfg.num_unroll_steps + 1
+ root_sampled_actions_tmp = game.root_sampled_actions[pos_in_game_segment:pos_in_game_segment +
+ self._cfg.num_unroll_steps + 1]
+
+ # add mask for invalid actions (out of trajectory), 1 for valid, 0 for invalid
+ mask_tmp = [1. for i in range(len(root_sampled_actions_tmp))]
+ mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps + 1 - len(mask_tmp))]
+
+ # pad random action
+ if self._cfg.model.continuous_action_space:
+ actions_tmp += [
+ np.random.randn(self._cfg.model.action_space_size)
+ for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
+ ]
+ root_sampled_actions_tmp += [
+ np.random.rand(self._cfg.model.num_of_sampled_actions, self._cfg.model.action_space_size)
+ for _ in range(self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp))
+ ]
+ else:
+ # generate random `padded actions_tmp`
+ actions_tmp += generate_random_actions_discrete(
+ self._cfg.num_unroll_steps - len(actions_tmp),
+ self._cfg.model.action_space_size,
+ 1 # Number of sampled actions for actions_tmp is 1
+ )
+
+ # generate random padded `root_sampled_actions_tmp`
+ # root_sampled_action have different shape in mcts_ctree and mcts_ptree, thus we need to pad differently
+ reshape = True if self._cfg.mcts_ctree else False
+ root_sampled_actions_tmp += generate_random_actions_discrete(
+ self._cfg.num_unroll_steps + 1 - len(root_sampled_actions_tmp),
+ self._cfg.model.action_space_size,
+ self._cfg.model.num_of_sampled_actions,
+ reshape=reshape
+ )
+
+ # obtain the input observations
+ # stack+num_unroll_steps = 4+5
+ # pad if length of obs in game_segment is less than stack+num_unroll_steps
+ obs_list.append(
+ game_lst[i].get_unroll_obs(
+ pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
+ )
+ )
+ action_list.append(actions_tmp)
+ root_sampled_actions_list.append(root_sampled_actions_tmp)
+
+ mask_list.append(mask_tmp)
+
+ # formalize the input observations
+ obs_list = prepare_observation(obs_list, self._cfg.model.model_type)
+ # ==============================================================
+ # sampled related core code
+ # ==============================================================
+ # formalize the inputs of a batch
+ current_batch = [
+ obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list
+ ]
+
+ for i in range(len(current_batch)):
+ current_batch[i] = np.asarray(current_batch[i])
+
+ total_transitions = self.get_num_of_transitions()
+
+ # obtain the context of value targets
+ reward_value_context = self._prepare_reward_value_context(
+ batch_index_list, game_lst, pos_in_game_segment_list, total_transitions
+ )
+ """
+ only reanalyze recent reanalyze_ratio (e.g. 50%) data
+ if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps
+ 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy
+ """
+ reanalyze_num = int(batch_size * reanalyze_ratio)
+ # reanalyzed policy
+ if reanalyze_num > 0:
+ # obtain the context of reanalyzed policy targets
+ policy_re_context = self._prepare_policy_reanalyzed_context(
+ batch_index_list[:reanalyze_num], game_lst[:reanalyze_num], pos_in_game_segment_list[:reanalyze_num]
+ )
+ else:
+ policy_re_context = None
+
+ # non reanalyzed policy
+ if reanalyze_num < batch_size:
+ # obtain the context of non-reanalyzed policy targets
+ policy_non_re_context = self._prepare_policy_non_reanalyzed_context(
+ batch_index_list[reanalyze_num:], game_lst[reanalyze_num:], pos_in_game_segment_list[reanalyze_num:]
+ )
+ else:
+ policy_non_re_context = None
+
+ context = reward_value_context, policy_re_context, policy_non_re_context, current_batch
+ return context
+
+ def _compute_target_reward_value(self, reward_value_context: List[Any], model: Any) -> List[np.ndarray]:
+ """
+ Overview:
+ prepare reward and value targets from the context of rewards and values.
+ Arguments:
+ - reward_value_context (:obj:'list'): the reward value context
+ - model (:obj:'torch.tensor'):model of the target model
+ Returns:
+ - batch_value_prefixs (:obj:'np.ndarray): batch of value prefix
+ - batch_target_values (:obj:'np.ndarray): batch of value estimation
+ """
+ value_obs_list, value_mask, pos_in_game_segment_list, rewards_list, game_segment_lens, td_steps_list, action_mask_segment, \
+ to_play_segment = reward_value_context # noqa
+
+ # transition_batch_size = game_segment_batch_size * (num_unroll_steps+1)
+ transition_batch_size = len(value_obs_list)
+ game_segment_batch_size = len(pos_in_game_segment_list)
+
+ to_play, action_mask = self._preprocess_to_play_and_action_mask(
+ game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
+ )
+ if self._cfg.model.continuous_action_space is True:
+ # when the action space of the environment is continuous, action_mask[:] is None.
+ action_mask = [
+ list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size)
+ ]
+ # NOTE: in continuous action space env: we set all legal_actions as -1
+ legal_actions = [
+ [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
+ ]
+ else:
+ legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
+
+ batch_target_values, batch_value_prefixs = [], []
+ with torch.no_grad():
+ value_obs_list = prepare_observation(value_obs_list, self._cfg.model.model_type)
+ # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors
+ slices = int(np.ceil(transition_batch_size / self._cfg.mini_infer_size))
+ network_output = []
+ for i in range(slices):
+ beg_index = self._cfg.mini_infer_size * i
+ end_index = self._cfg.mini_infer_size * (i + 1)
+ m_obs = torch.from_numpy(value_obs_list[beg_index:end_index]).to(self._cfg.device).float()
+
+ # calculate the target value
+ m_output = model.initial_inference(m_obs)
+
+ # TODO(pu)
+ if not model.training:
+ # if not in training, obtain the scalars of the value/reward
+ [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
+ [
+ m_output.latent_state,
+ inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
+ m_output.policy_logits
+ ]
+ )
+ m_output.reward_hidden_state = (
+ m_output.reward_hidden_state[0].detach().cpu().numpy(),
+ m_output.reward_hidden_state[1].detach().cpu().numpy()
+ )
+
+ network_output.append(m_output)
+
+ # concat the output slices after model inference
+ if self._cfg.use_root_value:
+ # use the root values from MCTS
+ # the root values have limited improvement but require much more GPU actors;
+ _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output(
+ network_output, data_type='efficientzero'
+ )
+ value_prefix_pool = value_prefix_pool.squeeze().tolist()
+ policy_logits_pool = policy_logits_pool.tolist()
+ # generate the noises for the root nodes
+ noises = [
+ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.num_of_sampled_actions
+ ).astype(np.float32).tolist() for _ in range(transition_batch_size)
+ ]
+
+ if self._cfg.mcts_ctree:
+ # cpp mcts_tree
+ # prepare the root nodes for MCTS
+ roots = MCTSCtree.roots(
+ transition_batch_size, legal_actions, self._cfg.model.action_space_size,
+ self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space
+ )
+
+ roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
+ # do MCTS for a new policy with the recent target model
+ MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)
+ else:
+ # python mcts_tree
+ roots = MCTSPtree.roots(
+ transition_batch_size, legal_actions, self._cfg.model.action_space_size,
+ self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space
+ )
+ roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
+ # do MCTS for a new policy with the recent target model
+ MCTSPtree.roots(self._cfg
+ ).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)
+
+ roots_values = roots.get_values()
+ value_list = np.array(roots_values)
+ else:
+ # use the predicted values
+ value_list = concat_output_value(network_output)
+
+ # get last state value
+ if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]:
+ # TODO(pu): for board_games, very important, to check
+ value_list = value_list.reshape(-1) * np.array(
+ [
+ self._cfg.discount_factor ** td_steps_list[i] if int(td_steps_list[i]) %
+ 2 == 0 else -self._cfg.discount_factor ** td_steps_list[i]
+ for i in range(transition_batch_size)
+ ]
+ )
+ else:
+ value_list = value_list.reshape(-1) * (
+ np.array([self._cfg.discount_factor for _ in range(transition_batch_size)]) ** td_steps_list
+ )
+
+ value_list = value_list * np.array(value_mask)
+ value_list = value_list.tolist()
+
+ horizon_id, value_index = 0, 0
+ for game_segment_len_non_re, reward_list, state_index, to_play_list in zip(game_segment_lens, rewards_list,
+ pos_in_game_segment_list,
+ to_play_segment):
+ target_values = []
+ target_value_prefixs = []
+
+ value_prefix = 0.0
+ base_index = state_index
+ for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
+ bootstrap_index = current_index + td_steps_list[value_index]
+ # for i, reward in enumerate(game.rewards[current_index:bootstrap_index]):
+ for i, reward in enumerate(reward_list[current_index:bootstrap_index]):
+ if self._cfg.env_type == 'board_games' and to_play_segment[0][0] in [1, 2]:
+ # TODO(pu): for board_games, very important, to check
+ if to_play_list[base_index] == to_play_list[i]:
+ value_list[value_index] += reward * self._cfg.discount_factor ** i
+ else:
+ value_list[value_index] += -reward * self._cfg.discount_factor ** i
+ else:
+ value_list[value_index] += reward * self._cfg.discount_factor ** i
+ # TODO(pu): why value don't use discount_factor factor
+
+ # reset every lstm_horizon_len
+ if horizon_id % self._cfg.lstm_horizon_len == 0:
+ value_prefix = 0.0
+ base_index = current_index
+ horizon_id += 1
+
+ if current_index < game_segment_len_non_re:
+ target_values.append(value_list[value_index])
+ # Since the horizon is small and the discount_factor is close to 1.
+ # Compute the reward sum to approximate the value prefix for simplification
+ value_prefix += reward_list[current_index
+ ] # * config.discount_factor ** (current_index - base_index)
+ target_value_prefixs.append(value_prefix)
+ else:
+ target_values.append(0)
+ target_value_prefixs.append(value_prefix)
+
+ value_index += 1
+
+ batch_value_prefixs.append(target_value_prefixs)
+ batch_target_values.append(target_values)
+
+ batch_value_prefixs = np.asarray(batch_value_prefixs, dtype=object)
+ batch_target_values = np.asarray(batch_target_values, dtype=object)
+
+ return batch_value_prefixs, batch_target_values
+
+ def _compute_target_policy_reanalyzed(self, policy_re_context: List[Any], model: Any) -> np.ndarray:
+ """
+ Overview:
+ prepare policy targets from the reanalyzed context of policies
+ Arguments:
+ - policy_re_context (:obj:`List`): List of policy context to reanalyzed
+ Returns:
+ - batch_target_policies_re
+ """
+ if policy_re_context is None:
+ return []
+ batch_target_policies_re = []
+
+ policy_obs_list, policy_mask, pos_in_game_segment_list, batch_index_list, child_visits, game_segment_lens, action_mask_segment, \
+ to_play_segment = policy_re_context # noqa
+ # transition_batch_size = game_segment_batch_size * (self._cfg.num_unroll_steps + 1)
+ transition_batch_size = len(policy_obs_list)
+ game_segment_batch_size = len(pos_in_game_segment_list)
+
+ to_play, action_mask = self._preprocess_to_play_and_action_mask(
+ game_segment_batch_size, to_play_segment, action_mask_segment, pos_in_game_segment_list
+ )
+ if self._cfg.model.continuous_action_space is True:
+ # when the action space of the environment is continuous, action_mask[:] is None.
+ action_mask = [
+ list(np.ones(self._cfg.model.action_space_size, dtype=np.int8)) for _ in range(transition_batch_size)
+ ]
+ # NOTE: in continuous action space env, we set all legal_actions as -1
+ legal_actions = [
+ [-1 for _ in range(self._cfg.model.action_space_size)] for _ in range(transition_batch_size)
+ ]
+ else:
+ legal_actions = [[i for i, x in enumerate(action_mask[j]) if x == 1] for j in range(transition_batch_size)]
+
+ with torch.no_grad():
+ policy_obs_list = prepare_observation(policy_obs_list, self._cfg.model.model_type)
+ # split a full batch into slices of mini_infer_size: to save the GPU memory for more GPU actors
+ self._cfg.mini_infer_size = self._cfg.mini_infer_size
+ slices = np.ceil(transition_batch_size / self._cfg.mini_infer_size).astype(np.int_)
+ network_output = []
+ for i in range(slices):
+ beg_index = self._cfg.mini_infer_size * i
+ end_index = self._cfg.mini_infer_size * (i + 1)
+ m_obs = torch.from_numpy(policy_obs_list[beg_index:end_index]).to(self._cfg.device).float()
+
+ m_output = model.initial_inference(m_obs)
+
+ if not model.training:
+ # if not in training, obtain the scalars of the value/reward
+ [m_output.latent_state, m_output.value, m_output.policy_logits] = to_detach_cpu_numpy(
+ [
+ m_output.latent_state,
+ inverse_scalar_transform(m_output.value, self._cfg.model.support_scale),
+ m_output.policy_logits
+ ]
+ )
+ m_output.reward_hidden_state = (
+ m_output.reward_hidden_state[0].detach().cpu().numpy(),
+ m_output.reward_hidden_state[1].detach().cpu().numpy()
+ )
+
+ network_output.append(m_output)
+
+ _, value_prefix_pool, policy_logits_pool, latent_state_roots, reward_hidden_state_roots = concat_output(
+ network_output, data_type='efficientzero'
+ )
+
+ value_prefix_pool = value_prefix_pool.squeeze().tolist()
+ policy_logits_pool = policy_logits_pool.tolist()
+ noises = [
+ np.random.dirichlet([self._cfg.root_dirichlet_alpha] * self._cfg.model.num_of_sampled_actions
+ ).astype(np.float32).tolist() for _ in range(transition_batch_size)
+ ]
+ if self._cfg.mcts_ctree:
+ # ==============================================================
+ # sampled related core code
+ # ==============================================================
+ # cpp mcts_tree
+ roots = MCTSCtree.roots(
+ transition_batch_size, legal_actions, self._cfg.model.action_space_size,
+ self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space
+ )
+ roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
+ # do MCTS for a new policy with the recent target model
+ MCTSCtree(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)
+ else:
+ # python mcts_tree
+ roots = MCTSPtree.roots(
+ transition_batch_size, legal_actions, self._cfg.model.action_space_size,
+ self._cfg.model.num_of_sampled_actions, self._cfg.model.continuous_action_space
+ )
+ roots.prepare(self._cfg.root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play)
+ # do MCTS for a new policy with the recent target model
+ MCTSPtree.roots(self._cfg).search(roots, model, latent_state_roots, reward_hidden_state_roots, to_play)
+
+ roots_legal_actions_list = legal_actions
+ roots_distributions = roots.get_distributions()
+
+ # ==============================================================
+ # fix reanalyze in sez
+ # ==============================================================
+ roots_sampled_actions = roots.get_sampled_actions()
+ try:
+ root_sampled_actions = np.array([action.value for action in roots_sampled_actions])
+ except Exception:
+ root_sampled_actions = np.array([action for action in roots_sampled_actions])
+
+ policy_index = 0
+ for state_index, game_idx in zip(pos_in_game_segment_list, batch_index_list):
+ target_policies = []
+ for current_index in range(state_index, state_index + self._cfg.num_unroll_steps + 1):
+ distributions = roots_distributions[policy_index]
+ # ==============================================================
+ # sampled related core code
+ # ==============================================================
+ if policy_mask[policy_index] == 0:
+ # NOTE: the invalid padding target policy, O is to make sure the correspoding cross_entropy_loss=0
+ target_policies.append([0 for _ in range(self._cfg.model.num_of_sampled_actions)])
+ else:
+ if distributions is None:
+ # if at some obs, the legal_action is None, then add the fake target_policy
+ target_policies.append(
+ list(
+ np.ones(self._cfg.model.num_of_sampled_actions) /
+ self._cfg.model.num_of_sampled_actions
+ )
+ )
+ else:
+ if self._cfg.action_type == 'fixed_action_space':
+ sum_visits = sum(distributions)
+ policy = [visit_count / sum_visits for visit_count in distributions]
+ target_policies.append(policy)
+ else:
+ # for two_player board games
+ policy_tmp = [0 for _ in range(self._cfg.model.num_of_sampled_actions)]
+ # to make sure target_policies have the same dimension
+ sum_visits = sum(distributions)
+ policy = [visit_count / sum_visits for visit_count in distributions]
+ for index, legal_action in enumerate(roots_legal_actions_list[policy_index]):
+ policy_tmp[legal_action] = policy[index]
+ target_policies.append(policy_tmp)
+
+ policy_index += 1
+
+ batch_target_policies_re.append(target_policies)
+
+ batch_target_policies_re = np.array(batch_target_policies_re)
+
+ return batch_target_policies_re, root_sampled_actions
+
+ def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None:
+ """
+ Overview:
+ Update the priority of training data.
+ Arguments:
+ - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority.
+ - batch_priorities (:obj:`batch_priorities`): priorities to update to.
+ NOTE:
+ train_data = [current_batch, target_batch]
+ current_batch = [obs_list, action_list, root_sampled_actions_list, mask_list, batch_index_list, weights_list, make_time_list]
+ """
+
+ batch_index_list = train_data[0][4]
+ metas = {'make_time': train_data[0][6], 'batch_priorities': batch_priorities}
+ # only update the priorities for data still in replay buffer
+ for i in range(len(batch_index_list)):
+ if metas['make_time'][i] > self.clear_time:
+ idx, prio = batch_index_list[i], metas['batch_priorities'][i]
+ self.game_pos_priorities[idx] = prio
diff --git a/LightZero/lzero/mcts/buffer/game_buffer_stochastic_muzero.py b/LightZero/lzero/mcts/buffer/game_buffer_stochastic_muzero.py
new file mode 100644
index 0000000000000000000000000000000000000000..b2b72ba1d74fc2ec048df9751c799ec183d3fc58
--- /dev/null
+++ b/LightZero/lzero/mcts/buffer/game_buffer_stochastic_muzero.py
@@ -0,0 +1,172 @@
+from typing import Any, Tuple, List
+
+import numpy as np
+from ding.utils import BUFFER_REGISTRY
+
+from lzero.mcts.utils import prepare_observation
+from .game_buffer_muzero import MuZeroGameBuffer
+
+
+@BUFFER_REGISTRY.register('game_buffer_stochastic_muzero')
+class StochasticMuZeroGameBuffer(MuZeroGameBuffer):
+ """
+ Overview:
+ The specific game buffer for Stochastic MuZero policy.
+ """
+
+ def __init__(self, cfg: dict):
+ super().__init__(cfg)
+ """
+ Overview:
+ Use the default configuration mechanism. If a user passes in a cfg with a key that matches an existing key
+ in the default configuration, the user-provided value will override the default configuration. Otherwise,
+ the default configuration will be used.
+ """
+ default_config = self.default_config()
+ default_config.update(cfg)
+ self._cfg = default_config
+ assert self._cfg.env_type in ['not_board_games', 'board_games']
+ assert self._cfg.action_type in ['fixed_action_space', 'varied_action_space']
+ self.replay_buffer_size = self._cfg.replay_buffer_size
+ self.batch_size = self._cfg.batch_size
+ self._alpha = self._cfg.priority_prob_alpha
+ self._beta = self._cfg.priority_prob_beta
+
+ self.keep_ratio = 1
+ self.model_update_interval = 10
+ self.num_of_collected_episodes = 0
+ self.base_idx = 0
+ self.clear_time = 0
+
+ self.game_segment_buffer = []
+ self.game_pos_priorities = []
+ self.game_segment_game_pos_look_up = []
+
+ def _make_batch(self, batch_size: int, reanalyze_ratio: float) -> Tuple[Any]:
+ """
+ Overview:
+ first sample orig_data through ``_sample_orig_data()``,
+ then prepare the context of a batch:
+ reward_value_context: the context of reanalyzed value targets
+ policy_re_context: the context of reanalyzed policy targets
+ policy_non_re_context: the context of non-reanalyzed policy targets
+ current_batch: the inputs of batch
+ Arguments:
+ - batch_size (:obj:`int`): the batch size of orig_data from replay buffer.
+ - reanalyze_ratio (:obj:`float`): ratio of reanalyzed policy (value is 100% reanalyzed)
+ Returns:
+ - context (:obj:`Tuple`): reward_value_context, policy_re_context, policy_non_re_context, current_batch
+ """
+ # obtain the batch context from replay buffer
+ orig_data = self._sample_orig_data(batch_size)
+ game_segment_list, pos_in_game_segment_list, batch_index_list, weights_list, make_time_list = orig_data
+ batch_size = len(batch_index_list)
+ obs_list, action_list, mask_list = [], [], []
+ if self._cfg.use_ture_chance_label_in_chance_encoder:
+ chance_list = []
+ # prepare the inputs of a batch
+ for i in range(batch_size):
+ game = game_segment_list[i]
+ pos_in_game_segment = pos_in_game_segment_list[i]
+
+ actions_tmp = game.action_segment[pos_in_game_segment:pos_in_game_segment +
+ self._cfg.num_unroll_steps].tolist()
+ if self._cfg.use_ture_chance_label_in_chance_encoder:
+ chances_tmp = game.chance_segment[1 + pos_in_game_segment:1 + pos_in_game_segment +
+ self._cfg.num_unroll_steps].tolist()
+ # add mask for invalid actions (out of trajectory)
+ mask_tmp = [1. for i in range(len(actions_tmp))]
+ mask_tmp += [0. for _ in range(self._cfg.num_unroll_steps - len(mask_tmp))]
+
+ # pad random action
+ actions_tmp += [
+ np.random.randint(0, game.action_space_size)
+ for _ in range(self._cfg.num_unroll_steps - len(actions_tmp))
+ ]
+ if self._cfg.use_ture_chance_label_in_chance_encoder:
+ chances_tmp += [
+ np.random.randint(0, game.action_space_size)
+ for _ in range(self._cfg.num_unroll_steps - len(chances_tmp))
+ ]
+ # obtain the input observations
+ # pad if length of obs in game_segment is less than stack+num_unroll_steps
+ # e.g. stack+num_unroll_steps 4+5
+ obs_list.append(
+ game_segment_list[i].get_unroll_obs(
+ pos_in_game_segment_list[i], num_unroll_steps=self._cfg.num_unroll_steps, padding=True
+ )
+ )
+ action_list.append(actions_tmp)
+ mask_list.append(mask_tmp)
+ if self._cfg.use_ture_chance_label_in_chance_encoder:
+ chance_list.append(chances_tmp)
+
+ # formalize the input observations
+ obs_list = prepare_observation(obs_list, self._cfg.model.model_type)
+
+ # formalize the inputs of a batch
+ if self._cfg.use_ture_chance_label_in_chance_encoder:
+ current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list,
+ chance_list]
+ else:
+ current_batch = [obs_list, action_list, mask_list, batch_index_list, weights_list, make_time_list]
+ for i in range(len(current_batch)):
+ current_batch[i] = np.asarray(current_batch[i])
+
+ total_transitions = self.get_num_of_transitions()
+
+ # obtain the context of value targets
+ reward_value_context = self._prepare_reward_value_context(
+ batch_index_list, game_segment_list, pos_in_game_segment_list, total_transitions
+ )
+ """
+ only reanalyze recent reanalyze_ratio (e.g. 50%) data
+ if self._cfg.reanalyze_outdated is True, batch_index_list is sorted according to its generated env_steps
+ 0: reanalyze_num -> reanalyzed policy, reanalyze_num:end -> non reanalyzed policy
+ """
+ reanalyze_num = int(batch_size * reanalyze_ratio)
+ # reanalyzed policy
+ if reanalyze_num > 0:
+ # obtain the context of reanalyzed policy targets
+ policy_re_context = self._prepare_policy_reanalyzed_context(
+ batch_index_list[:reanalyze_num], game_segment_list[:reanalyze_num],
+ pos_in_game_segment_list[:reanalyze_num]
+ )
+ else:
+ policy_re_context = None
+
+ # non reanalyzed policy
+ if reanalyze_num < batch_size:
+ # obtain the context of non-reanalyzed policy targets
+ policy_non_re_context = self._prepare_policy_non_reanalyzed_context(
+ batch_index_list[reanalyze_num:], game_segment_list[reanalyze_num:],
+ pos_in_game_segment_list[reanalyze_num:]
+ )
+ else:
+ policy_non_re_context = None
+
+ context = reward_value_context, policy_re_context, policy_non_re_context, current_batch
+ return context
+
+ def update_priority(self, train_data: List[np.ndarray], batch_priorities: Any) -> None:
+ """
+ Overview:
+ Update the priority of training data.
+ Arguments:
+ - train_data (:obj:`Optional[List[Optional[np.ndarray]]]`): training data to be updated priority.
+ - batch_priorities (:obj:`batch_priorities`): priorities to update to.
+ NOTE:
+ train_data = [current_batch, target_batch]
+ if self._cfg.use_ture_chance_label_in_chance_encoder:
+ obs_batch_orig, action_batch, mask_batch, indices, weights, make_time, chance_batch = current_batch
+ else:
+ obs_batch_orig, action_batch, mask_batch, indices, weights, make_time = current_batch
+
+ """
+ indices = train_data[0][3]
+ metas = {'make_time': train_data[0][5], 'batch_priorities': batch_priorities}
+ # only update the priorities for data still in replay buffer
+ for i in range(len(indices)):
+ if metas['make_time'][i] > self.clear_time:
+ idx, prio = indices[i], metas['batch_priorities'][i]
+ self.game_pos_priorities[idx] = prio
diff --git a/LightZero/lzero/mcts/buffer/game_segment.py b/LightZero/lzero/mcts/buffer/game_segment.py
new file mode 100644
index 0000000000000000000000000000000000000000..ae926092168d006d07554b6274967ac9795e0dde
--- /dev/null
+++ b/LightZero/lzero/mcts/buffer/game_segment.py
@@ -0,0 +1,334 @@
+import copy
+from typing import List, Tuple
+
+import numpy as np
+from easydict import EasyDict
+
+from ding.utils.compression_helper import jpeg_data_decompressor
+
+
+class GameSegment:
+ """
+ Overview:
+ A game segment from a full episode trajectory.
+
+ The length of one episode in (Atari) games is often quite large. This class represents a single game segment
+ within a larger trajectory, split into several blocks.
+
+ Interfaces:
+ - __init__
+ - __len__
+ - reset
+ - pad_over
+ - is_full
+ - legal_actions
+ - append
+ - get_observation
+ - zero_obs
+ - step_obs
+ - get_targets
+ - game_segment_to_array
+ - store_search_stats
+ """
+
+ def __init__(self, action_space: int, game_segment_length: int = 200, config: EasyDict = None) -> None:
+ """
+ Overview:
+ Init the ``GameSegment`` according to the provided arguments.
+ Arguments:
+ action_space (:obj:`int`): action space
+ - game_segment_length (:obj:`int`): the transition number of one ``GameSegment`` block
+ """
+ self.action_space = action_space
+ self.game_segment_length = game_segment_length
+ self.num_unroll_steps = config.num_unroll_steps
+ self.td_steps = config.td_steps
+ self.frame_stack_num = config.model.frame_stack_num
+ self.discount_factor = config.discount_factor
+ self.action_space_size = config.model.action_space_size
+ self.gray_scale = config.gray_scale
+ self.transform2string = config.transform2string
+ self.sampled_algo = config.sampled_algo
+ self.gumbel_algo = config.gumbel_algo
+ self.use_ture_chance_label_in_chance_encoder = config.use_ture_chance_label_in_chance_encoder
+
+ if isinstance(config.model.observation_shape, int) or len(config.model.observation_shape) == 1:
+ # for vector obs input, e.g. classical control and box2d environments
+ self.zero_obs_shape = config.model.observation_shape
+ elif len(config.model.observation_shape) == 3:
+ # image obs input, e.g. atari environments
+ self.zero_obs_shape = (
+ config.model.observation_shape[-2], config.model.observation_shape[-1], config.model.image_channel
+ )
+
+ self.obs_segment = []
+ self.action_segment = []
+ self.reward_segment = []
+
+ self.child_visit_segment = []
+ self.root_value_segment = []
+
+ self.action_mask_segment = []
+ self.to_play_segment = []
+
+ self.target_values = []
+ self.target_rewards = []
+ self.target_policies = []
+
+ self.improved_policy_probs = []
+
+ if self.sampled_algo:
+ self.root_sampled_actions = []
+ if self.use_ture_chance_label_in_chance_encoder:
+ self.chance_segment = []
+
+
+ def get_unroll_obs(self, timestep: int, num_unroll_steps: int = 0, padding: bool = False) -> np.ndarray:
+ """
+ Overview:
+ Get an observation of the correct format: o[t, t + stack frames + num_unroll_steps].
+ Arguments:
+ - timestep (int): The time step.
+ - num_unroll_steps (int): The extra length of the observation frames.
+ - padding (bool): If True, pad frames if (t + stack frames) is outside of the trajectory.
+ """
+ stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num + num_unroll_steps]
+ if padding:
+ pad_len = self.frame_stack_num + num_unroll_steps - len(stacked_obs)
+ if pad_len > 0:
+ pad_frames = np.array([stacked_obs[-1] for _ in range(pad_len)])
+ stacked_obs = np.concatenate((stacked_obs, pad_frames))
+ if self.transform2string:
+ stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs]
+ return stacked_obs
+
+ def zero_obs(self) -> List:
+ """
+ Overview:
+ Return an observation frame filled with zeros.
+ Returns:
+ ndarray: An array filled with zeros.
+ """
+ return [np.zeros(self.zero_obs_shape, dtype=np.float32) for _ in range(self.frame_stack_num)]
+
+ def get_obs(self) -> List:
+ """
+ Overview:
+ Return an observation in the correct format for model inference.
+ Returns:
+ stacked_obs (List): An observation in the correct format for model inference.
+ """
+ timestep_obs = len(self.obs_segment) - self.frame_stack_num
+ timestep_reward = len(self.reward_segment)
+ assert timestep_obs == timestep_reward, "timestep_obs: {}, timestep_reward: {}".format(
+ timestep_obs, timestep_reward
+ )
+ timestep = timestep_reward
+ stacked_obs = self.obs_segment[timestep:timestep + self.frame_stack_num]
+ if self.transform2string:
+ stacked_obs = [jpeg_data_decompressor(obs, self.gray_scale) for obs in stacked_obs]
+ return stacked_obs
+
+ def append(
+ self,
+ action: np.ndarray,
+ obs: np.ndarray,
+ reward: np.ndarray,
+ action_mask: np.ndarray = None,
+ to_play: int = -1,
+ chance: int = 0,
+ ) -> None:
+ """
+ Overview:
+ Append a transition tuple, including a_t, o_{t+1}, r_{t}, action_mask_{t}, to_play_{t}.
+ """
+ self.action_segment.append(action)
+ self.obs_segment.append(obs)
+ self.reward_segment.append(reward)
+
+ self.action_mask_segment.append(action_mask)
+ self.to_play_segment.append(to_play)
+ if self.use_ture_chance_label_in_chance_encoder:
+ self.chance_segment.append(chance)
+
+ def pad_over(
+ self, next_segment_observations: List, next_segment_rewards: List, next_segment_root_values: List,
+ next_segment_child_visits: List, next_segment_improved_policy: List = None, next_chances: List = None,
+ ) -> None:
+ """
+ Overview:
+ To make sure the correction of value targets, we need to add (o_t, r_t, etc) from the next game_segment
+ , which is necessary for the bootstrapped values at the end states of previous game_segment.
+ e.g: len = 100; target value v_100 = r_100 + gamma^1 r_101 + ... + gamma^4 r_104 + gamma^5 v_105,
+ but r_101, r_102, ... are from the next game_segment.
+ Arguments:
+ - next_segment_observations (:obj:`list`): o_t from the next game_segment
+ - next_segment_rewards (:obj:`list`): r_t from the next game_segment
+ - next_segment_root_values (:obj:`list`): root values of MCTS from the next game_segment
+ - next_segment_child_visits (:obj:`list`): root visit count distributions of MCTS from the next game_segment
+ - next_segment_improved_policy (:obj:`list`): root children select policy of MCTS from the next game_segment (Only used in Gumbel MuZero)
+ """
+ assert len(next_segment_observations) <= self.num_unroll_steps
+ assert len(next_segment_child_visits) <= self.num_unroll_steps
+ assert len(next_segment_root_values) <= self.num_unroll_steps + self.td_steps
+ assert len(next_segment_rewards) <= self.num_unroll_steps + self.td_steps - 1
+ # ==============================================================
+ # The core difference between GumbelMuZero and MuZero
+ # ==============================================================
+ if self.gumbel_algo:
+ assert len(next_segment_improved_policy) <= self.num_unroll_steps + self.td_steps
+
+ # NOTE: next block observation should start from (stacked_observation - 1) in next trajectory
+ for observation in next_segment_observations:
+ self.obs_segment.append(copy.deepcopy(observation))
+
+ for reward in next_segment_rewards:
+ self.reward_segment.append(reward)
+
+ for value in next_segment_root_values:
+ self.root_value_segment.append(value)
+
+ for child_visits in next_segment_child_visits:
+ self.child_visit_segment.append(child_visits)
+
+ if self.gumbel_algo:
+ for improved_policy in next_segment_improved_policy:
+ self.improved_policy_probs.append(improved_policy)
+ if self.use_ture_chance_label_in_chance_encoder:
+ for chances in next_chances:
+ self.chance_segment.append(chances)
+
+ def get_targets(self, timestep: int) -> Tuple:
+ """
+ Overview:
+ return the value/reward/policy targets at step timestep
+ """
+ return self.target_values[timestep], self.target_rewards[timestep], self.target_policies[timestep]
+
+ def store_search_stats(
+ self, visit_counts: List, root_value: List, root_sampled_actions: List = None, improved_policy: List = None, idx: int = None
+ ) -> None:
+ """
+ Overview:
+ store the visit count distributions and value of the root node after MCTS.
+ """
+ sum_visits = sum(visit_counts)
+ if idx is None:
+ self.child_visit_segment.append([visit_count / sum_visits for visit_count in visit_counts])
+ self.root_value_segment.append(root_value)
+ if self.sampled_algo:
+ self.root_sampled_actions.append(root_sampled_actions)
+ # store the improved policy in Gumbel Muzero: \pi'=softmax(logits + \sigma(CompletedQ))
+ if self.gumbel_algo:
+ self.improved_policy_probs.append(improved_policy)
+ else:
+ self.child_visit_segment[idx] = [visit_count / sum_visits for visit_count in visit_counts]
+ self.root_value_segment[idx] = root_value
+ self.improved_policy_probs[idx] = improved_policy
+
+ def game_segment_to_array(self) -> None:
+ """
+ Overview:
+ Post-process the data when a `GameSegment` block is full. This function converts various game segment
+ elements into numpy arrays for easier manipulation and processing.
+ Structure:
+ The structure and shapes of different game segment elements are as follows. Let's assume
+ `game_segment_length`=20, `stack`=4, `num_unroll_steps`=5, `td_steps`=5:
+
+ - obs: game_segment_length + stack + num_unroll_steps, 20+4+5
+ - action: game_segment_length -> 20
+ - reward: game_segment_length + num_unroll_steps + td_steps -1 20+5+5-1
+ - root_values: game_segment_length + num_unroll_steps + td_steps -> 20+5+5
+ - child_visits: game_segment_length + num_unroll_steps -> 20+5
+ - to_play: game_segment_length -> 20
+ - action_mask: game_segment_length -> 20
+ Examples:
+ Here is an illustration of the structure of `obs` and `rew` for two consecutive game segments
+ (game_segment_i and game_segment_i+1):
+
+ - game_segment_i (obs): 4 20 5
+ ----|----...----|-----|
+ - game_segment_i+1 (obs): 4 20 5
+ ----|----...----|-----|
+
+ - game_segment_i (rew): 20 5 4
+ ----...----|------|-----|
+ - game_segment_i+1 (rew): 20 5 4
+ ----...----|------|-----|
+
+ Postprocessing:
+ - self.obs_segment (:obj:`numpy.ndarray`): A numpy array version of the original obs_segment.
+ - self.action_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_segment.
+ - self.reward_segment (:obj:`numpy.ndarray`): A numpy array version of the original reward_segment.
+ - self.child_visit_segment (:obj:`numpy.ndarray`): A numpy array version of the original child_visit_segment.
+ - self.root_value_segment (:obj:`numpy.ndarray`): A numpy array version of the original root_value_segment.
+ - self.improved_policy_probs (:obj:`numpy.ndarray`): A numpy array version of the original improved_policy_probs.
+ - self.action_mask_segment (:obj:`numpy.ndarray`): A numpy array version of the original action_mask_segment.
+ - self.to_play_segment (:obj:`numpy.ndarray`): A numpy array version of the original to_play_segment.
+ - self.chance_segment (:obj:`numpy.ndarray`, optional): A numpy array version of the original chance_segment. Only
+ created if `self.use_ture_chance_label_in_chance_encoder` is True.
+
+ .. note::
+ For environments with a variable action space, such as board games, the elements in `child_visit_segment` may have
+ different lengths. In such scenarios, it is necessary to use the object data type for `self.child_visit_segment`.
+ """
+ self.obs_segment = np.array(self.obs_segment)
+ self.action_segment = np.array(self.action_segment)
+ self.reward_segment = np.array(self.reward_segment)
+
+ # Check if all elements in self.child_visit_segment have the same length
+ if all(len(x) == len(self.child_visit_segment[0]) for x in self.child_visit_segment):
+ self.child_visit_segment = np.array(self.child_visit_segment)
+ else:
+ # In the case of environments with a variable action space, such as board games,
+ # the elements in child_visit_segment may have different lengths.
+ # In such scenarios, it is necessary to use the object data type.
+ self.child_visit_segment = np.array(self.child_visit_segment, dtype=object)
+
+ self.root_value_segment = np.array(self.root_value_segment)
+ self.improved_policy_probs = np.array(self.improved_policy_probs)
+
+ self.action_mask_segment = np.array(self.action_mask_segment)
+ self.to_play_segment = np.array(self.to_play_segment)
+ if self.use_ture_chance_label_in_chance_encoder:
+ self.chance_segment = np.array(self.chance_segment)
+
+ def reset(self, init_observations: np.ndarray) -> None:
+ """
+ Overview:
+ Initialize the game segment using ``init_observations``,
+ which is the previous ``frame_stack_num`` stacked frames.
+ Arguments:
+ - init_observations (:obj:`list`): list of the stack observations in the previous time steps.
+ """
+ self.obs_segment = []
+ self.action_segment = []
+ self.reward_segment = []
+
+ self.child_visit_segment = []
+ self.root_value_segment = []
+
+ self.action_mask_segment = []
+ self.to_play_segment = []
+ if self.use_ture_chance_label_in_chance_encoder:
+ self.chance_segment = []
+
+ assert len(init_observations) == self.frame_stack_num
+
+ for observation in init_observations:
+ self.obs_segment.append(copy.deepcopy(observation))
+
+ def is_full(self) -> bool:
+ """
+ Overview:
+ Check whether the current game segment is full, i.e. larger than the segment length.
+ Returns:
+ bool: True if the game segment is full, False otherwise.
+ """
+ return len(self.action_segment) >= self.game_segment_length
+
+ def legal_actions(self):
+ return [_ for _ in range(self.action_space.n)]
+
+ def __len__(self):
+ return len(self.action_segment)
diff --git a/LightZero/lzero/mcts/ctree/__init__.py b/LightZero/lzero/mcts/ctree/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/LightZero/lzero/mcts/ctree/ctree_alphazero/CMakeLists.txt b/LightZero/lzero/mcts/ctree/ctree_alphazero/CMakeLists.txt
new file mode 100644
index 0000000000000000000000000000000000000000..fbac7e0b9872c6063f0a48c521a1778c2e7557cd
--- /dev/null
+++ b/LightZero/lzero/mcts/ctree/ctree_alphazero/CMakeLists.txt
@@ -0,0 +1,32 @@
+# Declare the minimum version of CMake that can be used
+# To understand and build the project
+cmake_minimum_required(VERSION 3.4...3.18)
+
+# Set the project name to mcts_alphazero and set the version to 1.0
+project(mcts_alphazero VERSION 1.0)
+
+# Find and get the details of Python package
+# This is required for embedding Python in the project
+find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
+
+# Add pybind11 as a subdirectory,
+# so that its build files are generated alongside the current project.
+# This is necessary because the current project depends on pybind11
+add_subdirectory(pybind11)
+
+# Add two .cpp files to the mcts_alphazero module
+# These files are compiled and linked into the module
+pybind11_add_module(mcts_alphazero mcts_alphazero.cpp node_alphazero.cpp)
+
+# Add the Python header file paths to the include paths
+# of the mcts_alphazero library. This is necessary for the
+# project to find the Python header files it needs to include
+target_include_directories(mcts_alphazero PRIVATE ${Python3_INCLUDE_DIRS})
+
+# Link the mcts_alphazero library with the pybind11::module target.
+# This is necessary for the mcts_alphazero library to use the functions and classes defined by pybind11
+target_link_libraries(mcts_alphazero PRIVATE pybind11::module)
+
+# Set the Python standard to the version of Python found by find_package(Python3)
+# This ensures that the code will be compiled against the correct version of Python
+set_target_properties(mcts_alphazero PROPERTIES PYTHON_STANDARD ${Python3_VERSION})
\ No newline at end of file
diff --git a/LightZero/lzero/mcts/ctree/ctree_alphazero/CMakeLists_mcts.txt b/LightZero/lzero/mcts/ctree/ctree_alphazero/CMakeLists_mcts.txt
new file mode 100644
index 0000000000000000000000000000000000000000..884570116acb0b54824d8dcb541174219550835d
--- /dev/null
+++ b/LightZero/lzero/mcts/ctree/ctree_alphazero/CMakeLists_mcts.txt
@@ -0,0 +1,29 @@
+# Declare the minimum version of CMake that can be used
+# To understand and build the project
+cmake_minimum_required(VERSION 3.4...3.18)
+
+# Set the project name to mcts_alphazero and set the version to 1.0
+project(mcts_alphazero VERSION 1.0)
+
+# Find and get the details of Python package
+# This is required for embedding Python in the project
+find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
+
+# Add pybind11 as a subdirectory,
+# so that its build files are generated alongside the current project.
+# This is necessary because the current project depends on pybind11
+add_subdirectory(pybind11)
+pybind11_add_module(mcts_alphazero mcts_alphazero.cpp)
+
+# Add the Python header file paths to the include paths
+# of the mcts_alphazero library. This is necessary for the
+# project to find the Python header files it needs to include
+target_include_directories(mcts_alphazero PRIVATE ${Python3_INCLUDE_DIRS})
+
+# Link the mcts_alphazero library with the pybind11::module target.
+# This is necessary for the mcts_alphazero library to use the functions and classes defined by pybind11
+target_link_libraries(mcts_alphazero PRIVATE pybind11::module)
+
+# Set the Python standard to the version of Python found by find_package(Python3)
+# This ensures that the code will be compiled against the correct version of Python
+set_target_properties(mcts_alphazero PROPERTIES PYTHON_STANDARD ${Python3_VERSION})
\ No newline at end of file
diff --git a/LightZero/lzero/mcts/ctree/ctree_alphazero/CMakeLists_node.txt b/LightZero/lzero/mcts/ctree/ctree_alphazero/CMakeLists_node.txt
new file mode 100644
index 0000000000000000000000000000000000000000..b1b59d80972e578b4e980ea50aab456ce8cd82bb
--- /dev/null
+++ b/LightZero/lzero/mcts/ctree/ctree_alphazero/CMakeLists_node.txt
@@ -0,0 +1,29 @@
+# Declare the minimum version of CMake that can be used
+# To understand and build the project
+cmake_minimum_required(VERSION 3.4...3.18)
+
+# Set the project name to node_alphazero and set the version to 1.0
+project(node_alphazero VERSION 1.0)
+
+# Find and get the details of Python package
+# This is required for embedding Python in the project
+find_package(Python3 COMPONENTS Interpreter Development REQUIRED)
+
+# Add pybind11 as a subdirectory,
+# so that its build files are generated alongside the current project.
+# This is necessary because the current project depends on pybind11
+add_subdirectory(pybind11)
+pybind11_add_module(node_alphazero node_alphazero.cpp)
+
+# Add the Python header file paths to the include paths
+# of the node_alphazero library. This is necessary for the
+# project to find the Python header files it needs to include
+target_include_directories(node_alphazero PRIVATE ${Python3_INCLUDE_DIRS})
+
+# Link the node_alphazero library with the pybind11::module target.
+# This is necessary for the node_alphazero library to use the functions and classes defined by pybind11
+target_link_libraries(node_alphazero PRIVATE pybind11::module)
+
+# Set the Python standard to the version of Python found by find_package(Python3)
+# This ensures that the code will be compiled against the correct version of Python
+set_target_properties(node_alphazero PROPERTIES PYTHON_STANDARD ${Python3_VERSION})
\ No newline at end of file
diff --git a/LightZero/lzero/mcts/ctree/ctree_alphazero/make.sh b/LightZero/lzero/mcts/ctree/ctree_alphazero/make.sh
new file mode 100644
index 0000000000000000000000000000000000000000..1f64b5c541ba809bb615cb214d0715ec48103d7e
--- /dev/null
+++ b/LightZero/lzero/mcts/ctree/ctree_alphazero/make.sh
@@ -0,0 +1,22 @@
+"""
+This script compiles the ctree_alphazero project. The compiled files are stored in the "build" directory.
+
+In summary, this script automates the process of creating a new build directory, navigating into it,
+running cmake to generate build files suitable for the arm64 architecture, and running make to compile the project.
+"""
+
+# Navigate to the project directory
+cd /Users//code/LightZero/lzero/mcts/ctree/ctree_alphazero/
+
+# Create a new directory named "build." The build directory is where the compiled files will be stored.
+mkdir build
+
+# Navigate into the "build" directory
+cd build
+
+# Run cmake on the parent directory. The ".." refers to the parent directory of the current directory.
+# The -DCMAKE_OSX_ARCHITECTURES="arm64" flag specifies that the generated build files should be suitable for the arm64 architecture.
+cmake .. -DCMAKE_OSX_ARCHITECTURES="arm64"
+
+# Run the "make" command. This command uses the files generated by cmake to compile the project.
+make
\ No newline at end of file
diff --git a/LightZero/lzero/mcts/ctree/ctree_alphazero/test/eval_mcts_alphazero.py b/LightZero/lzero/mcts/ctree/ctree_alphazero/test/eval_mcts_alphazero.py
new file mode 100644
index 0000000000000000000000000000000000000000..da7432d1dd27077730b97c3b48120ce6c0a165f3
--- /dev/null
+++ b/LightZero/lzero/mcts/ctree/ctree_alphazero/test/eval_mcts_alphazero.py
@@ -0,0 +1,34 @@
+"""
+This is an illustrative example of Python interfacing with a MCTS (Monte Carlo Tree Search) object implemented in C++.
+Please note that this code is not designed for actual execution.
+It serves as a conceptual demonstration, providing an understanding of how Python can interact with C++ objects,
+specifically within the context of MCTS.
+"""
+import sys
+
+import torch
+
+sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')
+
+import mcts_alphazero
+mcts_alphazero = mcts_alphazero.MCTS()
+
+def _policy_value_fn(self, env: 'Env') -> Tuple[Dict[int, np.ndarray], float]: # noqa
+ legal_actions = env.legal_actions
+ current_state, current_state_scale = env.current_state()
+ current_state_scale = torch.from_numpy(current_state_scale).to(
+ device=self._device, dtype=torch.float
+ ).unsqueeze(0)
+ with torch.no_grad():
+ action_probs, value = self._policy_model.compute_policy_value(current_state_scale)
+ action_probs_dict = dict(zip(legal_actions, action_probs.squeeze(0)[legal_actions].detach().cpu().numpy()))
+ return action_probs_dict, value.item()
+
+action, mcts_probs = mcts_alphazero.get_next_action(
+ simulate_env=simulate_env,
+ policy_value_func=_policy_value_fn,
+ temperature=1,
+ sample=True,
+)
+
+print(action, mcts_probs)
\ No newline at end of file
diff --git a/LightZero/lzero/mcts/ctree/ctree_alphazero/test/eval_node_alphazero.py b/LightZero/lzero/mcts/ctree/ctree_alphazero/test/eval_node_alphazero.py
new file mode 100644
index 0000000000000000000000000000000000000000..52a78144d5b52bd321d9d8a6c5fdcf139b30a9d6
--- /dev/null
+++ b/LightZero/lzero/mcts/ctree/ctree_alphazero/test/eval_node_alphazero.py
@@ -0,0 +1,9 @@
+import sys
+sys.path.append('./LightZero/lzero/mcts/ctree/ctree_alphazero/build')
+
+import mcts_alphazero
+n = mcts_alphazero.Node()
+print(n.is_leaf())
+print(n.update(5.0))
+# print(n.value())
+print(n)
diff --git a/LightZero/lzero/mcts/ctree/ctree_efficientzero/__init__.py b/LightZero/lzero/mcts/ctree/ctree_efficientzero/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/LightZero/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pxd b/LightZero/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pxd
new file mode 100644
index 0000000000000000000000000000000000000000..9151993498781d451a3f3c8054d69c536755d688
--- /dev/null
+++ b/LightZero/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pxd
@@ -0,0 +1,97 @@
+# distutils:language=c++
+# cython:language_level=3
+from libcpp.vector cimport vector
+
+
+cdef extern from "../common_lib/cminimax.cpp":
+ pass
+
+
+cdef extern from "../common_lib/cminimax.h" namespace "tools":
+ cdef cppclass CMinMaxStats:
+ CMinMaxStats() except +
+ float maximum, minimum, value_delta_max
+
+ void set_delta(float value_delta_max)
+ void update(float value)
+ void clear()
+ float normalize(float value)
+
+ cdef cppclass CMinMaxStatsList:
+ CMinMaxStatsList() except +
+ CMinMaxStatsList(int num) except +
+ int num
+ vector[CMinMaxStats] stats_lst
+
+ void set_delta(float value_delta_max)
+
+cdef extern from "lib/cnode.cpp":
+ pass
+
+
+cdef extern from "lib/cnode.h" namespace "tree":
+ cdef cppclass CNode:
+ CNode() except +
+ CNode(float prior, vector[int] & legal_actions) except +
+ int visit_count, to_play, current_latent_state_index, batch_index, best_action
+ float value_prefixs, prior, value_sum, parent_value_prefix
+
+ void expand(int to_play, int current_latent_state_index, int batch_index, float value_prefixs,
+ vector[float] policy_logits)
+ void add_exploration_noise(float exploration_fraction, vector[float] noises)
+ float compute_mean_q(int isRoot, float parent_q, float discount_factor)
+
+ int expanded()
+ float value()
+ vector[int] get_trajectory()
+ vector[int] get_children_distribution()
+ CNode * get_child(int action)
+
+ cdef cppclass CRoots:
+ CRoots() except +
+ CRoots(int root_num, vector[vector[int]] legal_actions_list) except +
+ int root_num
+ vector[CNode] roots
+
+ void prepare(float root_noise_weight, const vector[vector[float]] & noises,
+ const vector[float] & value_prefixs, const vector[vector[float]] & policies,
+ vector[int] to_play_batch)
+ void prepare_no_noise(const vector[float] & value_prefixs, const vector[vector[float]] & policies,
+ vector[int] to_play_batch)
+ void clear()
+ vector[vector[int]] get_trajectories()
+ vector[vector[int]] get_distributions()
+ vector[float] get_values()
+ # visualize related code
+ # CNode* get_root(int index)
+
+ cdef cppclass CSearchResults:
+ CSearchResults() except +
+ CSearchResults(int num) except +
+ int num
+ vector[int] latent_state_index_in_search_path, latent_state_index_in_batch, last_actions, search_lens
+ vector[int] virtual_to_play_batchs
+ vector[CNode *] nodes
+
+ cdef void cbackpropagate(vector[CNode *] & search_path, CMinMaxStats & min_max_stats,
+ int to_play, float value, float discount_factor)
+ void cbatch_backpropagate(int current_latent_state_index, float discount_factor, vector[float] value_prefixs,
+ vector[float] values, vector[vector[float]] policies,
+ CMinMaxStatsList *min_max_stats_lst, CSearchResults & results,
+ vector[int] is_reset_list, vector[int] & to_play_batch)
+ void cbatch_traverse(CRoots *roots, int pb_c_base, float pb_c_init, float discount_factor,
+ CMinMaxStatsList *min_max_stats_lst, CSearchResults & results,
+ vector[int] & virtual_to_play_batch)
+
+cdef class MinMaxStatsList:
+ cdef CMinMaxStatsList *cmin_max_stats_lst
+
+cdef class ResultsWrapper:
+ cdef CSearchResults cresults
+
+cdef class Roots:
+ cdef readonly int root_num
+ cdef CRoots *roots
+
+cdef class Node:
+ cdef CNode cnode
diff --git a/LightZero/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pyx b/LightZero/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pyx
new file mode 100644
index 0000000000000000000000000000000000000000..8149f8569d50a7be5fbbc9ca4827be6f1779243a
--- /dev/null
+++ b/LightZero/lzero/mcts/ctree/ctree_efficientzero/ez_tree.pyx
@@ -0,0 +1,100 @@
+# distutils:language=c++
+# cython:language_level=3
+import cython
+from libcpp.vector cimport vector
+
+cdef class MinMaxStatsList:
+ @cython.binding
+ def __cinit__(self, int num):
+ self.cmin_max_stats_lst = new CMinMaxStatsList(num)
+
+ @cython.binding
+ def set_delta(self, float value_delta_max):
+ self.cmin_max_stats_lst[0].set_delta(value_delta_max)
+
+ def __dealloc__(self):
+ del self.cmin_max_stats_lst
+
+cdef class ResultsWrapper:
+ @cython.binding
+ def __cinit__(self, int num):
+ self.cresults = CSearchResults(num)
+
+ @cython.binding
+ def get_search_len(self):
+ return self.cresults.search_lens
+
+cdef class Roots:
+ @cython.binding
+ def __cinit__(self, int root_num, vector[vector[int]] legal_actions_list):
+ self.root_num = root_num
+ self.roots = new CRoots(root_num, legal_actions_list)
+
+ @cython.binding
+ def prepare(self, float root_noise_weight, list noises, list value_prefix_pool,
+ list policy_logits_pool, vector[int] & to_play_batch):
+ self.roots[0].prepare(root_noise_weight, noises, value_prefix_pool, policy_logits_pool, to_play_batch)
+
+ @cython.binding
+ def prepare_no_noise(self, list value_prefix_pool, list policy_logits_pool, vector[int] & to_play_batch):
+ self.roots[0].prepare_no_noise(value_prefix_pool, policy_logits_pool, to_play_batch)
+
+ @cython.binding
+ def get_trajectories(self):
+ return self.roots[0].get_trajectories()
+
+ @cython.binding
+ def get_distributions(self):
+ return self.roots[0].get_distributions()
+
+ @cython.binding
+ def get_values(self):
+ return self.roots[0].get_values()
+
+ # visualize related code
+ #def get_root(self, int index):
+ # return self.roots[index]
+
+ @cython.binding
+ def clear(self):
+ self.roots[0].clear()
+
+ def __dealloc__(self):
+ del self.roots
+
+ @property
+ def num(self):
+ return self.root_num
+
+cdef class Node:
+ def __cinit__(self):
+ pass
+
+ def __cinit__(self, float prior, vector[int] & legal_actions):
+ pass
+
+ @cython.binding
+ def expand(self, int to_play, int current_latent_state_index, int batch_index, float value_prefix,
+ list policy_logits):
+ cdef vector[float] cpolicy = policy_logits
+ self.cnode.expand(to_play, current_latent_state_index, batch_index, value_prefix, cpolicy)
+
+@cython.binding
+def batch_backpropagate(int current_latent_state_index, float discount_factor, list value_prefixs, list values, list policies,
+ MinMaxStatsList min_max_stats_lst, ResultsWrapper results, list is_reset_list,
+ list to_play_batch):
+ cdef int i
+ cdef vector[float] cvalue_prefixs = value_prefixs
+ cdef vector[float] cvalues = values
+ cdef vector[vector[float]] cpolicies = policies
+
+ cbatch_backpropagate(current_latent_state_index, discount_factor, cvalue_prefixs, cvalues, cpolicies,
+ min_max_stats_lst.cmin_max_stats_lst, results.cresults, is_reset_list, to_play_batch)
+
+@cython.binding
+def batch_traverse(Roots roots, int pb_c_base, float pb_c_init, float discount_factor, MinMaxStatsList min_max_stats_lst,
+ ResultsWrapper results, list virtual_to_play_batch):
+ cbatch_traverse(roots.roots, pb_c_base, pb_c_init, discount_factor, min_max_stats_lst.cmin_max_stats_lst,
+ results.cresults, virtual_to_play_batch)
+
+ return results.cresults.latent_state_index_in_search_path, results.cresults.latent_state_index_in_batch, results.cresults.last_actions, results.cresults.virtual_to_play_batchs
diff --git a/LightZero/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp b/LightZero/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp
new file mode 100644
index 0000000000000000000000000000000000000000..59846f1ac16dd1b1a21cdb3b6f220881afac9529
--- /dev/null
+++ b/LightZero/lzero/mcts/ctree/ctree_efficientzero/lib/cnode.cpp
@@ -0,0 +1,792 @@
+// C++11
+
+#include
+#include "cnode.h"
+#include
+#include