File size: 13,082 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
from ding.utils import POLICY_REGISTRY
from ding.rl_utils import get_epsilon_greedy_fn
from .base_policy import CommandModePolicy

from .dqn import DQNPolicy, DQNSTDIMPolicy
from .mdqn import MDQNPolicy
from .c51 import C51Policy
from .qrdqn import QRDQNPolicy
from .iqn import IQNPolicy
from .fqf import FQFPolicy
from .rainbow import RainbowDQNPolicy
from .r2d2 import R2D2Policy
from .r2d2_gtrxl import R2D2GTrXLPolicy
from .r2d2_collect_traj import R2D2CollectTrajPolicy
from .sqn import SQNPolicy
from .ppo import PPOPolicy, PPOOffPolicy, PPOPGPolicy, PPOSTDIMPolicy
from .offppo_collect_traj import OffPPOCollectTrajPolicy
from .ppg import PPGPolicy, PPGOffPolicy
from .pg import PGPolicy
from .a2c import A2CPolicy
from .impala import IMPALAPolicy
from .ngu import NGUPolicy
from .ddpg import DDPGPolicy
from .td3 import TD3Policy
from .td3_vae import TD3VAEPolicy
from .td3_bc import TD3BCPolicy
from .sac import SACPolicy, DiscreteSACPolicy, SQILSACPolicy
from .mbpolicy.mbsac import MBSACPolicy, STEVESACPolicy
from .mbpolicy.dreamer import DREAMERPolicy
from .qmix import QMIXPolicy
from .wqmix import WQMIXPolicy
from .collaq import CollaQPolicy
from .coma import COMAPolicy
from .atoc import ATOCPolicy
from .acer import ACERPolicy
from .qtran import QTRANPolicy
from .sql import SQLPolicy
from .bc import BehaviourCloningPolicy
from .ibc import IBCPolicy

from .dqfd import DQFDPolicy
from .r2d3 import R2D3Policy

from .d4pg import D4PGPolicy
from .cql import CQLPolicy, DiscreteCQLPolicy
from .dt import DTPolicy
from .pdqn import PDQNPolicy
from .madqn import MADQNPolicy
from .bdq import BDQPolicy
from .bcq import BCQPolicy
from .edac import EDACPolicy
from .prompt_pg import PromptPGPolicy
from .plan_diffuser import PDPolicy
from .happo import HAPPOPolicy


class EpsCommandModePolicy(CommandModePolicy):

    def _init_command(self) -> None:
        r"""
        Overview:
            Command mode init method. Called by ``self.__init__``.
            Set the eps_greedy rule according to the config for command
        """
        eps_cfg = self._cfg.other.eps
        self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)

    def _get_setting_collect(self, command_info: dict) -> dict:
        r"""
        Overview:
            Collect mode setting information including eps
        Arguments:
            - command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep']
        Returns:
           - collect_setting (:obj:`dict`): Including eps in collect mode.
        """
        # Decay according to `learner_train_iter`
        # step = command_info['learner_train_iter']
        # Decay according to `envstep`
        step = command_info['envstep']
        return {'eps': self.epsilon_greedy(step)}

    def _get_setting_learn(self, command_info: dict) -> dict:
        return {}

    def _get_setting_eval(self, command_info: dict) -> dict:
        return {}


class DummyCommandModePolicy(CommandModePolicy):

    def _init_command(self) -> None:
        pass

    def _get_setting_collect(self, command_info: dict) -> dict:
        return {}

    def _get_setting_learn(self, command_info: dict) -> dict:
        return {}

    def _get_setting_eval(self, command_info: dict) -> dict:
        return {}


@POLICY_REGISTRY.register('bdq_command')
class BDQCommandModePolicy(BDQPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('mdqn_command')
class MDQNCommandModePolicy(MDQNPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('dqn_command')
class DQNCommandModePolicy(DQNPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('dqn_stdim_command')
class DQNSTDIMCommandModePolicy(DQNSTDIMPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('dqfd_command')
class DQFDCommandModePolicy(DQFDPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('c51_command')
class C51CommandModePolicy(C51Policy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('qrdqn_command')
class QRDQNCommandModePolicy(QRDQNPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('iqn_command')
class IQNCommandModePolicy(IQNPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('fqf_command')
class FQFCommandModePolicy(FQFPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('rainbow_command')
class RainbowDQNCommandModePolicy(RainbowDQNPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('r2d2_command')
class R2D2CommandModePolicy(R2D2Policy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('r2d2_gtrxl_command')
class R2D2GTrXLCommandModePolicy(R2D2GTrXLPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('r2d2_collect_traj_command')
class R2D2CollectTrajCommandModePolicy(R2D2CollectTrajPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('r2d3_command')
class R2D3CommandModePolicy(R2D3Policy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('sqn_command')
class SQNCommandModePolicy(SQNPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('sql_command')
class SQLCommandModePolicy(SQLPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('ppo_command')
class PPOCommandModePolicy(PPOPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('happo_command')
class HAPPOCommandModePolicy(HAPPOPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('ppo_stdim_command')
class PPOSTDIMCommandModePolicy(PPOSTDIMPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('ppo_pg_command')
class PPOPGCommandModePolicy(PPOPGPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('ppo_offpolicy_command')
class PPOOffCommandModePolicy(PPOOffPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('offppo_collect_traj_command')
class PPOOffCollectTrajCommandModePolicy(OffPPOCollectTrajPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('pg_command')
class PGCommandModePolicy(PGPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('a2c_command')
class A2CCommandModePolicy(A2CPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('impala_command')
class IMPALACommandModePolicy(IMPALAPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('ppg_offpolicy_command')
class PPGOffCommandModePolicy(PPGOffPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('ppg_command')
class PPGCommandModePolicy(PPGPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('madqn_command')
class MADQNCommandModePolicy(MADQNPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('ddpg_command')
class DDPGCommandModePolicy(DDPGPolicy, CommandModePolicy):

    def _init_command(self) -> None:
        r"""
        Overview:
            Command mode init method. Called by ``self.__init__``.
            If hybrid action space, set the eps_greedy rule according to the config for command,
            otherwise, just a empty method
        """
        if self._cfg.action_space == 'hybrid':
            eps_cfg = self._cfg.other.eps
            self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)

    def _get_setting_collect(self, command_info: dict) -> dict:
        r"""
        Overview:
            Collect mode setting information including eps when hybrid action space
        Arguments:
            - command_info (:obj:`dict`): Dict type, including at least ['learner_step', 'envstep']
        Returns:
           - collect_setting (:obj:`dict`): Including eps in collect mode.
        """
        if self._cfg.action_space == 'hybrid':
            # Decay according to `learner_step`
            # step = command_info['learner_step']
            # Decay according to `envstep`
            step = command_info['envstep']
            return {'eps': self.epsilon_greedy(step)}
        else:
            return {}

    def _get_setting_learn(self, command_info: dict) -> dict:
        return {}

    def _get_setting_eval(self, command_info: dict) -> dict:
        return {}


@POLICY_REGISTRY.register('td3_command')
class TD3CommandModePolicy(TD3Policy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('td3_vae_command')
class TD3VAECommandModePolicy(TD3VAEPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('td3_bc_command')
class TD3BCCommandModePolicy(TD3BCPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('sac_command')
class SACCommandModePolicy(SACPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('mbsac_command')
class MBSACCommandModePolicy(MBSACPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('stevesac_command')
class STEVESACCommandModePolicy(STEVESACPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('dreamer_command')
class DREAMERCommandModePolicy(DREAMERPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('cql_command')
class CQLCommandModePolicy(CQLPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('discrete_cql_command')
class DiscreteCQLCommandModePolicy(DiscreteCQLPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('dt_command')
class DTCommandModePolicy(DTPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('qmix_command')
class QMIXCommandModePolicy(QMIXPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('wqmix_command')
class WQMIXCommandModePolicy(WQMIXPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('collaq_command')
class CollaQCommandModePolicy(CollaQPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('coma_command')
class COMACommandModePolicy(COMAPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('atoc_command')
class ATOCCommandModePolicy(ATOCPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('acer_command')
class ACERCommandModePolisy(ACERPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('qtran_command')
class QTRANCommandModePolicy(QTRANPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('ngu_command')
class NGUCommandModePolicy(NGUPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('d4pg_command')
class D4PGCommandModePolicy(D4PGPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('pdqn_command')
class PDQNCommandModePolicy(PDQNPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('discrete_sac_command')
class DiscreteSACCommandModePolicy(DiscreteSACPolicy, EpsCommandModePolicy):
    pass


@POLICY_REGISTRY.register('sqil_sac_command')
class SQILSACCommandModePolicy(SQILSACPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('ibc_command')
class IBCCommandModePolicy(IBCPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('bcq_command')
class BCQCommandModelPolicy(BCQPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('edac_command')
class EDACCommandModelPolicy(EDACPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('pd_command')
class PDCommandModelPolicy(PDPolicy, DummyCommandModePolicy):
    pass


@POLICY_REGISTRY.register('bc_command')
class BCCommandModePolicy(BehaviourCloningPolicy, DummyCommandModePolicy):

    def _init_command(self) -> None:
        r"""
        Overview:
            Command mode init method. Called by ``self.__init__``.
            Set the eps_greedy rule according to the config for command
        """
        if self._cfg.continuous:
            noise_cfg = self._cfg.collect.noise_sigma
            self.epsilon_greedy = get_epsilon_greedy_fn(noise_cfg.start, noise_cfg.end, noise_cfg.decay, noise_cfg.type)
        else:
            eps_cfg = self._cfg.other.eps
            self.epsilon_greedy = get_epsilon_greedy_fn(eps_cfg.start, eps_cfg.end, eps_cfg.decay, eps_cfg.type)

    def _get_setting_collect(self, command_info: dict) -> dict:
        r"""
        Overview:
            Collect mode setting information including eps
        Arguments:
            - command_info (:obj:`dict`): Dict type, including at least ['learner_train_iter', 'collector_envstep']
        Returns:
           - collect_setting (:obj:`dict`): Including eps in collect mode.
        """
        if self._cfg.continuous:
            # Decay according to `learner_step`
            step = command_info['learner_step']
            return {'sigma': self.epsilon_greedy(step)}
        else:
            # Decay according to `envstep`
            step = command_info['envstep']
            return {'eps': self.epsilon_greedy(step)}

    def _get_setting_learn(self, command_info: dict) -> dict:
        return {}

    def _get_setting_eval(self, command_info: dict) -> dict:
        return {}


@POLICY_REGISTRY.register('prompt_pg_command')
class PromptPGCommandModePolicy(PromptPGPolicy, DummyCommandModePolicy):
    pass