File size: 1,156 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
from typing import Tuple
from easydict import EasyDict
import sys
import importlib

env_dict = {
    'cartpole': 'dizoo.classic_control.cartpole.config',
    'pendulum': 'dizoo.classic_control.pendulum.config',
}
policy_dict = {
    'dqn': 'ding.policy.dqn',
    'rainbow': 'ding.policy.rainbow',
    'c51': 'ding.policy.c51',
    'qrdqn': 'ding.policy.qrdqn',
    'iqn': 'ding.policy.iqn',
    'a2c': 'ding.policy.a2c',
    'impala': 'ding.policy.impala',
    'ppo': 'ding.policy.ppo',
    'sqn': 'ding.policy.sqn',
    'r2d2': 'ding.policy.r2d2',
    'ddpg': 'ding.policy.ddpg',
    'td3': 'ding.policy.td3',
    'sac': 'ding.policy.sac',
}


def get_predefined_config(env: str, policy: str) -> Tuple[EasyDict, EasyDict]:
    config_name = '{}_{}_config'.format(env, policy)
    create_config_name = '{}_{}_create_config'.format(env, policy)
    try:
        m = importlib.import_module(env_dict[env] + '.' + config_name)
        return [getattr(m, config_name), getattr(m, create_config_name)]
    except ImportError:
        print("Please get started by other types, there is no related pre-defined config({})".format(config_name))
        sys.exit(1)