File size: 4,038 Bytes
2d9a728
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import argparse
import os
import socket

from utils import has_slurm, random_port, runcmd

EXP_DIR_ENV_NAME = "VL_EXP_DIR"

# if key in hostname; apply the args in value to slurm.
DEFAULT_SLURM_ARGS = dict(login="-p gpu --mem=240GB -c 64 -t 2-00:00:00")


def get_default_slurm_args():
    """get the slurm args for different cluster.
    Returns: TODO

    """
    hostname = socket.gethostname()
    for k, v in DEFAULT_SLURM_ARGS.items():
        if k in hostname:
            return v
    return ""


def parse_args():
    parser = argparse.ArgumentParser()

    # slurm
    parser.add_argument("--slurm_args", type=str, default="", help="args for slurm.")
    parser.add_argument(
        "--no_slurm",
        action="store_true",
        help="If specified, will launch job without slurm",
    )
    parser.add_argument("--jobname", type=str, required=True, help="experiment name")
    parser.add_argument(
        "--dep_jobname", type=str, default="impossible_jobname", help="the dependent job name"
    )
    parser.add_argument("--nnodes", "-n", type=int, default=1, help="the number of nodes")
    parser.add_argument(
        "--ngpus", "-g", type=int, default=1, help="the number of gpus per nodes"
    )

    #
    parser.add_argument(
        "--task",
        type=str,
        required=True,
        help="one of: pretrain, retrieval, retrieval_mc, vqa.",
    )
    parser.add_argument("--config", type=str, required=True, help="config file name.")
    parser.add_argument("--model_args", type=str, default="", help="args for model")

    args = parser.parse_args()
    return args


def get_output_dir(args):
    """get the output_dir"""
    return os.path.join(os.environ[EXP_DIR_ENV_NAME], args.jobname)


def prepare(args: argparse.Namespace):
    """prepare for job submission

    Args:
        args (dict): The arguments.

    Returns: The path to the copied source code.

    """

    output_dir = get_output_dir(args)
    code_dir = os.path.join(output_dir, "code")
    project_dirname = os.path.basename(os.getcwd())

    # check output_dir exist
    if os.path.isdir(output_dir):
        # if using slurm
        if has_slurm() and not args.no_slurm:
            raise ValueError(f"output_dir {output_dir} already exist. Exit.")
    else:
        os.mkdir(output_dir)
        # copy code
        cmd = f"cd ..; rsync -ar {project_dirname} {code_dir} --exclude='*.out'"
        print(cmd)
        runcmd(cmd)
    return os.path.join(code_dir, project_dirname)


def submit_job(args: argparse.Namespace):
    """TODO: Docstring for build_job_script.

    Args:
        args (argparse.Namespace): The commandline args.

    Returns: str. The script to run.

    """
    output_dir = get_output_dir(args)
    # copy code
    code_dir = prepare(args)

    # enter in the backup code
    master_port = os.environ.get("MASTER_PORT", random_port())
    init_cmd = f" cd {code_dir}; export MASTER_PORT={master_port}; "

    if has_slurm() and not args.no_slurm:
        # prepare slurm args.
        mode = "slurm"
        default_slurm_args = get_default_slurm_args()
        bin = (
            f" sbatch --output {output_dir}/%j.out --error {output_dir}/%j.out"
            f" {default_slurm_args}"
            f" {args.slurm_args} --job-name={args.jobname} --nodes {args.nnodes} "
            f" --ntasks {args.nnodes} "
            f" --gpus-per-node={args.ngpus} "
            f" --dependency=$(squeue --noheader --format %i --name {args.dep_jobname}) "
        )
    else:
        mode = "local"
        bin = "bash "

    # build job cmd
    job_cmd = (
        f" tasks/{args.task}.py"
        f" {args.config}"
        f" output_dir {output_dir}"
        f" {args.model_args}"
    )

    cmd = (
        f" {init_cmd} {bin} "
        f" tools/submit.sh "
        f" {mode} {args.nnodes} {args.ngpus} {job_cmd} "
    )

    with open(os.path.join(output_dir, "cmd.txt"), "w") as f:
        f.write(cmd)

    print(cmd)
    runcmd(cmd)


if __name__ == "__main__":
    args = parse_args()
    submit_job(args)