File size: 2,653 Bytes
e7d3e35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python

#
# This tool checks if evaluation is needed
#

import argparse
import os
import subprocess
import sys
import time
from pathlib import Path


repo_path = Path(__file__).parents[2]

# we have to deal with potentially overlapping slurm jobs running on different nodes, so we can't
# rely on PIDs of a running process. Will use a control file instead as the filesystem is shared.
#
# If that file is there it means:
#
# 1. either the eval is still running
# 2. the eval got aborted (e.g. gpu-oom)
#

# should fine tune - but surely 9h per checkpoint is plenty
reasonable_eval_time_in_secs = 9 * 60 * 60


def run_cmd(cmd, check=True):
    try:
        response = subprocess.run(
            cmd,
            stderr=subprocess.PIPE,
            stdout=subprocess.PIPE,
            check=check,
            encoding="utf-8",
        ).stdout.strip()
    except subprocess.CalledProcessError as exc:
        raise EnvironmentError(exc.stderr)

    return response


def get_args():
    parser = argparse.ArgumentParser()
    parser.add_argument("checkpoints_path", type=str, help="base dir with checkpoints")
    return parser.parse_args()


def exit(msg):
    print(msg)
    sys.exit()


def check_eval_crash(path):
    """Heuristics to decide whether to restart this opt_step-XXX checkpoint evaluation or not"""
    eval_0_completed_path = path / "start_run_evals_0_shots"
    eval_4_completed_path = path / "start_run_evals_4_shots"
    eval_perplexity_path = path / "start_run_evals_perplexity_validation"
    # complicated checks - has another job already started processing? or did it crash?
    for eval_start_path in [eval_0_completed_path, eval_4_completed_path, eval_perplexity_path]:
        if eval_start_path.exists():
            if eval_start_path.stat().st_mtime < time.time() - reasonable_eval_time_in_secs:
                print(f"[Y] {path} looks stale - Probably crashed - Restart evals")
                os.remove(eval_start_path)


def main():
    args = get_args()

    checkpoints_path = Path(args.checkpoints_path)
    if not (checkpoints_path.exists() and checkpoints_path.is_dir()):
        raise FileNotFoundError(f"can't find a directory '{checkpoints_path}'")

    checkpoint_dirs = list(checkpoints_path.glob("opt_step-*"))
    if len(checkpoint_dirs) == 0:
        exit("No checkpoints found, exiting")

    checkpoint_dirs_sorted = sorted(checkpoint_dirs, key=lambda x: int(str(x).split("-")[-1]))
    for i, checkpoint_dir in enumerate(checkpoint_dirs_sorted):
        print(f"\n*** Checking {checkpoint_dir} for evals")
        check_eval_crash(checkpoint_dir)


if __name__ == "__main__":
    main()