Spaces:
Runtime error
Runtime error
File size: 4,194 Bytes
217780a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 |
#!/usr/bin/env python
#
# This tool converts any deepspeed checkpoints found at given path to hf format
#
# Example:
#
# ./convert-checkpoints.py checkpoints-path
#
import argparse
import subprocess
import sys
import time
from pathlib import Path
repo_path = Path(__file__).parents[2]
zero_checkpoint_to_hf_path = repo_path / "m4/models/zero_checkpoint_to_hf.py"
# we have to deal with potentially overlapping slurm jobs running on different nodes, so we can't
# rely on PIDs of a running process. Will use a control file instead as the filesystem is shared.
#
# If that file is there it means:
#
# 1. either the conversion is still running
# 2. the conversion got aborted (e.g. cpu-oom)
#
# to detect aborted conversions we will check if the control file is older than a reasonable time to perform such a conversion
control_file_name = "started-convert-checkpoint"
# should fine tune - but surely 2h per checkpoint is plenty
reasonable_conversion_time_in_secs = 2 * 60 * 60
def run_cmd(cmd, check=True):
try:
response = subprocess.run(
cmd,
stderr=subprocess.PIPE,
stdout=subprocess.PIPE,
check=check,
encoding="utf-8",
).stdout.strip()
except subprocess.CalledProcessError as exc:
raise EnvironmentError(exc.stderr)
return response
def get_args():
parser = argparse.ArgumentParser()
parser.add_argument("checkpoints_path", type=str, help="base dir with checkpoints")
parser.add_argument("-f", "--force", action="store_true", help="force rebuilding of all checkpoints")
return parser.parse_args()
def exit(msg):
print(msg)
sys.exit()
def should_process(path, force, control_file_path):
"""Heuristics to decide whether to convert this opt_step-XXX checkpoint or not"""
target_dir = path / "unwrapped_model"
# easy checks - the conversion is clearly completed
if (target_dir / "pytorch_model.bin").exists() or (target_dir / "pytorch_model.bin.index.json").exists():
print(f"[N] {path} appears to be already converted. Skipping")
return False
if force:
print("[Y] Forced to re-convert {checkpoint_dir}")
return True
# complicated checks - has another job already started processing? or did it crash?
control_file_path = target_dir / control_file_name
if control_file_path.exists():
if control_file_path.stat().st_mtime < time.time() - reasonable_conversion_time_in_secs:
print(f"[Y] {path} looks stale - probably aborted job. Re-converting")
return True
else:
print(
f"[N] {path} either another job is converting it or less than"
f" {reasonable_conversion_time_in_secs} secs has passed since it was launched. Skipping"
)
return False
else:
print(f"[Y] {path} is a new checkpoint. Converting")
return True
def main():
args = get_args()
checkpoints_path = Path(args.checkpoints_path)
if not (checkpoints_path.exists() and checkpoints_path.is_dir()):
raise FileNotFoundError(f"can't find a directory '{checkpoints_path}'")
checkpoint_dirs = list(checkpoints_path.glob("opt_step-*"))
if len(checkpoint_dirs) == 0:
exit("No checkpoints found, exiting")
# Check each folder in real time to allow for overlapping jobs starting at different times
for checkpoint_dir in checkpoint_dirs:
print(f"\n*** Checking {checkpoint_dir}")
control_file_path = checkpoint_dir / "unwrapped_model" / control_file_name
if not should_process(checkpoint_dir, args.force, control_file_path):
continue
print(f"Launching conversion for {checkpoint_dir} - it could take a long time")
cmd = [zero_checkpoint_to_hf_path, checkpoint_dir]
# we could use flock here, to avoid a race condition, but it'd be pointless since each
# cronjob is likely to run on a different node and flock only works within a single node
control_file_path.touch()
response = run_cmd(cmd)
control_file_path.unlink()
print(response)
if __name__ == "__main__":
main()
|