mazpie's picture
Initial commit
2d9a728
raw
history blame
4.04 kB
import argparse
import os
import socket
from utils import has_slurm, random_port, runcmd
EXP_DIR_ENV_NAME = "VL_EXP_DIR"
# if key in hostname; apply the args in value to slurm.
DEFAULT_SLURM_ARGS = dict(login="-p gpu --mem=240GB -c 64 -t 2-00:00:00")
def get_default_slurm_args():
"""get the slurm args for different cluster.
Returns: TODO
"""
hostname = socket.gethostname()
for k, v in DEFAULT_SLURM_ARGS.items():
if k in hostname:
return v
return ""
def parse_args():
parser = argparse.ArgumentParser()
# slurm
parser.add_argument("--slurm_args", type=str, default="", help="args for slurm.")
parser.add_argument(
"--no_slurm",
action="store_true",
help="If specified, will launch job without slurm",
)
parser.add_argument("--jobname", type=str, required=True, help="experiment name")
parser.add_argument(
"--dep_jobname", type=str, default="impossible_jobname", help="the dependent job name"
)
parser.add_argument("--nnodes", "-n", type=int, default=1, help="the number of nodes")
parser.add_argument(
"--ngpus", "-g", type=int, default=1, help="the number of gpus per nodes"
)
#
parser.add_argument(
"--task",
type=str,
required=True,
help="one of: pretrain, retrieval, retrieval_mc, vqa.",
)
parser.add_argument("--config", type=str, required=True, help="config file name.")
parser.add_argument("--model_args", type=str, default="", help="args for model")
args = parser.parse_args()
return args
def get_output_dir(args):
"""get the output_dir"""
return os.path.join(os.environ[EXP_DIR_ENV_NAME], args.jobname)
def prepare(args: argparse.Namespace):
"""prepare for job submission
Args:
args (dict): The arguments.
Returns: The path to the copied source code.
"""
output_dir = get_output_dir(args)
code_dir = os.path.join(output_dir, "code")
project_dirname = os.path.basename(os.getcwd())
# check output_dir exist
if os.path.isdir(output_dir):
# if using slurm
if has_slurm() and not args.no_slurm:
raise ValueError(f"output_dir {output_dir} already exist. Exit.")
else:
os.mkdir(output_dir)
# copy code
cmd = f"cd ..; rsync -ar {project_dirname} {code_dir} --exclude='*.out'"
print(cmd)
runcmd(cmd)
return os.path.join(code_dir, project_dirname)
def submit_job(args: argparse.Namespace):
"""TODO: Docstring for build_job_script.
Args:
args (argparse.Namespace): The commandline args.
Returns: str. The script to run.
"""
output_dir = get_output_dir(args)
# copy code
code_dir = prepare(args)
# enter in the backup code
master_port = os.environ.get("MASTER_PORT", random_port())
init_cmd = f" cd {code_dir}; export MASTER_PORT={master_port}; "
if has_slurm() and not args.no_slurm:
# prepare slurm args.
mode = "slurm"
default_slurm_args = get_default_slurm_args()
bin = (
f" sbatch --output {output_dir}/%j.out --error {output_dir}/%j.out"
f" {default_slurm_args}"
f" {args.slurm_args} --job-name={args.jobname} --nodes {args.nnodes} "
f" --ntasks {args.nnodes} "
f" --gpus-per-node={args.ngpus} "
f" --dependency=$(squeue --noheader --format %i --name {args.dep_jobname}) "
)
else:
mode = "local"
bin = "bash "
# build job cmd
job_cmd = (
f" tasks/{args.task}.py"
f" {args.config}"
f" output_dir {output_dir}"
f" {args.model_args}"
)
cmd = (
f" {init_cmd} {bin} "
f" tools/submit.sh "
f" {mode} {args.nnodes} {args.ngpus} {job_cmd} "
)
with open(os.path.join(output_dir, "cmd.txt"), "w") as f:
f.write(cmd)
print(cmd)
runcmd(cmd)
if __name__ == "__main__":
args = parse_args()
submit_job(args)