File size: 5,475 Bytes
079c32c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import click
import os
import sys
import importlib
import importlib.util
import json
from click.core import Context, Option

from ding import __TITLE__, __VERSION__, __AUTHOR__, __AUTHOR_EMAIL__
from ding.framework import Parallel
from ding.entry.cli_parsers import PLATFORM_PARSERS


def print_version(ctx: Context, param: Option, value: bool) -> None:
    if not value or ctx.resilient_parsing:
        return
    click.echo('{title}, version {version}.'.format(title=__TITLE__, version=__VERSION__))
    click.echo('Developed by {author}, {email}.'.format(author=__AUTHOR__, email=__AUTHOR_EMAIL__))
    ctx.exit()


CONTEXT_SETTINGS = dict(help_option_names=['-h', '--help'])


@click.command(context_settings=CONTEXT_SETTINGS)
@click.option(
    '-v',
    '--version',
    is_flag=True,
    callback=print_version,
    expose_value=False,
    is_eager=True,
    help="Show package's version information."
)
@click.option('-p', '--package', type=str, help="Your code package path, could be a directory or a zip file.")
@click.option('--parallel-workers', type=int, default=1, help="Parallel worker number, default: 1")
@click.option(
    '--protocol',
    type=click.Choice(["tcp", "ipc"]),
    default="tcp",
    help="Network protocol in parallel mode, default: tcp"
)
@click.option(
    "--ports",
    type=str,
    help="The port addresses that the tasks listen to, e.g. 50515,50516, default: k8s, local: 50515, slurm: 15151"
)
@click.option("--attach-to", type=str, help="The addresses to connect to.")
@click.option("--address", type=str, help="The address to listen to (without port).")
@click.option("--labels", type=str, help="Labels.")
@click.option("--node-ids", type=str, help="Candidate node ids.")
@click.option(
    "--topology",
    type=click.Choice(["alone", "mesh", "star"]),
    default="alone",
    help="Network topology, default: alone."
)
@click.option("--platform-spec", type=str, help="Platform specific configure.")
@click.option("--platform", type=str, help="Platform type: slurm, k8s.")
@click.option("--mq-type", type=str, default="nng", help="Class type of message queue, i.e. nng, redis.")
@click.option("--redis-host", type=str, help="Redis host.")
@click.option("--redis-port", type=int, help="Redis port.")
@click.option("-m", "--main", type=str, help="Main function of entry module.")
@click.option("--startup-interval", type=int, default=1, help="Start up interval between each task.")
@click.option("--local_rank", type=int, default=0, help="Compatibility with PyTorch DDP")
def cli_ditask(*args, **kwargs):
    return _cli_ditask(*args, **kwargs)


def _parse_platform_args(platform: str, platform_spec: str, all_args: dict):
    if platform_spec:
        try:
            if os.path.splitext(platform_spec) == "json":
                with open(platform_spec) as f:
                    platform_spec = json.load(f)
            else:
                platform_spec = json.loads(platform_spec)
        except:
            click.echo("platform_spec is not a valid json!")
            exit(1)
    if platform not in PLATFORM_PARSERS:
        click.echo("platform type is invalid! type: {}".format(platform))
        exit(1)
    all_args.pop("platform")
    all_args.pop("platform_spec")
    try:
        parsed_args = PLATFORM_PARSERS[platform](platform_spec, **all_args)
    except Exception as e:
        click.echo("error when parse platform spec configure: {}".format(e))
        raise e

    return parsed_args


def _cli_ditask(
    package: str,
    main: str,
    parallel_workers: int,
    protocol: str,
    ports: str,
    attach_to: str,
    address: str,
    labels: str,
    node_ids: str,
    topology: str,
    mq_type: str,
    redis_host: str,
    redis_port: int,
    startup_interval: int,
    local_rank: int = 0,
    platform: str = None,
    platform_spec: str = None,
):
    # Parse entry point
    all_args = locals()
    if platform:
        parsed_args = _parse_platform_args(platform, platform_spec, all_args)
        return _cli_ditask(**parsed_args)

    if not package:
        package = os.getcwd()
    sys.path.append(package)
    if main is None:
        mod_name = os.path.basename(package)
        mod_name, _ = os.path.splitext(mod_name)
        func_name = "main"
    else:
        mod_name, func_name = main.rsplit(".", 1)
    root_mod_name = mod_name.split(".", 1)[0]
    sys.path.append(os.path.join(package, root_mod_name))
    mod = importlib.import_module(mod_name)
    main_func = getattr(mod, func_name)
    # Parse arguments
    ports = ports or 50515
    if not isinstance(ports, int):
        ports = ports.split(",")
        ports = list(map(lambda i: int(i), ports))
        ports = ports[0] if len(ports) == 1 else ports
    if attach_to:
        attach_to = attach_to.split(",")
        attach_to = list(map(lambda s: s.strip(), attach_to))
    if labels:
        labels = labels.split(",")
        labels = set(map(lambda s: s.strip(), labels))
    if node_ids and not isinstance(node_ids, int):
        node_ids = node_ids.split(",")
        node_ids = list(map(lambda i: int(i), node_ids))
    Parallel.runner(
        n_parallel_workers=parallel_workers,
        ports=ports,
        protocol=protocol,
        topology=topology,
        attach_to=attach_to,
        address=address,
        labels=labels,
        node_ids=node_ids,
        mq_type=mq_type,
        redis_host=redis_host,
        redis_port=redis_port,
        startup_interval=startup_interval
    )(main_func)