gomoku / DI-engine /ding /policy /command_mode_policy_instance.py
zjowowen's picture
init space
079c32c
raw
history blame
13.1 kB
from ding.utils import POLICY_REGISTRY
from ding.rl_utils import get_epsilon_greedy_fn
from .base_policy import CommandModePolicy
from .dqn import DQNPolicy, DQNSTDIMPolicy
from .mdqn import MDQNPolicy
from .c51 import C51Policy
from .qrdqn import QRDQNPolicy
from .iqn import IQNPolicy
from .fqf import FQFPolicy
from .rainbow import RainbowDQNPolicy
from .r2d2 import R2D2Policy
from .r2d2_gtrxl import R2D2GTrXLPolicy
from .r2d2_collect_traj import R2D2CollectTrajPolicy
from .sqn import SQNPolicy
from .ppo import PPOPolicy, PPOOffPolicy, PPOPGPolicy, PPOSTDIMPolicy
from .offppo_collect_traj import OffPPOCollectTrajPolicy
from .ppg import PPGPolicy, PPGOffPolicy
from .pg import PGPolicy
from .a2c import A2CPolicy
from .impala import IMPALAPolicy
from .ngu import NGUPolicy
from .ddpg import DDPGPolicy
from .td3 import TD3Policy
from .td3_vae import TD3VAEPolicy
from .td3_bc import TD3BCPolicy
from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy
from .mbpolicy.mbsac import MBSACPolicy, STEVESACPolicy
from .mbpolicy.dreamer import DREAMERPolicy
from .qmix import QMIXPolicy
from .wqmix import WQMIXPolicy
from .collaq import CollaQPolicy
from .coma import COMAPolicy
from .atoc import ATOCPolicy
from .acer import ACERPolicy
from .qtran import QTRANPolicy
from .sql import SQLPolicy
from .bc import BehaviourCloningPolicy
from .ibc import IBCPolicy
from .dqfd import DQFDPolicy
from .r2d3 import R2D3Policy
from .d4pg import D4PGPolicy
from .cql import CQLPolicy, DiscreteCQLPolicy
from .dt import DTPolicy
from .pdqn import PDQNPolicy
from .madqn import MADQNPolicy
from .bdq import BDQPolicy
from .bcq import BCQPolicy
from .edac import EDACPolicy
from .prompt_pg import PromptPGPolicy
from .plan_diffuser import PDPolicy
from .happo import HAPPOPolicy
class EpsCommandModePolicy(CommandModePolicy):
def _init_command(self) -> None:
r"""
Overview:
Command mode init method. Called by ``self.__init__``.
Set the eps_greedy rule according to the config for command
"""
eps_cfg = self._cfg.other.eps
self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
def _get_setting_collect(self, command_info: dict) -> dict:
r"""
Overview:
Collect mode setting information including eps
Arguments:
- command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep']
Returns:
- collect_setting (:obj:`dict`): Including eps in collect mode.
"""
# Decay according to `learner_train_iter`
# step = command_info['learner_train_iter']
# Decay according to `envstep`
step = command_info['envstep']
return {'eps': self.epsilon_greedy(step)}
def _get_setting_learn(self, command_info: dict) -> dict:
return {}
def _get_setting_eval(self, command_info: dict) -> dict:
return {}
class DummyCommandModePolicy(CommandModePolicy):
def _init_command(self) -> None:
pass
def _get_setting_collect(self, command_info: dict) -> dict:
return {}
def _get_setting_learn(self, command_info: dict) -> dict:
return {}
def _get_setting_eval(self, command_info: dict) -> dict:
return {}
@POLICY_REGISTRY.register('bdq_command')
class BDQCommandModePolicy(BDQPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('mdqn_command')
class MDQNCommandModePolicy(MDQNPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('dqn_command')
class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('dqn_stdim_command')
class DQNSTDIMCommandModePolicy(DQNSTDIMPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('dqfd_command')
class DQFDCommandModePolicy(DQFDPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('c51_command')
class C51CommandModePolicy(C51Policy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('qrdqn_command')
class QRDQNCommandModePolicy(QRDQNPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('iqn_command')
class IQNCommandModePolicy(IQNPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('fqf_command')
class FQFCommandModePolicy(FQFPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('rainbow_command')
class RainbowDQNCommandModePolicy(RainbowDQNPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('r2d2_command')
class R2D2CommandModePolicy(R2D2Policy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('r2d2_gtrxl_command')
class R2D2GTrXLCommandModePolicy(R2D2GTrXLPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('r2d2_collect_traj_command')
class R2D2CollectTrajCommandModePolicy(R2D2CollectTrajPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('r2d3_command')
class R2D3CommandModePolicy(R2D3Policy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('sqn_command')
class SQNCommandModePolicy(SQNPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('sql_command')
class SQLCommandModePolicy(SQLPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('ppo_command')
class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('happo_command')
class HAPPOCommandModePolicy(HAPPOPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('ppo_stdim_command')
class PPOSTDIMCommandModePolicy(PPOSTDIMPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('ppo_pg_command')
class PPOPGCommandModePolicy(PPOPGPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('ppo_offpolicy_command')
class PPOOffCommandModePolicy(PPOOffPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('offppo_collect_traj_command')
class PPOOffCollectTrajCommandModePolicy(OffPPOCollectTrajPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('pg_command')
class PGCommandModePolicy(PGPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('a2c_command')
class A2CCommandModePolicy(A2CPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('impala_command')
class IMPALACommandModePolicy(IMPALAPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('ppg_offpolicy_command')
class PPGOffCommandModePolicy(PPGOffPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('ppg_command')
class PPGCommandModePolicy(PPGPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('madqn_command')
class MADQNCommandModePolicy(MADQNPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('ddpg_command')
class DDPGCommandModePolicy(DDPGPolicy, CommandModePolicy):
def _init_command(self) -> None:
r"""
Overview:
Command mode init method. Called by ``self.__init__``.
If hybrid action space, set the eps_greedy rule according to the config for command,
otherwise, just a empty method
"""
if self._cfg.action_space == 'hybrid':
eps_cfg = self._cfg.other.eps
self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
def _get_setting_collect(self, command_info: dict) -> dict:
r"""
Overview:
Collect mode setting information including eps when hybrid action space
Arguments:
- command_info (:obj:`dict`): Dict type, including at least ['learner_step', 'envstep']
Returns:
- collect_setting (:obj:`dict`): Including eps in collect mode.
"""
if self._cfg.action_space == 'hybrid':
# Decay according to `learner_step`
# step = command_info['learner_step']
# Decay according to `envstep`
step = command_info['envstep']
return {'eps': self.epsilon_greedy(step)}
else:
return {}
def _get_setting_learn(self, command_info: dict) -> dict:
return {}
def _get_setting_eval(self, command_info: dict) -> dict:
return {}
@POLICY_REGISTRY.register('td3_command')
class TD3CommandModePolicy(TD3Policy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('td3_vae_command')
class TD3VAECommandModePolicy(TD3VAEPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('td3_bc_command')
class TD3BCCommandModePolicy(TD3BCPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('sac_command')
class SACCommandModePolicy(SACPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('mbsac_command')
class MBSACCommandModePolicy(MBSACPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('stevesac_command')
class STEVESACCommandModePolicy(STEVESACPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('dreamer_command')
class DREAMERCommandModePolicy(DREAMERPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('cql_command')
class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('discrete_cql_command')
class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('dt_command')
class DTCommandModePolicy(DTPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('qmix_command')
class QMIXCommandModePolicy(QMIXPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('wqmix_command')
class WQMIXCommandModePolicy(WQMIXPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('collaq_command')
class CollaQCommandModePolicy(CollaQPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('coma_command')
class COMACommandModePolicy(COMAPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('atoc_command')
class ATOCCommandModePolicy(ATOCPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('acer_command')
class ACERCommandModePolisy(ACERPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('qtran_command')
class QTRANCommandModePolicy(QTRANPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('ngu_command')
class NGUCommandModePolicy(NGUPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('d4pg_command')
class D4PGCommandModePolicy(D4PGPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('pdqn_command')
class PDQNCommandModePolicy(PDQNPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('discrete_sac_command')
class DiscreteSACCommandModePolicy(DiscreteSACPolicy, EpsCommandModePolicy):
pass
@POLICY_REGISTRY.register('sqil_sac_command')
class SQILSACCommandModePolicy(SQILSACPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('ibc_command')
class IBCCommandModePolicy(IBCPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('bcq_command')
class BCQCommandModelPolicy(BCQPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('edac_command')
class EDACCommandModelPolicy(EDACPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('pd_command')
class PDCommandModelPolicy(PDPolicy, DummyCommandModePolicy):
pass
@POLICY_REGISTRY.register('bc_command')
class BCCommandModePolicy(BehaviourCloningPolicy, DummyCommandModePolicy):
def _init_command(self) -> None:
r"""
Overview:
Command mode init method. Called by ``self.__init__``.
Set the eps_greedy rule according to the config for command
"""
if self._cfg.continuous:
noise_cfg = self._cfg.collect.noise_sigma
self.epsilon_greedy = get_epsilon_greedy_fn(noise_cfg.start, noise_cfg.end, noise_cfg.decay, noise_cfg.type)
else:
eps_cfg = self._cfg.other.eps
self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)
def _get_setting_collect(self, command_info: dict) -> dict:
r"""
Overview:
Collect mode setting information including eps
Arguments:
- command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep']
Returns:
- collect_setting (:obj:`dict`): Including eps in collect mode.
"""
if self._cfg.continuous:
# Decay according to `learner_step`
step = command_info['learner_step']
return {'sigma': self.epsilon_greedy(step)}
else:
# Decay according to `envstep`
step = command_info['envstep']
return {'eps': self.epsilon_greedy(step)}
def _get_setting_learn(self, command_info: dict) -> dict:
return {}
def _get_setting_eval(self, command_info: dict) -> dict:
return {}
@POLICY_REGISTRY.register('prompt_pg_command')
class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy):
pass