zjowowen's picture
init space
079c32c
from typing import List, Union
import os
import copy
import click
from click.core import Context, Option
import numpy as np
from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__
from ding.config import read_config
from .predefined_config import get_predefined_config
def print_version(ctx: Context, param: Option, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__))
click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__))
ctx.exit()
def print_registry(ctx: Context, param: Option, value: str):
if value is None:
return
from ding.utils import registries # noqa
if value not in registries:
click.echo('[ERROR]: not support registry name: {}'.format(value))
else:
registered_info = registries[value].query_details()
click.echo('Available {}: [{}]'.format(value, '|'.join(registered_info.keys())))
for alias, info in registered_info.items():
click.echo('\t{}: registered at {}#{}'.format(alias, info[0], info[1]))
ctx.exit()
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.option(
'-v',
'--version',
is_flag=True,
callback=print_version,
expose_value=False,
is_eager=True,
help="Show package's version information."
)
@click.option(
'-q',
'--query-registry',
type=str,
callback=print_registry,
expose_value=False,
is_eager=True,
help='query registered module or function, show name and path'
)
@click.option(
'-m',
'--mode',
type=click.Choice(
[
'serial',
'serial_onpolicy',
'serial_sqil',
'serial_dqfd',
'serial_trex',
'serial_trex_onpolicy',
'parallel',
'dist',
'eval',
'serial_reward_model',
'serial_gail',
'serial_offline',
'serial_ngu',
]
),
help='serial-train or parallel-train or dist-train or eval'
)
@click.option('-c', '--config', type=str, help='Path to DRL experiment config')
@click.option(
'-s',
'--seed',
type=int,
default=[0],
multiple=True,
help='random generator seed(for all the possible package: random, numpy, torch and user env)'
)
@click.option('-e', '--env', type=str, help='RL env name')
@click.option('-p', '--policy', type=str, help='DRL policy name')
@click.option('--exp-name', type=str, help='experiment directory name')
@click.option('--train-iter', type=str, default='1e8', help='Maximum policy update iterations in training')
@click.option('--env-step', type=str, default='1e8', help='Maximum collected environment steps for training')
@click.option('--load-path', type=str, default=None, help='Path to load ckpt')
@click.option('--replay-path', type=str, default=None, help='Path to save replay')
# the following arguments are only applied to dist mode
@click.option('--enable-total-log', type=bool, help='whether enable the total DI-engine system log', default=False)
@click.option('--disable-flask-log', type=bool, help='whether disable flask log', default=True)
@click.option(
'-P', '--platform', type=click.Choice(['local', 'slurm', 'k8s']), help='local or slurm or k8s', default='local'
)
@click.option(
'-M',
'--module',
type=click.Choice(['config', 'collector', 'learner', 'coordinator', 'learner_aggregator', 'spawn_learner']),
help='dist module type'
)
@click.option('--module-name', type=str, help='dist module name')
@click.option('-cdh', '--coordinator-host', type=str, help='coordinator host', default='0.0.0.0')
@click.option('-cdp', '--coordinator-port', type=int, help='coordinator port')
@click.option('-lh', '--learner-host', type=str, help='learner host', default='0.0.0.0')
@click.option('-lp', '--learner-port', type=int, help='learner port')
@click.option('-clh', '--collector-host', type=str, help='collector host', default='0.0.0.0')
@click.option('-clp', '--collector-port', type=int, help='collector port')
@click.option('-agh', '--aggregator-host', type=str, help='aggregator slave host', default='0.0.0.0')
@click.option('-agp', '--aggregator-port', type=int, help='aggregator slave port')
@click.option('--add', type=click.Choice(['collector', 'learner']), help='add replicas type')
@click.option('--delete', type=click.Choice(['collector', 'learner']), help='delete replicas type')
@click.option('--restart', type=click.Choice(['collector', 'learner']), help='restart replicas type')
@click.option('--kubeconfig', type=str, default=None, help='the path of Kubernetes configuration file')
@click.option('-cdn', '--coordinator-name', type=str, default=None, help='coordinator name')
@click.option('-ns', '--namespace', type=str, default=None, help='job namespace')
@click.option('-rs', '--replicas', type=int, default=1, help='number of replicas to add/delete/restart')
@click.option('-rpn', '--restart-pod-name', type=str, default=None, help='restart pod name')
@click.option('--cpus', type=int, default=0, help='The requested CPU, read the value from DIJob yaml by default')
@click.option('--gpus', type=int, default=0, help='The requested GPU, read the value from DIJob yaml by default')
@click.option(
'--memory', type=str, default=None, help='The requested Memory, read the value from DIJob yaml by default'
)
@click.option(
'--profile',
type=str,
default=None,
help='profile Time cost by cProfile, and save the files into the specified folder path'
)
def cli(
# serial/eval
mode: str,
config: str,
seed: Union[int, List],
exp_name: str,
env: str,
policy: str,
train_iter: str, # transform into int
env_step: str, # transform into int
load_path: str,
replay_path: str,
# parallel/dist
platform: str,
coordinator_host: str,
coordinator_port: int,
learner_host: str,
learner_port: int,
collector_host: str,
collector_port: int,
aggregator_host: str,
aggregator_port: int,
enable_total_log: bool,
disable_flask_log: bool,
module: str,
module_name: str,
# add/delete/restart
add: str,
delete: str,
restart: str,
kubeconfig: str,
coordinator_name: str,
namespace: str,
replicas: int,
cpus: int,
gpus: int,
memory: str,
restart_pod_name: str,
profile: str,
):
if profile is not None:
from ..utils.profiler_helper import Profiler
profiler = Profiler()
profiler.profile(profile)
train_iter = int(float(train_iter))
env_step = int(float(env_step))
def run_single_pipeline(seed, config):
if config is None:
config = get_predefined_config(env, policy)
else:
config = read_config(config)
if exp_name is not None:
config[0].exp_name = exp_name
if mode == 'serial':
from .serial_entry import serial_pipeline
serial_pipeline(config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_onpolicy':
from .serial_entry_onpolicy import serial_pipeline_onpolicy
serial_pipeline_onpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_sqil':
from .serial_entry_sqil import serial_pipeline_sqil
expert_config = input("Enter the name of the config you used to generate your expert model: ")
serial_pipeline_sqil(config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_reward_model':
from .serial_entry_reward_model_offpolicy import serial_pipeline_reward_model_offpolicy
serial_pipeline_reward_model_offpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_gail':
from .serial_entry_gail import serial_pipeline_gail
expert_config = input("Enter the name of the config you used to generate your expert model: ")
serial_pipeline_gail(
config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step, collect_data=True
)
elif mode == 'serial_dqfd':
from .serial_entry_dqfd import serial_pipeline_dqfd
expert_config = input("Enter the name of the config you used to generate your expert model: ")
assert (expert_config == config[:config.find('_dqfd')] + '_dqfd_config.py'), "DQFD only supports "\
+ "the models used in q learning now; However, one should still type the DQFD config in this "\
+ "place, i.e., {}{}".format(config[:config.find('_dqfd')], '_dqfd_config.py')
serial_pipeline_dqfd(config, expert_config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_trex':
from .serial_entry_trex import serial_pipeline_trex
serial_pipeline_trex(config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_trex_onpolicy':
from .serial_entry_trex_onpolicy import serial_pipeline_trex_onpolicy
serial_pipeline_trex_onpolicy(config, seed, max_train_iter=train_iter, max_env_step=env_step)
elif mode == 'serial_offline':
from .serial_entry_offline import serial_pipeline_offline
serial_pipeline_offline(config, seed, max_train_iter=train_iter)
elif mode == 'serial_ngu':
from .serial_entry_ngu import serial_pipeline_ngu
serial_pipeline_ngu(config, seed, max_train_iter=train_iter)
elif mode == 'parallel':
from .parallel_entry import parallel_pipeline
parallel_pipeline(config, seed, enable_total_log, disable_flask_log)
elif mode == 'dist':
from .dist_entry import dist_launch_coordinator, dist_launch_collector, dist_launch_learner, \
dist_prepare_config, dist_launch_learner_aggregator, dist_launch_spawn_learner, \
dist_add_replicas, dist_delete_replicas, dist_restart_replicas
if module == 'config':
dist_prepare_config(
config, seed, platform, coordinator_host, learner_host, collector_host, coordinator_port,
learner_port, collector_port
)
elif module == 'coordinator':
dist_launch_coordinator(config, seed, coordinator_port, disable_flask_log)
elif module == 'learner_aggregator':
dist_launch_learner_aggregator(
config, seed, aggregator_host, aggregator_port, module_name, disable_flask_log
)
elif module == 'collector':
dist_launch_collector(config, seed, collector_port, module_name, disable_flask_log)
elif module == 'learner':
dist_launch_learner(config, seed, learner_port, module_name, disable_flask_log)
elif module == 'spawn_learner':
dist_launch_spawn_learner(config, seed, learner_port, module_name, disable_flask_log)
elif add in ['collector', 'learner']:
dist_add_replicas(add, kubeconfig, replicas, coordinator_name, namespace, cpus, gpus, memory)
elif delete in ['collector', 'learner']:
dist_delete_replicas(delete, kubeconfig, replicas, coordinator_name, namespace)
elif restart in ['collector', 'learner']:
dist_restart_replicas(restart, kubeconfig, coordinator_name, namespace, restart_pod_name)
else:
raise Exception
elif mode == 'eval':
from .application_entry import eval
eval(config, seed, load_path=load_path, replay_path=replay_path)
if mode is None:
raise RuntimeError("Please indicate at least one argument.")
if isinstance(seed, (list, tuple)):
assert len(seed) > 0, "Please input at least 1 seed"
if len(seed) == 1: # necessary
run_single_pipeline(seed[0], config)
else:
if exp_name is None:
multi_exp_root = os.path.basename(config).split('.')[0] + '_result'
else:
multi_exp_root = exp_name
if not os.path.exists(multi_exp_root):
os.makedirs(multi_exp_root)
abs_config_path = os.path.abspath(config)
origin_root = os.getcwd()
for s in seed:
seed_exp_root = os.path.join(multi_exp_root, 'seed{}'.format(s))
if not os.path.exists(seed_exp_root):
os.makedirs(seed_exp_root)
os.chdir(seed_exp_root)
run_single_pipeline(s, abs_config_path)
os.chdir(origin_root)
else:
raise TypeError("invalid seed type: {}".format(type(seed)))