gomoku / DI-engine /ding /entry /cli_ditask.py
zjowowen's picture
init space
079c32c
raw
history blame
5.48 kB
import click
import os
import sys
import importlib
import importlib.util
import json
from click.core import Context, Option
from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__
from ding.framework import Parallel
from ding.entry.cli_parsers import PLATFORM_PARSERS
def print_version(ctx: Context, param: Option, value: bool) -> None:
if not value or ctx.resilient_parsing:
return
click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__))
click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__))
ctx.exit()
CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])
@click.command(context_settings=CONTEXT_SETTINGS)
@click.option(
'-v',
'--version',
is_flag=True,
callback=print_version,
expose_value=False,
is_eager=True,
help="Show package's version information."
)
@click.option('-p', '--package', type=str, help="Your code package path, could be a directory or a zip file.")
@click.option('--parallel-workers', type=int, default=1, help="Parallel worker number, default: 1")
@click.option(
'--protocol',
type=click.Choice(["tcp", "ipc"]),
default="tcp",
help="Network protocol in parallel mode, default: tcp"
)
@click.option(
"--ports",
type=str,
help="The port addresses that the tasks listen to, e.g. 50515,50516, default: k8s, local: 50515, slurm: 15151"
)
@click.option("--attach-to", type=str, help="The addresses to connect to.")
@click.option("--address", type=str, help="The address to listen to (without port).")
@click.option("--labels", type=str, help="Labels.")
@click.option("--node-ids", type=str, help="Candidate node ids.")
@click.option(
"--topology",
type=click.Choice(["alone", "mesh", "star"]),
default="alone",
help="Network topology, default: alone."
)
@click.option("--platform-spec", type=str, help="Platform specific configure.")
@click.option("--platform", type=str, help="Platform type: slurm, k8s.")
@click.option("--mq-type", type=str, default="nng", help="Class type of message queue, i.e. nng, redis.")
@click.option("--redis-host", type=str, help="Redis host.")
@click.option("--redis-port", type=int, help="Redis port.")
@click.option("-m", "--main", type=str, help="Main function of entry module.")
@click.option("--startup-interval", type=int, default=1, help="Start up interval between each task.")
@click.option("--local_rank", type=int, default=0, help="Compatibility with PyTorch DDP")
def cli_ditask(*args, **kwargs):
return _cli_ditask(*args, **kwargs)
def _parse_platform_args(platform: str, platform_spec: str, all_args: dict):
if platform_spec:
try:
if os.path.splitext(platform_spec) == "json":
with open(platform_spec) as f:
platform_spec = json.load(f)
else:
platform_spec = json.loads(platform_spec)
except:
click.echo("platform_spec is not a valid json!")
exit(1)
if platform not in PLATFORM_PARSERS:
click.echo("platform type is invalid! type: {}".format(platform))
exit(1)
all_args.pop("platform")
all_args.pop("platform_spec")
try:
parsed_args = PLATFORM_PARSERS[platform](platform_spec, **all_args)
except Exception as e:
click.echo("error when parse platform spec configure: {}".format(e))
raise e
return parsed_args
def _cli_ditask(
package: str,
main: str,
parallel_workers: int,
protocol: str,
ports: str,
attach_to: str,
address: str,
labels: str,
node_ids: str,
topology: str,
mq_type: str,
redis_host: str,
redis_port: int,
startup_interval: int,
local_rank: int = 0,
platform: str = None,
platform_spec: str = None,
):
# Parse entry point
all_args = locals()
if platform:
parsed_args = _parse_platform_args(platform, platform_spec, all_args)
return _cli_ditask(**parsed_args)
if not package:
package = os.getcwd()
sys.path.append(package)
if main is None:
mod_name = os.path.basename(package)
mod_name, _ = os.path.splitext(mod_name)
func_name = "main"
else:
mod_name, func_name = main.rsplit(".", 1)
root_mod_name = mod_name.split(".", 1)[0]
sys.path.append(os.path.join(package, root_mod_name))
mod = importlib.import_module(mod_name)
main_func = getattr(mod, func_name)
# Parse arguments
ports = ports or 50515
if not isinstance(ports, int):
ports = ports.split(",")
ports = list(map(lambda i: int(i), ports))
ports = ports[0] if len(ports) == 1 else ports
if attach_to:
attach_to = attach_to.split(",")
attach_to = list(map(lambda s: s.strip(), attach_to))
if labels:
labels = labels.split(",")
labels = set(map(lambda s: s.strip(), labels))
if node_ids and not isinstance(node_ids, int):
node_ids = node_ids.split(",")
node_ids = list(map(lambda i: int(i), node_ids))
Parallel.runner(
n_parallel_workers=parallel_workers,
ports=ports,
protocol=protocol,
topology=topology,
attach_to=attach_to,
address=address,
labels=labels,
node_ids=node_ids,
mq_type=mq_type,
redis_host=redis_host,
redis_port=redis_port,
startup_interval=startup_interval
)(main_func)