File size: 13,003 Bytes
079c32c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 |
from typing import List, Union
import os
import copy
import click
from click.core import Context, Option
import numpy as np
from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__
from ding.config import read_config
from .predefined_config import get_predefined_config
def print_version(ctx: Context, param: Option, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__))
click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__))
ctx.exit()
def print_registry(ctx: Context, param: Option, value: str):
if value is None:
return
from ding.utils import registries # noqa
if value not in registries:
click.echo('[ERROR]: not support registry name: {}'.format(value))
else:
registered_info = registries[value].query_details()
click.echo('Available {}: [{}]'.format(value, '|'.join(registered_info.keys())))
for alias, info in registered_info.items():
click.echo('\t{}: registered at {}#{}'.format(alias, info[0], info[1]))
ctx.exit()
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.option(
'-v',
'--version',
is_flag=True,
callback=print_version,
expose_value=False,
is_eager=True,
help="Show package's version information."
)
@click.option(
'-q',
'--query-registry',
type=str,
callback=print_registry,
expose_value=False,
is_eager=True,
help='query registered module or function, show name and path'
)
@click.option(
'-m',
'--mode',
type=click.Choice(
[
'serial',
'serial_onpolicy',
'serial_sqil',
'serial_dqfd',
'serial_trex',
'serial_trex_onpolicy',
'parallel',
'dist',
'eval',
'serial_reward_model',
'serial_gail',
'serial_offline',
'serial_ngu',
]
),
help='serial-train or parallel-train or dist-train or eval'
)
@click.option('-c', '--config', type=str, help='Path to DRL experiment config')
@click.option(
'-s',
'--seed',
type=int,
default=[0],
multiple=True,
help='random generator seed(for all the possible package: random, numpy, torch and user env)'
)
@click.option('-e', '--env', type=str, help='RL env name')
@click.option('-p', '--policy', type=str, help='DRL policy name')
@click.option('--exp-name', type=str, help='experiment directory name')
@click.option('--train-iter', type=str, default='1e8', help='Maximum policy update iterations in training')
@click.option('--env-step', type=str, default='1e8', help='Maximum collected environment steps for training')
@click.option('--load-path', type=str, default=None, help='Path to load ckpt')
@click.option('--replay-path', type=str, default=None, help='Path to save replay')
# the following arguments are only applied to dist mode
@click.option('--enable-total-log', type=bool, help='whether enable the total DI-engine system log', default=False)
@click.option('--disable-flask-log', type=bool, help='whether disable flask log', default=True)
@click.option(
'-P', '--platform', type=click.Choice(['local', 'slurm', 'k8s']), help='local or slurm or k8s', default='local'
)
@click.option(
'-M',
'--module',
type=click.Choice(['config', 'collector', 'learner', 'coordinator', 'learner_aggregator', 'spawn_learner']),
help='dist module type'
)
@click.option('--module-name', type=str, help='dist module name')
@click.option('-cdh', '--coordinator-host', type=str, help='coordinator host', default='0.0.0.0')
@click.option('-cdp', '--coordinator-port', type=int, help='coordinator port')
@click.option('-lh', '--learner-host', type=str, help='learner host', default='0.0.0.0')
@click.option('-lp', '--learner-port', type=int, help='learner port')
@click.option('-clh', '--collector-host', type=str, help='collector host', default='0.0.0.0')
@click.option('-clp', '--collector-port', type=int, help='collector port')
@click.option('-agh', '--aggregator-host', type=str, help='aggregator slave host', default='0.0.0.0')
@click.option('-agp', '--aggregator-port', type=int, help='aggregator slave port')
@click.option('--add', type=click.Choice(['collector', 'learner']), help='add replicas type')
@click.option('--delete', type=click.Choice(['collector', 'learner']), help='delete replicas type')
@click.option('--restart', type=click.Choice(['collector', 'learner']), help='restart replicas type')
@click.option('--kubeconfig', type=str, default=None, help='the path of Kubernetes configuration file')
@click.option('-cdn', '--coordinator-name', type=str, default=None, help='coordinator name')
@click.option('-ns', '--namespace', type=str, default=None, help='job namespace')
@click.option('-rs', '--replicas', type=int, default=1, help='number of replicas to add/delete/restart')
@click.option('-rpn', '--restart-pod-name', type=str, default=None, help='restart pod name')
@click.option('--cpus', type=int, default=0, help='The requested CPU, read the value from DIJob yaml by default')
@click.option('--gpus', type=int, default=0, help='The requested GPU, read the value from DIJob yaml by default')
@click.option(
'--memory', type=str, default=None, help='The requested Memory, read the value from DIJob yaml by default'
)
@click.option(
'--profile',
type=str,
default=None,
help='profile Time cost by cProfile, and save the files into the specified folder path'
)
def cli(
# serial/eval
mode: str,
config: str,
seed: Union[int, List],
exp_name: str,
env: str,
policy: str,
train_iter: str, # transform into int
env_step: str, # transform into int
load_path: str,
replay_path: str,
# parallel/dist
platform: str,
coordinator_host: str,
coordinator_port: int,
learner_host: str,
learner_port: int,
collector_host: str,
collector_port: int,
aggregator_host: str,
aggregator_port: int,
enable_total_log: bool,
disable_flask_log: bool,
module: str,
module_name: str,
# add/delete/restart
add: str,
delete: str,
restart: str,
kubeconfig: str,
coordinator_name: str,
namespace: str,
replicas: int,
cpus: int,
gpus: int,
memory: str,
restart_pod_name: str,
profile: str,
):
if profile is not None:
from ..utils.profiler_helper import Profiler
profiler = Profiler()
profiler.profile(profile)
train_iter = int(float(train_iter))
env_step = int(float(env_step))
def run_single_pipeline(seed, config):
if config is None:
config = get_predefined_config(env, policy)
else:
config = read_config(config)
if exp_name is not None:
config[0].exp_name = exp_name
if mode == 'serial':
from .serial_entry import serial_pipeline
serial_pipeline(config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_onpolicy':
from .serial_entry_onpolicy import serial_pipeline_onpolicy
serial_pipeline_onpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_sqil':
from .serial_entry_sqil import serial_pipeline_sqil
expert_config = input("Enter the name of the config you used to generate your expert model: ")
serial_pipeline_sqil(config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_reward_model':
from .serial_entry_reward_model_offpolicy import serial_pipeline_reward_model_offpolicy
serial_pipeline_reward_model_offpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_gail':
from .serial_entry_gail import serial_pipeline_gail
expert_config = input("Enter the name of the config you used to generate your expert model: ")
serial_pipeline_gail(
config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step, collect_data=True
)
elif mode == 'serial_dqfd':
from .serial_entry_dqfd import serial_pipeline_dqfd
expert_config = input("Enter the name of the config you used to generate your expert model: ")
assert (expert_config == config[:config.find('_dqfd')] + '_dqfd_config.py'), "DQFD only supports "\
+ "the models used in q learning now; However, one should still type the DQFD config in this "\
+ "place, i.e., {}{}".format(config[:config.find('_dqfd')], '_dqfd_config.py')
serial_pipeline_dqfd(config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_trex':
from .serial_entry_trex import serial_pipeline_trex
serial_pipeline_trex(config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_trex_onpolicy':
from .serial_entry_trex_onpolicy import serial_pipeline_trex_onpolicy
serial_pipeline_trex_onpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_offline':
from .serial_entry_offline import serial_pipeline_offline
serial_pipeline_offline(config, seed, max_train_iter=train_iter)
elif mode == 'serial_ngu':
from .serial_entry_ngu import serial_pipeline_ngu
serial_pipeline_ngu(config, seed, max_train_iter=train_iter)
elif mode == 'parallel':
from .parallel_entry import parallel_pipeline
parallel_pipeline(config, seed, enable_total_log, disable_flask_log)
elif mode == 'dist':
from .dist_entry import dist_launch_coordinator, dist_launch_collector, dist_launch_learner, \
dist_prepare_config, dist_launch_learner_aggregator, dist_launch_spawn_learner, \
dist_add_replicas, dist_delete_replicas, dist_restart_replicas
if module == 'config':
dist_prepare_config(
config, seed, platform, coordinator_host, learner_host, collector_host, coordinator_port,
learner_port, collector_port
)
elif module == 'coordinator':
dist_launch_coordinator(config, seed, coordinator_port, disable_flask_log)
elif module == 'learner_aggregator':
dist_launch_learner_aggregator(
config, seed, aggregator_host, aggregator_port, module_name, disable_flask_log
)
elif module == 'collector':
dist_launch_collector(config, seed, collector_port, module_name, disable_flask_log)
elif module == 'learner':
dist_launch_learner(config, seed, learner_port, module_name, disable_flask_log)
elif module == 'spawn_learner':
dist_launch_spawn_learner(config, seed, learner_port, module_name, disable_flask_log)
elif add in ['collector', 'learner']:
dist_add_replicas(add, kubeconfig, replicas, coordinator_name, namespace, cpus, gpus, memory)
elif delete in ['collector', 'learner']:
dist_delete_replicas(delete, kubeconfig, replicas, coordinator_name, namespace)
elif restart in ['collector', 'learner']:
dist_restart_replicas(restart, kubeconfig, coordinator_name, namespace, restart_pod_name)
else:
raise Exception
elif mode == 'eval':
from .application_entry import eval
eval(config, seed, load_path=load_path, replay_path=replay_path)
if mode is None:
raise RuntimeError("Please indicate at least one argument.")
if isinstance(seed, (list, tuple)):
assert len(seed) > 0, "Please input at least 1 seed"
if len(seed) == 1: # necessary
run_single_pipeline(seed[0], config)
else:
if exp_name is None:
multi_exp_root = os.path.basename(config).split('.')[0] + '_result'
else:
multi_exp_root = exp_name
if not os.path.exists(multi_exp_root):
os.makedirs(multi_exp_root)
abs_config_path = os.path.abspath(config)
origin_root = os.getcwd()
for s in seed:
seed_exp_root = os.path.join(multi_exp_root, 'seed{}'.format(s))
if not os.path.exists(seed_exp_root):
os.makedirs(seed_exp_root)
os.chdir(seed_exp_root)
run_single_pipeline(s, abs_config_path)
os.chdir(origin_root)
else:
raise TypeError("invalid seed type: {}".format(type(seed)))
|